Skip to content

Commit 0127175

Browse files
Moved default authenticator into internal package.
Fixed error in default authenticator where it tried to register the same http handler multiple times. Added config decorator functions WithDefaultOAUTH, WithClientCredentials Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
1 parent b63e445 commit 0127175

17 files changed

Lines changed: 497 additions & 278 deletions

File tree

auth/oauth/dev/dev.go

Lines changed: 0 additions & 235 deletions
This file was deleted.

auth/oauth/oauth.go

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error)
2020
}
2121

2222
cloud := InferCloudFromHost(hostName)
23-
if cloud == unknown {
23+
24+
if cloud == Unknown {
2425
return oauth2.Endpoint{}, errors.New("unsupported cloud type")
2526
}
2627

27-
if cloud == azure {
28+
if cloud == Azure {
2829
authURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/authorize", hostName)
2930
tokenURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/token", hostName)
3031
return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil
@@ -50,7 +51,7 @@ func GetScopes(hostName string, scopes []string) []string {
5051
}
5152

5253
cloudType := InferCloudFromHost(hostName)
53-
if cloudType == azure {
54+
if cloudType == Azure {
5455
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureEnterpriseAppId)
5556
if !hasScope(scopes, userImpersonationScope) {
5657
scopes = append(scopes, userImpersonationScope)
@@ -86,34 +87,47 @@ var databricksAzureDomains []string = []string{
8687

8788
var databricksGCPDomains []string = []string{".gcp.databricks.com"}
8889

89-
type cloudType int
90+
type CloudType int
9091

9192
const (
92-
aws = iota
93-
azure
94-
gcp
95-
unknown
93+
AWS = iota
94+
Azure
95+
GCP
96+
Unknown
9697
)
9798

98-
func InferCloudFromHost(hostname string) cloudType {
99+
func (cl CloudType) String() string {
100+
switch cl {
101+
case AWS:
102+
return "AWS"
103+
case Azure:
104+
return "Azure"
105+
case GCP:
106+
return "GCP"
107+
}
108+
109+
return "Unknown"
110+
}
111+
112+
func InferCloudFromHost(hostname string) CloudType {
99113

100114
for _, d := range databricksAzureDomains {
101115
if strings.Contains(hostname, d) {
102-
return azure
116+
return Azure
103117
}
104118
}
105119

106120
for _, d := range databricksAWSDomains {
107121
if strings.Contains(hostname, d) {
108-
return aws
122+
return AWS
109123
}
110124
}
111125

112126
for _, d := range databricksGCPDomains {
113127
if strings.Contains(hostname, d) {
114-
return gcp
128+
return GCP
115129
}
116130
}
117131

118-
return unknown
132+
return Unknown
119133
}

auth/oauth/u2m/u2m.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ import (
66
"strings"
77

88
"github.com/databricks/databricks-sql-go/auth/oauth"
9-
"github.com/databricks/databricks-sql-go/auth/oauth/pkce"
9+
"github.com/databricks/databricks-sql-go/internal/auth/oauth/pkce"
1010
"golang.org/x/oauth2"
1111
)
1212

1313
func GetConfig(ctx context.Context, hostName, clientID, clientSecret, callbackURL string, scopes []string) (oauth2.Config, error) {
14+
// Add necessary scopes for AWS or Azure
1415
scopes = oauth.GetScopes(hostName, scopes)
1516

17+
// Get the endpoint based on the host name
1618
endpoint, err := oauth.GetEndpoint(ctx, hostName)
1719
if err != nil {
1820
return oauth2.Config{}, fmt.Errorf("could not lookup provider details: %w", err)

connector.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ import (
99
"time"
1010

1111
"github.com/databricks/databricks-sql-go/auth"
12-
"github.com/databricks/databricks-sql-go/auth/pat"
1312
"github.com/databricks/databricks-sql-go/driverctx"
1413
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
14+
"github.com/databricks/databricks-sql-go/internal/auth/oauth/m2m"
15+
"github.com/databricks/databricks-sql-go/internal/auth/pat"
1516
"github.com/databricks/databricks-sql-go/internal/cli_service"
1617
"github.com/databricks/databricks-sql-go/internal/client"
1718
"github.com/databricks/databricks-sql-go/internal/config"
@@ -259,3 +260,12 @@ func WithMaxDownloadThreads(numThreads int) connOption {
259260
c.MaxDownloadThreads = numThreads
260261
}
261262
}
263+
264+
func WithClientCredentials(clientID, clientSecret string) connOption {
265+
return func(c *config.Config) {
266+
if clientID != "" && clientSecret != "" {
267+
authr := m2m.NewClient(clientID, clientSecret, c.Host)
268+
c.Authenticator = authr
269+
}
270+
}
271+
}

connector_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"testing"
66
"time"
77

8-
"github.com/databricks/databricks-sql-go/auth/pat"
8+
"github.com/databricks/databricks-sql-go/internal/auth/pat"
99
"github.com/databricks/databricks-sql-go/internal/config"
1010
"github.com/stretchr/testify/assert"
1111
"github.com/stretchr/testify/require"

driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"fmt"
55
"testing"
66

7-
"github.com/databricks/databricks-sql-go/auth/pat"
7+
"github.com/databricks/databricks-sql-go/internal/auth/pat"
88
"github.com/databricks/databricks-sql-go/internal/config"
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"

errors/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const (
2222
ErrInvalidURL = "invalid URL"
2323

2424
ErrNoAuthenticationMethod = "no authentication method set"
25+
ErrNoDefaultAuthenticator = "unable to create default authenticator"
2526
ErrInvalidDSNFormat = "invalid DSN: invalid format"
2627
ErrInvalidDSNPort = "invalid DSN: invalid DSN port"
2728
ErrInvalidDSNPATIsEmpty = "invalid DSN: empty token"

0 commit comments

Comments
 (0)