Skip to content

Commit 7a177c9

Browse files
Better handling of bad connection errors and specifying server protocol. (#152)
Added a new error type that identifies as driver.ErrBadConn. This can be added in to an error chain to signal that a connection should no longer be used and to retry with a new connection. Updated ThriftServiceClient error handling. When a Thrift request fails we are now checking the error message for the term 'Invalid SessionHandle'. If it is present we are adding a bad connection error to the error chain. Searching in the error message is clumsy and fragile but there doesn't currently appear to be another way to get the information. Updated the WithServerHostname function to handle host names prefixed by 'http:' or 'https:' to allow users to specify which protocol to use. This is in response to github issue #140.
2 parents 2c2cb73 + 6330b1b commit 7a177c9

4 files changed

Lines changed: 196 additions & 91 deletions

File tree

internal/client/client.go

Lines changed: 111 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"os"
1515
"regexp"
1616
"strconv"
17+
"strings"
1718
"time"
1819

1920
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
@@ -45,116 +46,121 @@ const (
4546

4647
type clientMethod int
4748

48-
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod
49+
//go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod -trimprefix=clientMethod
4950

5051
const (
51-
unknown clientMethod = iota
52-
openSession
53-
closeSession
54-
fetchResults
55-
getResultSetMetadata
56-
executeStatement
57-
getOperationStatus
58-
closeOperation
59-
cancelOperation
52+
clientMethodUnknown clientMethod = iota
53+
clientMethodOpenSession
54+
clientMethodCloseSession
55+
clientMethodFetchResults
56+
clientMethodGetResultSetMetadata
57+
clientMethodExecuteStatement
58+
clientMethodGetOperationStatus
59+
clientMethodCloseOperation
60+
clientMethodCancelOperation
6061
)
6162

6263
var nonRetryableClientMethods map[clientMethod]any = map[clientMethod]any{
63-
executeStatement: struct{}{},
64-
unknown: struct{}{}}
64+
clientMethodExecuteStatement: struct{}{},
65+
clientMethodUnknown: struct{}{},
66+
}
67+
68+
var clientMethodRequestErrorMsgs map[clientMethod]string = map[clientMethod]string{
69+
clientMethodOpenSession: "open session request error",
70+
clientMethodCloseSession: "close session request error",
71+
clientMethodFetchResults: "fetch results request error",
72+
clientMethodGetResultSetMetadata: "get result set metadata request error",
73+
clientMethodExecuteStatement: "execute statement request error",
74+
clientMethodGetOperationStatus: "get operation status request error",
75+
clientMethodCloseOperation: "close operation request error",
76+
clientMethodCancelOperation: "cancel operation request error",
77+
}
6578

6679
// OpenSession is a wrapper around the thrift operation OpenSession
6780
// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
6881
func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
69-
ctx = context.WithValue(ctx, ClientMethod, openSession)
82+
ctx = context.WithValue(ctx, ClientMethod, clientMethodOpenSession)
7083
msg, start := logger.Track("OpenSession")
7184
resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req)
7285
if err != nil {
73-
return nil, dbsqlerrint.NewRequestError(ctx, "open session request error", err)
86+
err = handleClientMethodError(ctx, err)
87+
return resp, err
7488
}
89+
90+
recordResult(ctx, resp)
91+
7592
log := logger.WithContext(SprintGuid(resp.SessionHandle.SessionId.GUID), driverctx.CorrelationIdFromContext(ctx), "")
7693
defer log.Duration(msg, start)
77-
if RecordResults {
78-
j, _ := json.MarshalIndent(resp, "", " ")
79-
_ = os.WriteFile(fmt.Sprintf("OpenSession%d.json", resultIndex), j, 0600)
80-
resultIndex++
81-
}
94+
8295
return resp, CheckStatus(resp)
8396
}
8497

8598
// CloseSession is a wrapper around the thrift operation CloseSession
8699
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
87100
func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) {
88-
ctx = context.WithValue(ctx, ClientMethod, closeSession)
101+
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseSession)
89102
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "")
90103
defer log.Duration(logger.Track("CloseSession"))
91104
resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req)
92105
if err != nil {
93-
return resp, dbsqlerrint.NewRequestError(ctx, "close session request error", err)
94-
}
95-
if RecordResults {
96-
j, _ := json.MarshalIndent(resp, "", " ")
97-
_ = os.WriteFile(fmt.Sprintf("CloseSession%d.json", resultIndex), j, 0600)
98-
resultIndex++
106+
err = handleClientMethodError(ctx, err)
107+
return resp, err
99108
}
109+
110+
recordResult(ctx, resp)
111+
100112
return resp, CheckStatus(resp)
101113
}
102114

103115
// FetchResults is a wrapper around the thrift operation FetchResults
104116
// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
105117
func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
106-
ctx = context.WithValue(ctx, ClientMethod, fetchResults)
118+
ctx = context.WithValue(ctx, ClientMethod, clientMethodFetchResults)
107119
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
108120
defer log.Duration(logger.Track("FetchResults"))
109121
resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req)
110122
if err != nil {
111-
return resp, dbsqlerrint.NewRequestError(ctx, "fetch results request error", err)
112-
}
113-
if RecordResults {
114-
j, _ := json.MarshalIndent(resp, "", " ")
115-
_ = os.WriteFile(fmt.Sprintf("FetchResults%d.json", resultIndex), j, 0600)
116-
resultIndex++
123+
err = handleClientMethodError(ctx, err)
124+
return resp, err
117125
}
126+
127+
recordResult(ctx, resp)
128+
118129
return resp, CheckStatus(resp)
119130
}
120131

121132
// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
122133
// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
123134
func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
124-
ctx = context.WithValue(ctx, ClientMethod, getResultSetMetadata)
135+
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetResultSetMetadata)
125136
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
126137
defer log.Duration(logger.Track("GetResultSetMetadata"))
127138
resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req)
128139
if err != nil {
129-
return resp, dbsqlerrint.NewRequestError(ctx, "get result set metadata request error", err)
130-
}
131-
if RecordResults {
132-
j, _ := json.MarshalIndent(resp, "", " ")
133-
_ = os.WriteFile(fmt.Sprintf("GetResultSetMetadata%d.json", resultIndex), j, 0600)
134-
resultIndex++
140+
err = handleClientMethodError(ctx, err)
141+
return resp, err
135142
}
143+
144+
recordResult(ctx, resp)
145+
136146
return resp, CheckStatus(resp)
137147
}
138148

139149
// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
140150
// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
141151
func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
142-
ctx = context.WithValue(ctx, ClientMethod, executeStatement)
152+
ctx = context.WithValue(ctx, ClientMethod, clientMethodExecuteStatement)
143153
msg, start := logger.Track("ExecuteStatement")
144154

145155
// We use context.Background to fix a problem where on context done the query would not be cancelled.
146156
resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req)
147157
if err != nil {
148-
return resp, dbsqlerrint.NewRequestError(ctx, "execute statement request error", err)
149-
}
150-
if RecordResults {
151-
j, _ := json.MarshalIndent(resp, "", " ")
152-
_ = os.WriteFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex), j, 0600)
153-
// f, _ := os.ReadFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex))
154-
// var resp2 cli_service.TExecuteStatementResp
155-
// json.Unmarshal(f, &resp2)
156-
resultIndex++
158+
err = handleClientMethodError(ctx, err)
159+
return resp, err
157160
}
161+
162+
recordResult(ctx, resp)
163+
158164
if resp != nil && resp.OperationHandle != nil {
159165
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(resp.OperationHandle.OperationId.GUID))
160166
defer log.Duration(msg, start)
@@ -165,54 +171,51 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
165171
// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
166172
// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
167173
func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) {
168-
ctx = context.WithValue(ctx, ClientMethod, getOperationStatus)
174+
ctx = context.WithValue(ctx, ClientMethod, clientMethodGetOperationStatus)
169175
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
170176
defer log.Duration(logger.Track("GetOperationStatus"))
171177
resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req)
172178
if err != nil {
173-
return resp, dbsqlerrint.NewRequestError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), "databricks: get operation status request error", err)
174-
}
175-
if RecordResults {
176-
j, _ := json.MarshalIndent(resp, "", " ")
177-
_ = os.WriteFile(fmt.Sprintf("GetOperationStatus%d.json", resultIndex), j, 0600)
178-
resultIndex++
179+
err = handleClientMethodError(driverctx.NewContextWithQueryId(ctx, SprintGuid(req.OperationHandle.OperationId.GUID)), err)
180+
return resp, err
179181
}
182+
183+
recordResult(ctx, resp)
184+
180185
return resp, CheckStatus(resp)
181186
}
182187

183188
// CloseOperation is a wrapper around the thrift operation CloseOperation
184189
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
185190
func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
186-
ctx = context.WithValue(ctx, ClientMethod, closeOperation)
191+
ctx = context.WithValue(ctx, ClientMethod, clientMethodCloseOperation)
187192
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
188193
defer log.Duration(logger.Track("CloseOperation"))
189194
resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req)
190195
if err != nil {
191-
return resp, dbsqlerrint.NewRequestError(ctx, "close operation request error", err)
192-
}
193-
if RecordResults {
194-
j, _ := json.MarshalIndent(resp, "", " ")
195-
_ = os.WriteFile(fmt.Sprintf("CloseOperation%d.json", resultIndex), j, 0600)
196-
resultIndex++
196+
err = handleClientMethodError(ctx, err)
197+
return resp, err
197198
}
199+
200+
recordResult(ctx, resp)
201+
198202
return resp, CheckStatus(resp)
199203
}
200204

201205
// CancelOperation is a wrapper around the thrift operation CancelOperation
202206
// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
203207
func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) {
204-
ctx = context.WithValue(ctx, ClientMethod, cancelOperation)
208+
ctx = context.WithValue(ctx, ClientMethod, clientMethodCancelOperation)
205209
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
206210
defer log.Duration(logger.Track("CancelOperation"))
207211
resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req)
208212
if err != nil {
209-
return resp, dbsqlerrint.NewRequestError(ctx, "cancel operation request error", err)
210-
}
211-
if RecordResults {
212-
j, _ := json.MarshalIndent(resp, "", " ")
213-
_ = os.WriteFile(fmt.Sprintf("CancelOperation%d.json", resultIndex), j, 0600)
214-
resultIndex++
213+
err = handleClientMethodError(ctx, err)
214+
return resp, err
215215
}
216+
217+
recordResult(ctx, resp)
218+
216219
return resp, CheckStatus(resp)
217220
}
218221

@@ -283,6 +286,42 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi
283286
return tsClient, nil
284287
}
285288

289+
// handler function for errors returned by the thrift client methods
290+
func handleClientMethodError(ctx context.Context, err error) dbsqlerr.DBRequestError {
291+
if err == nil {
292+
return nil
293+
}
294+
295+
// If the passed error indicates an invalid session we inject a bad connection error
296+
// into the error stack. This allows the for retrying with a new connection.
297+
s := err.Error()
298+
if strings.Contains(s, "Invalid SessionHandle") {
299+
err = dbsqlerrint.NewBadConnectionError(err)
300+
}
301+
302+
// the passed error will be wrapped in a DBRequestError
303+
method := getClientMethod(ctx)
304+
msg := clientMethodRequestErrorMsgs[method]
305+
306+
return dbsqlerrint.NewRequestError(ctx, msg, err)
307+
}
308+
309+
// Extract a clientMethod value from the given Context.
310+
func getClientMethod(ctx context.Context) clientMethod {
311+
v, _ := ctx.Value(ClientMethod).(clientMethod)
312+
return v
313+
}
314+
315+
// Write the result
316+
func recordResult(ctx context.Context, resp any) {
317+
if RecordResults && resp != nil {
318+
method := getClientMethod(ctx)
319+
j, _ := json.MarshalIndent(resp, "", " ")
320+
_ = os.WriteFile(fmt.Sprintf("%s%d.json", method, resultIndex), j, 0600)
321+
resultIndex++
322+
}
323+
}
324+
286325
// ThriftResponse represents the thrift rpc response
287326
type ThriftResponse interface {
288327
GetStatus() *cli_service.TStatus
@@ -507,7 +546,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
507546
return false, ctx.Err()
508547
}
509548

510-
caller, _ := ctx.Value(ClientMethod).(clientMethod)
549+
caller := getClientMethod(ctx)
511550
_, nonRetryableClientMethod := nonRetryableClientMethods[caller]
512551

513552
if err != nil {

internal/client/client_test.go

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ package client
33
import (
44
"context"
55
"crypto/x509"
6+
"database/sql/driver"
67
"net/http"
78
"net/url"
9+
"strings"
810
"testing"
911
"time"
1012

13+
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1114
"github.com/pkg/errors"
1215
"github.com/stretchr/testify/require"
1316
)
@@ -126,16 +129,16 @@ func TestRetryPolicy(t *testing.T) {
126129
nonRetryableCodes := []int{200, 300, 400, 501}
127130

128131
retryableOps := []clientMethod{
129-
closeSession,
130-
getResultSetMetadata,
131-
getOperationStatus,
132-
closeOperation,
133-
cancelOperation,
134-
fetchResults,
135-
openSession,
132+
clientMethodCloseSession,
133+
clientMethodGetResultSetMetadata,
134+
clientMethodGetOperationStatus,
135+
clientMethodCloseOperation,
136+
clientMethodCancelOperation,
137+
clientMethodFetchResults,
138+
clientMethodOpenSession,
136139
}
137140

138-
nonRetryableOps := []clientMethod{executeStatement}
141+
nonRetryableOps := []clientMethod{clientMethodExecuteStatement}
139142

140143
cancelled, cancel := context.WithCancel(context.Background())
141144
cancel()
@@ -214,4 +217,41 @@ func TestRetryPolicy(t *testing.T) {
214217

215218
})
216219

220+
t.Run("test handling client method errors", func(t *testing.T) {
221+
cases := []struct {
222+
base string
223+
method clientMethod
224+
isBadConn bool
225+
}{
226+
{"generic error", clientMethodUnknown, false},
227+
{"error with Invalid SessionHandle and stuff", clientMethodUnknown, true},
228+
{"generic error", clientMethodOpenSession, false},
229+
{"error with Invalid SessionHandle and stuff", clientMethodOpenSession, true},
230+
{"generic error", clientMethodCloseSession, false},
231+
{"error with Invalid SessionHandle and stuff", clientMethodCloseSession, true},
232+
{"generic error", clientMethodFetchResults, false},
233+
{"error with Invalid SessionHandle and stuff", clientMethodFetchResults, true},
234+
{"generic error", clientMethodGetResultSetMetadata, false},
235+
{"error with Invalid SessionHandle and stuff", clientMethodGetResultSetMetadata, true},
236+
{"generic error", clientMethodExecuteStatement, false},
237+
{"error with Invalid SessionHandle and stuff", clientMethodExecuteStatement, true},
238+
{"generic error", clientMethodGetOperationStatus, false},
239+
{"error with Invalid SessionHandle and stuff", clientMethodGetOperationStatus, true},
240+
{"generic error", clientMethodCloseOperation, false},
241+
{"error with Invalid SessionHandle and stuff", clientMethodCloseOperation, true},
242+
{"generic error", clientMethodCancelOperation, false},
243+
{"error with Invalid SessionHandle and stuff", clientMethodCancelOperation, true},
244+
}
245+
246+
for i := range cases {
247+
c := cases[i]
248+
err := handleClientMethodError(context.WithValue(context.Background(), ClientMethod, c.method), errors.New(c.base))
249+
msg := clientMethodRequestErrorMsgs[c.method]
250+
require.True(t, strings.Contains(err.Error(), msg))
251+
require.True(t, strings.Contains(err.Error(), c.base))
252+
require.True(t, errors.Is(err, dbsqlerr.DatabricksError))
253+
require.True(t, errors.Is(err, dbsqlerr.RequestError))
254+
require.Equal(t, c.isBadConn, errors.Is(err, driver.ErrBadConn))
255+
}
256+
})
217257
}

0 commit comments

Comments
 (0)