1- import copy
21import logging
32import os
4- from functools import cmp_to_key
53
64import requests
7- from cryptojwt import jwe
8- from cryptojwt .jws .jws import SIGNER_ALGS
95from cryptojwt .key_jar import KeyJar
106from cryptojwt .key_jar import init_key_jar
117from jinja2 import Environment
1612from oidcendpoint import rndstr
1713from oidcendpoint import util
1814from oidcendpoint .client_authn import CLIENT_AUTHN_METHOD
19- from oidcendpoint .exception import ConfigurationError
2015from oidcendpoint .id_token import IDToken
2116from oidcendpoint .session import create_session_db
2217from oidcendpoint .sso_db import SSODb
2823
2924LOGGER = logging .getLogger (__name__ )
3025
31- CAPABILITIES = {
32- "response_types_supported" : [
33- "code" ,
34- "token" ,
35- "id_token" ,
36- "code token" ,
37- "code id_token" ,
38- "id_token token" ,
39- "code id_token token" ,
40- "none" ,
41- ],
42- "token_endpoint_auth_methods_supported" : [
43- "client_secret_post" ,
44- "client_secret_basic" ,
45- "client_secret_jwt" ,
46- "private_key_jwt" ,
47- ],
48- "response_modes_supported" : ["query" , "fragment" , "form_post" ],
49- "subject_types_supported" : ["public" , "pairwise" ],
50- "grant_types_supported" : [
51- "authorization_code" ,
52- "implicit" ,
53- "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
54- "refresh_token" ,
55- ],
56- "claim_types_supported" : ["normal" , "aggregated" , "distributed" ],
57- "claims_parameter_supported" : True ,
58- "request_parameter_supported" : True ,
59- "request_uri_parameter_supported" : True ,
60- }
61-
62- SORT_ORDER = {"RS" : 0 , "ES" : 1 , "HS" : 2 , "PS" : 3 , "no" : 4 }
63-
64-
65- def sort_sign_alg (alg1 , alg2 ):
66- if SORT_ORDER [alg1 [0 :2 ]] < SORT_ORDER [alg2 [0 :2 ]]:
67- return - 1
68-
69- if SORT_ORDER [alg1 [0 :2 ]] > SORT_ORDER [alg2 [0 :2 ]]:
70- return 1
71-
72- if alg1 < alg2 :
73- return - 1
74-
75- if alg1 > alg2 :
76- return 1
77-
78- return 0
79-
8026
8127def add_path (url , path ):
8228 if url .endswith ("/" ):
@@ -375,9 +321,9 @@ def do_endpoints(self, conf):
375321 if endpoint in ["webfinger" , "provider_info" ]:
376322 continue
377323
378- _cap [ endpoint_instance .endpoint_name ] = "{}" . format (
379- endpoint_instance .endpoint_path
380- )
324+ if endpoint_instance .endpoint_name :
325+ _cap [ endpoint_instance .endpoint_name ] = endpoint_instance . full_path
326+
381327 return _cap
382328
383329 def do_authz (self , conf ):
@@ -388,90 +334,11 @@ def do_authz(self, conf):
388334 else :
389335 self .authz = init_service (authz_spec , self )
390336
391- def package_capabilities (self ):
392- _provider_info = copy .deepcopy (CAPABILITIES )
393- _provider_info ["issuer" ] = self .issuer
394- _provider_info ["version" ] = "3.0"
395-
396- _claims = []
397- for _cl in self .scope2claims .values ():
398- _claims .extend (_cl )
399- _provider_info ["claims_supported" ] = list (set (_claims ))
400-
401- _scopes = list (self .scope2claims .keys ())
402- _provider_info ["scopes_supported" ] = _scopes
403-
404- # Sort order RS, ES, HS, PS
405- sign_algs = list (SIGNER_ALGS .keys ())
406- sign_algs = sorted (sign_algs , key = cmp_to_key (sort_sign_alg ))
407-
408- for typ in ["userinfo" , "id_token" , "request_object" ]:
409- _provider_info ["%s_signing_alg_values_supported" % typ ] = sign_algs
410-
411- # Remove 'none' for token_endpoint_auth_signing_alg_values_supported
412- # since it is not allowed
413- sign_algs = sign_algs [:]
414- sign_algs .remove ("none" )
415- _provider_info ["token_endpoint_auth_signing_alg_values_supported" ] = sign_algs
416-
417- algs = jwe .SUPPORTED ["alg" ]
418- for typ in ["userinfo" , "id_token" , "request_object" ]:
419- _provider_info ["%s_encryption_alg_values_supported" % typ ] = algs
420-
421- encs = jwe .SUPPORTED ["enc" ]
422- for typ in ["userinfo" , "id_token" , "request_object" ]:
423- _provider_info ["%s_encryption_enc_values_supported" % typ ] = encs
424-
425- # acr_values
426- if self .authn_broker :
427- acr_values = self .authn_broker .get_acr_value_string ()
428- if acr_values is not None :
429- _provider_info ["acr_values_supported" ] = acr_values
430-
431- return _provider_info
432-
433- @staticmethod
434- def get_wanted (key , val ):
435- if isinstance (val , str ):
436- _wanted = {val }
437- else :
438- try :
439- _wanted = set (val )
440- except TypeError :
441- if key == "response_types_supported" :
442- _wanted = set ()
443- for _v in val :
444- _v .sort ()
445- _wanted .add (" " .join (_v ))
446- else :
447- raise
448- else :
449- _wanted = set ()
450- for _v in val :
451- _vals = _v .split (" " )
452- _vals .sort ()
453- _wanted .add (" " .join (_vals ))
454- return _wanted
455-
456- def verify_allowed (self , allowed , key , val , pinfo , not_supported ):
457- if isinstance (allowed , bool ):
458- if allowed is False :
459- if val is True :
460- not_supported [key ] = True
461- else :
462- pinfo [key ] = val
463- elif isinstance (allowed , str ):
464- if val != allowed :
465- not_supported [key ] = val
466- elif isinstance (allowed , list ):
467- _wanted = self .get_wanted (key , val )
468-
469- _allowed = set (allowed )
470-
471- if (_wanted & _allowed ) == _wanted :
472- pinfo [key ] = list (_wanted )
473- else :
474- not_supported [key ] = list (_wanted - _allowed )
337+ def claims_supported (self ):
338+ _claims = set ()
339+ for scope , claims in self .scope2claims .items ():
340+ _claims .update (set (claims ))
341+ return list (_claims )
475342
476343 def create_providerinfo (self , capabilities ):
477344 """
@@ -481,28 +348,20 @@ def create_providerinfo(self, capabilities):
481348 :return:
482349 """
483350
484- _pinfo = self .package_capabilities ()
485- not_supported = {}
486- for key , val in capabilities .items ():
487- try :
488- allowed = _pinfo [key ]
489- except KeyError :
490- _pinfo [key ] = val
491- else :
492- self .verify_allowed (allowed , key , val , _pinfo , not_supported )
351+ _provider_info = capabilities
352+ _provider_info ["issuer" ] = self .issuer
353+ _provider_info ["version" ] = "3.0"
493354
494- if not_supported :
495- _msg = "Server doesn't support the following features: {}" .format (
496- not_supported
497- )
498- LOGGER .error (_msg )
499- raise ConfigurationError (_msg )
355+ # acr_values
356+ if self .authn_broker :
357+ acr_values = self .authn_broker .get_acr_values ()
358+ if acr_values is not None :
359+ _provider_info ["acr_values_supported" ] = acr_values
500360
501361 if self .jwks_uri and self .keyjar :
502- _pinfo ["jwks_uri" ] = self .jwks_uri
362+ _provider_info ["jwks_uri" ] = self .jwks_uri
503363
504- for name , instance in self .endpoint .items ():
505- if name not in ["webfinger" , "provider_info" ]:
506- _pinfo [instance .endpoint_name ] = instance .full_path
364+ _provider_info .update (self .idtoken .provider_info )
365+ _provider_info ['claims_supported' ] = self .claims_supported ()
507366
508- return _pinfo
367+ return _provider_info
0 commit comments