Skip to content

Commit e217729

Browse files
committed
Initial commit
Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent 86525e6 commit e217729

3 files changed

Lines changed: 98 additions & 38 deletions

File tree

connection.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,13 @@ func (c *conn) IsValid() bool {
9494
// ExecContext honors the context timeout and return when it is canceled.
9595
// Statement ExecContext is the same as connection ExecContext
9696
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
97+
s, ok := ctx.(StagingCtx)
98+
if ok {
99+
if s.IsStagingOperation {
100+
return c.ExecStagingOperation(s, query, args)
101+
}
102+
}
103+
97104
corrId := driverctx.CorrelationIdFromContext(ctx)
98105
log := logger.WithContext(c.id, corrId, "")
99106
msg, start := logger.Track("ExecContext")
@@ -133,6 +140,22 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
133140
return &res, nil
134141
}
135142

143+
type StagingRow struct {
144+
presignedUrl string
145+
localFile string
146+
headers string
147+
operation string
148+
}
149+
150+
func (c *conn) ExecStagingOperation(ctx StagingCtx, query string, args []driver.NamedValue) (driver.Result, error) {
151+
row, err := c.QueryContext(ctx, query, args)
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
row.Next()
157+
}
158+
136159
// QueryContext executes a query that may return rows, such as a
137160
// SELECT.
138161
//

connector.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,36 @@ type connector struct {
2525
client *http.Client
2626
}
2727

28+
type StagingCtx struct {
29+
IsStagingOperation bool
30+
StagingAllowedLocalPath string
31+
}
32+
33+
func (ctx *StagingCtx) WithDefaults() {
34+
ctx.IsStagingOperation = true
35+
ctx.StagingAllowedLocalPath = "staging/"
36+
}
37+
38+
func (StagingCtx) Deadline() (deadline time.Time, ok bool) {
39+
return
40+
}
41+
42+
func (StagingCtx) Done() <-chan struct{} {
43+
return nil
44+
}
45+
46+
func (StagingCtx) Err() error {
47+
return nil
48+
}
49+
50+
func (StagingCtx) Value(key any) any {
51+
return nil
52+
}
53+
54+
func (StagingCtx) String() string {
55+
return "context.Background"
56+
}
57+
2858
// Connect returns a connection to the Databricks database from a connection pool.
2959
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
3060
var catalogName *cli_service.TIdentifier

internal/config/config.go

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,25 @@ func (c *Config) DeepCopy() *Config {
8282

8383
// UserConfig is the set of configurations exposed to users
8484
type UserConfig struct {
85-
Protocol string
86-
Host string // from databricks UI
87-
Port int // from databricks UI
88-
HTTPPath string // from databricks UI
89-
Catalog string
90-
Schema string
91-
Authenticator auth.Authenticator
92-
AccessToken string // from databricks UI
93-
MaxRows int // max rows per page
94-
QueryTimeout time.Duration // Timeout passed to server for query processing
95-
UserAgentEntry string
96-
Location *time.Location
97-
SessionParams map[string]string
98-
RetryWaitMin time.Duration
99-
RetryWaitMax time.Duration
100-
RetryMax int
101-
Transport http.RoundTripper
102-
UseLz4Compression bool
85+
Protocol string
86+
Host string // from databricks UI
87+
Port int // from databricks UI
88+
HTTPPath string // from databricks UI
89+
Catalog string
90+
Schema string
91+
Authenticator auth.Authenticator
92+
AccessToken string // from databricks UI
93+
MaxRows int // max rows per page
94+
QueryTimeout time.Duration // Timeout passed to server for query processing
95+
UserAgentEntry string
96+
Location *time.Location
97+
SessionParams map[string]string
98+
RetryWaitMin time.Duration
99+
RetryWaitMax time.Duration
100+
RetryMax int
101+
Transport http.RoundTripper
102+
UseLz4Compression bool
103+
StagingAllowedLocalPath string
103104
CloudFetchConfig
104105
}
105106

@@ -123,25 +124,26 @@ func (ucfg UserConfig) DeepCopy() UserConfig {
123124
}
124125

125126
return UserConfig{
126-
Protocol: ucfg.Protocol,
127-
Host: ucfg.Host,
128-
Port: ucfg.Port,
129-
HTTPPath: ucfg.HTTPPath,
130-
Catalog: ucfg.Catalog,
131-
Schema: ucfg.Schema,
132-
Authenticator: ucfg.Authenticator,
133-
AccessToken: ucfg.AccessToken,
134-
MaxRows: ucfg.MaxRows,
135-
QueryTimeout: ucfg.QueryTimeout,
136-
UserAgentEntry: ucfg.UserAgentEntry,
137-
Location: loccp,
138-
SessionParams: sessionParams,
139-
RetryWaitMin: ucfg.RetryWaitMin,
140-
RetryWaitMax: ucfg.RetryWaitMax,
141-
RetryMax: ucfg.RetryMax,
142-
Transport: ucfg.Transport,
143-
UseLz4Compression: ucfg.UseLz4Compression,
144-
CloudFetchConfig: ucfg.CloudFetchConfig,
127+
Protocol: ucfg.Protocol,
128+
Host: ucfg.Host,
129+
Port: ucfg.Port,
130+
HTTPPath: ucfg.HTTPPath,
131+
Catalog: ucfg.Catalog,
132+
Schema: ucfg.Schema,
133+
Authenticator: ucfg.Authenticator,
134+
AccessToken: ucfg.AccessToken,
135+
MaxRows: ucfg.MaxRows,
136+
QueryTimeout: ucfg.QueryTimeout,
137+
UserAgentEntry: ucfg.UserAgentEntry,
138+
Location: loccp,
139+
SessionParams: sessionParams,
140+
RetryWaitMin: ucfg.RetryWaitMin,
141+
RetryWaitMax: ucfg.RetryWaitMax,
142+
RetryMax: ucfg.RetryMax,
143+
Transport: ucfg.Transport,
144+
UseLz4Compression: ucfg.UseLz4Compression,
145+
StagingAllowedLocalPath: ucfg.StagingAllowedLocalPath,
146+
CloudFetchConfig: ucfg.CloudFetchConfig,
145147
}
146148
}
147149

@@ -176,7 +178,7 @@ func (ucfg UserConfig) WithDefaults() UserConfig {
176178
}
177179
ucfg.UseLz4Compression = false
178180
ucfg.CloudFetchConfig = CloudFetchConfig{}.WithDefaults()
179-
181+
ucfg.StagingAllowedLocalPath = "staging/"
180182
return ucfg
181183
}
182184

@@ -279,6 +281,11 @@ func ParseDSN(dsn string) (UserConfig, error) {
279281
ucfg.Location, err = time.LoadLocation(timezone)
280282
}
281283

284+
if params.Has("stagingAllowedLocalPath") {
285+
ucfg.StagingAllowedLocalPath = params.Get("stagingAllowedLocalPath")
286+
}
287+
params.Del("stagingAllowedLocalPath")
288+
282289
// any left over params are treated as session params
283290
if len(params.Values) > 0 {
284291
sessionParams := make(map[string]string)

0 commit comments

Comments
 (0)