@@ -2,15 +2,18 @@ package client
22
33import (
44 "context"
5+ "crypto/x509"
56 "encoding/json"
67 "fmt"
78 "log"
9+ "math"
810 "net"
911 "net/http"
1012 "net/http/httptrace"
1113 "net/url"
1214 "os"
1315 "regexp"
16+ "strconv"
1417 "time"
1518
1619 dbsqlerr "github.com/databricks/databricks-sql-go/errors"
@@ -34,9 +37,31 @@ type ThriftServiceClient struct {
3437 * cli_service.TCLIServiceClient
3538}
3639
40+ type contextKey int
41+
42+ const (
43+ ClientMethod contextKey = iota
44+ )
45+
46+ type clientMethod int
47+
48+ const (
49+ openSession clientMethod = iota
50+ closeSession
51+ fetchResults
52+ getResultSetMetadata
53+ executeStatement
54+ getOperationStatus
55+ closeOperation
56+ cancelOperation
57+ )
58+
59+ var nonRetryableClientMethods map [clientMethod ]any = map [clientMethod ]any {executeStatement : struct {}{}}
60+
3761// OpenSession is a wrapper around the thrift operation OpenSession
3862// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
3963func (tsc * ThriftServiceClient ) OpenSession (ctx context.Context , req * cli_service.TOpenSessionReq ) (* cli_service.TOpenSessionResp , error ) {
64+ ctx = context .WithValue (ctx , ClientMethod , openSession )
4065 msg , start := logger .Track ("OpenSession" )
4166 resp , err := tsc .TCLIServiceClient .OpenSession (ctx , req )
4267 if err != nil {
@@ -55,6 +80,7 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
5580// CloseSession is a wrapper around the thrift operation CloseSession
5681// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
5782func (tsc * ThriftServiceClient ) CloseSession (ctx context.Context , req * cli_service.TCloseSessionReq ) (* cli_service.TCloseSessionResp , error ) {
83+ ctx = context .WithValue (ctx , ClientMethod , closeSession )
5884 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), "" )
5985 defer log .Duration (logger .Track ("CloseSession" ))
6086 resp , err := tsc .TCLIServiceClient .CloseSession (ctx , req )
@@ -72,6 +98,7 @@ func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_servi
7298// FetchResults is a wrapper around the thrift operation FetchResults
7399// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
74100func (tsc * ThriftServiceClient ) FetchResults (ctx context.Context , req * cli_service.TFetchResultsReq ) (* cli_service.TFetchResultsResp , error ) {
101+ ctx = context .WithValue (ctx , ClientMethod , fetchResults )
75102 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
76103 defer log .Duration (logger .Track ("FetchResults" ))
77104 resp , err := tsc .TCLIServiceClient .FetchResults (ctx , req )
@@ -89,6 +116,7 @@ func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_servi
89116// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
90117// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
91118func (tsc * ThriftServiceClient ) GetResultSetMetadata (ctx context.Context , req * cli_service.TGetResultSetMetadataReq ) (* cli_service.TGetResultSetMetadataResp , error ) {
119+ ctx = context .WithValue (ctx , ClientMethod , getResultSetMetadata )
92120 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
93121 defer log .Duration (logger .Track ("GetResultSetMetadata" ))
94122 resp , err := tsc .TCLIServiceClient .GetResultSetMetadata (ctx , req )
@@ -106,7 +134,10 @@ func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *c
106134// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
107135// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
108136func (tsc * ThriftServiceClient ) ExecuteStatement (ctx context.Context , req * cli_service.TExecuteStatementReq ) (* cli_service.TExecuteStatementResp , error ) {
137+ ctx = context .WithValue (ctx , ClientMethod , executeStatement )
109138 msg , start := logger .Track ("ExecuteStatement" )
139+
140+ // We use context.Background to fix a problem where on context done the query would not be cancelled.
110141 resp , err := tsc .TCLIServiceClient .ExecuteStatement (context .Background (), req )
111142 if err != nil {
112143 return resp , dbsqlerrint .NewRequestError (ctx , "execute statement request error" , err )
@@ -129,6 +160,7 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
129160// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
130161// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
131162func (tsc * ThriftServiceClient ) GetOperationStatus (ctx context.Context , req * cli_service.TGetOperationStatusReq ) (* cli_service.TGetOperationStatusResp , error ) {
163+ ctx = context .WithValue (ctx , ClientMethod , getOperationStatus )
132164 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
133165 defer log .Duration (logger .Track ("GetOperationStatus" ))
134166 resp , err := tsc .TCLIServiceClient .GetOperationStatus (ctx , req )
@@ -146,6 +178,7 @@ func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli
146178// CloseOperation is a wrapper around the thrift operation CloseOperation
147179// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
148180func (tsc * ThriftServiceClient ) CloseOperation (ctx context.Context , req * cli_service.TCloseOperationReq ) (* cli_service.TCloseOperationResp , error ) {
181+ ctx = context .WithValue (ctx , ClientMethod , closeOperation )
149182 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
150183 defer log .Duration (logger .Track ("CloseOperation" ))
151184 resp , err := tsc .TCLIServiceClient .CloseOperation (ctx , req )
@@ -163,6 +196,7 @@ func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_ser
163196// CancelOperation is a wrapper around the thrift operation CancelOperation
164197// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
165198func (tsc * ThriftServiceClient ) CancelOperation (ctx context.Context , req * cli_service.TCancelOperationReq ) (* cli_service.TCancelOperationResp , error ) {
199+ ctx = context .WithValue (ctx , ClientMethod , cancelOperation )
166200 log := logger .WithContext (driverctx .ConnIdFromContext (ctx ), driverctx .CorrelationIdFromContext (ctx ), SprintGuid (req .OperationHandle .OperationId .GUID ))
167201 defer log .Duration (logger .Track ("CancelOperation" ))
168202 resp , err := tsc .TCLIServiceClient .CancelOperation (ctx , req )
@@ -277,15 +311,11 @@ func SprintGuid(bts []byte) string {
277311 return fmt .Sprintf ("%x" , bts )
278312}
279313
280- var retryableStatusCode = [] int {http .StatusTooManyRequests , http .StatusServiceUnavailable }
314+ var retryableStatusCodes = map [ int ] any {http .StatusTooManyRequests : struct {}{} , http .StatusServiceUnavailable : struct {}{} }
281315
282- func isRetryable (statusCode int ) bool {
283- for _ , c := range retryableStatusCode {
284- if c == statusCode {
285- return true
286- }
287- }
288- return false
316+ func isRetryableServerResponse (resp * http.Response ) bool {
317+ _ , ok := retryableStatusCodes [resp .StatusCode ]
318+ return ok
289319}
290320
291321type Transport struct {
@@ -324,31 +354,8 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
324354 // req.Body is assumed to be closed by the base RoundTripper.
325355 reqBodyClosed = true
326356 resp , err := t .Base .RoundTrip (req2 )
327- if err != nil {
328- return nil , err
329- }
330- if resp .StatusCode != http .StatusOK {
331- reason := resp .Header .Get ("X-Databricks-Reason-Phrase" )
332- terrmsg := resp .Header .Get ("X-Thriftserver-Error-Message" )
333- if isRetryable (resp .StatusCode ) {
334- if terrmsg != "" {
335- logger .Warn ().Msg (terrmsg )
336- }
337- return resp , nil
338- }
339357
340- if reason != "" {
341- logger .Err (fmt .Errorf (reason )).Msg ("non retryable error" )
342- return nil , errors .New (reason )
343- }
344- if terrmsg != "" {
345- logger .Err (fmt .Errorf (terrmsg )).Msg ("non retryable error" )
346- return nil , errors .New (terrmsg )
347- }
348- return nil , errors .New (resp .Status )
349- }
350-
351- return resp , nil
358+ return resp , err
352359}
353360
354361func RetryableClient (cfg * config.Config ) * http.Client {
@@ -361,7 +368,7 @@ func RetryableClient(cfg *config.Config) *http.Client {
361368 RetryMax : cfg .RetryMax ,
362369 ErrorHandler : errorHandler ,
363370 CheckRetry : RetryPolicy ,
364- Backoff : retryablehttp . DefaultBackoff ,
371+ Backoff : backoff ,
365372 }
366373 return retryableClient .StandardClient ()
367374}
@@ -431,59 +438,120 @@ func (l *leveledLogger) Warn(msg string, keysAndValues ...interface{}) {
431438
432439func errorHandler (resp * http.Response , err error , numTries int ) (* http.Response , error ) {
433440 var werr error
441+ msg := fmt .Sprintf ("request error after %d attempt(s)" , numTries )
434442 if err == nil {
435- err = errors .New (fmt .Sprintf ("request error after %d attempt(s)" , numTries ))
443+ werr = errors .New (msg )
444+ } else {
445+ werr = errors .Wrap (err , msg )
436446 }
437447
438448 if resp != nil {
439- var orgid , reason , terrmsg , errmsg , retryAfter string
440- // TODO @mattdeekay: convert these to specific error types
441449 if resp .Header != nil {
442- orgid = resp .Header .Get ("X-Databricks-Org-Id" )
443- reason = resp .Header .Get ("X-Databricks-Reason-Phrase" ) // TODO note: shown on notebook
444- terrmsg = resp .Header .Get ("X-Thriftserver-Error-Message" )
445- errmsg = resp .Header .Get ("x-databricks-error-or-redirect-message" )
446- retryAfter = resp .Header .Get ("Retry-After" )
447- // TODO note: need to see if there's other headers
448- }
449- msg := fmt .Sprintf ("orgId: %s, reason: %s, thriftErr: %s, err: %s" , orgid , reason , terrmsg , errmsg )
450+ reason := resp .Header .Get ("X-Databricks-Reason-Phrase" )
451+ terrmsg := resp .Header .Get ("X-Thriftserver-Error-Message" )
450452
451- if isRetryable (resp .StatusCode ) {
452- err = dbsqlerrint .NewRetryableError (err , retryAfter )
453+ if reason != "" {
454+ werr = dbsqlerrint .WrapErr (werr , reason )
455+ } else if terrmsg != "" {
456+ werr = dbsqlerrint .WrapErr (werr , terrmsg )
457+ }
453458 }
454459
455- werr = dbsqlerrint .WrapErr (err , msg )
456- } else {
457- werr = err
460+ logger .Err (werr ).Msg (resp .Status )
458461 }
459462
460463 return resp , werr
461464}
462465
463- func RetryPolicy (ctx context.Context , resp * http.Response , err error ) (bool , error ) {
464- var lostConn = regexp .MustCompile (`EOF` )
466+ var (
467+ // A regular expression to match the error returned by net/http when the
468+ // configured number of redirects is exhausted. This error isn't typed
469+ // specifically so we resort to matching on the error string.
470+ redirectsErrorRe = regexp .MustCompile (`stopped after \d+ redirects\z` )
471+
472+ // A regular expression to match the error returned by net/http when the
473+ // scheme specified in the URL is invalid. This error isn't typed
474+ // specifically so we resort to matching on the error string.
475+ schemeErrorRe = regexp .MustCompile (`unsupported protocol scheme` )
465476
477+ // A regular expression to match the error returned by net/http when the
478+ // TLS certificate is not trusted. This error isn't typed
479+ // specifically so we resort to matching on the error string.
480+ notTrustedErrorRe = regexp .MustCompile (`certificate is not trusted` )
481+
482+ errorRes = []* regexp.Regexp {redirectsErrorRe , schemeErrorRe , notTrustedErrorRe }
483+ )
484+
485+ func RetryPolicy (ctx context.Context , resp * http.Response , err error ) (bool , error ) {
466486 // do not retry on context.Canceled or context.DeadlineExceeded
467487 if ctx .Err () != nil {
468488 return false , ctx .Err ()
469489 }
470490
471491 if err != nil {
472492 if v , ok := err .(* url.Error ); ok {
473- if lostConn .MatchString (v .Error ()) {
474- return true , v
493+ s := v .Error ()
494+ for _ , re := range errorRes {
495+ if re .MatchString (s ) {
496+ return false , v
497+ }
498+ }
499+
500+ if _ , ok := v .Err .(x509.UnknownAuthorityError ); ok {
501+ return false , v
475502 }
476503 }
477- return false , nil
504+
505+ // The error is likely recoverable so retry.
506+ return true , nil
507+ }
508+
509+ var checkErr error
510+ if resp .StatusCode != http .StatusOK {
511+ checkErr = fmt .Errorf ("unexpected HTTP status %s" , resp .Status )
478512 }
479513
480514 // 429 Too Many Requests or 503 service unavailable is recoverable. Sometimes the server puts
481515 // a Retry-After response header to indicate when the server is
482516 // available to start processing request from client.
483- if isRetryable (resp .StatusCode ) {
484- return true , nil
517+ if isRetryableServerResponse (resp ) {
518+ var retryAfter string
519+ if resp .Header != nil {
520+ retryAfter = resp .Header .Get ("Retry-After" )
521+ }
522+
523+ return true , dbsqlerrint .NewRetryableError (checkErr , retryAfter )
485524 }
486525
487- return false , nil
526+ if resp .StatusCode == 0 || (resp .StatusCode >= 500 && resp .StatusCode != http .StatusNotImplemented ) {
527+ callerAny := ctx .Value (ClientMethod )
528+ if caller , ok := callerAny .(clientMethod ); ok {
529+ if _ , noRetry := nonRetryableClientMethods [caller ]; ! noRetry {
530+ return true , checkErr
531+ }
532+ }
533+ }
488534
535+ // checkErr will be non-nil if the response code was not StatusOK.
536+ // Returning it here ensures that the error handler will be called.
537+ return false , checkErr
538+ }
539+
540+ func backoff (min , max time.Duration , attemptNum int , resp * http.Response ) time.Duration {
541+ // honour the Retry-After header
542+ if resp != nil && resp .Header != nil {
543+ if s , ok := resp .Header ["Retry-After" ]; ok {
544+ if sleep , err := strconv .ParseInt (s [0 ], 10 , 64 ); err == nil {
545+ return time .Second * time .Duration (sleep )
546+ }
547+ }
548+ }
549+
550+ // exponential backoff
551+ mult := math .Pow (2 , float64 (attemptNum )) * float64 (min )
552+ sleep := time .Duration (mult )
553+ if float64 (sleep ) != mult || sleep > max {
554+ sleep = max
555+ }
556+ return sleep
489557}
0 commit comments