Skip to content

Commit 2661951

Browse files
Updated retry behaviour (#125)
- retry on 500 status codes if the server operation is idempotent - retry on request errors except: too many redirects, invalid protocol scheme, TLS cert verification failure Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
2 parents 8103ae1 + 565dc2c commit 2661951

4 files changed

Lines changed: 399 additions & 59 deletions

File tree

internal/client/client.go

Lines changed: 125 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@ package client
22

33
import (
44
"context"
5+
"crypto/x509"
56
"encoding/json"
67
"fmt"
78
"log"
9+
"math"
810
"net"
911
"net/http"
1012
"net/http/httptrace"
1113
"net/url"
1214
"os"
1315
"regexp"
16+
"strconv"
1417
"time"
1518

1619
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
@@ -34,9 +37,31 @@ type ThriftServiceClient struct {
3437
*cli_service.TCLIServiceClient
3538
}
3639

40+
type contextKey int
41+
42+
const (
43+
ClientMethod contextKey = iota
44+
)
45+
46+
type clientMethod int
47+
48+
const (
49+
openSession clientMethod = iota
50+
closeSession
51+
fetchResults
52+
getResultSetMetadata
53+
executeStatement
54+
getOperationStatus
55+
closeOperation
56+
cancelOperation
57+
)
58+
59+
var nonRetryableClientMethods map[clientMethod]any = map[clientMethod]any{executeStatement: struct{}{}}
60+
3761
// OpenSession is a wrapper around the thrift operation OpenSession
3862
// If RecordResults is true, the results will be marshalled to JSON format and written to OpenSession<index>.json
3963
func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
64+
ctx = context.WithValue(ctx, ClientMethod, openSession)
4065
msg, start := logger.Track("OpenSession")
4166
resp, err := tsc.TCLIServiceClient.OpenSession(ctx, req)
4267
if err != nil {
@@ -55,6 +80,7 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
5580
// CloseSession is a wrapper around the thrift operation CloseSession
5681
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseSession<index>.json
5782
func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (*cli_service.TCloseSessionResp, error) {
83+
ctx = context.WithValue(ctx, ClientMethod, closeSession)
5884
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), "")
5985
defer log.Duration(logger.Track("CloseSession"))
6086
resp, err := tsc.TCLIServiceClient.CloseSession(ctx, req)
@@ -72,6 +98,7 @@ func (tsc *ThriftServiceClient) CloseSession(ctx context.Context, req *cli_servi
7298
// FetchResults is a wrapper around the thrift operation FetchResults
7399
// If RecordResults is true, the results will be marshalled to JSON format and written to FetchResults<index>.json
74100
func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
101+
ctx = context.WithValue(ctx, ClientMethod, fetchResults)
75102
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
76103
defer log.Duration(logger.Track("FetchResults"))
77104
resp, err := tsc.TCLIServiceClient.FetchResults(ctx, req)
@@ -89,6 +116,7 @@ func (tsc *ThriftServiceClient) FetchResults(ctx context.Context, req *cli_servi
89116
// GetResultSetMetadata is a wrapper around the thrift operation GetResultSetMetadata
90117
// If RecordResults is true, the results will be marshalled to JSON format and written to GetResultSetMetadata<index>.json
91118
func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *cli_service.TGetResultSetMetadataReq) (*cli_service.TGetResultSetMetadataResp, error) {
119+
ctx = context.WithValue(ctx, ClientMethod, getResultSetMetadata)
92120
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
93121
defer log.Duration(logger.Track("GetResultSetMetadata"))
94122
resp, err := tsc.TCLIServiceClient.GetResultSetMetadata(ctx, req)
@@ -106,7 +134,10 @@ func (tsc *ThriftServiceClient) GetResultSetMetadata(ctx context.Context, req *c
106134
// ExecuteStatement is a wrapper around the thrift operation ExecuteStatement
107135
// If RecordResults is true, the results will be marshalled to JSON format and written to ExecuteStatement<index>.json
108136
func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
137+
ctx = context.WithValue(ctx, ClientMethod, executeStatement)
109138
msg, start := logger.Track("ExecuteStatement")
139+
140+
// We use context.Background to fix a problem where on context done the query would not be cancelled.
110141
resp, err := tsc.TCLIServiceClient.ExecuteStatement(context.Background(), req)
111142
if err != nil {
112143
return resp, dbsqlerrint.NewRequestError(ctx, "execute statement request error", err)
@@ -129,6 +160,7 @@ func (tsc *ThriftServiceClient) ExecuteStatement(ctx context.Context, req *cli_s
129160
// GetOperationStatus is a wrapper around the thrift operation GetOperationStatus
130161
// If RecordResults is true, the results will be marshalled to JSON format and written to GetOperationStatus<index>.json
131162
func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli_service.TGetOperationStatusReq) (*cli_service.TGetOperationStatusResp, error) {
163+
ctx = context.WithValue(ctx, ClientMethod, getOperationStatus)
132164
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
133165
defer log.Duration(logger.Track("GetOperationStatus"))
134166
resp, err := tsc.TCLIServiceClient.GetOperationStatus(ctx, req)
@@ -146,6 +178,7 @@ func (tsc *ThriftServiceClient) GetOperationStatus(ctx context.Context, req *cli
146178
// CloseOperation is a wrapper around the thrift operation CloseOperation
147179
// If RecordResults is true, the results will be marshalled to JSON format and written to CloseOperation<index>.json
148180
func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_service.TCloseOperationReq) (*cli_service.TCloseOperationResp, error) {
181+
ctx = context.WithValue(ctx, ClientMethod, closeOperation)
149182
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
150183
defer log.Duration(logger.Track("CloseOperation"))
151184
resp, err := tsc.TCLIServiceClient.CloseOperation(ctx, req)
@@ -163,6 +196,7 @@ func (tsc *ThriftServiceClient) CloseOperation(ctx context.Context, req *cli_ser
163196
// CancelOperation is a wrapper around the thrift operation CancelOperation
164197
// If RecordResults is true, the results will be marshalled to JSON format and written to CancelOperation<index>.json
165198
func (tsc *ThriftServiceClient) CancelOperation(ctx context.Context, req *cli_service.TCancelOperationReq) (*cli_service.TCancelOperationResp, error) {
199+
ctx = context.WithValue(ctx, ClientMethod, cancelOperation)
166200
log := logger.WithContext(driverctx.ConnIdFromContext(ctx), driverctx.CorrelationIdFromContext(ctx), SprintGuid(req.OperationHandle.OperationId.GUID))
167201
defer log.Duration(logger.Track("CancelOperation"))
168202
resp, err := tsc.TCLIServiceClient.CancelOperation(ctx, req)
@@ -277,15 +311,11 @@ func SprintGuid(bts []byte) string {
277311
return fmt.Sprintf("%x", bts)
278312
}
279313

280-
var retryableStatusCode = []int{http.StatusTooManyRequests, http.StatusServiceUnavailable}
314+
var retryableStatusCodes = map[int]any{http.StatusTooManyRequests: struct{}{}, http.StatusServiceUnavailable: struct{}{}}
281315

282-
func isRetryable(statusCode int) bool {
283-
for _, c := range retryableStatusCode {
284-
if c == statusCode {
285-
return true
286-
}
287-
}
288-
return false
316+
func isRetryableServerResponse(resp *http.Response) bool {
317+
_, ok := retryableStatusCodes[resp.StatusCode]
318+
return ok
289319
}
290320

291321
type Transport struct {
@@ -324,31 +354,8 @@ func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
324354
// req.Body is assumed to be closed by the base RoundTripper.
325355
reqBodyClosed = true
326356
resp, err := t.Base.RoundTrip(req2)
327-
if err != nil {
328-
return nil, err
329-
}
330-
if resp.StatusCode != http.StatusOK {
331-
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
332-
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
333-
if isRetryable(resp.StatusCode) {
334-
if terrmsg != "" {
335-
logger.Warn().Msg(terrmsg)
336-
}
337-
return resp, nil
338-
}
339357

340-
if reason != "" {
341-
logger.Err(fmt.Errorf(reason)).Msg("non retryable error")
342-
return nil, errors.New(reason)
343-
}
344-
if terrmsg != "" {
345-
logger.Err(fmt.Errorf(terrmsg)).Msg("non retryable error")
346-
return nil, errors.New(terrmsg)
347-
}
348-
return nil, errors.New(resp.Status)
349-
}
350-
351-
return resp, nil
358+
return resp, err
352359
}
353360

354361
func RetryableClient(cfg *config.Config) *http.Client {
@@ -361,7 +368,7 @@ func RetryableClient(cfg *config.Config) *http.Client {
361368
RetryMax: cfg.RetryMax,
362369
ErrorHandler: errorHandler,
363370
CheckRetry: RetryPolicy,
364-
Backoff: retryablehttp.DefaultBackoff,
371+
Backoff: backoff,
365372
}
366373
return retryableClient.StandardClient()
367374
}
@@ -431,59 +438,120 @@ func (l *leveledLogger) Warn(msg string, keysAndValues ...interface{}) {
431438

432439
func errorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
433440
var werr error
441+
msg := fmt.Sprintf("request error after %d attempt(s)", numTries)
434442
if err == nil {
435-
err = errors.New(fmt.Sprintf("request error after %d attempt(s)", numTries))
443+
werr = errors.New(msg)
444+
} else {
445+
werr = errors.Wrap(err, msg)
436446
}
437447

438448
if resp != nil {
439-
var orgid, reason, terrmsg, errmsg, retryAfter string
440-
// TODO @mattdeekay: convert these to specific error types
441449
if resp.Header != nil {
442-
orgid = resp.Header.Get("X-Databricks-Org-Id")
443-
reason = resp.Header.Get("X-Databricks-Reason-Phrase") // TODO note: shown on notebook
444-
terrmsg = resp.Header.Get("X-Thriftserver-Error-Message")
445-
errmsg = resp.Header.Get("x-databricks-error-or-redirect-message")
446-
retryAfter = resp.Header.Get("Retry-After")
447-
// TODO note: need to see if there's other headers
448-
}
449-
msg := fmt.Sprintf("orgId: %s, reason: %s, thriftErr: %s, err: %s", orgid, reason, terrmsg, errmsg)
450+
reason := resp.Header.Get("X-Databricks-Reason-Phrase")
451+
terrmsg := resp.Header.Get("X-Thriftserver-Error-Message")
450452

451-
if isRetryable(resp.StatusCode) {
452-
err = dbsqlerrint.NewRetryableError(err, retryAfter)
453+
if reason != "" {
454+
werr = dbsqlerrint.WrapErr(werr, reason)
455+
} else if terrmsg != "" {
456+
werr = dbsqlerrint.WrapErr(werr, terrmsg)
457+
}
453458
}
454459

455-
werr = dbsqlerrint.WrapErr(err, msg)
456-
} else {
457-
werr = err
460+
logger.Err(werr).Msg(resp.Status)
458461
}
459462

460463
return resp, werr
461464
}
462465

463-
func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
464-
var lostConn = regexp.MustCompile(`EOF`)
466+
var (
467+
// A regular expression to match the error returned by net/http when the
468+
// configured number of redirects is exhausted. This error isn't typed
469+
// specifically so we resort to matching on the error string.
470+
redirectsErrorRe = regexp.MustCompile(`stopped after \d+ redirects\z`)
471+
472+
// A regular expression to match the error returned by net/http when the
473+
// scheme specified in the URL is invalid. This error isn't typed
474+
// specifically so we resort to matching on the error string.
475+
schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`)
465476

477+
// A regular expression to match the error returned by net/http when the
478+
// TLS certificate is not trusted. This error isn't typed
479+
// specifically so we resort to matching on the error string.
480+
notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`)
481+
482+
errorRes = []*regexp.Regexp{redirectsErrorRe, schemeErrorRe, notTrustedErrorRe}
483+
)
484+
485+
func RetryPolicy(ctx context.Context, resp *http.Response, err error) (bool, error) {
466486
// do not retry on context.Canceled or context.DeadlineExceeded
467487
if ctx.Err() != nil {
468488
return false, ctx.Err()
469489
}
470490

471491
if err != nil {
472492
if v, ok := err.(*url.Error); ok {
473-
if lostConn.MatchString(v.Error()) {
474-
return true, v
493+
s := v.Error()
494+
for _, re := range errorRes {
495+
if re.MatchString(s) {
496+
return false, v
497+
}
498+
}
499+
500+
if _, ok := v.Err.(x509.UnknownAuthorityError); ok {
501+
return false, v
475502
}
476503
}
477-
return false, nil
504+
505+
// The error is likely recoverable so retry.
506+
return true, nil
507+
}
508+
509+
var checkErr error
510+
if resp.StatusCode != http.StatusOK {
511+
checkErr = fmt.Errorf("unexpected HTTP status %s", resp.Status)
478512
}
479513

480514
// 429 Too Many Requests or 503 service unavailable is recoverable. Sometimes the server puts
481515
// a Retry-After response header to indicate when the server is
482516
// available to start processing request from client.
483-
if isRetryable(resp.StatusCode) {
484-
return true, nil
517+
if isRetryableServerResponse(resp) {
518+
var retryAfter string
519+
if resp.Header != nil {
520+
retryAfter = resp.Header.Get("Retry-After")
521+
}
522+
523+
return true, dbsqlerrint.NewRetryableError(checkErr, retryAfter)
485524
}
486525

487-
return false, nil
526+
if resp.StatusCode == 0 || (resp.StatusCode >= 500 && resp.StatusCode != http.StatusNotImplemented) {
527+
callerAny := ctx.Value(ClientMethod)
528+
if caller, ok := callerAny.(clientMethod); ok {
529+
if _, noRetry := nonRetryableClientMethods[caller]; !noRetry {
530+
return true, checkErr
531+
}
532+
}
533+
}
488534

535+
// checkErr will be non-nil if the response code was not StatusOK.
536+
// Returning it here ensures that the error handler will be called.
537+
return false, checkErr
538+
}
539+
540+
func backoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration {
541+
// honour the Retry-After header
542+
if resp != nil && resp.Header != nil {
543+
if s, ok := resp.Header["Retry-After"]; ok {
544+
if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil {
545+
return time.Second * time.Duration(sleep)
546+
}
547+
}
548+
}
549+
550+
// exponential backoff
551+
mult := math.Pow(2, float64(attemptNum)) * float64(min)
552+
sleep := time.Duration(mult)
553+
if float64(sleep) != mult || sleep > max {
554+
sleep = max
555+
}
556+
return sleep
489557
}

0 commit comments

Comments
 (0)