11package dbsql
22
33import (
4+ "bytes"
45 "context"
56 "database/sql/driver"
67 "encoding/json"
@@ -101,12 +102,6 @@ func (c *conn) IsValid() bool {
101102// ExecContext honors the context timeout and return when it is canceled.
102103// Statement ExecContext is the same as connection ExecContext
103104func (c * conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
104- s , ok := ctx .(StagingCtx )
105- if ok {
106- if s .IsStagingOperation {
107- return c .ExecStagingOperation (s , query , args )
108- }
109- }
110105
111106 corrId := driverctx .CorrelationIdFromContext (ctx )
112107 log := logger .WithContext (c .id , corrId , "" )
@@ -122,6 +117,25 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
122117 exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
123118
124119 if exStmtResp != nil && exStmtResp .OperationHandle != nil {
120+ req := cli_service.TGetResultSetMetadataReq {
121+ OperationHandle : exStmtResp .OperationHandle ,
122+ }
123+ resp , err2 := c .client .GetResultSetMetadata (ctx , & req )
124+ if err2 != nil {
125+ return nil , errors .New ("Error performing staging operation" )
126+ }
127+ if * resp .IsStagingOperation {
128+ if len (driverctx .StagingPathsFromContext (ctx )) != 0 {
129+ row , err := rows .NewRows (c .id , corrId , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
130+ if err == nil {
131+ return nil , dbsqlerrint .NewDriverError (ctx , "Error reading row." , errors .New ("Error reading row." ))
132+ }
133+ return c .ExecStagingOperation (ctx , row )
134+ } else {
135+ return nil , dbsqlerrint .NewDriverError (ctx , "Staging ctx must be provided." , errors .New ("Staging ctx must be provided." ))
136+ }
137+ }
138+
125139 // we have an operation id so update the logger
126140 log = logger .WithContext (c .id , corrId , client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID ))
127141
@@ -159,8 +173,14 @@ func (c *conn) HandleStagingPut(presignedUrl string, headers map[string]string,
159173 return nil , fmt .Errorf ("cannot perform PUT without specifying a local_file" )
160174 }
161175 client := & http.Client {}
162- req , _ := http .NewRequest ("GET" , presignedUrl , nil )
163176
177+ dat , err := os .ReadFile (localFile )
178+
179+ req , _ := http .NewRequest ("PUT" , presignedUrl , bytes .NewReader (dat ))
180+
181+ if err != nil {
182+ return nil , err
183+ }
164184 for k , v := range headers {
165185 req .Header .Set (k , v )
166186 }
@@ -227,9 +247,10 @@ func (c *conn) HandleStagingDelete(presignedUrl string, headers map[string]strin
227247 return driver .ResultNoRows , nil
228248}
229249
230- func localPathIsAllowed (ctx StagingCtx , localFile string ) bool {
231- for i := range ctx .StagingAllowedLocalPath {
232- path := ctx .StagingAllowedLocalPath [i ]
250+ func localPathIsAllowed (ctx context.Context , localFile string ) bool {
251+ stagingAllowedLocalPaths := driverctx .StagingPathsFromContext (ctx )
252+ for i := range stagingAllowedLocalPaths {
253+ path := stagingAllowedLocalPaths [i ]
233254 relativePath , err := filepath .Rel (path , localFile )
234255 if err != nil {
235256 return false
@@ -241,23 +262,30 @@ func localPathIsAllowed(ctx StagingCtx, localFile string) bool {
241262 return false
242263}
243264
244- func (c * conn ) ExecStagingOperation (ctx StagingCtx , query string , args []driver.NamedValue ) (driver.Result , error ) {
245- row , err := c .QueryContext (ctx , query , args )
246- if err != nil {
247- return nil , err
248- }
265+ func (c * conn ) ExecStagingOperation (
266+ ctx context.Context ,
267+ row driver.Rows ) (driver.Result , error ) {
268+
249269 var sqlRow []driver.Value
250270 colNames := row .Columns ()
251271 sqlRow = make ([]driver.Value , len (colNames ))
252272 row .Next (sqlRow )
253- operation := sqlRow [0 ].(string )
254- presignedUrl := sqlRow [1 ].(string )
255- headersByteArr := []byte (sqlRow [2 ].(string ))
273+ var stringValues []string = make ([]string , 4 )
274+ for i := range stringValues {
275+ if s , ok := sqlRow [i ].(string ); ok {
276+ stringValues [i ] = s
277+ } else {
278+ return nil , fmt .Errorf ("local file operations are restricted to paths within the configured staging_allowed_local_path" )
279+ }
280+ }
281+ operation := stringValues [0 ]
282+ presignedUrl := stringValues [1 ]
283+ headersByteArr := []byte (stringValues [2 ])
256284 var headers map [string ]string
257285 if err := json .Unmarshal (headersByteArr , & headers ); err != nil {
258286 return nil , err
259287 }
260- localFile := sqlRow [3 ].( string )
288+ localFile := stringValues [3 ]
261289 switch operation {
262290 case "PUT" :
263291 if localPathIsAllowed (ctx , localFile ) {
0 commit comments