Skip to content

Commit 6330b1b

Browse files
Added a new error type that identifies as driver.ErrBadConn. This can be added in to an error chain to signal that a connection should no longer be used and to retry with a new connection.
Updated ThriftServiceClient error handling. When a Thrift request fails we are now checking the error message for the term 'Invalid SessionHandle'. If it is present we are adding a bad connection error to the error chain. Searching in the error message is clumsy and fragile but there doesn't currently appear to be another way to get the information. Updated the WithServerHostname function to handle host names prefixed by 'http:' or 'https:' to allow users to specify which protocol to use. This is in response to github issue #140. Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
1 parent 45520d9 commit 6330b1b

7 files changed

Lines changed: 255 additions & 95 deletions

File tree

connector.go

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,37 @@ func withUserConfig(ucfg config.UserConfig) connOption {
109109
// WithServerHostname sets up the server hostname. Mandatory.
110110
func WithServerHostname(host string) connOption {
111111
return func(c *config.Config) {
112-
if host == "localhost" {
113-
c.Protocol = "http"
112+
protocol, hostname := parseHostName(host)
113+
if protocol != "" {
114+
c.Protocol = protocol
114115
}
115-
c.Host = host
116+
117+
c.Host = hostname
116118
}
117119
}
118120

121+
func parseHostName(host string) (protocol, hostname string) {
122+
hostname = host
123+
if strings.HasPrefix(host, "https") {
124+
hostname = strings.TrimPrefix(host, "https")
125+
protocol = "https"
126+
} else if strings.HasPrefix(host, "http") {
127+
hostname = strings.TrimPrefix(host, "http")
128+
protocol = "http"
129+
}
130+
131+
if protocol != "" {
132+
hostname = strings.TrimPrefix(hostname, ":")
133+
hostname = strings.TrimPrefix(hostname, "//")
134+
}
135+
136+
if hostname == "localhost" && protocol == "" {
137+
protocol = "http"
138+
}
139+
140+
return
141+
}
142+
119143
// WithPort sets up the server port. Mandatory.
120144
func WithPort(port int) connOption {
121145
return func(c *config.Config) {

connector_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,37 @@ func TestNewConnector(t *testing.T) {
130130
assert.Nil(t, err)
131131
assert.Equal(t, expectedCfg, coni.cfg)
132132
})
133+
134+
t.Run("Connector test WithServerHostname", func(t *testing.T) {
135+
cases := []struct {
136+
hostname, host, protocol string
137+
}{
138+
{"databricks-host", "databricks-host", "https"},
139+
{"http://databricks-host", "databricks-host", "http"},
140+
{"https://databricks-host", "databricks-host", "https"},
141+
{"http:databricks-host", "databricks-host", "http"},
142+
{"https:databricks-host", "databricks-host", "https"},
143+
{"htt://databricks-host", "htt://databricks-host", "https"},
144+
{"localhost", "localhost", "http"},
145+
{"http:localhost", "localhost", "http"},
146+
{"https:localhost", "localhost", "https"},
147+
}
148+
149+
for i := range cases {
150+
c := cases[i]
151+
con, err := NewConnector(
152+
WithServerHostname(c.hostname),
153+
)
154+
assert.Nil(t, err)
155+
156+
coni, ok := con.(*connector)
157+
require.True(t, ok)
158+
userConfig := coni.cfg.UserConfig
159+
require.Equal(t, c.protocol, userConfig.Protocol)
160+
require.Equal(t, c.host, userConfig.Host)
161+
}
162+
163+
})
133164
}
134165

135166
type mockRoundTripper struct{}

doc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ Use sql.OpenDB() to create a database handle via a new connector object created
7070
7171
Supported functional options include:
7272
73-
- WithServerHostname(<hostname> string): Sets up the server hostname. Mandatory
73+
- WithServerHostname(<hostname> string): Sets up the server hostname. The hostname can be prefixed with "http:" or "https:" to specify a protocol to use. Mandatory
7474
- WithPort(<port> int): Sets up the server port. Mandatory
7575
- WithAccessToken(<my_token> string): Sets up the Personal Access Token. Mandatory
7676
- WithHTTPPath(<http_path> string): Sets up the endpoint to the warehouse. Mandatory

internal/client/client.go

Lines changed: 111 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"os"
1515
"regexp"
1616
"strconv"
17+
"strings"
1718
"time"
1819

1920
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
@@ -45,116 +46,121 @@ const (
4546

4647
type clientMethod int
4748

48-
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod
49+
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod -trimprefix=clientMethod
4950

5051
const (
51-
unknown clientMethod = iota
52-
openSession
53-
closeSession
54-
fetchResults
55-
getResultSetMetadata
56-
executeStatement
57-
getOperationStatus
58-
closeOperation
59-
cancelOperation
52+
clientMethodUnknown clientMethod = iota
53+
clientMethodOpenSession
54+
clientMethodCloseSession
55+
clientMethodFetchResults
56+
clientMethodGetResultSetMetadata
57+
clientMethodExecuteStatement
58+
clientMethodGetOperationStatus
59+
clientMethodCloseOperation
60+
clientMethodCancelOperation
6061
)
6162

6263
var nonRetryableClientMethods map[clientMethod]any = map[clientMethod]any{
63-
executeStatement: struct{}{},
64-
unknown: struct{}{}}
64+
clientMethodExecuteStatement: struct{}{},
65+
clientMethodUnknown: struct{}{},
66+
}
67+
68+
var clientMethodRequestErrorMsgs map[clientMethod]string = map[clientMethod]string{
69+
clientMethodOpenSession: "open session request error",
70+
clientMethodCloseSession: "close session request error",
71+
clientMethodFetchResults: "fetch results request error",
72+
clientMethodGetResultSetMetadata: "get result set metadata request error",
73+
clientMethodExecuteStatement: "execute statement request error",
74+
clientMethodGetOperationStatus: "get operation status request error",
75+
clientMethodCloseOperation: "close operation request error",
76+
clientMethodCancelOperation: "cancel operation request error",
77+
}
6578

6679
// OpenSession is a wrapper around the thrift operation OpenSession
6780
// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
6881
func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
69-
ctx = context.WithValue(ctx, ClientMethod, openSession)
82+
ctx = context.WithValue(ctx, ClientMethod, clientMethodOpenSession)
7083
msg, start := logger.Track("OpenSession")
7184
resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req)
7285
if err != nil {
73-
return nil, dbsqlerrint.NewRequestError(ctx, "open session request error", err)
86+
err = handleClientMethodError(ctx, err)
87+
return resp, err
7488
}
89+
90+
recordResult(ctx, resp)
91+
7592
log := logger.WithContext(SprintGuid(resp.SessionHandle.SessionId.GUID), driverctx.CorrelationIdFromContext(ctx), "")
7693
defer log.Duration(msg, start)
77-
if RecordResults {
78-
j, _ := json.MarshalIndent(resp, "", " ")
79-
_ = os.WriteFile(fmt.Sprintf("OpenSession%d.json", resultIndex), j, 0600)
80-
resultIndex++
81-
}
94+
8295
return resp, CheckStatus(resp)
8396
}
8497

8598
// CloseSession is a wrapper around the thrift operation CloseSession
8699
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
87100
func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) {
88-
ctx = context.WithValue(ctx, ClientMethod, closeSession)
101+
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseSession)
89102
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "")
90103
defer log.Duration(logger.Track("CloseSession"))
91104
resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req)
92105
if err != nil {
93-
return resp, dbsqlerrint.NewRequestError(ctx, "close session request error", err)
94-
}
95-
if RecordResults {
96-
j, _ := json.MarshalIndent(resp, "", " ")
97-
_ = os.WriteFile(fmt.Sprintf("CloseSession%d.json", resultIndex), j, 0600)
98-
resultIndex++
106+
err = handleClientMethodError(ctx, err)
107+
return resp, err
99108
}
109+
110+
recordResult(ctx, resp)
111+
100112
return resp, CheckStatus(resp)
101113
}
102114

103115
// FetchResults is a wrapper around the thrift operation FetchResults
104116
// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
105117
func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
106-
ctx = context.WithValue(ctx, ClientMethod, fetchResults)
118+
ctx = context.WithValue(ctx, ClientMethod, clientMethodFetchResults)
107119
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
108120
defer log.Duration(logger.Track("FetchResults"))
109121
resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req)
110122
if err != nil {
111-
return resp, dbsqlerrint.NewRequestError(ctx, "fetch results request error", err)
112-
}
113-
if RecordResults {
114-
j, _ := json.MarshalIndent(resp, "", " ")
115-
_ = os.WriteFile(fmt.Sprintf("FetchResults%d.json", resultIndex), j, 0600)
116-
resultIndex++
123+
err = handleClientMethodError(ctx, err)
124+
return resp, err
117125
}
126+
127+
recordResult(ctx, resp)
128+
118129
return resp, CheckStatus(resp)
119130
}
120131

121132
// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
122133
// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
123134
func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
124-
ctx = context.WithValue(ctx, ClientMethod, getResultSetMetadata)
135+
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetResultSetMetadata)
125136
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
126137
defer log.Duration(logger.Track("GetResultSetMetadata"))
127138
resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req)
128139
if err != nil {
129-
return resp, dbsqlerrint.NewRequestError(ctx, "get result set metadata request error", err)
130-
}
131-
if RecordResults {
132-
j, _ := json.MarshalIndent(resp, "", " ")
133-
_ = os.WriteFile(fmt.Sprintf("GetResultSetMetadata%d.json", resultIndex), j, 0600)
134-
resultIndex++
140+
err = handleClientMethodError(ctx, err)
141+
return resp, err
135142
}
143+
144+
recordResult(ctx, resp)
145+
136146
return resp, CheckStatus(resp)
137147
}
138148

139149
// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
140150
// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
141151
func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
142-
ctx = context.WithValue(ctx, ClientMethod, executeStatement)
152+
ctx = context.WithValue(ctx, ClientMethod, clientMethodExecuteStatement)
143153
msg, start := logger.Track("ExecuteStatement")
144154

145155
// We use context.Background to fix a problem where on context done the query would not be cancelled.
146156
resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req)
147157
if err != nil {
148-
return resp, dbsqlerrint.NewRequestError(ctx, "execute statement request error", err)
149-
}
150-
if RecordResults {
151-
j, _ := json.MarshalIndent(resp, "", " ")
152-
_ = os.WriteFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex), j, 0600)
153-
// f, _ := os.ReadFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex))
154-
// var resp2 cli_service.TExecuteStatementResp
155-
// json.Unmarshal(f, &resp2)
156-
resultIndex++
158+
err = handleClientMethodError(ctx, err)
159+
return resp, err
157160
}
161+
162+
recordResult(ctx, resp)
163+
158164
if resp != nil && resp.OperationHandle != nil {
159165
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(resp.OperationHandle.OperationId.GUID))
160166
defer log.Duration(msg, start)
@@ -165,54 +171,51 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
165171
// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
166172
// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
167173
func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) {
168-
ctx = context.WithValue(ctx, ClientMethod, getOperationStatus)
174+
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetOperationStatus)
169175
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
170176
defer log.Duration(logger.Track("GetOperationStatus"))
171177
resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req)
172178
if err != nil {
173-
return resp, dbsqlerrint.NewRequestError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), "databricks: get operation status request error", err)
174-
}
175-
if RecordResults {
176-
j, _ := json.MarshalIndent(resp, "", " ")
177-
_ = os.WriteFile(fmt.Sprintf("GetOperationStatus%d.json", resultIndex), j, 0600)
178-
resultIndex++
179+
err = handleClientMethodError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), err)
180+
return resp, err
179181
}
182+
183+
recordResult(ctx, resp)
184+
180185
return resp, CheckStatus(resp)
181186
}
182187

183188
// CloseOperation is a wrapper around the thrift operation CloseOperation
184189
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
185190
func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
186-
ctx = context.WithValue(ctx, ClientMethod, closeOperation)
191+
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseOperation)
187192
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
188193
defer log.Duration(logger.Track("CloseOperation"))
189194
resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req)
190195
if err != nil {
191-
return resp, dbsqlerrint.NewRequestError(ctx, "close operation request error", err)
192-
}
193-
if RecordResults {
194-
j, _ := json.MarshalIndent(resp, "", " ")
195-
_ = os.WriteFile(fmt.Sprintf("CloseOperation%d.json", resultIndex), j, 0600)
196-
resultIndex++
196+
err = handleClientMethodError(ctx, err)
197+
return resp, err
197198
}
199+
200+
recordResult(ctx, resp)
201+
198202
return resp, CheckStatus(resp)
199203
}
200204

201205
// CancelOperation is a wrapper around the thrift operation CancelOperation
202206
// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
203207
func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) {
204-
ctx = context.WithValue(ctx, ClientMethod, cancelOperation)
208+
ctx = context.WithValue(ctx, ClientMethod, clientMethodCancelOperation)
205209
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
206210
defer log.Duration(logger.Track("CancelOperation"))
207211
resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req)
208212
if err != nil {
209-
return resp, dbsqlerrint.NewRequestError(ctx, "cancel operation request error", err)
210-
}
211-
if RecordResults {
212-
j, _ := json.MarshalIndent(resp, "", " ")
213-
_ = os.WriteFile(fmt.Sprintf("CancelOperation%d.json", resultIndex), j, 0600)
214-
resultIndex++
213+
err = handleClientMethodError(ctx, err)
214+
return resp, err
215215
}
216+
217+
recordResult(ctx, resp)
218+
216219
return resp, CheckStatus(resp)
217220
}
218221

@@ -283,6 +286,42 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi
283286
return tsClient, nil
284287
}
285288

289+
// handler function for errors returned by the thrift client methods
290+
func handleClientMethodError(ctx context.Context, err error) dbsqlerr.DBRequestError {
291+
if err == nil {
292+
return nil
293+
}
294+
295+
// If the passed error indicates an invalid session we inject a bad connection error
296+
// into the error stack. This allows the for retrying with a new connection.
297+
s := err.Error()
298+
if strings.Contains(s, "Invalid SessionHandle") {
299+
err = dbsqlerrint.NewBadConnectionError(err)
300+
}
301+
302+
// the passed error will be wrapped in a DBRequestError
303+
method := getClientMethod(ctx)
304+
msg := clientMethodRequestErrorMsgs[method]
305+
306+
return dbsqlerrint.NewRequestError(ctx, msg, err)
307+
}
308+
309+
// Extract a clientMethod value from the given Context.
310+
func getClientMethod(ctx context.Context) clientMethod {
311+
v, _ := ctx.Value(ClientMethod).(clientMethod)
312+
return v
313+
}
314+
315+
// Write the result
316+
func recordResult(ctx context.Context, resp any) {
317+
if RecordResults && resp != nil {
318+
method := getClientMethod(ctx)
319+
j, _ := json.MarshalIndent(resp, "", " ")
320+
_ = os.WriteFile(fmt.Sprintf("%s%d.json", method, resultIndex), j, 0600)
321+
resultIndex++
322+
}
323+
}
324+
286325
// ThriftResponse represents the thrift rpc response
287326
type ThriftResponse interface {
288327
GetStatus() *cli_service.TStatus
@@ -507,7 +546,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
507546
return false, ctx.Err()
508547
}
509548

510-
caller, _ := ctx.Value(ClientMethod).(clientMethod)
549+
caller := getClientMethod(ctx)
511550
_, nonRetryableClientMethod := nonRetryableClientMethods[caller]
512551

513552
if err != nil {

0 commit comments

Comments
 (0)