Skip to content

Commit 73073d2

Browse files
authored
[PECO-1016-2] Add handling for special types (#158)
In this PR, we add further handling to allow for special types to be set via the DBSQLParams variable.
2 parents 34599c4 + 27d9a87 commit 73073d2

5 files changed

Lines changed: 200 additions & 102 deletions

File tree

connection.go

Lines changed: 14 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package dbsql
33
import (
44
"context"
55
"database/sql/driver"
6-
"fmt"
7-
"strconv"
86
"time"
97

108
"github.com/databricks/databricks-sql-go/driverctx"
@@ -102,9 +100,6 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
102100
defer log.Duration(msg, start)
103101

104102
ctx = driverctx.NewContextWithConnId(ctx, c.id)
105-
if len(args) > 0 {
106-
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
107-
}
108103

109104
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
110105

@@ -145,9 +140,6 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
145140
msg, start := log.Track("QueryContext")
146141

147142
ctx = driverctx.NewContextWithConnId(ctx, c.id)
148-
if len(args) > 0 {
149-
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
150-
}
151143

152144
// first we try to get the results synchronously.
153145
// at any point in time that the context is done we must cancel and return
@@ -288,7 +280,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
288280
MaxRows: int64(c.cfg.MaxRows),
289281
},
290282
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
291-
Parameters: namedValuesToTSparkParams(args),
283+
Parameters: convertNamedValuesToSparkParams(args),
292284
}
293285

294286
if c.cfg.UseArrowBatches {
@@ -342,87 +334,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
342334
return resp, err
343335
}
344336

345-
func namedValuesToTSparkParams(args []driver.NamedValue) []*cli_service.TSparkParameter {
346-
var ts []string = []string{"STRING", "DOUBLE", "BOOLEAN", "TIMESTAMP", "FLOAT", "INTEGER", "TINYINT", "SMALLINT", "BIGINT"}
347-
var params []*cli_service.TSparkParameter
348-
for i := range args {
349-
arg := args[i]
350-
param := cli_service.TSparkParameter{Value: &cli_service.TSparkParameterValue{}}
351-
if arg.Name != "" {
352-
param.Name = &arg.Name
353-
} else {
354-
i := int32(arg.Ordinal)
355-
param.Ordinal = &i
356-
}
357-
358-
switch t := arg.Value.(type) {
359-
case bool:
360-
b := arg.Value.(bool)
361-
param.Value.BooleanValue = &b
362-
param.Type = &ts[2]
363-
case string:
364-
s := arg.Value.(string)
365-
param.Value.StringValue = &s
366-
param.Type = &ts[0]
367-
case int:
368-
f := float64(t)
369-
param.Value.DoubleValue = &f
370-
param.Type = &ts[5]
371-
case uint:
372-
f := float64(t)
373-
param.Value.DoubleValue = &f
374-
param.Type = &ts[5]
375-
case int8:
376-
f := float64(t)
377-
param.Value.DoubleValue = &f
378-
param.Type = &ts[6]
379-
case uint8:
380-
f := float64(t)
381-
param.Value.DoubleValue = &f
382-
param.Type = &ts[6]
383-
case int16:
384-
f := float64(t)
385-
param.Value.DoubleValue = &f
386-
param.Type = &ts[7]
387-
case uint16:
388-
f := float64(t)
389-
param.Value.DoubleValue = &f
390-
param.Type = &ts[7]
391-
case int32:
392-
f := float64(t)
393-
param.Value.DoubleValue = &f
394-
param.Type = &ts[5]
395-
case uint32:
396-
f := float64(t)
397-
param.Value.DoubleValue = &f
398-
param.Type = &ts[5]
399-
case int64:
400-
s := strconv.FormatInt(t, 10)
401-
param.Value.StringValue = &s
402-
param.Type = &ts[8]
403-
case uint64:
404-
s := strconv.FormatUint(t, 10)
405-
param.Value.StringValue = &s
406-
param.Type = &ts[8]
407-
case float32:
408-
f := float64(t)
409-
param.Value.DoubleValue = &f
410-
param.Type = &ts[4]
411-
case time.Time:
412-
s := t.String()
413-
param.Value.StringValue = &s
414-
param.Type = &ts[3]
415-
default:
416-
s := fmt.Sprintf("%s", arg.Value)
417-
param.Value.StringValue = &s
418-
param.Type = &ts[0]
419-
}
420-
421-
params = append(params, &param)
422-
}
423-
return params
424-
}
425-
426337
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
427338
corrId := driverctx.CorrelationIdFromContext(ctx)
428339
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
@@ -481,6 +392,18 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
481392
return statusResp, nil
482393
}
483394

395+
func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
396+
var err error
397+
if dbsqlParam, ok := nv.Value.(DBSqlParam); ok {
398+
nv.Name = dbsqlParam.Name
399+
dbsqlParam.Value, err = driver.DefaultParameterConverter.ConvertValue(dbsqlParam.Value)
400+
return err
401+
}
402+
403+
nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
404+
return err
405+
}
406+
484407
var _ driver.Conn = (*conn)(nil)
485408
var _ driver.Pinger = (*conn)(nil)
486409
var _ driver.SessionResetter = (*conn)(nil)
@@ -489,3 +412,4 @@ var _ driver.ExecerContext = (*conn)(nil)
489412
var _ driver.QueryerContext = (*conn)(nil)
490413
var _ driver.ConnPrepareContext = (*conn)(nil)
491414
var _ driver.ConnBeginTx = (*conn)(nil)
415+
var _ driver.NamedValueChecker = (*conn)(nil)

errors/errors.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ const (
1212
// Driver errors
1313
ErrNotImplemented = "not implemented"
1414
ErrTransactionsNotSupported = "transactions are not supported"
15-
ErrParametersNotSupported = "query parameters are not supported"
1615
ErrReadQueryStatus = "could not read query status"
1716
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"
1817

parameter_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package dbsql
2+
3+
import (
4+
"database/sql/driver"
5+
"strconv"
6+
"testing"
7+
"time"
8+
9+
"github.com/databricks/databricks-sql-go/internal/cli_service"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestParameter_Inference(t *testing.T) {
14+
t.Run("Should infer types correctly", func(t *testing.T) {
15+
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: Decimal}}}
16+
parameters := convertNamedValuesToSparkParams(values[:])
17+
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
18+
assert.NotNil(t, parameters[1].Value.StringValue)
19+
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
20+
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
21+
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
22+
assert.Equal(t, string("DECIMAL"), *parameters[4].Type)
23+
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
24+
})
25+
}
26+
func TestParameters_Names(t *testing.T) {
27+
t.Run("Should infer types correctly", func(t *testing.T) {
28+
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: Decimal, Value: "6.2"}}}
29+
parameters := convertNamedValuesToSparkParams(values[:])
30+
assert.Equal(t, string("1"), *parameters[0].Name)
31+
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
32+
assert.Equal(t, string("2"), *parameters[1].Name)
33+
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
34+
assert.Equal(t, string("DECIMAL"), *parameters[1].Type)
35+
})
36+
}

parameters.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
package dbsql
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
"strconv"
7+
"time"
8+
9+
"github.com/databricks/databricks-sql-go/internal/cli_service"
10+
)
11+
12+
type DBSqlParam struct {
13+
Name string
14+
Type SqlType
15+
Value any
16+
}
17+
18+
type SqlType int64
19+
20+
const (
21+
String SqlType = iota
22+
Date
23+
Timestamp
24+
Float
25+
Decimal
26+
Double
27+
Integer
28+
BigInt
29+
SmallInt
30+
TinyInt
31+
Boolean
32+
IntervalMonth
33+
IntervalDay
34+
)
35+
36+
func (s SqlType) String() string {
37+
switch s {
38+
case String:
39+
return "STRING"
40+
case Date:
41+
return "DATE"
42+
case Timestamp:
43+
return "TIMESTAMP"
44+
case Float:
45+
return "FLOAT"
46+
case Decimal:
47+
return "DECIMAL"
48+
case Double:
49+
return "DOUBLE"
50+
case Integer:
51+
return "INTEGER"
52+
case BigInt:
53+
return "BIGINT"
54+
case SmallInt:
55+
return "SMALLINT"
56+
case TinyInt:
57+
return "TINYINT"
58+
case Boolean:
59+
return "BOOLEAN"
60+
case IntervalMonth:
61+
return "INTERVAL MONTH"
62+
case IntervalDay:
63+
return "INTERVAL DAY"
64+
}
65+
return "unknown"
66+
}
67+
68+
func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
69+
var params []DBSqlParam
70+
for i := range namedValues {
71+
namedValue := namedValues[i]
72+
param := *new(DBSqlParam)
73+
param.Name = namedValue.Name
74+
param.Value = namedValue.Value
75+
params = append(params, param)
76+
}
77+
return params
78+
}
79+
80+
func inferTypes(params []DBSqlParam) {
81+
for i := range params {
82+
param := &params[i]
83+
switch value := param.Value.(type) {
84+
case bool:
85+
param.Value = strconv.FormatBool(value)
86+
param.Type = Boolean
87+
case string:
88+
param.Value = value
89+
param.Type = String
90+
case int:
91+
param.Value = strconv.Itoa(value)
92+
param.Type = Integer
93+
case uint:
94+
param.Value = strconv.FormatUint(uint64(value), 10)
95+
param.Type = Integer
96+
case int8:
97+
param.Value = strconv.Itoa(int(value))
98+
param.Type = Integer
99+
case uint8:
100+
param.Value = strconv.FormatUint(uint64(value), 10)
101+
param.Type = Integer
102+
case int16:
103+
param.Value = strconv.Itoa(int(value))
104+
param.Type = Integer
105+
case uint16:
106+
param.Value = strconv.FormatUint(uint64(value), 10)
107+
param.Type = Integer
108+
case int32:
109+
param.Value = strconv.Itoa(int(value))
110+
param.Type = Integer
111+
case uint32:
112+
param.Value = strconv.FormatUint(uint64(value), 10)
113+
param.Type = Integer
114+
case int64:
115+
param.Value = strconv.Itoa(int(value))
116+
param.Type = Integer
117+
case uint64:
118+
param.Value = strconv.FormatUint(uint64(value), 10)
119+
param.Type = Integer
120+
case float32:
121+
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
122+
param.Type = Float
123+
case time.Time:
124+
param.Value = value.String()
125+
param.Type = Timestamp
126+
case DBSqlParam:
127+
param.Name = value.Name
128+
param.Value = value.Value
129+
param.Type = value.Type
130+
default:
131+
s := fmt.Sprintf("%s", value)
132+
param.Value = s
133+
param.Type = String
134+
}
135+
}
136+
}
137+
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
138+
var sparkParams []*cli_service.TSparkParameter
139+
140+
sqlParams := valuesToDBSQLParams(values)
141+
inferTypes(sqlParams)
142+
for i := range sqlParams {
143+
sqlParam := sqlParams[i]
144+
sparkParamValue := sqlParam.Value.(string)
145+
sparkParamType := sqlParam.Type.String()
146+
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
147+
sparkParams = append(sparkParams, &sparkParam)
148+
}
149+
return sparkParams
150+
}

statement_test.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql/driver"
66
"testing"
7-
"time"
87

98
"github.com/apache/thrift/lib/go/thrift"
109
"github.com/databricks/databricks-sql-go/internal/cli_service"
@@ -166,13 +165,3 @@ func TestStmt_QueryContext(t *testing.T) {
166165
assert.Equal(t, testQuery, savedQueryString)
167166
})
168167
}
169-
func TestParameters(t *testing.T) {
170-
t.Run("Parameter casting should be correct", func(t *testing.T) {
171-
values := [3]driver.NamedValue{{Ordinal: 1, Name: "", Value: float32(5)}, {Ordinal: 2, Name: "", Value: time.Now()}, {Ordinal: 3, Name: "", Value: int64(5)}}
172-
parameters := namedValuesToTSparkParams(values[:])
173-
assert.Equal(t, &cli_service.TSparkParameterValue{DoubleValue: thrift.Float64Ptr(5)}, parameters[0].Value)
174-
assert.NotNil(t, parameters[1].Value.StringValue)
175-
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
176-
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
177-
})
178-
}

0 commit comments

Comments
 (0)