Skip to content

Commit 82a57b6

Browse files
declan-scalestainless-app[bot]
authored andcommitted
Add task_id to span creation (#329)
1 parent e37f5d0 commit 82a57b6

8 files changed

Lines changed: 375 additions & 5 deletions

File tree

src/agentex/lib/adk/_modules/tracing.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ def _tracing_service(self) -> TracingService:
6464
# Re-create the underlying httpx client when the event loop changes
6565
# (e.g. between HTTP requests in a sync ASGI server) to avoid
6666
# "Event loop is closed" / "bound to a different event loop" errors.
67-
if self._tracing_service_lazy is None or (
68-
loop_id is not None and loop_id != self._bound_loop_id
69-
):
67+
if self._tracing_service_lazy is None or (loop_id is not None and loop_id != self._bound_loop_id):
7068
import httpx
7169

7270
# Disable keepalive so each span HTTP call gets a fresh TCP
@@ -93,6 +91,7 @@ async def span(
9391
input: list[Any] | dict[str, Any] | BaseModel | None = None,
9492
data: list[Any] | dict[str, Any] | BaseModel | None = None,
9593
parent_id: str | None = None,
94+
task_id: str | None = None,
9695
start_to_close_timeout: timedelta = timedelta(seconds=5),
9796
heartbeat_timeout: timedelta = timedelta(seconds=5),
9897
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
@@ -109,6 +108,7 @@ async def span(
109108
input (Union[List, Dict, BaseModel]): The input for the span.
110109
parent_id (Optional[str]): The parent span ID for the span.
111110
data (Optional[Union[List, Dict, BaseModel]]): The data for the span.
111+
task_id (Optional[str]): The task ID this span belongs to.
112112
start_to_close_timeout (timedelta): The start to close timeout for the span.
113113
heartbeat_timeout (timedelta): The heartbeat timeout for the span.
114114
retry_policy (RetryPolicy): The retry policy for the span.
@@ -126,6 +126,7 @@ async def span(
126126
input=input,
127127
parent_id=parent_id,
128128
data=data,
129+
task_id=task_id,
129130
start_to_close_timeout=start_to_close_timeout,
130131
heartbeat_timeout=heartbeat_timeout,
131132
retry_policy=retry_policy,
@@ -149,6 +150,7 @@ async def start_span(
149150
input: list[Any] | dict[str, Any] | BaseModel | None = None,
150151
parent_id: str | None = None,
151152
data: list[Any] | dict[str, Any] | BaseModel | None = None,
153+
task_id: str | None = None,
152154
start_to_close_timeout: timedelta = timedelta(seconds=5),
153155
heartbeat_timeout: timedelta = timedelta(seconds=1),
154156
retry_policy: RetryPolicy = DEFAULT_RETRY_POLICY,
@@ -162,6 +164,7 @@ async def start_span(
162164
input (Union[List, Dict, BaseModel]): The input for the span.
163165
parent_id (Optional[str]): The parent span ID for the span.
164166
data (Optional[Union[List, Dict, BaseModel]]): The data for the span.
167+
task_id (Optional[str]): The task ID this span belongs to.
165168
start_to_close_timeout (timedelta): The start to close timeout for the span.
166169
heartbeat_timeout (timedelta): The heartbeat timeout for the span.
167170
retry_policy (RetryPolicy): The retry policy for the span.
@@ -175,6 +178,7 @@ async def start_span(
175178
name=name,
176179
input=input,
177180
data=data,
181+
task_id=task_id,
178182
)
179183
if in_temporal_workflow():
180184
return await ActivityHelpers.execute_activity(
@@ -192,6 +196,7 @@ async def start_span(
192196
input=input,
193197
parent_id=parent_id,
194198
data=data,
199+
task_id=task_id,
195200
)
196201

197202
async def end_span(

src/agentex/lib/core/services/adk/tracing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ async def start_span(
2222
parent_id: str | None = None,
2323
input: list[Any] | dict[str, Any] | BaseModel | None = None,
2424
data: list[Any] | dict[str, Any] | BaseModel | None = None,
25+
task_id: str | None = None,
2526
) -> Span | None:
2627
trace = self._tracer.trace(trace_id)
2728
span = await trace.start_span(
2829
name=name,
2930
parent_id=parent_id,
3031
input=input or {},
3132
data=data,
33+
task_id=task_id,
3234
)
3335
heartbeat_if_in_workflow("start span")
3436
return span

src/agentex/lib/core/temporal/activities/adk/tracing_activities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class StartSpanParams(BaseModel):
2424
name: str
2525
input: list[Any] | dict[str, Any] | BaseModel | None = None
2626
data: list[Any] | dict[str, Any] | BaseModel | None = None
27+
task_id: str | None = None
2728

2829

2930
class EndSpanParams(BaseModel):
@@ -47,6 +48,7 @@ async def start_span(self, params: StartSpanParams) -> Span | None:
4748
name=params.name,
4849
input=params.input,
4950
data=params.data,
51+
task_id=params.task_id,
5052
)
5153

5254
@activity.defn(name=TracingActivityName.END_SPAN)

src/agentex/lib/core/tracing/trace.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def start_span(
5454
parent_id: str | None = None,
5555
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
5656
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
57+
task_id: str | None = None,
5758
) -> Span:
5859
"""
5960
Start a new span and register it with the API.
@@ -63,6 +64,7 @@ def start_span(
6364
parent_id: Optional parent span ID.
6465
input: Optional input data for the span.
6566
data: Optional additional data for the span.
67+
task_id: Optional ID of the task this span belongs to.
6668
6769
Returns:
6870
The newly created span.
@@ -86,6 +88,7 @@ def start_span(
8688
start_time=start_time,
8789
input=serialized_input,
8890
data=serialized_data,
91+
task_id=task_id,
8992
)
9093

9194
for processor in self.processors:
@@ -150,6 +153,7 @@ def span(
150153
parent_id: str | None = None,
151154
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
152155
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
156+
task_id: str | None = None,
153157
):
154158
"""
155159
Context manager for spans.
@@ -158,7 +162,7 @@ def span(
158162
if not self.trace_id:
159163
yield None
160164
return
161-
span = self.start_span(name, parent_id, input, data)
165+
span = self.start_span(name, parent_id, input, data, task_id=task_id)
162166
try:
163167
yield span
164168
finally:
@@ -198,6 +202,7 @@ async def start_span(
198202
parent_id: str | None = None,
199203
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
200204
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
205+
task_id: str | None = None,
201206
) -> Span:
202207
"""
203208
Start a new span and register it with the API.
@@ -207,6 +212,7 @@ async def start_span(
207212
parent_id: Optional parent span ID.
208213
input: Optional input data for the span.
209214
data: Optional additional data for the span.
215+
task_id: Optional ID of the task this span belongs to.
210216
211217
Returns:
212218
The newly created span.
@@ -229,6 +235,7 @@ async def start_span(
229235
start_time=start_time,
230236
input=serialized_input,
231237
data=serialized_data,
238+
task_id=task_id,
232239
)
233240

234241
if self.processors:
@@ -293,6 +300,7 @@ async def span(
293300
parent_id: str | None = None,
294301
input: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
295302
data: dict[str, Any] | list[dict[str, Any]] | BaseModel | None = None,
303+
task_id: str | None = None,
296304
) -> AsyncGenerator[Span | None, None]:
297305
"""
298306
Context manager for spans.
@@ -302,14 +310,15 @@ async def span(
302310
parent_id: Optional parent span ID.
303311
input: Optional input data for the span.
304312
data: Optional additional data for the span.
313+
task_id: Optional ID of the task this span belongs to.
305314
306315
Yields:
307316
The span object.
308317
"""
309318
if not self.trace_id:
310319
yield None
311320
return
312-
span = await self.start_span(name, parent_id, input, data)
321+
span = await self.start_span(name, parent_id, input, data, task_id=task_id)
313322
try:
314323
yield span
315324
finally:
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from __future__ import annotations
2+
3+
from datetime import datetime, timezone
4+
from unittest.mock import AsyncMock
5+
6+
from temporalio.testing import ActivityEnvironment
7+
8+
from agentex.types.span import Span
9+
10+
11+
def _make_span(**overrides) -> Span:
12+
defaults = {
13+
"id": "span-123",
14+
"name": "test-span",
15+
"start_time": datetime(2026, 1, 1, tzinfo=timezone.utc),
16+
"trace_id": "trace-123",
17+
}
18+
defaults.update(overrides)
19+
return Span(**defaults)
20+
21+
22+
def _make_tracing_activities():
23+
from agentex.lib.core.services.adk.tracing import TracingService
24+
from agentex.lib.core.temporal.activities.adk.tracing_activities import TracingActivities
25+
26+
mock_service = AsyncMock(spec=TracingService)
27+
activities = TracingActivities(tracing_service=mock_service)
28+
env = ActivityEnvironment()
29+
return mock_service, activities, env
30+
31+
32+
class TestStartSpanActivity:
33+
async def test_start_span_with_task_id(self):
34+
from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams
35+
36+
mock_service, activities, env = _make_tracing_activities()
37+
expected = _make_span(task_id="task-abc")
38+
mock_service.start_span.return_value = expected
39+
40+
params = StartSpanParams(
41+
trace_id="trace-123",
42+
name="test-span",
43+
task_id="task-abc",
44+
)
45+
result = await env.run(activities.start_span, params)
46+
47+
assert result == expected
48+
assert result.task_id == "task-abc"
49+
mock_service.start_span.assert_called_once_with(
50+
trace_id="trace-123",
51+
parent_id=None,
52+
name="test-span",
53+
input=None,
54+
data=None,
55+
task_id="task-abc",
56+
)
57+
58+
async def test_start_span_without_task_id(self):
59+
from agentex.lib.core.temporal.activities.adk.tracing_activities import StartSpanParams
60+
61+
mock_service, activities, env = _make_tracing_activities()
62+
expected = _make_span()
63+
mock_service.start_span.return_value = expected
64+
65+
params = StartSpanParams(trace_id="trace-123", name="test-span")
66+
result = await env.run(activities.start_span, params)
67+
68+
assert result == expected
69+
mock_service.start_span.assert_called_once_with(
70+
trace_id="trace-123",
71+
parent_id=None,
72+
name="test-span",
73+
input=None,
74+
data=None,
75+
task_id=None,
76+
)
77+
78+
79+
class TestEndSpanActivity:
80+
async def test_end_span_preserves_task_id(self):
81+
from agentex.lib.core.temporal.activities.adk.tracing_activities import EndSpanParams
82+
83+
mock_service, activities, env = _make_tracing_activities()
84+
span = _make_span(task_id="task-abc")
85+
expected = _make_span(
86+
task_id="task-abc",
87+
end_time=datetime(2026, 1, 1, tzinfo=timezone.utc),
88+
)
89+
mock_service.end_span.return_value = expected
90+
91+
params = EndSpanParams(trace_id="trace-123", span=span)
92+
result = await env.run(activities.end_span, params)
93+
94+
assert result == expected
95+
assert result.task_id == "task-abc"
96+
mock_service.end_span.assert_called_once_with(trace_id="trace-123", span=span)

0 commit comments

Comments
 (0)