Skip to content

Commit f4644f7

Browse files
fix(event_handler): prevent deadlock when async middleware raises before calling next() (#8196)
* fix(event_handler): avoid deadlock in async resolver * fix(event_handler): avoid deadlock in async resolver
1 parent 53678cf commit f4644f7

2 files changed

Lines changed: 56 additions & 174 deletions

File tree

aws_lambda_powertools/event_handler/middlewares/async_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,20 @@ def run_middleware() -> None:
8686
middleware_result_holder.append(result)
8787
except Exception as e:
8888
middleware_error_holder.append(e)
89+
finally:
90+
middleware_called_next.set()
8991

9092
thread = threading.Thread(target=run_middleware, daemon=True)
9193
thread.start()
9294

93-
# Wait for the middleware to call next()
95+
# Wait for the middleware to call next() or raise
9496
await middleware_called_next.wait()
9597

98+
# If middleware raised before calling next, propagate immediately
99+
if not next_app_holder:
100+
thread.join()
101+
raise middleware_error_holder[0]
102+
96103
# Resolve the async next_handler on the event-loop
97104
real_response = await next_handler(next_app_holder[0])
98105
real_response_holder.append(real_response)
Lines changed: 48 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import asyncio
22

3+
import pytest
4+
35
from aws_lambda_powertools.event_handler import content_types
46
from aws_lambda_powertools.event_handler.api_gateway import (
57
ApiGatewayResolver,
68
ProxyEventType,
79
Response,
810
)
911
from aws_lambda_powertools.event_handler.middlewares import NextMiddleware
10-
from aws_lambda_powertools.event_handler.middlewares.async_utils import AsyncMiddlewareFrame
12+
from aws_lambda_powertools.event_handler.middlewares.async_utils import AsyncMiddlewareFrame, wrap_middleware_async
1113
from tests.functional.utils import load_event
1214

1315
API_REST_EVENT = load_event("apiGatewayProxyEvent.json")
@@ -20,195 +22,68 @@ def _make_app() -> ApiGatewayResolver:
2022
return app
2123

2224

23-
class TestAsyncMiddlewareFrameWithAsyncMiddleware:
24-
def test_async_middleware_is_awaited(self):
25-
# GIVEN an async middleware and an async next handler
26-
app = _make_app()
27-
28-
async def my_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
29-
app.append_context(middleware_called=True)
30-
return await next_middleware(app)
31-
32-
async def next_handler(app: ApiGatewayResolver):
33-
await asyncio.sleep(0)
34-
return Response(200, content_types.TEXT_HTML, "from handler")
35-
36-
frame = AsyncMiddlewareFrame(current_middleware=my_middleware, next_middleware=next_handler)
37-
38-
# WHEN calling the frame
39-
result = asyncio.run(frame(app))
40-
41-
# THEN the async middleware is invoked and the chain proceeds
42-
assert result.status_code == 200
43-
assert result.body == "from handler"
44-
assert app.context.get("middleware_called") is True
45-
46-
def test_async_middleware_can_short_circuit(self):
47-
# GIVEN an async middleware that returns early without calling next
48-
app = _make_app()
49-
50-
async def blocking_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
51-
await asyncio.sleep(0)
52-
return Response(403, content_types.TEXT_PLAIN, "forbidden")
53-
54-
async def next_handler(app: ApiGatewayResolver):
55-
await asyncio.sleep(0)
56-
return Response(200, content_types.TEXT_HTML, "should not reach")
57-
58-
frame = AsyncMiddlewareFrame(current_middleware=blocking_middleware, next_middleware=next_handler)
59-
60-
# WHEN calling the frame
61-
result = asyncio.run(frame(app))
62-
63-
# THEN the middleware short-circuits the chain
64-
assert result.status_code == 403
65-
assert result.body == "forbidden"
66-
67-
def test_multiple_async_middlewares_chained(self):
68-
# GIVEN two async middlewares chained together
69-
app = _make_app()
70-
71-
async def first_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
72-
app.append_context(first=True)
73-
return await next_middleware(app)
74-
75-
async def second_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
76-
app.append_context(second=True)
77-
return await next_middleware(app)
78-
79-
async def final_handler(app: ApiGatewayResolver):
80-
await asyncio.sleep(0)
81-
return Response(200, content_types.TEXT_HTML, "done")
82-
83-
# WHEN building a chain: first -> second -> handler
84-
inner_frame = AsyncMiddlewareFrame(current_middleware=second_middleware, next_middleware=final_handler)
85-
outer_frame = AsyncMiddlewareFrame(current_middleware=first_middleware, next_middleware=inner_frame)
86-
87-
result = asyncio.run(outer_frame(app))
88-
89-
# THEN both middlewares run in order
90-
assert result.status_code == 200
91-
assert app.context.get("first") is True
92-
assert app.context.get("second") is True
25+
def test_sync_middleware_raising_before_next_does_not_deadlock():
26+
# GIVEN a sync middleware that raises before calling next()
27+
# This previously caused a deadlock because middleware_called_next was never set
28+
app = _make_app()
9329

30+
class AuthError(Exception):
31+
pass
9432

95-
class TestAsyncMiddlewareFrameWithSyncMiddleware:
96-
def test_sync_middleware_is_bridged(self):
97-
# GIVEN a sync middleware and an async next handler
98-
app = _make_app()
33+
def failing_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
34+
raise AuthError("denied")
9935

100-
def sync_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
101-
app.append_context(sync_called=True)
102-
return next_middleware(app)
36+
async def next_handler(app: ApiGatewayResolver):
37+
await asyncio.sleep(0)
38+
return Response(200, content_types.TEXT_HTML, "should not reach")
10339

104-
async def next_handler(app: ApiGatewayResolver):
105-
await asyncio.sleep(0)
106-
return Response(200, content_types.TEXT_HTML, "async handler")
40+
frame = AsyncMiddlewareFrame(current_middleware=failing_middleware, next_middleware=next_handler)
10741

108-
frame = AsyncMiddlewareFrame(current_middleware=sync_middleware, next_middleware=next_handler)
109-
110-
# WHEN calling the frame
111-
result = asyncio.run(frame(app))
112-
113-
# THEN the sync middleware is bridged via wrap_middleware_async
114-
assert result.status_code == 200
115-
assert result.body == "async handler"
116-
assert app.context.get("sync_called") is True
117-
118-
def test_sync_middleware_can_short_circuit(self):
119-
# GIVEN a sync middleware that returns early
120-
app = _make_app()
121-
122-
def sync_blocking(app: ApiGatewayResolver, next_middleware: NextMiddleware):
123-
return Response(401, content_types.TEXT_PLAIN, "unauthorized")
124-
125-
async def next_handler(app: ApiGatewayResolver):
126-
await asyncio.sleep(0)
127-
return Response(200, content_types.TEXT_HTML, "should not reach")
128-
129-
frame = AsyncMiddlewareFrame(current_middleware=sync_blocking, next_middleware=next_handler)
130-
131-
# WHEN calling the frame
132-
result = asyncio.run(frame(app))
133-
134-
# THEN the sync middleware short-circuits
135-
assert result.status_code == 401
136-
assert result.body == "unauthorized"
137-
138-
139-
class TestAsyncMiddlewareFrameMixedChain:
140-
def test_sync_then_async_middleware(self):
141-
# GIVEN a chain with sync middleware followed by async middleware
142-
app = _make_app()
143-
144-
def sync_mw(app: ApiGatewayResolver, next_middleware: NextMiddleware):
145-
app.append_context(sync_ran=True)
146-
return next_middleware(app)
147-
148-
async def async_mw(app: ApiGatewayResolver, next_middleware: NextMiddleware):
149-
app.append_context(async_ran=True)
150-
return await next_middleware(app)
151-
152-
async def handler(app: ApiGatewayResolver):
153-
await asyncio.sleep(0)
154-
return Response(200, content_types.TEXT_HTML, "mixed chain")
155-
156-
inner = AsyncMiddlewareFrame(current_middleware=async_mw, next_middleware=handler)
157-
outer = AsyncMiddlewareFrame(current_middleware=sync_mw, next_middleware=inner)
158-
159-
# WHEN calling the chain
160-
result = asyncio.run(outer(app))
161-
162-
# THEN both middlewares execute in order
163-
assert result.status_code == 200
164-
assert app.context.get("sync_ran") is True
165-
assert app.context.get("async_ran") is True
42+
# WHEN calling the frame
43+
# THEN the exception propagates without deadlocking
44+
with pytest.raises(AuthError, match="denied"):
45+
asyncio.run(frame(app))
16646

16747

168-
class TestAsyncMiddlewareFrameProperties:
169-
def test_name_property(self):
170-
# GIVEN a middleware with a known name
171-
def my_named_middleware(app, next_mw):
172-
return next_mw(app)
48+
def test_wrap_middleware_async_sync_raising_before_next_does_not_deadlock():
49+
# GIVEN a sync middleware that raises before calling next(), using wrap_middleware_async
50+
# This exercises _run_sync_middleware_in_thread directly
51+
app = _make_app()
17352

174-
def next_handler(app):
175-
return Response(200, content_types.TEXT_HTML, "ok")
53+
class AuthError(Exception):
54+
pass
17655

177-
frame = AsyncMiddlewareFrame(current_middleware=my_named_middleware, next_middleware=next_handler)
56+
def failing_middleware(app, next_middleware):
57+
raise AuthError("denied")
17858

179-
# THEN __name__ returns the current middleware name
180-
assert frame.__name__ == "my_named_middleware"
59+
async def next_handler(app):
60+
return Response(200, content_types.TEXT_HTML, "should not reach")
18161

182-
def test_str_representation(self):
183-
# GIVEN a frame with named middleware and next handler
184-
def auth_middleware(app, next_mw):
185-
return next_mw(app)
62+
wrapped = wrap_middleware_async(failing_middleware, next_handler)
18663

187-
def logging_middleware(app):
188-
return Response(200, content_types.TEXT_HTML, "ok")
64+
# WHEN calling the wrapped middleware
65+
# THEN the exception propagates without deadlocking
66+
with pytest.raises(AuthError, match="denied"):
67+
asyncio.run(wrapped(app))
18968

190-
frame = AsyncMiddlewareFrame(current_middleware=auth_middleware, next_middleware=logging_middleware)
19169

192-
# THEN str() shows the call chain
193-
assert str(frame) == "[auth_middleware] next call chain is auth_middleware -> logging_middleware"
70+
def test_async_middleware_raising_before_next_propagates():
71+
# GIVEN an async middleware that raises before calling next()
72+
app = _make_app()
19473

195-
def test_pushes_processed_stack_frame(self):
196-
# GIVEN a frame
197-
app = _make_app()
74+
class ValidationError(Exception):
75+
pass
19876

199-
async def my_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
200-
return await next_middleware(app)
77+
async def failing_middleware(app: ApiGatewayResolver, next_middleware: NextMiddleware):
78+
raise ValidationError("invalid request")
20179

202-
async def handler(app: ApiGatewayResolver):
203-
await asyncio.sleep(0)
204-
return Response(200, content_types.TEXT_HTML, "ok")
80+
async def next_handler(app: ApiGatewayResolver):
81+
await asyncio.sleep(0)
82+
return Response(200, content_types.TEXT_HTML, "should not reach")
20583

206-
frame = AsyncMiddlewareFrame(current_middleware=my_middleware, next_middleware=handler)
207-
app._reset_processed_stack()
84+
frame = AsyncMiddlewareFrame(current_middleware=failing_middleware, next_middleware=next_handler)
20885

209-
# WHEN calling the frame
86+
# WHEN calling the frame
87+
# THEN the exception propagates
88+
with pytest.raises(ValidationError, match="invalid request"):
21089
asyncio.run(frame(app))
211-
212-
# THEN the processed stack frame is recorded for debugging
213-
assert len(app.processed_stack_frames) > 0
214-
assert "my_middleware" in app.processed_stack_frames[0]

0 commit comments

Comments
 (0)