Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ async def _content_to_message_param(
for part in content.parts:
if part.function_call:
tool_call_id = part.function_call.id or ""
tool_call_dict: ChatCompletionAssistantToolCall = {
tool_call_dict: Dict[str, Any] = {
"type": "function",
"id": tool_call_id,
"function": {
Expand Down Expand Up @@ -1688,13 +1688,33 @@ def _message_to_generate_content_response(

if tool_calls:
for tool_call in tool_calls:
if tool_call.type == "function":
if (
isinstance(tool_call, dict) and tool_call.get("type") == "function"
) or (hasattr(tool_call, "type") and tool_call.type == "function"):
thought_signature = _extract_thought_signature_from_tool_call(tool_call)

func_name = ""
func_args_str = "{}"
call_id = ""

if isinstance(tool_call, dict):
call_id = tool_call.get("id", "")
func = tool_call.get("function")
if isinstance(func, dict):
func_name = func.get("name", "")
func_args_str = func.get("arguments", "{}")
else:
call_id = getattr(tool_call, "id", "")
func = getattr(tool_call, "function", None)
if func:
func_name = getattr(func, "name", "")
func_args_str = getattr(func, "arguments", "{}")

part = types.Part.from_function_call(
name=tool_call.function.name,
args=json.loads(tool_call.function.arguments or "{}"),
name=func_name,
args=json.loads(func_args_str or "{}"),
)
part.function_call.id = tool_call.id
part.function_call.id = call_id
if thought_signature:
part.thought_signature = thought_signature
parts.append(part)
Expand Down Expand Up @@ -2235,6 +2255,7 @@ async def generate_content_async(
aggregated_llm_response_with_tool_call = None
usage_metadata = None
fallback_index = 0
last_model_version = effective_model

def _finalize_tool_call_response(
*, model_version: str, finish_reason: str
Expand All @@ -2249,17 +2270,15 @@ def _finalize_tool_call_response(
except json.JSONDecodeError:
has_incomplete_tool_call_args = True
continue
tool_calls.append(
ChatCompletionMessageToolCall(
type="function",
id=func_data["id"],
function=Function(
name=func_data["name"],
arguments=func_data["args"],
index=index,
),
)
)
tool_calls.append({
"type": "function",
"id": func_data["id"],
"function": {
"name": func_data["name"],
"arguments": func_data["args"],
"index": index,
},
})

if has_incomplete_tool_call_args:
return LlmResponse(
Expand Down Expand Up @@ -2319,6 +2338,7 @@ def _reset_stream_buffers() -> None:
function_calls.clear()

async for part in await self.llm_client.acompletion(**completion_args):
last_model_version = getattr(part, "model", effective_model)
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
index = chunk.index or fallback_index
Expand Down Expand Up @@ -2349,15 +2369,15 @@ def _reset_stream_buffers() -> None:
content=chunk.text,
),
is_partial=True,
model_version=part.model,
model_version=last_model_version,
)
elif isinstance(chunk, ReasoningChunk):
if chunk.parts:
reasoning_parts.extend(chunk.parts)
yield LlmResponse(
content=types.Content(role="model", parts=list(chunk.parts)),
partial=True,
model_version=part.model,
model_version=last_model_version,
)
elif isinstance(chunk, UsageMetadataChunk):
usage_metadata = types.GenerateContentResponseUsageMetadata(
Expand All @@ -2380,7 +2400,7 @@ def _reset_stream_buffers() -> None:
):
aggregated_llm_response_with_tool_call = (
_finalize_tool_call_response(
model_version=part.model,
model_version=last_model_version,
finish_reason=finish_reason,
)
)
Expand All @@ -2394,21 +2414,21 @@ def _reset_stream_buffers() -> None:
)
):
aggregated_llm_response = _finalize_text_response(
model_version=part.model,
model_version=last_model_version,
finish_reason=finish_reason,
)
_reset_stream_buffers()

if function_calls and not aggregated_llm_response_with_tool_call:
aggregated_llm_response_with_tool_call = _finalize_tool_call_response(
model_version=part.model,
model_version=last_model_version,
finish_reason="tool_calls",
)
_reset_stream_buffers()

if (text or reasoning_parts) and not aggregated_llm_response:
aggregated_llm_response = _finalize_text_response(
model_version=part.model,
model_version=last_model_version,
finish_reason="stop",
)
_reset_stream_buffers()
Expand Down