Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Commit 006b11d

Browse files
committed
Defining capabilities not anymore done in one central place.
Most of it is done per endpoint. Defined default capabilities per endpoint. Some capabilities are computed at start time since they are configuration dependent. Refactored. Bumped version.
1 parent 2a02313 commit 006b11d

22 files changed

Lines changed: 531 additions & 319 deletions

src/oidcendpoint/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
except ImportError:
77
import random as rnd
88

9-
__version__ = "0.10.1"
9+
__version__ = "0.11.0"
1010

1111

1212
DEF_SIGN_ALG = {

src/oidcendpoint/cookie.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
__author__ = "Roland Hedberg"
2626

27-
logger = logging.getLogger(__name__)
27+
LOGGER = logging.getLogger(__name__)
2828

2929
CORS_HEADERS = [
3030
("Access-Control-Allow-Origin", "*"),
@@ -466,12 +466,15 @@ def append_cookie(
466466

467467
def compute_session_state(opbs, salt, client_id, redirect_uri):
468468
"""
469+
Computes a session state value.
470+
This value is later used during session management to check whether
471+
the log in state has changed.
469472
470-
:param opbs:
473+
:param opbs: Cookie value
471474
:param salt:
472475
:param client_id:
473476
:param redirect_uri:
474-
:return:
477+
:return: Session state value
475478
"""
476479
parsed_uri = urlparse(redirect_uri)
477480
rp_origin_url = "{uri.scheme}://{uri.netloc}".format(uri=parsed_uri)

src/oidcendpoint/endpoint.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
2+
from functools import cmp_to_key
23
from urllib.parse import urlparse
34

5+
from cryptojwt import jwe
6+
from cryptojwt.jws.jws import SIGNER_ALGS
47
from oidcmsg.exception import MissingRequiredAttribute
58
from oidcmsg.exception import MissingRequiredValue
69
from oidcmsg.message import Message
@@ -51,6 +54,72 @@ def fragment_encoding(return_type):
5154
return True
5255

5356

57+
ALG_SORT_ORDER = {"RS": 0, "ES": 1, "HS": 2, "PS": 3, "no": 4}
58+
59+
60+
def sort_sign_alg(alg1, alg2):
61+
if ALG_SORT_ORDER[alg1[0:2]] < ALG_SORT_ORDER[alg2[0:2]]:
62+
return -1
63+
64+
if ALG_SORT_ORDER[alg1[0:2]] > ALG_SORT_ORDER[alg2[0:2]]:
65+
return 1
66+
67+
if alg1 < alg2:
68+
return -1
69+
70+
if alg1 > alg2:
71+
return 1
72+
73+
return 0
74+
75+
76+
def assign_algorithms(typ):
77+
if typ == "signing_alg":
78+
# Pick supported signing algorithms from crypto library
79+
# Sort order RS, ES, HS, PS
80+
sign_algs = list(SIGNER_ALGS.keys())
81+
return sorted(sign_algs, key=cmp_to_key(sort_sign_alg))
82+
elif typ == "encryption_alg":
83+
return jwe.SUPPORTED["alg"]
84+
elif typ == "encryption_enc":
85+
return jwe.SUPPORTED["enc"]
86+
87+
88+
def construct_provider_info(default_capabilities, **kwargs):
89+
if default_capabilities is not None:
90+
provider_info = {}
91+
for attr, default_val in default_capabilities.items():
92+
try:
93+
_proposal = kwargs[attr]
94+
except KeyError:
95+
if default_val is not None:
96+
provider_info[attr] = default_val
97+
elif "signing_alg_values_supported" in attr:
98+
provider_info[attr] = assign_algorithms("signing_alg")
99+
elif "encryption_alg_values_supported" in attr:
100+
provider_info[attr] = assign_algorithms("encryption_alg")
101+
elif "encryption_enc_values_supported" in attr:
102+
provider_info[attr] = assign_algorithms("encryption_enc")
103+
else:
104+
_permitted = None
105+
106+
if "signing_alg_values_supported" in attr:
107+
_permitted = set(assign_algorithms("signing_alg"))
108+
elif "encryption_alg_values_supported" in attr:
109+
_permitted = set(assign_algorithms("encryption_alg"))
110+
elif "encryption_enc_values_supported" in attr:
111+
_permitted = set(assign_algorithms("encryption_enc"))
112+
113+
if _permitted and not _permitted.issuperset(set(_proposal)):
114+
raise ValueError(
115+
"Proposed set of values outside set of permitted ({})".__format__(attr))
116+
117+
provider_info[attr] = _proposal
118+
return provider_info
119+
else:
120+
return None
121+
122+
54123
class Endpoint(object):
55124
request_cls = Message
56125
response_cls = Message
@@ -63,6 +132,7 @@ class Endpoint(object):
63132
response_format = "json"
64133
response_placement = "body"
65134
client_authn_method = ""
135+
default_capabilities = None
66136

67137
def __init__(self, endpoint_context, **kwargs):
68138
self.endpoint_context = endpoint_context
@@ -71,7 +141,14 @@ def __init__(self, endpoint_context, **kwargs):
71141
self.post_parse_request = []
72142
self.kwargs = kwargs
73143
self.full_path = ""
74-
self.provider_info = None
144+
145+
if "client_authn_method" in kwargs:
146+
self.client_authn_method = kwargs["client_authn_method"]
147+
elif self.default_capabilities is not None:
148+
if "client_authn_method" in self.default_capabilities:
149+
self.client_authn_method = self.default_capabilities["client_authn_method"]
150+
151+
self.provider_info = construct_provider_info(self.default_capabilities, **kwargs)
75152

76153
def parse_request(self, request, auth=None, **kwargs):
77154
"""

src/oidcendpoint/endpoint_context.py

Lines changed: 20 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import copy
21
import logging
32
import os
4-
from functools import cmp_to_key
53

64
import requests
7-
from cryptojwt import jwe
8-
from cryptojwt.jws.jws import SIGNER_ALGS
95
from cryptojwt.key_jar import KeyJar
106
from cryptojwt.key_jar import init_key_jar
117
from jinja2 import Environment
@@ -16,7 +12,6 @@
1612
from oidcendpoint import rndstr
1713
from oidcendpoint import util
1814
from oidcendpoint.client_authn import CLIENT_AUTHN_METHOD
19-
from oidcendpoint.exception import ConfigurationError
2015
from oidcendpoint.id_token import IDToken
2116
from oidcendpoint.session import create_session_db
2217
from oidcendpoint.sso_db import SSODb
@@ -28,55 +23,6 @@
2823

2924
LOGGER = logging.getLogger(__name__)
3025

31-
CAPABILITIES = {
32-
"response_types_supported": [
33-
"code",
34-
"token",
35-
"id_token",
36-
"code token",
37-
"code id_token",
38-
"id_token token",
39-
"code id_token token",
40-
"none",
41-
],
42-
"token_endpoint_auth_methods_supported": [
43-
"client_secret_post",
44-
"client_secret_basic",
45-
"client_secret_jwt",
46-
"private_key_jwt",
47-
],
48-
"response_modes_supported": ["query", "fragment", "form_post"],
49-
"subject_types_supported": ["public", "pairwise"],
50-
"grant_types_supported": [
51-
"authorization_code",
52-
"implicit",
53-
"urn:ietf:params:oauth:grant-type:jwt-bearer",
54-
"refresh_token",
55-
],
56-
"claim_types_supported": ["normal", "aggregated", "distributed"],
57-
"claims_parameter_supported": True,
58-
"request_parameter_supported": True,
59-
"request_uri_parameter_supported": True,
60-
}
61-
62-
SORT_ORDER = {"RS": 0, "ES": 1, "HS": 2, "PS": 3, "no": 4}
63-
64-
65-
def sort_sign_alg(alg1, alg2):
66-
if SORT_ORDER[alg1[0:2]] < SORT_ORDER[alg2[0:2]]:
67-
return -1
68-
69-
if SORT_ORDER[alg1[0:2]] > SORT_ORDER[alg2[0:2]]:
70-
return 1
71-
72-
if alg1 < alg2:
73-
return -1
74-
75-
if alg1 > alg2:
76-
return 1
77-
78-
return 0
79-
8026

8127
def add_path(url, path):
8228
if url.endswith("/"):
@@ -375,9 +321,9 @@ def do_endpoints(self, conf):
375321
if endpoint in ["webfinger", "provider_info"]:
376322
continue
377323

378-
_cap[endpoint_instance.endpoint_name] = "{}".format(
379-
endpoint_instance.endpoint_path
380-
)
324+
if endpoint_instance.endpoint_name:
325+
_cap[endpoint_instance.endpoint_name] = endpoint_instance.full_path
326+
381327
return _cap
382328

383329
def do_authz(self, conf):
@@ -388,90 +334,11 @@ def do_authz(self, conf):
388334
else:
389335
self.authz = init_service(authz_spec, self)
390336

391-
def package_capabilities(self):
392-
_provider_info = copy.deepcopy(CAPABILITIES)
393-
_provider_info["issuer"] = self.issuer
394-
_provider_info["version"] = "3.0"
395-
396-
_claims = []
397-
for _cl in self.scope2claims.values():
398-
_claims.extend(_cl)
399-
_provider_info["claims_supported"] = list(set(_claims))
400-
401-
_scopes = list(self.scope2claims.keys())
402-
_provider_info["scopes_supported"] = _scopes
403-
404-
# Sort order RS, ES, HS, PS
405-
sign_algs = list(SIGNER_ALGS.keys())
406-
sign_algs = sorted(sign_algs, key=cmp_to_key(sort_sign_alg))
407-
408-
for typ in ["userinfo", "id_token", "request_object"]:
409-
_provider_info["%s_signing_alg_values_supported" % typ] = sign_algs
410-
411-
# Remove 'none' for token_endpoint_auth_signing_alg_values_supported
412-
# since it is not allowed
413-
sign_algs = sign_algs[:]
414-
sign_algs.remove("none")
415-
_provider_info["token_endpoint_auth_signing_alg_values_supported"] = sign_algs
416-
417-
algs = jwe.SUPPORTED["alg"]
418-
for typ in ["userinfo", "id_token", "request_object"]:
419-
_provider_info["%s_encryption_alg_values_supported" % typ] = algs
420-
421-
encs = jwe.SUPPORTED["enc"]
422-
for typ in ["userinfo", "id_token", "request_object"]:
423-
_provider_info["%s_encryption_enc_values_supported" % typ] = encs
424-
425-
# acr_values
426-
if self.authn_broker:
427-
acr_values = self.authn_broker.get_acr_value_string()
428-
if acr_values is not None:
429-
_provider_info["acr_values_supported"] = acr_values
430-
431-
return _provider_info
432-
433-
@staticmethod
434-
def get_wanted(key, val):
435-
if isinstance(val, str):
436-
_wanted = {val}
437-
else:
438-
try:
439-
_wanted = set(val)
440-
except TypeError:
441-
if key == "response_types_supported":
442-
_wanted = set()
443-
for _v in val:
444-
_v.sort()
445-
_wanted.add(" ".join(_v))
446-
else:
447-
raise
448-
else:
449-
_wanted = set()
450-
for _v in val:
451-
_vals = _v.split(" ")
452-
_vals.sort()
453-
_wanted.add(" ".join(_vals))
454-
return _wanted
455-
456-
def verify_allowed(self, allowed, key, val, pinfo, not_supported):
457-
if isinstance(allowed, bool):
458-
if allowed is False:
459-
if val is True:
460-
not_supported[key] = True
461-
else:
462-
pinfo[key] = val
463-
elif isinstance(allowed, str):
464-
if val != allowed:
465-
not_supported[key] = val
466-
elif isinstance(allowed, list):
467-
_wanted = self.get_wanted(key, val)
468-
469-
_allowed = set(allowed)
470-
471-
if (_wanted & _allowed) == _wanted:
472-
pinfo[key] = list(_wanted)
473-
else:
474-
not_supported[key] = list(_wanted - _allowed)
337+
def claims_supported(self):
338+
_claims = set()
339+
for scope, claims in self.scope2claims.items():
340+
_claims.update(set(claims))
341+
return list(_claims)
475342

476343
def create_providerinfo(self, capabilities):
477344
"""
@@ -481,28 +348,20 @@ def create_providerinfo(self, capabilities):
481348
:return:
482349
"""
483350

484-
_pinfo = self.package_capabilities()
485-
not_supported = {}
486-
for key, val in capabilities.items():
487-
try:
488-
allowed = _pinfo[key]
489-
except KeyError:
490-
_pinfo[key] = val
491-
else:
492-
self.verify_allowed(allowed, key, val, _pinfo, not_supported)
351+
_provider_info = capabilities
352+
_provider_info["issuer"] = self.issuer
353+
_provider_info["version"] = "3.0"
493354

494-
if not_supported:
495-
_msg = "Server doesn't support the following features: {}".format(
496-
not_supported
497-
)
498-
LOGGER.error(_msg)
499-
raise ConfigurationError(_msg)
355+
# acr_values
356+
if self.authn_broker:
357+
acr_values = self.authn_broker.get_acr_values()
358+
if acr_values is not None:
359+
_provider_info["acr_values_supported"] = acr_values
500360

501361
if self.jwks_uri and self.keyjar:
502-
_pinfo["jwks_uri"] = self.jwks_uri
362+
_provider_info["jwks_uri"] = self.jwks_uri
503363

504-
for name, instance in self.endpoint.items():
505-
if name not in ["webfinger", "provider_info"]:
506-
_pinfo[instance.endpoint_name] = instance.full_path
364+
_provider_info.update(self.idtoken.provider_info)
365+
_provider_info['claims_supported'] = self.claims_supported()
507366

508-
return _pinfo
367+
return _provider_info

0 commit comments

Comments
 (0)