Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 170 additions & 4 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import Mock, patch

from google.adk.agents.llm_agent import LlmAgent
from google.adk.models.lite_llm import LiteLlm
from google.adk.tools import load_memory

from veadk import Agent
from veadk.consts import DEFAULT_MODEL_EXTRA_CONFIG
from veadk.consts import (
DEFAULT_AGENT_NAME,
DEFAULT_MODEL_AGENT_API_BASE,
DEFAULT_MODEL_AGENT_NAME,
DEFAULT_MODEL_AGENT_PROVIDER,
DEFAULT_MODEL_EXTRA_CONFIG,
)
from veadk.knowledgebase import KnowledgeBase
from veadk.memory.long_term_memory import LongTermMemory
from veadk.tools import load_knowledgebase_tool
Expand Down Expand Up @@ -46,7 +56,7 @@ def test_agent():
serve_url="",
)

assert agent.model.model == f"{agent.model_provider}/{agent.model_name}"
assert agent.model.model == f"{agent.model_provider}/{agent.model_name}" # type: ignore

expected_config = DEFAULT_MODEL_EXTRA_CONFIG.copy()
expected_config["extra_headers"] |= extra_config["extra_headers"]
Expand All @@ -55,9 +65,165 @@ def test_agent():
assert agent.model_extra_config == expected_config

assert agent.knowledgebase == knowledgebase
assert agent.knowledgebase.backend == "local"
assert agent.knowledgebase.backend == "local" # type: ignore
assert load_knowledgebase_tool.knowledgebase == agent.knowledgebase
assert load_knowledgebase_tool.load_knowledgebase_tool in agent.tools

assert agent.long_term_memory.backend == "local"
assert agent.long_term_memory.backend == "local" # type: ignore
assert load_memory in agent.tools


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_default_values():
agent = Agent()

assert agent.name == DEFAULT_AGENT_NAME

assert agent.model_name == DEFAULT_MODEL_AGENT_NAME
assert agent.model_provider == DEFAULT_MODEL_AGENT_PROVIDER
assert agent.model_api_base == DEFAULT_MODEL_AGENT_API_BASE

assert agent.tools == []
assert agent.sub_agents == []
assert agent.knowledgebase is None
assert agent.long_term_memory is None
assert agent.tracers == []

assert agent.serve_url == ""


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_without_knowledgebase():
agent = Agent()

assert agent.knowledgebase is None
assert load_knowledgebase_tool.load_knowledgebase_tool not in agent.tools


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_without_long_term_memory():
agent = Agent()

assert agent.long_term_memory is None
assert load_memory not in agent.tools


@patch("veadk.agent.LiteLlm")
def test_agent_model_creation(mock_lite_llm):
mock_model = Mock()
mock_lite_llm.return_value = mock_model

agent = Agent(
model_name="test_model",
model_provider="test_provider",
model_api_key="test_key",
model_api_base="test_base",
)

mock_lite_llm.assert_called_once()
assert agent.model == mock_model


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_with_existing_model():
existing_model = LiteLlm(model="test_model")
agent = Agent(model=existing_model)

assert agent.model == existing_model


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_model_extra_config_merge():
user_config = {
"extra_headers": {"custom": "header"},
"extra_body": {"custom": "body"},
"other_param": "value",
}

agent = Agent(model_extra_config=user_config)

expected_headers = DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"].copy()
expected_headers["custom"] = "header"

expected_body = DEFAULT_MODEL_EXTRA_CONFIG["extra_body"].copy()
expected_body["custom"] = "body"

assert agent.model_extra_config["extra_headers"] == expected_headers
assert agent.model_extra_config["extra_body"] == expected_body
assert agent.model_extra_config["other_param"] == "value"


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_empty_model_extra_config():
agent = Agent(model_extra_config={})

assert (
agent.model_extra_config["extra_headers"]
== DEFAULT_MODEL_EXTRA_CONFIG["extra_headers"]
)
assert (
agent.model_extra_config["extra_body"]
== DEFAULT_MODEL_EXTRA_CONFIG["extra_body"]
)


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_with_tools():
mock_tool = Mock()
agent = Agent(tools=[mock_tool])

assert mock_tool in agent.tools


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_with_sub_agents():
adk_agent = LlmAgent(name="agent")
veadk_agent = Agent(name="agent")
agent = Agent(sub_agents=[adk_agent, veadk_agent])

assert adk_agent in agent.sub_agents
assert veadk_agent in agent.sub_agents
assert adk_agent.parent_agent == agent
assert veadk_agent.parent_agent == agent


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_with_tracers():
tracer1 = OpentelemetryTracer()
tracer2 = OpentelemetryTracer()

agent = Agent(tracers=[tracer1, tracer2])

assert len(agent.tracers) == 2
assert tracer1 in agent.tracers
assert tracer2 in agent.tracers


@patch.dict(
"os.environ",
{"MODEL_AGENT_NAME": "env_model_name", "MODEL_AGENT_API_KEY": "mock_api_key"},
clear=True,
)
def test_agent_environment_variables():
agent = Agent()
print(agent)
assert agent.model_name == "env_model_name"


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_custom_name_and_description():
custom_name = "CustomAgent"
custom_description = "A custom agent for testing"

agent = Agent(name=custom_name, description=custom_description)

assert agent.name == custom_name
assert agent.description == custom_description


@patch.dict("os.environ", {"MODEL_AGENT_API_KEY": "mock_api_key"})
def test_agent_serve_url():
serve_url = "http://localhost:8080"
agent = Agent(serve_url=serve_url)

assert agent.serve_url == serve_url
19 changes: 15 additions & 4 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from veadk.config import getenv
from veadk.consts import (
DEFAULT_AGENT_NAME,
DEFAULT_MODEL_AGENT_API_BASE,
DEFAULT_MODEL_AGENT_NAME,
DEFAULT_MODEL_AGENT_PROVIDER,
Expand All @@ -53,7 +54,7 @@ class Agent(LlmAgent):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
"""The model config"""

name: str = "veAgent"
name: str = DEFAULT_AGENT_NAME
"""The name of the agent."""

description: str = DEFAULT_DESCRIPTION
Expand All @@ -62,13 +63,23 @@ class Agent(LlmAgent):
instruction: str = DEFAULT_INSTRUCTION
"""The instruction for the agent, such as principles of function calling."""

model_name: str = getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME)
model_name: str = Field(
default_factory=lambda: getenv("MODEL_AGENT_NAME", DEFAULT_MODEL_AGENT_NAME)
)
"""The name of the model for agent running."""

model_provider: str = getenv("MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER)
model_provider: str = Field(
default_factory=lambda: getenv(
"MODEL_AGENT_PROVIDER", DEFAULT_MODEL_AGENT_PROVIDER
)
)
"""The provider of the model for agent running."""

model_api_base: str = getenv("MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE)
model_api_base: str = Field(
default_factory=lambda: getenv(
"MODEL_AGENT_API_BASE", DEFAULT_MODEL_AGENT_API_BASE
)
)
"""The api base of the model for agent running."""

model_api_key: str = Field(default_factory=lambda: getenv("MODEL_AGENT_API_KEY"))
Expand Down
2 changes: 2 additions & 0 deletions veadk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from veadk.config import getenv
from veadk.version import VERSION

DEFAULT_AGENT_NAME = "veAgent"

DEFAULT_MODEL_AGENT_NAME = "doubao-seed-1-6-250615"
DEFAULT_MODEL_AGENT_PROVIDER = "openai"
DEFAULT_MODEL_AGENT_API_BASE = "https://ark.cn-beijing.volces.com/api/v3/"
Expand Down