Skip to content

Commit 7ce81f6

Browse files
committed
Finished impl and tests
Signed-off-by: nithinkdb <nithin.krishnamurthi@databricks.com>
1 parent c0e473a commit 7ce81f6

4 files changed

Lines changed: 195 additions & 105 deletions

File tree

connection.go

Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package dbsql
33
import (
44
"context"
55
"database/sql/driver"
6-
"fmt"
7-
"strconv"
86
"time"
97

108
"github.com/databricks/databricks-sql-go/driverctx"
@@ -288,7 +286,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
288286
MaxRows: int64(c.cfg.MaxRows),
289287
},
290288
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
291-
Parameters: namedValuesToTSparkParams(args),
289+
Parameters: convertNamedValuesToSparkParams(args),
292290
}
293291

294292
if c.cfg.UseArrowBatches {
@@ -342,97 +340,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
342340
return resp, err
343341
}
344342

345-
type dbSqlParam struct {
346-
Ordinal *int32
347-
Name *string
348-
Type *string
349-
Value *string
350-
}
351-
352-
func namedValuesToTSparkParams(args []driver.NamedValue) []*cli_service.TSparkParameter {
353-
var ts []string = []string{"STRING", "DOUBLE", "BOOLEAN", "TIMESTAMP", "FLOAT", "INTEGER", "TINYINT", "SMALLINT", "BIGINT"}
354-
var params []*cli_service.TSparkParameter
355-
for i := range args {
356-
arg := args[i]
357-
param := cli_service.TSparkParameter{Value: &cli_service.TSparkParameterValue{}}
358-
if arg.Name != "" {
359-
param.Name = &arg.Name
360-
} else {
361-
i := int32(arg.Ordinal)
362-
param.Ordinal = &i
363-
}
364-
365-
switch t := arg.Value.(type) {
366-
case bool:
367-
b := arg.Value.(bool)
368-
param.Value.BooleanValue = &b
369-
param.Type = &ts[2]
370-
case string:
371-
s := arg.Value.(string)
372-
param.Value.StringValue = &s
373-
param.Type = &ts[0]
374-
case int:
375-
f := float64(t)
376-
param.Value.DoubleValue = &f
377-
param.Type = &ts[5]
378-
case uint:
379-
f := float64(t)
380-
param.Value.DoubleValue = &f
381-
param.Type = &ts[5]
382-
case int8:
383-
f := float64(t)
384-
param.Value.DoubleValue = &f
385-
param.Type = &ts[6]
386-
case uint8:
387-
f := float64(t)
388-
param.Value.DoubleValue = &f
389-
param.Type = &ts[6]
390-
case int16:
391-
f := float64(t)
392-
param.Value.DoubleValue = &f
393-
param.Type = &ts[7]
394-
case uint16:
395-
f := float64(t)
396-
param.Value.DoubleValue = &f
397-
param.Type = &ts[7]
398-
case int32:
399-
f := float64(t)
400-
param.Value.DoubleValue = &f
401-
param.Type = &ts[5]
402-
case uint32:
403-
f := float64(t)
404-
param.Value.DoubleValue = &f
405-
param.Type = &ts[5]
406-
case int64:
407-
s := strconv.FormatInt(t, 10)
408-
param.Value.StringValue = &s
409-
param.Type = &ts[8]
410-
case uint64:
411-
s := strconv.FormatUint(t, 10)
412-
param.Value.StringValue = &s
413-
param.Type = &ts[8]
414-
case float32:
415-
f := float64(t)
416-
param.Value.DoubleValue = &f
417-
param.Type = &ts[4]
418-
case time.Time:
419-
s := t.String()
420-
param.Value.StringValue = &s
421-
param.Type = &ts[3]
422-
case dbSqlParam:
423-
param.Value.StringValue = t.Value
424-
param.Type = t.Type
425-
default:
426-
s := fmt.Sprintf("%s", arg.Value)
427-
param.Value.StringValue = &s
428-
param.Type = &ts[0]
429-
}
430-
431-
params = append(params, &param)
432-
}
433-
return params
434-
}
435-
436343
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
437344
corrId := driverctx.CorrelationIdFromContext(ctx)
438345
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))

parameter_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package dbsql
2+
3+
import (
4+
"database/sql/driver"
5+
"strconv"
6+
"testing"
7+
"time"
8+
9+
"github.com/databricks/databricks-sql-go/internal/cli_service"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestParameter_Inference(t *testing.T) {
14+
t.Run("Should infer types correctly", func(t *testing.T) {
15+
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: TypeValuePair{Value: "6.2", Type: Decimal}}}
16+
parameters := convertNamedValuesToSparkParams(values[:])
17+
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
18+
assert.NotNil(t, parameters[1].Value.StringValue)
19+
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
20+
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
21+
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
22+
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
23+
})
24+
}
25+
func TestParameters_Names(t *testing.T) {
26+
t.Run("Should infer types correctly", func(t *testing.T) {
27+
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "2", Value: TypeValuePair{Type: Decimal, Value: "6.2"}}}
28+
parameters := convertNamedValuesToSparkParams(values[:])
29+
assert.Equal(t, string("1"), *parameters[0].Name)
30+
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
31+
assert.Equal(t, string("2"), *parameters[1].Name)
32+
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
33+
assert.Equal(t, string("DECIMAL"), *parameters[1].Type)
34+
})
35+
}

parameters.go

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package dbsql
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
"strconv"
7+
"time"
8+
9+
"github.com/databricks/databricks-sql-go/internal/cli_service"
10+
)
11+
12+
type DbSqlParam struct {
13+
Name string
14+
Type SqlType
15+
Value any
16+
}
17+
18+
type TypeValuePair struct {
19+
Type SqlType
20+
Value any
21+
}
22+
23+
type SqlType int64
24+
25+
const (
26+
Void SqlType = iota
27+
String
28+
Date
29+
Timestamp
30+
Float
31+
Decimal
32+
Double
33+
Integer
34+
Bigint
35+
Smallint
36+
Tinyint
37+
Boolean
38+
IntervalMonth
39+
IntervalDay
40+
)
41+
42+
func (s SqlType) String() string {
43+
switch s {
44+
case Void:
45+
return "VOID"
46+
case String:
47+
return "STRING"
48+
case Date:
49+
return "DATE"
50+
case Timestamp:
51+
return "TIMESTAMP"
52+
case Float:
53+
return "FLOAT"
54+
case Decimal:
55+
return "DECIMAL"
56+
case Double:
57+
return "DOUBLE"
58+
case Integer:
59+
return "INTEGER"
60+
case Bigint:
61+
return "BIGINT"
62+
case Smallint:
63+
return "SMALLINT"
64+
case Tinyint:
65+
return "TINYINT"
66+
case Boolean:
67+
return "BOOLEAN"
68+
case IntervalMonth:
69+
return "INTERVAL MONTH"
70+
case IntervalDay:
71+
return "INTERVAL DAY"
72+
}
73+
return "unknown"
74+
}
75+
76+
func valuesToDBSQLParams(namedValues []driver.NamedValue) []DbSqlParam {
77+
var params []DbSqlParam
78+
for i := range namedValues {
79+
namedValue := namedValues[i]
80+
param := *new(DbSqlParam)
81+
param.Name = namedValue.Name
82+
param.Value = namedValue.Value
83+
params = append(params, param)
84+
}
85+
return params
86+
}
87+
88+
func inferTypes(params []DbSqlParam) {
89+
for i := range params {
90+
param := &params[i]
91+
switch value := param.Value.(type) {
92+
case nil:
93+
param.Type = Void
94+
case bool:
95+
param.Value = strconv.FormatBool(value)
96+
param.Type = Boolean
97+
case string:
98+
param.Value = &value
99+
param.Type = String
100+
case int:
101+
param.Value = strconv.Itoa(value)
102+
param.Type = Integer
103+
case uint:
104+
param.Value = strconv.FormatUint(uint64(value), 10)
105+
param.Type = Integer
106+
case int8:
107+
param.Value = strconv.Itoa(int(value))
108+
param.Type = Integer
109+
case uint8:
110+
param.Value = strconv.FormatUint(uint64(value), 10)
111+
param.Type = Integer
112+
case int16:
113+
param.Value = strconv.Itoa(int(value))
114+
param.Type = Integer
115+
case uint16:
116+
param.Value = strconv.FormatUint(uint64(value), 10)
117+
param.Type = Integer
118+
case int32:
119+
param.Value = strconv.Itoa(int(value))
120+
param.Type = Integer
121+
case uint32:
122+
param.Value = strconv.FormatUint(uint64(value), 10)
123+
param.Type = Integer
124+
case int64:
125+
param.Value = strconv.Itoa(int(value))
126+
param.Type = Integer
127+
case uint64:
128+
param.Value = strconv.FormatUint(uint64(value), 10)
129+
param.Type = Integer
130+
case float32:
131+
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
132+
param.Type = Float
133+
case time.Time:
134+
param.Value = value.String()
135+
param.Type = Timestamp
136+
case TypeValuePair:
137+
param.Value = value.Value
138+
param.Type = value.Type
139+
default:
140+
s := fmt.Sprintf("%s", value)
141+
param.Value = s
142+
param.Type = String
143+
}
144+
}
145+
}
146+
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
147+
var sparkParams []*cli_service.TSparkParameter
148+
149+
sqlParams := valuesToDBSQLParams(values)
150+
inferTypes(sqlParams)
151+
for i := range sqlParams {
152+
sqlParam := sqlParams[i]
153+
sparkParamValue := sqlParam.Value.(string)
154+
sparkParamType := sqlParam.Type.String()
155+
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
156+
sparkParams = append(sparkParams, &sparkParam)
157+
}
158+
return sparkParams
159+
}

statement_test.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql/driver"
66
"testing"
7-
"time"
87

98
"github.com/apache/thrift/lib/go/thrift"
109
"github.com/databricks/databricks-sql-go/internal/cli_service"
@@ -166,13 +165,3 @@ func TestStmt_QueryContext(t *testing.T) {
166165
assert.Equal(t, testQuery, savedQueryString)
167166
})
168167
}
169-
func TestParameters(t *testing.T) {
170-
t.Run("Parameter casting should be correct", func(t *testing.T) {
171-
values := [3]driver.NamedValue{{Ordinal: 1, Name: "", Value: float32(5)}, {Ordinal: 2, Name: "", Value: time.Now()}, {Ordinal: 3, Name: "", Value: int64(5)}}
172-
parameters := namedValuesToTSparkParams(values[:])
173-
assert.Equal(t, &cli_service.TSparkParameterValue{DoubleValue: thrift.Float64Ptr(5)}, parameters[0].Value)
174-
assert.NotNil(t, parameters[1].Value.StringValue)
175-
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
176-
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
177-
})
178-
}

0 commit comments

Comments
 (0)