Skip to content

Commit

Permalink
Merge branch 'master' into SNOW-1029631-optimize-put-mem
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-ext-simba-jy authored Dec 18, 2024
2 parents cd542f1 + 7d34091 commit 6ece113
Show file tree
Hide file tree
Showing 18 changed files with 540 additions and 124 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/jira_issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
summary: '${{ github.event.issue.title }}'
description: |
${{ github.event.issue.body }} \\ \\ _Created from GitHub Action_ for ${{ github.event.issue.html_url }}
fields: '{ "customfield_11401": {"id": "14723"}, "assignee": {"id": "712020:3c0352b5-63f7-4e26-9afe-38f6f9f0f4c5"}, "components":[{"id":"19286"}] }'
fields: '{ "customfield_11401": {"id": "14723"}, "assignee": {"id": "712020:e1f41916-da57-4fe8-b317-116d5229aa51"}, "components":[{"id":"19286"}], "labels": ["oss"], "priority": {"id": "10001"} }'

- name: Update GitHub Issue
uses: ./jira/gajira-issue-update
Expand Down
1 change: 1 addition & 0 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ func (sr *snowflakeRestful) getAsync(
rows.errChannel <- err
return err
}
rows.format = respd.Data.QueryResultFormat
rows.errChannel <- nil // mark query status complete
}
} else {
Expand Down
33 changes: 22 additions & 11 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
)

type snowflakeAzureClient struct {
cfg *Config
}

type azureLocation struct {
Expand Down Expand Up @@ -85,9 +86,11 @@ func (util *snowflakeAzureClient) getFileHeader(meta *fileMetadata, filename str
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
resp, err := blobClient.GetProperties(context.Background(), &blob.GetPropertiesOptions{
AccessConditions: &blob.AccessConditions{},
CPKInfo: &blob.CPKInfo{},
resp, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (blob.GetPropertiesResponse, error) {
return blobClient.GetProperties(ctx, &blob.GetPropertiesOptions{
AccessConditions: &blob.AccessConditions{},
CPKInfo: &blob.CPKInfo{},
})
})
if err != nil {
var se *azcore.ResponseError
Expand Down Expand Up @@ -203,9 +206,11 @@ func (util *snowflakeAzureClient) uploadFile(
if meta.realSrcStream != nil {
uploadSrc = meta.realSrcStream
}
_, err = blobClient.UploadStream(context.Background(), uploadSrc, &azblob.UploadStreamOptions{
BlockSize: int64(uploadSrc.Len()),
Metadata: azureMeta,
_, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (azblob.UploadStreamResponse, error) {
return blobClient.UploadStream(ctx, uploadSrc, &azblob.UploadStreamOptions{
BlockSize: int64(uploadSrc.Len()),
Metadata: azureMeta,
})
})
} else {
var f *os.File
Expand All @@ -228,7 +233,9 @@ func (util *snowflakeAzureClient) uploadFile(
if meta.options.putAzureCallback != nil {
blobOptions.Progress = meta.options.putAzureCallback.call
}
_, err = blobClient.UploadFile(context.Background(), f, blobOptions)
_, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (azblob.UploadFileResponse, error) {
return blobClient.UploadFile(ctx, f, blobOptions)
})
}
if err != nil {
var se *azcore.ResponseError
Expand Down Expand Up @@ -279,7 +286,9 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
blobClient = meta.mockAzureClient
}
if meta.options.GetFileToStream {
blobDownloadResponse, err := blobClient.DownloadStream(context.Background(), &azblob.DownloadStreamOptions{})
blobDownloadResponse, err := withCloudStorageTimeout(util.cfg, func(ctx context.Context) (azblob.DownloadStreamResponse, error) {
return blobClient.DownloadStream(ctx, &azblob.DownloadStreamOptions{})
})
if err != nil {
return err
}
Expand All @@ -295,9 +304,11 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
return err
}
defer f.Close()
_, err = blobClient.DownloadFile(
context.Background(), f, &azblob.DownloadFileOptions{
Concurrency: uint16(maxConcurrency)})
_, err = withCloudStorageTimeout(util.cfg, func(ctx context.Context) (any, error) {
return blobClient.DownloadFile(
ctx, f, &azblob.DownloadFileOptions{
Concurrency: uint16(maxConcurrency)})
})
if err != nil {
return err
}
Expand Down
51 changes: 48 additions & 3 deletions azure_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ func TestUploadFileWithAzureUploadFailedError(t *testing.T) {
return azblob.UploadFileResponse{}, errors.New("unexpected error uploading file")
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

uploadMeta.realSrcFileName = uploadMeta.srcFileName
Expand Down Expand Up @@ -230,6 +235,11 @@ func TestUploadStreamWithAzureUploadFailedError(t *testing.T) {
return azblob.UploadStreamResponse{}, errors.New("unexpected error uploading file")
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

uploadMeta.realSrcStream = uploadMeta.srcStream
Expand Down Expand Up @@ -291,6 +301,11 @@ func TestUploadFileWithAzureUploadTokenExpired(t *testing.T) {
}
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

uploadMeta.realSrcFileName = uploadMeta.srcFileName
Expand Down Expand Up @@ -362,6 +377,11 @@ func TestUploadFileWithAzureUploadNeedsRetry(t *testing.T) {
}
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

uploadMeta.realSrcFileName = uploadMeta.srcFileName
Expand Down Expand Up @@ -418,6 +438,11 @@ func TestDownloadOneFileToAzureFailed(t *testing.T) {
return blob.GetPropertiesResponse{}, nil
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}
err = new(remoteStorageUtil).downloadOneFile(&downloadMeta)
if err == nil {
Expand All @@ -444,9 +469,14 @@ func TestGetFileHeaderErrorStatus(t *testing.T) {
return blob.GetPropertiesResponse{}, errors.New("failed to retrieve headers")
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
t.Fatalf("expected null header, got: %v", header)
}
if meta.resStatus != errStatus {
Expand Down Expand Up @@ -477,9 +507,14 @@ func TestGetFileHeaderErrorStatus(t *testing.T) {
}
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
t.Fatalf("expected null header, got: %v", header)
}
if meta.resStatus != notFoundFile {
Expand All @@ -505,7 +540,7 @@ func TestGetFileHeaderErrorStatus(t *testing.T) {
},
}

if header, err := new(snowflakeAzureClient).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
if header, err := (&snowflakeAzureClient{cfg: &Config{}}).getFileHeader(&meta, "file.txt"); header != nil || err == nil {
t.Fatalf("expected null header, got: %v", header)
}
if meta.resStatus != renewToken {
Expand Down Expand Up @@ -540,6 +575,11 @@ func TestUploadFileToAzureClientCastFail(t *testing.T) {
options: &SnowflakeFileTransferOptions{
MultiPartThreshold: dataSizeThreshold,
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

uploadMeta.realSrcFileName = uploadMeta.srcFileName
Expand Down Expand Up @@ -573,6 +613,11 @@ func TestAzureGetHeaderClientCastFail(t *testing.T) {
return blob.GetPropertiesResponse{}, nil
},
},
sfa: &snowflakeFileTransferAgent{
sc: &snowflakeConn{
cfg: &Config{},
},
},
}

_, err = new(snowflakeAzureClient).getFileHeader(&meta, "file.txt")
Expand Down
39 changes: 39 additions & 0 deletions chunk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -533,6 +534,44 @@ func TestWithArrowBatchesAsync(t *testing.T) {
})
}

func TestWithArrowBatchesButReturningJSON(t *testing.T) {
testWithArrowBatchesButReturningJSON(t, false)
}

func TestWithArrowBatchesButReturningJSONAsync(t *testing.T) {
testWithArrowBatchesButReturningJSON(t, true)
}

func testWithArrowBatchesButReturningJSON(t *testing.T, async bool) {
runSnowflakeConnTest(t, func(sct *SCTest) {
requestID := NewUUID()
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
ctx := WithArrowAllocator(context.Background(), pool)
ctx = WithArrowBatches(ctx)
ctx = WithRequestID(ctx, requestID)
if async {
ctx = WithAsyncMode(ctx)
}

sct.mustExec(forceJSON, nil)
rows := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
defer rows.Close()
_, err := rows.(SnowflakeRows).GetArrowBatches()
assertNotNilF(t, err)
var se *SnowflakeError
errors.As(err, &se)
assertEqualE(t, se.Message, errJSONResponseInArrowBatchesMode)

ctx = WithRequestID(context.Background(), requestID)
rows2 := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
defer rows2.Close()
scanValues := make([]driver.Value, 1)
assertNilF(t, rows2.Next(scanValues))
assertEqualE(t, scanValues[0], "hello")
})
}

func TestQueryArrowStream(t *testing.T) {
runSnowflakeConnTest(t, func(sct *SCTest) {
numrows := 50000 // approximately 10 ArrowBatch objects
Expand Down
1 change: 1 addition & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ func (sc *snowflakeConn) queryContextInternal(
rows.sc = sc
rows.queryID = data.Data.QueryID
rows.ctx = ctx
rows.format = data.Data.QueryResultFormat

if isMultiStmt(&data.Data) {
// handleMultiQuery is responsible to fill rows with childResults
Expand Down
25 changes: 24 additions & 1 deletion dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout
defaultJWTTimeout = 60 * time.Second
defaultExternalBrowserTimeout = 120 * time.Second // Timeout for external browser login
defaultCloudStorageTimeout = -1 // Timeout for calling cloud storage.
defaultMaxRetryCount = 7 // specifies maximum number of subsequent retries
defaultDomain = ".snowflakecomputing.com"
cnDomain = ".snowflakecomputing.cn"
Expand Down Expand Up @@ -77,6 +78,7 @@ type Config struct {
ClientTimeout time.Duration // Timeout for network round trip + read out http response
JWTClientTimeout time.Duration // Timeout for network round trip + read out http response used when JWT token auth is taking place
ExternalBrowserTimeout time.Duration // Timeout for external browser login
CloudStorageTimeout time.Duration // Timeout for a single call to a cloud storage provider
MaxRetryCount int // Specifies how many times non-periodic HTTP request can be retried

Application string // application name.
Expand Down Expand Up @@ -139,11 +141,16 @@ func (c *Config) ocspMode() string {

// DSN constructs a DSN for Snowflake db.
func DSN(cfg *Config) (dsn string, err error) {
if cfg.Region == "us-west-2" {
if strings.ToLower(cfg.Region) == "us-west-2" {
cfg.Region = ""
}
// in case account includes region
region, posDot := extractRegionFromAccount(cfg.Account)
if strings.ToLower(region) == "us-west-2" {
region = ""
cfg.Account = cfg.Account[:posDot]
logger.Info("Ignoring default region .us-west-2 in DSN from Account configuration.")
}
if region != "" {
if cfg.Region != "" {
return "", errRegionConflict()
Expand Down Expand Up @@ -215,6 +222,9 @@ func DSN(cfg *Config) (dsn string, err error) {
if cfg.ExternalBrowserTimeout != defaultExternalBrowserTimeout {
params.Add("externalBrowserTimeout", strconv.FormatInt(int64(cfg.ExternalBrowserTimeout/time.Second), 10))
}
if cfg.CloudStorageTimeout != defaultCloudStorageTimeout {
params.Add("cloudStorageTimeout", strconv.FormatInt(int64(cfg.CloudStorageTimeout/time.Second), 10))
}
if cfg.MaxRetryCount != defaultMaxRetryCount {
params.Add("maxRetryCount", strconv.Itoa(cfg.MaxRetryCount))
}
Expand Down Expand Up @@ -498,6 +508,9 @@ func fillMissingConfigParameters(cfg *Config) error {
if cfg.ExternalBrowserTimeout == 0 {
cfg.ExternalBrowserTimeout = defaultExternalBrowserTimeout
}
if cfg.CloudStorageTimeout == 0 {
cfg.CloudStorageTimeout = defaultCloudStorageTimeout
}
if cfg.MaxRetryCount == 0 {
cfg.MaxRetryCount = defaultMaxRetryCount
}
Expand Down Expand Up @@ -579,6 +592,11 @@ func transformAccountToHost(cfg *Config) (err error) {
// account name is specified instead of host:port
cfg.Account = cfg.Host
region, posDot := extractRegionFromAccount(cfg.Account)
if strings.ToLower(region) == "us-west-2" {
region = ""
cfg.Account = cfg.Account[:posDot]
logger.Info("Ignoring default region .us-west-2 from Account configuration.")
}
if region != "" {
cfg.Region = region
cfg.Account = cfg.Account[:posDot]
Expand Down Expand Up @@ -714,6 +732,11 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return err
}
case "cloudStorageTimeout":
cfg.CloudStorageTimeout, err = parseTimeout(value)
if err != nil {
return err
}
case "maxRetryCount":
cfg.MaxRetryCount, err = strconv.Atoi(value)
if err != nil {
Expand Down
Loading

0 comments on commit 6ece113

Please sign in to comment.