@@ -14,6 +14,7 @@ import (
1414 "os"
1515 "regexp"
1616 "strconv"
17+ "strings"
1718 "time"
1819
1920 dbsqlerr "github.com/databricks/databricks-sql-go/errors"
@@ -45,116 +46,121 @@ const (
4546
4647type clientMethod int
4748
48- //go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod
49+ //go:generate go run golang.org/x/tools/cmd/stringer -type=clientMethod -trimprefix=clientMethod
4950
5051const (
51- unknown clientMethod = iota
52- openSession
53- closeSession
54- fetchResults
55- getResultSetMetadata
56- executeStatement
57- getOperationStatus
58- closeOperation
59- cancelOperation
52+ clientMethodUnknown clientMethod = iota
53+ clientMethodOpenSession
54+ clientMethodCloseSession
55+ clientMethodFetchResults
56+ clientMethodGetResultSetMetadata
57+ clientMethodExecuteStatement
58+ clientMethodGetOperationStatus
59+ clientMethodCloseOperation
60+ clientMethodCancelOperation
6061)
6162
6263var nonRetryableClientMethods map [clientMethod ]any = map [clientMethod ]any {
63- executeStatement : struct {}{},
64- unknown : struct {}{}}
64+ clientMethodExecuteStatement : struct {}{},
65+ clientMethodUnknown : struct {}{},
66+ }
67+
68+ var clientMethodRequestErrorMsgs map [clientMethod ]string = map [clientMethod ]string {
69+ clientMethodOpenSession : "open session request error" ,
70+ clientMethodCloseSession : "close session request error" ,
71+ clientMethodFetchResults : "fetch results request error" ,
72+ clientMethodGetResultSetMetadata : "get result set metadata request error" ,
73+ clientMethodExecuteStatement : "execute statement request error" ,
74+ clientMethodGetOperationStatus : "get operation status request error" ,
75+ clientMethodCloseOperation : "close operation request error" ,
76+ clientMethodCancelOperation : "cancel operation request error" ,
77+ }
6578
6679// OpenSession is a wrapper around the thrift operation OpenSession
6780// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
6881func (tsc * ThriftServiceClient ) OpenSession (ctx context.Context , req * cli_service.TOpenSessionReq ) (* cli_service.TOpenSessionResp , error ) {
69- ctx = context .WithValue (ctx , ClientMethod , openSession )
82+ ctx = context .WithValue (ctx , ClientMethod , clientMethodOpenSession )
7083 msg , start := logger .Track ("OpenSession" )
7184 resp , err := tsc .TCLIServiceClient .OpenSession (ctx , req )
7285 if err != nil {
73- return nil , dbsqlerrint .NewRequestError (ctx , "open session request error" , err )
86+ err = handleClientMethodError (ctx , err )
87+ return resp , err
7488 }
89+
90+ recordResult (ctx , resp )
91+
7592 log := logger .WithContext (SprintGuid (resp .SessionHandle .SessionId .GUID ), driverctx .CorrelationIdFromContext (ctx ), "" )
7693 defer log .Duration (msg , start )
77- if RecordResults {
78- j , _ := json .MarshalIndent (resp , "" , " " )
79- _ = os .WriteFile (fmt .Sprintf ("OpenSession%d.json" , resultIndex ), j , 0600 )
80- resultIndex ++
81- }
94+
8295 return resp , CheckStatus (resp )
8396}
8497
8598// CloseSession is a wrapper around the thrift operation CloseSession
8699// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
87100func (tsc * ThriftServiceClient ) CloseSession (ctx context.Context , req * cli_service.TCloseSessionReq ) (* cli_service.TCloseSessionResp , error ) {
88- ctx = context .WithValue (ctx , ClientMethod , closeSession )
101+ ctx = context .WithValue (ctx , ClientMethod , clientMethodCloseSession )
89102 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), "" )
90103 defer log .Duration (logger .Track ("CloseSession" ))
91104 resp , err := tsc .TCLIServiceClient .CloseSession (ctx , req )
92105 if err != nil {
93- return resp , dbsqlerrint .NewRequestError (ctx , "close session request error" , err )
94- }
95- if RecordResults {
96- j , _ := json .MarshalIndent (resp , "" , " " )
97- _ = os .WriteFile (fmt .Sprintf ("CloseSession%d.json" , resultIndex ), j , 0600 )
98- resultIndex ++
106+ err = handleClientMethodError (ctx , err )
107+ return resp , err
99108 }
109+
110+ recordResult (ctx , resp )
111+
100112 return resp , CheckStatus (resp )
101113}
102114
103115// FetchResults is a wrapper around the thrift operation FetchResults
104116// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
105117func (tsc * ThriftServiceClient ) FetchResults (ctx context.Context , req * cli_service.TFetchResultsReq ) (* cli_service.TFetchResultsResp , error ) {
106- ctx = context .WithValue (ctx , ClientMethod , fetchResults )
118+ ctx = context .WithValue (ctx , ClientMethod , clientMethodFetchResults )
107119 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
108120 defer log .Duration (logger .Track ("FetchResults" ))
109121 resp , err := tsc .TCLIServiceClient .FetchResults (ctx , req )
110122 if err != nil {
111- return resp , dbsqlerrint .NewRequestError (ctx , "fetch results request error" , err )
112- }
113- if RecordResults {
114- j , _ := json .MarshalIndent (resp , "" , " " )
115- _ = os .WriteFile (fmt .Sprintf ("FetchResults%d.json" , resultIndex ), j , 0600 )
116- resultIndex ++
123+ err = handleClientMethodError (ctx , err )
124+ return resp , err
117125 }
126+
127+ recordResult (ctx , resp )
128+
118129 return resp , CheckStatus (resp )
119130}
120131
121132// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
122133// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
123134func (tsc * ThriftServiceClient ) GetResultSetMetadata (ctx context.Context , req * cli_service.TGetResultSetMetadataReq ) (* cli_service.TGetResultSetMetadataResp , error ) {
124- ctx = context .WithValue (ctx , ClientMethod , getResultSetMetadata )
135+ ctx = context .WithValue (ctx , ClientMethod , clientMethodGetResultSetMetadata )
125136 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
126137 defer log .Duration (logger .Track ("GetResultSetMetadata" ))
127138 resp , err := tsc .TCLIServiceClient .GetResultSetMetadata (ctx , req )
128139 if err != nil {
129- return resp , dbsqlerrint .NewRequestError (ctx , "get result set metadata request error" , err )
130- }
131- if RecordResults {
132- j , _ := json .MarshalIndent (resp , "" , " " )
133- _ = os .WriteFile (fmt .Sprintf ("GetResultSetMetadata%d.json" , resultIndex ), j , 0600 )
134- resultIndex ++
140+ err = handleClientMethodError (ctx , err )
141+ return resp , err
135142 }
143+
144+ recordResult (ctx , resp )
145+
136146 return resp , CheckStatus (resp )
137147}
138148
139149// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
140150// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
141151func (tsc * ThriftServiceClient ) ExecuteStatement (ctx context.Context , req * cli_service.TExecuteStatementReq ) (* cli_service.TExecuteStatementResp , error ) {
142- ctx = context .WithValue (ctx , ClientMethod , executeStatement )
152+ ctx = context .WithValue (ctx , ClientMethod , clientMethodExecuteStatement )
143153 msg , start := logger .Track ("ExecuteStatement" )
144154
145155 // We use context.Background to fix a problem where on context done the query would not be cancelled.
146156 resp , err := tsc .TCLIServiceClient .ExecuteStatement (context .Background (), req )
147157 if err != nil {
148- return resp , dbsqlerrint .NewRequestError (ctx , "execute statement request error" , err )
149- }
150- if RecordResults {
151- j , _ := json .MarshalIndent (resp , "" , " " )
152- _ = os .WriteFile (fmt .Sprintf ("ExecuteStatement%d.json" , resultIndex ), j , 0600 )
153- // f, _ := os.ReadFile(fmt.Sprintf("ExecuteStatement%d.json", resultIndex))
154- // var resp2 cli_service.TExecuteStatementResp
155- // json.Unmarshal(f, &resp2)
156- resultIndex ++
158+ err = handleClientMethodError (ctx , err )
159+ return resp , err
157160 }
161+
162+ recordResult (ctx , resp )
163+
158164 if resp != nil && resp .OperationHandle != nil {
159165 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (resp .OperationHandle .OperationId .GUID ))
160166 defer log .Duration (msg , start )
@@ -165,54 +171,51 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
165171// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
166172// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
167173func (tsc * ThriftServiceClient ) GetOperationStatus (ctx context.Context , req * cli_service.TGetOperationStatusReq ) (* cli_service.TGetOperationStatusResp , error ) {
168- ctx = context .WithValue (ctx , ClientMethod , getOperationStatus )
174+ ctx = context .WithValue (ctx , ClientMethod , clientMethodGetOperationStatus )
169175 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
170176 defer log .Duration (logger .Track ("GetOperationStatus" ))
171177 resp , err := tsc .TCLIServiceClient .GetOperationStatus (ctx , req )
172178 if err != nil {
173- return resp , dbsqlerrint .NewRequestError (driverctx .NewContextWithQueryId (ctx , SprintGuid (req .OperationHandle .OperationId .GUID )), "databricks: get operation status request error" , err )
174- }
175- if RecordResults {
176- j , _ := json .MarshalIndent (resp , "" , " " )
177- _ = os .WriteFile (fmt .Sprintf ("GetOperationStatus%d.json" , resultIndex ), j , 0600 )
178- resultIndex ++
179+ err = handleClientMethodError (driverctx .NewContextWithQueryId (ctx , SprintGuid (req .OperationHandle .OperationId .GUID )), err )
180+ return resp , err
179181 }
182+
183+ recordResult (ctx , resp )
184+
180185 return resp , CheckStatus (resp )
181186}
182187
183188// CloseOperation is a wrapper around the thrift operation CloseOperation
184189// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
185190func (tsc * ThriftServiceClient ) CloseOperation (ctx context.Context , req * cli_service.TCloseOperationReq ) (* cli_service.TCloseOperationResp , error ) {
186- ctx = context .WithValue (ctx , ClientMethod , closeOperation )
191+ ctx = context .WithValue (ctx , ClientMethod , clientMethodCloseOperation )
187192 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
188193 defer log .Duration (logger .Track ("CloseOperation" ))
189194 resp , err := tsc .TCLIServiceClient .CloseOperation (ctx , req )
190195 if err != nil {
191- return resp , dbsqlerrint .NewRequestError (ctx , "close operation request error" , err )
192- }
193- if RecordResults {
194- j , _ := json .MarshalIndent (resp , "" , " " )
195- _ = os .WriteFile (fmt .Sprintf ("CloseOperation%d.json" , resultIndex ), j , 0600 )
196- resultIndex ++
196+ err = handleClientMethodError (ctx , err )
197+ return resp , err
197198 }
199+
200+ recordResult (ctx , resp )
201+
198202 return resp , CheckStatus (resp )
199203}
200204
201205// CancelOperation is a wrapper around the thrift operation CancelOperation
202206// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
203207func (tsc * ThriftServiceClient ) CancelOperation (ctx context.Context , req * cli_service.TCancelOperationReq ) (* cli_service.TCancelOperationResp , error ) {
204- ctx = context .WithValue (ctx , ClientMethod , cancelOperation )
208+ ctx = context .WithValue (ctx , ClientMethod , clientMethodCancelOperation )
205209 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
206210 defer log .Duration (logger .Track ("CancelOperation" ))
207211 resp , err := tsc .TCLIServiceClient .CancelOperation (ctx , req )
208212 if err != nil {
209- return resp , dbsqlerrint .NewRequestError (ctx , "cancel operation request error" , err )
210- }
211- if RecordResults {
212- j , _ := json .MarshalIndent (resp , "" , " " )
213- _ = os .WriteFile (fmt .Sprintf ("CancelOperation%d.json" , resultIndex ), j , 0600 )
214- resultIndex ++
213+ err = handleClientMethodError (ctx , err )
214+ return resp , err
215215 }
216+
217+ recordResult (ctx , resp )
218+
216219 return resp , CheckStatus (resp )
217220}
218221
@@ -283,6 +286,42 @@ func InitThriftClient(cfg *config.Config, httpclient *http.Client) (*ThriftServi
283286 return tsClient , nil
284287}
285288
289+ // handler function for errors returned by the thrift client methods
290+ func handleClientMethodError (ctx context.Context , err error ) dbsqlerr.DBRequestError {
291+ if err == nil {
292+ return nil
293+ }
294+
295+ // If the passed error indicates an invalid session we inject a bad connection error
296+ // into the error stack. This allows the for retrying with a new connection.
297+ s := err .Error ()
298+ if strings .Contains (s , "Invalid SessionHandle" ) {
299+ err = dbsqlerrint .NewBadConnectionError (err )
300+ }
301+
302+ // the passed error will be wrapped in a DBRequestError
303+ method := getClientMethod (ctx )
304+ msg := clientMethodRequestErrorMsgs [method ]
305+
306+ return dbsqlerrint .NewRequestError (ctx , msg , err )
307+ }
308+
309+ // Extract a clientMethod value from the given Context.
310+ func getClientMethod (ctx context.Context ) clientMethod {
311+ v , _ := ctx .Value (ClientMethod ).(clientMethod )
312+ return v
313+ }
314+
315+ // Write the result
316+ func recordResult (ctx context.Context , resp any ) {
317+ if RecordResults && resp != nil {
318+ method := getClientMethod (ctx )
319+ j , _ := json .MarshalIndent (resp , "" , " " )
320+ _ = os .WriteFile (fmt .Sprintf ("%s%d.json" , method , resultIndex ), j , 0600 )
321+ resultIndex ++
322+ }
323+ }
324+
286325// ThriftResponse represents the thrift rpc response
287326type ThriftResponse interface {
288327 GetStatus () * cli_service.TStatus
@@ -507,7 +546,7 @@ func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, err
507546 return false , ctx .Err ()
508547 }
509548
510- caller , _ := ctx . Value ( ClientMethod ).( clientMethod )
549+ caller := getClientMethod ( ctx )
511550 _ , nonRetryableClientMethod := nonRetryableClientMethods [caller ]
512551
513552 if err != nil {
0 commit comments