11import asyncio
22
3+ import pytest
4+
35from aws_lambda_powertools .event_handler import content_types
46from aws_lambda_powertools .event_handler .api_gateway import (
57 ApiGatewayResolver ,
68 ProxyEventType ,
79 Response ,
810)
911from aws_lambda_powertools .event_handler .middlewares import NextMiddleware
10- from aws_lambda_powertools .event_handler .middlewares .async_utils import AsyncMiddlewareFrame
12+ from aws_lambda_powertools .event_handler .middlewares .async_utils import AsyncMiddlewareFrame , wrap_middleware_async
1113from tests .functional .utils import load_event
1214
1315API_REST_EVENT = load_event ("apiGatewayProxyEvent.json" )
@@ -20,195 +22,68 @@ def _make_app() -> ApiGatewayResolver:
2022 return app
2123
2224
23- class TestAsyncMiddlewareFrameWithAsyncMiddleware :
24- def test_async_middleware_is_awaited (self ):
25- # GIVEN an async middleware and an async next handler
26- app = _make_app ()
27-
28- async def my_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
29- app .append_context (middleware_called = True )
30- return await next_middleware (app )
31-
32- async def next_handler (app : ApiGatewayResolver ):
33- await asyncio .sleep (0 )
34- return Response (200 , content_types .TEXT_HTML , "from handler" )
35-
36- frame = AsyncMiddlewareFrame (current_middleware = my_middleware , next_middleware = next_handler )
37-
38- # WHEN calling the frame
39- result = asyncio .run (frame (app ))
40-
41- # THEN the async middleware is invoked and the chain proceeds
42- assert result .status_code == 200
43- assert result .body == "from handler"
44- assert app .context .get ("middleware_called" ) is True
45-
46- def test_async_middleware_can_short_circuit (self ):
47- # GIVEN an async middleware that returns early without calling next
48- app = _make_app ()
49-
50- async def blocking_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
51- await asyncio .sleep (0 )
52- return Response (403 , content_types .TEXT_PLAIN , "forbidden" )
53-
54- async def next_handler (app : ApiGatewayResolver ):
55- await asyncio .sleep (0 )
56- return Response (200 , content_types .TEXT_HTML , "should not reach" )
57-
58- frame = AsyncMiddlewareFrame (current_middleware = blocking_middleware , next_middleware = next_handler )
59-
60- # WHEN calling the frame
61- result = asyncio .run (frame (app ))
62-
63- # THEN the middleware short-circuits the chain
64- assert result .status_code == 403
65- assert result .body == "forbidden"
66-
67- def test_multiple_async_middlewares_chained (self ):
68- # GIVEN two async middlewares chained together
69- app = _make_app ()
70-
71- async def first_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
72- app .append_context (first = True )
73- return await next_middleware (app )
74-
75- async def second_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
76- app .append_context (second = True )
77- return await next_middleware (app )
78-
79- async def final_handler (app : ApiGatewayResolver ):
80- await asyncio .sleep (0 )
81- return Response (200 , content_types .TEXT_HTML , "done" )
82-
83- # WHEN building a chain: first -> second -> handler
84- inner_frame = AsyncMiddlewareFrame (current_middleware = second_middleware , next_middleware = final_handler )
85- outer_frame = AsyncMiddlewareFrame (current_middleware = first_middleware , next_middleware = inner_frame )
86-
87- result = asyncio .run (outer_frame (app ))
88-
89- # THEN both middlewares run in order
90- assert result .status_code == 200
91- assert app .context .get ("first" ) is True
92- assert app .context .get ("second" ) is True
25+ def test_sync_middleware_raising_before_next_does_not_deadlock ():
26+ # GIVEN a sync middleware that raises before calling next()
27+ # This previously caused a deadlock because middleware_called_next was never set
28+ app = _make_app ()
9329
30+ class AuthError (Exception ):
31+ pass
9432
95- class TestAsyncMiddlewareFrameWithSyncMiddleware :
96- def test_sync_middleware_is_bridged (self ):
97- # GIVEN a sync middleware and an async next handler
98- app = _make_app ()
33+ def failing_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
34+ raise AuthError ("denied" )
9935
100- def sync_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
101- app . append_context ( sync_called = True )
102- return next_middleware ( app )
36+ async def next_handler (app : ApiGatewayResolver ):
37+ await asyncio . sleep ( 0 )
38+ return Response ( 200 , content_types . TEXT_HTML , "should not reach" )
10339
104- async def next_handler (app : ApiGatewayResolver ):
105- await asyncio .sleep (0 )
106- return Response (200 , content_types .TEXT_HTML , "async handler" )
40+ frame = AsyncMiddlewareFrame (current_middleware = failing_middleware , next_middleware = next_handler )
10741
108- frame = AsyncMiddlewareFrame (current_middleware = sync_middleware , next_middleware = next_handler )
109-
110- # WHEN calling the frame
111- result = asyncio .run (frame (app ))
112-
113- # THEN the sync middleware is bridged via wrap_middleware_async
114- assert result .status_code == 200
115- assert result .body == "async handler"
116- assert app .context .get ("sync_called" ) is True
117-
118- def test_sync_middleware_can_short_circuit (self ):
119- # GIVEN a sync middleware that returns early
120- app = _make_app ()
121-
122- def sync_blocking (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
123- return Response (401 , content_types .TEXT_PLAIN , "unauthorized" )
124-
125- async def next_handler (app : ApiGatewayResolver ):
126- await asyncio .sleep (0 )
127- return Response (200 , content_types .TEXT_HTML , "should not reach" )
128-
129- frame = AsyncMiddlewareFrame (current_middleware = sync_blocking , next_middleware = next_handler )
130-
131- # WHEN calling the frame
132- result = asyncio .run (frame (app ))
133-
134- # THEN the sync middleware short-circuits
135- assert result .status_code == 401
136- assert result .body == "unauthorized"
137-
138-
139- class TestAsyncMiddlewareFrameMixedChain :
140- def test_sync_then_async_middleware (self ):
141- # GIVEN a chain with sync middleware followed by async middleware
142- app = _make_app ()
143-
144- def sync_mw (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
145- app .append_context (sync_ran = True )
146- return next_middleware (app )
147-
148- async def async_mw (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
149- app .append_context (async_ran = True )
150- return await next_middleware (app )
151-
152- async def handler (app : ApiGatewayResolver ):
153- await asyncio .sleep (0 )
154- return Response (200 , content_types .TEXT_HTML , "mixed chain" )
155-
156- inner = AsyncMiddlewareFrame (current_middleware = async_mw , next_middleware = handler )
157- outer = AsyncMiddlewareFrame (current_middleware = sync_mw , next_middleware = inner )
158-
159- # WHEN calling the chain
160- result = asyncio .run (outer (app ))
161-
162- # THEN both middlewares execute in order
163- assert result .status_code == 200
164- assert app .context .get ("sync_ran" ) is True
165- assert app .context .get ("async_ran" ) is True
42+ # WHEN calling the frame
43+ # THEN the exception propagates without deadlocking
44+ with pytest .raises (AuthError , match = "denied" ):
45+ asyncio .run (frame (app ))
16646
16747
168- class TestAsyncMiddlewareFrameProperties :
169- def test_name_property (self ):
170- # GIVEN a middleware with a known name
171- def my_named_middleware (app , next_mw ):
172- return next_mw (app )
48+ def test_wrap_middleware_async_sync_raising_before_next_does_not_deadlock ():
49+ # GIVEN a sync middleware that raises before calling next(), using wrap_middleware_async
50+ # This exercises _run_sync_middleware_in_thread directly
51+ app = _make_app ()
17352
174- def next_handler ( app ):
175- return Response ( 200 , content_types . TEXT_HTML , "ok" )
53+ class AuthError ( Exception ):
54+ pass
17655
177- frame = AsyncMiddlewareFrame (current_middleware = my_named_middleware , next_middleware = next_handler )
56+ def failing_middleware (app , next_middleware ):
57+ raise AuthError ("denied" )
17858
179- # THEN __name__ returns the current middleware name
180- assert frame . __name__ == "my_named_middleware"
59+ async def next_handler ( app ):
60+ return Response ( 200 , content_types . TEXT_HTML , "should not reach" )
18161
182- def test_str_representation (self ):
183- # GIVEN a frame with named middleware and next handler
184- def auth_middleware (app , next_mw ):
185- return next_mw (app )
62+ wrapped = wrap_middleware_async (failing_middleware , next_handler )
18663
187- def logging_middleware (app ):
188- return Response (200 , content_types .TEXT_HTML , "ok" )
64+ # WHEN calling the wrapped middleware
65+ # THEN the exception propagates without deadlocking
66+ with pytest .raises (AuthError , match = "denied" ):
67+ asyncio .run (wrapped (app ))
18968
190- frame = AsyncMiddlewareFrame (current_middleware = auth_middleware , next_middleware = logging_middleware )
19169
192- # THEN str() shows the call chain
193- assert str (frame ) == "[auth_middleware] next call chain is auth_middleware -> logging_middleware"
70+ def test_async_middleware_raising_before_next_propagates ():
71+ # GIVEN an async middleware that raises before calling next()
72+ app = _make_app ()
19473
195- def test_pushes_processed_stack_frame (self ):
196- # GIVEN a frame
197- app = _make_app ()
74+ class ValidationError (Exception ):
75+ pass
19876
199- async def my_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
200- return await next_middleware ( app )
77+ async def failing_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
78+ raise ValidationError ( "invalid request" )
20179
202- async def handler (app : ApiGatewayResolver ):
203- await asyncio .sleep (0 )
204- return Response (200 , content_types .TEXT_HTML , "ok " )
80+ async def next_handler (app : ApiGatewayResolver ):
81+ await asyncio .sleep (0 )
82+ return Response (200 , content_types .TEXT_HTML , "should not reach " )
20583
206- frame = AsyncMiddlewareFrame (current_middleware = my_middleware , next_middleware = handler )
207- app ._reset_processed_stack ()
84+ frame = AsyncMiddlewareFrame (current_middleware = failing_middleware , next_middleware = next_handler )
20885
209- # WHEN calling the frame
86+ # WHEN calling the frame
87+ # THEN the exception propagates
88+ with pytest .raises (ValidationError , match = "invalid request" ):
21089 asyncio .run (frame (app ))
211-
212- # THEN the processed stack frame is recorded for debugging
213- assert len (app .processed_stack_frames ) > 0
214- assert "my_middleware" in app .processed_stack_frames [0 ]
0 commit comments