|
1 | 1 | import pytest |
2 | 2 |
|
3 | | -from strands.hooks import BeforeToolCallEvent |
| 3 | +from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent |
4 | 4 | from strands.interrupt import Interrupt |
5 | 5 | from strands.tools.executors import ConcurrentToolExecutor |
6 | 6 | from strands.tools.structured_output._structured_output_context import StructuredOutputContext |
@@ -76,3 +76,30 @@ def interrupt_callback(event): |
76 | 76 | tru_results = tool_results |
77 | 77 | exp_results = [exp_events[1].tool_result] |
78 | 78 | assert tru_results == exp_results |
| 79 | + |
| 80 | + |
| 81 | +@pytest.mark.asyncio |
| 82 | +async def test_concurrent_executor_reraises_exceptions( |
| 83 | + executor, agent, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context, alist |
| 84 | +): |
| 85 | + """Test that hook re-raised exceptions propagate and cancel remaining tasks.""" |
| 86 | + |
| 87 | + def reraise_callback(event): |
| 88 | + if event.exception is not None: |
| 89 | + raise event.exception |
| 90 | + |
| 91 | + agent.hooks.add_callback(AfterToolCallEvent, reraise_callback) |
| 92 | + |
| 93 | + tool_uses = [ |
| 94 | + {"name": "exception_tool", "toolUseId": "1", "input": {}}, |
| 95 | + {"name": "slow_tool", "toolUseId": "2", "input": {}}, |
| 96 | + ] |
| 97 | + |
| 98 | + stream = executor._execute( |
| 99 | + agent, tool_uses, tool_results, cycle_trace, cycle_span, invocation_state, structured_output_context |
| 100 | + ) |
| 101 | + |
| 102 | + with pytest.raises(RuntimeError, match="Tool error"): |
| 103 | + await alist(stream) |
| 104 | + |
| 105 | + assert tool_results == [] |
0 commit comments