Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config.yaml.full
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
embedding:
name: doubao-embedding-text-240715
dim: 2560
api_base: https://ark.cn-beijing.volces.com/api/v3/embeddings
api_base: https://ark.cn-beijing.volces.com/api/v3/
api_key:
video:
name: doubao-seedance-1-0-pro-250528
Expand Down
24 changes: 20 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,33 @@ dependencies = [
"opentelemetry-instrumentation-logging>=0.56b0",
"wrapt>=1.17.2", # For patching built-in functions
"openai<1.100", # For fix https://github.com/BerriAI/litellm/issues/13710
"volcengine-python-sdk==4.0.3", # For Volcengine API
"volcengine-python-sdk>=4.0.3", # For Volcengine API
"volcengine>=1.0.193", # For Volcengine sign
"agent-pilot-sdk>=0.0.9", # Prompt optimization by Volcengine AgentPilot/PromptPilot toolkits
"fastmcp>=2.11.3", # For running MCP
"cookiecutter>=2.6.0", # For cloud deploy # For OpenSearch database
"opensearch-py==2.8.0",
"cookiecutter>=2.6.0", # For cloud deploy
"omegaconf>=2.3.0", # For agent builder
"llama-index>=0.14.0",
"llama-index-embeddings-openai-like>=0.2.2",
"llama-index-llms-openai-like>=0.5.1",
"llama-index-vector-stores-opensearch>=0.6.1",
"psycopg2-binary>=2.9.10", # For PostgreSQL database (short term memory)
"pymysql>=1.1.1", # For MySQL database (short term memory)
"opensearch-py==2.8.0",
]

[project.scripts]
veadk = "veadk.cli.cli:veadk"

[project.optional-dependencies]
extensions = [
"redis>=5.0", # For Redis database
"tos>=2.8.4", # For TOS storage and Viking DB
"llama-index-vector-stores-redis>=0.6.1",
"mcp-server-vikingdb-memory",
]
database = [
"redis>=6.2.0", # For Redis database
"redis>=5.0", # For Redis database
"pymysql>=1.1.1", # For MySQL database
"volcengine>=1.0.193", # For Viking DB
"tos>=2.8.4", # For TOS storage and Viking DB
Expand Down Expand Up @@ -78,3 +91,6 @@ exclude = [
"veadk/integrations/ve_faas/template/*",
"veadk/integrations/ve_faas/web_template/*"
]

[tool.uv.sources]
mcp-server-vikingdb-memory = { git = "https://github.com/volcengine/mcp-server", subdirectory = "server/mcp_server_vikingdb_memory" }
6 changes: 5 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@


def test_agent():
knowledgebase = KnowledgeBase()
knowledgebase = KnowledgeBase(
index="test_index",
backend="local",
backend_config={"embedding_config": {"api_key": "test"}},
)
long_term_memory = LongTermMemory(backend="local")
tracer = OpentelemetryTracer()

Expand Down
18 changes: 7 additions & 11 deletions tests/test_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from veadk.knowledgebase import KnowledgeBase
from veadk.knowledgebase.backends.in_memory_backend import InMemoryKnowledgeBackend


@pytest.mark.asyncio
async def test_knowledgebase():
app_name = "kb_test_app"
key = "Supercalifragilisticexpialidocious"
kb = KnowledgeBase(backend="local")
# Attempt to delete any existing data for the app_name before adding new data
kb.add(
data=[f"knowledgebase_id is {key}"],
app_name=app_name,
)
res_list = kb.search(
query="knowledgebase_id",
kb = KnowledgeBase(
backend="local",
app_name=app_name,
backend_config={"embedding_config": {"api_key": "test"}},
)
res = "".join(res_list)
assert key in res, f"Test failed for backend local res is {res}"

assert isinstance(kb._backend, InMemoryKnowledgeBackend)
41 changes: 10 additions & 31 deletions tests/test_long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest
from google.adk.events import Event
from google.adk.sessions import Session
from google.adk.tools import load_memory
from google.genai import types

from veadk.agent import Agent
from veadk.memory.long_term_memory import LongTermMemory
Expand All @@ -27,7 +25,11 @@

@pytest.mark.asyncio
async def test_long_term_memory():
long_term_memory = LongTermMemory(backend="local")
long_term_memory = LongTermMemory(
backend="local",
# app_name=app_name,
# user_id=user_id,
)
agent = Agent(
name="all_name",
model_name="test_model_name",
Expand All @@ -41,31 +43,8 @@ async def test_long_term_memory():

assert load_memory in agent.tools, "load_memory tool not found in agent tools"

# mock session
session = Session(
id="test_session_id",
app_name=app_name,
user_id=user_id,
events=[
Event(
invocation_id="test_invocation_id",
author="user",
branch=None,
content=types.Content(
parts=[types.Part(text="My name is Alice.")],
role="user",
),
)
],
)

await long_term_memory.add_session_to_memory(session)
assert not agent.long_term_memory._backend

memories = await long_term_memory.search_memory(
app_name=app_name,
user_id=user_id,
query="Alice",
)
assert (
"Alice" in memories.model_dump()["memories"][0]["content"]["parts"][0]["text"]
)
# assert agent.long_term_memory._backend.index == build_long_term_memory_index(
# app_name, user_id
# )
13 changes: 6 additions & 7 deletions tests/test_short_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import asyncio
import os

import veadk.memory.short_term_memory
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.utils.misc import formatted_timestamp

Expand All @@ -35,11 +34,11 @@ def test_short_term_memory():
)
assert session is not None

# database - local
veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH = (
f"/tmp/tmp_for_test_{formatted_timestamp()}.db"
# sqlite
memory = ShortTermMemory(
backend="sqlite",
local_database_path=f"/tmp/tmp_for_test_{formatted_timestamp()}.db",
)
memory = ShortTermMemory(backend="database")
asyncio.run(
memory.session_service.create_session(
app_name="app", user_id="user", session_id="session"
Expand All @@ -51,5 +50,5 @@ def test_short_term_memory():
)
)
assert session is not None
assert os.path.exists(veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH)
os.remove(veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH)
assert os.path.exists(memory.local_database_path)
os.remove(memory.local_database_path)
2 changes: 0 additions & 2 deletions veadk/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ async def run(
collect_runtime_data: bool = False,
eval_set_id: str = "",
save_session_to_memory: bool = False,
enable_memory_optimization: bool = False,
):
"""Running the agent. The runner and session service will be created automatically.

Expand Down Expand Up @@ -226,7 +225,6 @@ async def run(
# memory service
short_term_memory = ShortTermMemory(
backend="database" if load_history_sessions_from_db else "local",
enable_memory_optimization=enable_memory_optimization,
db_url=db_url,
)
session_service = short_term_memory.session_service
Expand Down
75 changes: 75 additions & 0 deletions veadk/auth/veauth/opensearch_veauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from typing_extensions import override

from veadk.auth.veauth.base_veauth import BaseVeAuth
from veadk.utils.logger import get_logger

# from veadk.utils.volcengine_sign import ve_request

logger = get_logger(__name__)


class OpensearchVeAuth(BaseVeAuth):
def __init__(
self,
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
) -> None:
super().__init__(access_key, secret_key)

self._token: str = ""

@override
def _fetch_token(self) -> None:
logger.info("Fetching Opensearch STS token...")

# res = ve_request(
# request_body={},
# action="GetOrCreatePromptPilotAPIKeys",
# ak=self.access_key,
# sk=self.secret_key,
# service="ark",
# version="2024-01-01",
# region="cn-beijing",
# host="open.volcengineapi.com",
# )
# try:
# self._token = res["Result"]["APIKeys"][0]["APIKey"]
# except KeyError:
# raise ValueError(f"Failed to get Prompt Pilot token: {res}")

@property
def token(self) -> str:
if self._token:
return self._token
self._fetch_token()
return self._token
75 changes: 75 additions & 0 deletions veadk/auth/veauth/postgresql_veauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from typing_extensions import override

from veadk.auth.veauth.base_veauth import BaseVeAuth
from veadk.utils.logger import get_logger

# from veadk.utils.volcengine_sign import ve_request

logger = get_logger(__name__)


class PostgreSqlVeAuth(BaseVeAuth):
def __init__(
self,
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
) -> None:
super().__init__(access_key, secret_key)

self._token: str = ""

@override
def _fetch_token(self) -> None:
logger.info("Fetching PostgreSQL STS token...")

# res = ve_request(
# request_body={},
# action="GetOrCreatePromptPilotAPIKeys",
# ak=self.access_key,
# sk=self.secret_key,
# service="ark",
# version="2024-01-01",
# region="cn-beijing",
# host="open.volcengineapi.com",
# )
# try:
# self._token = res["Result"]["APIKeys"][0]["APIKey"]
# except KeyError:
# raise ValueError(f"Failed to get Prompt Pilot token: {res}")

@property
def token(self) -> str:
if self._token:
return self._token
self._fetch_token()
return self._token
Loading