Skip to content

Commit b63e445

Browse files
Initial oauth implementation
Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
1 parent 7e079fd commit b63e445

11 files changed

Lines changed: 884 additions & 4 deletions

File tree

auth/oauth/dev/dev.go

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
package dev
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"encoding/base64"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"net"
11+
"net/http"
12+
"net/url"
13+
"os"
14+
"os/exec"
15+
"os/signal"
16+
"runtime"
17+
"strings"
18+
"sync"
19+
"time"
20+
21+
"github.com/databricks/databricks-sql-go/auth"
22+
"github.com/databricks/databricks-sql-go/auth/oauth/u2m"
23+
"github.com/rs/zerolog/log"
24+
"golang.org/x/oauth2"
25+
)
26+
27+
const (
28+
LISTEN_ADDR = "localhost:8020"
29+
)
30+
31+
func NewDevAuthenticator(clientID, hostName string, scopes ...string) auth.Authenticator {
32+
return &devAuthenticator{
33+
clientID: clientID,
34+
hostName: hostName,
35+
scopes: scopes,
36+
}
37+
}
38+
39+
type devAuthenticator struct {
40+
clientID string
41+
hostName string
42+
scopes []string
43+
tokenSource oauth2.TokenSource
44+
mx sync.Mutex
45+
}
46+
47+
// Auth will start the OAuth Authorization Flow to authenticate the cli client
48+
// using the users credentials in the browser. Compatible with SSO.
49+
func (c *devAuthenticator) Authenticate(r *http.Request) error {
50+
c.mx.Lock()
51+
defer c.mx.Unlock()
52+
if c.tokenSource != nil {
53+
token, err := c.tokenSource.Token()
54+
if err == nil {
55+
token.SetAuthHeader(r)
56+
return nil
57+
} else if !strings.Contains(err.Error(), "invalid_grant") {
58+
return err
59+
}
60+
61+
token.SetAuthHeader(r)
62+
return nil
63+
}
64+
65+
config, err := u2m.GetConfig(context.Background(), c.hostName, c.clientID, "", LISTEN_ADDR, c.scopes)
66+
if err != nil {
67+
return fmt.Errorf("unable to generate oauth2.Config: %w", err)
68+
}
69+
70+
tokenSource, err := GetTokenSource(context.Background(), config, 0)
71+
if err != nil {
72+
return fmt.Errorf("unable to get token source: %w", err)
73+
}
74+
75+
c.tokenSource = tokenSource
76+
77+
token, err := tokenSource.Token()
78+
if err != nil {
79+
return fmt.Errorf("unable to get token source: %w", err)
80+
}
81+
82+
token.SetAuthHeader(r)
83+
84+
return nil
85+
}
86+
87+
type authResponse struct {
88+
err string
89+
details string
90+
state string
91+
code string
92+
}
93+
94+
func GetTokenSource(ctx context.Context, config oauth2.Config, timeout time.Duration) (oauth2.TokenSource, error) {
95+
if timeout == 0 {
96+
timeout = 2 * time.Minute
97+
}
98+
99+
state, err := randString(16)
100+
if err != nil {
101+
err = fmt.Errorf("unable to generate random number: %w", err)
102+
return nil, err
103+
}
104+
105+
challenge, challengeMethod, verifier, err := u2m.GetAuthCodeOptions()
106+
if err != nil {
107+
return nil, err
108+
}
109+
110+
loginURL := u2m.GetLoginURL(config, state, challenge, challengeMethod)
111+
112+
// handle ctrl-c while waiting for the callback
113+
sigintCh := make(chan os.Signal, 1)
114+
signal.Notify(sigintCh, os.Interrupt)
115+
// receive auth callback response
116+
authDoneCh := make(chan authResponse)
117+
118+
u, _ := url.Parse(config.RedirectURL)
119+
if u.Path == "" {
120+
u.Path = "/"
121+
}
122+
123+
http.HandleFunc(u.Path, handlerFunc(authDoneCh, state))
124+
125+
log.Info().Msgf("listening on %s://%s/", u.Scheme, u.Host)
126+
listener, err := net.Listen("tcp", u.Host)
127+
if err != nil {
128+
return nil, err
129+
}
130+
defer listener.Close()
131+
132+
srv := &http.Server{
133+
ReadHeaderTimeout: 3 * time.Second,
134+
WriteTimeout: 30 * time.Second,
135+
}
136+
137+
defer srv.Close()
138+
139+
// Start local server to wait for callback
140+
go func() {
141+
err := srv.Serve(listener)
142+
143+
// in case port is in use
144+
if err != nil && err != http.ErrServerClosed {
145+
authDoneCh <- authResponse{err: err.Error()}
146+
}
147+
}()
148+
149+
fmt.Printf("\nOpen URL in Browser to Continue: %s\n\n", loginURL)
150+
err = openbrowser(loginURL)
151+
if err != nil {
152+
fmt.Println("Unable to open browser automatically. Please open manually: ", loginURL)
153+
}
154+
155+
// Wait for callback to be received, Wait for either the callback to finish, SIGINT to be received or up to 2 minutes
156+
select {
157+
case authResponse := <-authDoneCh:
158+
if authResponse.err != "" {
159+
return nil, fmt.Errorf("identity provider error: %s: %s", authResponse.err, authResponse.details)
160+
}
161+
token, err := config.Exchange(ctx, authResponse.code, verifier)
162+
if err != nil {
163+
return nil, fmt.Errorf("failed to exchange token: %w", err)
164+
}
165+
166+
return config.TokenSource(ctx, token), nil
167+
168+
case <-sigintCh:
169+
return nil, errors.New("interrupted while waiting for auth callback")
170+
171+
case <-time.After(timeout):
172+
return nil, errors.New("timed out waiting for response from provider")
173+
}
174+
}
175+
176+
func handlerFunc(authDoneCh chan authResponse, state string) func(http.ResponseWriter, *http.Request) {
177+
return func(w http.ResponseWriter, r *http.Request) {
178+
resp := authResponse{
179+
err: r.URL.Query().Get("error"),
180+
details: r.URL.Query().Get("error_description"),
181+
state: r.URL.Query().Get("state"),
182+
code: r.URL.Query().Get("code"),
183+
}
184+
185+
// Send the response back to the to cli
186+
defer func() { authDoneCh <- resp }()
187+
188+
// Do some checking of the response here to show more relevant content
189+
if resp.err != "" {
190+
w.WriteHeader(http.StatusBadRequest)
191+
_, err := w.Write([]byte(errorHTML("Identity Provider returned an error: " + resp.err)))
192+
if err != nil {
193+
log.Error().Err(err).Msg("unable to write error response")
194+
}
195+
return
196+
}
197+
if resp.state != state {
198+
w.WriteHeader(http.StatusBadRequest)
199+
_, err := w.Write([]byte(errorHTML("Authentication state received did not match original request. Please try to login again.")))
200+
if err != nil {
201+
log.Error().Err(err).Msg("unable to write error response")
202+
}
203+
return
204+
}
205+
206+
_, err := w.Write([]byte(infoHTML("CLI Login Success", "You may close this window anytime now and go back to terminal")))
207+
if err != nil {
208+
log.Error().Err(err).Msg("unable to write success response")
209+
}
210+
}
211+
}
212+
213+
func randString(nByte int) (string, error) {
214+
b := make([]byte, nByte)
215+
if _, err := io.ReadFull(rand.Reader, b); err != nil {
216+
return "", err
217+
}
218+
return base64.RawURLEncoding.EncodeToString(b), nil
219+
}
220+
221+
func openbrowser(url string) error {
222+
var err error
223+
224+
switch runtime.GOOS {
225+
case "linux":
226+
err = exec.Command("xdg-open", url).Start()
227+
case "windows":
228+
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
229+
case "darwin":
230+
err = exec.Command("open", url).Start()
231+
default:
232+
err = fmt.Errorf("unsupported platform")
233+
}
234+
return err
235+
}

auth/oauth/dev/html_template.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package dev
2+
3+
import (
4+
"bytes"
5+
_ "embed"
6+
"html/template"
7+
)
8+
9+
type SimplePage struct {
10+
Title string
11+
Heading string
12+
Content string
13+
Action ActionLink
14+
Code string
15+
}
16+
17+
type ActionLink struct {
18+
Label string
19+
Link string
20+
}
21+
22+
var (
23+
//go:embed templates/simple.html
24+
simpleHtmlPage string
25+
)
26+
27+
func renderHTML(data SimplePage) (string, error) {
28+
var out bytes.Buffer
29+
tmpl, err := template.New("name").Parse(simpleHtmlPage)
30+
if err != nil {
31+
return "", err
32+
}
33+
err = tmpl.Execute(&out, data)
34+
return out.String(), err
35+
}
36+
37+
func infoHTML(title, content string) string {
38+
data := SimplePage{
39+
Title: "Authentication Success",
40+
Heading: title,
41+
Content: content,
42+
}
43+
out, _ := renderHTML(data)
44+
return out
45+
}
46+
47+
func errorHTML(msg string) string {
48+
data := SimplePage{
49+
Title: "Authentication Error",
50+
Heading: "Ooops!",
51+
Content: "Sorry, Databricks could not authenticate to your account due to some server errors. Please try it later.",
52+
Code: msg,
53+
}
54+
out, _ := renderHTML(data)
55+
return out
56+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
<!DOCTYPE html SYSTEM "http://www.thymeleaf.org/dtd/xhtml1-strict-thymeleaf-4.dtd">
2+
3+
<html xmlns="http://www.w3.org/1999/xhtml" xmlns:th="http://www.thymeleaf.org">
4+
5+
<head>
6+
<title>{{ .Title }}</title>
7+
<link rel="preconnect" href="https://fonts.gstatic.com" />
8+
<link href="https://fonts.googleapis.com/css2?family=IBM+Plex+Sans:ital,wght@0,400;0,700;1,400&display=swap"
9+
rel="stylesheet" />
10+
11+
<style>
12+
html,
13+
body {
14+
height: 100%;
15+
}
16+
17+
body {
18+
font-family: "IBM Plex Sans";
19+
font-style: normal;
20+
font-size: 14px;
21+
margin: 0;
22+
padding: 0;
23+
height: 100%;
24+
width: 100%;
25+
background: #f5f6f6;
26+
align-items: center;
27+
}
28+
29+
.root-container {
30+
display: flex;
31+
height: 100%;
32+
align-items: center;
33+
justify-content: center;
34+
}
35+
36+
.info-container {
37+
width: 320px;
38+
box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1),
39+
0px 8px 25px rgba(0, 0, 0, 0.1);
40+
border-radius: 8px;
41+
display: flex;
42+
flex-direction: column;
43+
padding: 48px;
44+
background: #fff;
45+
justify-content: center;
46+
align-items: center;
47+
text-align: center;
48+
gap: 24px;
49+
}
50+
51+
.logo {
52+
display: "block";
53+
max-width: 140px;
54+
max-height: 40px;
55+
}
56+
57+
.title {
58+
font-weight: 600;
59+
font-size: 24px;
60+
line-height: 28px;
61+
}
62+
63+
.content {
64+
width: 300px;
65+
font-size: 14px;
66+
}
67+
68+
.button {
69+
display: flex;
70+
background: #191519;
71+
align-items: center;
72+
justify-content: center;
73+
height: 40px;
74+
width: 300px;
75+
border-radius: 4px;
76+
text-align: center;
77+
text-decoration: none;
78+
color: #ffffff !important;
79+
}
80+
</style>
81+
</head>
82+
83+
<body>
84+
<div class="root-container">
85+
<div class="info-container">
86+
<img class="logo"
87+
src="https://www.databricks.com/wp-content/uploads/2022/06/db-nav-logo-stacked-white-desktop.svg" />
88+
<div class="title">{{ .Heading }}</div>
89+
<div class="content">{{ .Content }}</div>
90+
<!-- {{ if .Action.Link }} -->
91+
<a class="button" target="_blank" href="{{ .Action.Link }}">{{ .Action.Label }}</a>
92+
<!--{{ end }} -->
93+
<!-- {{ if .Code }} -->
94+
<code>{{ .Code }}</code>
95+
<!--{{ end }} -->
96+
</div>
97+
</div>
98+
</body>
99+
100+
</html>

0 commit comments

Comments
 (0)