Skip to content

Commit 9e82d18

Browse files
Add best-effort --force-refresh support for databricks-cli auth
When the SDK's cached CLI token is stale, try `databricks auth token --force-refresh` to get a freshly minted token from the IdP. If the installed CLI is too old to recognise the flag, fall back to regular `auth token` and remember the capability for future refreshes. Centralise unknown-flag detection in CliTokenSource._exec_cli_command() via UnsupportedCliFlagError so the same classifier is reused by both the legacy --profile fallback and the new --force-refresh downgrade path in DatabricksCliTokenSource. See: databricks/cli#4767
1 parent 34d6184 commit 9e82d18

7 files changed

Lines changed: 173 additions & 13 deletions

File tree

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Release v0.102.0
44

55
### New Features and Improvements
6+
* Pass `--force-refresh` to the Databricks CLI `auth token` command so the SDK always receives a freshly minted token instead of a potentially stale cached one. Falls back gracefully on older CLIs that do not support the flag.
67

78
### Security
89

databricks/sdk/credentials_provider.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import pathlib
99
import platform
10+
import re
1011
import subprocess
1112
import sys
1213
import threading
@@ -650,6 +651,8 @@ def refreshed_headers() -> Dict[str, str]:
650651

651652

652653
class CliTokenSource(oauth.Refreshable):
654+
_UNKNOWN_FLAG_RE = re.compile(r"unknown flag: (--[a-z-]+)")
655+
653656
def __init__(
654657
self,
655658
cmd: List[str],
@@ -699,6 +702,12 @@ def _exec_cli_command(self, cmd: List[str]) -> oauth.Token:
699702
message = "\n".join(filter(None, [stdout, stderr]))
700703
raise IOError(f"cannot get access token: {message}") from e
701704

705+
@staticmethod
706+
def _get_unsupported_flag(error: IOError) -> Optional[str]:
707+
"""Extract the flag name if the error is an 'unknown flag' CLI rejection."""
708+
match = CliTokenSource._UNKNOWN_FLAG_RE.search(str(error))
709+
return match.group(1) if match else None
710+
702711
def refresh(self) -> oauth.Token:
703712
try:
704713
return self._exec_cli_command(self._cmd)
@@ -900,15 +909,14 @@ def __init__(self, cfg: "Config"):
900909

901910
fallback_cmd = None
902911
if cfg.profile:
903-
# When profile is set, use --profile as the primary command.
904-
# The profile contains the full config (host, account_id, etc.).
905912
args = ["auth", "token", "--profile", cfg.profile]
906-
# Build a --host fallback for older CLIs that don't support --profile.
907913
if cfg.host:
908914
fallback_cmd = [cli_path, *self.__class__._build_host_args(cfg)]
909915
else:
910916
args = self.__class__._build_host_args(cfg)
911917

918+
self._force_cmd = [cli_path, *args, "--force-refresh"]
919+
912920
# get_scopes() defaults to ["all-apis"] when nothing is configured, which would
913921
# cause false-positive mismatches against every token that wasn't issued with
914922
# exactly ["all-apis"]. Only validate when scopes are explicitly set (either
@@ -925,13 +933,21 @@ def __init__(self, cfg: "Config"):
925933
fallback_cmd=fallback_cmd,
926934
)
927935

936+
_KNOWN_CLI_FLAGS = {"--force-refresh", "--profile"}
937+
928938
def refresh(self) -> oauth.Token:
929-
# The scope validation lives in refresh() because this is the only method that
930-
# produces new tokens (see Refreshable._token assignments). By overriding here,
931-
# every token is validated — both at initial auth and on every subsequent refresh
932-
# when the cached token expires. This catches cases where a user re-authenticates
933-
# mid-session with different scopes.
934-
token = super().refresh()
939+
try:
940+
token = self._exec_cli_command(self._force_cmd)
941+
except IOError as e:
942+
flag = self._get_unsupported_flag(e)
943+
if flag in self._KNOWN_CLI_FLAGS:
944+
logger.warning(
945+
"Databricks CLI does not support %s. " "Please upgrade your CLI to the latest version.",
946+
flag,
947+
)
948+
token = super().refresh()
949+
else:
950+
raise
935951
if self._requested_scopes:
936952
self._validate_token_scopes(token)
937953
return token

pyrefly.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
project-includes = [
2+
"**/*.py"
3+
]
4+
5+
project-excludes = []
6+
7+
search-path = []
8+
9+
disable-search-path-heuristics = true
10+
ignore-missing-imports = ["*"]
11+
ignore-errors-in-generated-code = true

tests/__init__.pyc

145 Bytes
Binary file not shown.

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
117117

118118
def test_config_deep_copy(monkeypatch, mocker, tmp_path):
119119
mocker.patch(
120-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
120+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
121121
return_value=oauth.Token(
122122
access_token="token",
123123
token_type="Bearer",

tests/test_core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_databricks_cli_credential_provider_installed_legacy(config, monkeypatch
166166

167167
def test_databricks_cli_credential_provider_installed_new(config, monkeypatch, tmp_path, mocker):
168168
get_mock = mocker.patch(
169-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
169+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
170170
return_value=Token(
171171
access_token="token",
172172
token_type="Bearer",
@@ -222,7 +222,7 @@ def test_databricks_cli_scope_validation(
222222
config, monkeypatch, tmp_path, mocker, token_claims, configured_scopes, auth_type, expect
223223
):
224224
mocker.patch(
225-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
225+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
226226
return_value=Token(access_token=_make_jwt(token_claims), token_type="Bearer", expiry=datetime(2023, 5, 22)),
227227
)
228228
write_large_dummy_executable(tmp_path)
@@ -244,7 +244,7 @@ def test_databricks_cli_scope_validation(
244244

245245
def test_databricks_cli_scope_validation_error_message(config, monkeypatch, tmp_path, mocker):
246246
mocker.patch(
247-
"databricks.sdk.credentials_provider.CliTokenSource.refresh",
247+
"databricks.sdk.credentials_provider.CliTokenSource._exec_cli_command",
248248
return_value=Token(
249249
access_token=_make_jwt({"scope": "all-apis"}), token_type="Bearer", expiry=datetime(2023, 5, 22)
250250
),

tests/test_credentials_provider.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,138 @@ def test_no_fallback_when_fallback_cmd_not_set(self, mocker):
471471
assert mock_run.call_count == 1
472472

473473

474+
class TestDatabricksCliForceRefresh:
475+
"""Tests for --force-refresh support in DatabricksCliTokenSource."""
476+
477+
@staticmethod
478+
def _make_process_error(stderr: str, stdout: str = ""):
479+
import subprocess
480+
481+
err = subprocess.CalledProcessError(1, ["databricks"])
482+
err.stdout = stdout.encode()
483+
err.stderr = stderr.encode()
484+
return err
485+
486+
@staticmethod
487+
def _make_token_source(
488+
*,
489+
profile=None,
490+
host="https://workspace.databricks.com",
491+
cli_path="/path/to/databricks",
492+
):
493+
"""Build a DatabricksCliTokenSource by mocking only the executable check."""
494+
mock_cfg = Mock()
495+
mock_cfg.profile = profile
496+
mock_cfg.host = host
497+
mock_cfg.databricks_cli_path = cli_path
498+
mock_cfg.disable_async_token_refresh = True
499+
mock_cfg.scopes = None
500+
mock_cfg.get_scopes = Mock(return_value=["all-apis"])
501+
mock_cfg.client_type = ClientType.WORKSPACE
502+
mock_cfg.account_id = None
503+
return credentials_provider.DatabricksCliTokenSource(mock_cfg)
504+
505+
def _valid_response_json(self, access_token="fresh-token"):
506+
import json
507+
508+
expiry = (datetime.now() + timedelta(hours=1)).strftime("%Y-%m-%dT%H:%M:%S")
509+
return json.dumps({"access_token": access_token, "token_type": "Bearer", "expiry": expiry})
510+
511+
def test_force_refresh_always_tried_first(self, mocker):
512+
"""refresh() always tries --force-refresh first."""
513+
ts = self._make_token_source()
514+
515+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
516+
mock_run.return_value = Mock(stdout=self._valid_response_json("refreshed").encode())
517+
518+
token = ts.refresh()
519+
assert token.access_token == "refreshed"
520+
assert mock_run.call_count == 1
521+
522+
cmd = mock_run.call_args[0][0]
523+
assert "--force-refresh" in cmd
524+
525+
def test_force_refresh_fallback_when_unsupported(self, mocker):
526+
"""Old CLI without --force-refresh: falls back to cmd without --force-refresh."""
527+
ts = self._make_token_source()
528+
529+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
530+
mock_run.side_effect = [
531+
self._make_process_error("Error: unknown flag: --force-refresh"),
532+
Mock(stdout=self._valid_response_json("fallback").encode()),
533+
]
534+
535+
token = ts.refresh()
536+
assert token.access_token == "fallback"
537+
assert mock_run.call_count == 2
538+
539+
first_cmd = mock_run.call_args_list[0][0][0]
540+
second_cmd = mock_run.call_args_list[1][0][0]
541+
assert "--force-refresh" in first_cmd
542+
assert "--force-refresh" not in second_cmd
543+
544+
def test_profile_fallback_when_unsupported(self, mocker):
545+
"""Old CLI without --profile: force_cmd fails, fallback retries with --host."""
546+
ts = self._make_token_source(profile="my-profile")
547+
548+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
549+
mock_run.side_effect = [
550+
# force_cmd: --profile + --force-refresh → unknown --profile
551+
self._make_process_error("Error: unknown flag: --profile"),
552+
# _refresh_without_force cmd: --profile → unknown --profile
553+
self._make_process_error("Error: unknown flag: --profile"),
554+
# _refresh_without_force fallback_cmd: --host → success
555+
Mock(stdout=self._valid_response_json("host-token").encode()),
556+
]
557+
558+
token = ts.refresh()
559+
assert token.access_token == "host-token"
560+
assert mock_run.call_count == 3
561+
assert "--host" in mock_run.call_args_list[2][0][0]
562+
563+
def test_two_step_downgrade_both_flags_unsupported(self, mocker):
564+
"""CLI supports neither --force-refresh nor --profile: force_cmd fails, then full fallback."""
565+
ts = self._make_token_source(profile="my-profile")
566+
567+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
568+
mock_run.side_effect = [
569+
# 1st: force_cmd (--profile + --force-refresh) → unknown --force-refresh
570+
self._make_process_error("Error: unknown flag: --force-refresh"),
571+
# 2nd: _refresh_without_force cmd (--profile) → unknown --profile
572+
self._make_process_error("Error: unknown flag: --profile"),
573+
# 3rd: _refresh_without_force fallback_cmd (--host) → success
574+
Mock(stdout=self._valid_response_json("plain").encode()),
575+
]
576+
577+
token = ts.refresh()
578+
assert token.access_token == "plain"
579+
assert mock_run.call_count == 3
580+
581+
cmds = [call[0][0] for call in mock_run.call_args_list]
582+
assert "--force-refresh" in cmds[0] and "--profile" in cmds[0]
583+
assert "--force-refresh" not in cmds[1] and "--profile" in cmds[1]
584+
assert "--host" in cmds[2]
585+
586+
def test_real_auth_error_does_not_trigger_fallback(self, mocker):
587+
"""Real auth failures (not unknown-flag) surface immediately."""
588+
ts = self._make_token_source()
589+
590+
mock_run = mocker.patch("databricks.sdk.credentials_provider._run_subprocess")
591+
mock_run.side_effect = self._make_process_error("cache: databricks OAuth is not configured for this host")
592+
593+
with pytest.raises(IOError) as exc_info:
594+
ts.refresh()
595+
assert "databricks OAuth is not configured" in str(exc_info.value)
596+
assert mock_run.call_count == 1
597+
598+
def test_get_unsupported_flag_extracts_flag(self):
599+
"""The classifier correctly parses the flag name from CLI error output."""
600+
get = credentials_provider.CliTokenSource._get_unsupported_flag
601+
assert get(IOError("Error: unknown flag: --force-refresh")) == "--force-refresh"
602+
assert get(IOError("Error: unknown flag: --profile")) == "--profile"
603+
assert get(IOError("some other error")) is None
604+
605+
474606
# Tests for cloud-agnostic hosts and removed cloud checks
475607
class TestCloudAgnosticHosts:
476608
"""Tests that credential providers work with cloud-agnostic hosts after removing is_azure/is_gcp checks."""

0 commit comments

Comments
 (0)