diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 7d13696c96..b6f90a1f5a 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -833,7 +833,7 @@ async def _content_to_message_param( for part in content.parts: if part.function_call: tool_call_id = part.function_call.id or "" - tool_call_dict: ChatCompletionAssistantToolCall = { + tool_call_dict: Dict[str, Any] = { "type": "function", "id": tool_call_id, "function": { @@ -1688,13 +1688,33 @@ def _message_to_generate_content_response( if tool_calls: for tool_call in tool_calls: - if tool_call.type == "function": + if ( + isinstance(tool_call, dict) and tool_call.get("type") == "function" + ) or (hasattr(tool_call, "type") and tool_call.type == "function"): thought_signature = _extract_thought_signature_from_tool_call(tool_call) + + func_name = "" + func_args_str = "{}" + call_id = "" + + if isinstance(tool_call, dict): + call_id = tool_call.get("id", "") + func = tool_call.get("function") + if isinstance(func, dict): + func_name = func.get("name", "") + func_args_str = func.get("arguments", "{}") + else: + call_id = getattr(tool_call, "id", "") + func = getattr(tool_call, "function", None) + if func: + func_name = getattr(func, "name", "") + func_args_str = getattr(func, "arguments", "{}") + part = types.Part.from_function_call( - name=tool_call.function.name, - args=json.loads(tool_call.function.arguments or "{}"), + name=func_name, + args=json.loads(func_args_str or "{}"), ) - part.function_call.id = tool_call.id + part.function_call.id = call_id if thought_signature: part.thought_signature = thought_signature parts.append(part) @@ -2235,6 +2255,7 @@ async def generate_content_async( aggregated_llm_response_with_tool_call = None usage_metadata = None fallback_index = 0 + last_model_version = effective_model def _finalize_tool_call_response( *, model_version: str, finish_reason: str @@ -2249,17 +2270,15 @@ def _finalize_tool_call_response( except json.JSONDecodeError: has_incomplete_tool_call_args = True continue - tool_calls.append( - ChatCompletionMessageToolCall( - type="function", - id=func_data["id"], - function=Function( - name=func_data["name"], - arguments=func_data["args"], - index=index, - ), - ) - ) + tool_calls.append({ + "type": "function", + "id": func_data["id"], + "function": { + "name": func_data["name"], + "arguments": func_data["args"], + "index": index, + }, + }) if has_incomplete_tool_call_args: return LlmResponse( @@ -2319,6 +2338,7 @@ def _reset_stream_buffers() -> None: function_calls.clear() async for part in await self.llm_client.acompletion(**completion_args): + last_model_version = getattr(part, "model", effective_model) for chunk, finish_reason in _model_response_to_chunk(part): if isinstance(chunk, FunctionChunk): index = chunk.index or fallback_index @@ -2349,7 +2369,7 @@ def _reset_stream_buffers() -> None: content=chunk.text, ), is_partial=True, - model_version=part.model, + model_version=last_model_version, ) elif isinstance(chunk, ReasoningChunk): if chunk.parts: @@ -2357,7 +2377,7 @@ def _reset_stream_buffers() -> None: yield LlmResponse( content=types.Content(role="model", parts=list(chunk.parts)), partial=True, - model_version=part.model, + model_version=last_model_version, ) elif isinstance(chunk, UsageMetadataChunk): usage_metadata = types.GenerateContentResponseUsageMetadata( @@ -2380,7 +2400,7 @@ def _reset_stream_buffers() -> None: ): aggregated_llm_response_with_tool_call = ( _finalize_tool_call_response( - model_version=part.model, + model_version=last_model_version, finish_reason=finish_reason, ) ) @@ -2394,21 +2414,21 @@ def _reset_stream_buffers() -> None: ) ): aggregated_llm_response = _finalize_text_response( - model_version=part.model, + model_version=last_model_version, finish_reason=finish_reason, ) _reset_stream_buffers() if function_calls and not aggregated_llm_response_with_tool_call: aggregated_llm_response_with_tool_call = _finalize_tool_call_response( - model_version=part.model, + model_version=last_model_version, finish_reason="tool_calls", ) _reset_stream_buffers() if (text or reasoning_parts) and not aggregated_llm_response: aggregated_llm_response = _finalize_text_response( - model_version=part.model, + model_version=last_model_version, finish_reason="stop", ) _reset_stream_buffers()