Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Commit c59ea9e

Browse files
committed
custom session db
1 parent 96cb50f commit c59ea9e

2 files changed

Lines changed: 38 additions & 27 deletions

File tree

src/oidcendpoint/endpoint_context.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from oidcendpoint.util import build_endpoints
2727
from oidcendpoint.util import importer
2828

29-
LOGGER = logging.getLogger(__name__)
29+
logger = logging.getLogger(__name__)
3030

3131
CAPABILITIES = {
3232
"response_types_supported": [
@@ -150,6 +150,7 @@ def __init__(
150150
keyjar=None,
151151
client_db=None,
152152
session_db=None,
153+
sso_db=None,
153154
cwd="",
154155
cookie_dealer=None,
155156
httpc=None,
@@ -188,7 +189,13 @@ def __init__(
188189
# arguments for endpoints add-ons
189190
self.args = {}
190191

192+
# session db
191193
self._sub_func = None
194+
self.sdb = session_db
195+
if not self.sdb:
196+
self.set_session_db(conf, sso_db)
197+
#
198+
192199
self.scope2claims = SCOPE2CLAIMS
193200

194201
if cookie_name:
@@ -251,11 +258,6 @@ def __init__(
251258
if _func:
252259
_func(self.conf)
253260

254-
if session_db:
255-
self.sdb = session_db
256-
else:
257-
self.do_session_db(conf)
258-
259261
_cap = self.do_endpoints(conf)
260262

261263
for item in ["userinfo", "login_hint_lookup", "login_hint2acrs",
@@ -275,6 +277,14 @@ def __init__(
275277
# client registration access tokens
276278
self.registration_access_token = {}
277279

280+
def set_session_db(self, conf, sso_db=None):
281+
# this populate self.sdb
282+
sso_db = sso_db if sso_db else SSODb()
283+
self.do_session_db(conf, sso_db)
284+
# this append useinfo db to the session db
285+
self.do_userinfo(conf)
286+
logger.debug('Session DB: {}'.format(self.sdb.__dict__))
287+
278288
def do_add_on(self, conf):
279289
if 'add_on' in conf:
280290
for spec in conf["add_on"].values():
@@ -309,8 +319,10 @@ def do_userinfo(self, conf):
309319
except KeyError:
310320
pass
311321
else:
312-
self.userinfo = init_user_info(_conf, self.cwd)
313-
self.sdb.userinfo = self.userinfo
322+
if self.sdb:
323+
self.userinfo = init_user_info(_conf, self.cwd)
324+
self.sdb.userinfo = self.userinfo
325+
314326

315327
def do_id_token(self, conf):
316328
try:
@@ -338,25 +350,23 @@ def do_cookie_dealer(self, conf):
338350
self.cookie_dealer = init_service(_conf)
339351

340352
def do_sub_func(self, conf):
341-
try:
342-
_conf = conf["sub_func"]
343-
except KeyError:
344-
self._sub_func = None
345-
else:
346-
self._sub_func = {}
347-
for key, args in _conf.items():
348-
if "class" in args:
349-
self._sub_func[key] = init_service(args)
350-
elif "function" in args:
351-
if isinstance(args["function"], str):
352-
self._sub_func[key] = util.importer(args["function"])
353-
else:
354-
self._sub_func[key] = args["function"]
355-
356-
def do_session_db(self, conf):
353+
_conf = conf.get("sub_func", {})
354+
self._sub_func = {}
355+
for key, args in _conf.items():
356+
if "class" in args:
357+
self._sub_func[key] = init_service(args)
358+
elif "function" in args:
359+
if isinstance(args["function"], str):
360+
self._sub_func[key] = util.importer(args["function"])
361+
else:
362+
self._sub_func[key] = args["function"]
363+
364+
def do_session_db(self, conf, sso_db):
357365
th_args = get_token_handlers(conf)
358366
self.sdb = create_session_db(
359-
self, th_args, db=None, sso_db=SSODb(), sub_func=self._sub_func
367+
self, th_args, db=None,
368+
sso_db=sso_db,
369+
sub_func=self._sub_func
360370
)
361371

362372
def do_endpoints(self, conf):
@@ -498,7 +508,7 @@ def create_providerinfo(self, capabilities):
498508
_msg = "Server doesn't support the following features: {}".format(
499509
not_supported
500510
)
501-
LOGGER.error(_msg)
511+
logger.error(_msg)
502512
raise ConfigurationError(_msg)
503513

504514
if self.jwks_uri and self.keyjar:

src/oidcendpoint/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,8 @@ def get_authentication_event(self, sid):
565565
raise ValueError("No Authn event info")
566566

567567

568-
def create_session_db(ec, token_handler_args, db=None, sso_db=SSODb(), sub_func=None):
568+
def create_session_db(ec, token_handler_args, db=None,
569+
sso_db=SSODb(), sub_func=None):
569570
_token_handler = token_handler.factory(ec, **token_handler_args)
570571

571572
if not db:

0 commit comments

Comments
 (0)