From 3a14c6b4013abc16dc0ac0b32c0b1486ce904d74 Mon Sep 17 00:00:00 2001 From: "fangyaozheng@bytedance.com" Date: Mon, 8 Sep 2025 09:33:34 +0800 Subject: [PATCH] refine(tests): refine tests for agent --- tests/test_agent.py | 174 +++++++++++++++++++++++++++++++++++++++++++- veadk/agent.py | 19 ++++- veadk/consts.py | 2 + 3 files changed, 187 insertions(+), 8 deletions(-) diff --git a/tests/test_agent.py b/tests/test_agent.py index 2ecf4522..02bb8b6a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -12,10 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import Mock, patch + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.models.lite_llm import LiteLlm from google.adk.tools import load_memory from veadk import Agent -from veadk.consts import DEFAULT_MODEL_EXTRA_CONFIG +from veadk.consts import ( + DEFAULT_AGENT_NAME, + DEFAULT_MODEL_AGENT_API_BASE, + DEFAULT_MODEL_AGENT_NAME, + DEFAULT_MODEL_AGENT_PROVIDER, + DEFAULT_MODEL_EXTRA_CONFIG, +) from veadk.knowledgebase import KnowledgeBase from veadk.memory.long_term_memory import LongTermMemory from veadk.tools import load_knowledgebase_tool @@ -46,7 +56,7 @@ def test_agent(): serve_url="", ) - assert agent.model.model == f"{agent.model_provider}/{agent.model_name}" + assert agent.model.model == f"{agent.model_provider}/{agent.model_name}" # type: ignore expected_config = DEFAULT_MODEL_EXTRA_CONFIG.copy() expected_config["extra_headers"] |= extra_config["extra_headers"] @@ -55,9 +65,165 @@ def test_agent(): assert agent.model_extra_config == expected_config assert agent.knowledgebase == knowledgebase - assert agent.knowledgebase.backend == "local" + assert agent.knowledgebase.backend == "local" # type: ignore assert load_knowledgebase_tool.knowledgebase == agent.knowledgebase assert load_knowledgebase_tool.load_knowledgebase_tool in agent.tools - assert agent.long_term_memory.backend == "local" + assert agent.long_term_memory.backend == "local" # type: ignore assert load_memory in agent.tools + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_default_values(): + agent = Agent() + + assert agent.name == DEFAULT_AGENT_NAME + + assert agent.model_name == DEFAULT_MODEL_AGENT_NAME + assert agent.model_provider == DEFAULT_MODEL_AGENT_PROVIDER + assert agent.model_api_base == DEFAULT_MODEL_AGENT_API_BASE + + assert agent.tools == [] + assert agent.sub_agents == [] + assert agent.knowledgebase is None + assert agent.long_term_memory is None + assert agent.tracers == [] + + assert agent.serve_url == "" + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_without_knowledgebase(): + agent = Agent() + + assert agent.knowledgebase is None + assert load_knowledgebase_tool.load_knowledgebase_tool not in agent.tools + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_without_long_term_memory(): + agent = Agent() + + assert agent.long_term_memory is None + assert load_memory not in agent.tools + + +@patch("veadk.agent.LiteLlm") +def test_agent_model_creation(mock_lite_llm): + mock_model = Mock() + mock_lite_llm.return_value = mock_model + + agent = Agent( + model_name="test_model", + model_provider="test_provider", + model_api_key="test_key", + model_api_base="test_base", + ) + + mock_lite_llm.assert_called_once() + assert agent.model == mock_model + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_with_existing_model(): + existing_model = LiteLlm(model="test_model") + agent = Agent(model=existing_model) + + assert agent.model == existing_model + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_model_extra_config_merge(): + user_config = { + "extra_headers": {"custom": "header"}, + "extra_body": {"custom": "body"}, + "other_param": "value", + } + + agent = Agent(model_extra_config=user_config) + + expected_headers = DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"].copy() + expected_headers["custom"] = "header" + + expected_body = DEFAULT_MODEL_EXTRA_CONFIG["extra_body"].copy() + expected_body["custom"] = "body" + + assert agent.model_extra_config["extra_headers"] == expected_headers + assert agent.model_extra_config["extra_body"] == expected_body + assert agent.model_extra_config["other_param"] == "value" + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_empty_model_extra_config(): + agent = Agent(model_extra_config={}) + + assert ( + agent.model_extra_config["extra_headers"] + == DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"] + ) + assert ( + agent.model_extra_config["extra_body"] + == DEFAULT_MODEL_EXTRA_CONFIG["extra_body"] + ) + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_with_tools(): + mock_tool = Mock() + agent = Agent(tools=[mock_tool]) + + assert mock_tool in agent.tools + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_with_sub_agents(): + adk_agent = LlmAgent(name="agent") + veadk_agent = Agent(name="agent") + agent = Agent(sub_agents=[adk_agent, veadk_agent]) + + assert adk_agent in agent.sub_agents + assert veadk_agent in agent.sub_agents + assert adk_agent.parent_agent == agent + assert veadk_agent.parent_agent == agent + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_with_tracers(): + tracer1 = OpentelemetryTracer() + tracer2 = OpentelemetryTracer() + + agent = Agent(tracers=[tracer1, tracer2]) + + assert len(agent.tracers) == 2 + assert tracer1 in agent.tracers + assert tracer2 in agent.tracers + + +@patch.dict( + "os.environ", + {"MODEL_AGENT_NAME": "env_model_name", "MODEL_AGENT_API_KEY": "mock_api_key"}, + clear=True, +) +def test_agent_environment_variables(): + agent = Agent() + print(agent) + assert agent.model_name == "env_model_name" + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_custom_name_and_description(): + custom_name = "CustomAgent" + custom_description = "A custom agent for testing" + + agent = Agent(name=custom_name, description=custom_description) + + assert agent.name == custom_name + assert agent.description == custom_description + + +@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"}) +def test_agent_serve_url(): + serve_url = "http://localhost:8080" + agent = Agent(serve_url=serve_url) + + assert agent.serve_url == serve_url diff --git a/veadk/agent.py b/veadk/agent.py index 248b8559..fc0e6c63 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -28,6 +28,7 @@ from veadk.config import getenv from veadk.consts import ( + DEFAULT_AGENT_NAME, DEFAULT_MODEL_AGENT_API_BASE, DEFAULT_MODEL_AGENT_NAME, DEFAULT_MODEL_AGENT_PROVIDER, @@ -53,7 +54,7 @@ class Agent(LlmAgent): model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") """The model config""" - name: str = "veAgent" + name: str = DEFAULT_AGENT_NAME """The name of the agent.""" description: str = DEFAULT_DESCRIPTION @@ -62,13 +63,23 @@ class Agent(LlmAgent): instruction: str = DEFAULT_INSTRUCTION """The instruction for the agent, such as principles of function calling.""" - model_name: str = getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME) + model_name: str = Field( + default_factory=lambda: getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME) + ) """The name of the model for agent running.""" - model_provider: str = getenv("MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER) + model_provider: str = Field( + default_factory=lambda: getenv( + "MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER + ) + ) """The provider of the model for agent running.""" - model_api_base: str = getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE) + model_api_base: str = Field( + default_factory=lambda: getenv( + "MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE + ) + ) """The api base of the model for agent running.""" model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY")) diff --git a/veadk/consts.py b/veadk/consts.py index 9320d7ad..b7be3ec2 100644 --- a/veadk/consts.py +++ b/veadk/consts.py @@ -17,6 +17,8 @@ from veadk.config import getenv from veadk.version import VERSION +DEFAULT_AGENT_NAME = "veAgent" + DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615" DEFAULT_MODEL_AGENT_PROVIDER = "openai" DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"