Skip to content

Commit 8e82838

Browse files
EItanyaGWeale
authored andcommitted
fix: Refactor Anthropic integration to support both direct API and Vertex AI
This change introduces an `AnthropicLlm` base class for direct Anthropic API calls using `AsyncAnthropic`. The existing `Claude` class now inherits from `AnthropicLlm` and is specialized to use `AsyncAnthropicVertex` for models hosted on Vertex AI. The `messages.create` call is now properly awaited Merge: google#2904 Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 838851026
1 parent 7edd7ea commit 8e82838

2 files changed

Lines changed: 57 additions & 8 deletions

File tree

src/google/adk/models/anthropic_llm.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
from typing import TYPE_CHECKING
2929
from typing import Union
3030

31-
from anthropic import AnthropicVertex
31+
from anthropic import AsyncAnthropic
32+
from anthropic import AsyncAnthropicVertex
3233
from anthropic import NOT_GIVEN
3334
from anthropic import types as anthropic_types
3435
from google.genai import types
@@ -41,7 +42,7 @@
4142
if TYPE_CHECKING:
4243
from .llm_request import LlmRequest
4344

44-
__all__ = ["Claude"]
45+
__all__ = ["AnthropicLlm", "Claude"]
4546

4647
logger = logging.getLogger("google_adk." + __name__)
4748

@@ -264,15 +265,15 @@ def function_declaration_to_tool_param(
264265
)
265266

266267

267-
class Claude(BaseLlm):
268-
"""Integration with Claude models served from Vertex AI.
268+
class AnthropicLlm(BaseLlm):
269+
"""Integration with Claude models via the Anthropic API.
269270
270271
Attributes:
271272
model: The name of the Claude model.
272273
max_tokens: The maximum number of tokens to generate.
273274
"""
274275

275-
model: str = "claude-3-5-sonnet-v2@20241022"
276+
model: str = "claude-sonnet-4-20250514"
276277
max_tokens: int = 8192
277278

278279
@classmethod
@@ -304,7 +305,7 @@ async def generate_content_async(
304305
else NOT_GIVEN
305306
)
306307
# TODO(b/421255973): Enable streaming for anthropic models.
307-
message = self._anthropic_client.messages.create(
308+
message = await self._anthropic_client.messages.create(
308309
model=llm_request.model,
309310
system=llm_request.config.system_instruction,
310311
messages=messages,
@@ -315,7 +316,23 @@ async def generate_content_async(
315316
yield message_to_generate_content_response(message)
316317

317318
@cached_property
318-
def _anthropic_client(self) -> AnthropicVertex:
319+
def _anthropic_client(self) -> AsyncAnthropic:
320+
return AsyncAnthropic()
321+
322+
323+
class Claude(AnthropicLlm):
324+
"""Integration with Claude models served from Vertex AI.
325+
326+
Attributes:
327+
model: The name of the Claude model.
328+
max_tokens: The maximum number of tokens to generate.
329+
"""
330+
331+
model: str = "claude-3-5-sonnet-v2@20241022"
332+
333+
@cached_property
334+
@override
335+
def _anthropic_client(self) -> AsyncAnthropicVertex:
319336
if (
320337
"GOOGLE_CLOUD_PROJECT" not in os.environ
321338
or "GOOGLE_CLOUD_LOCATION" not in os.environ
@@ -325,7 +342,7 @@ def _anthropic_client(self) -> AnthropicVertex:
325342
" Anthropic on Vertex."
326343
)
327344

328-
return AnthropicVertex(
345+
return AsyncAnthropicVertex(
329346
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
330347
region=os.environ["GOOGLE_CLOUD_LOCATION"],
331348
)

tests/unittests/models/test_anthropic_llm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from anthropic import types as anthropic_types
2020
from google.adk import version as adk_version
2121
from google.adk.models import anthropic_llm
22+
from google.adk.models.anthropic_llm import AnthropicLlm
2223
from google.adk.models.anthropic_llm import Claude
2324
from google.adk.models.anthropic_llm import content_to_message_param
2425
from google.adk.models.anthropic_llm import function_declaration_to_tool_param
@@ -359,6 +360,37 @@ async def mock_coro():
359360
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
360361

361362

363+
@pytest.mark.asyncio
364+
async def test_anthropic_llm_generate_content_async(
365+
llm_request, generate_content_response, generate_llm_response
366+
):
367+
anthropic_llm_instance = AnthropicLlm(model="claude-sonnet-4-20250514")
368+
with mock.patch.object(
369+
anthropic_llm_instance, "_anthropic_client"
370+
) as mock_client:
371+
with mock.patch.object(
372+
anthropic_llm,
373+
"message_to_generate_content_response",
374+
return_value=generate_llm_response,
375+
):
376+
# Create a mock coroutine that returns the generate_content_response.
377+
async def mock_coro():
378+
return generate_content_response
379+
380+
# Assign the coroutine to the mocked method
381+
mock_client.messages.create.return_value = mock_coro()
382+
383+
responses = [
384+
resp
385+
async for resp in anthropic_llm_instance.generate_content_async(
386+
llm_request, stream=False
387+
)
388+
]
389+
assert len(responses) == 1
390+
assert isinstance(responses[0], LlmResponse)
391+
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
392+
393+
362394
@pytest.mark.asyncio
363395
async def test_generate_content_async_with_max_tokens(
364396
llm_request, generate_content_response, generate_llm_response

0 commit comments

Comments
 (0)