|
17 | 17 | from typing_extensions import override |
18 | 18 |
|
19 | 19 | from aws_lambda_powertools.event_handler import content_types |
| 20 | +from aws_lambda_powertools.event_handler.exception_handling import ExceptionHandlerManager |
20 | 21 | from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError |
21 | 22 | from aws_lambda_powertools.event_handler.openapi.config import OpenAPIConfig |
22 | 23 | from aws_lambda_powertools.event_handler.openapi.constants import ( |
@@ -1576,6 +1577,7 @@ def __init__( |
1576 | 1577 | self.processed_stack_frames = [] |
1577 | 1578 | self._response_builder_class = ResponseBuilder[BaseProxyEvent] |
1578 | 1579 | self.openapi_config = OpenAPIConfig() # starting an empty dataclass |
| 1580 | + self.exception_handler_manager = ExceptionHandlerManager() |
1579 | 1581 | self._has_response_validation_error = response_validation_error_http_code is not None |
1580 | 1582 | self._response_validation_error_http_code = self._validate_response_validation_error_http_code( |
1581 | 1583 | response_validation_error_http_code, |
@@ -2498,7 +2500,7 @@ def not_found_handler(): |
2498 | 2500 | return Response(status_code=204, content_type=None, headers=_headers, body="") |
2499 | 2501 |
|
2500 | 2502 | # Customer registered 404 route? Call it. |
2501 | | - custom_not_found_handler = self._lookup_exception_handler(NotFoundError) |
| 2503 | + custom_not_found_handler = self.exception_handler_manager.lookup_exception_handler(NotFoundError) |
2502 | 2504 | if custom_not_found_handler: |
2503 | 2505 | return custom_not_found_handler(NotFoundError()) |
2504 | 2506 |
|
@@ -2571,26 +2573,10 @@ def not_found(self, func: Callable | None = None): |
2571 | 2573 | return self.exception_handler(NotFoundError)(func) |
2572 | 2574 |
|
2573 | 2575 | def exception_handler(self, exc_class: type[Exception] | list[type[Exception]]): |
2574 | | - def register_exception_handler(func: Callable): |
2575 | | - if isinstance(exc_class, list): # pragma: no cover |
2576 | | - for exp in exc_class: |
2577 | | - self._exception_handlers[exp] = func |
2578 | | - else: |
2579 | | - self._exception_handlers[exc_class] = func |
2580 | | - return func |
2581 | | - |
2582 | | - return register_exception_handler |
2583 | | - |
2584 | | - def _lookup_exception_handler(self, exp_type: type) -> Callable | None: |
2585 | | - # Use "Method Resolution Order" to allow for matching against a base class |
2586 | | - # of an exception |
2587 | | - for cls in exp_type.__mro__: |
2588 | | - if cls in self._exception_handlers: |
2589 | | - return self._exception_handlers[cls] |
2590 | | - return None |
| 2576 | + return self.exception_handler_manager.exception_handler(exc_class=exc_class) |
2591 | 2577 |
|
2592 | 2578 | def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuilder | None: |
2593 | | - handler = self._lookup_exception_handler(type(exp)) |
| 2579 | + handler = self.exception_handler_manager.lookup_exception_handler(type(exp)) |
2594 | 2580 | if handler: |
2595 | 2581 | try: |
2596 | 2582 | return self._response_builder_class(response=handler(exp), serializer=self._serializer, route=route) |
@@ -2686,7 +2672,7 @@ def include_router(self, router: Router, prefix: str | None = None) -> None: |
2686 | 2672 | self._router_middlewares = self._router_middlewares + router._router_middlewares |
2687 | 2673 |
|
2688 | 2674 | logger.debug("Appending Router exception_handler into App exception_handler.") |
2689 | | - self._exception_handlers.update(router._exception_handlers) |
| 2675 | + self.exception_handler_manager.update_exception_handlers(router._exception_handlers) |
2690 | 2676 |
|
2691 | 2677 | # use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx) |
2692 | 2678 | router.context = self.context |
|
0 commit comments