@@ -3,6 +3,13 @@ package dbsql
33import (
44 "context"
55 "database/sql/driver"
6+ "encoding/json"
7+ "fmt"
8+ "io"
9+ "net/http"
10+ "os"
11+ "path/filepath"
12+ "strings"
613 "time"
714
815 "github.com/databricks/databricks-sql-go/driverctx"
@@ -140,20 +147,137 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
140147 return & res , nil
141148}
142149
143- type StagingRow struct {
144- presignedUrl string
145- localFile string
146- headers string
147- operation string
150+ func Succeeded (response * http.Response ) bool {
151+ if response .StatusCode == 200 || response .StatusCode == 201 || response .StatusCode == 202 || response .StatusCode == 204 {
152+ return true
153+ }
154+ return false
155+ }
156+
157+ func (c * conn ) HandleStagingPut (presignedUrl string , headers map [string ]string , localFile string ) (driver.Result , error ) {
158+ if localFile == "" {
159+ return nil , fmt .Errorf ("cannot perform PUT without specifying a local_file" )
160+ }
161+ client := & http.Client {}
162+ req , _ := http .NewRequest ("GET" , presignedUrl , nil )
163+
164+ for k , v := range headers {
165+ req .Header .Set (k , v )
166+ }
167+ res , err := client .Do (req )
168+ if err != nil {
169+ return nil , err
170+ }
171+ defer res .Body .Close ()
172+ content , err := io .ReadAll (res .Body )
173+
174+ if err != nil || ! Succeeded (res ) {
175+ return nil , fmt .Errorf ("staging operation over HTTP was unsuccessful: %d-%s" , res .StatusCode , content )
176+ }
177+ return driver .ResultNoRows , nil
178+
179+ }
180+
181+ func (c * conn ) HandleStagingGet (presignedUrl string , headers map [string ]string , localFile string ) (driver.Result , error ) {
182+ if localFile == "" {
183+ return nil , fmt .Errorf ("cannot perform GET without specifying a local_file" )
184+ }
185+ client := & http.Client {}
186+ req , _ := http .NewRequest ("GET" , presignedUrl , nil )
187+
188+ for k , v := range headers {
189+ req .Header .Set (k , v )
190+ }
191+ res , err := client .Do (req )
192+ if err != nil {
193+ return nil , err
194+ }
195+ defer res .Body .Close ()
196+ content , err := io .ReadAll (res .Body )
197+
198+ if err != nil || ! Succeeded (res ) {
199+ return nil , fmt .Errorf ("staging operation over HTTP was unsuccessful: %d-%s" , res .StatusCode , content )
200+ }
201+
202+ err = os .WriteFile (localFile , content , 0644 )
203+ if err != nil {
204+ return nil , err
205+ }
206+ return driver .ResultNoRows , nil
207+
208+ }
209+
210+ func (c * conn ) HandleStagingDelete (presignedUrl string , headers map [string ]string ) (driver.Result , error ) {
211+ client := & http.Client {}
212+ req , _ := http .NewRequest ("DELETE" , presignedUrl , nil )
213+ for k , v := range headers {
214+ req .Header .Set (k , v )
215+ }
216+ res , err := client .Do (req )
217+ if err != nil {
218+ return nil , err
219+ }
220+ defer res .Body .Close ()
221+ content , err := io .ReadAll (res .Body )
222+
223+ if err != nil || ! Succeeded (res ) {
224+ return nil , fmt .Errorf ("staging operation over HTTP was unsuccessful: %d-%s" , res .StatusCode , content )
225+ }
226+
227+ return driver .ResultNoRows , nil
228+ }
229+
230+ func localPathIsAllowed (ctx StagingCtx , localFile string ) bool {
231+ for i := range ctx .StagingAllowedLocalPath {
232+ path := ctx .StagingAllowedLocalPath [i ]
233+ relativePath , err := filepath .Rel (path , localFile )
234+ if err != nil {
235+ return false
236+ }
237+ if ! strings .Contains (relativePath , "../" ) {
238+ return true
239+ }
240+ }
241+ return false
148242}
149243
150244func (c * conn ) ExecStagingOperation (ctx StagingCtx , query string , args []driver.NamedValue ) (driver.Result , error ) {
151245 row , err := c .QueryContext (ctx , query , args )
152246 if err != nil {
153247 return nil , err
154248 }
249+ var sqlRow []driver.Value
250+ colNames := row .Columns ()
251+ sqlRow = make ([]driver.Value , len (colNames ))
252+ row .Next (sqlRow )
253+ operation := sqlRow [0 ].(string )
254+ presignedUrl := sqlRow [1 ].(string )
255+ headersByteArr := []byte (sqlRow [2 ].(string ))
256+ var headers map [string ]string
257+ if err := json .Unmarshal (headersByteArr , & headers ); err != nil {
258+ return nil , err
259+ }
260+ localFile := sqlRow [3 ].(string )
261+ switch operation {
262+ case "PUT" :
263+ if localPathIsAllowed (ctx , localFile ) {
264+ c .HandleStagingPut (presignedUrl , headers , localFile )
265+ } else {
266+ return nil , fmt .Errorf ("local file operations are restricted to paths within the configured staging_allowed_local_path" )
267+ }
268+ case "GET" :
269+ if localPathIsAllowed (ctx , localFile ) {
270+ c .HandleStagingGet (presignedUrl , headers , localFile )
271+ } else {
272+ return nil , fmt .Errorf ("local file operations are restricted to paths within the configured staging_allowed_local_path" )
273+ }
274+ case "DELETE" :
275+ c .HandleStagingDelete (presignedUrl , headers )
276+ default :
277+ return nil , fmt .Errorf ("operation %s is not supported. Supported operations are GET, PUT, and REMOVE" , operation )
278+ }
155279
156- row . Next ()
280+ return driver . ResultNoRows , nil
157281}
158282
159283// QueryContext executes a query that may return rows, such as a
0 commit comments