diff --git a/config.yaml.full b/config.yaml.full index aaf601d0..6c0cef57 100644 --- a/config.yaml.full +++ b/config.yaml.full @@ -5,6 +5,9 @@ model: name: doubao-seed-1-6-250615 api_base: https://ark.cn-beijing.volces.com/api/v3/ api_key: + encrypted: true # true | false + caching: enabled # enabled | disabled + max_llm_calls: 100 # [optional] for llm-as-a-judge a evaluation judge: name: doubao-seed-1-6-250615 diff --git a/config.yaml.simple b/config.yaml.simple index 9476d107..7c31c4dc 100644 --- a/config.yaml.simple +++ b/config.yaml.simple @@ -3,4 +3,7 @@ model: provider: openai name: doubao-seed-1-6-250615 api_base: https://ark.cn-beijing.volces.com/api/v3/ - api_key: \ No newline at end of file + api_key: + encrypted: true # true | false + caching: enabled # enabled | disabled + max_llm_calls: 100 \ No newline at end of file diff --git a/docs/docs/agent.md b/docs/docs/agent.md index 056e7405..88f2dd62 100644 --- a/docs/docs/agent.md +++ b/docs/docs/agent.md @@ -15,6 +15,7 @@ Agent 中主要包括如下属性: | model_provider | str | Agent 中内置模型的提供商,默认从环境变量中获取 | | model_api_base | str | Agent 中内置模型的 API Base,默认从环境变量中获取 | | model_api_key | str | Agent 中内置模型的 API Key,默认从环境变量中获取 | +| model_extra_config | dict | Agent 进行模型请求时的额外参数,Key 值为属性名,Value 值为属性值 | | tools | list | Function call 中的工具列表,既可以是本地工具,也可以是 MCP 工具 | | sub_agents | list | 子 Agent 列表,用于多 Agent 之间交互 | | knowledgebase | KnowledgeBase | 知识库,后端支持本地内存(local)和数据库(opensearch、viking、redis、mysql),通常设置为一个能够检索的向量数据库 | diff --git a/tests/test_agent.py b/tests/test_agent.py index 8ad947a7..2ecf4522 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -15,7 +15,7 @@ from google.adk.tools import load_memory from veadk import Agent -from veadk.consts import DEFAULT_MODEL_EXTRA_HEADERS +from veadk.consts import DEFAULT_MODEL_EXTRA_CONFIG from veadk.knowledgebase import KnowledgeBase from veadk.memory.long_term_memory import LongTermMemory from veadk.tools import load_knowledgebase_tool @@ -27,14 +27,17 @@ def test_agent(): long_term_memory = LongTermMemory(backend="local") tracer = OpentelemetryTracer() - model_extra_headers = {"test-header": "test-value"} + extra_config = { + "extra_headers": {"thinking": "test"}, + "extra_body": {"content": "test"}, + } agent = Agent( model_name="test_model_name", model_provider="test_model_provider", model_api_key="test_model_api_key", model_api_base="test_model_api_base", - model_extra_headers=model_extra_headers, + model_extra_config=extra_config, tools=[], sub_agents=[], knowledgebase=knowledgebase, @@ -43,10 +46,13 @@ def test_agent(): serve_url="", ) - model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS - assert agent.model.model == f"{agent.model_provider}/{agent.model_name}" - assert agent.model_extra_headers == model_extra_headers + + expected_config = DEFAULT_MODEL_EXTRA_CONFIG.copy() + expected_config["extra_headers"] |= extra_config["extra_headers"] + expected_config["extra_body"] |= extra_config["extra_body"] + + assert agent.model_extra_config == expected_config assert agent.knowledgebase == knowledgebase assert agent.knowledgebase.backend == "local" diff --git a/veadk/agent.py b/veadk/agent.py index 99c50ec0..248b8559 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -28,10 +28,10 @@ from veadk.config import getenv from veadk.consts import ( - DEFAULT_MODEL_AGENT_PROVIDER, DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_MODEL_AGENT_NAME, - DEFAULT_MODEL_EXTRA_HEADERS, + DEFAULT_MODEL_AGENT_PROVIDER, + DEFAULT_MODEL_EXTRA_CONFIG, ) from veadk.evaluation import EvalSetRecorder from veadk.knowledgebase import KnowledgeBase @@ -101,11 +101,23 @@ class Agent(LlmAgent): def model_post_init(self, __context: Any) -> None: super().model_post_init(None) # for sub_agents init - # add model request source (veadk) in extra headers - if self.model_extra_config and "extra_headers" in self.model_extra_config: - self.model_extra_config["extra_headers"] |= DEFAULT_MODEL_EXTRA_HEADERS - else: - self.model_extra_config["extra_headers"] = DEFAULT_MODEL_EXTRA_HEADERS + # combine user model config with VeADK defaults + headers = DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"].copy() + body = DEFAULT_MODEL_EXTRA_CONFIG["extra_body"].copy() + + if self.model_extra_config: + user_headers = self.model_extra_config.get("extra_headers", {}) + user_body = self.model_extra_config.get("extra_body", {}) + + headers |= user_headers + body |= user_body + + self.model_extra_config |= { + "extra_headers": headers, + "extra_body": body, + } + + logger.info(f"Model extra config: {self.model_extra_config}") if not self.model: self.model = LiteLlm( diff --git a/veadk/consts.py b/veadk/consts.py index 5c5ac71a..dd877690 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -12,12 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time + +from veadk.config import getenv from veadk.version import VERSION DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615" DEFAULT_MODEL_AGENT_PROVIDER = "openai" DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/" -DEFAULT_MODEL_EXTRA_HEADERS = {"veadk-source": "veadk", "veadk-version": VERSION} +DEFAULT_MODEL_EXTRA_CONFIG = { + "extra_headers": { + "x-is-encrypted": getenv("MODEL_AGENT_ENCRYPTED", "true"), + "veadk-source": "veadk", + "veadk-version": VERSION, + }, + "extra_body": { + "caching": { + "type": getenv("MODEL_AGENT_CACHING", "enabled"), + }, + "expire_at": int(time.time()) + 3600, # expire after 1 hour + }, +} DEFAULT_APMPLUS_OTEL_EXPORTER_ENDPOINT = "http://apmplus-cn-beijing.volces.com:4317" DEFAULT_APMPLUS_OTEL_EXPORTER_SERVICE_NAME = "veadk_tracing" diff --git a/veadk/runner.py b/veadk/runner.py index 54d8442a..ee67aeb8 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -27,12 +27,13 @@ from veadk.agents.loop_agent import LoopAgent from veadk.agents.parallel_agent import ParallelAgent from veadk.agents.sequential_agent import SequentialAgent +from veadk.config import getenv from veadk.evaluation import EvalSetRecorder +from veadk.integrations.ve_tos.ve_tos import VeTOS from veadk.memory.short_term_memory import ShortTermMemory from veadk.types import MediaMessage from veadk.utils.logger import get_logger from veadk.utils.misc import read_png_to_bytes -from veadk.integrations.ve_tos.ve_tos import VeTOS logger = get_logger(__name__) @@ -142,7 +143,13 @@ async def _run( if run_config is not None: stream_mode = run_config.streaming_mode else: - run_config = RunConfig(streaming_mode=stream_mode) + run_config = RunConfig( + streaming_mode=stream_mode, + max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100)), + ) + + logger.info(f"Run config: {run_config}") + try: async def event_generator(): @@ -231,7 +238,13 @@ async def run_with_raw_message( session_id: str, run_config: RunConfig | None = None, ): - run_config = RunConfig() if not run_config else run_config + run_config = ( + RunConfig(max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100))) + if not run_config + else run_config + ) + + logger.info(f"Run config: {run_config}") await self.short_term_memory.create_session( app_name=self.app_name, user_id=self.user_id, session_id=session_id