1111from .exceptions import InvalidRefreshToken
1212from .exceptions import InvalidScope
1313from .exceptions import InvalidSubjectIdentifier
14+ from .storage import StatelessWrapper
1415from .util import requested_scope_is_allowed
1516
1617logger = logging .getLogger (__name__ )
@@ -24,13 +25,15 @@ def rand_str():
2425
2526class AuthorizationState (object ):
2627 KEY_AUTHORIZATION_REQUEST = 'auth_req'
28+ KEY_USER_INFO = 'user_info'
29+ KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims'
2730
2831 def __init__ (self , subject_identifier_factory , authorization_code_db = None , access_token_db = None ,
2932 refresh_token_db = None , subject_identifier_db = None , * ,
3033 authorization_code_lifetime = 600 , access_token_lifetime = 3600 , refresh_token_lifetime = None ,
3134 refresh_token_threshold = None ):
3235 # type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any],
33- # Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
36+ # Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
3437 """
3538 :param subject_identifier_factory: callable to use when construction subject identifiers
3639 :param authorization_code_db: database for storing authorization codes, defaults to in-memory
@@ -77,10 +80,18 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces
7780 """
7881 Mapping of user id's to subject identifiers.
7982 """
80- self .subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}
83+ if isinstance (self .authorization_codes , StatelessWrapper ) or \
84+ isinstance (self .access_tokens , StatelessWrapper ) or isinstance (
85+ self .refresh_tokens , StatelessWrapper ):
86+ self .stateless = True
87+ self .subject_identifiers = {}
88+ else :
89+ self .stateless = False
90+ self .subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}
8191
82- def create_authorization_code (self , authorization_request , subject_identifier , scope = None ):
83- # type: (AuthorizationRequest, str, Optional[List[str]]) -> str
92+ def create_authorization_code (self , authorization_request , subject_identifier , scope = None , user_info = None ,
93+ extra_id_token_claims = None ):
94+ # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str
8495 """
8596 Creates an authorization code bound to the authorization request and the authenticated user identified
8697 by the subject identifier.
@@ -92,21 +103,29 @@ def create_authorization_code(self, authorization_request, subject_identifier, s
92103 scope = ' ' .join (scope or authorization_request ['scope' ])
93104 logger .debug ('creating authz code for scope=%s' , scope )
94105
95- authorization_code = rand_str ()
96106 authz_info = {
97107 'used' : False ,
98108 'exp' : int (time .time ()) + self .authorization_code_lifetime ,
99109 'sub' : subject_identifier ,
100110 'granted_scope' : scope ,
101111 self .KEY_AUTHORIZATION_REQUEST : authorization_request .to_dict ()
102112 }
103- self .authorization_codes [authorization_code ] = authz_info
113+
114+ if isinstance (self .authorization_codes , StatelessWrapper ):
115+ if user_info :
116+ authz_info [self .KEY_USER_INFO ] = user_info
117+ authz_info [self .KEY_EXTRA_ID_TOKEN_CLAIMS ] = extra_id_token_claims or {}
118+ authorization_code = self .authorization_codes .pack (authz_info )
119+ else :
120+ authorization_code = rand_str ()
121+ self .authorization_codes [authorization_code ] = authz_info
122+
104123 logger .debug ('new authz_code=%s to client_id=%s for sub=%s valid_until=%s' , authorization_code ,
105124 authorization_request ['client_id' ], subject_identifier , authz_info ['exp' ])
106125 return authorization_code
107126
108- def create_access_token (self , authorization_request , subject_identifier , scope = None ):
109- # type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
127+ def create_access_token (self , authorization_request , subject_identifier , scope = None , user_info = None ):
128+ # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict] ) -> se_leg_op.access_token.AccessToken
110129 """
111130 Creates an access token bound to the authentication request and the authenticated user identified by the
112131 subject identifier.
@@ -116,15 +135,15 @@ def create_access_token(self, authorization_request, subject_identifier, scope=N
116135
117136 scope = scope or authorization_request ['scope' ]
118137
119- return self ._create_access_token (subject_identifier , authorization_request .to_dict (), ' ' .join (scope ))
138+ return self ._create_access_token (subject_identifier , authorization_request .to_dict (), ' ' .join (scope ),
139+ user_info = user_info )
120140
121- def _create_access_token (self , subject_identifier , auth_req , granted_scope , current_scope = None ):
122- # type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str]) -> se_leg_op.access_token.AccessToken
141+ def _create_access_token (self , subject_identifier , auth_req , granted_scope , current_scope = None ,
142+ user_info = None ):
143+ # type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken
123144 """
124145 Creates an access token bound to the subject identifier, client id and requested scope.
125146 """
126- access_token = AccessToken (rand_str (), self .access_token_lifetime )
127-
128147 scope = current_scope or granted_scope
129148 logger .debug ('creating access token for scope=%s' , scope )
130149
@@ -136,13 +155,21 @@ def _create_access_token(self, subject_identifier, auth_req, granted_scope, curr
136155 'aud' : [auth_req ['client_id' ]],
137156 'scope' : scope ,
138157 'granted_scope' : granted_scope ,
139- 'token_type' : access_token .BEARER_TOKEN_TYPE ,
158+ 'token_type' : AccessToken .BEARER_TOKEN_TYPE ,
140159 self .KEY_AUTHORIZATION_REQUEST : auth_req
141160 }
142- self .access_tokens [access_token .value ] = authz_info
161+
162+ if isinstance (self .access_tokens , StatelessWrapper ):
163+ if user_info :
164+ authz_info [self .KEY_USER_INFO ] = user_info
165+ access_token_val = self .access_tokens .pack (authz_info )
166+ else :
167+ access_token_val = rand_str ()
168+ self .access_tokens [access_token_val ] = authz_info
143169
144170 logger .debug ('new access_token=%s to client_id=%s for sub=%s valid_until=%s' ,
145- access_token .value , auth_req ['client_id' ], subject_identifier , authz_info ['exp' ])
171+ access_token_val , auth_req ['client_id' ], subject_identifier , authz_info ['exp' ])
172+ access_token = AccessToken (access_token_val , self .access_token_lifetime )
146173 return access_token
147174
148175 def exchange_code_for_token (self , authorization_code ):
@@ -165,7 +192,8 @@ def exchange_code_for_token(self, authorization_code):
165192 authz_info ['used' ] = True
166193
167194 access_token = self ._create_access_token (authz_info ['sub' ], authz_info [self .KEY_AUTHORIZATION_REQUEST ],
168- authz_info ['granted_scope' ])
195+ authz_info ['granted_scope' ],
196+ user_info = authz_info .get (self .KEY_USER_INFO ))
169197
170198 logger .debug ('authz_code=%s exchanged to access_token=%s' , authorization_code , access_token .value )
171199 return access_token
@@ -199,9 +227,13 @@ def create_refresh_token(self, access_token_value):
199227 logger .debug ('no refresh token issued for for access_token=%s' , access_token_value )
200228 return None
201229
202- refresh_token = rand_str ()
203230 authz_info = {'access_token' : access_token_value , 'exp' : int (time .time ()) + self .refresh_token_lifetime }
204- self .refresh_tokens [refresh_token ] = authz_info
231+
232+ if isinstance (self .refresh_tokens , StatelessWrapper ):
233+ refresh_token = self .refresh_tokens .pack (authz_info )
234+ else :
235+ refresh_token = rand_str ()
236+ self .refresh_tokens [refresh_token ] = authz_info
205237
206238 logger .debug ('issued refresh_token=%s expiring=%d for access_token=%s' , refresh_token , authz_info ['exp' ],
207239 access_token_value )
@@ -235,7 +267,8 @@ def use_refresh_token(self, refresh_token, scope=None):
235267 scope = authz_info ['granted_scope' ]
236268
237269 new_access_token = self ._create_access_token (authz_info ['sub' ], authz_info [self .KEY_AUTHORIZATION_REQUEST ],
238- authz_info ['granted_scope' ], scope )
270+ authz_info ['granted_scope' ], scope ,
271+ user_info = authz_info .get (self .KEY_USER_INFO ))
239272
240273 new_refresh_token = None
241274 if self .refresh_token_threshold \
@@ -314,6 +347,27 @@ def get_user_id_for_subject_identifier(self, subject_identifier):
314347
315348 raise InvalidSubjectIdentifier ('{} unknown' .format (subject_identifier ))
316349
350+ def get_user_info_for_code (self , authorization_code ):
351+ # type: (str) -> dict
352+ if authorization_code not in self .authorization_codes :
353+ raise InvalidAuthorizationCode ('{} unknown' .format (authorization_code ))
354+
355+ return self .authorization_codes [authorization_code ].get (self .KEY_USER_INFO )
356+
357+ def get_extra_io_token_claims_for_code (self , authorization_code ):
358+ # type: (str) -> dict
359+ if authorization_code not in self .authorization_codes :
360+ raise InvalidAuthorizationCode ('{} unknown' .format (authorization_code ))
361+
362+ return self .authorization_codes [authorization_code ].get (self .KEY_EXTRA_ID_TOKEN_CLAIMS )
363+
364+ def get_user_info_for_access_token (self , access_token ):
365+ # type: (str) -> dict
366+ if access_token not in self .access_tokens :
367+ raise InvalidAccessToken ('{} unknown' .format (access_token ))
368+
369+ return self .access_tokens [access_token ].get (self .KEY_USER_INFO )
370+
317371 def get_authorization_request_for_code (self , authorization_code ):
318372 # type: (str) -> AuthorizationRequest
319373 if authorization_code not in self .authorization_codes :
0 commit comments