|
23 | 23 | import com.google.adk.models.BaseLlmConnection; |
24 | 24 | import com.google.adk.models.LlmRequest; |
25 | 25 | import com.google.adk.models.LlmResponse; |
| 26 | +import com.google.auto.value.AutoValue; |
26 | 27 | import com.google.genai.types.Blob; |
27 | 28 | import com.google.genai.types.Content; |
28 | 29 | import com.google.genai.types.FunctionCall; |
29 | 30 | import com.google.genai.types.FunctionCallingConfigMode; |
30 | 31 | import com.google.genai.types.FunctionDeclaration; |
31 | 32 | import com.google.genai.types.FunctionResponse; |
32 | 33 | import com.google.genai.types.GenerateContentConfig; |
| 34 | +import com.google.genai.types.GenerateContentResponseUsageMetadata; |
33 | 35 | import com.google.genai.types.Part; |
34 | 36 | import com.google.genai.types.Schema; |
35 | 37 | import com.google.genai.types.ToolConfig; |
36 | 38 | import com.google.genai.types.Type; |
37 | | -import dev.langchain4j.Experimental; |
38 | 39 | import dev.langchain4j.agent.tool.ToolExecutionRequest; |
39 | 40 | import dev.langchain4j.agent.tool.ToolSpecification; |
40 | 41 | import dev.langchain4j.data.audio.Audio; |
|
52 | 53 | import dev.langchain4j.data.pdf.PdfFile; |
53 | 54 | import dev.langchain4j.data.video.Video; |
54 | 55 | import dev.langchain4j.exception.UnsupportedFeatureException; |
| 56 | +import dev.langchain4j.model.TokenCountEstimator; |
55 | 57 | import dev.langchain4j.model.chat.ChatModel; |
56 | 58 | import dev.langchain4j.model.chat.StreamingChatModel; |
57 | 59 | import dev.langchain4j.model.chat.request.ChatRequest; |
|
65 | 67 | import dev.langchain4j.model.chat.request.json.JsonStringSchema; |
66 | 68 | import dev.langchain4j.model.chat.response.ChatResponse; |
67 | 69 | import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; |
| 70 | +import dev.langchain4j.model.output.TokenUsage; |
68 | 71 | import io.reactivex.rxjava3.core.BackpressureStrategy; |
69 | 72 | import io.reactivex.rxjava3.core.Flowable; |
70 | 73 | import java.util.ArrayList; |
71 | 74 | import java.util.Base64; |
72 | 75 | import java.util.HashMap; |
73 | 76 | import java.util.List; |
74 | 77 | import java.util.Map; |
75 | | -import java.util.Objects; |
76 | 78 | import java.util.UUID; |
| 79 | +import org.jspecify.annotations.Nullable; |
77 | 80 |
|
78 | | -@Experimental |
79 | | -public class LangChain4j extends BaseLlm { |
| 81 | +@AutoValue |
| 82 | +public abstract class LangChain4j extends BaseLlm { |
80 | 83 |
|
81 | 84 | private static final TypeReference<Map<String, Object>> MAP_TYPE_REFERENCE = |
82 | 85 | new TypeReference<>() {}; |
83 | 86 |
|
84 | | - private final ChatModel chatModel; |
85 | | - private final StreamingChatModel streamingChatModel; |
86 | | - private final ObjectMapper objectMapper; |
| 87 | + LangChain4j() { |
| 88 | + super(""); |
| 89 | + } |
| 90 | + |
| 91 | + @Nullable |
| 92 | + public abstract ChatModel chatModel(); |
| 93 | + |
| 94 | + @Nullable |
| 95 | + public abstract StreamingChatModel streamingChatModel(); |
| 96 | + |
| 97 | + public abstract ObjectMapper objectMapper(); |
| 98 | + |
| 99 | + public abstract String modelName(); |
| 100 | + |
| 101 | + @Nullable |
| 102 | + public abstract TokenCountEstimator tokenCountEstimator(); |
| 103 | + |
| 104 | + @Override |
| 105 | + public String model() { |
| 106 | + return modelName(); |
| 107 | + } |
| 108 | + |
| 109 | + public static Builder builder() { |
| 110 | + return new AutoValue_LangChain4j.Builder().objectMapper(new ObjectMapper()); |
| 111 | + } |
| 112 | + |
| 113 | + @AutoValue.Builder |
| 114 | + public abstract static class Builder { |
| 115 | + public abstract Builder chatModel(ChatModel chatModel); |
| 116 | + |
| 117 | + public abstract Builder streamingChatModel(StreamingChatModel streamingChatModel); |
| 118 | + |
| 119 | + public abstract Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator); |
| 120 | + |
| 121 | + public abstract Builder objectMapper(ObjectMapper objectMapper); |
| 122 | + |
| 123 | + public abstract Builder modelName(String modelName); |
| 124 | + |
| 125 | + public abstract LangChain4j build(); |
| 126 | + } |
87 | 127 |
|
88 | 128 | public LangChain4j(ChatModel chatModel) { |
89 | | - super( |
90 | | - Objects.requireNonNull( |
91 | | - chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); |
92 | | - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); |
93 | | - this.streamingChatModel = null; |
94 | | - this.objectMapper = new ObjectMapper(); |
| 129 | + this(chatModel, null, null, chatModel.defaultRequestParameters().modelName(), null); |
95 | 130 | } |
96 | 131 |
|
97 | 132 | public LangChain4j(ChatModel chatModel, String modelName) { |
98 | | - super(Objects.requireNonNull(modelName, "chat model name cannot be null")); |
99 | | - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); |
100 | | - this.streamingChatModel = null; |
101 | | - this.objectMapper = new ObjectMapper(); |
| 133 | + this(chatModel, null, null, modelName, null); |
102 | 134 | } |
103 | 135 |
|
104 | 136 | public LangChain4j(StreamingChatModel streamingChatModel) { |
105 | | - super( |
106 | | - Objects.requireNonNull( |
107 | | - streamingChatModel.defaultRequestParameters().modelName(), |
108 | | - "streaming chat model name cannot be null")); |
109 | | - this.chatModel = null; |
110 | | - this.streamingChatModel = |
111 | | - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); |
112 | | - this.objectMapper = new ObjectMapper(); |
| 137 | + this( |
| 138 | + null, |
| 139 | + streamingChatModel, |
| 140 | + null, |
| 141 | + streamingChatModel.defaultRequestParameters().modelName(), |
| 142 | + null); |
113 | 143 | } |
114 | 144 |
|
115 | 145 | public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { |
116 | | - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); |
117 | | - this.chatModel = null; |
118 | | - this.streamingChatModel = |
119 | | - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); |
120 | | - this.objectMapper = new ObjectMapper(); |
| 146 | + this(null, streamingChatModel, null, modelName, null); |
121 | 147 | } |
122 | 148 |
|
123 | 149 | public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { |
124 | | - super(Objects.requireNonNull(modelName, "model name cannot be null")); |
125 | | - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); |
126 | | - this.streamingChatModel = |
127 | | - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); |
128 | | - this.objectMapper = new ObjectMapper(); |
| 150 | + this(chatModel, streamingChatModel, null, modelName, null); |
| 151 | + } |
| 152 | + |
| 153 | + private LangChain4j( |
| 154 | + ChatModel chatModel, |
| 155 | + StreamingChatModel streamingChatModel, |
| 156 | + ObjectMapper objectMapper, |
| 157 | + String modelName, |
| 158 | + TokenCountEstimator tokenCountEstimator) { |
| 159 | + this(); |
| 160 | + LangChain4j.builder() |
| 161 | + .chatModel(chatModel) |
| 162 | + .streamingChatModel(streamingChatModel) |
| 163 | + .objectMapper(objectMapper) |
| 164 | + .modelName(modelName) |
| 165 | + .tokenCountEstimator(tokenCountEstimator) |
| 166 | + .build(); |
129 | 167 | } |
130 | 168 |
|
131 | 169 | @Override |
132 | 170 | public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stream) { |
133 | 171 | if (stream) { |
134 | | - if (this.streamingChatModel == null) { |
| 172 | + if (this.streamingChatModel() == null) { |
135 | 173 | return Flowable.error(new IllegalStateException("StreamingChatModel is not configured")); |
136 | 174 | } |
137 | 175 |
|
138 | 176 | ChatRequest chatRequest = toChatRequest(llmRequest); |
139 | 177 |
|
140 | 178 | return Flowable.create( |
141 | 179 | emitter -> { |
142 | | - streamingChatModel.chat( |
143 | | - chatRequest, |
144 | | - new StreamingChatResponseHandler() { |
145 | | - @Override |
146 | | - public void onPartialResponse(String s) { |
147 | | - emitter.onNext( |
148 | | - LlmResponse.builder().content(Content.fromParts(Part.fromText(s))).build()); |
149 | | - } |
150 | | - |
151 | | - @Override |
152 | | - public void onCompleteResponse(ChatResponse chatResponse) { |
153 | | - if (chatResponse.aiMessage().hasToolExecutionRequests()) { |
154 | | - AiMessage aiMessage = chatResponse.aiMessage(); |
155 | | - toParts(aiMessage).stream() |
156 | | - .map(Part::functionCall) |
157 | | - .forEach( |
158 | | - functionCall -> { |
159 | | - functionCall.ifPresent( |
160 | | - function -> { |
161 | | - emitter.onNext( |
162 | | - LlmResponse.builder() |
163 | | - .content( |
164 | | - Content.fromParts( |
165 | | - Part.fromFunctionCall( |
166 | | - function.name().orElse(""), |
167 | | - function.args().orElse(Map.of())))) |
168 | | - .build()); |
169 | | - }); |
170 | | - }); |
171 | | - } |
172 | | - emitter.onComplete(); |
173 | | - } |
174 | | - |
175 | | - @Override |
176 | | - public void onError(Throwable throwable) { |
177 | | - emitter.onError(throwable); |
178 | | - } |
179 | | - }); |
| 180 | + streamingChatModel() |
| 181 | + .chat( |
| 182 | + chatRequest, |
| 183 | + new StreamingChatResponseHandler() { |
| 184 | + @Override |
| 185 | + public void onPartialResponse(String s) { |
| 186 | + emitter.onNext( |
| 187 | + LlmResponse.builder() |
| 188 | + .content(Content.fromParts(Part.fromText(s))) |
| 189 | + .build()); |
| 190 | + } |
| 191 | + |
| 192 | + @Override |
| 193 | + public void onCompleteResponse(ChatResponse chatResponse) { |
| 194 | + if (chatResponse.aiMessage().hasToolExecutionRequests()) { |
| 195 | + AiMessage aiMessage = chatResponse.aiMessage(); |
| 196 | + toParts(aiMessage).stream() |
| 197 | + .map(Part::functionCall) |
| 198 | + .forEach( |
| 199 | + functionCall -> { |
| 200 | + functionCall.ifPresent( |
| 201 | + function -> { |
| 202 | + emitter.onNext( |
| 203 | + LlmResponse.builder() |
| 204 | + .content( |
| 205 | + Content.fromParts( |
| 206 | + Part.fromFunctionCall( |
| 207 | + function.name().orElse(""), |
| 208 | + function.args().orElse(Map.of())))) |
| 209 | + .build()); |
| 210 | + }); |
| 211 | + }); |
| 212 | + } |
| 213 | + emitter.onComplete(); |
| 214 | + } |
| 215 | + |
| 216 | + @Override |
| 217 | + public void onError(Throwable throwable) { |
| 218 | + emitter.onError(throwable); |
| 219 | + } |
| 220 | + }); |
180 | 221 | }, |
181 | 222 | BackpressureStrategy.BUFFER); |
182 | 223 | } else { |
183 | | - if (this.chatModel == null) { |
| 224 | + if (this.chatModel() == null) { |
184 | 225 | return Flowable.error(new IllegalStateException("ChatModel is not configured")); |
185 | 226 | } |
186 | 227 |
|
187 | 228 | ChatRequest chatRequest = toChatRequest(llmRequest); |
188 | | - ChatResponse chatResponse = chatModel.chat(chatRequest); |
189 | | - LlmResponse llmResponse = toLlmResponse(chatResponse); |
| 229 | + ChatResponse chatResponse = chatModel().chat(chatRequest); |
| 230 | + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); |
190 | 231 |
|
191 | 232 | return Flowable.just(llmResponse); |
192 | 233 | } |
@@ -413,7 +454,7 @@ private AiMessage toAiMessage(Content content) { |
413 | 454 |
|
414 | 455 | private String toJson(Object object) { |
415 | 456 | try { |
416 | | - return objectMapper.writeValueAsString(object); |
| 457 | + return objectMapper().writeValueAsString(object); |
417 | 458 | } catch (JsonProcessingException e) { |
418 | 459 | throw new RuntimeException(e); |
419 | 460 | } |
@@ -511,11 +552,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { |
511 | 552 | } |
512 | 553 | } |
513 | 554 |
|
514 | | - private LlmResponse toLlmResponse(ChatResponse chatResponse) { |
| 555 | + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { |
515 | 556 | Content content = |
516 | 557 | Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); |
517 | 558 |
|
518 | | - return LlmResponse.builder().content(content).build(); |
| 559 | + LlmResponse.Builder builder = LlmResponse.builder().content(content); |
| 560 | + TokenUsage tokenUsage = chatResponse.tokenUsage(); |
| 561 | + if (tokenCountEstimator() != null) { |
| 562 | + try { |
| 563 | + int estimatedInput = |
| 564 | + tokenCountEstimator().estimateTokenCountInMessages(chatRequest.messages()); |
| 565 | + int estimatedOutput = |
| 566 | + tokenCountEstimator().estimateTokenCountInText(chatResponse.aiMessage().text()); |
| 567 | + int estimatedTotal = estimatedInput + estimatedOutput; |
| 568 | + builder.usageMetadata( |
| 569 | + GenerateContentResponseUsageMetadata.builder() |
| 570 | + .promptTokenCount(estimatedInput) |
| 571 | + .candidatesTokenCount(estimatedOutput) |
| 572 | + .totalTokenCount(estimatedTotal) |
| 573 | + .build()); |
| 574 | + } catch (Exception e) { |
| 575 | + e.printStackTrace(); |
| 576 | + } |
| 577 | + } else if (tokenUsage != null) { |
| 578 | + builder.usageMetadata( |
| 579 | + GenerateContentResponseUsageMetadata.builder() |
| 580 | + .promptTokenCount(tokenUsage.inputTokenCount()) |
| 581 | + .candidatesTokenCount(tokenUsage.outputTokenCount()) |
| 582 | + .totalTokenCount(tokenUsage.totalTokenCount()) |
| 583 | + .build()); |
| 584 | + } |
| 585 | + |
| 586 | + return builder.build(); |
519 | 587 | } |
520 | 588 |
|
521 | 589 | private List<Part> toParts(AiMessage aiMessage) { |
@@ -546,7 +614,7 @@ private List<Part> toParts(AiMessage aiMessage) { |
546 | 614 |
|
547 | 615 | private Map<String, Object> toArgs(ToolExecutionRequest toolExecutionRequest) { |
548 | 616 | try { |
549 | | - return objectMapper.readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); |
| 617 | + return objectMapper().readValue(toolExecutionRequest.arguments(), MAP_TYPE_REFERENCE); |
550 | 618 | } catch (JsonProcessingException e) { |
551 | 619 | throw new RuntimeException(e); |
552 | 620 | } |
|
0 commit comments