Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config.yaml.full
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ model:
name: doubao-seed-1-6-250615
api_base: https://ark.cn-beijing.volces.com/api/v3/
api_key:
encrypted: true # true | false
caching: enabled # enabled | disabled
max_llm_calls: 100
# [optional] for llm-as-a-judge a evaluation
judge:
name: doubao-seed-1-6-250615
Expand Down
5 changes: 4 additions & 1 deletion config.yaml.simple
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ model:
provider: openai
name: doubao-seed-1-6-250615
api_base: https://ark.cn-beijing.volces.com/api/v3/
api_key:
api_key:
encrypted: true # true | false
caching: enabled # enabled | disabled
max_llm_calls: 100
1 change: 1 addition & 0 deletions docs/docs/agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Agent 中主要包括如下属性:
| model_provider | str | Agent 中内置模型的提供商,默认从环境变量中获取 |
| model_api_base | str | Agent 中内置模型的 API Base,默认从环境变量中获取 |
| model_api_key | str | Agent 中内置模型的 API Key,默认从环境变量中获取 |
| model_extra_config | dict | Agent 进行模型请求时的额外参数,Key 值为属性名,Value 值为属性值 |
| tools | list | Function call 中的工具列表,既可以是本地工具,也可以是 MCP 工具 |
| sub_agents | list | 子 Agent 列表,用于多 Agent 之间交互 |
| knowledgebase | KnowledgeBase | 知识库,后端支持本地内存(local)和数据库(opensearch、viking、redis、mysql),通常设置为一个能够检索的向量数据库 |
Expand Down
18 changes: 12 additions & 6 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from google.adk.tools import load_memory

from veadk import Agent
from veadk.consts import DEFAULT_MODEL_EXTRA_HEADERS
from veadk.consts import DEFAULT_MODEL_EXTRA_CONFIG
from veadk.knowledgebase import KnowledgeBase
from veadk.memory.long_term_memory import LongTermMemory
from veadk.tools import load_knowledgebase_tool
Expand All @@ -27,14 +27,17 @@ def test_agent():
long_term_memory = LongTermMemory(backend="local")
tracer = OpentelemetryTracer()

model_extra_headers = {"test-header": "test-value"}
extra_config = {
"extra_headers": {"thinking": "test"},
"extra_body": {"content": "test"},
}

agent = Agent(
model_name="test_model_name",
model_provider="test_model_provider",
model_api_key="test_model_api_key",
model_api_base="test_model_api_base",
model_extra_headers=model_extra_headers,
model_extra_config=extra_config,
tools=[],
sub_agents=[],
knowledgebase=knowledgebase,
Expand All @@ -43,10 +46,13 @@ def test_agent():
serve_url="",
)

model_extra_headers |= DEFAULT_MODEL_EXTRA_HEADERS

assert agent.model.model == f"{agent.model_provider}/{agent.model_name}"
assert agent.model_extra_headers == model_extra_headers

expected_config = DEFAULT_MODEL_EXTRA_CONFIG.copy()
expected_config["extra_headers"] |= extra_config["extra_headers"]
expected_config["extra_body"] |= extra_config["extra_body"]

assert agent.model_extra_config == expected_config

assert agent.knowledgebase == knowledgebase
assert agent.knowledgebase.backend == "local"
Expand Down
26 changes: 19 additions & 7 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@

from veadk.config import getenv
from veadk.consts import (
DEFAULT_MODEL_AGENT_PROVIDER,
DEFAULT_MODEL_AGENT_API_BASE,
DEFAULT_MODEL_AGENT_NAME,
DEFAULT_MODEL_EXTRA_HEADERS,
DEFAULT_MODEL_AGENT_PROVIDER,
DEFAULT_MODEL_EXTRA_CONFIG,
)
from veadk.evaluation import EvalSetRecorder
from veadk.knowledgebase import KnowledgeBase
Expand Down Expand Up @@ -101,11 +101,23 @@ class Agent(LlmAgent):
def model_post_init(self, __context: Any) -> None:
super().model_post_init(None) # for sub_agents init

# add model request source (veadk) in extra headers
if self.model_extra_config and "extra_headers" in self.model_extra_config:
self.model_extra_config["extra_headers"] |= DEFAULT_MODEL_EXTRA_HEADERS
else:
self.model_extra_config["extra_headers"] = DEFAULT_MODEL_EXTRA_HEADERS
# combine user model config with VeADK defaults
headers = DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"].copy()
body = DEFAULT_MODEL_EXTRA_CONFIG["extra_body"].copy()

if self.model_extra_config:
user_headers = self.model_extra_config.get("extra_headers", {})
user_body = self.model_extra_config.get("extra_body", {})

headers |= user_headers
body |= user_body

self.model_extra_config |= {
"extra_headers": headers,
"extra_body": body,
}

logger.info(f"Model extra config: {self.model_extra_config}")

if not self.model:
self.model = LiteLlm(
Expand Down
17 changes: 16 additions & 1 deletion veadk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import time

from veadk.config import getenv
from veadk.version import VERSION

DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615"
DEFAULT_MODEL_AGENT_PROVIDER = "openai"
DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"
DEFAULT_MODEL_EXTRA_HEADERS = {"veadk-source": "veadk", "veadk-version": VERSION}
DEFAULT_MODEL_EXTRA_CONFIG = {
"extra_headers": {
"x-is-encrypted": getenv("MODEL_AGENT_ENCRYPTED", "true"),
"veadk-source": "veadk",
"veadk-version": VERSION,
},
"extra_body": {
"caching": {
"type": getenv("MODEL_AGENT_CACHING", "enabled"),
},
"expire_at": int(time.time()) + 3600, # expire after 1 hour
},
}

DEFAULT_APMPLUS_OTEL_EXPORTER_ENDPOINT = "http://apmplus-cn-beijing.volces.com:4317"
DEFAULT_APMPLUS_OTEL_EXPORTER_SERVICE_NAME = "veadk_tracing"
Expand Down
19 changes: 16 additions & 3 deletions veadk/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
from veadk.agents.loop_agent import LoopAgent
from veadk.agents.parallel_agent import ParallelAgent
from veadk.agents.sequential_agent import SequentialAgent
from veadk.config import getenv
from veadk.evaluation import EvalSetRecorder
from veadk.integrations.ve_tos.ve_tos import VeTOS
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.types import MediaMessage
from veadk.utils.logger import get_logger
from veadk.utils.misc import read_png_to_bytes
from veadk.integrations.ve_tos.ve_tos import VeTOS

logger = get_logger(__name__)

Expand Down Expand Up @@ -142,7 +143,13 @@ async def _run(
if run_config is not None:
stream_mode = run_config.streaming_mode
else:
run_config = RunConfig(streaming_mode=stream_mode)
run_config = RunConfig(
streaming_mode=stream_mode,
max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100)),
)

logger.info(f"Run config: {run_config}")

try:

async def event_generator():
Expand Down Expand Up @@ -231,7 +238,13 @@ async def run_with_raw_message(
session_id: str,
run_config: RunConfig | None = None,
):
run_config = RunConfig() if not run_config else run_config
run_config = (
RunConfig(max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100)))
if not run_config
else run_config
)

logger.info(f"Run config: {run_config}")

await self.short_term_memory.create_session(
app_name=self.app_name, user_id=self.user_id, session_id=session_id
Expand Down