11import asyncio
22
3+ import pytest
4+
35from aws_lambda_powertools .event_handler import content_types
46from aws_lambda_powertools .event_handler .api_gateway import (
57 ApiGatewayResolver ,
@@ -20,195 +22,46 @@ def _make_app() -> ApiGatewayResolver:
2022 return app
2123
2224
23- class TestAsyncMiddlewareFrameWithAsyncMiddleware :
24- def test_async_middleware_is_awaited (self ):
25- # GIVEN an async middleware and an async next handler
26- app = _make_app ()
27-
28- async def my_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
29- app .append_context (middleware_called = True )
30- return await next_middleware (app )
31-
32- async def next_handler (app : ApiGatewayResolver ):
33- await asyncio .sleep (0 )
34- return Response (200 , content_types .TEXT_HTML , "from handler" )
35-
36- frame = AsyncMiddlewareFrame (current_middleware = my_middleware , next_middleware = next_handler )
37-
38- # WHEN calling the frame
39- result = asyncio .run (frame (app ))
40-
41- # THEN the async middleware is invoked and the chain proceeds
42- assert result .status_code == 200
43- assert result .body == "from handler"
44- assert app .context .get ("middleware_called" ) is True
45-
46- def test_async_middleware_can_short_circuit (self ):
47- # GIVEN an async middleware that returns early without calling next
48- app = _make_app ()
49-
50- async def blocking_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
51- await asyncio .sleep (0 )
52- return Response (403 , content_types .TEXT_PLAIN , "forbidden" )
53-
54- async def next_handler (app : ApiGatewayResolver ):
55- await asyncio .sleep (0 )
56- return Response (200 , content_types .TEXT_HTML , "should not reach" )
57-
58- frame = AsyncMiddlewareFrame (current_middleware = blocking_middleware , next_middleware = next_handler )
59-
60- # WHEN calling the frame
61- result = asyncio .run (frame (app ))
62-
63- # THEN the middleware short-circuits the chain
64- assert result .status_code == 403
65- assert result .body == "forbidden"
66-
67- def test_multiple_async_middlewares_chained (self ):
68- # GIVEN two async middlewares chained together
69- app = _make_app ()
70-
71- async def first_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
72- app .append_context (first = True )
73- return await next_middleware (app )
74-
75- async def second_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
76- app .append_context (second = True )
77- return await next_middleware (app )
78-
79- async def final_handler (app : ApiGatewayResolver ):
80- await asyncio .sleep (0 )
81- return Response (200 , content_types .TEXT_HTML , "done" )
82-
83- # WHEN building a chain: first -> second -> handler
84- inner_frame = AsyncMiddlewareFrame (current_middleware = second_middleware , next_middleware = final_handler )
85- outer_frame = AsyncMiddlewareFrame (current_middleware = first_middleware , next_middleware = inner_frame )
86-
87- result = asyncio .run (outer_frame (app ))
88-
89- # THEN both middlewares run in order
90- assert result .status_code == 200
91- assert app .context .get ("first" ) is True
92- assert app .context .get ("second" ) is True
93-
94-
95- class TestAsyncMiddlewareFrameWithSyncMiddleware :
96- def test_sync_middleware_is_bridged (self ):
97- # GIVEN a sync middleware and an async next handler
98- app = _make_app ()
99-
100- def sync_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
101- app .append_context (sync_called = True )
102- return next_middleware (app )
25+ def test_sync_middleware_raising_before_next_does_not_deadlock ():
26+ # GIVEN a sync middleware that raises before calling next()
27+ # This previously caused a deadlock because middleware_called_next was never set
28+ app = _make_app ()
10329
104- async def next_handler (app : ApiGatewayResolver ):
105- await asyncio .sleep (0 )
106- return Response (200 , content_types .TEXT_HTML , "async handler" )
30+ class AuthError (Exception ):
31+ pass
10732
108- frame = AsyncMiddlewareFrame (current_middleware = sync_middleware , next_middleware = next_handler )
33+ def failing_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
34+ raise AuthError ("denied" )
10935
110- # WHEN calling the frame
111- result = asyncio .run (frame (app ))
36+ async def next_handler (app : ApiGatewayResolver ):
37+ await asyncio .sleep (0 )
38+ return Response (200 , content_types .TEXT_HTML , "should not reach" )
11239
113- # THEN the sync middleware is bridged via wrap_middleware_async
114- assert result .status_code == 200
115- assert result .body == "async handler"
116- assert app .context .get ("sync_called" ) is True
40+ frame = AsyncMiddlewareFrame (current_middleware = failing_middleware , next_middleware = next_handler )
11741
118- def test_sync_middleware_can_short_circuit (self ):
119- # GIVEN a sync middleware that returns early
120- app = _make_app ()
121-
122- def sync_blocking (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
123- return Response (401 , content_types .TEXT_PLAIN , "unauthorized" )
124-
125- async def next_handler (app : ApiGatewayResolver ):
126- await asyncio .sleep (0 )
127- return Response (200 , content_types .TEXT_HTML , "should not reach" )
128-
129- frame = AsyncMiddlewareFrame (current_middleware = sync_blocking , next_middleware = next_handler )
130-
131- # WHEN calling the frame
132- result = asyncio .run (frame (app ))
133-
134- # THEN the sync middleware short-circuits
135- assert result .status_code == 401
136- assert result .body == "unauthorized"
137-
138-
139- class TestAsyncMiddlewareFrameMixedChain :
140- def test_sync_then_async_middleware (self ):
141- # GIVEN a chain with sync middleware followed by async middleware
142- app = _make_app ()
143-
144- def sync_mw (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
145- app .append_context (sync_ran = True )
146- return next_middleware (app )
147-
148- async def async_mw (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
149- app .append_context (async_ran = True )
150- return await next_middleware (app )
151-
152- async def handler (app : ApiGatewayResolver ):
153- await asyncio .sleep (0 )
154- return Response (200 , content_types .TEXT_HTML , "mixed chain" )
155-
156- inner = AsyncMiddlewareFrame (current_middleware = async_mw , next_middleware = handler )
157- outer = AsyncMiddlewareFrame (current_middleware = sync_mw , next_middleware = inner )
158-
159- # WHEN calling the chain
160- result = asyncio .run (outer (app ))
161-
162- # THEN both middlewares execute in order
163- assert result .status_code == 200
164- assert app .context .get ("sync_ran" ) is True
165- assert app .context .get ("async_ran" ) is True
166-
167-
168- class TestAsyncMiddlewareFrameProperties :
169- def test_name_property (self ):
170- # GIVEN a middleware with a known name
171- def my_named_middleware (app , next_mw ):
172- return next_mw (app )
173-
174- def next_handler (app ):
175- return Response (200 , content_types .TEXT_HTML , "ok" )
176-
177- frame = AsyncMiddlewareFrame (current_middleware = my_named_middleware , next_middleware = next_handler )
178-
179- # THEN __name__ returns the current middleware name
180- assert frame .__name__ == "my_named_middleware"
181-
182- def test_str_representation (self ):
183- # GIVEN a frame with named middleware and next handler
184- def auth_middleware (app , next_mw ):
185- return next_mw (app )
186-
187- def logging_middleware (app ):
188- return Response (200 , content_types .TEXT_HTML , "ok" )
42+ # WHEN calling the frame
43+ # THEN the exception propagates without deadlocking
44+ with pytest .raises (AuthError , match = "denied" ):
45+ asyncio .run (frame (app ))
18946
190- frame = AsyncMiddlewareFrame (current_middleware = auth_middleware , next_middleware = logging_middleware )
19147
192- # THEN str() shows the call chain
193- assert str (frame ) == "[auth_middleware] next call chain is auth_middleware -> logging_middleware"
48+ def test_async_middleware_raising_before_next_propagates ():
49+ # GIVEN an async middleware that raises before calling next()
50+ app = _make_app ()
19451
195- def test_pushes_processed_stack_frame (self ):
196- # GIVEN a frame
197- app = _make_app ()
52+ class ValidationError (Exception ):
53+ pass
19854
199- async def my_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
200- return await next_middleware ( app )
55+ async def failing_middleware (app : ApiGatewayResolver , next_middleware : NextMiddleware ):
56+ raise ValidationError ( "invalid request" )
20157
202- async def handler (app : ApiGatewayResolver ):
203- await asyncio .sleep (0 )
204- return Response (200 , content_types .TEXT_HTML , "ok " )
58+ async def next_handler (app : ApiGatewayResolver ):
59+ await asyncio .sleep (0 )
60+ return Response (200 , content_types .TEXT_HTML , "should not reach " )
20561
206- frame = AsyncMiddlewareFrame (current_middleware = my_middleware , next_middleware = handler )
207- app ._reset_processed_stack ()
62+ frame = AsyncMiddlewareFrame (current_middleware = failing_middleware , next_middleware = next_handler )
20863
209- # WHEN calling the frame
64+ # WHEN calling the frame
65+ # THEN the exception propagates
66+ with pytest .raises (ValidationError , match = "invalid request" ):
21067 asyncio .run (frame (app ))
211-
212- # THEN the processed stack frame is recorded for debugging
213- assert len (app .processed_stack_frames ) > 0
214- assert "my_middleware" in app .processed_stack_frames [0 ]
0 commit comments