Skip to content

Commit 835d0fd

Browse files
fix(event_handler): avoid deadlock in async resolver
1 parent f83e141 commit 835d0fd

2 files changed

Lines changed: 40 additions & 180 deletions

File tree

aws_lambda_powertools/event_handler/middlewares/async_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,20 @@ def run_middleware() -> None:
8686
middleware_result_holder.append(result)
8787
except Exception as e:
8888
middleware_error_holder.append(e)
89+
finally:
90+
middleware_called_next.set()
8991

9092
thread = threading.Thread(target=run_middleware, daemon=True)
9193
thread.start()
9294

93-
# Wait for the middleware to call next()
95+
# Wait for the middleware to call next() or raise
9496
await middleware_called_next.wait()
9597

98+
# If middleware raised before calling next, propagate immediately
99+
if not next_app_holder:
100+
thread.join()
101+
raise middleware_error_holder[0]
102+
96103
# Resolve the async next_handler on the event-loop
97104
real_response = await next_handler(next_app_holder[0])
98105
real_response_holder.append(real_response)
Lines changed: 32 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22

3+
import pytest
4+
35
from aws_lambda_powertools.event_handler import content_types
46
from aws_lambda_powertools.event_handler.api_gateway import (
57
ApiGatewayResolver,
@@ -20,195 +22,46 @@ def _make_app() -> ApiGatewayResolver:
2022
return app
2123

2224

23-
class TestAsyncMiddlewareFrameWithAsyncMiddleware:
24-
def test_async_middleware_is_awaited(self):
25-
# GIVEN an async middleware and an async next handler
26-
app = _make_app()
27-
28-
async def my_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
29-
app.append_context(middleware_called=True)
30-
return await next_middleware(app)
31-
32-
async def next_handler(app: ApiGatewayResolver):
33-
await asyncio.sleep(0)
34-
return Response(200, content_types.TEXT_HTML, "from handler")
35-
36-
frame = AsyncMiddlewareFrame(current_middleware=my_middleware, next_middleware=next_handler)
37-
38-
# WHEN calling the frame
39-
result = asyncio.run(frame(app))
40-
41-
# THEN the async middleware is invoked and the chain proceeds
42-
assert result.status_code == 200
43-
assert result.body == "from handler"
44-
assert app.context.get("middleware_called") is True
45-
46-
def test_async_middleware_can_short_circuit(self):
47-
# GIVEN an async middleware that returns early without calling next
48-
app = _make_app()
49-
50-
async def blocking_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
51-
await asyncio.sleep(0)
52-
return Response(403, content_types.TEXT_PLAIN, "forbidden")
53-
54-
async def next_handler(app: ApiGatewayResolver):
55-
await asyncio.sleep(0)
56-
return Response(200, content_types.TEXT_HTML, "should not reach")
57-
58-
frame = AsyncMiddlewareFrame(current_middleware=blocking_middleware, next_middleware=next_handler)
59-
60-
# WHEN calling the frame
61-
result = asyncio.run(frame(app))
62-
63-
# THEN the middleware short-circuits the chain
64-
assert result.status_code == 403
65-
assert result.body == "forbidden"
66-
67-
def test_multiple_async_middlewares_chained(self):
68-
# GIVEN two async middlewares chained together
69-
app = _make_app()
70-
71-
async def first_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
72-
app.append_context(first=True)
73-
return await next_middleware(app)
74-
75-
async def second_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
76-
app.append_context(second=True)
77-
return await next_middleware(app)
78-
79-
async def final_handler(app: ApiGatewayResolver):
80-
await asyncio.sleep(0)
81-
return Response(200, content_types.TEXT_HTML, "done")
82-
83-
# WHEN building a chain: first -> second -> handler
84-
inner_frame = AsyncMiddlewareFrame(current_middleware=second_middleware, next_middleware=final_handler)
85-
outer_frame = AsyncMiddlewareFrame(current_middleware=first_middleware, next_middleware=inner_frame)
86-
87-
result = asyncio.run(outer_frame(app))
88-
89-
# THEN both middlewares run in order
90-
assert result.status_code == 200
91-
assert app.context.get("first") is True
92-
assert app.context.get("second") is True
93-
94-
95-
class TestAsyncMiddlewareFrameWithSyncMiddleware:
96-
def test_sync_middleware_is_bridged(self):
97-
# GIVEN a sync middleware and an async next handler
98-
app = _make_app()
99-
100-
def sync_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
101-
app.append_context(sync_called=True)
102-
return next_middleware(app)
25+
def test_sync_middleware_raising_before_next_does_not_deadlock():
26+
# GIVEN a sync middleware that raises before calling next()
27+
# This previously caused a deadlock because middleware_called_next was never set
28+
app = _make_app()
10329

104-
async def next_handler(app: ApiGatewayResolver):
105-
await asyncio.sleep(0)
106-
return Response(200, content_types.TEXT_HTML, "async handler")
30+
class AuthError(Exception):
31+
pass
10732

108-
frame = AsyncMiddlewareFrame(current_middleware=sync_middleware, next_middleware=next_handler)
33+
def failing_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
34+
raise AuthError("denied")
10935

110-
# WHEN calling the frame
111-
result = asyncio.run(frame(app))
36+
async def next_handler(app: ApiGatewayResolver):
37+
await asyncio.sleep(0)
38+
return Response(200, content_types.TEXT_HTML, "should not reach")
11239

113-
# THEN the sync middleware is bridged via wrap_middleware_async
114-
assert result.status_code == 200
115-
assert result.body == "async handler"
116-
assert app.context.get("sync_called") is True
40+
frame = AsyncMiddlewareFrame(current_middleware=failing_middleware, next_middleware=next_handler)
11741

118-
def test_sync_middleware_can_short_circuit(self):
119-
# GIVEN a sync middleware that returns early
120-
app = _make_app()
121-
122-
def sync_blocking(app: ApiGatewayResolver, next_middleware: NextMiddleware):
123-
return Response(401, content_types.TEXT_PLAIN, "unauthorized")
124-
125-
async def next_handler(app: ApiGatewayResolver):
126-
await asyncio.sleep(0)
127-
return Response(200, content_types.TEXT_HTML, "should not reach")
128-
129-
frame = AsyncMiddlewareFrame(current_middleware=sync_blocking, next_middleware=next_handler)
130-
131-
# WHEN calling the frame
132-
result = asyncio.run(frame(app))
133-
134-
# THEN the sync middleware short-circuits
135-
assert result.status_code == 401
136-
assert result.body == "unauthorized"
137-
138-
139-
class TestAsyncMiddlewareFrameMixedChain:
140-
def test_sync_then_async_middleware(self):
141-
# GIVEN a chain with sync middleware followed by async middleware
142-
app = _make_app()
143-
144-
def sync_mw(app: ApiGatewayResolver, next_middleware: NextMiddleware):
145-
app.append_context(sync_ran=True)
146-
return next_middleware(app)
147-
148-
async def async_mw(app: ApiGatewayResolver, next_middleware: NextMiddleware):
149-
app.append_context(async_ran=True)
150-
return await next_middleware(app)
151-
152-
async def handler(app: ApiGatewayResolver):
153-
await asyncio.sleep(0)
154-
return Response(200, content_types.TEXT_HTML, "mixed chain")
155-
156-
inner = AsyncMiddlewareFrame(current_middleware=async_mw, next_middleware=handler)
157-
outer = AsyncMiddlewareFrame(current_middleware=sync_mw, next_middleware=inner)
158-
159-
# WHEN calling the chain
160-
result = asyncio.run(outer(app))
161-
162-
# THEN both middlewares execute in order
163-
assert result.status_code == 200
164-
assert app.context.get("sync_ran") is True
165-
assert app.context.get("async_ran") is True
166-
167-
168-
class TestAsyncMiddlewareFrameProperties:
169-
def test_name_property(self):
170-
# GIVEN a middleware with a known name
171-
def my_named_middleware(app, next_mw):
172-
return next_mw(app)
173-
174-
def next_handler(app):
175-
return Response(200, content_types.TEXT_HTML, "ok")
176-
177-
frame = AsyncMiddlewareFrame(current_middleware=my_named_middleware, next_middleware=next_handler)
178-
179-
# THEN __name__ returns the current middleware name
180-
assert frame.__name__ == "my_named_middleware"
181-
182-
def test_str_representation(self):
183-
# GIVEN a frame with named middleware and next handler
184-
def auth_middleware(app, next_mw):
185-
return next_mw(app)
186-
187-
def logging_middleware(app):
188-
return Response(200, content_types.TEXT_HTML, "ok")
42+
# WHEN calling the frame
43+
# THEN the exception propagates without deadlocking
44+
with pytest.raises(AuthError, match="denied"):
45+
asyncio.run(frame(app))
18946

190-
frame = AsyncMiddlewareFrame(current_middleware=auth_middleware, next_middleware=logging_middleware)
19147

192-
# THEN str() shows the call chain
193-
assert str(frame) == "[auth_middleware] next call chain is auth_middleware -> logging_middleware"
48+
def test_async_middleware_raising_before_next_propagates():
49+
# GIVEN an async middleware that raises before calling next()
50+
app = _make_app()
19451

195-
def test_pushes_processed_stack_frame(self):
196-
# GIVEN a frame
197-
app = _make_app()
52+
class ValidationError(Exception):
53+
pass
19854

199-
async def my_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
200-
return await next_middleware(app)
55+
async def failing_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
56+
raise ValidationError("invalid request")
20157

202-
async def handler(app: ApiGatewayResolver):
203-
await asyncio.sleep(0)
204-
return Response(200, content_types.TEXT_HTML, "ok")
58+
async def next_handler(app: ApiGatewayResolver):
59+
await asyncio.sleep(0)
60+
return Response(200, content_types.TEXT_HTML, "should not reach")
20561

206-
frame = AsyncMiddlewareFrame(current_middleware=my_middleware, next_middleware=handler)
207-
app._reset_processed_stack()
62+
frame = AsyncMiddlewareFrame(current_middleware=failing_middleware, next_middleware=next_handler)
20863

209-
# WHEN calling the frame
64+
# WHEN calling the frame
65+
# THEN the exception propagates
66+
with pytest.raises(ValidationError, match="invalid request"):
21067
asyncio.run(frame(app))
211-
212-
# THEN the processed stack frame is recorded for debugging
213-
assert len(app.processed_stack_frames) > 0
214-
assert "my_middleware" in app.processed_stack_frames[0]

0 commit comments

Comments
 (0)