33import time
44import uuid
55import asyncio
6+ from typing import cast
67from datetime import UTC , datetime
78from unittest .mock import AsyncMock , MagicMock , patch
89
@@ -52,7 +53,8 @@ async def slow_start(span: Span) -> None:
5253
5354
5455class TestAsyncSpanQueueOrdering :
55- async def test_fifo_ordering_preserved (self ):
56+ async def test_per_span_start_before_end (self ):
57+ """START always completes before END for the same span, even with batching."""
5658 call_log : list [tuple [str , str ]] = []
5759
5860 async def record_start (span : Span ) -> None :
@@ -77,12 +79,19 @@ async def record_end(span: Span) -> None:
7779
7880 await queue .shutdown ()
7981
80- assert call_log == [
81- ("start" , "span-a" ),
82- ("end" , "span-a" ),
83- ("start" , "span-b" ),
84- ("end" , "span-b" ),
85- ]
82+ # All 4 events should fire
83+ assert len (call_log ) == 4
84+
85+ # Per-span invariant: START before END
86+ for span_id in ("span-a" , "span-b" ):
87+ start_idx = next (i for i , (ev , sid ) in enumerate (call_log ) if ev == "start" and sid == span_id )
88+ end_idx = next (i for i , (ev , sid ) in enumerate (call_log ) if ev == "end" and sid == span_id )
89+ assert start_idx < end_idx , f"START should come before END for { span_id } "
90+
91+ # All STARTs before all ENDs within a batch
92+ start_indices = [i for i , (ev , _ ) in enumerate (call_log ) if ev == "start" ]
93+ end_indices = [i for i , (ev , _ ) in enumerate (call_log ) if ev == "end" ]
94+ assert max (start_indices ) < min (end_indices ), "All STARTs should complete before any END"
8695
8796
8897class TestAsyncSpanQueueErrorHandling :
@@ -154,6 +163,61 @@ async def test_enqueue_after_shutdown_is_dropped(self):
154163 proc .on_span_start .assert_not_called ()
155164
156165
166+ class TestAsyncSpanQueueBatchConcurrency :
167+ async def test_batch_processes_multiple_items_concurrently (self ):
168+ """Items in the same batch should run concurrently, not serially."""
169+ concurrency = 0
170+ max_concurrency = 0
171+ lock = asyncio .Lock ()
172+
173+ async def slow_start (span : Span ) -> None :
174+ nonlocal concurrency , max_concurrency
175+ async with lock :
176+ concurrency += 1
177+ max_concurrency = max (max_concurrency , concurrency )
178+ await asyncio .sleep (0.05 )
179+ async with lock :
180+ concurrency -= 1
181+
182+ proc = _make_processor (on_span_start = AsyncMock (side_effect = slow_start ))
183+ queue = AsyncSpanQueue ()
184+
185+ # Enqueue 10 START events before the drain loop runs — they should
186+ # all land in the same batch and be processed concurrently.
187+ for i in range (10 ):
188+ queue .enqueue (SpanEventType .START , _make_span (f"span-{ i } " ), [proc ])
189+
190+ await queue .shutdown ()
191+
192+ assert max_concurrency > 1 , (
193+ f"Expected concurrent processing, but max concurrency was { max_concurrency } "
194+ )
195+
196+ async def test_batch_faster_than_serial (self ):
197+ """Batched drain should be significantly faster than serial for slow processors."""
198+ n_items = 10
199+ per_item_delay = 0.05 # 50ms per processor call
200+
201+ async def slow_start (span : Span ) -> None :
202+ await asyncio .sleep (per_item_delay )
203+
204+ proc = _make_processor (on_span_start = AsyncMock (side_effect = slow_start ))
205+ queue = AsyncSpanQueue ()
206+
207+ for i in range (n_items ):
208+ queue .enqueue (SpanEventType .START , _make_span (f"span-{ i } " ), [proc ])
209+
210+ start = time .monotonic ()
211+ await queue .shutdown ()
212+ elapsed = time .monotonic () - start
213+
214+ serial_time = n_items * per_item_delay
215+ assert elapsed < serial_time * 0.5 , (
216+ f"Batch drain took { elapsed :.3f} s — serial would be { serial_time :.3f} s. "
217+ f"Expected at least 2x speedup from concurrency."
218+ )
219+
220+
157221class TestAsyncSpanQueueIntegration :
158222 async def test_integration_with_async_trace (self ):
159223 call_log : list [tuple [str , str ]] = []
@@ -196,3 +260,55 @@ async def record_end(span: Span) -> None:
196260 assert call_log [1 ][0 ] == "end"
197261 # Same span ID for both events
198262 assert call_log [0 ][1 ] == call_log [1 ][1 ]
263+
264+ async def test_end_event_preserves_modified_input (self ):
265+ """END event should carry span.input so modifications after start are preserved."""
266+ start_spans : list [Span ] = []
267+ end_spans : list [Span ] = []
268+
269+ async def capture_start (span : Span ) -> None :
270+ start_spans .append (span )
271+
272+ async def capture_end (span : Span ) -> None :
273+ end_spans .append (span )
274+
275+ proc = _make_processor (
276+ on_span_start = AsyncMock (side_effect = capture_start ),
277+ on_span_end = AsyncMock (side_effect = capture_end ),
278+ )
279+ queue = AsyncSpanQueue ()
280+
281+ from agentex .lib .core .tracing .trace import AsyncTrace
282+
283+ mock_client = MagicMock ()
284+ trace = AsyncTrace (
285+ processors = [proc ],
286+ client = mock_client ,
287+ trace_id = "test-trace" ,
288+ span_queue = queue ,
289+ )
290+
291+ initial_input : dict [str , object ] = {"messages" : [{"role" : "user" , "content" : "hello" }]}
292+ async with trace .span ("llm-call" , input = initial_input ) as span :
293+ # Simulate modifying input after start (e.g. chatbot appending messages)
294+ messages = cast (list [dict [str , str ]], cast (dict [str , object ], span .input )["messages" ])
295+ messages .append ({"role" : "assistant" , "content" : "hi there" })
296+ messages .append ({"role" : "user" , "content" : "how are you?" })
297+ span .output = cast (dict [str , object ], {"response" : "I'm good!" })
298+
299+ await queue .shutdown ()
300+
301+ assert len (start_spans ) == 1
302+ assert len (end_spans ) == 1
303+
304+ # START should carry the original input (serialized at start time)
305+ assert start_spans [0 ].input is not None
306+ assert len (cast (dict [str , list [object ]], start_spans [0 ].input )["messages" ]) == 1 # only the original message
307+
308+ # END should carry the modified input (re-serialized at end time)
309+ assert end_spans [0 ].input is not None
310+ assert len (cast (dict [str , list [object ]], end_spans [0 ].input )["messages" ]) == 3 # all three messages
311+
312+ # END should still carry output and end_time
313+ assert end_spans [0 ].output is not None
314+ assert end_spans [0 ].end_time is not None
0 commit comments