From 262a32b349d2c04fec8dab94b6fce5a178b2bf80 Mon Sep 17 00:00:00 2001 From: "fangyaozheng@bytedance.com" Date: Tue, 26 Aug 2025 14:44:04 +0800 Subject: [PATCH] feat: support extra headers for model --- tests/test_agent.py | 7 +++++++ veadk/agent.py | 26 +++++++++++++++++++++----- veadk/consts.py | 3 +++ veadk/tracing/telemetry/telemetry.py | 2 +- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index 2843d03e..8ad947a7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -15,6 +15,7 @@ from google.adk.tools import load_memory from veadk import Agent +from veadk.consts import DEFAULT_MODEL_EXTRA_HEADERS from veadk.knowledgebase import KnowledgeBase from veadk.memory.long_term_memory import LongTermMemory from veadk.tools import load_knowledgebase_tool @@ -26,11 +27,14 @@ def test_agent(): long_term_memory = LongTermMemory(backend="local") tracer = OpentelemetryTracer() + model_extra_headers = {"test-header": "test-value"} + agent = Agent( model_name="test_model_name", model_provider="test_model_provider", model_api_key="test_model_api_key", model_api_base="test_model_api_base", + model_extra_headers=model_extra_headers, tools=[], sub_agents=[], knowledgebase=knowledgebase, @@ -39,7 +43,10 @@ def test_agent(): serve_url="", ) + model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS + assert agent.model.model == f"{agent.model_provider}/{agent.model_name}" + assert agent.model_extra_headers == model_extra_headers assert agent.knowledgebase == knowledgebase assert agent.knowledgebase.backend == "local" diff --git a/veadk/agent.py b/veadk/agent.py index 75323dce..7d947252 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -31,6 +31,7 @@ DEFALUT_MODEL_AGENT_PROVIDER, DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_MODEL_AGENT_NAME, + DEFAULT_MODEL_EXTRA_HEADERS, ) from veadk.evaluation import EvalSetRecorder from veadk.knowledgebase import KnowledgeBase @@ -73,6 +74,9 @@ class Agent(LlmAgent): model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY")) """The api key of the model for agent running.""" + model_extra_headers: dict = Field(default_factory=dict) + """The extra headers to include in the model requests.""" + tools: list[ToolUnion] = [] """The tools provided to agent.""" @@ -96,11 +100,23 @@ class Agent(LlmAgent): def model_post_init(self, __context: Any) -> None: super().model_post_init(None) # for sub_agents init - self.model = LiteLlm( - model=f"{self.model_provider}/{self.model_name}", - api_key=self.model_api_key, - api_base=self.model_api_base, - ) + + self.model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS + + if not self.model: + self.model = LiteLlm( + model=f"{self.model_provider}/{self.model_name}", + api_key=self.model_api_key, + api_base=self.model_api_base, + extra_headers=self.model_extra_headers, + ) + logger.debug( + f"LiteLLM client created with extra headers: {self.model_extra_headers}" + ) + else: + logger.warning( + "You are trying to use your own LiteLLM client, some default request headers may be missing." + ) if self.knowledgebase: from veadk.tools import load_knowledgebase_tool diff --git a/veadk/consts.py b/veadk/consts.py index 79134219..b1ceb88d 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from veadk.version import VERSION + DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615" DEFALUT_MODEL_AGENT_PROVIDER = "openai" DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" +DEFAULT_MODEL_EXTRA_HEADERS = {"veadk-source": "veadk", "veadk-version": VERSION} diff --git a/veadk/tracing/telemetry/telemetry.py b/veadk/tracing/telemetry/telemetry.py index 59dbf324..fe0218d5 100644 --- a/veadk/tracing/telemetry/telemetry.py +++ b/veadk/tracing/telemetry/telemetry.py @@ -64,7 +64,7 @@ def set_common_attributes( ) return - if isinstance(invocation_context.agent, Agent): + if isinstance(invocation_context.agent, Agent) and invocation_context.agent.tracers: try: from veadk.tracing.telemetry.opentelemetry_tracer import OpentelemetryTracer