Skip to content

Commit a862a06

Browse files
levilentzdeclan-scale
authored andcommitted
fix(adk): fix to queue drain (#327)
Co-authored-by: Declan Brady <declan.brady@scale.com>
1 parent 710c63f commit a862a06

6 files changed

Lines changed: 517 additions & 15 deletions

File tree

src/agentex/lib/core/tracing/processors/agentex_tracing_processor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def __init__(self, config: AgentexTracingProcessorConfig): # noqa: ARG002
7979
),
8080
)
8181

82+
# TODO(AGX1-199): Add batch create/update endpoints to Agentex API and use
83+
# them here instead of one HTTP call per span.
84+
# https://linear.app/scale-epd/issue/AGX1-199/add-agentex-batch-endpoint-for-traces
8285
@override
8386
async def on_span_start(self, span: Span) -> None:
8487
await self.client.spans.create(

src/agentex/lib/core/tracing/processors/sgp_tracing_processor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ async def on_span_start(self, span: Span) -> None:
141141
if self.disabled:
142142
logger.warning("SGP is disabled, skipping span upsert")
143143
return
144+
# TODO(AGX1-198): Batch multiple spans into a single upsert_batch call
145+
# instead of one span per HTTP request.
146+
# https://linear.app/scale-epd/issue/AGX1-198/actually-use-sgp-batching-for-spans
144147
await self.sgp_async_client.spans.upsert_batch( # type: ignore[union-attr]
145148
items=[sgp_span.to_request_params()]
146149
)
@@ -155,6 +158,7 @@ async def on_span_end(self, span: Span) -> None:
155158
return
156159

157160
self._add_source_to_span(span)
161+
sgp_span.input = span.input # type: ignore[assignment]
158162
sgp_span.output = span.output # type: ignore[assignment]
159163
sgp_span.metadata = span.data # type: ignore[assignment]
160164
sgp_span.end_time = span.end_time.isoformat() # type: ignore[union-attr]

src/agentex/lib/core/tracing/span_queue.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
logger = make_logger(__name__)
1414

15+
_DEFAULT_BATCH_SIZE = 50
16+
1517

1618
class SpanEventType(str, Enum):
1719
START = "start"
@@ -28,15 +30,18 @@ class _SpanQueueItem:
2830
class AsyncSpanQueue:
2931
"""Background FIFO queue for async span processing.
3032
31-
Span events are enqueued synchronously (non-blocking) and processed
32-
sequentially by a background drain task. This keeps tracing HTTP calls
33-
off the critical request path while preserving start-before-end ordering.
33+
Span events are enqueued synchronously (non-blocking) and drained by a
34+
background task. Items are processed in batches: all START events in a
35+
batch are flushed concurrently, then all END events, so that per-span
36+
start-before-end ordering is preserved while HTTP calls for independent
37+
spans execute in parallel.
3438
"""
3539

36-
def __init__(self) -> None:
40+
def __init__(self, batch_size: int = _DEFAULT_BATCH_SIZE) -> None:
3741
self._queue: asyncio.Queue[_SpanQueueItem] = asyncio.Queue()
3842
self._drain_task: asyncio.Task[None] | None = None
3943
self._stopping = False
44+
self._batch_size = batch_size
4045

4146
def enqueue(
4247
self,
@@ -54,9 +59,45 @@ def _ensure_drain_running(self) -> None:
5459
if self._drain_task is None or self._drain_task.done():
5560
self._drain_task = asyncio.create_task(self._drain_loop())
5661

62+
# ------------------------------------------------------------------
63+
# Drain loop
64+
# ------------------------------------------------------------------
65+
5766
async def _drain_loop(self) -> None:
5867
while True:
59-
item = await self._queue.get()
68+
# Block until at least one item is available.
69+
first = await self._queue.get()
70+
batch: list[_SpanQueueItem] = [first]
71+
72+
# Opportunistically grab more ready items (non-blocking).
73+
while len(batch) < self._batch_size:
74+
try:
75+
batch.append(self._queue.get_nowait())
76+
except asyncio.QueueEmpty:
77+
break
78+
79+
try:
80+
# Separate START and END events. Processing all STARTs before
81+
# ENDs ensures that on_span_start completes before on_span_end
82+
# for any span whose both events land in the same batch.
83+
starts = [i for i in batch if i.event_type == SpanEventType.START]
84+
ends = [i for i in batch if i.event_type == SpanEventType.END]
85+
86+
if starts:
87+
await self._process_items(starts)
88+
if ends:
89+
await self._process_items(ends)
90+
finally:
91+
for _ in batch:
92+
self._queue.task_done()
93+
# Release span data for GC.
94+
batch.clear()
95+
96+
@staticmethod
97+
async def _process_items(items: list[_SpanQueueItem]) -> None:
98+
"""Process a list of span events concurrently."""
99+
100+
async def _handle(item: _SpanQueueItem) -> None:
60101
try:
61102
if item.event_type == SpanEventType.START:
62103
coros = [p.on_span_start(item.span) for p in item.processors]
@@ -72,9 +113,15 @@ async def _drain_loop(self) -> None:
72113
exc_info=result,
73114
)
74115
except Exception:
75-
logger.exception("Unexpected error in span queue drain loop for span %s", item.span.id)
76-
finally:
77-
self._queue.task_done()
116+
logger.exception(
117+
"Unexpected error in span queue for span %s", item.span.id
118+
)
119+
120+
await asyncio.gather(*[_handle(item) for item in items])
121+
122+
# ------------------------------------------------------------------
123+
# Shutdown
124+
# ------------------------------------------------------------------
78125

79126
async def shutdown(self, timeout: float = 30.0) -> None:
80127
self._stopping = True

tests/lib/core/tracing/processors/test_sgp_tracing_processor.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,29 @@ async def test_span_end_for_unknown_span_is_noop(self):
162162
await processor.on_span_end(span)
163163

164164
assert len(processor._spans) == 0
165+
166+
async def test_sgp_span_input_updated_on_end(self):
167+
"""on_span_end should update sgp_span.input from the incoming span."""
168+
processor, _ = self._make_processor()
169+
170+
with patch(f"{MODULE}.create_span", side_effect=lambda **kw: _make_mock_sgp_span()):
171+
span = _make_span()
172+
span.input = {"messages": [{"role": "user", "content": "hello"}]}
173+
await processor.on_span_start(span)
174+
175+
assert len(processor._spans) == 1
176+
177+
# Simulate modified input at end time
178+
updated_input: dict[str, object] = {"messages": [
179+
{"role": "user", "content": "hello"},
180+
{"role": "assistant", "content": "hi"},
181+
]}
182+
span.input = updated_input
183+
span.output = {"response": "hi"}
184+
span.end_time = datetime.now(UTC)
185+
await processor.on_span_end(span)
186+
187+
# Span should be removed after end
188+
assert len(processor._spans) == 0
189+
# The end upsert should have been called
190+
assert processor.sgp_async_client.spans.upsert_batch.call_count == 2 # start + end

tests/lib/core/tracing/test_span_queue.py

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import uuid
55
import asyncio
6+
from typing import cast
67
from datetime import UTC, datetime
78
from unittest.mock import AsyncMock, MagicMock, patch
89

@@ -52,7 +53,8 @@ async def slow_start(span: Span) -> None:
5253

5354

5455
class TestAsyncSpanQueueOrdering:
55-
async def test_fifo_ordering_preserved(self):
56+
async def test_per_span_start_before_end(self):
57+
"""START always completes before END for the same span, even with batching."""
5658
call_log: list[tuple[str, str]] = []
5759

5860
async def record_start(span: Span) -> None:
@@ -77,12 +79,19 @@ async def record_end(span: Span) -> None:
7779

7880
await queue.shutdown()
7981

80-
assert call_log == [
81-
("start", "span-a"),
82-
("end", "span-a"),
83-
("start", "span-b"),
84-
("end", "span-b"),
85-
]
82+
# All 4 events should fire
83+
assert len(call_log) == 4
84+
85+
# Per-span invariant: START before END
86+
for span_id in ("span-a", "span-b"):
87+
start_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "start" and sid == span_id)
88+
end_idx = next(i for i, (ev, sid) in enumerate(call_log) if ev == "end" and sid == span_id)
89+
assert start_idx < end_idx, f"START should come before END for {span_id}"
90+
91+
# All STARTs before all ENDs within a batch
92+
start_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "start"]
93+
end_indices = [i for i, (ev, _) in enumerate(call_log) if ev == "end"]
94+
assert max(start_indices) < min(end_indices), "All STARTs should complete before any END"
8695

8796

8897
class TestAsyncSpanQueueErrorHandling:
@@ -154,6 +163,61 @@ async def test_enqueue_after_shutdown_is_dropped(self):
154163
proc.on_span_start.assert_not_called()
155164

156165

166+
class TestAsyncSpanQueueBatchConcurrency:
167+
async def test_batch_processes_multiple_items_concurrently(self):
168+
"""Items in the same batch should run concurrently, not serially."""
169+
concurrency = 0
170+
max_concurrency = 0
171+
lock = asyncio.Lock()
172+
173+
async def slow_start(span: Span) -> None:
174+
nonlocal concurrency, max_concurrency
175+
async with lock:
176+
concurrency += 1
177+
max_concurrency = max(max_concurrency, concurrency)
178+
await asyncio.sleep(0.05)
179+
async with lock:
180+
concurrency -= 1
181+
182+
proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start))
183+
queue = AsyncSpanQueue()
184+
185+
# Enqueue 10 START events before the drain loop runs — they should
186+
# all land in the same batch and be processed concurrently.
187+
for i in range(10):
188+
queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc])
189+
190+
await queue.shutdown()
191+
192+
assert max_concurrency > 1, (
193+
f"Expected concurrent processing, but max concurrency was {max_concurrency}"
194+
)
195+
196+
async def test_batch_faster_than_serial(self):
197+
"""Batched drain should be significantly faster than serial for slow processors."""
198+
n_items = 10
199+
per_item_delay = 0.05 # 50ms per processor call
200+
201+
async def slow_start(span: Span) -> None:
202+
await asyncio.sleep(per_item_delay)
203+
204+
proc = _make_processor(on_span_start=AsyncMock(side_effect=slow_start))
205+
queue = AsyncSpanQueue()
206+
207+
for i in range(n_items):
208+
queue.enqueue(SpanEventType.START, _make_span(f"span-{i}"), [proc])
209+
210+
start = time.monotonic()
211+
await queue.shutdown()
212+
elapsed = time.monotonic() - start
213+
214+
serial_time = n_items * per_item_delay
215+
assert elapsed < serial_time * 0.5, (
216+
f"Batch drain took {elapsed:.3f}s — serial would be {serial_time:.3f}s. "
217+
f"Expected at least 2x speedup from concurrency."
218+
)
219+
220+
157221
class TestAsyncSpanQueueIntegration:
158222
async def test_integration_with_async_trace(self):
159223
call_log: list[tuple[str, str]] = []
@@ -196,3 +260,55 @@ async def record_end(span: Span) -> None:
196260
assert call_log[1][0] == "end"
197261
# Same span ID for both events
198262
assert call_log[0][1] == call_log[1][1]
263+
264+
async def test_end_event_preserves_modified_input(self):
265+
"""END event should carry span.input so modifications after start are preserved."""
266+
start_spans: list[Span] = []
267+
end_spans: list[Span] = []
268+
269+
async def capture_start(span: Span) -> None:
270+
start_spans.append(span)
271+
272+
async def capture_end(span: Span) -> None:
273+
end_spans.append(span)
274+
275+
proc = _make_processor(
276+
on_span_start=AsyncMock(side_effect=capture_start),
277+
on_span_end=AsyncMock(side_effect=capture_end),
278+
)
279+
queue = AsyncSpanQueue()
280+
281+
from agentex.lib.core.tracing.trace import AsyncTrace
282+
283+
mock_client = MagicMock()
284+
trace = AsyncTrace(
285+
processors=[proc],
286+
client=mock_client,
287+
trace_id="test-trace",
288+
span_queue=queue,
289+
)
290+
291+
initial_input: dict[str, object] = {"messages": [{"role": "user", "content": "hello"}]}
292+
async with trace.span("llm-call", input=initial_input) as span:
293+
# Simulate modifying input after start (e.g. chatbot appending messages)
294+
messages = cast(list[dict[str, str]], cast(dict[str, object], span.input)["messages"])
295+
messages.append({"role": "assistant", "content": "hi there"})
296+
messages.append({"role": "user", "content": "how are you?"})
297+
span.output = cast(dict[str, object], {"response": "I'm good!"})
298+
299+
await queue.shutdown()
300+
301+
assert len(start_spans) == 1
302+
assert len(end_spans) == 1
303+
304+
# START should carry the original input (serialized at start time)
305+
assert start_spans[0].input is not None
306+
assert len(cast(dict[str, list[object]], start_spans[0].input)["messages"]) == 1 # only the original message
307+
308+
# END should carry the modified input (re-serialized at end time)
309+
assert end_spans[0].input is not None
310+
assert len(cast(dict[str, list[object]], end_spans[0].input)["messages"]) == 3 # all three messages
311+
312+
# END should still carry output and end_time
313+
assert end_spans[0].output is not None
314+
assert end_spans[0].end_time is not None

0 commit comments

Comments
 (0)