Skip to content

Commit ab1533d

Browse files
committed
feat(platform): support external token providers and simplify caching
Add a `token_provider` parameter to `Client.__init__()` so callers can supply their own bearer-token callable (e.g. for M2M / service-account flows) instead of relying on the built-in OAuth device-code flow. The operation cache needs a token to build per-user cache keys. Three approaches were explored: 1. **CachedApiMixin with Protocol** (explored, discarded) — A `_HasTokenProvider` Protocol plus `CachedApiMixin` base class that provided a `_cached()` convenience method. Saved one kwarg per call site but added multiple-inheritance, a Protocol, and competing `_api` annotations on every resource class. Over-engineering for what amounts to avoiding `token_provider=self._api.token_provider`. 2. **`TokenProvider` type alias** (explored, discarded) — `TokenProvider = Callable[[], str]` exported as public API. Added no value over the raw `Callable[[], str]` since every parameter is already named `token_provider`. Removed to avoid unnecessary imports and indirection. 3. **`_AuthenticatedApi` subclass + explicit `token_provider` at each call site** (chosen) — A thin `_AuthenticatedApi(PublicApi)` subclass in `_api.py` lifts `token_provider` from the deeply-nested codegen `Configuration` to a top-level attribute. Each cached method passes `token_provider=self._api.token_provider` to `@cached_operation(...)`. One kwarg of boilerplate per call site, but no new types, no inheritance, and trivially greppable. `_api.py` exists solely to break the circular import between `_client.py` and the resource modules.
1 parent ba74ebe commit ab1533d

11 files changed

Lines changed: 308 additions & 132 deletions

File tree

src/aignostics/platform/_api.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""Authenticated API wrapper and configuration.
2+
3+
This module defines the thin API subclass and configuration that lift
4+
``token_provider`` to a first-class attribute. Kept separate from ``_client``
5+
so that resource modules can import these types directly without circular
6+
dependencies.
7+
"""
8+
9+
from collections.abc import Callable
10+
11+
from aignx.codegen.api.public_api import PublicApi
12+
from aignx.codegen.api_client import ApiClient
13+
from aignx.codegen.configuration import AuthSettings, Configuration
14+
15+
16+
class _OAuth2TokenProviderConfiguration(Configuration):
17+
"""Overwrites the original Configuration to call a function to obtain a refresh token.
18+
19+
The base class does not support callbacks. This is necessary for integrations where
20+
tokens may expire or need to be refreshed automatically.
21+
"""
22+
23+
def __init__(
24+
self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None
25+
) -> None:
26+
super().__init__(host=host, ssl_ca_cert=ssl_ca_cert)
27+
self.token_provider = token_provider
28+
29+
def auth_settings(self) -> AuthSettings:
30+
token = self.token_provider() if self.token_provider else None
31+
if not token:
32+
return {}
33+
return {
34+
"OAuth2AuthorizationCodeBearer": {
35+
"type": "oauth2",
36+
"in": "header",
37+
"key": "Authorization",
38+
"value": f"Bearer {token}",
39+
}
40+
}
41+
42+
43+
class _AuthenticatedApi(PublicApi):
44+
"""Thin wrapper around the generated :class:`PublicApi`.
45+
46+
Lifts ``token_provider`` from the deeply-nested ``Configuration`` to a
47+
top-level attribute, making it accessible without traversing codegen internals.
48+
"""
49+
50+
token_provider: Callable[[], str] | None
51+
52+
def __init__(self, api_client: ApiClient, token_provider: Callable[[], str] | None = None) -> None:
53+
super().__init__(api_client)
54+
self.token_provider = token_provider

src/aignostics/platform/_client.py

Lines changed: 53 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
from urllib.request import getproxies
55

66
import semver
7-
from aignx.codegen.api.public_api import PublicApi
87
from aignx.codegen.api_client import ApiClient
9-
from aignx.codegen.configuration import AuthSettings, Configuration
108
from aignx.codegen.exceptions import NotFoundException, ServiceException
119
from aignx.codegen.models import ApplicationReadResponse as Application
1210
from aignx.codegen.models import MeReadResponse as Me
@@ -22,6 +20,7 @@
2220
from urllib3.exceptions import IncompleteRead, PoolError, ProtocolError, ProxyError
2321
from urllib3.exceptions import TimeoutError as Urllib3TimeoutError
2422

23+
from aignostics.platform._api import _AuthenticatedApi, _OAuth2TokenProviderConfiguration
2524
from aignostics.platform._authentication import get_token
2625
from aignostics.platform._operation_cache import cached_operation
2726
from aignostics.platform.resources.applications import Applications, Versions
@@ -59,34 +58,6 @@ def _log_retry_attempt(retry_state: RetryCallState) -> None:
5958
)
6059

6160

62-
class _OAuth2TokenProviderConfiguration(Configuration):
63-
"""
64-
Overwrites the original Configuration to call a function to obtain a refresh token.
65-
66-
The base class does not support callbacks. This is necessary for integrations where
67-
tokens may expire or need to be refreshed automatically.
68-
"""
69-
70-
def __init__(
71-
self, host: str, ssl_ca_cert: str | None = None, token_provider: Callable[[], str] | None = None
72-
) -> None:
73-
super().__init__(host=host, ssl_ca_cert=ssl_ca_cert)
74-
self.token_provider = token_provider
75-
76-
def auth_settings(self) -> AuthSettings:
77-
token = self.token_provider() if self.token_provider else None
78-
if not token:
79-
return {}
80-
return {
81-
"OAuth2AuthorizationCodeBearer": {
82-
"type": "oauth2",
83-
"in": "header",
84-
"key": "Authorization",
85-
"value": f"Bearer {token}",
86-
}
87-
}
88-
89-
9061
class Client:
9162
"""Main client for interacting with the Aignostics Platform API.
9263
@@ -96,25 +67,41 @@ class Client:
9667
- Caches operation results for specific operations.
9768
"""
9869

99-
_api_client_cached: ClassVar[PublicApi | None] = None
100-
_api_client_uncached: ClassVar[PublicApi | None] = None
70+
_api_client_cached: ClassVar[_AuthenticatedApi | None] = None
71+
_api_client_uncached: ClassVar[_AuthenticatedApi | None] = None
72+
_api_client_external: ClassVar[dict[int, _AuthenticatedApi]] = {}
10173

74+
_api: _AuthenticatedApi
10275
applications: Applications
10376
versions: Versions
10477
runs: Runs
10578

106-
def __init__(self, cache_token: bool = True) -> None:
79+
def __init__(self, cache_token: bool = True, token_provider: Callable[[], str] | None = None) -> None:
10780
"""Initializes a client instance with authenticated API access.
10881
10982
Args:
110-
cache_token (bool): If True, caches the authentication token.
111-
Defaults to True.
83+
cache_token: If True, caches the authentication token. Defaults to True.
84+
token_provider: Optional external token provider callable. When provided,
85+
bypasses internal OAuth authentication entirely. The callable must
86+
return a valid bearer token string.
87+
88+
Raises:
89+
ValueError: If both ``token_provider`` and ``cache_token=False`` are specified,
90+
since token caching is irrelevant when using an external provider.
11291
11392
Sets up resource accessors for applications, versions, and runs.
11493
"""
94+
if token_provider is not None and not cache_token:
95+
msg = (
96+
"Cannot set cache_token=False with an external token_provider. "
97+
"Token caching is managed internally when using the default OAuth flow. "
98+
"When providing an external token_provider, omit cache_token or use the default."
99+
)
100+
raise ValueError(msg)
101+
115102
try:
116-
logger.trace("Initializing client with cache_token={}", cache_token)
117-
self._api = Client.get_api_client(cache_token=cache_token)
103+
logger.trace("Initializing client with cache_token={}, token_provider={}", cache_token, token_provider)
104+
self._api = Client.get_api_client(cache_token=cache_token, token_provider=token_provider)
118105
self.applications: Applications = Applications(self._api)
119106
self.runs: Runs = Runs(self._api)
120107
self.versions: Versions = Versions(self._api)
@@ -143,7 +130,7 @@ def me(self, nocache: bool = False) -> Me:
143130
aignx.codegen.exceptions.ApiException: If the API call fails.
144131
"""
145132

146-
@cached_operation(ttl=settings().me_cache_ttl, use_token=True)
133+
@cached_operation(ttl=settings().me_cache_ttl, token_provider=self._api.token_provider)
147134
def me_with_retry() -> Me:
148135
return Retrying( # We are not using Tenacity annotations as settings can change at runtime
149136
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
@@ -177,7 +164,7 @@ def application(self, application_id: str, nocache: bool = False) -> Application
177164
Application: The application object.
178165
"""
179166

180-
@cached_operation(ttl=settings().application_cache_ttl, use_token=True)
167+
@cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider)
181168
def application_with_retry(application_id: str) -> Application:
182169
return Retrying(
183170
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
@@ -234,7 +221,7 @@ def application_version(
234221
raise ValueError(message)
235222

236223
# Make the API call with retry logic and caching
237-
@cached_operation(ttl=settings().application_version_cache_ttl, use_token=True)
224+
@cached_operation(ttl=settings().application_version_cache_ttl, token_provider=self._api.token_provider)
238225
def application_version_with_retry(application_id: str, version: str) -> ApplicationVersion:
239226
return Retrying(
240227
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
@@ -268,44 +255,53 @@ def run(self, run_id: str) -> Run:
268255
return Run(self._api, run_id)
269256

270257
@staticmethod
271-
def get_api_client(cache_token: bool = True) -> PublicApi:
258+
def get_api_client(cache_token: bool = True, token_provider: Callable[[], str] | None = None) -> _AuthenticatedApi:
272259
"""Create and configure an authenticated API client.
273260
274261
API client instances are shared across all Client instances for efficient connection reuse.
275-
Two separate instances are maintained: one for cached tokens and one for uncached tokens.
262+
Three pools are maintained: cached-token, uncached-token, and external-provider (keyed by
263+
provider identity).
276264
277265
Args:
278-
cache_token (bool): If True, caches the authentication token.
279-
Defaults to True.
266+
cache_token: If True, caches the authentication token. Defaults to True.
267+
token_provider: Optional external token provider. When provided, bypasses
268+
internal OAuth and uses this callable to obtain bearer tokens.
280269
281270
Returns:
282-
PublicApi: Configured API client with authentication token.
271+
_AuthenticatedApi: Configured API client with authentication token.
283272
284273
Raises:
285274
RuntimeError: If authentication fails.
286275
"""
287-
# Return cached instance if available
288-
if cache_token and Client._api_client_cached is not None:
276+
# Check singleton caches first
277+
if token_provider is not None:
278+
provider_key = id(token_provider)
279+
if provider_key in Client._api_client_external:
280+
return Client._api_client_external[provider_key]
281+
elif cache_token and Client._api_client_cached is not None:
289282
return Client._api_client_cached
290-
if not cache_token and Client._api_client_uncached is not None:
283+
elif not cache_token and Client._api_client_uncached is not None:
291284
return Client._api_client_uncached
292285

293-
def token_provider() -> str:
294-
return get_token(use_cache=cache_token)
286+
# Resolve the effective token provider
287+
effective_provider: Callable[[], str] = (
288+
token_provider if token_provider is not None else (lambda: get_token(use_cache=cache_token))
289+
)
295290

291+
# Build the API client
296292
ca_file = os.getenv("REQUESTS_CA_BUNDLE") # point to .cer file of proxy if defined
297293
config = _OAuth2TokenProviderConfiguration(
298-
host=settings().api_root, ssl_ca_cert=ca_file, token_provider=token_provider
294+
host=settings().api_root, ssl_ca_cert=ca_file, token_provider=effective_provider
299295
)
300296
config.proxy = getproxies().get("https") # use system proxy
301-
client = ApiClient(
302-
config,
303-
)
297+
client = ApiClient(config)
304298
client.user_agent = user_agent()
305-
api_client = PublicApi(client)
299+
api_client = _AuthenticatedApi(client, effective_provider)
306300

307-
# Cache the instance
308-
if cache_token:
301+
# Store in the appropriate singleton cache
302+
if token_provider is not None:
303+
Client._api_client_external[provider_key] = api_client
304+
elif cache_token:
309305
Client._api_client_cached = api_client
310306
else:
311307
Client._api_client_uncached = api_client

src/aignostics/platform/_operation_cache.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
- Supports selective cache clearing by function
1212
"""
1313

14+
from __future__ import annotations
15+
1416
import hashlib
1517
import time
1618
import typing as t
17-
from collections.abc import Callable
18-
from typing import Any, ParamSpec, TypeVar
19+
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
1920

20-
from ._authentication import get_token
21+
if TYPE_CHECKING:
22+
from collections.abc import Callable
2123

2224
# Cache storage for operation results
2325
_operation_cache: dict[str, tuple[Any, float]] = {}
@@ -92,16 +94,19 @@ def cache_key_with_token(token: str, func_qualified_name: str, *args: object, **
9294

9395

9496
def cached_operation(
95-
ttl: int, *, use_token: bool = True, instance_attrs: tuple[str, ...] | None = None
97+
ttl: int,
98+
*,
99+
token_provider: Callable[[], str] | None = None,
100+
instance_attrs: tuple[str, ...] | None = None,
96101
) -> Callable[[Callable[P, T]], Callable[P, T]]:
97102
"""Caches the result of a function call for a specified time-to-live (TTL).
98103
99104
Args:
100105
ttl (int): Time-to-live for the cache in seconds.
101-
use_token (bool): If True, includes the authentication token in the cache key.
102-
This is useful for Client methods that should cache per-user.
103-
When use_token is True and no instance_attrs are specified, the 'self'
104-
argument is excluded from the cache key to enable cache sharing across instances.
106+
token_provider (Callable[[], str] | None): A callable returning the current
107+
authentication token string. When provided, the token is included in the
108+
cache key for per-user isolation. Pass ``None`` to omit the token from
109+
the cache key.
105110
instance_attrs (tuple[str, ...] | None): Instance attributes to include in the cache key.
106111
This is useful for instance methods where caching should be per-instance based on
107112
specific attributes (e.g., 'run_id' for Run.details()).
@@ -132,8 +137,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
132137
instance_values = tuple(getattr(instance, attr) for attr in instance_attrs)
133138
cache_args = instance_values + args[1:]
134139

135-
if use_token:
136-
key = cache_key_with_token(get_token(True), func_qualified_name, *cache_args, **kwargs)
140+
if token_provider is not None:
141+
token = token_provider()
142+
key = cache_key_with_token(token, func_qualified_name, *cache_args, **kwargs)
137143
else:
138144
key = cache_key(func_qualified_name, *cache_args, **kwargs)
139145

src/aignostics/platform/resources/applications.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from operator import itemgetter
1010

1111
import semver
12-
from aignx.codegen.api.public_api import PublicApi
1312
from aignx.codegen.exceptions import NotFoundException, ServiceException
1413
from aignx.codegen.models import ApplicationReadResponse as Application
1514
from aignx.codegen.models import ApplicationReadShortResponse as ApplicationSummary
@@ -26,6 +25,7 @@
2625
from urllib3.exceptions import IncompleteRead, PoolError, ProtocolError, ProxyError
2726
from urllib3.exceptions import TimeoutError as Urllib3TimeoutError
2827

28+
from aignostics.platform._api import _AuthenticatedApi
2929
from aignostics.platform._operation_cache import cached_operation
3030
from aignostics.platform._settings import settings
3131
from aignostics.platform.resources.utils import paginate
@@ -66,11 +66,11 @@ class Versions:
6666
Provides operations to list and retrieve application versions.
6767
"""
6868

69-
def __init__(self, api: PublicApi) -> None:
69+
def __init__(self, api: _AuthenticatedApi) -> None:
7070
"""Initializes the Versions resource with the API platform.
7171
7272
Args:
73-
api (PublicApi): The configured API platform.
73+
api (_AuthenticatedApi): The configured API platform.
7474
"""
7575
self._api = api
7676

@@ -92,7 +92,7 @@ def list(self, application: Application | str, nocache: bool = False) -> builtin
9292
"""
9393
application_id = application.application_id if isinstance(application, Application) else application
9494

95-
@cached_operation(ttl=settings().application_cache_ttl, use_token=True)
95+
@cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider)
9696
def list_with_retry(app_id: str) -> Application:
9797
return Retrying(
9898
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
@@ -149,7 +149,7 @@ def details(
149149
raise ValueError(message)
150150

151151
# Make the API call with retry logic and caching
152-
@cached_operation(ttl=settings().application_version_cache_ttl, use_token=True)
152+
@cached_operation(ttl=settings().application_version_cache_ttl, token_provider=self._api.token_provider)
153153
def details_with_retry(app_id: str, app_version: str) -> ApplicationVersion:
154154
return Retrying(
155155
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
@@ -230,11 +230,11 @@ class Applications:
230230
Provides operations to list applications and access version resources.
231231
"""
232232

233-
def __init__(self, api: PublicApi) -> None:
233+
def __init__(self, api: _AuthenticatedApi) -> None:
234234
"""Initializes the Applications resource with the API platform.
235235
236236
Args:
237-
api (PublicApi): The configured API platform.
237+
api (_AuthenticatedApi): The configured API platform.
238238
"""
239239
self._api = api
240240
self.versions: Versions = Versions(self._api)
@@ -257,7 +257,7 @@ def details(self, application_id: str, nocache: bool = False) -> Application:
257257
aignx.codegen.exceptions.ApiException: If the API call fails.
258258
"""
259259

260-
@cached_operation(ttl=settings().application_cache_ttl, use_token=True)
260+
@cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider)
261261
def details_with_retry(application_id: str) -> Application:
262262
return Retrying(
263263
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),
@@ -293,7 +293,7 @@ def list(self, nocache: bool = False) -> t.Iterator[ApplicationSummary]:
293293

294294
# Create a wrapper function that applies retry logic and caching to each API call
295295
# Caching at this level ensures having a fresh iterator on cache hits
296-
@cached_operation(ttl=settings().application_cache_ttl, use_token=True)
296+
@cached_operation(ttl=settings().application_cache_ttl, token_provider=self._api.token_provider)
297297
def list_with_retry(**kwargs: object) -> builtins.list[ApplicationSummary]:
298298
return Retrying(
299299
retry=retry_if_exception_type(exception_types=RETRYABLE_EXCEPTIONS),

0 commit comments

Comments
 (0)