Skip to content

Commit 34599c4

Browse files
Initial oauth implementation (#122)
Oauth implementation. Added m2m authenticator. Added basic u2m authenticator. Implementation is broken out into public functions that clients can use to implement their own authenticator. Allow specifying auth type and parameters in dsn.
2 parents 7a177c9 + 74df4bf commit 34599c4

17 files changed

Lines changed: 1388 additions & 275 deletions

File tree

auth/auth.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,40 @@
11
package auth
22

3-
import "net/http"
3+
import (
4+
"net/http"
5+
"strings"
6+
)
47

58
type Authenticator interface {
69
Authenticate(*http.Request) error
710
}
11+
12+
type AuthType int
13+
14+
const (
15+
AuthTypeUnknown AuthType = iota
16+
AuthTypePat
17+
AuthTypeOauthU2M
18+
AuthTypeOauthM2M
19+
)
20+
21+
var authTypeNames []string = []string{"Unknown", "Pat", "OauthU2M", "OauthM2M"}
22+
23+
func (at AuthType) String() string {
24+
if at >= 0 && int(at) < len(authTypeNames) {
25+
return authTypeNames[at]
26+
}
27+
28+
return authTypeNames[0]
29+
}
30+
31+
func ParseAuthType(typeString string) AuthType {
32+
typeString = strings.ToLower(typeString)
33+
for i, n := range authTypeNames {
34+
if strings.ToLower(n) == typeString {
35+
return AuthType(i)
36+
}
37+
}
38+
39+
return AuthTypeUnknown
40+
}

auth/oauth/m2m/m2m.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package m2m
2+
3+
// clientid e92aa085-4875-42fe-ad75-ba38fb3c9706
4+
// secretid vUdzecmn4aUi2jRDamaBOy3qThu9LSgeV_BW4UnQ
5+
6+
import (
7+
"context"
8+
"fmt"
9+
"net/http"
10+
"sync"
11+
12+
"github.com/databricks/databricks-sql-go/auth"
13+
"github.com/databricks/databricks-sql-go/auth/oauth"
14+
"github.com/rs/zerolog/log"
15+
"golang.org/x/oauth2"
16+
"golang.org/x/oauth2/clientcredentials"
17+
)
18+
19+
func NewAuthenticator(clientID, clientSecret, hostName string) auth.Authenticator {
20+
scopes := oauth.GetScopes(hostName, []string{})
21+
return &authClient{
22+
clientID: clientID,
23+
clientSecret: clientSecret,
24+
hostName: hostName,
25+
scopes: scopes,
26+
}
27+
}
28+
29+
type authClient struct {
30+
clientID string
31+
clientSecret string
32+
hostName string
33+
scopes []string
34+
tokenSource oauth2.TokenSource
35+
mx sync.Mutex
36+
}
37+
38+
// Auth will start the OAuth Authorization Flow to authenticate the cli client
39+
// using the users credentials in the browser. Compatible with SSO.
40+
func (c *authClient) Authenticate(r *http.Request) error {
41+
c.mx.Lock()
42+
defer c.mx.Unlock()
43+
if c.tokenSource != nil {
44+
token, err := c.tokenSource.Token()
45+
if err != nil {
46+
return err
47+
}
48+
token.SetAuthHeader(r)
49+
return nil
50+
}
51+
52+
config, err := GetConfig(context.Background(), c.hostName, c.clientID, c.clientSecret, c.scopes)
53+
if err != nil {
54+
return fmt.Errorf("unable to generate clientCredentials.Config: %w", err)
55+
}
56+
57+
c.tokenSource = GetTokenSource(config)
58+
token, err := c.tokenSource.Token()
59+
log.Info().Msgf("token fetched successfully")
60+
if err != nil {
61+
log.Err(err).Msg("failed to get token")
62+
63+
return err
64+
}
65+
token.SetAuthHeader(r)
66+
67+
return nil
68+
69+
}
70+
71+
func GetTokenSource(config clientcredentials.Config) oauth2.TokenSource {
72+
tokenSource := config.TokenSource(context.Background())
73+
return tokenSource
74+
}
75+
76+
func GetConfig(ctx context.Context, issuerURL, clientID, clientSecret string, scopes []string) (clientcredentials.Config, error) {
77+
// Get the endpoint based on the host name
78+
endpoint, err := oauth.GetEndpoint(ctx, issuerURL)
79+
if err != nil {
80+
return clientcredentials.Config{}, fmt.Errorf("could not lookup provider details: %w", err)
81+
}
82+
83+
config := clientcredentials.Config{
84+
ClientID: clientID,
85+
ClientSecret: clientSecret,
86+
TokenURL: endpoint.TokenURL,
87+
Scopes: scopes,
88+
}
89+
90+
return config, nil
91+
}

auth/oauth/oauth.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package oauth
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"strings"
8+
9+
"github.com/coreos/go-oidc/v3/oidc"
10+
"golang.org/x/oauth2"
11+
)
12+
13+
const (
14+
azureTenantId = "4a67d088-db5c-48f1-9ff2-0aace800ae68"
15+
)
16+
17+
func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) {
18+
if ctx == nil {
19+
ctx = context.Background()
20+
}
21+
22+
cloud := InferCloudFromHost(hostName)
23+
24+
if cloud == Unknown {
25+
return oauth2.Endpoint{}, errors.New("unsupported cloud type")
26+
}
27+
28+
if cloud == Azure {
29+
authURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/authorize", hostName)
30+
tokenURL := fmt.Sprintf("https://%s/oidc/oauth2/v2.0/token", hostName)
31+
return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil
32+
}
33+
34+
issuerURL := fmt.Sprintf("https://%s/oidc", hostName)
35+
ctx = oidc.InsecureIssuerURLContext(ctx, issuerURL)
36+
provider, err := oidc.NewProvider(ctx, issuerURL)
37+
if err != nil {
38+
return oauth2.Endpoint{}, err
39+
}
40+
41+
endpoint := provider.Endpoint()
42+
43+
return endpoint, err
44+
}
45+
46+
func GetScopes(hostName string, scopes []string) []string {
47+
for _, s := range []string{oidc.ScopeOfflineAccess} {
48+
if !hasScope(scopes, s) {
49+
scopes = append(scopes, s)
50+
}
51+
}
52+
53+
cloudType := InferCloudFromHost(hostName)
54+
if cloudType == Azure {
55+
userImpersonationScope := fmt.Sprintf("%s/user_impersonation", azureTenantId)
56+
if !hasScope(scopes, userImpersonationScope) {
57+
scopes = append(scopes, userImpersonationScope)
58+
}
59+
} else {
60+
if !hasScope(scopes, "sql") {
61+
scopes = append(scopes, "sql")
62+
}
63+
}
64+
65+
return scopes
66+
}
67+
68+
func hasScope(scopes []string, scope string) bool {
69+
for _, s := range scopes {
70+
if s == scope {
71+
return true
72+
}
73+
}
74+
return false
75+
}
76+
77+
var databricksAWSDomains []string = []string{
78+
".cloud.databricks.com",
79+
".dev.databricks.com",
80+
}
81+
82+
var databricksAzureDomains []string = []string{
83+
".azuredatabricks.net",
84+
".databricks.azure.cn",
85+
".databricks.azure.us",
86+
}
87+
88+
type CloudType int
89+
90+
const (
91+
AWS = iota
92+
Azure
93+
Unknown
94+
)
95+
96+
func (cl CloudType) String() string {
97+
switch cl {
98+
case AWS:
99+
return "AWS"
100+
case Azure:
101+
return "Azure"
102+
}
103+
104+
return "Unknown"
105+
}
106+
107+
func InferCloudFromHost(hostname string) CloudType {
108+
109+
for _, d := range databricksAzureDomains {
110+
if strings.Contains(hostname, d) {
111+
return Azure
112+
}
113+
}
114+
115+
for _, d := range databricksAWSDomains {
116+
if strings.Contains(hostname, d) {
117+
return AWS
118+
}
119+
}
120+
121+
return Unknown
122+
}

auth/oauth/pkce/pkce.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package pkce
2+
3+
import (
4+
"crypto/rand"
5+
"crypto/sha256"
6+
"encoding/base64"
7+
"encoding/hex"
8+
"fmt"
9+
"io"
10+
11+
"golang.org/x/oauth2"
12+
)
13+
14+
// Generate generates a new random PKCE code.
15+
func Generate() (Code, error) { return generate(rand.Reader) }
16+
17+
func generate(rand io.Reader) (Code, error) {
18+
// From https://tools.ietf.org/html/rfc7636#section-4.1:
19+
// code_verifier = high-entropy cryptographic random STRING using the
20+
// unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
21+
// from Section 2.3 of [RFC3986], with a minimum length of 43 characters
22+
// and a maximum length of 128 characters.
23+
var buf [32]byte
24+
if _, err := io.ReadFull(rand, buf[:]); err != nil {
25+
return "", fmt.Errorf("could not generate PKCE code: %w", err)
26+
}
27+
return Code(hex.EncodeToString(buf[:])), nil
28+
}
29+
30+
// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE).
31+
type Code string
32+
33+
// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge.
34+
func (p *Code) Challenge() oauth2.AuthCodeOption {
35+
b := sha256.Sum256([]byte(*p))
36+
return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:]))
37+
}
38+
39+
// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method.
40+
func (p *Code) Method() oauth2.AuthCodeOption {
41+
return oauth2.SetAuthURLParam("code_challenge_method", "S256")
42+
}
43+
44+
// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier.
45+
func (p *Code) Verifier() oauth2.AuthCodeOption {
46+
return oauth2.SetAuthURLParam("code_verifier", string(*p))
47+
}

0 commit comments

Comments
 (0)