@@ -530,12 +530,30 @@ def test_handle_content_block_stop(state, exp_updated_state):
530530def test_handle_message_stop ():
531531 event : MessageStopEvent = {"stopReason" : "end_turn" }
532532
533- tru_reason = strands .event_loop .streaming .handle_message_stop (event )
533+ tru_reason = strands .event_loop .streaming .handle_message_stop (event , [] )
534534 exp_reason = "end_turn"
535535
536536 assert tru_reason == exp_reason
537537
538538
539+ def test_handle_message_stop_overrides_end_turn_when_tool_use_present ():
540+ event : MessageStopEvent = {"stopReason" : "end_turn" }
541+ content = [{"toolUse" : {"toolUseId" : "t1" , "name" : "myTool" , "input" : {}}}]
542+
543+ tru_reason = strands .event_loop .streaming .handle_message_stop (event , content )
544+
545+ assert tru_reason == "tool_use"
546+
547+
548+ def test_handle_message_stop_keeps_tool_use_unchanged ():
549+ event : MessageStopEvent = {"stopReason" : "tool_use" }
550+ content = [{"toolUse" : {"toolUseId" : "t1" , "name" : "myTool" , "input" : {}}}]
551+
552+ tru_reason = strands .event_loop .streaming .handle_message_stop (event , content )
553+
554+ assert tru_reason == "tool_use"
555+
556+
539557def test_extract_usage_metrics ():
540558 event = {
541559 "usage" : {"inputTokens" : 0 , "outputTokens" : 0 , "totalTokens" : 0 },
@@ -1334,3 +1352,68 @@ async def test_stream_messages_normalizes_messages(agenerator, alist):
13341352 {"content" : [{"toolUse" : {"name" : "INVALID_TOOL_NAME" }}], "role" : "assistant" },
13351353 {"content" : [{"toolUse" : {"name" : "INVALID_TOOL_NAME" }}], "role" : "assistant" },
13361354 ]
1355+
1356+
1357+ @pytest .mark .asyncio
1358+ async def test_process_stream_overrides_end_turn_when_tool_use_present (agenerator , alist ):
1359+ response = [
1360+ {"messageStart" : {"role" : "assistant" }},
1361+ {"contentBlockStart" : {"contentBlockIndex" : 0 , "start" : {"toolUse" : {"toolUseId" : "t1" , "name" : "myTool" }}}},
1362+ {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : '{"key": "val"}' }}, "contentBlockIndex" : 0 }},
1363+ {"contentBlockStop" : {"contentBlockIndex" : 0 }},
1364+ {"messageStop" : {"stopReason" : "end_turn" }},
1365+ {
1366+ "metadata" : {
1367+ "usage" : {"inputTokens" : 10 , "outputTokens" : 20 , "totalTokens" : 30 },
1368+ "metrics" : {"latencyMs" : 100 },
1369+ }
1370+ },
1371+ ]
1372+
1373+ stream = strands .event_loop .streaming .process_stream (agenerator (response ))
1374+ last_event = cast (ModelStopReason , (await alist (stream ))[- 1 ])
1375+
1376+ assert last_event ["stop" ][0 ] == "tool_use"
1377+
1378+
1379+ @pytest .mark .asyncio
1380+ async def test_process_stream_keeps_end_turn_when_no_tool_use (agenerator , alist ):
1381+ response = [
1382+ {"messageStart" : {"role" : "assistant" }},
1383+ {"contentBlockDelta" : {"delta" : {"text" : "Hello!" }, "contentBlockIndex" : 0 }},
1384+ {"contentBlockStop" : {"contentBlockIndex" : 0 }},
1385+ {"messageStop" : {"stopReason" : "end_turn" }},
1386+ {
1387+ "metadata" : {
1388+ "usage" : {"inputTokens" : 10 , "outputTokens" : 20 , "totalTokens" : 30 },
1389+ "metrics" : {"latencyMs" : 100 },
1390+ }
1391+ },
1392+ ]
1393+
1394+ stream = strands .event_loop .streaming .process_stream (agenerator (response ))
1395+ last_event = cast (ModelStopReason , (await alist (stream ))[- 1 ])
1396+
1397+ assert last_event ["stop" ][0 ] == "end_turn"
1398+
1399+
1400+ @pytest .mark .asyncio
1401+ async def test_process_stream_keeps_tool_use_stop_reason_unchanged (agenerator , alist ):
1402+ response = [
1403+ {"messageStart" : {"role" : "assistant" }},
1404+ {"contentBlockStart" : {"contentBlockIndex" : 0 , "start" : {"toolUse" : {"toolUseId" : "t1" , "name" : "myTool" }}}},
1405+ {"contentBlockDelta" : {"delta" : {"toolUse" : {"input" : "{}" }}, "contentBlockIndex" : 0 }},
1406+ {"contentBlockStop" : {"contentBlockIndex" : 0 }},
1407+ {"messageStop" : {"stopReason" : "tool_use" }},
1408+ {
1409+ "metadata" : {
1410+ "usage" : {"inputTokens" : 10 , "outputTokens" : 20 , "totalTokens" : 30 },
1411+ "metrics" : {"latencyMs" : 100 },
1412+ }
1413+ },
1414+ ]
1415+
1416+ stream = strands .event_loop .streaming .process_stream (agenerator (response ))
1417+ last_event = cast (ModelStopReason , (await alist (stream ))[- 1 ])
1418+
1419+ assert last_event ["stop" ][0 ] == "tool_use"
0 commit comments