Skip to content

Commit db56e23

Browse files
committed
feat(tools): support per-parameter descriptions via Annotated[T, Field(description=...)]
This change implements Option A from issue #4552, adding support for per-parameter descriptions in FunctionTool using the Annotated type hint with pydantic.Field(description=...). Changes: - Add _extract_field_info_from_annotated() to extract FieldInfo from Annotated - Add _extract_base_type_from_annotated() to unwrap base types from Annotated - Update _get_fields_dict() to use descriptions from Annotated[T, Field(...)] - Add comprehensive tests for the new functionality This enables developers to provide contextual guidance for LLM parameter selection without embedding all information in the tool docstring: from typing import Annotated from pydantic import Field async def create_task( repository: Annotated[str, Field( description='Repository URL. MUST be obtained from get_repository_info.' )], ) -> dict: ... Closes #4552
1 parent 4b677e7 commit db56e23

2 files changed

Lines changed: 481 additions & 90 deletions

File tree

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 146 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import inspect
1919
from types import FunctionType
2020
import typing
21+
from typing import Annotated
2122
from typing import Any
2223
from typing import Callable
2324
from typing import Dict
@@ -31,6 +32,7 @@
3132
from pydantic import BaseModel
3233
from pydantic import create_model
3334
from pydantic import fields as pydantic_fields
35+
from pydantic.fields import FieldInfo
3436

3537
from . import _function_parameter_parse_util
3638
from . import _function_tool_declarations
@@ -39,62 +41,116 @@
3941
from ..utils.variant_utils import GoogleLLMVariant
4042

4143
_py_type_2_schema_type = {
42-
'str': types.Type.STRING,
43-
'int': types.Type.INTEGER,
44-
'float': types.Type.NUMBER,
45-
'bool': types.Type.BOOLEAN,
46-
'string': types.Type.STRING,
47-
'integer': types.Type.INTEGER,
48-
'number': types.Type.NUMBER,
49-
'boolean': types.Type.BOOLEAN,
50-
'list': types.Type.ARRAY,
51-
'array': types.Type.ARRAY,
52-
'tuple': types.Type.ARRAY,
53-
'object': types.Type.OBJECT,
54-
'Dict': types.Type.OBJECT,
55-
'List': types.Type.ARRAY,
56-
'Tuple': types.Type.ARRAY,
57-
'Any': types.Type.TYPE_UNSPECIFIED,
44+
"str": types.Type.STRING,
45+
"int": types.Type.INTEGER,
46+
"float": types.Type.NUMBER,
47+
"bool": types.Type.BOOLEAN,
48+
"string": types.Type.STRING,
49+
"integer": types.Type.INTEGER,
50+
"number": types.Type.NUMBER,
51+
"boolean": types.Type.BOOLEAN,
52+
"list": types.Type.ARRAY,
53+
"array": types.Type.ARRAY,
54+
"tuple": types.Type.ARRAY,
55+
"object": types.Type.OBJECT,
56+
"Dict": types.Type.OBJECT,
57+
"List": types.Type.ARRAY,
58+
"Tuple": types.Type.ARRAY,
59+
"Any": types.Type.TYPE_UNSPECIFIED,
5860
}
5961

6062

63+
def _extract_field_info_from_annotated(
64+
annotation: Any,
65+
) -> Optional[FieldInfo]:
66+
"""Extract pydantic FieldInfo from Annotated[T, Field(...)] if present.
67+
68+
Args:
69+
annotation: The type annotation to inspect.
70+
71+
Returns:
72+
The FieldInfo instance if found in Annotated metadata, None otherwise.
73+
"""
74+
if get_origin(annotation) is Annotated:
75+
for metadata in get_args(annotation)[1:]:
76+
if isinstance(metadata, FieldInfo):
77+
return metadata
78+
return None
79+
80+
81+
def _extract_base_type_from_annotated(annotation: Any) -> Any:
82+
"""Extract the base type from Annotated[T, ...].
83+
84+
Args:
85+
annotation: The type annotation to unwrap.
86+
87+
Returns:
88+
The base type T if annotation is Annotated[T, ...], otherwise the original
89+
annotation.
90+
"""
91+
if get_origin(annotation) is Annotated:
92+
return get_args(annotation)[0]
93+
return annotation
94+
95+
6196
def _get_fields_dict(func: Callable) -> Dict:
97+
"""Build a dictionary of field definitions for Pydantic model creation.
98+
99+
This function extracts parameter information from a callable and creates
100+
field definitions compatible with Pydantic's create_model. It supports
101+
parameter descriptions via Annotated[T, Field(description=...)] syntax.
102+
103+
Args:
104+
func: The callable to extract parameters from.
105+
106+
Returns:
107+
A dictionary mapping parameter names to (type, FieldInfo) tuples.
108+
"""
62109
param_signature = dict(inspect.signature(func).parameters)
63-
fields_dict = {
64-
name: (
65-
# 1. We infer the argument type here: use Any rather than None so
66-
# it will not try to auto-infer the type based on the default value.
67-
(
68-
param.annotation
69-
if param.annotation != inspect.Parameter.empty
70-
else Any
71-
),
72-
pydantic.Field(
73-
# 2. We do not support default values for now.
74-
default=(
75-
param.default
76-
if param.default != inspect.Parameter.empty
77-
# ! Need to use Undefined instead of None
78-
else pydantic_fields.PydanticUndefined
79-
),
80-
# 3. Do not support parameter description for now.
81-
description=None,
82-
),
83-
)
84-
for name, param in param_signature.items()
85-
# We do not support *args or **kwargs
86-
if param.kind
87-
in (
88-
inspect.Parameter.POSITIONAL_OR_KEYWORD,
89-
inspect.Parameter.KEYWORD_ONLY,
90-
inspect.Parameter.POSITIONAL_ONLY,
91-
)
92-
}
110+
fields_dict = {}
111+
112+
for name, param in param_signature.items():
113+
# We do not support *args or **kwargs
114+
if param.kind not in (
115+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
116+
inspect.Parameter.KEYWORD_ONLY,
117+
inspect.Parameter.POSITIONAL_ONLY,
118+
):
119+
continue
120+
121+
annotation = (
122+
param.annotation if param.annotation != inspect.Parameter.empty else Any
123+
)
124+
125+
# Extract FieldInfo from Annotated[T, Field(...)] if present
126+
field_info = _extract_field_info_from_annotated(annotation)
127+
128+
# Extract the base type from Annotated[T, ...] for the model field
129+
base_type = _extract_base_type_from_annotated(annotation)
130+
131+
# Determine the default value
132+
default = (
133+
param.default
134+
if param.default != inspect.Parameter.empty
135+
else pydantic_fields.PydanticUndefined
136+
)
137+
138+
# Get description from FieldInfo if available
139+
description = field_info.description if field_info else None
140+
141+
fields_dict[name] = (
142+
base_type,
143+
pydantic.Field(
144+
default=default,
145+
description=description,
146+
),
147+
)
148+
93149
return fields_dict
94150

95151

96152
def _annotate_nullable_fields(schema: Dict):
97-
for _, property_schema in schema.get('properties', {}).items():
153+
for _, property_schema in schema.get("properties", {}).items():
98154
# for Optional[T], the pydantic schema is:
99155
# {
100156
# "type": "object",
@@ -109,53 +165,53 @@ def _annotate_nullable_fields(schema: Dict):
109165
# ]
110166
# }
111167
# }
112-
for type_ in property_schema.get('anyOf', []):
113-
if type_.get('type') == 'null':
114-
property_schema['nullable'] = True
115-
property_schema['anyOf'].remove(type_)
168+
for type_ in property_schema.get("anyOf", []):
169+
if type_.get("type") == "null":
170+
property_schema["nullable"] = True
171+
property_schema["anyOf"].remove(type_)
116172
break
117173

118174

119175
def _annotate_required_fields(schema: Dict):
120176
required = [
121177
field_name
122-
for field_name, field_schema in schema.get('properties', {}).items()
123-
if not field_schema.get('nullable') and 'default' not in field_schema
178+
for field_name, field_schema in schema.get("properties", {}).items()
179+
if not field_schema.get("nullable") and "default" not in field_schema
124180
]
125-
schema['required'] = required
181+
schema["required"] = required
126182

127183

128184
def _remove_any_of(schema: Dict):
129-
for _, property_schema in schema.get('properties', {}).items():
130-
union_types = property_schema.pop('anyOf', None)
185+
for _, property_schema in schema.get("properties", {}).items():
186+
union_types = property_schema.pop("anyOf", None)
131187
# Take the first non-null type.
132188
if union_types:
133189
for type_ in union_types:
134-
if type_.get('type') != 'null':
190+
if type_.get("type") != "null":
135191
property_schema.update(type_)
136192

137193

138194
def _remove_default(schema: Dict):
139-
for _, property_schema in schema.get('properties', {}).items():
140-
property_schema.pop('default', None)
195+
for _, property_schema in schema.get("properties", {}).items():
196+
property_schema.pop("default", None)
141197

142198

143199
def _remove_nullable(schema: Dict):
144-
for _, property_schema in schema.get('properties', {}).items():
145-
property_schema.pop('nullable', None)
200+
for _, property_schema in schema.get("properties", {}).items():
201+
property_schema.pop("nullable", None)
146202

147203

148204
def _remove_title(schema: Dict):
149-
for _, property_schema in schema.get('properties', {}).items():
150-
property_schema.pop('title', None)
205+
for _, property_schema in schema.get("properties", {}).items():
206+
property_schema.pop("title", None)
151207

152208

153209
def _get_pydantic_schema(func: Callable) -> Dict:
154210
from ..utils.context_utils import find_context_parameter
155211

156212
fields_dict = _get_fields_dict(func)
157213
# Remove context parameter (detected by type or fallback to 'tool_context' name)
158-
context_param = find_context_parameter(func) or 'tool_context'
214+
context_param = find_context_parameter(func) or "tool_context"
159215
if context_param in fields_dict.keys():
160216
fields_dict.pop(context_param)
161217
return pydantic.create_model(func.__name__, **fields_dict).model_json_schema()
@@ -173,24 +229,24 @@ def _process_pydantic_schema(vertexai: bool, schema: Dict) -> Dict:
173229

174230

175231
def _map_pydantic_type_to_property_schema(property_schema: Dict):
176-
if 'type' in property_schema:
177-
property_schema['type'] = _py_type_2_schema_type.get(
178-
property_schema['type'], 'TYPE_UNSPECIFIED'
232+
if "type" in property_schema:
233+
property_schema["type"] = _py_type_2_schema_type.get(
234+
property_schema["type"], "TYPE_UNSPECIFIED"
179235
)
180-
if property_schema['type'] == 'ARRAY':
181-
_map_pydantic_type_to_property_schema(property_schema['items'])
182-
for type_ in property_schema.get('anyOf', []):
183-
if 'type' in type_:
184-
type_['type'] = _py_type_2_schema_type.get(
185-
type_['type'], 'TYPE_UNSPECIFIED'
236+
if property_schema["type"] == "ARRAY":
237+
_map_pydantic_type_to_property_schema(property_schema["items"])
238+
for type_ in property_schema.get("anyOf", []):
239+
if "type" in type_:
240+
type_["type"] = _py_type_2_schema_type.get(
241+
type_["type"], "TYPE_UNSPECIFIED"
186242
)
187243
# TODO: To investigate. Unclear why a Type is needed with 'anyOf' to
188244
# avoid google.genai.errors.ClientError: 400 INVALID_ARGUMENT.
189-
property_schema['type'] = type_['type']
245+
property_schema["type"] = type_["type"]
190246

191247

192248
def _map_pydantic_type_to_schema_type(schema: Dict):
193-
for _, property_schema in schema.get('properties', {}).items():
249+
for _, property_schema in schema.get("properties", {}).items():
194250
_map_pydantic_type_to_property_schema(property_schema)
195251

196252

@@ -266,13 +322,13 @@ def build_function_declaration_for_langchain(
266322
vertexai: bool, name, description, func, param_pydantic_schema
267323
) -> types.FunctionDeclaration:
268324
param_pydantic_schema = _process_pydantic_schema(
269-
vertexai, {'properties': param_pydantic_schema}
270-
)['properties']
325+
vertexai, {"properties": param_pydantic_schema}
326+
)["properties"]
271327
param_copy = param_pydantic_schema.copy()
272-
required_fields = param_copy.pop('required', [])
328+
required_fields = param_copy.pop("required", [])
273329
before_param_pydantic_schema = {
274-
'properties': param_copy,
275-
'required': required_fields,
330+
"properties": param_copy,
331+
"required": required_fields,
276332
}
277333
return build_function_declaration_util(
278334
vertexai, name, description, func, before_param_pydantic_schema
@@ -295,10 +351,10 @@ def build_function_declaration_util(
295351
vertexai: bool, name, description, func, before_param_pydantic_schema
296352
) -> types.FunctionDeclaration:
297353
_map_pydantic_type_to_schema_type(before_param_pydantic_schema)
298-
properties = before_param_pydantic_schema.get('properties', {})
354+
properties = before_param_pydantic_schema.get("properties", {})
299355
function_declaration = types.FunctionDeclaration(
300356
parameters=types.Schema(
301-
type='OBJECT',
357+
type="OBJECT",
302358
properties=properties,
303359
)
304360
if properties
@@ -317,7 +373,7 @@ def build_function_declaration_util(
317373
def from_function_with_options(
318374
func: Callable,
319375
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
320-
) -> 'types.FunctionDeclaration':
376+
) -> "types.FunctionDeclaration":
321377

322378
parameters_properties = {}
323379
parameters_json_schema = {}
@@ -379,7 +435,7 @@ def from_function_with_options(
379435
)
380436
if parameters_properties:
381437
declaration.parameters = types.Schema(
382-
type='OBJECT',
438+
type="OBJECT",
383439
properties=parameters_properties,
384440
)
385441
declaration.parameters.required = (
@@ -389,7 +445,7 @@ def from_function_with_options(
389445
)
390446
elif parameters_json_schema:
391447
declaration.parameters = types.Schema(
392-
type='OBJECT',
448+
type="OBJECT",
393449
properties=parameters_json_schema,
394450
)
395451

@@ -416,7 +472,7 @@ def from_function_with_options(
416472
if return_annotation is inspect._empty:
417473
# Functions with no return annotation can return any type
418474
return_value = inspect.Parameter(
419-
'return_value',
475+
"return_value",
420476
inspect.Parameter.POSITIONAL_OR_KEYWORD,
421477
annotation=typing.Any,
422478
)
@@ -433,11 +489,11 @@ def from_function_with_options(
433489
if (
434490
return_annotation is None
435491
or return_annotation is type(None)
436-
or (isinstance(return_annotation, str) and return_annotation == 'None')
492+
or (isinstance(return_annotation, str) and return_annotation == "None")
437493
):
438494
# Create a response schema for None/null return
439495
return_value = inspect.Parameter(
440-
'return_value',
496+
"return_value",
441497
inspect.Parameter.POSITIONAL_OR_KEYWORD,
442498
annotation=None,
443499
)
@@ -451,13 +507,13 @@ def from_function_with_options(
451507
return declaration
452508

453509
return_value = inspect.Parameter(
454-
'return_value',
510+
"return_value",
455511
inspect.Parameter.POSITIONAL_OR_KEYWORD,
456512
annotation=return_annotation,
457513
)
458514
if isinstance(return_value.annotation, str):
459515
return_value = return_value.replace(
460-
annotation=typing.get_type_hints(func)['return']
516+
annotation=typing.get_type_hints(func)["return"]
461517
)
462518

463519
response_schema: Optional[types.Schema] = None

0 commit comments

Comments
 (0)