|
1 | | -from __future__ import annotations |
2 | | - |
3 | 1 | import asyncio |
| 2 | +import re |
4 | 3 | from typing import cast |
5 | 4 |
|
6 | 5 | import pytest |
| 6 | +from typing_extensions import Annotated |
7 | 7 |
|
8 | 8 | from aws_lambda_powertools.event_handler import content_types |
9 | 9 | from aws_lambda_powertools.event_handler.api_gateway import ( |
|
13 | 13 | BaseRouter, |
14 | 14 | ProxyEventType, |
15 | 15 | Response, |
| 16 | + Route, |
16 | 17 | ) |
| 18 | +from aws_lambda_powertools.event_handler.depends import Depends |
17 | 19 | from aws_lambda_powertools.event_handler.middlewares.async_utils import _registered_api_adapter_async |
| 20 | +from aws_lambda_powertools.event_handler.request import Request |
18 | 21 | from tests.functional.utils import load_event |
19 | 22 |
|
20 | 23 | API_REST_EVENT = load_event("apiGatewayProxyEvent.json") |
@@ -212,3 +215,121 @@ async def get_lambda(): |
212 | 215 | # THEN the adapter skips request injection and dependency resolution |
213 | 216 | assert result.status_code == 200 |
214 | 217 | assert result.body == "no route" |
| 218 | + |
| 219 | + |
| 220 | +def test_adapter_injects_request_param(): |
| 221 | + # GIVEN an async handler that declares a Request parameter |
| 222 | + app = APIGatewayHttpResolver() |
| 223 | + |
| 224 | + async def get_lambda(request: Request): |
| 225 | + return Response(200, content_types.TEXT_HTML, request.method) |
| 226 | + |
| 227 | + # WHEN a Route is present in context with request_param_name not yet checked |
| 228 | + _setup_resolver_context(app, API_RESTV2_EVENT) |
| 229 | + route = Route( |
| 230 | + method="GET", |
| 231 | + path="/my/path", |
| 232 | + rule=re.compile(r"^/my/path$"), |
| 233 | + func=get_lambda, |
| 234 | + cors=False, |
| 235 | + compress=False, |
| 236 | + ) |
| 237 | + app.append_context(_route=route, _route_args={}) |
| 238 | + |
| 239 | + result = asyncio.run( |
| 240 | + _registered_api_adapter_async(app, get_lambda), |
| 241 | + ) |
| 242 | + |
| 243 | + # THEN the Request object is injected and request_param_name is cached |
| 244 | + assert result.status_code == 200 |
| 245 | + assert route.request_param_name_checked is True |
| 246 | + assert route.request_param_name == "request" |
| 247 | + |
| 248 | + |
| 249 | +def test_adapter_uses_cached_request_param_name(): |
| 250 | + # GIVEN a Route where request_param_name was already resolved |
| 251 | + app = APIGatewayHttpResolver() |
| 252 | + |
| 253 | + async def get_lambda(req: Request): |
| 254 | + return Response(200, content_types.TEXT_HTML, req.method) |
| 255 | + |
| 256 | + _setup_resolver_context(app, API_RESTV2_EVENT) |
| 257 | + route = Route( |
| 258 | + method="GET", |
| 259 | + path="/my/path", |
| 260 | + rule=re.compile(r"^/my/path$"), |
| 261 | + func=get_lambda, |
| 262 | + cors=False, |
| 263 | + compress=False, |
| 264 | + ) |
| 265 | + route.request_param_name = "req" |
| 266 | + route.request_param_name_checked = True |
| 267 | + app.append_context(_route=route, _route_args={}) |
| 268 | + |
| 269 | + # WHEN calling the adapter a second time (cache hit) |
| 270 | + result = asyncio.run( |
| 271 | + _registered_api_adapter_async(app, get_lambda), |
| 272 | + ) |
| 273 | + |
| 274 | + # THEN it still injects the Request using the cached param name |
| 275 | + assert result.status_code == 200 |
| 276 | + |
| 277 | + |
| 278 | +def test_adapter_resolves_dependencies(): |
| 279 | + # GIVEN an async handler with Depends() parameters |
| 280 | + app = APIGatewayHttpResolver() |
| 281 | + |
| 282 | + def get_greeting() -> str: |
| 283 | + return "hello" |
| 284 | + |
| 285 | + async def get_lambda(greeting: Annotated[str, Depends(get_greeting)]): |
| 286 | + return {"greeting": greeting} |
| 287 | + |
| 288 | + _setup_resolver_context(app, API_RESTV2_EVENT) |
| 289 | + route = Route( |
| 290 | + method="GET", |
| 291 | + path="/my/path", |
| 292 | + rule=re.compile(r"^/my/path$"), |
| 293 | + func=get_lambda, |
| 294 | + cors=False, |
| 295 | + compress=False, |
| 296 | + ) |
| 297 | + app.append_context(_route=route, _route_args={}) |
| 298 | + |
| 299 | + # WHEN calling the adapter |
| 300 | + result = asyncio.run( |
| 301 | + _registered_api_adapter_async(app, get_lambda), |
| 302 | + ) |
| 303 | + |
| 304 | + # THEN dependencies are resolved and injected |
| 305 | + assert result.status_code == 200 |
| 306 | + |
| 307 | + |
| 308 | +def test_adapter_resolves_dependencies_with_sync_handler(): |
| 309 | + # GIVEN a sync handler with Depends() parameters |
| 310 | + app = APIGatewayHttpResolver() |
| 311 | + |
| 312 | + def get_greeting() -> str: |
| 313 | + return "hello" |
| 314 | + |
| 315 | + def get_lambda(greeting: Annotated[str, Depends(get_greeting)]): |
| 316 | + return {"greeting": greeting} |
| 317 | + |
| 318 | + _setup_resolver_context(app, API_RESTV2_EVENT) |
| 319 | + route = Route( |
| 320 | + method="GET", |
| 321 | + path="/my/path", |
| 322 | + rule=re.compile(r"^/my/path$"), |
| 323 | + func=get_lambda, |
| 324 | + cors=False, |
| 325 | + compress=False, |
| 326 | + ) |
| 327 | + app.append_context(_route=route, _route_args={}) |
| 328 | + |
| 329 | + # WHEN calling the adapter with a sync handler that has dependencies |
| 330 | + result = asyncio.run( |
| 331 | + _registered_api_adapter_async(app, get_lambda), |
| 332 | + ) |
| 333 | + |
| 334 | + # THEN dependencies are resolved and injected for sync handler too |
| 335 | + assert result.status_code == 200 |
0 commit comments