Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class Agent(LlmAgent):
model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))
"""The api key of the model for agent running."""

model_extra_headers: dict = Field(default_factory=dict)
"""The extra headers to include in the model requests."""
model_extra_config: dict = Field(default_factory=dict)
"""The extra config to include in the model requests."""

tools: list[ToolUnion] = []
"""The tools provided to agent."""
Expand All @@ -101,17 +101,21 @@ class Agent(LlmAgent):
def model_post_init(self, __context: Any) -> None:
super().model_post_init(None) # for sub_agents init

self.model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS
# add model request source (veadk) in extra headers
if self.model_extra_config and "extra_headers" in self.model_extra_config:
self.model_extra_config["extra_headers"] |= DEFAULT_MODEL_EXTRA_HEADERS
else:
self.model_extra_config["extra_headers"] = DEFAULT_MODEL_EXTRA_HEADERS

if not self.model:
self.model = LiteLlm(
model=f"{self.model_provider}/{self.model_name}",
api_key=self.model_api_key,
api_base=self.model_api_base,
extra_headers=self.model_extra_headers,
**self.model_extra_config,
)
logger.debug(
f"LiteLLM client created with extra headers: {self.model_extra_headers}"
f"LiteLLM client created with config: {self.model_extra_config}"
)
else:
logger.warning(
Expand All @@ -133,7 +137,7 @@ def model_post_init(self, __context: Any) -> None:

logger.info(f"{self.__class__.__name__} `{self.name}` init done.")
logger.debug(
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools', 'serve_url'})}"
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
)

async def _run(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def tool_gen_ai_operation_name(params: ToolAttributesParams) -> ExtractorRespons

def tool_gen_ai_tool_message(params: ToolAttributesParams) -> ExtractorResponse:
tool_input = {
"id": "123",
"role": "tool",
"content": json.dumps(
{
Expand Down