Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Commit 088946d

Browse files
committed
lukewarm refactoring of client_autn.py
adding __contains__ in MemoryDB to allow: `i in endpoint.cdb -> Bool` and avoid access try / except on the search for an entity
1 parent dff94e4 commit 088946d

3 files changed

Lines changed: 57 additions & 63 deletions

File tree

src/oidcendpoint/client_authn.py

Lines changed: 42 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,13 @@ def verify(self, request, **kwargs):
136136
logger.info("%s" % sanitize(err))
137137
raise AuthnFailure("Could not verify client_assertion.")
138138

139-
try:
140-
logger.debug("authntoken: %s" % sanitize(ca_jwt.to_dict()))
141-
except AttributeError:
142-
logger.debug("authntoken: %s" % sanitize(ca_jwt))
139+
authtoken = sanitize(ca_jwt)
140+
if hasattr(ca_jwt, 'to_dict') and callable(ca_jwt, 'to_dict'):
141+
authtoken = sanitize(ca_jwt.to_dict())
142+
logger.debug("authntoken: {}".format(authtoken))
143143

144144
request[verified_claim_name("client_assertion")] = ca_jwt
145-
146-
try:
147-
client_id = kwargs["client_id"]
148-
except KeyError:
149-
client_id = ca_jwt["iss"]
145+
client_id = kwargs.get("client_id") or ca_jwt["iss"]
150146

151147
# I should be among the audience
152148
# could be either my issuer id or the token endpoint
@@ -244,51 +240,45 @@ def verify_client(
244240
else:
245241
raise UnknownOrNoAuthnMethod(authorization_info)
246242

247-
try:
248-
client_id = auth_info["client_id"]
249-
except KeyError:
250-
try:
251-
_token = auth_info["token"]
252-
except KeyError:
243+
client_id = auth_info.get("client_id")
244+
_token = auth_info.get("token")
245+
246+
if client_id:
247+
248+
if not client_id in endpoint_context.cdb:
249+
raise ValueError("Unknown Client ID")
250+
251+
_cinfo = endpoint_context.cdb[client_id]
252+
if isinstance(_cinfo, str):
253+
if not _cinfo in endpoint_context.cdb:
254+
raise ValueError("Unknown Client ID")
255+
256+
if not valid_client_info(_cinfo):
257+
logger.warning("Client registration has timed out")
258+
raise ValueError("Not valid client")
259+
260+
# store what authn method was used
261+
if auth_info.get("method"):
262+
if endpoint_context.cdb[client_id].get("auth_method") and \
263+
request.__class__.__name__ in endpoint_context.cdb[client_id]["auth_method"]:
264+
endpoint_context.cdb[client_id]["auth_method"][
265+
request.__class__.__name__
266+
] = auth_info["method"]
267+
else:
268+
endpoint_context.cdb[client_id]["auth_method"] = {
269+
request.__class__.__name__: auth_info["method"]
270+
}
271+
272+
elif not client_id and get_client_id_from_token:
273+
if not _token:
253274
logger.warning("No token")
254-
else:
255-
if get_client_id_from_token:
256-
try:
257-
_id = get_client_id_from_token(endpoint_context, _token, request)
258-
except KeyError:
259-
raise ValueError("Unknown token")
260-
261-
if _id:
262-
auth_info["client_id"] = _id
263-
else:
275+
raise ValueError("No token")
276+
264277
try:
265-
_cinfo = endpoint_context.cdb[client_id]
278+
# get_client_id_from_token is a callback... Do not abuse for code readability.
279+
auth_info["client_id"] = get_client_id_from_token(endpoint_context,
280+
_token, request)
266281
except KeyError:
267-
raise ValueError("Unknown Client ID")
268-
else:
269-
if isinstance(_cinfo, str):
270-
try:
271-
_cinfo = endpoint_context.cdb[_cinfo]
272-
except KeyError:
273-
raise ValueError("Unknown Client ID")
274-
275-
try:
276-
valid_client_info(_cinfo)
277-
except KeyError:
278-
logger.warning("Client registration has timed out")
279-
raise ValueError("Not valid client")
280-
else:
281-
# store what authn method was used
282-
try:
283-
endpoint_context.cdb[client_id]["auth_method"][
284-
request.__class__.__name__
285-
] = auth_info["method"]
286-
except KeyError:
287-
try:
288-
endpoint_context.cdb[client_id]["auth_method"] = {
289-
request.__class__.__name__: auth_info["method"]
290-
}
291-
except KeyError:
292-
pass
282+
raise ValueError("Unknown token")
293283

294284
return auth_info

src/oidcendpoint/in_memory_db.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@ class InMemoryDataBase(object):
22
def __init__(self):
33
self.db = {}
44

5+
def __contains__(self, key):
6+
if self.db.get(key):
7+
return 1
8+
59
def set(self, key, value):
610
self.db[key] = value
711

812
def get(self, key):
9-
try:
10-
return self.db[key]
11-
except KeyError:
12-
return None
13+
return self.db.get(key, None)
1314

1415
def delete(self, key):
15-
del self.db[key]
16+
if self.db.get(key):
17+
del self.db[key]
1618

1719
def keys(self):
1820
return self.db.keys()

src/oidcendpoint/session.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from oidcendpoint import token_handler
1414
from oidcendpoint.authn_event import AuthnEvent
1515
from oidcendpoint.in_memory_db import InMemoryDataBase
16-
from oidcendpoint.sso_db import SSODb
16+
from oidcendpoint.sso_db import (SSODb, KEY_FORMAT)
1717
from oidcendpoint.token_handler import AccessCodeUsed
1818
from oidcendpoint.token_handler import ExpiredToken
1919
from oidcendpoint.token_handler import UnknownToken
@@ -119,7 +119,7 @@ def dict_match(a, b):
119119

120120
class SessionDB(object):
121121
def __init__(self, db, handler, sso_db, userinfo=None, sub_func=None):
122-
# db must implement the InMemoryStateDataBase interface
122+
# db must implement the InMemoryDataBase interface
123123
self._db = db
124124
self.handler = handler
125125
self.sso_db = sso_db
@@ -224,13 +224,14 @@ def update_by_token(self, token, **kwargs):
224224
return self.update(_sid, **kwargs)
225225

226226
def map_kv2sid(self, key, value, sid):
227-
self._db.set("__{}__{}__".format(key, value), sid)
227+
""" KEY_FORMAT = "__{}__{}" """
228+
self._db.set(KEY_FORMAT.format(key, value), sid)
228229

229230
def delete_kv2sid(self, key, value):
230-
self._db.delete("__{}__{}__".format(key, value))
231+
self._db.delete(KEY_FORMAT.format(key, value))
231232

232233
def get_sid_by_kv(self, key, value):
233-
return self._db.get("__{}__{}__".format(key, value))
234+
return self._db.get(KEY_FORMAT.format(key, value))
234235

235236
def get_token(self, sid):
236237
_sess_info = self[sid]
@@ -241,7 +242,8 @@ def get_token(self, sid):
241242
return _sess_info["access_token"]
242243

243244
def do_sub(
244-
self, sid, uid, client_salt, sector_id="", subject_type="public", user_salt=""
245+
self, sid, uid, client_salt, sector_id="",
246+
subject_type="public", user_salt=""
245247
):
246248
"""
247249
Create and store a subject identifier

0 commit comments

Comments
 (0)