Skip to content

Commit 89a0c7a

Browse files
rustyconoverclaude
andcommitted
Support multiple JWT issuers in jwt_authenticate
Accept a tuple of issuer strings for multi-tenant setups (e.g. Microsoft Entra with multiple tenants). A token matching any issuer is accepted. First issuer used for OIDC discovery when jwks_uri is not provided. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 73343e6 commit 89a0c7a

3 files changed

Lines changed: 78 additions & 9 deletions

File tree

tests/test_oauth.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,13 @@ def _mint_jwt(
8585
def _make_local_auth(
8686
public_key: dict[str, object],
8787
*,
88-
issuer: str = "https://auth.example.com",
88+
issuer: str | tuple[str, ...] = "https://auth.example.com",
8989
audience: str | tuple[str, ...] = "https://api.example.com/vgi",
9090
principal_claim: str = "sub",
9191
domain: str = "jwt",
9292
) -> Callable[[falcon.Request], AuthContext]:
9393
"""Create a local JWT authenticate callback (no JWKS endpoint needed)."""
94+
issuers = (issuer,) if isinstance(issuer, str) else issuer
9495
audiences = (audience,) if isinstance(audience, str) else audience
9596

9697
def authenticate(req: falcon.Request) -> AuthContext:
@@ -103,7 +104,7 @@ def authenticate(req: falcon.Request) -> AuthContext:
103104
raw_token,
104105
public_key,
105106
claims_options={
106-
"iss": {"essential": True, "value": issuer},
107+
"iss": {"essential": True, "values": list(issuers)},
107108
"aud": {"essential": True, "values": list(audiences)},
108109
},
109110
)
@@ -816,6 +817,37 @@ def test_wrong_issuer_raises_value_error(self) -> None:
816817
with pytest.raises(ValueError):
817818
auth_fn(req)
818819

820+
def test_multiple_issuers_first_matches(self) -> None:
821+
"""A JWT matching the first of multiple issuers is accepted."""
822+
priv, pub = _make_rsa_key()
823+
iss1 = "https://auth.example.com"
824+
iss2 = "https://auth2.example.com"
825+
token = _mint_jwt(priv, iss=iss1)
826+
auth_fn = _make_local_auth(pub, issuer=(iss1, iss2))
827+
req = falcon.testing.helpers.create_req(headers={"Authorization": f"Bearer {token}"})
828+
auth = auth_fn(req)
829+
assert auth.authenticated is True
830+
831+
def test_multiple_issuers_second_matches(self) -> None:
832+
"""A JWT matching the second of multiple issuers is accepted."""
833+
priv, pub = _make_rsa_key()
834+
iss1 = "https://auth.example.com"
835+
iss2 = "https://auth2.example.com"
836+
token = _mint_jwt(priv, iss=iss2)
837+
auth_fn = _make_local_auth(pub, issuer=(iss1, iss2))
838+
req = falcon.testing.helpers.create_req(headers={"Authorization": f"Bearer {token}"})
839+
auth = auth_fn(req)
840+
assert auth.authenticated is True
841+
842+
def test_multiple_issuers_none_match(self) -> None:
843+
"""A JWT with an unrecognized issuer is rejected."""
844+
priv, pub = _make_rsa_key()
845+
token = _mint_jwt(priv, iss="https://wrong.example.com")
846+
auth_fn = _make_local_auth(pub, issuer=("https://a.example.com", "https://b.example.com"))
847+
req = falcon.testing.helpers.create_req(headers={"Authorization": f"Bearer {token}"})
848+
with pytest.raises(ValueError):
849+
auth_fn(req)
850+
819851
def test_missing_bearer_raises_value_error(self) -> None:
820852
"""Missing Authorization header raises ValueError."""
821853
_priv, pub = _make_rsa_key()
@@ -909,3 +941,30 @@ def test_empty_audience_tuple_raises(self) -> None:
909941
audience=(),
910942
jwks_uri="https://auth.example.com/.well-known/jwks.json",
911943
)
944+
945+
def test_empty_issuer_tuple_raises(self) -> None:
946+
"""Passing an empty issuer tuple raises ValueError eagerly."""
947+
with pytest.raises(ValueError, match="issuer must not be empty"):
948+
jwt_authenticate(
949+
issuer=(),
950+
audience="https://api.example.com/vgi",
951+
jwks_uri="https://auth.example.com/.well-known/jwks.json",
952+
)
953+
954+
def test_jwt_authenticate_factory_multiple_issuers(self) -> None:
955+
"""jwt_authenticate() accepts a tuple of issuers."""
956+
auth_fn = jwt_authenticate(
957+
issuer=("https://auth.example.com", "https://auth2.example.com"),
958+
audience="https://api.example.com/vgi",
959+
jwks_uri="https://auth.example.com/.well-known/jwks.json",
960+
)
961+
assert callable(auth_fn)
962+
963+
def test_single_issuer_string_still_works(self) -> None:
964+
"""Passing issuer as a plain string still works (backwards compat)."""
965+
priv, pub = _make_rsa_key()
966+
token = _mint_jwt(priv)
967+
auth_fn = _make_local_auth(pub, issuer="https://auth.example.com")
968+
req = falcon.testing.helpers.create_req(headers={"Authorization": f"Bearer {token}"})
969+
auth = auth_fn(req)
970+
assert auth.authenticated is True

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vgi_rpc/http/_oauth_jwt.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
def jwt_authenticate(
3535
*,
36-
issuer: str,
36+
issuer: str | tuple[str, ...],
3737
audience: str | tuple[str, ...],
3838
jwks_uri: str | None = None,
3939
claims_options: Mapping[str, Any] | None = None,
@@ -47,7 +47,11 @@ def jwt_authenticate(
4747
Keys are cached in-process with automatic refresh on unknown ``kid``.
4848
4949
Args:
50-
issuer: Expected ``iss`` claim in the JWT.
50+
issuer: Expected ``iss`` claim(s) in the JWT. A single string
51+
or a tuple of strings. When multiple issuers are given, a
52+
token matching **any** of them is accepted. When multiple
53+
issuers are provided and ``jwks_uri`` is not set, the first
54+
issuer is used for OIDC discovery.
5155
audience: Expected ``aud`` claim(s) in the JWT. A single string
5256
or a tuple of strings. When multiple audiences are given, a
5357
token matching **any** of them is accepted.
@@ -71,6 +75,12 @@ def jwt_authenticate(
7175
# Authlib's KeySet type is not exported in public stubs; alias to contain Any.
7276
type _KeySet = Any
7377

78+
issuers = (issuer,) if isinstance(issuer, str) else issuer
79+
if not issuers:
80+
raise ValueError("issuer must not be empty")
81+
# Use the first issuer for OIDC discovery when jwks_uri is not provided.
82+
discovery_issuer = issuers[0]
83+
7484
resolved_jwks_uri = jwks_uri
7585
lock = threading.Lock()
7686
key_set: _KeySet = None
@@ -79,7 +89,7 @@ def _fetch_jwks() -> _KeySet:
7989
nonlocal resolved_jwks_uri, key_set
8090
if resolved_jwks_uri is None:
8191
with httpx.Client() as client:
82-
oidc_resp = client.get(f"{issuer.rstrip('/')}/.well-known/openid-configuration")
92+
oidc_resp = client.get(f"{discovery_issuer.rstrip('/')}/.well-known/openid-configuration")
8393
oidc_resp.raise_for_status()
8494
resolved_jwks_uri = oidc_resp.json()["jwks_uri"]
8595

@@ -102,7 +112,7 @@ def _get_key_set(force_refresh: bool = False) -> _KeySet:
102112
raise ValueError("audience must not be empty")
103113

104114
base_claims_options: dict[str, Any] = {
105-
"iss": {"essential": True, "value": issuer},
115+
"iss": {"essential": True, "values": list(issuers)},
106116
"aud": {"essential": True, "values": list(audiences)},
107117
}
108118
if claims_options:
@@ -117,12 +127,12 @@ def _log_claim_mismatch(token: str, keys: _KeySet, exc: JoseError) -> None:
117127
token_iss = raw_claims.get("iss")
118128
logger.warning(
119129
"JWT validation failed: %s\n"
120-
" expected iss: %s\n"
130+
" expected iss (any of): %s\n"
121131
" token iss: %s\n"
122132
" expected aud (any of): %s\n"
123133
" token aud: %s",
124134
exc,
125-
issuer,
135+
list(issuers),
126136
token_iss,
127137
list(audiences),
128138
token_aud,

0 commit comments

Comments
 (0)