|
| 1 | +import logging |
| 2 | + |
| 3 | +from cryptojwt.jwe.exception import JWEException |
| 4 | +from cryptojwt.jws.exception import NoSuitableSigningKeys |
| 5 | +from oidcmsg import oidc |
| 6 | +from oidcmsg.exception import MissingRequiredAttribute |
| 7 | +from oidcmsg.exception import MissingRequiredValue |
| 8 | +from oidcmsg.oauth2 import ResponseMessage |
| 9 | +from oidcmsg.oidc import AccessTokenRequest |
| 10 | +from oidcmsg.oidc import AccessTokenResponse |
| 11 | +from oidcmsg.oidc import RefreshAccessTokenRequest |
| 12 | +from oidcmsg.oidc import TokenErrorResponse |
| 13 | + |
| 14 | +from oidcendpoint import sanitize |
| 15 | +from oidcendpoint.cookie import new_cookie |
| 16 | +from oidcendpoint.endpoint import Endpoint |
| 17 | +from oidcendpoint.token_handler import AccessCodeUsed |
| 18 | +from oidcendpoint.token_handler import ExpiredToken |
| 19 | +from oidcendpoint.userinfo import by_schema |
| 20 | + |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +class TokenCoop(Endpoint): |
| 25 | + request_cls = oidc.Message |
| 26 | + response_cls = oidc.AccessTokenResponse |
| 27 | + error_cls = TokenErrorResponse |
| 28 | + request_format = "json" |
| 29 | + request_placement = "body" |
| 30 | + response_format = "json" |
| 31 | + response_placement = "body" |
| 32 | + endpoint_name = "token_endpoint" |
| 33 | + name = "token" |
| 34 | + default_capabilities = {"token_endpoint_auth_signing_alg_values_supported": None} |
| 35 | + |
| 36 | + def __init__(self, endpoint_context, **kwargs): |
| 37 | + Endpoint.__init__(self, endpoint_context, **kwargs) |
| 38 | + self.post_parse_request.append(self._post_parse_request) |
| 39 | + if "client_authn_method" in kwargs: |
| 40 | + self.endpoint_info["token_endpoint_auth_methods_supported"] = kwargs[ |
| 41 | + "client_authn_method" |
| 42 | + ] |
| 43 | + |
| 44 | + def _refresh_access_token(self, req, **kwargs): |
| 45 | + _sdb = self.endpoint_context.sdb |
| 46 | + |
| 47 | + rtoken = req["refresh_token"] |
| 48 | + try: |
| 49 | + _info = _sdb.refresh_token(rtoken) |
| 50 | + except ExpiredToken: |
| 51 | + return self.error_cls( |
| 52 | + error="invalid_request", error_description="Refresh token is expired" |
| 53 | + ) |
| 54 | + |
| 55 | + return by_schema(AccessTokenResponse, **_info) |
| 56 | + |
| 57 | + def _access_token(self, req, **kwargs): |
| 58 | + _context = self.endpoint_context |
| 59 | + _sdb = _context.sdb |
| 60 | + _log_debug = logger.debug |
| 61 | + |
| 62 | + try: |
| 63 | + _access_code = req["code"].replace(" ", "+") |
| 64 | + except KeyError: # Missing code parameter - absolutely fatal |
| 65 | + return self.error_cls( |
| 66 | + error="invalid_request", error_description="Missing code" |
| 67 | + ) |
| 68 | + |
| 69 | + # Session might not exist or _access_code malformed |
| 70 | + try: |
| 71 | + _info = _sdb[_access_code] |
| 72 | + except KeyError: |
| 73 | + return self.error_cls( |
| 74 | + error="invalid_request", error_description="Code is invalid" |
| 75 | + ) |
| 76 | + |
| 77 | + _authn_req = _info["authn_req"] |
| 78 | + |
| 79 | + # assert that the code is valid |
| 80 | + if _context.sdb.is_session_revoked(_access_code): |
| 81 | + return self.error_cls( |
| 82 | + error="invalid_request", error_description="Session is revoked" |
| 83 | + ) |
| 84 | + |
| 85 | + # If redirect_uri was in the initial authorization request |
| 86 | + # verify that the one given here is the correct one. |
| 87 | + if "redirect_uri" in _authn_req: |
| 88 | + if req["redirect_uri"] != _authn_req["redirect_uri"]: |
| 89 | + return self.error_cls( |
| 90 | + error="invalid_request", error_description="redirect_uri mismatch" |
| 91 | + ) |
| 92 | + |
| 93 | + _log_debug("All checks OK") |
| 94 | + |
| 95 | + issue_refresh = False |
| 96 | + if "issue_refresh" in kwargs: |
| 97 | + issue_refresh = kwargs["issue_refresh"] |
| 98 | + |
| 99 | + # offline_access the default if nothing is specified |
| 100 | + permissions = _info.get("permission", ["offline_access"]) |
| 101 | + |
| 102 | + if "offline_access" in _authn_req["scope"] and "offline_access" in permissions: |
| 103 | + issue_refresh = True |
| 104 | + |
| 105 | + try: |
| 106 | + _info = _sdb.upgrade_to_token(_access_code, issue_refresh=issue_refresh) |
| 107 | + except AccessCodeUsed as err: |
| 108 | + logger.error("%s" % err) |
| 109 | + # Should revoke the token issued to this access code |
| 110 | + _sdb.revoke_all_tokens(_access_code) |
| 111 | + return self.error_cls( |
| 112 | + error="access_denied", error_description="Access Code already used" |
| 113 | + ) |
| 114 | + |
| 115 | + if "openid" in _authn_req["scope"]: |
| 116 | + try: |
| 117 | + _idtoken = _context.idtoken.make(req, _info, _authn_req) |
| 118 | + except (JWEException, NoSuitableSigningKeys) as err: |
| 119 | + logger.warning(str(err)) |
| 120 | + resp = TokenErrorResponse( |
| 121 | + error="invalid_request", |
| 122 | + error_description="Could not sign/encrypt id_token", |
| 123 | + ) |
| 124 | + return resp |
| 125 | + |
| 126 | + _sdb.update_by_token(_access_code, id_token=_idtoken) |
| 127 | + _info = _sdb[_access_code] |
| 128 | + |
| 129 | + return by_schema(AccessTokenResponse, **_info) |
| 130 | + |
| 131 | + def get_client_id_from_token(self, endpoint_context, token, request=None): |
| 132 | + sinfo = endpoint_context.sdb[token] |
| 133 | + return sinfo["authn_req"]["client_id"] |
| 134 | + |
| 135 | + def _access_token_post_parse_request(self, request, client_id="", **kwargs): |
| 136 | + """ |
| 137 | + This is where clients come to get their access tokens |
| 138 | +
|
| 139 | + :param request: The request |
| 140 | + :param authn: Authentication info, comes from HTTP header |
| 141 | + :returns: |
| 142 | + """ |
| 143 | + |
| 144 | + request = AccessTokenRequest(**request.to_dict()) |
| 145 | + |
| 146 | + if "state" in request: |
| 147 | + try: |
| 148 | + sinfo = self.endpoint_context.sdb[request["code"]] |
| 149 | + except KeyError: |
| 150 | + logger.error("Code not present in SessionDB") |
| 151 | + return self.error_cls(error="unauthorized_client") |
| 152 | + else: |
| 153 | + state = sinfo["authn_req"]["state"] |
| 154 | + |
| 155 | + if state != request["state"]: |
| 156 | + logger.error("State value mismatch") |
| 157 | + return self.error_cls(error="unauthorized_client") |
| 158 | + |
| 159 | + if "client_id" not in request: # Optional for access token request |
| 160 | + request["client_id"] = client_id |
| 161 | + |
| 162 | + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) |
| 163 | + |
| 164 | + return request |
| 165 | + |
| 166 | + def _refresh_token_post_parse_request(self, request, client_id="", **kwargs): |
| 167 | + """ |
| 168 | + This is where clients come to refresh their access tokens |
| 169 | +
|
| 170 | + :param request: The request |
| 171 | + :param authn: Authentication info, comes from HTTP header |
| 172 | + :returns: |
| 173 | + """ |
| 174 | + |
| 175 | + request = RefreshAccessTokenRequest(**request.to_dict()) |
| 176 | + |
| 177 | + # verify that the request message is correct |
| 178 | + try: |
| 179 | + request.verify(keyjar=self.endpoint_context.keyjar) |
| 180 | + except (MissingRequiredAttribute, ValueError, MissingRequiredValue) as err: |
| 181 | + return self.error_cls(error="invalid_request", error_description="%s" % err) |
| 182 | + |
| 183 | + try: |
| 184 | + keyjar = self.endpoint_context.keyjar |
| 185 | + except AttributeError: |
| 186 | + keyjar = "" |
| 187 | + |
| 188 | + request.verify(keyjar=keyjar, opponent_id=client_id) |
| 189 | + |
| 190 | + if "client_id" not in request: # Optional for refresh access token request |
| 191 | + request["client_id"] = client_id |
| 192 | + |
| 193 | + logger.debug("%s: %s" % (request.__class__.__name__, sanitize(request))) |
| 194 | + |
| 195 | + return request |
| 196 | + |
| 197 | + def _post_parse_request(self, request, client_id="", **kwargs): |
| 198 | + if request["grant_type"] == "authorization_code": |
| 199 | + return self._access_token_post_parse_request(request, client_id, **kwargs) |
| 200 | + else: # request["grant_type"] == "refresh_token": |
| 201 | + return self._refresh_token_post_parse_request(request, client_id, **kwargs) |
| 202 | + |
| 203 | + def process_request(self, request=None, **kwargs): |
| 204 | + """ |
| 205 | +
|
| 206 | + :param request: |
| 207 | + :param kwargs: |
| 208 | + :return: Dictionary with response information |
| 209 | + """ |
| 210 | + try: |
| 211 | + if request["grant_type"] == "authorization_code": |
| 212 | + logger.debug("Access Token Request") |
| 213 | + response_args = self._access_token(request, **kwargs) |
| 214 | + elif request["grant_type"] == "refresh_token": |
| 215 | + logger.debug("Refresh Access Token Request") |
| 216 | + response_args = self._refresh_access_token(request, **kwargs) |
| 217 | + else: |
| 218 | + return self.error_cls( |
| 219 | + error="invalid_request", error_description="Wrong grant_type" |
| 220 | + ) |
| 221 | + except JWEException as err: |
| 222 | + return self.error_cls(error="invalid_request", error_description="%s" % err) |
| 223 | + |
| 224 | + if isinstance(response_args, ResponseMessage): |
| 225 | + return response_args |
| 226 | + |
| 227 | + if request["grant_type"] == "authorization_code": |
| 228 | + _token = request["code"].replace(" ", "+") |
| 229 | + else: |
| 230 | + _token = request["refresh_token"].replace(" ", "+") |
| 231 | + |
| 232 | + _cookie = new_cookie( |
| 233 | + self.endpoint_context, sub=self.endpoint_context.sdb[_token]["sub"] |
| 234 | + ) |
| 235 | + |
| 236 | + _headers = [("Content-type", "application/json")] |
| 237 | + resp = {"response_args": response_args, "http_headers": _headers} |
| 238 | + if _cookie: |
| 239 | + resp["cookie"] = _cookie |
| 240 | + return resp |
0 commit comments