77from oidcmsg .message import OPTIONAL_LIST_OF_STRINGS
88from oidcmsg .message import SINGLE_OPTIONAL_STRING
99from oidcmsg .message import SINGLE_REQUIRED_STRING
10+ from oidcmsg .message import SINGLE_OPTIONAL_JSON
1011from oidcmsg .message import msg_ser
1112from oidcmsg .oidc import AuthorizationRequest
1213
@@ -89,8 +90,27 @@ class SessionInfo(Message):
8990 "client_id" : SINGLE_REQUIRED_STRING ,
9091 "authn_event" : SINGLE_REQUIRED_AUTHN_EVENT ,
9192 "si_redirects" : OPTIONAL_LIST_OF_STRINGS ,
93+ "black_list" : SINGLE_OPTIONAL_JSON ,
9294 }
9395
96+ def __init__ (self , * args , ** kwargs ):
97+ super (SessionInfo , self ).__init__ (* args , ** kwargs )
98+ self ["black_list" ] = {}
99+
100+ def is_black_listed (self , typ , token ):
101+ # If session is revoked
102+ if "revoked" in self :
103+ return True
104+
105+ return typ in self ["black_list" ] and token in self ["black_list" ][typ ]
106+
107+ def black_list (self , typ ):
108+ if typ in self :
109+ if typ in self ["black_list" ]:
110+ self ["black_list" ][typ ].append (self [typ ])
111+ else :
112+ self ["black_list" ][typ ] = [self [typ ]]
113+
94114
95115def pairwise_id (uid , sector_identifier , salt , ** kwargs ):
96116 return hashlib .sha256 (
@@ -268,9 +288,9 @@ def do_sub(
268288
269289 return sub
270290
271- def is_valid (self , item ):
291+ def is_valid (self , typ , item ):
272292 try :
273- return not self . handler . is_black_listed (item )
293+ return not self [ item ]. is_black_listed (typ , item )
274294 except KeyError :
275295 return False
276296
@@ -295,9 +315,8 @@ def replace_token(self, sid, sinfo, token_type):
295315
296316 if token_type in self .handler :
297317 refresh_token = self .handler [token_type ](sid , sinfo = sinfo )
298- # blacklist the old is there is one
299- if sinfo .get (token_type ):
300- self .handler [token_type ].black_list (sinfo [token_type ])
318+ # blacklist the old
319+ sinfo .black_list (token_type )
301320
302321 sinfo [token_type ] = refresh_token
303322 return sinfo
@@ -336,22 +355,20 @@ def upgrade_to_token(
336355 _tinfo = self .handler ["code" ].info (grant )
337356
338357 session_info = self [_tinfo ["sid" ]]
358+ key = _tinfo ["sid" ]
339359
340- if self . handler [ "code" ]. is_black_listed ( grant ):
360+ if session_info . is_black_listed ( "code" , grant ):
341361 # invalidate the released access token and refresh token
342362 for item in ["access_token" , "refresh_token" ]:
343- try :
344- self .handler [item ].black_list (session_info [item ])
345- except KeyError :
346- pass
363+ session_info .black_list (item )
364+ self [key ] = session_info
347365 raise AccessCodeUsed (grant )
348366
349367 # mint a new access token
350368 _at = self ._make_at (_tinfo ["sid" ], session_info )
351369
352370 # make sure the code can't be used again
353- self .handler ["code" ].black_list (grant )
354- key = _tinfo ["sid" ]
371+ session_info .black_list ("code" )
355372 else :
356373 session_info = self [key ]
357374 _at = self ._make_at (key , session_info )
@@ -392,11 +409,11 @@ def refresh_token(self, token, new_refresh=False):
392409 except KeyError :
393410 return False
394411
395- if is_expired (int (_tinfo ["exp" ])) or _tinfo ["black_listed" ]:
396- raise ExpiredToken ()
397-
398412 _sid = _tinfo ["sid" ]
399413 session_info = self [_sid ]
414+ if is_expired (int (_tinfo ["exp" ])) or \
415+ session_info .is_black_listed ("refresh_token" , token ):
416+ raise ExpiredToken ()
400417
401418 session_info = self .replace_token (_sid , session_info , "access_token" )
402419
@@ -420,11 +437,11 @@ def is_token_valid(self, token):
420437 except KeyError :
421438 return False
422439
423- if is_expired (int (_tinfo ["exp" ])) or _tinfo ["black_listed" ]:
424- return False
425-
426440 # Dependent on what state the session is in.
427441 session_info = self [_tinfo ["sid" ]]
442+ if is_expired (int (_tinfo ["exp" ])) or \
443+ session_info .is_black_listed ("access_token" , token ):
444+ return False
428445
429446 if session_info ["oauth_state" ] == "authz" :
430447 if _tinfo ["handler" ] != self .handler ["code" ]:
@@ -435,22 +452,24 @@ def is_token_valid(self, token):
435452
436453 return True
437454
438- def revoke_token (self , token , token_type = "" ):
455+ def revoke_token (self , sid , token_type ):
439456 """
440- Revokes access token
457+ Revokes token
441458
442- :param token: access token
459+ :param sid: session id
460+ :param token_type: token type, one of "code", "access_token" or
461+ "refresh_token"
443462 """
444- if token_type :
445- self .handler [token_type ].black_list (token )
446- else :
447- self .handler .black_list (token )
463+ _sinfo = self [sid ]
464+ _sinfo .black_list (token_type )
465+ self [sid ] = _sinfo
448466
449467 def revoke_all_tokens (self , token ):
450- _sinfo = self [token ]
468+ sid = self .handler .sid (token )
469+ _sinfo = self [sid ]
451470 for typ in self .handler .keys ():
452- if _sinfo .get (typ ):
453- self . revoke_token ( _sinfo [ typ ], typ )
471+ _sinfo .black_list (typ )
472+ self [ sid ] = _sinfo
454473
455474 def revoke_session (self , sid = "" , token = "" ):
456475 """
@@ -465,11 +484,9 @@ def revoke_session(self, sid="", token=""):
465484 else :
466485 raise ValueError ('Need one of "sid" or "token"' )
467486
468- for typ in ["access_token" , "refresh_token" , "code" ]:
469- try :
470- self .revoke_token (self [sid ][typ ], typ )
471- except KeyError : # If no such token has been issued
472- pass
487+ _sinfo = self [sid ]
488+ for typ in self .handler .keys ():
489+ _sinfo .black_list (typ )
473490
474491 self .update (sid , revoked = True )
475492
0 commit comments