Skip to content

Commit d88104b

Browse files
authored
add support for Temporal PayloadCodec (#328)
1 parent b739289 commit d88104b

7 files changed

Lines changed: 304 additions & 22 deletions

File tree

src/agentex/lib/core/clients/temporal/temporal_client.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from temporalio.client import Client, WorkflowExecutionStatus
88
from temporalio.common import RetryPolicy as TemporalRetryPolicy, WorkflowIDReusePolicy
99
from temporalio.service import RPCError, RPCStatusCode
10+
from temporalio.converter import PayloadCodec
1011

1112
from agentex.lib.utils.logging import make_logger
1213
from agentex.lib.utils.model_utils import BaseModel
@@ -76,9 +77,12 @@
7677

7778

7879
class TemporalClient:
79-
def __init__(self, temporal_client: Client | None = None, plugins: list[Any] = []):
80+
def __init__(
81+
self, temporal_client: Client | None = None, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None
82+
):
8083
self._client: Client | None = temporal_client
8184
self._plugins = plugins
85+
self._payload_codec = payload_codec
8286

8387
@property
8488
def client(self) -> Client:
@@ -88,7 +92,7 @@ def client(self) -> Client:
8892
return self._client
8993

9094
@classmethod
91-
async def create(cls, temporal_address: str, plugins: list[Any] = []):
95+
async def create(cls, temporal_address: str, plugins: list[Any] = [], payload_codec: PayloadCodec | None = None):
9296
if temporal_address in [
9397
"false",
9498
"False",
@@ -101,8 +105,8 @@ async def create(cls, temporal_address: str, plugins: list[Any] = []):
101105
]:
102106
_client = None
103107
else:
104-
_client = await get_temporal_client(temporal_address, plugins=plugins)
105-
return cls(_client, plugins)
108+
_client = await get_temporal_client(temporal_address, plugins=plugins, payload_codec=payload_codec)
109+
return cls(_client, plugins, payload_codec)
106110

107111
async def setup(self, temporal_address: str):
108112
self._client = await self._get_temporal_client(temporal_address=temporal_address)
@@ -120,7 +124,7 @@ async def _get_temporal_client(self, temporal_address: str) -> Client | None:
120124
]:
121125
return None
122126
else:
123-
return await get_temporal_client(temporal_address, plugins=self._plugins)
127+
return await get_temporal_client(temporal_address, plugins=self._plugins, payload_codec=self._payload_codec)
124128

125129
async def start_workflow(
126130
self,

src/agentex/lib/core/clients/temporal/utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
from typing import Any
45

56
from temporalio.client import Client, Plugin as ClientPlugin
67
from temporalio.worker import Interceptor
78
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
9+
from temporalio.converter import PayloadCodec
810
from temporalio.contrib.pydantic import pydantic_data_converter
911

1012
# class DateTimeJSONEncoder(AdvancedJSONEncoder):
@@ -79,14 +81,20 @@ def validate_worker_interceptors(interceptors: list[Any]) -> None:
7981
)
8082

8183

82-
async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list[Any] = []) -> Client:
84+
async def get_temporal_client(
85+
temporal_address: str,
86+
metrics_url: str | None = None,
87+
plugins: list[Any] = [],
88+
payload_codec: PayloadCodec | None = None,
89+
) -> Client:
8390
"""
8491
Create a Temporal client with plugin integration.
8592
8693
Args:
8794
temporal_address: Temporal server address
8895
metrics_url: Optional metrics endpoint URL
8996
plugins: List of Temporal plugins to include
97+
payload_codec: Optional payload codec for encoding/decoding payloads (e.g. encryption, compression)
9098
9199
Returns:
92100
Configured Temporal client
@@ -98,18 +106,26 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N
98106
# Check if OpenAI plugin is present - it needs to configure its own data converter
99107
# Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents
100108
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
101-
has_openai_plugin = any(
102-
isinstance(p, OpenAIAgentsPlugin) for p in (plugins or [])
103-
)
104109

105-
# Only set data_converter if OpenAI plugin is not present
110+
has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []))
111+
112+
if has_openai_plugin and payload_codec is not None:
113+
raise ValueError(
114+
"payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin "
115+
"installs its own data converter and the codec would be silently ignored, "
116+
"leaving payloads unencoded. Remove one or the other."
117+
)
118+
106119
connect_kwargs = {
107120
"target_host": temporal_address,
108121
"plugins": plugins,
109122
}
110123

111124
if not has_openai_plugin:
112-
connect_kwargs["data_converter"] = pydantic_data_converter
125+
data_converter = pydantic_data_converter
126+
if payload_codec:
127+
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
128+
connect_kwargs["data_converter"] = data_converter
113129

114130
if not metrics_url:
115131
client = await Client.connect(**connect_kwargs)

src/agentex/lib/core/temporal/workers/worker.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from temporalio.runtime import Runtime, TelemetryConfig, OpenTelemetryConfig
2020
from temporalio.converter import (
21+
PayloadCodec,
2122
DataConverter,
2223
JSONTypeConverter,
2324
AdvancedJSONEncoder,
@@ -89,16 +90,27 @@ def _validate_interceptors(interceptors: list) -> None:
8990
)
9091

9192

92-
async def get_temporal_client(temporal_address: str, metrics_url: str | None = None, plugins: list = []) -> Client:
93+
async def get_temporal_client(
94+
temporal_address: str,
95+
metrics_url: str | None = None,
96+
plugins: list = [],
97+
payload_codec: PayloadCodec | None = None,
98+
) -> Client:
9399
if plugins != []: # We don't need to validate the plugins if they are empty
94100
_validate_plugins(plugins)
95101

96102
# Check if OpenAI plugin is present - it needs to configure its own data converter
97103
# Lazy import to avoid pulling in opentelemetry.sdk for non-Temporal agents
98104
from temporalio.contrib.openai_agents import OpenAIAgentsPlugin
99-
has_openai_plugin = any(
100-
isinstance(p, OpenAIAgentsPlugin) for p in (plugins or [])
101-
)
105+
106+
has_openai_plugin = any(isinstance(p, OpenAIAgentsPlugin) for p in (plugins or []))
107+
108+
if has_openai_plugin and payload_codec is not None:
109+
raise ValueError(
110+
"payload_codec is not supported alongside OpenAIAgentsPlugin: the plugin "
111+
"installs its own data converter and the codec would be silently ignored, "
112+
"leaving payloads unencoded. Remove one or the other."
113+
)
102114

103115
# Build connection kwargs
104116
connect_kwargs = {
@@ -108,7 +120,10 @@ async def get_temporal_client(temporal_address: str, metrics_url: str | None = N
108120

109121
# Only set data_converter if OpenAI plugin is not present
110122
if not has_openai_plugin:
111-
connect_kwargs["data_converter"] = custom_data_converter
123+
data_converter = custom_data_converter
124+
if payload_codec:
125+
data_converter = dataclasses.replace(data_converter, payload_codec=payload_codec)
126+
connect_kwargs["data_converter"] = data_converter
112127

113128
if not metrics_url:
114129
client = await Client.connect(**connect_kwargs)
@@ -129,17 +144,21 @@ def __init__(
129144
plugins: list = [],
130145
interceptors: list = [],
131146
metrics_url: str | None = None,
147+
payload_codec: PayloadCodec | None = None,
132148
):
133149
self.task_queue = task_queue
134150
self.activity_handles = []
135151
self.max_workers = max_workers
136152
self.max_concurrent_activities = max_concurrent_activities
137153
self.health_check_server_running = False
138154
self.healthy = False
139-
self.health_check_port = health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT
155+
self.health_check_port = (
156+
health_check_port if health_check_port is not None else EnvironmentVariables.refresh().HEALTH_CHECK_PORT
157+
)
140158
self.plugins = plugins
141159
self.interceptors = interceptors
142160
self.metrics_url = metrics_url
161+
self.payload_codec = payload_codec
143162

144163
@overload
145164
async def run(
@@ -175,6 +194,7 @@ async def run(
175194
temporal_address=os.environ.get("TEMPORAL_ADDRESS", "localhost:7233"),
176195
plugins=self.plugins,
177196
metrics_url=self.metrics_url,
197+
payload_codec=self.payload_codec,
178198
)
179199

180200
# Enable debug mode if AgentEx debug is enabled (disables deadlock detection)

src/agentex/lib/sdk/fastacp/fastacp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class FastACP:
3434
Supports three main ACP types:
3535
- "sync": Simple synchronous ACP implementation
3636
- "async": Advanced ACP with sub-types "base" or "temporal" (requires config)
37-
- "agentic": (Deprecated, use "async") Identical to "async"
37+
- "agentic": (Deprecated, use "async") Identical to "async"
3838
"""
3939

4040
@staticmethod
@@ -63,6 +63,8 @@ def create_async_acp(config: AsyncACPConfig, **kwargs) -> BaseACPServer:
6363
temporal_config["plugins"] = config.plugins # type: ignore[attr-defined]
6464
if hasattr(config, "interceptors"):
6565
temporal_config["interceptors"] = config.interceptors # type: ignore[attr-defined]
66+
if hasattr(config, "payload_codec"):
67+
temporal_config["payload_codec"] = config.payload_codec # type: ignore[attr-defined]
6668
return implementation_class.create(**temporal_config)
6769
else:
6870
return implementation_class.create(**kwargs)

src/agentex/lib/sdk/fastacp/impl/temporal_acp.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from contextlib import asynccontextmanager
55

66
from fastapi import FastAPI
7+
from temporalio.converter import PayloadCodec
78

89
from agentex.lib.types.acp import (
910
SendEventParams,
@@ -31,20 +32,30 @@ def __init__(
3132
temporal_task_service: TemporalTaskService | None = None,
3233
plugins: list[Any] | None = None,
3334
interceptors: list[Any] | None = None,
35+
payload_codec: PayloadCodec | None = None,
3436
):
3537
super().__init__()
3638
self._temporal_task_service = temporal_task_service
3739
self._temporal_address = temporal_address
3840
self._plugins = plugins or []
3941
self._interceptors = interceptors or []
42+
self._payload_codec = payload_codec
4043

4144
@classmethod
4245
@override
43-
def create(cls, temporal_address: str, plugins: list[Any] | None = None, interceptors: list[Any] | None = None) -> "TemporalACP":
46+
def create(
47+
cls,
48+
temporal_address: str,
49+
plugins: list[Any] | None = None,
50+
interceptors: list[Any] | None = None,
51+
payload_codec: PayloadCodec | None = None,
52+
) -> "TemporalACP":
4453
logger.info("Initializing TemporalACP instance")
4554

4655
# Create instance without temporal client initially
47-
temporal_acp = cls(temporal_address=temporal_address, plugins=plugins, interceptors=interceptors)
56+
temporal_acp = cls(
57+
temporal_address=temporal_address, plugins=plugins, interceptors=interceptors, payload_codec=payload_codec
58+
)
4859
temporal_acp._setup_handlers()
4960
logger.info("TemporalACP instance initialized now")
5061
return temporal_acp
@@ -60,7 +71,7 @@ async def lifespan(app: FastAPI):
6071
if self._temporal_task_service is None:
6172
env_vars = EnvironmentVariables.refresh()
6273
temporal_client = await TemporalClient.create(
63-
temporal_address=self._temporal_address, plugins=self._plugins
74+
temporal_address=self._temporal_address, plugins=self._plugins, payload_codec=self._payload_codec
6475
)
6576
self._temporal_task_service = TemporalTaskService(
6677
temporal_client=temporal_client,

src/agentex/lib/types/fastacp.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ class AsyncACPConfig(BaseACPConfig):
3939

4040
type: Literal["temporal", "base"] = Field(..., frozen=True)
4141

42+
4243
AgenticACPConfig = AsyncACPConfig
4344

45+
4446
class TemporalACPConfig(AsyncACPConfig):
4547
"""
4648
Configuration for TemporalACP implementation
@@ -50,12 +52,18 @@ class TemporalACPConfig(AsyncACPConfig):
5052
temporal_address: The address of the temporal server
5153
plugins: List of Temporal client plugins
5254
interceptors: List of Temporal worker interceptors
55+
payload_codec: Optional ``temporalio.converter.PayloadCodec`` for
56+
encoding/decoding payloads (e.g. encryption, compression). NOTE:
57+
this only configures the ACP (client) side. The worker side must
58+
be configured separately via ``AgentexWorker(payload_codec=...)``
59+
with the SAME codec, or decode will fail at runtime.
5360
"""
5461

5562
type: Literal["temporal"] = Field(default="temporal", frozen=True)
5663
temporal_address: str = Field(default="temporal-frontend.temporal.svc.cluster.local:7233", frozen=True)
5764
plugins: list[Any] = Field(default=[], frozen=True)
5865
interceptors: list[Any] = Field(default=[], frozen=True)
66+
payload_codec: Any = Field(default=None, frozen=True)
5967

6068
@field_validator("plugins")
6169
@classmethod
@@ -81,4 +89,5 @@ class AsyncBaseACPConfig(AsyncACPConfig):
8189

8290
type: Literal["base"] = Field(default="base", frozen=True)
8391

84-
AgenticBaseACPConfig = AsyncBaseACPConfig
92+
93+
AgenticBaseACPConfig = AsyncBaseACPConfig

0 commit comments

Comments
 (0)