1- import copy
21import logging
32import os
4- from functools import cmp_to_key
53
64import requests
7- from cryptojwt import jwe
8- from cryptojwt .jws .jws import SIGNER_ALGS
95from cryptojwt .key_jar import KeyJar
106from cryptojwt .key_jar import init_key_jar
117from jinja2 import Environment
1612from oidcendpoint import rndstr
1713from oidcendpoint import util
1814from oidcendpoint .client_authn import CLIENT_AUTHN_METHOD
19- from oidcendpoint .exception import ConfigurationError
2015from oidcendpoint .id_token import IDToken
2116from oidcendpoint .session import create_session_db
2217from oidcendpoint .sso_db import SSODb
2823
2924logger = logging .getLogger (__name__ )
3025
31- CAPABILITIES = {
32- "response_types_supported" : [
33- "code" ,
34- "token" ,
35- "id_token" ,
36- "code token" ,
37- "code id_token" ,
38- "id_token token" ,
39- "code id_token token" ,
40- "none" ,
41- ],
42- "token_endpoint_auth_methods_supported" : [
43- "client_secret_post" ,
44- "client_secret_basic" ,
45- "client_secret_jwt" ,
46- "private_key_jwt" ,
47- ],
48- "response_modes_supported" : ["query" , "fragment" , "form_post" ],
49- "subject_types_supported" : ["public" , "pairwise" ],
50- "grant_types_supported" : [
51- "authorization_code" ,
52- "implicit" ,
53- "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
54- "refresh_token" ,
55- ],
56- "claim_types_supported" : ["normal" , "aggregated" , "distributed" ],
57- "claims_parameter_supported" : True ,
58- "request_parameter_supported" : True ,
59- "request_uri_parameter_supported" : True ,
60- }
61-
62- SORT_ORDER = {"RS" : 0 , "ES" : 1 , "HS" : 2 , "PS" : 3 , "no" : 4 }
63-
64-
65- def sort_sign_alg (alg1 , alg2 ):
66- if SORT_ORDER [alg1 [0 :2 ]] < SORT_ORDER [alg2 [0 :2 ]]:
67- return - 1
68-
69- if SORT_ORDER [alg1 [0 :2 ]] > SORT_ORDER [alg2 [0 :2 ]]:
70- return 1
71-
72- if alg1 < alg2 :
73- return - 1
74-
75- if alg1 > alg2 :
76- return 1
77-
78- return 0
79-
8026
8127def add_path (url , path ):
8228 if url .endswith ("/" ):
@@ -388,9 +334,9 @@ def do_endpoints(self, conf):
388334 if endpoint in ["webfinger" , "provider_info" ]:
389335 continue
390336
391- _cap [ endpoint_instance .endpoint_name ] = "{}" . format (
392- endpoint_instance .endpoint_path
393- )
337+ if endpoint_instance .endpoint_name :
338+ _cap [ endpoint_instance .endpoint_name ] = endpoint_instance . full_path
339+
394340 return _cap
395341
396342 def do_authz (self , conf ):
@@ -401,90 +347,11 @@ def do_authz(self, conf):
401347 else :
402348 self .authz = init_service (authz_spec , self )
403349
404- def package_capabilities (self ):
405- _provider_info = copy .deepcopy (CAPABILITIES )
406- _provider_info ["issuer" ] = self .issuer
407- _provider_info ["version" ] = "3.0"
408-
409- _claims = []
410- for _cl in self .scope2claims .values ():
411- _claims .extend (_cl )
412- _provider_info ["claims_supported" ] = list (set (_claims ))
413-
414- _scopes = list (self .scope2claims .keys ())
415- _provider_info ["scopes_supported" ] = _scopes
416-
417- # Sort order RS, ES, HS, PS
418- sign_algs = list (SIGNER_ALGS .keys ())
419- sign_algs = sorted (sign_algs , key = cmp_to_key (sort_sign_alg ))
420-
421- for typ in ["userinfo" , "id_token" , "request_object" ]:
422- _provider_info ["%s_signing_alg_values_supported" % typ ] = sign_algs
423-
424- # Remove 'none' for token_endpoint_auth_signing_alg_values_supported
425- # since it is not allowed
426- sign_algs = sign_algs [:]
427- sign_algs .remove ("none" )
428- _provider_info ["token_endpoint_auth_signing_alg_values_supported" ] = sign_algs
429-
430- algs = jwe .SUPPORTED ["alg" ]
431- for typ in ["userinfo" , "id_token" , "request_object" ]:
432- _provider_info ["%s_encryption_alg_values_supported" % typ ] = algs
433-
434- encs = jwe .SUPPORTED ["enc" ]
435- for typ in ["userinfo" , "id_token" , "request_object" ]:
436- _provider_info ["%s_encryption_enc_values_supported" % typ ] = encs
437-
438- # acr_values
439- if self .authn_broker :
440- acr_values = self .authn_broker .get_acr_value_string ()
441- if acr_values is not None :
442- _provider_info ["acr_values_supported" ] = acr_values
443-
444- return _provider_info
445-
446- @staticmethod
447- def get_wanted (key , val ):
448- if isinstance (val , str ):
449- _wanted = {val }
450- else :
451- try :
452- _wanted = set (val )
453- except TypeError :
454- if key == "response_types_supported" :
455- _wanted = set ()
456- for _v in val :
457- _v .sort ()
458- _wanted .add (" " .join (_v ))
459- else :
460- raise
461- else :
462- _wanted = set ()
463- for _v in val :
464- _vals = _v .split (" " )
465- _vals .sort ()
466- _wanted .add (" " .join (_vals ))
467- return _wanted
468-
469- def verify_allowed (self , allowed , key , val , pinfo , not_supported ):
470- if isinstance (allowed , bool ):
471- if allowed is False :
472- if val is True :
473- not_supported [key ] = True
474- else :
475- pinfo [key ] = val
476- elif isinstance (allowed , str ):
477- if val != allowed :
478- not_supported [key ] = val
479- elif isinstance (allowed , list ):
480- _wanted = self .get_wanted (key , val )
481-
482- _allowed = set (allowed )
483-
484- if (_wanted & _allowed ) == _wanted :
485- pinfo [key ] = list (_wanted )
486- else :
487- not_supported [key ] = list (_wanted - _allowed )
350+ def claims_supported (self ):
351+ _claims = set ()
352+ for scope , claims in self .scope2claims .items ():
353+ _claims .update (set (claims ))
354+ return list (_claims )
488355
489356 def create_providerinfo (self , capabilities ):
490357 """
@@ -494,28 +361,20 @@ def create_providerinfo(self, capabilities):
494361 :return:
495362 """
496363
497- _pinfo = self .package_capabilities ()
498- not_supported = {}
499- for key , val in capabilities .items ():
500- try :
501- allowed = _pinfo [key ]
502- except KeyError :
503- _pinfo [key ] = val
504- else :
505- self .verify_allowed (allowed , key , val , _pinfo , not_supported )
364+ _provider_info = capabilities
365+ _provider_info ["issuer" ] = self .issuer
366+ _provider_info ["version" ] = "3.0"
506367
507- if not_supported :
508- _msg = "Server doesn't support the following features: {}" .format (
509- not_supported
510- )
511- logger .error (_msg )
512- raise ConfigurationError (_msg )
368+ # acr_values
369+ if self .authn_broker :
370+ acr_values = self .authn_broker .get_acr_values ()
371+ if acr_values is not None :
372+ _provider_info ["acr_values_supported" ] = acr_values
513373
514374 if self .jwks_uri and self .keyjar :
515- _pinfo ["jwks_uri" ] = self .jwks_uri
375+ _provider_info ["jwks_uri" ] = self .jwks_uri
516376
517- for name , instance in self .endpoint .items ():
518- if name not in ["webfinger" , "provider_info" ]:
519- _pinfo [instance .endpoint_name ] = instance .full_path
377+ _provider_info .update (self .idtoken .provider_info )
378+ _provider_info ['claims_supported' ] = self .claims_supported ()
520379
521- return _pinfo
380+ return _provider_info
0 commit comments