Skip to content

Commit 2b4bd40

Browse files
Ali Haiderc00kiemon5ter
authored andcommitted
Support stateless code flow.
If stateless flag is set to true, nothing is stored in the storage. The code and access token is embedded with the user_info which is required later for the evaluation of required claims. Signed-off-by: Ivan Kanakarakis <ivan.kanak@gmail.com>
1 parent b2ac7fe commit 2b4bd40

8 files changed

Lines changed: 1131 additions & 56 deletions

File tree

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
description='OpenID Connect Provider (OP) library in Python.',
1313
install_requires=[
1414
'oic >= 1.2.1',
15+
'pycryptodomex',
1516
],
1617
extras_require={
1718
'mongo': 'pymongo',

src/pyop/authz_state.py

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .exceptions import InvalidRefreshToken
1212
from .exceptions import InvalidScope
1313
from .exceptions import InvalidSubjectIdentifier
14+
from .storage import StatelessWrapper
1415
from .util import requested_scope_is_allowed
1516

1617
logger = logging.getLogger(__name__)
@@ -24,13 +25,15 @@ def rand_str():
2425

2526
class AuthorizationState(object):
2627
KEY_AUTHORIZATION_REQUEST = 'auth_req'
28+
KEY_USER_INFO = 'user_info'
29+
KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims'
2730

2831
def __init__(self, subject_identifier_factory, authorization_code_db=None, access_token_db=None,
2932
refresh_token_db=None, subject_identifier_db=None, *,
3033
authorization_code_lifetime=600, access_token_lifetime=3600, refresh_token_lifetime=None,
3134
refresh_token_threshold=None):
3235
# type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any],
33-
# Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
36+
# Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
3437
"""
3538
:param subject_identifier_factory: callable to use when construction subject identifiers
3639
:param authorization_code_db: database for storing authorization codes, defaults to in-memory
@@ -77,10 +80,18 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces
7780
"""
7881
Mapping of user id's to subject identifiers.
7982
"""
80-
self.subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}
83+
if isinstance(self.authorization_codes, StatelessWrapper) or \
84+
isinstance(self.access_tokens, StatelessWrapper) or isinstance(
85+
self.refresh_tokens, StatelessWrapper):
86+
self.stateless = True
87+
self.subject_identifiers = {}
88+
else:
89+
self.stateless = False
90+
self.subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}
8191

82-
def create_authorization_code(self, authorization_request, subject_identifier, scope=None):
83-
# type: (AuthorizationRequest, str, Optional[List[str]]) -> str
92+
def create_authorization_code(self, authorization_request, subject_identifier, scope=None, user_info=None,
93+
extra_id_token_claims=None):
94+
# type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str
8495
"""
8596
Creates an authorization code bound to the authorization request and the authenticated user identified
8697
by the subject identifier.
@@ -92,21 +103,29 @@ def create_authorization_code(self, authorization_request, subject_identifier, s
92103
scope = ' '.join(scope or authorization_request['scope'])
93104
logger.debug('creating authz code for scope=%s', scope)
94105

95-
authorization_code = rand_str()
96106
authz_info = {
97107
'used': False,
98108
'exp': int(time.time()) + self.authorization_code_lifetime,
99109
'sub': subject_identifier,
100110
'granted_scope': scope,
101111
self.KEY_AUTHORIZATION_REQUEST: authorization_request.to_dict()
102112
}
103-
self.authorization_codes[authorization_code] = authz_info
113+
114+
if isinstance(self.authorization_codes, StatelessWrapper):
115+
if user_info:
116+
authz_info[self.KEY_USER_INFO] = user_info
117+
authz_info[self.KEY_EXTRA_ID_TOKEN_CLAIMS] = extra_id_token_claims or {}
118+
authorization_code = self.authorization_codes.pack(authz_info)
119+
else:
120+
authorization_code = rand_str()
121+
self.authorization_codes[authorization_code] = authz_info
122+
104123
logger.debug('new authz_code=%s to client_id=%s for sub=%s valid_until=%s', authorization_code,
105124
authorization_request['client_id'], subject_identifier, authz_info['exp'])
106125
return authorization_code
107126

108-
def create_access_token(self, authorization_request, subject_identifier, scope=None):
109-
# type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
127+
def create_access_token(self, authorization_request, subject_identifier, scope=None, user_info=None):
128+
# type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict]) -> se_leg_op.access_token.AccessToken
110129
"""
111130
Creates an access token bound to the authentication request and the authenticated user identified by the
112131
subject identifier.
@@ -116,15 +135,15 @@ def create_access_token(self, authorization_request, subject_identifier, scope=N
116135

117136
scope = scope or authorization_request['scope']
118137

119-
return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope))
138+
return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope),
139+
user_info=user_info)
120140

121-
def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None):
122-
# type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str]) -> se_leg_op.access_token.AccessToken
141+
def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None,
142+
user_info=None):
143+
# type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken
123144
"""
124145
Creates an access token bound to the subject identifier, client id and requested scope.
125146
"""
126-
access_token = AccessToken(rand_str(), self.access_token_lifetime)
127-
128147
scope = current_scope or granted_scope
129148
logger.debug('creating access token for scope=%s', scope)
130149

@@ -136,13 +155,21 @@ def _create_access_token(self, subject_identifier, auth_req, granted_scope, curr
136155
'aud': [auth_req['client_id']],
137156
'scope': scope,
138157
'granted_scope': granted_scope,
139-
'token_type': access_token.BEARER_TOKEN_TYPE,
158+
'token_type': AccessToken.BEARER_TOKEN_TYPE,
140159
self.KEY_AUTHORIZATION_REQUEST: auth_req
141160
}
142-
self.access_tokens[access_token.value] = authz_info
161+
162+
if isinstance(self.access_tokens, StatelessWrapper):
163+
if user_info:
164+
authz_info[self.KEY_USER_INFO] = user_info
165+
access_token_val = self.access_tokens.pack(authz_info)
166+
else:
167+
access_token_val = rand_str()
168+
self.access_tokens[access_token_val] = authz_info
143169

144170
logger.debug('new access_token=%s to client_id=%s for sub=%s valid_until=%s',
145-
access_token.value, auth_req['client_id'], subject_identifier, authz_info['exp'])
171+
access_token_val, auth_req['client_id'], subject_identifier, authz_info['exp'])
172+
access_token = AccessToken(access_token_val, self.access_token_lifetime)
146173
return access_token
147174

148175
def exchange_code_for_token(self, authorization_code):
@@ -165,7 +192,8 @@ def exchange_code_for_token(self, authorization_code):
165192
authz_info['used'] = True
166193

167194
access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
168-
authz_info['granted_scope'])
195+
authz_info['granted_scope'],
196+
user_info=authz_info.get(self.KEY_USER_INFO))
169197

170198
logger.debug('authz_code=%s exchanged to access_token=%s', authorization_code, access_token.value)
171199
return access_token
@@ -199,9 +227,13 @@ def create_refresh_token(self, access_token_value):
199227
logger.debug('no refresh token issued for for access_token=%s', access_token_value)
200228
return None
201229

202-
refresh_token = rand_str()
203230
authz_info = {'access_token': access_token_value, 'exp': int(time.time()) + self.refresh_token_lifetime}
204-
self.refresh_tokens[refresh_token] = authz_info
231+
232+
if isinstance(self.refresh_tokens, StatelessWrapper):
233+
refresh_token = self.refresh_tokens.pack(authz_info)
234+
else:
235+
refresh_token = rand_str()
236+
self.refresh_tokens[refresh_token] = authz_info
205237

206238
logger.debug('issued refresh_token=%s expiring=%d for access_token=%s', refresh_token, authz_info['exp'],
207239
access_token_value)
@@ -235,7 +267,8 @@ def use_refresh_token(self, refresh_token, scope=None):
235267
scope = authz_info['granted_scope']
236268

237269
new_access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
238-
authz_info['granted_scope'], scope)
270+
authz_info['granted_scope'], scope,
271+
user_info=authz_info.get(self.KEY_USER_INFO))
239272

240273
new_refresh_token = None
241274
if self.refresh_token_threshold \
@@ -314,6 +347,27 @@ def get_user_id_for_subject_identifier(self, subject_identifier):
314347

315348
raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))
316349

350+
def get_user_info_for_code(self, authorization_code):
351+
# type: (str) -> dict
352+
if authorization_code not in self.authorization_codes:
353+
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
354+
355+
return self.authorization_codes[authorization_code].get(self.KEY_USER_INFO)
356+
357+
def get_extra_io_token_claims_for_code(self, authorization_code):
358+
# type: (str) -> dict
359+
if authorization_code not in self.authorization_codes:
360+
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
361+
362+
return self.authorization_codes[authorization_code].get(self.KEY_EXTRA_ID_TOKEN_CLAIMS)
363+
364+
def get_user_info_for_access_token(self, access_token):
365+
# type: (str) -> dict
366+
if access_token not in self.access_tokens:
367+
raise InvalidAccessToken('{} unknown'.format(access_token))
368+
369+
return self.access_tokens[access_token].get(self.KEY_USER_INFO)
370+
317371
def get_authorization_request_for_code(self, authorization_code):
318372
# type: (str) -> AuthorizationRequest
319373
if authorization_code not in self.authorization_codes:

0 commit comments

Comments
 (0)