Skip to content

Commit d8918fd

Browse files
Add temporal-spring-ai module for Spring AI integration
Adds a new module that integrates Spring AI with Temporal workflows, enabling durable AI model calls, vector store operations, embeddings, and MCP tool execution as Temporal activities. Key components: - ActivityChatModel: ChatModel implementation backed by activities - TemporalChatClient: Temporal-aware ChatClient with tool detection - SpringAiPlugin: Auto-registers Spring AI activities with workers - Tool system: @DeterministicTool, @SideEffectTool, activity-backed tools - MCP integration: ActivityMcpClient for durable MCP tool calls Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4557ac8 commit d8918fd

34 files changed

Lines changed: 3455 additions & 0 deletions

settings.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include 'temporal-testing'
66
include 'temporal-test-server'
77
include 'temporal-opentracing'
88
include 'temporal-kotlin'
9+
include 'temporal-spring-ai'
910
include 'temporal-spring-boot-autoconfigure'
1011
include 'temporal-spring-boot-starter'
1112
include 'temporal-remote-data-encoder'

temporal-bom/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies {
1212
api project(':temporal-sdk')
1313
api project(':temporal-serviceclient')
1414
api project(':temporal-shaded')
15+
api project(':temporal-spring-ai')
1516
api project(':temporal-spring-boot-autoconfigure')
1617
api project(':temporal-spring-boot-starter')
1718
api project(':temporal-test-server')

temporal-spring-ai/build.gradle

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
description = '''Temporal Java SDK Spring AI Plugin'''
2+
3+
ext {
4+
springAiVersion = '1.1.0'
5+
// Spring AI requires Spring Boot 3.x / Java 17+
6+
springBootVersionForSpringAi = "$springBoot3Version"
7+
}
8+
9+
// Spring AI requires Java 17+, override the default Java 8 target from java.gradle
10+
java {
11+
sourceCompatibility = JavaVersion.VERSION_17
12+
targetCompatibility = JavaVersion.VERSION_17
13+
}
14+
15+
compileJava {
16+
options.compilerArgs.removeAll(['--release', '8'])
17+
options.compilerArgs.addAll(['--release', '17'])
18+
}
19+
20+
compileTestJava {
21+
options.compilerArgs.removeAll(['--release', '8'])
22+
options.compilerArgs.addAll(['--release', '17'])
23+
}
24+
25+
dependencies {
26+
api(platform("org.springframework.boot:spring-boot-dependencies:$springBootVersionForSpringAi"))
27+
api(platform("org.springframework.ai:spring-ai-bom:$springAiVersion"))
28+
29+
// this module shouldn't carry temporal-sdk with it, especially for situations when users may be using a shaded artifact
30+
compileOnly project(':temporal-sdk')
31+
compileOnly project(':temporal-spring-boot-autoconfigure')
32+
33+
api 'org.springframework.boot:spring-boot-autoconfigure'
34+
api 'org.springframework.ai:spring-ai-client-chat'
35+
36+
implementation 'org.springframework.boot:spring-boot-starter'
37+
38+
// Optional: Vector store support
39+
compileOnly 'org.springframework.ai:spring-ai-rag'
40+
41+
// Optional: MCP (Model Context Protocol) support
42+
compileOnly 'org.springframework.ai:spring-ai-mcp'
43+
44+
testImplementation project(':temporal-sdk')
45+
testImplementation project(':temporal-testing')
46+
testImplementation "org.mockito:mockito-core:${mockitoVersion}"
47+
testImplementation 'org.springframework.boot:spring-boot-starter-test'
48+
49+
testRuntimeOnly group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}"
50+
}
51+
52+
tasks.test {
53+
useJUnitPlatform()
54+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package io.temporal.springai.activity;
2+
3+
import io.temporal.activity.ActivityInterface;
4+
import io.temporal.activity.ActivityMethod;
5+
import io.temporal.springai.model.ChatModelTypes;
6+
7+
/**
8+
* Temporal activity interface for calling Spring AI chat models.
9+
*
10+
* <p>This activity wraps a Spring AI {@link org.springframework.ai.chat.model.ChatModel} and makes
11+
* it callable from within Temporal workflows. The activity handles serialization of prompts and
12+
* responses, enabling durable AI conversations with automatic retries and timeout handling.
13+
*/
14+
@ActivityInterface
15+
public interface ChatModelActivity {
16+
17+
/**
18+
* Calls the chat model with the given input.
19+
*
20+
* @param input the chat model input containing messages, options, and tool definitions
21+
* @return the chat model output containing generated responses and metadata
22+
*/
23+
@ActivityMethod
24+
ChatModelTypes.ChatModelActivityOutput callChatModel(ChatModelTypes.ChatModelActivityInput input);
25+
}
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
package io.temporal.springai.activity;
2+
3+
import io.temporal.springai.model.ChatModelTypes;
4+
import io.temporal.springai.model.ChatModelTypes.Message;
5+
import java.net.URI;
6+
import java.net.URISyntaxException;
7+
import java.util.List;
8+
import java.util.Map;
9+
import java.util.stream.Collectors;
10+
import org.springframework.ai.chat.messages.*;
11+
import org.springframework.ai.chat.model.ChatModel;
12+
import org.springframework.ai.chat.model.ChatResponse;
13+
import org.springframework.ai.chat.prompt.Prompt;
14+
import org.springframework.ai.content.Media;
15+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
16+
import org.springframework.ai.tool.ToolCallback;
17+
import org.springframework.ai.tool.definition.ToolDefinition;
18+
import org.springframework.core.io.ByteArrayResource;
19+
import org.springframework.util.CollectionUtils;
20+
import org.springframework.util.MimeType;
21+
22+
/**
23+
* Implementation of {@link ChatModelActivity} that delegates to a Spring AI {@link ChatModel}.
24+
*
25+
* <p>This implementation handles the conversion between Temporal-serializable types ({@link
26+
* ChatModelTypes}) and Spring AI types.
27+
*
28+
* <p>Supports multiple chat models. The model to use is determined by the {@code modelName} field
29+
* in the input. If no model name is specified, the default model is used.
30+
*/
31+
public class ChatModelActivityImpl implements ChatModelActivity {
32+
33+
private final Map<String, ChatModel> chatModels;
34+
private final String defaultModelName;
35+
36+
/**
37+
* Creates an activity implementation with a single chat model.
38+
*
39+
* @param chatModel the chat model to use
40+
*/
41+
public ChatModelActivityImpl(ChatModel chatModel) {
42+
this.chatModels = Map.of("default", chatModel);
43+
this.defaultModelName = "default";
44+
}
45+
46+
/**
47+
* Creates an activity implementation with multiple chat models.
48+
*
49+
* @param chatModels map of model names to chat models
50+
* @param defaultModelName the name of the default model to use when none is specified
51+
*/
52+
public ChatModelActivityImpl(Map<String, ChatModel> chatModels, String defaultModelName) {
53+
this.chatModels = chatModels;
54+
this.defaultModelName = defaultModelName;
55+
}
56+
57+
@Override
58+
public ChatModelTypes.ChatModelActivityOutput callChatModel(
59+
ChatModelTypes.ChatModelActivityInput input) {
60+
ChatModel chatModel = resolveChatModel(input.modelName());
61+
Prompt prompt = createPrompt(input);
62+
ChatResponse response = chatModel.call(prompt);
63+
return toOutput(response);
64+
}
65+
66+
private ChatModel resolveChatModel(String modelName) {
67+
String name = (modelName != null && !modelName.isEmpty()) ? modelName : defaultModelName;
68+
ChatModel model = chatModels.get(name);
69+
if (model == null) {
70+
throw new IllegalArgumentException(
71+
"No chat model with name '" + name + "'. Available models: " + chatModels.keySet());
72+
}
73+
return model;
74+
}
75+
76+
private Prompt createPrompt(ChatModelTypes.ChatModelActivityInput input) {
77+
List<org.springframework.ai.chat.messages.Message> messages =
78+
input.messages().stream().map(this::toSpringMessage).collect(Collectors.toList());
79+
80+
ToolCallingChatOptions.Builder optionsBuilder =
81+
ToolCallingChatOptions.builder()
82+
.internalToolExecutionEnabled(false); // Let workflow handle tool execution
83+
84+
if (input.modelOptions() != null) {
85+
ChatModelTypes.ModelOptions opts = input.modelOptions();
86+
if (opts.model() != null) optionsBuilder.model(opts.model());
87+
if (opts.temperature() != null) optionsBuilder.temperature(opts.temperature());
88+
if (opts.maxTokens() != null) optionsBuilder.maxTokens(opts.maxTokens());
89+
if (opts.topP() != null) optionsBuilder.topP(opts.topP());
90+
if (opts.topK() != null) optionsBuilder.topK(opts.topK());
91+
if (opts.frequencyPenalty() != null) optionsBuilder.frequencyPenalty(opts.frequencyPenalty());
92+
if (opts.presencePenalty() != null) optionsBuilder.presencePenalty(opts.presencePenalty());
93+
if (opts.stopSequences() != null) optionsBuilder.stopSequences(opts.stopSequences());
94+
}
95+
96+
// Add tool callbacks (stubs that provide definitions but won't be executed
97+
// since internalToolExecutionEnabled is false)
98+
if (!CollectionUtils.isEmpty(input.tools())) {
99+
List<ToolCallback> toolCallbacks =
100+
input.tools().stream()
101+
.map(
102+
tool ->
103+
createStubToolCallback(
104+
tool.function().name(),
105+
tool.function().description(),
106+
tool.function().jsonSchema()))
107+
.collect(Collectors.toList());
108+
optionsBuilder.toolCallbacks(toolCallbacks);
109+
}
110+
111+
ToolCallingChatOptions chatOptions = optionsBuilder.build();
112+
113+
return Prompt.builder().messages(messages).chatOptions(chatOptions).build();
114+
}
115+
116+
private org.springframework.ai.chat.messages.Message toSpringMessage(Message message) {
117+
return switch (message.role()) {
118+
case SYSTEM -> new SystemMessage((String) message.rawContent());
119+
case USER -> {
120+
UserMessage.Builder builder = UserMessage.builder().text((String) message.rawContent());
121+
if (!CollectionUtils.isEmpty(message.mediaContents())) {
122+
builder.media(
123+
message.mediaContents().stream().map(this::toMedia).collect(Collectors.toList()));
124+
}
125+
yield builder.build();
126+
}
127+
case ASSISTANT ->
128+
AssistantMessage.builder()
129+
.content((String) message.rawContent())
130+
.properties(Map.of())
131+
.toolCalls(
132+
message.toolCalls() != null
133+
? message.toolCalls().stream()
134+
.map(
135+
tc ->
136+
new AssistantMessage.ToolCall(
137+
tc.id(),
138+
tc.type(),
139+
tc.function().name(),
140+
tc.function().arguments()))
141+
.collect(Collectors.toList())
142+
: List.of())
143+
.media(
144+
message.mediaContents() != null
145+
? message.mediaContents().stream()
146+
.map(this::toMedia)
147+
.collect(Collectors.toList())
148+
: List.of())
149+
.build();
150+
case TOOL ->
151+
ToolResponseMessage.builder()
152+
.responses(
153+
List.of(
154+
new ToolResponseMessage.ToolResponse(
155+
message.toolCallId(), message.name(), (String) message.rawContent())))
156+
.build();
157+
};
158+
}
159+
160+
private Media toMedia(ChatModelTypes.MediaContent mediaContent) {
161+
MimeType mimeType = MimeType.valueOf(mediaContent.mimeType());
162+
if (mediaContent.uri() != null) {
163+
try {
164+
return new Media(mimeType, new URI(mediaContent.uri()));
165+
} catch (URISyntaxException e) {
166+
throw new RuntimeException("Invalid media URI: " + mediaContent.uri(), e);
167+
}
168+
} else if (mediaContent.data() != null) {
169+
return new Media(mimeType, new ByteArrayResource(mediaContent.data()));
170+
}
171+
throw new IllegalArgumentException("Media content must have either uri or data");
172+
}
173+
174+
private ChatModelTypes.ChatModelActivityOutput toOutput(ChatResponse response) {
175+
List<ChatModelTypes.ChatModelActivityOutput.Generation> generations =
176+
response.getResults().stream()
177+
.map(
178+
gen ->
179+
new ChatModelTypes.ChatModelActivityOutput.Generation(
180+
fromAssistantMessage(gen.getOutput())))
181+
.collect(Collectors.toList());
182+
183+
ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata metadata = null;
184+
if (response.getMetadata() != null) {
185+
var rateLimit = response.getMetadata().getRateLimit();
186+
var usage = response.getMetadata().getUsage();
187+
188+
metadata =
189+
new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata(
190+
response.getMetadata().getModel(),
191+
rateLimit != null
192+
? new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.RateLimit(
193+
rateLimit.getRequestsLimit(),
194+
rateLimit.getRequestsRemaining(),
195+
rateLimit.getRequestsReset(),
196+
rateLimit.getTokensLimit(),
197+
rateLimit.getTokensRemaining(),
198+
rateLimit.getTokensReset())
199+
: null,
200+
usage != null
201+
? new ChatModelTypes.ChatModelActivityOutput.ChatResponseMetadata.Usage(
202+
usage.getPromptTokens() != null ? usage.getPromptTokens().intValue() : null,
203+
usage.getCompletionTokens() != null
204+
? usage.getCompletionTokens().intValue()
205+
: null,
206+
usage.getTotalTokens() != null ? usage.getTotalTokens().intValue() : null)
207+
: null);
208+
}
209+
210+
return new ChatModelTypes.ChatModelActivityOutput(generations, metadata);
211+
}
212+
213+
private Message fromAssistantMessage(AssistantMessage assistantMessage) {
214+
List<Message.ToolCall> toolCalls = null;
215+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
216+
toolCalls =
217+
assistantMessage.getToolCalls().stream()
218+
.map(
219+
tc ->
220+
new Message.ToolCall(
221+
tc.id(),
222+
tc.type(),
223+
new Message.ChatCompletionFunction(tc.name(), tc.arguments())))
224+
.collect(Collectors.toList());
225+
}
226+
227+
List<ChatModelTypes.MediaContent> mediaContents = null;
228+
if (!CollectionUtils.isEmpty(assistantMessage.getMedia())) {
229+
mediaContents =
230+
assistantMessage.getMedia().stream().map(this::fromMedia).collect(Collectors.toList());
231+
}
232+
233+
return new Message(
234+
assistantMessage.getText(), Message.Role.ASSISTANT, null, null, toolCalls, mediaContents);
235+
}
236+
237+
private ChatModelTypes.MediaContent fromMedia(Media media) {
238+
String mimeType = media.getMimeType().toString();
239+
if (media.getData() instanceof String uri) {
240+
return new ChatModelTypes.MediaContent(mimeType, uri);
241+
} else if (media.getData() instanceof byte[] data) {
242+
return new ChatModelTypes.MediaContent(mimeType, data);
243+
}
244+
throw new IllegalArgumentException(
245+
"Unsupported media data type: " + media.getData().getClass());
246+
}
247+
248+
/**
249+
* Creates a stub ToolCallback that provides a tool definition but throws if called. This is used
250+
* because Spring AI's ChatModel API requires ToolCallbacks, but we only need to inform the model
251+
* about available tools - actual execution happens in the workflow (since
252+
* internalToolExecutionEnabled is false).
253+
*/
254+
private ToolCallback createStubToolCallback(String name, String description, String inputSchema) {
255+
ToolDefinition toolDefinition =
256+
ToolDefinition.builder()
257+
.name(name)
258+
.description(description)
259+
.inputSchema(inputSchema)
260+
.build();
261+
262+
return new ToolCallback() {
263+
@Override
264+
public ToolDefinition getToolDefinition() {
265+
return toolDefinition;
266+
}
267+
268+
@Override
269+
public String call(String toolInput) {
270+
throw new UnsupportedOperationException(
271+
"Tool execution should be handled by the workflow, not the activity. "
272+
+ "Ensure internalToolExecutionEnabled is set to false.");
273+
}
274+
};
275+
}
276+
}

0 commit comments

Comments
 (0)