Skip to content

Commit 4be8ace

Browse files
committed
Merge branch 'cloudfetch' into main
Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com>
1 parent 65bde57 commit 4be8ace

16 files changed

Lines changed: 1085 additions & 187 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ You can set query timeout value by appending a `timeout` query parameter (in sec
4646
```
4747
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?timeout=1000&maxRows=1000
4848
```
49+
You can turn on Cloud Fetch to increase the performance of extracting large query results by fetching data in parallel via cloud storage (more info [here](https://www.databricks.com/blog/2021/08/11/how-we-achieved-high-bandwidth-connectivity-with-bi-tools.html)). To turn on Cloud Fetch, append `useCloudFetch=true`. You can also set the number of concurrently fetching goroutines by setting the `maxDownloadThreads` query parameter (default is 10):
50+
```
51+
token:[your token]@[Workspace hostname]:[Port number][Endpoint HTTP Path]?useCloudFetch=true&maxDownloadThreads=3
52+
```
4953

5054
### Connecting with a new Connector
5155

connection.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
283283
GetDirectResults: &cli_service.TSparkGetDirectResults{
284284
MaxRows: int64(c.cfg.MaxRows),
285285
},
286+
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
286287
}
287288

288289
if c.cfg.UseArrowBatches {
@@ -295,6 +296,10 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
295296
}
296297
}
297298

299+
if c.cfg.UseCloudFetch {
300+
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
301+
}
302+
298303
ctx = driverctx.NewContextWithConnId(ctx, c.id)
299304
resp, err := c.client.ExecuteStatement(ctx, &req)
300305

connector.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,17 @@ func WithTransport(t http.RoundTripper) connOption {
245245
c.Transport = t
246246
}
247247
}
248+
249+
// WithCloudFetch sets up the use of cloud fetch for query execution. Default is false.
250+
func WithCloudFetch(useCloudFetch bool) connOption {
251+
return func(c *config.Config) {
252+
c.UseCloudFetch = useCloudFetch
253+
}
254+
}
255+
256+
// WithMaxDownloadThreads sets up maximum download threads for cloud fetch. Default is 10.
257+
func WithMaxDownloadThreads(numThreads int) connOption {
258+
return func(c *config.Config) {
259+
c.MaxDownloadThreads = numThreads
260+
}
261+
}

doc.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ Supported optional connection parameters can be specified in param=value and inc
3737
- maxRows: Sets up the max rows fetched per request. Default is 100000
3838
- timeout: Adds timeout (in seconds) for the server query execution. Default is no timeout
3939
- userAgentEntry: Used to identify partners. Set as a string with format <isv-name+product-name>
40+
- useCloudFetch: Used to enable cloud fetch for the query execution. Default is false
41+
- maxDownloadThreads: Sets up the max number of concurrent workers for cloud fetch. Default is 10
4042
4143
Supported optional session parameters can be specified in param=value and include:
4244
@@ -79,6 +81,8 @@ Supported functional options include:
7981
- WithSessionParams(<params_map> map[string]string): Sets up session parameters including "timezone" and "ansi_mode". Optional
8082
- WithTimeout(<timeout> Duration). Adds timeout (in time.Duration) for the server query execution. Default is no timeout. Optional
8183
- WithUserAgentEntry(<isv-name+product-name> string). Used to identify partners. Optional
84+
- WithCloudFetch (bool). Used to enable cloud fetch for the query execution. Default is false. Optional
85+
- WithMaxDownloadThreads (<num_threads> int). Sets up the max number of concurrent workers for cloud fetch. Default is 10. Optional
8286
8387
# Query cancellation and timeout
8488

errors/errors.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,13 @@ const (
3131

3232
// Execution error messages (query failure)
3333
ErrQueryExecution = "failed to execute query"
34+
ErrLinkExpired = "link expired"
3435
)
3536

37+
func InvalidDSNFormat(param string, value string, expected string) string {
38+
return fmt.Sprintf("invalid DSN: param %s with value %s is not of type %s", param, value, expected)
39+
}
40+
3641
func ErrInvalidOperationState(state string) string {
3742
return fmt.Sprintf("invalid operation state %s. This should not have happened", state)
3843
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
dbsql "github.com/databricks/databricks-sql-go"
8+
"github.com/stretchr/testify/assert"
9+
"os"
10+
"strconv"
11+
"testing"
12+
"time"
13+
)
14+
15+
type row struct {
16+
symbol string
17+
companyName string
18+
industry string
19+
date string
20+
open float64
21+
high float64
22+
low float64
23+
close float64
24+
volume int
25+
change float64
26+
changePercentage float64
27+
upTrend bool
28+
volatile bool
29+
}
30+
31+
func runTest(withCloudFetch bool, query string) ([]row, error) {
32+
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
33+
if err != nil {
34+
return nil, err
35+
}
36+
37+
connector, err := dbsql.NewConnector(
38+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
39+
dbsql.WithPort(port),
40+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
41+
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
42+
dbsql.WithTimeout(10),
43+
dbsql.WithInitialNamespace("hive_metastore", "default"),
44+
dbsql.WithCloudFetch(withCloudFetch),
45+
)
46+
if err != nil {
47+
return nil, err
48+
}
49+
db := sql.OpenDB(connector)
50+
defer db.Close()
51+
52+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
53+
defer cancel()
54+
if err := db.PingContext(ctx); err != nil {
55+
return nil, err
56+
}
57+
rows, err1 := db.QueryContext(context.Background(), query)
58+
defer rows.Close()
59+
60+
if err1 != nil {
61+
if err1 == sql.ErrNoRows {
62+
fmt.Println("not found")
63+
return nil, err
64+
} else {
65+
return nil, err
66+
}
67+
}
68+
var res []row
69+
for rows.Next() {
70+
r := row{}
71+
err := rows.Scan(&r.symbol, &r.companyName, &r.industry, &r.date, &r.open, &r.high, &r.low, &r.close, &r.volume, &r.change, &r.changePercentage, &r.upTrend, &r.volatile)
72+
if err != nil {
73+
fmt.Println(err)
74+
return nil, err
75+
}
76+
res = append(res, r)
77+
}
78+
return res, nil
79+
}
80+
81+
func TestCloudFetch(t *testing.T) {
82+
t.Run("Compare local batch to cloud fetch", func(t *testing.T) {
83+
query := "select * from stock_data where date is not null and volume is not null order by date, symbol limit 10000000"
84+
85+
// Local arrow batch
86+
abRes, err := runTest(false, query)
87+
assert.NoError(t, err)
88+
89+
// Cloud fetch batch
90+
cfRes, err := runTest(true, query)
91+
assert.NoError(t, err)
92+
93+
for i := 0; i < len(abRes); i++ {
94+
assert.Equal(t, abRes[i], cfRes[i], fmt.Sprintf("not equal for row: %d", i))
95+
}
96+
})
97+
}

internal/config/config.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ type UserConfig struct {
9999
RetryWaitMax time.Duration
100100
RetryMax int
101101
Transport http.RoundTripper
102+
UseLz4Compression bool
103+
CloudFetchConfig
102104
}
103105

104106
// DeepCopy returns a true deep copy of UserConfig
@@ -138,6 +140,8 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
138140
RetryWaitMax: ucfg.RetryWaitMax,
139141
RetryMax: ucfg.RetryMax,
140142
Transport: ucfg.Transport,
143+
UseLz4Compression: ucfg.UseLz4Compression,
144+
CloudFetchConfig: ucfg.CloudFetchConfig,
141145
}
142146
}
143147

@@ -170,6 +174,8 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
170174
if ucfg.RetryWaitMax == 0 {
171175
ucfg.RetryWaitMax = 30 * time.Second
172176
}
177+
ucfg.UseLz4Compression = false
178+
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()
173179

174180
return ucfg
175181
}
@@ -194,7 +200,7 @@ func WithDefaults() *Config {
194200

195201
}
196202

197-
// ParseDSN constructs UserConfig by parsing DSN string supplied to `sql.Open()`
203+
// ParseDSN constructs UserConfig and CloudFetchConfig by parsing DSN string supplied to `sql.Open()`
198204
func ParseDSN(dsn string) (UserConfig, error) {
199205
fullDSN := dsn
200206
if !strings.HasPrefix(dsn, "https://") && !strings.HasPrefix(dsn, "http://") {
@@ -266,6 +272,25 @@ func ParseDSN(dsn string) (UserConfig, error) {
266272
ucfg.Schema = params.Get("schema")
267273
params.Del("schema")
268274
}
275+
276+
// Cloud Fetch parameters
277+
if params.Has("useCloudFetch") {
278+
useCloudFetch, err := strconv.ParseBool(params.Get("useCloudFetch"))
279+
if err != nil {
280+
return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.InvalidDSNFormat("useCloudFetch", params.Get("useCloudFetch"), "bool"), err)
281+
}
282+
ucfg.UseCloudFetch = useCloudFetch
283+
}
284+
params.Del("useCloudFetch")
285+
if params.Has("maxDownloadThreads") {
286+
numThreads, err := strconv.Atoi(params.Get("maxDownloadThreads"))
287+
if err != nil {
288+
return UserConfig{}, dbsqlerrint.NewRequestError(context.TODO(), dbsqlerr.InvalidDSNFormat("maxDownloadThreads", params.Get("maxDownloadThreads"), "int"), err)
289+
}
290+
ucfg.MaxDownloadThreads = numThreads
291+
}
292+
params.Del("maxDownloadThreads")
293+
269294
for k := range params {
270295
if strings.ToLower(k) == "timezone" {
271296
ucfg.Location, err = time.LoadLocation(params.Get("timezone"))
@@ -310,3 +335,37 @@ func (arrowConfig ArrowConfig) DeepCopy() ArrowConfig {
310335
UseArrowNativeIntervalTypes: arrowConfig.UseArrowNativeIntervalTypes,
311336
}
312337
}
338+
339+
type CloudFetchConfig struct {
340+
UseCloudFetch bool
341+
MaxDownloadThreads int
342+
MaxFilesInMemory int
343+
MinTimeToExpiry time.Duration
344+
}
345+
346+
func (cfg CloudFetchConfig) WithDefaults() CloudFetchConfig {
347+
cfg.UseCloudFetch = false
348+
349+
if cfg.MaxDownloadThreads <= 0 {
350+
cfg.MaxDownloadThreads = 10
351+
}
352+
353+
if cfg.MaxFilesInMemory < 1 {
354+
cfg.MaxFilesInMemory = 10
355+
}
356+
357+
if cfg.MinTimeToExpiry < 0 {
358+
cfg.MinTimeToExpiry = 0 * time.Second
359+
}
360+
361+
return cfg
362+
}
363+
364+
func (cfg CloudFetchConfig) DeepCopy() CloudFetchConfig {
365+
return CloudFetchConfig{
366+
UseCloudFetch: cfg.UseCloudFetch,
367+
MaxDownloadThreads: cfg.MaxDownloadThreads,
368+
MaxFilesInMemory: cfg.MaxFilesInMemory,
369+
MinTimeToExpiry: cfg.MinTimeToExpiry,
370+
}
371+
}

0 commit comments

Comments
 (0)