|
| 1 | +import time |
| 2 | +from typing import Optional |
| 3 | +import requests |
| 4 | + |
| 5 | +class OAuth2Client: |
| 6 | + """OAuth 2.0 client for handling client credentials flow with token caching""" |
| 7 | + |
| 8 | + def __init__(self, client_id: str, client_secret: str, access_token_uri: str): |
| 9 | + self.client_id = client_id |
| 10 | + self.client_secret = client_secret |
| 11 | + self.access_token_uri = access_token_uri |
| 12 | + self.access_token: Optional[str] = None |
| 13 | + self.token_expires_at: Optional[float] = None |
| 14 | + |
| 15 | + def _is_token_expired(self, buffer_seconds: int = 30) -> bool: |
| 16 | + """Check if the current token is expired or will expire soon""" |
| 17 | + if not self.access_token or not self.token_expires_at: |
| 18 | + return True |
| 19 | + |
| 20 | + # Add buffer time to refresh token before it actually expires |
| 21 | + return time.time() >= (self.token_expires_at - buffer_seconds) |
| 22 | + |
| 23 | + def _request_new_token(self) -> str: |
| 24 | + """Request a new access token from the OAuth server""" |
| 25 | + # Prepare token request |
| 26 | + token_data = { |
| 27 | + 'grant_type': 'client_credentials', |
| 28 | + 'client_id': self.client_id, |
| 29 | + 'client_secret': self.client_secret |
| 30 | + } |
| 31 | + |
| 32 | + headers = { |
| 33 | + 'Content-Type': 'application/x-www-form-urlencoded' |
| 34 | + } |
| 35 | + |
| 36 | + try: |
| 37 | + response = requests.post( |
| 38 | + self.access_token_uri, |
| 39 | + data=token_data, |
| 40 | + headers=headers, |
| 41 | + timeout=30 |
| 42 | + ) |
| 43 | + response.raise_for_status() |
| 44 | + |
| 45 | + token_response = response.json() |
| 46 | + self.access_token = token_response.get('access_token') |
| 47 | + |
| 48 | + # Calculate expiration time |
| 49 | + expires_in = token_response.get('expires_in', 3600) # Default to 1 hour |
| 50 | + self.token_expires_at = time.time() + expires_in |
| 51 | + |
| 52 | + if not self.access_token: |
| 53 | + raise ValueError("No access token received from OAuth server") |
| 54 | + |
| 55 | + return self.access_token |
| 56 | + except requests.exceptions.RequestException as e: |
| 57 | + raise ValueError(f"Failed to get access token: {e}") from e |
| 58 | + |
| 59 | + def get_access_token(self) -> str: |
| 60 | + """Get OAuth 2.0 access token using client credentials flow""" |
| 61 | + # Check if we have a valid cached token |
| 62 | + if not self._is_token_expired(): |
| 63 | + return self.access_token |
| 64 | + # Token is expired or doesn't exist, get a new one |
| 65 | + return self._request_new_token() |
0 commit comments