Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Commit 6e21ab9

Browse files
authored
Merge branch 'master' into custom_sessiondb
2 parents c59ea9e + 41b140b commit 6e21ab9

23 files changed

Lines changed: 575 additions & 329 deletions

src/oidcendpoint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
except ImportError:
77
import random as rnd
88

9-
__version__ = "0.10.1"
9+
__version__ = "0.11.0"
1010

1111

1212
DEF_SIGN_ALG = {

src/oidcendpoint/cookie.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
__author__ = "Roland Hedberg"
2626

27-
logger = logging.getLogger(__name__)
27+
LOGGER = logging.getLogger(__name__)
2828

2929
CORS_HEADERS = [
3030
("Access-Control-Allow-Origin", "*"),
@@ -466,12 +466,15 @@ def append_cookie(
466466

467467
def compute_session_state(opbs, salt, client_id, redirect_uri):
468468
"""
469+
Computes a session state value.
470+
This value is later used during session management to check whether
471+
the log in state has changed.
469472
470-
:param opbs:
473+
:param opbs: Cookie value
471474
:param salt:
472475
:param client_id:
473476
:param redirect_uri:
474-
:return:
477+
:return: Session state value
475478
"""
476479
parsed_uri = urlparse(redirect_uri)
477480
rp_origin_url = "{uri.scheme}://{uri.netloc}".format(uri=parsed_uri)

src/oidcendpoint/endpoint.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
2+
from functools import cmp_to_key
23
from urllib.parse import urlparse
34

5+
from cryptojwt import jwe
6+
from cryptojwt.jws.jws import SIGNER_ALGS
47
from oidcmsg.exception import MissingRequiredAttribute
58
from oidcmsg.exception import MissingRequiredValue
69
from oidcmsg.message import Message
@@ -51,6 +54,72 @@ def fragment_encoding(return_type):
5154
return True
5255

5356

57+
ALG_SORT_ORDER = {"RS": 0, "ES": 1, "HS": 2, "PS": 3, "no": 4}
58+
59+
60+
def sort_sign_alg(alg1, alg2):
61+
if ALG_SORT_ORDER[alg1[0:2]] < ALG_SORT_ORDER[alg2[0:2]]:
62+
return -1
63+
64+
if ALG_SORT_ORDER[alg1[0:2]] > ALG_SORT_ORDER[alg2[0:2]]:
65+
return 1
66+
67+
if alg1 < alg2:
68+
return -1
69+
70+
if alg1 > alg2:
71+
return 1
72+
73+
return 0
74+
75+
76+
def assign_algorithms(typ):
77+
if typ == "signing_alg":
78+
# Pick supported signing algorithms from crypto library
79+
# Sort order RS, ES, HS, PS
80+
sign_algs = list(SIGNER_ALGS.keys())
81+
return sorted(sign_algs, key=cmp_to_key(sort_sign_alg))
82+
elif typ == "encryption_alg":
83+
return jwe.SUPPORTED["alg"]
84+
elif typ == "encryption_enc":
85+
return jwe.SUPPORTED["enc"]
86+
87+
88+
def construct_provider_info(default_capabilities, **kwargs):
89+
if default_capabilities is not None:
90+
provider_info = {}
91+
for attr, default_val in default_capabilities.items():
92+
try:
93+
_proposal = kwargs[attr]
94+
except KeyError:
95+
if default_val is not None:
96+
provider_info[attr] = default_val
97+
elif "signing_alg_values_supported" in attr:
98+
provider_info[attr] = assign_algorithms("signing_alg")
99+
elif "encryption_alg_values_supported" in attr:
100+
provider_info[attr] = assign_algorithms("encryption_alg")
101+
elif "encryption_enc_values_supported" in attr:
102+
provider_info[attr] = assign_algorithms("encryption_enc")
103+
else:
104+
_permitted = None
105+
106+
if "signing_alg_values_supported" in attr:
107+
_permitted = set(assign_algorithms("signing_alg"))
108+
elif "encryption_alg_values_supported" in attr:
109+
_permitted = set(assign_algorithms("encryption_alg"))
110+
elif "encryption_enc_values_supported" in attr:
111+
_permitted = set(assign_algorithms("encryption_enc"))
112+
113+
if _permitted and not _permitted.issuperset(set(_proposal)):
114+
raise ValueError(
115+
"Proposed set of values outside set of permitted ({})".__format__(attr))
116+
117+
provider_info[attr] = _proposal
118+
return provider_info
119+
else:
120+
return None
121+
122+
54123
class Endpoint(object):
55124
request_cls = Message
56125
response_cls = Message
@@ -63,6 +132,7 @@ class Endpoint(object):
63132
response_format = "json"
64133
response_placement = "body"
65134
client_authn_method = ""
135+
default_capabilities = None
66136

67137
def __init__(self, endpoint_context, **kwargs):
68138
self.endpoint_context = endpoint_context
@@ -71,7 +141,14 @@ def __init__(self, endpoint_context, **kwargs):
71141
self.post_parse_request = []
72142
self.kwargs = kwargs
73143
self.full_path = ""
74-
self.provider_info = None
144+
145+
if "client_authn_method" in kwargs:
146+
self.client_authn_method = kwargs["client_authn_method"]
147+
elif self.default_capabilities is not None:
148+
if "client_authn_method" in self.default_capabilities:
149+
self.client_authn_method = self.default_capabilities["client_authn_method"]
150+
151+
self.provider_info = construct_provider_info(self.default_capabilities, **kwargs)
75152

76153
def parse_request(self, request, auth=None, **kwargs):
77154
"""

src/oidcendpoint/endpoint_context.py

Lines changed: 20 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import copy
21
import logging
32
import os
4-
from functools import cmp_to_key
53

64
import requests
7-
from cryptojwt import jwe
8-
from cryptojwt.jws.jws import SIGNER_ALGS
95
from cryptojwt.key_jar import KeyJar
106
from cryptojwt.key_jar import init_key_jar
117
from jinja2 import Environment
@@ -16,7 +12,6 @@
1612
from oidcendpoint import rndstr
1713
from oidcendpoint import util
1814
from oidcendpoint.client_authn import CLIENT_AUTHN_METHOD
19-
from oidcendpoint.exception import ConfigurationError
2015
from oidcendpoint.id_token import IDToken
2116
from oidcendpoint.session import create_session_db
2217
from oidcendpoint.sso_db import SSODb
@@ -28,55 +23,6 @@
2823

2924
logger = logging.getLogger(__name__)
3025

31-
CAPABILITIES = {
32-
"response_types_supported": [
33-
"code",
34-
"token",
35-
"id_token",
36-
"code token",
37-
"code id_token",
38-
"id_token token",
39-
"code id_token token",
40-
"none",
41-
],
42-
"token_endpoint_auth_methods_supported": [
43-
"client_secret_post",
44-
"client_secret_basic",
45-
"client_secret_jwt",
46-
"private_key_jwt",
47-
],
48-
"response_modes_supported": ["query", "fragment", "form_post"],
49-
"subject_types_supported": ["public", "pairwise"],
50-
"grant_types_supported": [
51-
"authorization_code",
52-
"implicit",
53-
"urn:ietf:params:oauth:grant-type:jwt-bearer",
54-
"refresh_token",
55-
],
56-
"claim_types_supported": ["normal", "aggregated", "distributed"],
57-
"claims_parameter_supported": True,
58-
"request_parameter_supported": True,
59-
"request_uri_parameter_supported": True,
60-
}
61-
62-
SORT_ORDER = {"RS": 0, "ES": 1, "HS": 2, "PS": 3, "no": 4}
63-
64-
65-
def sort_sign_alg(alg1, alg2):
66-
if SORT_ORDER[alg1[0:2]] < SORT_ORDER[alg2[0:2]]:
67-
return -1
68-
69-
if SORT_ORDER[alg1[0:2]] > SORT_ORDER[alg2[0:2]]:
70-
return 1
71-
72-
if alg1 < alg2:
73-
return -1
74-
75-
if alg1 > alg2:
76-
return 1
77-
78-
return 0
79-
8026

8127
def add_path(url, path):
8228
if url.endswith("/"):
@@ -388,9 +334,9 @@ def do_endpoints(self, conf):
388334
if endpoint in ["webfinger", "provider_info"]:
389335
continue
390336

391-
_cap[endpoint_instance.endpoint_name] = "{}".format(
392-
endpoint_instance.endpoint_path
393-
)
337+
if endpoint_instance.endpoint_name:
338+
_cap[endpoint_instance.endpoint_name] = endpoint_instance.full_path
339+
394340
return _cap
395341

396342
def do_authz(self, conf):
@@ -401,90 +347,11 @@ def do_authz(self, conf):
401347
else:
402348
self.authz = init_service(authz_spec, self)
403349

404-
def package_capabilities(self):
405-
_provider_info = copy.deepcopy(CAPABILITIES)
406-
_provider_info["issuer"] = self.issuer
407-
_provider_info["version"] = "3.0"
408-
409-
_claims = []
410-
for _cl in self.scope2claims.values():
411-
_claims.extend(_cl)
412-
_provider_info["claims_supported"] = list(set(_claims))
413-
414-
_scopes = list(self.scope2claims.keys())
415-
_provider_info["scopes_supported"] = _scopes
416-
417-
# Sort order RS, ES, HS, PS
418-
sign_algs = list(SIGNER_ALGS.keys())
419-
sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg))
420-
421-
for typ in ["userinfo", "id_token", "request_object"]:
422-
_provider_info["%s_signing_alg_values_supported" % typ] = sign_algs
423-
424-
# Remove 'none' for token_endpoint_auth_signing_alg_values_supported
425-
# since it is not allowed
426-
sign_algs = sign_algs[:]
427-
sign_algs.remove("none")
428-
_provider_info["token_endpoint_auth_signing_alg_values_supported"] = sign_algs
429-
430-
algs = jwe.SUPPORTED["alg"]
431-
for typ in ["userinfo", "id_token", "request_object"]:
432-
_provider_info["%s_encryption_alg_values_supported" % typ] = algs
433-
434-
encs = jwe.SUPPORTED["enc"]
435-
for typ in ["userinfo", "id_token", "request_object"]:
436-
_provider_info["%s_encryption_enc_values_supported" % typ] = encs
437-
438-
# acr_values
439-
if self.authn_broker:
440-
acr_values = self.authn_broker.get_acr_value_string()
441-
if acr_values is not None:
442-
_provider_info["acr_values_supported"] = acr_values
443-
444-
return _provider_info
445-
446-
@staticmethod
447-
def get_wanted(key, val):
448-
if isinstance(val, str):
449-
_wanted = {val}
450-
else:
451-
try:
452-
_wanted = set(val)
453-
except TypeError:
454-
if key == "response_types_supported":
455-
_wanted = set()
456-
for _v in val:
457-
_v.sort()
458-
_wanted.add(" ".join(_v))
459-
else:
460-
raise
461-
else:
462-
_wanted = set()
463-
for _v in val:
464-
_vals = _v.split(" ")
465-
_vals.sort()
466-
_wanted.add(" ".join(_vals))
467-
return _wanted
468-
469-
def verify_allowed(self, allowed, key, val, pinfo, not_supported):
470-
if isinstance(allowed, bool):
471-
if allowed is False:
472-
if val is True:
473-
not_supported[key] = True
474-
else:
475-
pinfo[key] = val
476-
elif isinstance(allowed, str):
477-
if val != allowed:
478-
not_supported[key] = val
479-
elif isinstance(allowed, list):
480-
_wanted = self.get_wanted(key, val)
481-
482-
_allowed = set(allowed)
483-
484-
if (_wanted & _allowed) == _wanted:
485-
pinfo[key] = list(_wanted)
486-
else:
487-
not_supported[key] = list(_wanted - _allowed)
350+
def claims_supported(self):
351+
_claims = set()
352+
for scope, claims in self.scope2claims.items():
353+
_claims.update(set(claims))
354+
return list(_claims)
488355

489356
def create_providerinfo(self, capabilities):
490357
"""
@@ -494,28 +361,20 @@ def create_providerinfo(self, capabilities):
494361
:return:
495362
"""
496363

497-
_pinfo = self.package_capabilities()
498-
not_supported = {}
499-
for key, val in capabilities.items():
500-
try:
501-
allowed = _pinfo[key]
502-
except KeyError:
503-
_pinfo[key] = val
504-
else:
505-
self.verify_allowed(allowed, key, val, _pinfo, not_supported)
364+
_provider_info = capabilities
365+
_provider_info["issuer"] = self.issuer
366+
_provider_info["version"] = "3.0"
506367

507-
if not_supported:
508-
_msg = "Server doesn't support the following features: {}".format(
509-
not_supported
510-
)
511-
logger.error(_msg)
512-
raise ConfigurationError(_msg)
368+
# acr_values
369+
if self.authn_broker:
370+
acr_values = self.authn_broker.get_acr_values()
371+
if acr_values is not None:
372+
_provider_info["acr_values_supported"] = acr_values
513373

514374
if self.jwks_uri and self.keyjar:
515-
_pinfo["jwks_uri"] = self.jwks_uri
375+
_provider_info["jwks_uri"] = self.jwks_uri
516376

517-
for name, instance in self.endpoint.items():
518-
if name not in ["webfinger", "provider_info"]:
519-
_pinfo[instance.endpoint_name] = instance.full_path
377+
_provider_info.update(self.idtoken.provider_info)
378+
_provider_info['claims_supported'] = self.claims_supported()
520379

521-
return _pinfo
380+
return _provider_info

0 commit comments

Comments
 (0)