Skip to content

Commit 0dd993b

Browse files
committed
Initial working commit
Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent e217729 commit 0dd993b

2 files changed

Lines changed: 136 additions & 10 deletions

File tree

connection.go

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@ package dbsql
33
import (
44
"context"
55
"database/sql/driver"
6+
"encoding/json"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"os"
11+
"path/filepath"
12+
"strings"
613
"time"
714

815
"github.com/databricks/databricks-sql-go/driverctx"
@@ -140,20 +147,137 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
140147
return &res, nil
141148
}
142149

143-
type StagingRow struct {
144-
presignedUrl string
145-
localFile string
146-
headers string
147-
operation string
150+
func Succeeded(response *http.Response) bool {
151+
if response.StatusCode == 200 || response.StatusCode == 201 || response.StatusCode == 202 || response.StatusCode == 204 {
152+
return true
153+
}
154+
return false
155+
}
156+
157+
func (c *conn) HandleStagingPut(presignedUrl string, headers map[string]string, localFile string) (driver.Result, error) {
158+
if localFile == "" {
159+
return nil, fmt.Errorf("cannot perform PUT without specifying a local_file")
160+
}
161+
client := &http.Client{}
162+
req, _ := http.NewRequest("GET", presignedUrl, nil)
163+
164+
for k, v := range headers {
165+
req.Header.Set(k, v)
166+
}
167+
res, err := client.Do(req)
168+
if err != nil {
169+
return nil, err
170+
}
171+
defer res.Body.Close()
172+
content, err := io.ReadAll(res.Body)
173+
174+
if err != nil || !Succeeded(res) {
175+
return nil, fmt.Errorf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content)
176+
}
177+
return driver.ResultNoRows, nil
178+
179+
}
180+
181+
func (c *conn) HandleStagingGet(presignedUrl string, headers map[string]string, localFile string) (driver.Result, error) {
182+
if localFile == "" {
183+
return nil, fmt.Errorf("cannot perform GET without specifying a local_file")
184+
}
185+
client := &http.Client{}
186+
req, _ := http.NewRequest("GET", presignedUrl, nil)
187+
188+
for k, v := range headers {
189+
req.Header.Set(k, v)
190+
}
191+
res, err := client.Do(req)
192+
if err != nil {
193+
return nil, err
194+
}
195+
defer res.Body.Close()
196+
content, err := io.ReadAll(res.Body)
197+
198+
if err != nil || !Succeeded(res) {
199+
return nil, fmt.Errorf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content)
200+
}
201+
202+
err = os.WriteFile(localFile, content, 0644)
203+
if err != nil {
204+
return nil, err
205+
}
206+
return driver.ResultNoRows, nil
207+
208+
}
209+
210+
func (c *conn) HandleStagingDelete(presignedUrl string, headers map[string]string) (driver.Result, error) {
211+
client := &http.Client{}
212+
req, _ := http.NewRequest("DELETE", presignedUrl, nil)
213+
for k, v := range headers {
214+
req.Header.Set(k, v)
215+
}
216+
res, err := client.Do(req)
217+
if err != nil {
218+
return nil, err
219+
}
220+
defer res.Body.Close()
221+
content, err := io.ReadAll(res.Body)
222+
223+
if err != nil || !Succeeded(res) {
224+
return nil, fmt.Errorf("staging operation over HTTP was unsuccessful: %d-%s", res.StatusCode, content)
225+
}
226+
227+
return driver.ResultNoRows, nil
228+
}
229+
230+
func localPathIsAllowed(ctx StagingCtx, localFile string) bool {
231+
for i := range ctx.StagingAllowedLocalPath {
232+
path := ctx.StagingAllowedLocalPath[i]
233+
relativePath, err := filepath.Rel(path, localFile)
234+
if err != nil {
235+
return false
236+
}
237+
if !strings.Contains(relativePath, "../") {
238+
return true
239+
}
240+
}
241+
return false
148242
}
149243

150244
func (c *conn) ExecStagingOperation(ctx StagingCtx, query string, args []driver.NamedValue) (driver.Result, error) {
151245
row, err := c.QueryContext(ctx, query, args)
152246
if err != nil {
153247
return nil, err
154248
}
249+
var sqlRow []driver.Value
250+
colNames := row.Columns()
251+
sqlRow = make([]driver.Value, len(colNames))
252+
row.Next(sqlRow)
253+
operation := sqlRow[0].(string)
254+
presignedUrl := sqlRow[1].(string)
255+
headersByteArr := []byte(sqlRow[2].(string))
256+
var headers map[string]string
257+
if err := json.Unmarshal(headersByteArr, &headers); err != nil {
258+
return nil, err
259+
}
260+
localFile := sqlRow[3].(string)
261+
switch operation {
262+
case "PUT":
263+
if localPathIsAllowed(ctx, localFile) {
264+
c.HandleStagingPut(presignedUrl, headers, localFile)
265+
} else {
266+
return nil, fmt.Errorf("local file operations are restricted to paths within the configured staging_allowed_local_path")
267+
}
268+
case "GET":
269+
if localPathIsAllowed(ctx, localFile) {
270+
c.HandleStagingGet(presignedUrl, headers, localFile)
271+
} else {
272+
return nil, fmt.Errorf("local file operations are restricted to paths within the configured staging_allowed_local_path")
273+
}
274+
case "DELETE":
275+
c.HandleStagingDelete(presignedUrl, headers)
276+
default:
277+
return nil, fmt.Errorf("operation %s is not supported. Supported operations are GET, PUT, and REMOVE", operation)
278+
}
155279

156-
row.Next()
280+
return driver.ResultNoRows, nil
157281
}
158282

159283
// QueryContext executes a query that may return rows, such as a

connector.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ type connector struct {
2727

2828
type StagingCtx struct {
2929
IsStagingOperation bool
30-
StagingAllowedLocalPath string
30+
StagingAllowedLocalPath []string
3131
}
3232

33-
func (ctx *StagingCtx) WithDefaults() {
34-
ctx.IsStagingOperation = true
35-
ctx.StagingAllowedLocalPath = "staging/"
33+
func (stagingCtx StagingCtx) WithDefaults() StagingCtx {
34+
stagingCtx.IsStagingOperation = true
35+
stagingCtx.StagingAllowedLocalPath = []string{"staging/"}
36+
37+
return stagingCtx
3638
}
3739

3840
func (StagingCtx) Deadline() (deadline time.Time, ok bool) {

0 commit comments

Comments
 (0)