@@ -3,8 +3,6 @@ package dbsql
33import (
44 "context"
55 "database/sql/driver"
6- "fmt"
7- "strconv"
86 "time"
97
108 "github.com/databricks/databricks-sql-go/driverctx"
@@ -102,9 +100,6 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
102100 defer log .Duration (msg , start )
103101
104102 ctx = driverctx .NewContextWithConnId (ctx , c .id )
105- if len (args ) > 0 {
106- return nil , dbsqlerrint .NewDriverError (ctx , dbsqlerr .ErrParametersNotSupported , nil )
107- }
108103
109104 exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
110105
@@ -145,9 +140,6 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
145140 msg , start := log .Track ("QueryContext" )
146141
147142 ctx = driverctx .NewContextWithConnId (ctx , c .id )
148- if len (args ) > 0 {
149- return nil , dbsqlerrint .NewDriverError (ctx , dbsqlerr .ErrParametersNotSupported , nil )
150- }
151143
152144 // first we try to get the results synchronously.
153145 // at any point in time that the context is done we must cancel and return
@@ -288,7 +280,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
288280 MaxRows : int64 (c .cfg .MaxRows ),
289281 },
290282 CanDecompressLZ4Result_ : & c .cfg .UseLz4Compression ,
291- Parameters : namedValuesToTSparkParams (args ),
283+ Parameters : convertNamedValuesToSparkParams (args ),
292284 }
293285
294286 if c .cfg .UseArrowBatches {
@@ -342,87 +334,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
342334 return resp , err
343335}
344336
345- func namedValuesToTSparkParams (args []driver.NamedValue ) []* cli_service.TSparkParameter {
346- var ts []string = []string {"STRING" , "DOUBLE" , "BOOLEAN" , "TIMESTAMP" , "FLOAT" , "INTEGER" , "TINYINT" , "SMALLINT" , "BIGINT" }
347- var params []* cli_service.TSparkParameter
348- for i := range args {
349- arg := args [i ]
350- param := cli_service.TSparkParameter {Value : & cli_service.TSparkParameterValue {}}
351- if arg .Name != "" {
352- param .Name = & arg .Name
353- } else {
354- i := int32 (arg .Ordinal )
355- param .Ordinal = & i
356- }
357-
358- switch t := arg .Value .(type ) {
359- case bool :
360- b := arg .Value .(bool )
361- param .Value .BooleanValue = & b
362- param .Type = & ts [2 ]
363- case string :
364- s := arg .Value .(string )
365- param .Value .StringValue = & s
366- param .Type = & ts [0 ]
367- case int :
368- f := float64 (t )
369- param .Value .DoubleValue = & f
370- param .Type = & ts [5 ]
371- case uint :
372- f := float64 (t )
373- param .Value .DoubleValue = & f
374- param .Type = & ts [5 ]
375- case int8 :
376- f := float64 (t )
377- param .Value .DoubleValue = & f
378- param .Type = & ts [6 ]
379- case uint8 :
380- f := float64 (t )
381- param .Value .DoubleValue = & f
382- param .Type = & ts [6 ]
383- case int16 :
384- f := float64 (t )
385- param .Value .DoubleValue = & f
386- param .Type = & ts [7 ]
387- case uint16 :
388- f := float64 (t )
389- param .Value .DoubleValue = & f
390- param .Type = & ts [7 ]
391- case int32 :
392- f := float64 (t )
393- param .Value .DoubleValue = & f
394- param .Type = & ts [5 ]
395- case uint32 :
396- f := float64 (t )
397- param .Value .DoubleValue = & f
398- param .Type = & ts [5 ]
399- case int64 :
400- s := strconv .FormatInt (t , 10 )
401- param .Value .StringValue = & s
402- param .Type = & ts [8 ]
403- case uint64 :
404- s := strconv .FormatUint (t , 10 )
405- param .Value .StringValue = & s
406- param .Type = & ts [8 ]
407- case float32 :
408- f := float64 (t )
409- param .Value .DoubleValue = & f
410- param .Type = & ts [4 ]
411- case time.Time :
412- s := t .String ()
413- param .Value .StringValue = & s
414- param .Type = & ts [3 ]
415- default :
416- s := fmt .Sprintf ("%s" , arg .Value )
417- param .Value .StringValue = & s
418- param .Type = & ts [0 ]
419- }
420-
421- params = append (params , & param )
422- }
423- return params
424- }
425-
426337func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
427338 corrId := driverctx .CorrelationIdFromContext (ctx )
428339 log := logger .WithContext (c .id , corrId , client .SprintGuid (opHandle .OperationId .GUID ))
@@ -481,6 +392,18 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
481392 return statusResp , nil
482393}
483394
395+ func (c * conn ) CheckNamedValue (nv * driver.NamedValue ) error {
396+ var err error
397+ if dbsqlParam , ok := nv .Value .(DBSqlParam ); ok {
398+ nv .Name = dbsqlParam .Name
399+ dbsqlParam .Value , err = driver .DefaultParameterConverter .ConvertValue (dbsqlParam .Value )
400+ return err
401+ }
402+
403+ nv .Value , err = driver .DefaultParameterConverter .ConvertValue (nv .Value )
404+ return err
405+ }
406+
484407var _ driver.Conn = (* conn )(nil )
485408var _ driver.Pinger = (* conn )(nil )
486409var _ driver.SessionResetter = (* conn )(nil )
@@ -489,3 +412,4 @@ var _ driver.ExecerContext = (*conn)(nil)
489412var _ driver.QueryerContext = (* conn )(nil )
490413var _ driver.ConnPrepareContext = (* conn )(nil )
491414var _ driver.ConnBeginTx = (* conn )(nil )
415+ var _ driver.NamedValueChecker = (* conn )(nil )
0 commit comments