Skip to content

Commit e93dd29

Browse files
committed
Fixed functionality a bit
Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent b3c175c commit e93dd29

4 files changed

Lines changed: 69 additions & 52 deletions

File tree

connection.go

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dbsql
22

33
import (
4+
"bytes"
45
"context"
56
"database/sql/driver"
67
"encoding/json"
@@ -101,12 +102,6 @@ func (c *conn) IsValid() bool {
101102
// ExecContext honors the context timeout and return when it is canceled.
102103
// Statement ExecContext is the same as connection ExecContext
103104
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
104-
s, ok := ctx.(StagingCtx)
105-
if ok {
106-
if s.IsStagingOperation {
107-
return c.ExecStagingOperation(s, query, args)
108-
}
109-
}
110105

111106
corrId := driverctx.CorrelationIdFromContext(ctx)
112107
log := logger.WithContext(c.id, corrId, "")
@@ -122,6 +117,25 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
122117
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
123118

124119
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
120+
req := cli_service.TGetResultSetMetadataReq{
121+
OperationHandle: exStmtResp.OperationHandle,
122+
}
123+
resp, err2 := c.client.GetResultSetMetadata(ctx, &req)
124+
if err2 != nil {
125+
return nil, errors.New("Error performing staging operation")
126+
}
127+
if *resp.IsStagingOperation {
128+
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
129+
row, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
130+
if err == nil {
131+
return nil, dbsqlerrint.NewDriverError(ctx, "Error reading row.", errors.New("Error reading row."))
132+
}
133+
return c.ExecStagingOperation(ctx, row)
134+
} else {
135+
return nil, dbsqlerrint.NewDriverError(ctx, "Staging ctx must be provided.", errors.New("Staging ctx must be provided."))
136+
}
137+
}
138+
125139
// we have an operation id so update the logger
126140
log = logger.WithContext(c.id, corrId, client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
127141

@@ -159,8 +173,14 @@ func (c *conn) HandleStagingPut(presignedUrl string, headers map[string]string,
159173
return nil, fmt.Errorf("cannot perform PUT without specifying a local_file")
160174
}
161175
client := &http.Client{}
162-
req, _ := http.NewRequest("GET", presignedUrl, nil)
163176

177+
dat, err := os.ReadFile(localFile)
178+
179+
req, _ := http.NewRequest("PUT", presignedUrl, bytes.NewReader(dat))
180+
181+
if err != nil {
182+
return nil, err
183+
}
164184
for k, v := range headers {
165185
req.Header.Set(k, v)
166186
}
@@ -227,9 +247,10 @@ func (c *conn) HandleStagingDelete(presignedUrl string, headers map[string]strin
227247
return driver.ResultNoRows, nil
228248
}
229249

230-
func localPathIsAllowed(ctx StagingCtx, localFile string) bool {
231-
for i := range ctx.StagingAllowedLocalPath {
232-
path := ctx.StagingAllowedLocalPath[i]
250+
func localPathIsAllowed(ctx context.Context, localFile string) bool {
251+
stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx)
252+
for i := range stagingAllowedLocalPaths {
253+
path := stagingAllowedLocalPaths[i]
233254
relativePath, err := filepath.Rel(path, localFile)
234255
if err != nil {
235256
return false
@@ -241,23 +262,30 @@ func localPathIsAllowed(ctx StagingCtx, localFile string) bool {
241262
return false
242263
}
243264

244-
func (c *conn) ExecStagingOperation(ctx StagingCtx, query string, args []driver.NamedValue) (driver.Result, error) {
245-
row, err := c.QueryContext(ctx, query, args)
246-
if err != nil {
247-
return nil, err
248-
}
265+
func (c *conn) ExecStagingOperation(
266+
ctx context.Context,
267+
row driver.Rows) (driver.Result, error) {
268+
249269
var sqlRow []driver.Value
250270
colNames := row.Columns()
251271
sqlRow = make([]driver.Value, len(colNames))
252272
row.Next(sqlRow)
253-
operation := sqlRow[0].(string)
254-
presignedUrl := sqlRow[1].(string)
255-
headersByteArr := []byte(sqlRow[2].(string))
273+
var stringValues []string = make([]string, 4)
274+
for i := range stringValues {
275+
if s, ok := sqlRow[i].(string); ok {
276+
stringValues[i] = s
277+
} else {
278+
return nil, fmt.Errorf("local file operations are restricted to paths within the configured staging_allowed_local_path")
279+
}
280+
}
281+
operation := stringValues[0]
282+
presignedUrl := stringValues[1]
283+
headersByteArr := []byte(stringValues[2])
256284
var headers map[string]string
257285
if err := json.Unmarshal(headersByteArr, &headers); err != nil {
258286
return nil, err
259287
}
260-
localFile := sqlRow[3].(string)
288+
localFile := stringValues[3]
261289
switch operation {
262290
case "PUT":
263291
if localPathIsAllowed(ctx, localFile) {

connector.go

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,6 @@ type connector struct {
2525
client *http.Client
2626
}
2727

28-
type StagingCtx struct {
29-
IsStagingOperation bool
30-
StagingAllowedLocalPath []string
31-
}
32-
33-
func (stagingCtx StagingCtx) WithDefaults() StagingCtx {
34-
stagingCtx.IsStagingOperation = true
35-
stagingCtx.StagingAllowedLocalPath = []string{"staging/"}
36-
37-
return stagingCtx
38-
}
39-
40-
func (StagingCtx) Deadline() (deadline time.Time, ok bool) {
41-
return
42-
}
43-
44-
func (StagingCtx) Done() <-chan struct{} {
45-
return nil
46-
}
47-
48-
func (StagingCtx) Err() error {
49-
return nil
50-
}
51-
52-
func (StagingCtx) Value(key any) any {
53-
return nil
54-
}
55-
56-
func (StagingCtx) String() string {
57-
return "context.Background"
58-
}
59-
6028
// Connect returns a connection to the Databricks database from a connection pool.
6129
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6230
var catalogName *cli_service.TIdentifier

driverctx/ctx.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ const (
1414
QueryIdContextKey
1515
QueryIdCallbackKey
1616
ConnIdCallbackKey
17+
StagingAllowedLocalPathKey
1718
)
1819

1920
type IdCallbackFunc func(string)
@@ -79,10 +80,27 @@ func QueryIdFromContext(ctx context.Context) string {
7980
return queryId
8081
}
8182

83+
// QueryIdFromContext retrieves the queryId stored in context.
84+
func StagingPathsFromContext(ctx context.Context) []string {
85+
if ctx == nil {
86+
return []string{}
87+
}
88+
89+
stagingAllowedLocalPath, ok := ctx.Value(StagingAllowedLocalPathKey).([]string)
90+
if !ok {
91+
return []string{}
92+
}
93+
return stagingAllowedLocalPath
94+
}
95+
8296
func NewContextWithQueryIdCallback(ctx context.Context, callback IdCallbackFunc) context.Context {
8397
return context.WithValue(ctx, QueryIdCallbackKey, callback)
8498
}
8599

86100
func NewContextWithConnIdCallback(ctx context.Context, callback IdCallbackFunc) context.Context {
87101
return context.WithValue(ctx, ConnIdCallbackKey, callback)
88102
}
103+
104+
func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []string) context.Context {
105+
return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath)
106+
}

examples/staging/main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package main
22

33
import (
4+
"context"
45
"database/sql"
56
"fmt"
67
"log"
78
"os"
89
"strconv"
910

1011
dbsql "github.com/databricks/databricks-sql-go"
12+
"github.com/databricks/databricks-sql-go/driverctx"
1113
"github.com/joho/godotenv"
1214
)
1315

@@ -39,7 +41,8 @@ func main() {
3941

4042
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
4143
// defer cancel()
42-
ctx := dbsql.StagingCtx{}.WithDefaults()
44+
45+
ctx := driverctx.NewContextWithStagingInfo(context.Background(), []string{"staging"})
4346
if err := db.Ping(); err != nil {
4447
fmt.Println(err)
4548
}

0 commit comments

Comments
 (0)