diff --git a/config.yaml.full b/config.yaml.full index 3b5550ba..4ebfccc4 100644 --- a/config.yaml.full +++ b/config.yaml.full @@ -17,7 +17,7 @@ model: embedding: name: doubao-embedding-text-240715 dim: 2560 - api_base: https://ark.cn-beijing.volces.com/api/v3/embeddings + api_base: https://ark.cn-beijing.volces.com/api/v3/ api_key: video: name: doubao-seedance-1-0-pro-250528 diff --git a/pyproject.toml b/pyproject.toml index b83ea3fb..e937ae09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,20 +22,33 @@ dependencies = [ "opentelemetry-instrumentation-logging>=0.56b0", "wrapt>=1.17.2", # For patching built-in functions "openai<1.100", # For fix https://github.com/BerriAI/litellm/issues/13710 - "volcengine-python-sdk==4.0.3", # For Volcengine API + "volcengine-python-sdk>=4.0.3", # For Volcengine API + "volcengine>=1.0.193", # For Volcengine sign "agent-pilot-sdk>=0.0.9", # Prompt optimization by Volcengine AgentPilot/PromptPilot toolkits "fastmcp>=2.11.3", # For running MCP - "cookiecutter>=2.6.0", # For cloud deploy # For OpenSearch database - "opensearch-py==2.8.0", + "cookiecutter>=2.6.0", # For cloud deploy "omegaconf>=2.3.0", # For agent builder + "llama-index>=0.14.0", + "llama-index-embeddings-openai-like>=0.2.2", + "llama-index-llms-openai-like>=0.5.1", + "llama-index-vector-stores-opensearch>=0.6.1", + "psycopg2-binary>=2.9.10", # For PostgreSQL database (short term memory) + "pymysql>=1.1.1", # For MySQL database (short term memory) + "opensearch-py==2.8.0", ] [project.scripts] veadk = "veadk.cli.cli:veadk" [project.optional-dependencies] +extensions = [ + "redis>=5.0", # For Redis database + "tos>=2.8.4", # For TOS storage and Viking DB + "llama-index-vector-stores-redis>=0.6.1", + "mcp-server-vikingdb-memory", +] database = [ - "redis>=6.2.0", # For Redis database + "redis>=5.0", # For Redis database "pymysql>=1.1.1", # For MySQL database "volcengine>=1.0.193", # For Viking DB "tos>=2.8.4", # For TOS storage and Viking DB @@ -78,3 +91,6 @@ exclude = [ "veadk/integrations/ve_faas/template/*", "veadk/integrations/ve_faas/web_template/*" ] + +[tool.uv.sources] +mcp-server-vikingdb-memory = { git = "https://github.com/volcengine/mcp-server", subdirectory = "server/mcp_server_vikingdb_memory" } diff --git a/tests/test_agent.py b/tests/test_agent.py index 97a7b82c..cb014f54 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -33,7 +33,11 @@ def test_agent(): - knowledgebase = KnowledgeBase() + knowledgebase = KnowledgeBase( + index="test_index", + backend="local", + backend_config={"embedding_config": {"api_key": "test"}}, + ) long_term_memory = LongTermMemory(backend="local") tracer = OpentelemetryTracer() diff --git a/tests/test_knowledgebase.py b/tests/test_knowledgebase.py index c7f33377..b8e91b3c 100644 --- a/tests/test_knowledgebase.py +++ b/tests/test_knowledgebase.py @@ -12,24 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest from veadk.knowledgebase import KnowledgeBase +from veadk.knowledgebase.backends.in_memory_backend import InMemoryKnowledgeBackend @pytest.mark.asyncio async def test_knowledgebase(): app_name = "kb_test_app" - key = "Supercalifragilisticexpialidocious" - kb = KnowledgeBase(backend="local") - # Attempt to delete any existing data for the app_name before adding new data - kb.add( - data=[f"knowledgebase_id is {key}"], - app_name=app_name, - ) - res_list = kb.search( - query="knowledgebase_id", + kb = KnowledgeBase( + backend="local", app_name=app_name, + backend_config={"embedding_config": {"api_key": "test"}}, ) - res = "".join(res_list) - assert key in res, f"Test failed for backend local res is {res}" + + assert isinstance(kb._backend, InMemoryKnowledgeBackend) diff --git a/tests/test_long_term_memory.py b/tests/test_long_term_memory.py index 12532825..34c67813 100644 --- a/tests/test_long_term_memory.py +++ b/tests/test_long_term_memory.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest -from google.adk.events import Event -from google.adk.sessions import Session from google.adk.tools import load_memory -from google.genai import types from veadk.agent import Agent from veadk.memory.long_term_memory import LongTermMemory @@ -27,7 +25,11 @@ @pytest.mark.asyncio async def test_long_term_memory(): - long_term_memory = LongTermMemory(backend="local") + long_term_memory = LongTermMemory( + backend="local", + # app_name=app_name, + # user_id=user_id, + ) agent = Agent( name="all_name", model_name="test_model_name", @@ -41,31 +43,8 @@ async def test_long_term_memory(): assert load_memory in agent.tools, "load_memory tool not found in agent tools" - # mock session - session = Session( - id="test_session_id", - app_name=app_name, - user_id=user_id, - events=[ - Event( - invocation_id="test_invocation_id", - author="user", - branch=None, - content=types.Content( - parts=[types.Part(text="My name is Alice.")], - role="user", - ), - ) - ], - ) - - await long_term_memory.add_session_to_memory(session) + assert not agent.long_term_memory._backend - memories = await long_term_memory.search_memory( - app_name=app_name, - user_id=user_id, - query="Alice", - ) - assert ( - "Alice" in memories.model_dump()["memories"][0]["content"]["parts"][0]["text"] - ) + # assert agent.long_term_memory._backend.index == build_long_term_memory_index( + # app_name, user_id + # ) diff --git a/tests/test_short_term_memory.py b/tests/test_short_term_memory.py index f0d4ff68..b1cbd0a2 100644 --- a/tests/test_short_term_memory.py +++ b/tests/test_short_term_memory.py @@ -15,7 +15,6 @@ import asyncio import os -import veadk.memory.short_term_memory from veadk.memory.short_term_memory import ShortTermMemory from veadk.utils.misc import formatted_timestamp @@ -35,11 +34,11 @@ def test_short_term_memory(): ) assert session is not None - # database - local - veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH = ( - f"/tmp/tmp_for_test_{formatted_timestamp()}.db" + # sqlite + memory = ShortTermMemory( + backend="sqlite", + local_database_path=f"/tmp/tmp_for_test_{formatted_timestamp()}.db", ) - memory = ShortTermMemory(backend="database") asyncio.run( memory.session_service.create_session( app_name="app", user_id="user", session_id="session" @@ -51,5 +50,5 @@ def test_short_term_memory(): ) ) assert session is not None - assert os.path.exists(veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH) - os.remove(veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH) + assert os.path.exists(memory.local_database_path) + os.remove(memory.local_database_path) diff --git a/veadk/agent.py b/veadk/agent.py index cc7701a3..33abe5ce 100644 --- a/veadk/agent.py +++ b/veadk/agent.py @@ -196,7 +196,6 @@ async def run( collect_runtime_data: bool = False, eval_set_id: str = "", save_session_to_memory: bool = False, - enable_memory_optimization: bool = False, ): """Running the agent. The runner and session service will be created automatically. @@ -226,7 +225,6 @@ async def run( # memory service short_term_memory = ShortTermMemory( backend="database" if load_history_sessions_from_db else "local", - enable_memory_optimization=enable_memory_optimization, db_url=db_url, ) session_service = short_term_memory.session_service diff --git a/veadk/auth/veauth/opensearch_veauth.py b/veadk/auth/veauth/opensearch_veauth.py new file mode 100644 index 00000000..4c2184fb --- /dev/null +++ b/veadk/auth/veauth/opensearch_veauth.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from typing_extensions import override + +from veadk.auth.veauth.base_veauth import BaseVeAuth +from veadk.utils.logger import get_logger + +# from veadk.utils.volcengine_sign import ve_request + +logger = get_logger(__name__) + + +class OpensearchVeAuth(BaseVeAuth): + def __init__( + self, + access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""), + secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""), + ) -> None: + super().__init__(access_key, secret_key) + + self._token: str = "" + + @override + def _fetch_token(self) -> None: + logger.info("Fetching Opensearch STS token...") + + # res = ve_request( + # request_body={}, + # action="GetOrCreatePromptPilotAPIKeys", + # ak=self.access_key, + # sk=self.secret_key, + # service="ark", + # version="2024-01-01", + # region="cn-beijing", + # host="open.volcengineapi.com", + # ) + # try: + # self._token = res["Result"]["APIKeys"][0]["APIKey"] + # except KeyError: + # raise ValueError(f"Failed to get Prompt Pilot token: {res}") + + @property + def token(self) -> str: + if self._token: + return self._token + self._fetch_token() + return self._token diff --git a/veadk/auth/veauth/postgresql_veauth.py b/veadk/auth/veauth/postgresql_veauth.py new file mode 100644 index 00000000..e85a0fac --- /dev/null +++ b/veadk/auth/veauth/postgresql_veauth.py @@ -0,0 +1,75 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from typing_extensions import override + +from veadk.auth.veauth.base_veauth import BaseVeAuth +from veadk.utils.logger import get_logger + +# from veadk.utils.volcengine_sign import ve_request + +logger = get_logger(__name__) + + +class PostgreSqlVeAuth(BaseVeAuth): + def __init__( + self, + access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""), + secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""), + ) -> None: + super().__init__(access_key, secret_key) + + self._token: str = "" + + @override + def _fetch_token(self) -> None: + logger.info("Fetching PostgreSQL STS token...") + + # res = ve_request( + # request_body={}, + # action="GetOrCreatePromptPilotAPIKeys", + # ak=self.access_key, + # sk=self.secret_key, + # service="ark", + # version="2024-01-01", + # region="cn-beijing", + # host="open.volcengineapi.com", + # ) + # try: + # self._token = res["Result"]["APIKeys"][0]["APIKey"] + # except KeyError: + # raise ValueError(f"Failed to get Prompt Pilot token: {res}") + + @property + def token(self) -> str: + if self._token: + return self._token + self._fetch_token() + return self._token diff --git a/veadk/configs/database_configs.py b/veadk/configs/database_configs.py index 86d8300f..c724af04 100644 --- a/veadk/configs/database_configs.py +++ b/veadk/configs/database_configs.py @@ -32,6 +32,8 @@ class OpensearchConfig(BaseSettings): password: str = "" + secret_token: str = "" + class MysqlConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="DATABASE_MYSQL_") @@ -46,6 +48,25 @@ class MysqlConfig(BaseSettings): charset: str = "utf8" + secret_token: str = "" + """STS token for MySQL auth, not supported yet.""" + + +class PostgreSqlConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="DATABASE_POSTGRESQL_") + + host: str = "" + + port: int = 5432 + + user: str = "" + + password: str = "" + + database: str = "" + + secret_token: str = "" + class RedisConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="DATABASE_REDIS_") @@ -58,6 +79,9 @@ class RedisConfig(BaseSettings): db: int = 0 + secret_token: str = "" + """STS token for Redis auth, not supported yet.""" + class VikingKnowledgebaseConfig(BaseSettings): model_config = SettingsConfigDict(env_prefix="DATABASE_VIKING_") @@ -81,3 +105,13 @@ def bucket(self) -> str: VeTOS(bucket_name=_bucket).create_bucket() return _bucket + + +class NormalTOSConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="DATABASE_TOS_") + + endpoint: str = "tos-cn-beijing.volces.com" + + region: str = "cn-beijing" + + bucket: str diff --git a/veadk/configs/model_configs.py b/veadk/configs/model_configs.py index e0efbdb6..8551ac3b 100644 --- a/veadk/configs/model_configs.py +++ b/veadk/configs/model_configs.py @@ -40,3 +40,35 @@ class ModelConfig(BaseSettings): @cached_property def api_key(self) -> str: return os.getenv("MODEL_AGENT_API_KEY") or ARKVeAuth().token + + +class EmbeddingModelConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="MODEL_EMBEDDING_") + + name: str = "doubao-embedding-text-240715" + """Model name for embedding.""" + + dim: int = 2560 + """Embedding dim is different from different models.""" + + api_base: str = "https://ark.cn-beijing.volces.com/api/v3/" + """The api base of the model for embedding.""" + + @cached_property + def api_key(self) -> str: + return os.getenv("MODEL_EMBEDDING_API_KEY") or ARKVeAuth().token + + +class NormalEmbeddingModelConfig(BaseSettings): + model_config = SettingsConfigDict(env_prefix="MODEL_EMBEDDING_") + + name: str = "doubao-embedding-text-240715" + """Model name for embedding.""" + + dim: int = 2560 + """Embedding dim is different from different models.""" + + api_base: str = "https://ark.cn-beijing.volces.com/api/v3/" + """The api base of the model for embedding.""" + + api_key: str diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py deleted file mode 100644 index 55081b0d..00000000 --- a/veadk/database/database_adapter.py +++ /dev/null @@ -1,533 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import re -import time -from typing import BinaryIO, TextIO - -from veadk.database.base_database import BaseDatabase -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - - -class KVDatabaseAdapter: - def __init__(self, client): - from veadk.database.kv.redis_database import RedisDatabase - - self.client: RedisDatabase = client - - def index_exists(self, index: str) -> bool: - """ - Check if the index (key) exists in Redis. - - Args: - index: The Redis key to check - - Returns: - bool: True if the key exists, False otherwise - """ - try: - # Use Redis EXISTS command to check if key exists - return bool(self.client._client.exists(index)) - except Exception as e: - logger.error( - f"Failed to check if key exists in Redis: index={index} error={e}" - ) - return False - - def add(self, data: list[str], index: str, **kwargs): - logger.debug(f"Adding documents to Redis database: index={index}") - - try: - for _data in data: - self.client.add(key=index, value=_data) - logger.debug(f"Added {len(data)} texts to Redis database: index={index}") - except Exception as e: - logger.error( - f"Failed to add data to Redis database: index={index} error={e}" - ) - raise e - - def query(self, query: str, index: str, top_k: int = 0) -> list: - logger.debug(f"Querying Redis database: index={index} query={query}") - - # ignore top_k, as KV search only return one result - _ = top_k - - try: - result = self.client.query(key=index, query=query) - return result - except Exception as e: - logger.error(f"Failed to search from Redis: index={index} error={e}") - raise e - - def delete(self, index: str) -> bool: - logger.debug(f"Deleting key from Redis database: index={index}") - try: - self.client.delete(key=index) - return True - except Exception as e: - logger.error( - f"Failed to delete key from Redis database: index={index} error={e}" - ) - return False - - def delete_doc(self, index: str, id: str) -> bool: - logger.debug(f"Deleting document from Redis database: index={index} id={id}") - try: - # For Redis, we need to handle deletion differently since RedisDatabase.delete_doc - # takes a key and a single id - result = self.client.delete_doc(key=index, id=id) - return result - except Exception as e: - logger.error( - f"Failed to delete document from Redis database: index={index} id={id} error={e}" - ) - return False - - def list_chunks(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: - logger.debug(f"Listing documents from Redis database: index={index}") - try: - # Get all documents from Redis - docs = self.client.list_docs(key=index) - - # Apply offset and limit for pagination - return docs[offset : offset + limit] - except Exception as e: - logger.error( - f"Failed to list documents from Redis database: index={index} error={e}" - ) - return [] - - -class RelationalDatabaseAdapter: - def __init__(self, client): - from veadk.database.relational.mysql_database import MysqlDatabase - - self.client: MysqlDatabase = client - - def index_exists(self, index: str) -> bool: - """ - Check if the table (index) exists in MySQL database. - - Args: - index: The table name to check - - Returns: - bool: True if the table exists, False otherwise - """ - try: - return self.client.table_exists(index) - except Exception as e: - logger.error( - f"Failed to check if table exists in MySQL: index={index} error={e}" - ) - return False - - def create_table(self, table_name: str): - logger.debug(f"Creating table for SQL database: table_name={table_name}") - - sql = f""" - CREATE TABLE `{table_name}` ( - `id` BIGINT AUTO_INCREMENT PRIMARY KEY, - `data` TEXT NOT NULL, - `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) ENGINE=InnoDB DEFAULT CHARSET={self.client.config.charset}; - """ - self.client.add(sql) - - def add(self, data: list[str], index: str, **kwargs): - logger.debug( - f"Adding documents to SQL database: table_name={index} data_len={len(data)}" - ) - - if not self.client.table_exists(index): - logger.warning(f"Table {index} does not exist, creating a new table.") - self.create_table(index) - - for _data in data: - sql = f""" - INSERT INTO `{index}` (`data`) - VALUES (%s); - """ - self.client.add(sql, params=(_data,)) - logger.debug(f"Added {len(data)} texts to table {index}.") - - def query(self, query: str, index: str, top_k: int) -> list[str]: - logger.debug( - f"Querying SQL database: table_name={index} query={query} top_k={top_k}" - ) - - if not self.client.table_exists(index): - logger.warning( - f"Querying SQL database, but table `{index}` does not exist, returning empty list." - ) - return [] - - sql = f""" - SELECT `data` FROM `{index}` ORDER BY `created_at` DESC LIMIT {top_k}; - """ - results = self.client.query(sql) - - return [item["data"] for item in results] - - def delete(self, index: str) -> bool: - logger.debug(f"Deleting table from SQL database: table_name={index}") - try: - self.client.delete(table=index) - return True - except Exception as e: - logger.error( - f"Failed to delete table from SQL database: table_name={index} error={e}" - ) - return False - - def delete_doc(self, index: str, id: str) -> bool: - logger.debug(f"Deleting document from SQL database: table_name={index} id={id}") - try: - # Convert single id to list for the client method - result = self.client.delete_doc(table=index, ids=[int(id)]) - return result - except Exception as e: - logger.error( - f"Failed to delete document from SQL database: table_name={index} id={id} error={e}" - ) - return False - - def list_docs(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: - logger.debug(f"Listing documents from SQL database: table_name={index}") - try: - return self.client.list_docs(table=index, offset=offset, limit=limit) - except Exception as e: - logger.error( - f"Failed to list documents from SQL database: table_name={index} error={e}" - ) - return [] - - -class VectorDatabaseAdapter: - def __init__(self, client): - from veadk.database.vector.opensearch_vector_database import ( - OpenSearchVectorDatabase, - ) - - self.client: OpenSearchVectorDatabase = client - - def index_exists(self, index: str) -> bool: - """ - Check if the collection (index) exists in OpenSearch. - - Args: - index: The collection name to check - - Returns: - bool: True if the collection exists, False otherwise - """ - try: - self._validate_index(index) - return self.client.collection_exists(index) - except Exception as e: - logger.error( - f"Failed to check if collection exists in OpenSearch: index={index} error={e}" - ) - return False - - def _validate_index(self, index: str): - """ - Verify whether the string conforms to the naming rules of index_name in OpenSearch. - https://docs.opensearch.org/2.8/api-reference/index-apis/create-index/ - """ - if not ( - isinstance(index, str) - and not index.startswith(("_", "-")) - and index.islower() - and re.match(r"^[a-z0-9_\-.]+$", index) - ): - raise ValueError( - "The index name does not conform to the naming rules of OpenSearch" - ) - - def add(self, data: list[str], index: str, **kwargs): - self._validate_index(index) - - logger.debug( - f"Adding documents to vector database: index={index} data_len={len(data)}" - ) - - self.client.add(data, collection_name=index) - - def query(self, query: str, index: str, top_k: int) -> list[str]: - logger.debug( - f"Querying vector database: collection_name={index} query={query} top_k={top_k}" - ) - - return self.client.query( - query=query, - collection_name=index, - top_k=top_k, - ) - - def delete(self, index: str) -> bool: - self._validate_index(index) - logger.debug(f"Deleting collection from vector database: index={index}") - try: - self.client.delete(collection_name=index) - return True - except Exception as e: - logger.error( - f"Failed to delete collection from vector database: index={index} error={e}" - ) - return False - - def delete_doc(self, index: str, id: str) -> bool: - self._validate_index(index) - logger.debug(f"Deleting documents from vector database: index={index} id={id}") - try: - self.client.delete_by_id(collection_name=index, id=id) - return True - except Exception as e: - logger.error( - f"Failed to delete document from vector database: index={index} id={id} error={e}" - ) - return False - - def list_chunks(self, index: str, offset: int = 0, limit: int = 1000) -> list[dict]: - self._validate_index(index) - logger.debug(f"Listing documents from vector database: index={index}") - return self.client.list_docs(collection_name=index, offset=offset, limit=limit) - - -class VikingDatabaseAdapter: - def __init__(self, client): - from veadk.database.viking.viking_database import VikingDatabase - - self.client: VikingDatabase = client - - def index_exists(self, index: str) -> bool: - """ - Check if the collection (index) exists in VikingDB. - - Args: - index: The collection name to check - - Returns: - bool: True if the collection exists, False otherwise - """ - try: - self._validate_index(index) - return self.client.collection_exists(index) - except Exception as e: - logger.error( - f"Failed to check if collection exists in VikingDB: index={index} error={e}" - ) - return False - - def _validate_index(self, index: str): - """ - Only English letters, numbers, and underscores (_) are allowed. - It must start with an English letter and cannot be empty. Length requirement: [1, 128]. - For details, please see: https://www.volcengine.com/docs/84313/1254542?lang=zh - """ - if not ( - isinstance(index, str) - and 0 < len(index) <= 128 - and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index) - ): - raise ValueError( - "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." - ) - - def get_or_create_collection(self, collection_name: str): - if not self.client.collection_exists(collection_name): - logger.warning( - f"Collection {collection_name} does not exist, creating a new collection." - ) - self.client.create_collection(collection_name) - - # After creation, it is necessary to wait for a while. - count = 0 - while not self.client.collection_exists(collection_name): - print("here") - time.sleep(1) - count += 1 - if count > 60: - raise TimeoutError( - f"Collection {collection_name} not created after 50 seconds" - ) - - def add( - self, data: str | list[str] | TextIO | BinaryIO | bytes, index: str, **kwargs - ): - self._validate_index(index) - - logger.debug(f"Adding documents to Viking database: collection_name={index}") - - self.get_or_create_collection(index) - self.client.add(data, collection_name=index, **kwargs) - - def query(self, query: str, index: str, top_k: int) -> list[str]: - self._validate_index(index) - - logger.debug(f"Querying Viking database: collection_name={index} query={query}") - - if not self.client.collection_exists(index): - return [] - - return self.client.query(query, collection_name=index, top_k=top_k) - - def delete(self, index: str) -> bool: - self._validate_index(index) - logger.debug(f"Deleting collection from Viking database: index={index}") - return self.client.delete(name=index) - - def delete_doc(self, index: str, id: str) -> bool: - self._validate_index(index) - logger.debug(f"Deleting documents from vector database: index={index} id={id}") - return self.client.delete_by_id(collection_name=index, id=id) - - def list_chunks(self, index: str, offset: int, limit: int) -> list[dict]: - self._validate_index(index) - logger.debug(f"Listing documents from vector database: index={index}") - return self.client.list_chunks( - collection_name=index, offset=offset, limit=limit - ) - - def list_docs(self, index: str, offset: int, limit: int) -> list[dict]: - self._validate_index(index) - logger.debug(f"Listing documents from vector database: index={index}") - return self.client.list_docs(collection_name=index, offset=offset, limit=limit) - - -class VikingMemoryDatabaseAdapter: - def __init__(self, client): - from veadk.database.viking.viking_memory_db import VikingMemoryDatabase - - self.client: VikingMemoryDatabase = client - - def index_exists(self, index: str) -> bool: - """ - Check if the collection (index) exists in VikingMemoryDB. - - Note: - VikingMemoryDatabase does not support checking if a collection exists. - This method always returns False. - - Args: - index: The collection name to check - - Returns: - bool: Always returns False as VikingMemoryDatabase does not support this functionality - """ - logger.warning( - "VikingMemoryDatabase does not support checking if a collection exists" - ) - raise NotImplementedError("VikingMemoryDatabase does not support index_exists") - - def _validate_index(self, index: str): - if not ( - isinstance(index, str) - and 1 <= len(index) <= 128 - and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index) - ): - raise ValueError( - "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." - ) - - def add(self, data: list[str], index: str, **kwargs): - self._validate_index(index) - - logger.debug( - f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}" - ) - - self.client.add(data, collection_name=index, **kwargs) - - def query(self, query: str, index: str, top_k: int, **kwargs): - self._validate_index(index) - - logger.debug( - f"Querying Viking database memory: collection_name={index} query={query} top_k={top_k}" - ) - - result = self.client.query(query, collection_name=index, top_k=top_k, **kwargs) - return result - - def delete(self, index: str) -> bool: - self._validate_index(index) - logger.debug(f"Deleting collection from Viking database memory: index={index}") - raise NotImplementedError("VikingMemoryDatabase does not support delete") - - def delete_docs(self, index: str, ids: list[int]): - raise NotImplementedError("VikingMemoryDatabase does not support delete_docs") - - def list_chunks(self, index: str): - raise NotImplementedError("VikingMemoryDatabase does not support list_docs") - - -class LocalDatabaseAdapter: - def __init__(self, client): - from veadk.database.local_database import LocalDataBase - - self.client: LocalDataBase = client - - def index_exists(self, index: str) -> bool: - """ - Check if the index exists in LocalDataBase. - - Note: - LocalDataBase does not support checking if an index exists. - This method always returns False. - - Args: - index: The index name to check (not used in LocalDataBase) - - Returns: - bool: Always returns False as LocalDataBase does not support this functionality - """ - logger.warning("LocalDataBase does not support checking if an index exists") - return True - - def add(self, data: list[str], **kwargs): - self.client.add(data) - - def query(self, query: str, **kwargs): - return self.client.query(query, **kwargs) - - def delete(self, index: str) -> bool: - return self.client.delete() - - def delete_doc(self, index: str, id: str) -> bool: - return self.client.delete_doc(id) - - def list_chunks(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: - return self.client.list_docs(offset=offset, limit=limit) - - -MAPPING = { - "RedisDatabase": KVDatabaseAdapter, - "MysqlDatabase": RelationalDatabaseAdapter, - "LocalDataBase": LocalDatabaseAdapter, - "VikingDatabase": VikingDatabaseAdapter, - "OpenSearchVectorDatabase": VectorDatabaseAdapter, - "VikingMemoryDatabase": VikingMemoryDatabaseAdapter, -} - - -def get_knowledgebase_database_adapter(database_client: BaseDatabase): - return MAPPING[type(database_client).__name__](client=database_client) - - -def get_long_term_memory_database_adapter(database_client: BaseDatabase): - return MAPPING[type(database_client).__name__](client=database_client) diff --git a/veadk/database/database_factory.py b/veadk/database/database_factory.py deleted file mode 100644 index 838b008b..00000000 --- a/veadk/database/database_factory.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from veadk.utils.logger import get_logger - -from .base_database import BaseDatabase - -logger = get_logger(__name__) - - -class DatabaseBackend: - OPENSEARCH = "opensearch" - LOCAL = "local" - MYSQL = "mysql" - REDIS = "redis" - VIKING = "viking" - VIKING_MEM = "viking_mem" - - @classmethod - def get_attr(cls) -> set[str]: - return { - value - for attr, value in cls.__dict__.items() - if not attr.startswith("__") and attr != "get_attr" - } - - -class DatabaseFactory: - @staticmethod - def create(backend: str, config=None) -> BaseDatabase: - if backend not in DatabaseBackend.get_attr(): - logger.warning(f"Unknown backend: {backend}), change backend to `local`") - backend = "local" - - if backend == DatabaseBackend.LOCAL: - from .local_database import LocalDataBase - - return LocalDataBase() - if backend == DatabaseBackend.OPENSEARCH: - from .vector.opensearch_vector_database import OpenSearchVectorDatabase - - return ( - OpenSearchVectorDatabase() - if config is None - else OpenSearchVectorDatabase(config=config) - ) - if backend == DatabaseBackend.MYSQL: - from .relational.mysql_database import MysqlDatabase - - return MysqlDatabase() if config is None else MysqlDatabase(config=config) - if backend == DatabaseBackend.REDIS: - from .kv.redis_database import RedisDatabase - - return RedisDatabase() if config is None else RedisDatabase(config=config) - if backend == DatabaseBackend.VIKING: - from .viking.viking_database import VikingDatabase - - return VikingDatabase() if config is None else VikingDatabase(config=config) - - if backend == DatabaseBackend.VIKING_MEM: - from .viking.viking_memory_db import VikingMemoryDatabase - - return ( - VikingMemoryDatabase() - if config is None - else VikingMemoryDatabase(config=config) - ) - else: - raise ValueError(f"Unsupported database type: {backend}") diff --git a/veadk/database/kv/redis_database.py b/veadk/database/kv/redis_database.py deleted file mode 100644 index c79144aa..00000000 --- a/veadk/database/kv/redis_database.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any - -import redis -from pydantic import BaseModel, Field -from typing_extensions import override - -from veadk.config import getenv -from veadk.utils.logger import get_logger - -from ..base_database import BaseDatabase - -logger = get_logger(__name__) - - -class RedisDatabaseConfig(BaseModel): - host: str = Field( - default_factory=lambda: getenv("DATABASE_REDIS_HOST"), - description="Redis host", - ) - port: int = Field( - default_factory=lambda: int(getenv("DATABASE_REDIS_PORT")), - description="Redis port", - ) - db: int = Field( - default_factory=lambda: int(getenv("DATABASE_REDIS_DB")), - description="Redis db", - ) - password: str = Field( - default_factory=lambda: getenv("DATABASE_REDIS_PASSWORD"), - description="Redis password", - ) - decode_responses: bool = Field( - default=True, - description="Redis decode responses", - ) - - -class RedisDatabase(BaseModel, BaseDatabase): - config: RedisDatabaseConfig = Field(default_factory=RedisDatabaseConfig) - - def model_post_init(self, context: Any, /) -> None: - try: - self._client = redis.StrictRedis( - host=self.config.host, - port=self.config.port, - db=self.config.db, - password=self.config.password, - decode_responses=self.config.decode_responses, - ) - - self._client.ping() - logger.info("Connected to Redis successfully.") - except Exception as e: - logger.error(f"Failed to connect to Redis: {e}") - raise e - - @override - def add(self, key: str, value: str, **kwargs): - try: - self._client.rpush(key, value) - except Exception as e: - logger.error(f"Failed to add value to Redis list key `{key}`: {e}") - raise e - - @override - def query(self, key: str, query: str = "", **kwargs) -> list: - try: - result = self._client.lrange(key, 0, -1) - return result # type: ignore - except Exception as e: - logger.error(f"Failed to search from Redis list key '{key}': {e}") - raise e - - @override - def delete(self, **kwargs): - """Delete Redis list key based on app_name, user_id and session_id, or directly by key.""" - key = kwargs.get("key") - if key is None: - app_name = kwargs.get("app_name") - user_id = kwargs.get("user_id") - session_id = kwargs.get("session_id") - key = f"{app_name}:{user_id}:{session_id}" - - try: - # For simple key deletion - # We use sync Redis client to delete the key - # so the result will be `int` - result = self._client.delete(key) - - if result > 0: # type: ignore - logger.info(f"Deleted key `{key}` from Redis.") - else: - logger.info(f"Key `{key}` not found in Redis. Skipping deletion.") - except Exception as e: - logger.error(f"Failed to delete key `{key}`: {e}") - raise e - - def delete_doc(self, key: str, id: str) -> bool: - """Delete a specific document by ID from a Redis list. - - Args: - key: The Redis key (list) to delete from - id: The ID of the document to delete - - Returns: - bool: True if deletion was successful, False otherwise - """ - try: - # Get all items in the list - items = self._client.lrange(key, 0, -1) - - # Find the index of the item to delete - for i, item in enumerate(items): - # Assuming the item is stored as a JSON string with an 'id' field - # If it's just the content, we'll use the list index as ID - if str(i) == id: - self._client.lrem(key, 1, item) - return True - - logger.warning(f"Document with id {id} not found in key {key}") - return False - except Exception as e: - logger.error(f"Failed to delete document with id {id} from key {key}: {e}") - return False - - def list_docs(self, key: str) -> list[dict]: - """List all documents in a Redis list. - - Args: - key: The Redis key (list) to list documents from - - Returns: - list[dict]: List of documents with id and content - """ - try: - items = self._client.lrange(key, 0, -1) - return [ - {"id": str(i), "content": item, "metadata": {}} - for i, item in enumerate(items) - ] - except Exception as e: - logger.error(f"Failed to list documents from key {key}: {e}") - return [] diff --git a/veadk/database/local_database.py b/veadk/database/local_database.py deleted file mode 100644 index e5d12290..00000000 --- a/veadk/database/local_database.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from .base_database import BaseDatabase - - -class LocalDataBase(BaseDatabase): - """This database is only for basic demonstration. - It does not support the vector search function, and the `search` function will return all data. - """ - - def __init__(self, **kwargs): - super().__init__() - self.data = {} - self._type = "local" - self._next_id = 0 # Used to generate unique IDs - - def add_texts(self, texts: list[str], **kwargs): - for text in texts: - self.data[str(self._next_id)] = text - self._next_id += 1 - - def is_empty(self): - return len(self.data) == 0 - - def query(self, query: str, **kwargs: Any) -> list[str]: - return list(self.data.values()) - - def delete(self, **kwargs: Any): - self.data = {} - return True - - def add(self, texts: list[str], **kwargs: Any): - return self.add_texts(texts) - - def list_docs(self, **kwargs: Any) -> list[dict]: - return [ - {"id": id, "content": content, "metadata": {}} - for id, content in self.data.items() - ] - - def delete_doc(self, id: str, **kwargs: Any): - if id not in self.data: - raise ValueError(f"id {id} not found") - try: - del self.data[id] - return True - except Exception: - return False diff --git a/veadk/database/relational/mysql_database.py b/veadk/database/relational/mysql_database.py deleted file mode 100644 index 644be67e..00000000 --- a/veadk/database/relational/mysql_database.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any - -import pymysql -from pydantic import BaseModel, Field -from typing_extensions import override - -from veadk.config import getenv -from veadk.utils.logger import get_logger - -from ..base_database import BaseDatabase - -logger = get_logger(__name__) - - -class MysqlDatabaseConfig(BaseModel): - host: str = Field( - default_factory=lambda: getenv("DATABASE_MYSQL_HOST"), - description="Mysql host", - ) - user: str = Field( - default_factory=lambda: getenv("DATABASE_MYSQL_USER"), - description="Mysql user", - ) - password: str = Field( - default_factory=lambda: getenv("DATABASE_MYSQL_PASSWORD"), - description="Mysql password", - ) - database: str = Field( - default_factory=lambda: getenv("DATABASE_MYSQL_DATABASE"), - description="Mysql database", - ) - charset: str = Field( - default_factory=lambda: getenv("DATABASE_MYSQL_CHARSET", "utf8mb4"), - description="Mysql charset", - ) - - -class MysqlDatabase(BaseModel, BaseDatabase): - config: MysqlDatabaseConfig = Field(default_factory=MysqlDatabaseConfig) - - def model_post_init(self, context: Any, /) -> None: - self._connection = pymysql.connect( - host=self.config.host, - user=self.config.user, - password=self.config.password, - database=self.config.database, - charset=self.config.charset, - cursorclass=pymysql.cursors.DictCursor, - ) - self._connection.ping() - logger.info("Connected to MySQL successfully.") - - self._type = "mysql" - - def table_exists(self, table: str) -> bool: - with self._connection.cursor() as cursor: - cursor.execute( - "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s", - (self.config.database, table), - ) - result = cursor.fetchone() - return result is not None - - @override - def add(self, sql: str, params=None, **kwargs): - with self._connection.cursor() as cursor: - cursor.execute(sql, params) - self._connection.commit() - - @override - def query(self, sql: str, params=None, **kwargs) -> tuple[dict[str, Any], ...]: - with self._connection.cursor() as cursor: - cursor.execute(sql, params) - return cursor.fetchall() - - @override - def delete(self, **kwargs): - table = kwargs.get("table") - if table is None: - app_name = kwargs.get("app_name", "default") - table = app_name - - if not self.table_exists(table): - logger.warning(f"Table {table} does not exist. Skipping delete operation.") - return - - try: - with self._connection.cursor() as cursor: - # Drop the table directly - sql = f"DROP TABLE `{table}`" - cursor.execute(sql) - self._connection.commit() - logger.info(f"Dropped table {table}") - except Exception as e: - logger.error(f"Failed to drop table {table}: {e}") - raise e - - def delete_doc(self, table: str, ids: list[int]) -> bool: - """Delete documents by IDs from a MySQL table. - - Args: - table: The table name to delete from - ids: List of document IDs to delete - - Returns: - bool: True if deletion was successful, False otherwise - """ - if not self.table_exists(table): - logger.warning(f"Table {table} does not exist. Skipping delete operation.") - return False - - if not ids: - return True # Nothing to delete - - try: - with self._connection.cursor() as cursor: - # Create placeholders for the IDs - placeholders = ",".join(["%s"] * len(ids)) - sql = f"DELETE FROM `{table}` WHERE id IN ({placeholders})" - cursor.execute(sql, ids) - self._connection.commit() - logger.info(f"Deleted {cursor.rowcount} documents from table {table}") - return True - except Exception as e: - logger.error(f"Failed to delete documents from table {table}: {e}") - return False - - def list_docs(self, table: str, offset: int = 0, limit: int = 100) -> list[dict]: - """List documents from a MySQL table. - - Args: - table: The table name to list documents from - offset: Offset for pagination - limit: Limit for pagination - - Returns: - list[dict]: List of documents with id and content - """ - if not self.table_exists(table): - logger.warning(f"Table {table} does not exist. Returning empty list.") - return [] - - try: - with self._connection.cursor() as cursor: - sql = f"SELECT id, data FROM `{table}` ORDER BY created_at DESC LIMIT %s OFFSET %s" - cursor.execute(sql, (limit, offset)) - results = cursor.fetchall() - return [ - {"id": str(row["id"]), "content": row["data"], "metadata": {}} - for row in results - ] - except Exception as e: - logger.error(f"Failed to list documents from table {table}: {e}") - return [] - - def is_empty(self): - pass diff --git a/veadk/database/vector/opensearch_vector_database.py b/veadk/database/vector/opensearch_vector_database.py deleted file mode 100644 index 4ba71d7b..00000000 --- a/veadk/database/vector/opensearch_vector_database.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import os -from typing import Any, Literal, Optional - -from opensearchpy import OpenSearch, Urllib3HttpConnection, helpers -from pydantic import BaseModel, Field, PrivateAttr -from typing_extensions import override - -from veadk.config import getenv -from veadk.utils.logger import get_logger - -from ..base_database import BaseDatabase -from .type import Embeddings - -logger = get_logger(__name__) - - -class OpenSearchVectorDatabaseConfig(BaseModel): - host: str = Field( - default_factory=lambda: getenv("DATABASE_OPENSEARCH_HOST"), - description="OpenSearch host", - ) - - port: str | int = Field( - default_factory=lambda: getenv("DATABASE_OPENSEARCH_PORT"), - description="OpenSearch port", - ) - - username: Optional[str] = Field( - default_factory=lambda: getenv("DATABASE_OPENSEARCH_USERNAME"), - description="OpenSearch username", - ) - - password: Optional[str] = Field( - default_factory=lambda: getenv("DATABASE_OPENSEARCH_PASSWORD"), - description="OpenSearch password", - ) - - secure: bool = Field(default=True, description="Whether enable SSL") - - verify_certs: bool = Field(default=False, description="Whether verify SSL certs") - - auth_method: Literal["basic", "aws_managed_iam"] = Field( - default="basic", description="OpenSearch auth method" - ) - - def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": [{"host": self.host, "port": int(self.port)}], - "use_ssl": self.secure, - "verify_certs": self.verify_certs, - "connection_class": Urllib3HttpConnection, - "pool_maxsize": 20, - } - ca_cert_path = os.getenv("OPENSEARCH_CA_CERT") - if self.verify_certs and ca_cert_path: - params["ca_certs"] = ca_cert_path - - params["http_auth"] = (self.username, self.password) - - return params - - -class OpenSearchVectorDatabase(BaseModel, BaseDatabase): - config: OpenSearchVectorDatabaseConfig = Field( - default_factory=OpenSearchVectorDatabaseConfig - ) - - _embedding_client: Embeddings = PrivateAttr() - _opensearch_client: OpenSearch = PrivateAttr() - - def model_post_init(self, context: Any, /) -> None: - self._embedding_client = Embeddings() - self._opensearch_client = OpenSearch(**self.config.to_opensearch_params()) - - self._type = "opensearch" - - def _get_settings(self) -> dict: - settings = {"index": {"knn": True}} - return settings - - def _get_mappings(self, dim: int = 2560) -> dict: - mappings = { - "properties": { - "page_content": { - "type": "text", - }, - "vector": { - "type": "knn_vector", - "dimension": dim, - "method": { - "name": "hnsw", - "space_type": "l2", - "engine": "faiss", - "parameters": {"ef_construction": 64, "m": 8}, - }, - }, - } - } - return mappings - - def create_collection( - self, - collection_name: str, - embedding_dim: int, - ): - if not self._opensearch_client.indices.exists(index=collection_name): - self._opensearch_client.indices.create( - index=collection_name, - body={ - "mappings": self._get_mappings(dim=embedding_dim), - "settings": self._get_settings(), - }, - ) - else: - logger.warning(f"Collection {collection_name} already exists.") - - self._opensearch_client.indices.refresh(index=collection_name) - return - - def _search_by_vector( - self, collection_name: str, query_vector: list[float], **kwargs: Any - ) -> list[str]: - top_k = kwargs.get("top_k", 5) - query = { - "size": top_k, - "query": {"knn": {"vector": {"vector": query_vector, "k": top_k}}}, - } - response = self._opensearch_client.search(index=collection_name, body=query) - - result_list = [] - for hit in response["hits"]["hits"]: - result_list.append(hit["_source"]["page_content"]) - - return result_list - - def get_health(self): - response = self._opensearch_client.cat.health() - logger.info(response) - - def add(self, texts: list[str], **kwargs): - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "Collection name is required." - if not self._opensearch_client.indices.exists(index=collection_name): - self.create_collection( - embedding_dim=self._embedding_client.get_embedding_dim(), - collection_name=collection_name, - ) - - actions = [] - embeddings = self._embedding_client.embed_documents(texts) - for i in range(len(texts)): - action = { - "_op_type": "index", - "_index": collection_name, - "_source": { - "page_content": texts[i], - "vector": embeddings[i], - }, - } - actions.append(action) - - helpers.bulk( - client=self._opensearch_client, - actions=actions, - timeout=30, - max_retries=3, - ) - - self._opensearch_client.indices.refresh(index=collection_name) - return - - @override - def query(self, query: str, **kwargs: Any) -> list[str]: - collection_name = kwargs.get("collection_name") - top_k = kwargs.get("top_k", 5) - assert collection_name is not None, "Collection name is required." - if not self._opensearch_client.indices.exists(index=collection_name): - logger.warning( - f"querying {query}, but collection {collection_name} does not exist. return a empty list." - ) - return [] - query_vector = self._embedding_client.embed_query(query) - return self._search_by_vector( - collection_name=collection_name, query_vector=query_vector, top_k=top_k - ) - - @override - def delete(self, collection_name: str, **kwargs: Any): - """drop index""" - if not self._opensearch_client.indices.exists(index=collection_name): - raise ValueError(f"Collection {collection_name} does not exist.") - self._opensearch_client.indices.delete(index=collection_name) - - def is_empty(self, collection_name: str) -> bool: - response = self._opensearch_client.count(index=collection_name) - return response["count"] == 0 - - def collection_exists(self, collection_name: str) -> bool: - return self._opensearch_client.indices.exists(index=collection_name) - - def list_all_collection(self) -> list: - """List all index name of OpenSearch.""" - response = self._opensearch_client.indices.get_alias() - return list(response.keys()) - - def list_docs( - self, collection_name: str, offset: int = 0, limit: int = 10000 - ) -> list[dict]: - """Match all docs in one index of OpenSearch""" - if not self.collection_exists(collection_name): - logger.warning( - f"Get all docs, but collection {collection_name} does not exist. return a empty list." - ) - return [] - - query = {"size": limit, "from": offset, "query": {"match_all": {}}} - response = self._opensearch_client.search(index=collection_name, body=query) - return [ - { - "id": hit["_id"], - "content": hit["_source"]["page_content"], - "metadata": {}, - } - for hit in response["hits"]["hits"] - ] - - def delete_by_query(self, collection_name: str, query: str) -> Any: - """Delete docs by query in one index of OpenSearch""" - if not self.collection_exists(collection_name): - raise ValueError(f"Collection {collection_name} does not exist.") - - query_payload = {"query": {"match": {"page_content": query}}} - response = self._opensearch_client.delete_by_query( - index=collection_name, body=query_payload - ) - - self._opensearch_client.indices.refresh(index=collection_name) - return response - - def delete_by_id(self, collection_name: str, id: str): - """Delete docs by id in index of OpenSearch""" - if not self.collection_exists(collection_name): - raise ValueError(f"Collection {collection_name} does not exist.") - - response = self._opensearch_client.delete(index=collection_name, id=id) - self._opensearch_client.indices.refresh(index=collection_name) - return response diff --git a/veadk/database/vector/type.py b/veadk/database/vector/type.py deleted file mode 100644 index fe6f6302..00000000 --- a/veadk/database/vector/type.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import requests - -from veadk.config import getenv - - -class Embeddings: - def __init__( - self, - model: str = getenv("MODEL_EMBEDDING_NAME"), - api_base: str = getenv("MODEL_EMBEDDING_API_BASE"), - api_key: str = getenv("MODEL_EMBEDDING_API_KEY"), - dim: int = int(getenv("MODEL_EMBEDDING_DIM")), - ): - self.model = model - self.url = api_base - self.api_key = api_key - self.dim = dim - - self.headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}", - } - - def embed_documents(self, texts: list[str]) -> list[list[float]]: - MAX_CHARS = 4000 - data = {"model": self.model, "input": [text[:MAX_CHARS] for text in texts]} - response = requests.post(self.url, headers=self.headers, json=data) - response.raise_for_status() - result = response.json() - return [item["embedding"] for item in result["data"]] - - def embed_query(self, text: str) -> list[float]: - return self.embed_documents([text])[0] - - def get_embedding_dim(self) -> int: - return self.dim diff --git a/veadk/database/viking/__init__.py b/veadk/database/viking/__init__.py deleted file mode 100644 index 7f463206..00000000 --- a/veadk/database/viking/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/veadk/database/viking/viking_database.py b/veadk/database/viking/viking_database.py deleted file mode 100644 index 3127b0e2..00000000 --- a/veadk/database/viking/viking_database.py +++ /dev/null @@ -1,638 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import json -import os -import uuid -from typing import Any, BinaryIO, Literal, TextIO - -import requests -import tos -from pydantic import BaseModel, Field -from volcengine.auth.SignerV4 import SignerV4 -from volcengine.base.Request import Request -from volcengine.Credentials import Credentials - -from veadk.config import getenv -from veadk.database.base_database import BaseDatabase -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - -# knowledge base domain -g_knowledge_base_domain = "api-knowledgebase.mlp.cn-beijing.volces.com" -# paths -create_collection_path = "/api/knowledge/collection/create" -search_knowledge_path = "/api/knowledge/collection/search_knowledge" -list_collections_path = "/api/knowledge/collection/list" -get_collections_path = "/api/knowledge/collection/info" -doc_del_path = "/api/knowledge/collection/delete" -doc_add_path = "/api/knowledge/doc/add" -doc_info_path = "/api/knowledge/doc/info" -list_point_path = "/api/knowledge/point/list" -list_docs_path = "/api/knowledge/doc/list" -delete_docs_path = "/api/knowledge/doc/delete" - - -class VolcengineTOSConfig(BaseModel): - endpoint: str = Field( - default_factory=lambda: getenv( - "DATABASE_TOS_ENDPOINT", "tos-cn-beijing.volces.com" - ), - description="VikingDB TOS endpoint", - ) - region: str = Field( - default_factory=lambda: getenv("DATABASE_TOS_REGION", "cn-beijing"), - description="VikingDB TOS region", - ) - bucket: str = Field( - default_factory=lambda: getenv("DATABASE_TOS_BUCKET"), - description="VikingDB TOS bucket", - ) - base_key: str = Field( - default="veadk", - description="VikingDB TOS base key", - ) - - -class VikingDatabaseConfig(BaseModel): - volcengine_ak: str = Field( - default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"), - description="VikingDB access key", - ) - volcengine_sk: str = Field( - default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"), - description="VikingDB secret key", - ) - project: str = Field( - default_factory=lambda: getenv("DATABASE_VIKING_PROJECT"), - description="VikingDB project name", - ) - region: str = Field( - default_factory=lambda: getenv("DATABASE_VIKING_REGION"), - description="VikingDB region", - ) - tos: VolcengineTOSConfig = Field( - default_factory=VolcengineTOSConfig, - description="VikingDB TOS configuration", - ) - - -def prepare_request( - method, path, config: VikingDatabaseConfig, params=None, data=None, doseq=0 -): - ak = config.volcengine_ak - sk = config.volcengine_sk - - if params: - for key in params: - if ( - type(params[key]) is int - or type(params[key]) is float - or type(params[key]) is bool - ): - params[key] = str(params[key]) - elif type(params[key]) is list: - if not doseq: - params[key] = ",".join(params[key]) - r = Request() - r.set_shema("https") - r.set_method(method) - r.set_connection_timeout(10) - r.set_socket_timeout(10) - mheaders = { - "Accept": "application/json", - "Content-Type": "application/json", - } - r.set_headers(mheaders) - if params: - r.set_query(params) - r.set_path(path) - if data is not None: - r.set_body(json.dumps(data)) - credentials = Credentials(ak, sk, "air", config.region) - SignerV4.sign(r, credentials) - return r - - -class VikingDatabase(BaseModel, BaseDatabase): - config: VikingDatabaseConfig = Field( - default_factory=VikingDatabaseConfig, - description="VikingDB configuration", - ) - - def _upload_to_tos( - self, - data: str | list[str] | TextIO | BinaryIO | bytes, - **kwargs: Any, - ) -> tuple[int, str]: - """ - Upload data to TOS (Tinder Object Storage). - - Args: - data: The data to be uploaded. Can be one of the following types: - - str: File path or string data - - list[str]: List of strings - - TextIO: File object (text) - - BinaryIO: File object (binary) - - bytes: Binary data - **kwargs: Additional keyword arguments. - - file_name (str): The file name (including suffix). - - Returns: - tuple: A tuple containing the status code and TOS URL. - - status_code (int): HTTP status code - - tos_url (str): The URL of the uploaded file in TOS - """ - ak = self.config.volcengine_ak - sk = self.config.volcengine_sk - - tos_bucket = self.config.tos.bucket - tos_endpoint = self.config.tos.endpoint - tos_region = self.config.tos.region - tos_key = self.config.tos.base_key - - client = tos.TosClientV2(ak, sk, tos_endpoint, tos_region, max_connections=1024) - - # Extract file_name from kwargs - this is now required and includes the extension - file_names = kwargs.get("file_name") - - if isinstance(data, str) and os.path.isfile(data): # Process file path - # Use provided file_name which includes the extension - new_key = f"{tos_key}/{file_names}" - with open(data, "rb") as f: - upload_data = f.read() - - elif ( - isinstance(data, list) - and all(isinstance(item, str) for item in data) - and all(os.path.isfile(item) for item in data) - ): - # Process list of file paths - this should be handled at a higher level - raise ValueError( - "Uploading multiple files through a list of file paths is not supported in _upload_to_tos directly. Please call this function for each file individually." - ) - - elif isinstance( - data, - (io.TextIOWrapper, io.BufferedReader), # file type: TextIO | BinaryIO - ): # Process file stream - # Use provided file_name which includes the extension - new_key = f"{tos_key}/{file_names}" - if isinstance(data, TextIO): - # Encode the text stream content into bytes - upload_data = data.read().encode("utf-8") - else: - # Read the content of the binary stream - upload_data = data.read() - - elif isinstance(data, str): # Process ordinary strings - # Use provided file_name which includes the extension - new_key = f"{tos_key}/{file_names}" - upload_data = data.encode("utf-8") # Encode as byte type - - elif isinstance(data, list): # Process list of strings - # Use provided file_name which includes the extension - new_key = f"{tos_key}/{file_names}" - # Join the strings in the list with newlines and encode as byte type - upload_data = "\n".join(data).encode("utf-8") - - elif isinstance(data, bytes): # Process bytes data - # Use provided file_name which includes the extension - new_key = f"{tos_key}/{file_names}" - upload_data = data - - else: - raise ValueError(f"Unsupported data type: {type(data)}") - - resp = client.put_object(tos_bucket, new_key, content=upload_data) - tos_url = f"{tos_bucket}/{new_key}" - - return resp.resp.status, tos_url - - def _add_doc(self, collection_name: str, tos_url: str, doc_id: str, **kwargs: Any): - request_params = { - "collection_name": collection_name, - "project": self.config.project, - "add_type": "tos", - "doc_id": doc_id, - "tos_path": tos_url, - } - - doc_add_req = prepare_request( - method="POST", path=doc_add_path, config=self.config, data=request_params - ) - rsp = requests.request( - method=doc_add_req.method, - url="https://{}{}".format(g_knowledge_base_domain, doc_add_req.path), - headers=doc_add_req.headers, - data=doc_add_req.body, - ) - - result = rsp.json() - if result["code"] != 0: - logger.error(f"Error in add_doc: {result['message']}") - return {"error": result["message"]} - - doc_add_data = result["data"] - if not doc_add_data: - raise ValueError(f"doc {doc_id} has no data.") - - return doc_id - - def add( - self, - data: str | list[str] | TextIO | BinaryIO | bytes, - collection_name: str, - **kwargs, - ): - """ - Add documents to the Viking database. - Args: - data: The data to be added. Can be one of the following types: - - str: File path or string data - - list[str]: List of file paths or list of strings - - TextIO: File object (text) - - BinaryIO: File object (binary) - - bytes: Binary data - collection_name: The name of the collection to add documents to. - **kwargs: Additional keyword arguments. - - file_name (str | list[str]): The file name or a list of file names (including suffix). - - doc_id (str): The document ID. If not provided, a UUID will be generated. - Returns: - dict or list: A dictionary containing the TOS URL and document ID, or a list of such dictionaries for multiple file uploads. - Format: { - "tos_url": "tos:///", - "doc_id": "", - } - """ - # Handle list of file paths (multiple file upload) - if ( - isinstance(data, list) - and all(isinstance(item, str) for item in data) - and all(os.path.isfile(item) for item in data) - ): - # Handle multiple file upload - file_names = kwargs.get("file_name") - if ( - not file_names - or not isinstance(file_names, list) - or len(file_names) != len(data) - ): - raise ValueError( - "For multiple file upload, file_name must be provided as a list with the same length as data" - ) - - results = [] - for i, file_path in enumerate(data): - # Create kwargs for this specific file - single_kwargs = kwargs.copy() - single_kwargs["file_name"] = file_names[i] - - # Generate or use provided doc_id for this file - doc_id = single_kwargs.get("doc_id") - if not doc_id: - doc_id = str(uuid.uuid4()) - single_kwargs["doc_id"] = doc_id - - status, tos_url = self._upload_to_tos(data=file_path, **single_kwargs) - if status != 200: - raise ValueError( - f"Error in upload_to_tos for file {file_path}: {status}" - ) - - doc_id = self._add_doc( - collection_name=collection_name, - tos_url=tos_url, - doc_id=doc_id, - ) - - results.append( - { - "tos_url": f"tos://{tos_url}", - "doc_id": doc_id, - } - ) - - return results - - # Handle list of strings (multiple string upload) - elif isinstance(data, list) and all(isinstance(item, str) for item in data): - # Handle multiple string upload - file_names = kwargs.get("file_name") - if ( - not file_names - or not isinstance(file_names, list) - or len(file_names) != len(data) - ): - raise ValueError( - "For multiple string upload, file_name must be provided as a list with the same length as data" - ) - - results = [] - for i, content in enumerate(data): - # Create kwargs for this specific string - single_kwargs = kwargs.copy() - single_kwargs["file_name"] = file_names[i] - - # Generate or use provided doc_id for this string - doc_id = single_kwargs.get("doc_id") - if not doc_id: - doc_id = str(uuid.uuid4()) - single_kwargs["doc_id"] = doc_id - - status, tos_url = self._upload_to_tos(data=content, **single_kwargs) - if status != 200: - raise ValueError(f"Error in upload_to_tos for string {i}: {status}") - - doc_id = self._add_doc( - collection_name=collection_name, - tos_url=tos_url, - doc_id=doc_id, - ) - - results.append( - { - "tos_url": f"tos://{tos_url}", - "doc_id": doc_id, - } - ) - - return results - - # Handle single file upload or other data types - else: - # Handle doc_id from kwargs or generate a new one - doc_id = kwargs.get("doc_id", str(uuid.uuid4())) - - status, tos_url = self._upload_to_tos(data=data, **kwargs) - if status != 200: - raise ValueError(f"Error in upload_to_tos: {status}") - doc_id = self._add_doc( - collection_name=collection_name, - tos_url=tos_url, - doc_id=doc_id, - ) - return { - "tos_url": f"tos://{tos_url}", - "doc_id": doc_id, - } - - def delete(self, **kwargs: Any): - name = kwargs.get("name") - project = kwargs.get("project", self.config.project) - request_param = {"name": name, "project": project} - doc_del_req = prepare_request( - method="POST", path=doc_del_path, config=self.config, data=request_param - ) - rsp = requests.request( - method=doc_del_req.method, - url="http://{}{}".format(g_knowledge_base_domain, doc_del_req.path), - headers=doc_del_req.headers, - data=doc_del_req.body, - ) - result = rsp.json() - if result["code"] != 0: - logger.error(f"Error in add_doc: {result['message']}") - return False - return True - - def query(self, query: str, **kwargs: Any) -> list[str]: - """ - Args: - query: query text - **kwargs: collection_name(required), top_k(optional, default 5) - - Returns: list of str, the search result - """ - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "collection_name is required" - request_params = { - "query": query, - "limit": int(kwargs.get("top_k", 5)), - "name": collection_name, - "project": self.config.project, - } - search_req = prepare_request( - method="POST", - path=search_knowledge_path, - config=self.config, - data=request_params, - ) - resp = requests.request( - method=search_req.method, - url="https://{}{}".format(g_knowledge_base_domain, search_req.path), - headers=search_req.headers, - data=search_req.body, - ) - - result = resp.json() - if result["code"] != 0: - logger.error(f"Error in search_knowledge: {result['message']}") - raise ValueError(f"Error in search_knowledge: {result['message']}") - - if not result["data"]["result_list"]: - raise ValueError(f"No results found for collection {collection_name}") - - chunks = result["data"]["result_list"] - - search_result = [] - - for chunk in chunks: - search_result.append(chunk["content"]) - - return search_result - - def create_collection( - self, - collection_name: str, - description: str = "", - version: Literal[2, 4] = 4, - data_type: Literal[ - "unstructured_data", "structured_data" - ] = "unstructured_data", - chunking_strategy: Literal["custom_balance", "custom"] = "custom_balance", - chunk_length: int = 500, - merge_small_chunks: bool = True, - ): - request_params = { - "name": collection_name, - "project": self.config.project, - "description": description, - "version": version, - "data_type": data_type, - "preprocessing": { - "chunking_strategy": chunking_strategy, - "chunk_length": chunk_length, - "merge_small_chunks": merge_small_chunks, - }, - } - - create_collection_req = prepare_request( - method="POST", - path=create_collection_path, - config=self.config, - data=request_params, - ) - resp = requests.request( - method=create_collection_req.method, - url="https://{}{}".format( - g_knowledge_base_domain, create_collection_req.path - ), - headers=create_collection_req.headers, - data=create_collection_req.body, - ) - - result = resp.json() - if result["code"] != 0: - logger.error(f"Error in create_collection: {result['message']}") - raise ValueError(f"Error in create_collection: {result['message']}") - return result - - def collection_exists(self, collection_name: str) -> bool: - request_params = { - "project": self.config.project, - } - list_collections_req = prepare_request( - method="POST", - path=list_collections_path, - config=self.config, - data=request_params, - ) - resp = requests.request( - method=list_collections_req.method, - url="https://{}{}".format( - g_knowledge_base_domain, list_collections_req.path - ), - headers=list_collections_req.headers, - data=list_collections_req.body, - ) - - result = resp.json() - if result["code"] != 0: - logger.error(f"Error in list_collections: {result['message']}") - raise ValueError(f"Error in list_collections: {result['message']}") - - collections = result["data"].get("collection_list", []) - if len(collections) == 0: - return False - - collection_list = set() - - for collection in collections: - collection_list.add(collection["collection_name"]) - # check the collection exist or not - if collection_name in collection_list: - return True - else: - return False - - def list_chunks( - self, collection_name: str, offset: int = 0, limit: int = -1 - ) -> list[dict]: - request_params = { - "collection_name": collection_name, - "project": self.config.project, - "offset": offset, - "limit": limit, - } - - list_doc_req = prepare_request( - method="POST", - path=list_point_path, - config=self.config, - data=request_params, - ) - resp = requests.request( - method=list_doc_req.method, - url="https://{}{}".format(g_knowledge_base_domain, list_doc_req.path), - headers=list_doc_req.headers, - data=list_doc_req.body, - ) - - result = resp.json() - if result["code"] != 0: - logger.error(f"Error in list_docs: {result['message']}") - raise ValueError(f"Error in list_docs: {result['message']}") - - if not result["data"].get("point_list", []): - return [] - - data = [ - { - "id": res["point_id"], - "content": res["content"], - "metadata": res["doc_info"], - } - for res in result["data"]["point_list"] - ] - return data - - def list_docs( - self, collection_name: str, offset: int = 0, limit: int = -1 - ) -> list[dict]: - request_params = { - "collection_name": collection_name, - "project": self.config.project, - "offset": offset, - "limit": limit, - } - - list_doc_req = prepare_request( - method="POST", - path=list_docs_path, - config=self.config, - data=request_params, - ) - resp = requests.request( - method=list_doc_req.method, - url="https://{}{}".format(g_knowledge_base_domain, list_doc_req.path), - headers=list_doc_req.headers, - data=list_doc_req.body, - ) - - result = resp.json() - if result["code"] != 0: - logger.error(f"Error in list_docs: {result['message']}") - raise ValueError(f"Error in list_docs: {result['message']}") - - if not result["data"].get("doc_list", []): - return [] - return result["data"]["doc_list"] - - def delete_by_id(self, collection_name: str, id: str) -> bool: - request_params = { - "collection_name": collection_name, - "project": self.config.project, - "doc_id": id, - } - - delete_by_id_req = prepare_request( - method="POST", - path=delete_docs_path, - config=self.config, - data=request_params, - ) - resp = requests.request( - method=delete_by_id_req.method, - url="https://{}{}".format(g_knowledge_base_domain, delete_by_id_req.path), - headers=delete_by_id_req.headers, - data=delete_by_id_req.body, - ) - - result = resp.json() - if result["code"] != 0: - return False - return True diff --git a/veadk/database/viking/viking_memory_db.py b/veadk/database/viking/viking_memory_db.py deleted file mode 100644 index e591cb52..00000000 --- a/veadk/database/viking/viking_memory_db.py +++ /dev/null @@ -1,525 +0,0 @@ -# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import random -import string -import threading -import time -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field -from volcengine.ApiInfo import ApiInfo -from volcengine.auth.SignerV4 import SignerV4 -from volcengine.base.Service import Service -from volcengine.Credentials import Credentials -from volcengine.ServiceInfo import ServiceInfo - -from veadk.config import getenv -from veadk.database.base_database import BaseDatabase -from veadk.utils.logger import get_logger - -logger = get_logger(__name__) - - -class VikingMemConfig(BaseModel): - volcengine_ak: str = Field( - default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY"), - description="VikingDB access key", - ) - volcengine_sk: str = Field( - default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY"), - description="VikingDB secret key", - ) - project: str = Field( - default_factory=lambda: getenv("DATABASE_VIKING_PROJECT"), - description="VikingDB project name", - ) - region: str = Field( - default_factory=lambda: getenv("DATABASE_VIKING_REGION"), - description="VikingDB region", - ) - - -# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py ======= -class VikingMemoryException(Exception): - def __init__(self, code, request_id, message=None): - self.code = code - self.request_id = request_id - self.message = "{}, code:{},request_id:{}".format( - message, self.code, self.request_id - ) - - def __str__(self): - return self.message - - -class VikingMemoryService(Service): - _instance_lock = threading.Lock() - - def __new__(cls, *args, **kwargs): - if not hasattr(VikingMemoryService, "_instance"): - with VikingMemoryService._instance_lock: - if not hasattr(VikingMemoryService, "_instance"): - VikingMemoryService._instance = object.__new__(cls) - return VikingMemoryService._instance - - def __init__( - self, - host="api-knowledgebase.mlp.cn-beijing.volces.com", - region="cn-beijing", - ak="", - sk="", - sts_token="", - scheme="http", - connection_timeout=30, - socket_timeout=30, - ): - self.service_info = VikingMemoryService.get_service_info( - host, region, scheme, connection_timeout, socket_timeout - ) - self.api_info = VikingMemoryService.get_api_info() - super(VikingMemoryService, self).__init__(self.service_info, self.api_info) - if ak: - self.set_ak(ak) - if sk: - self.set_sk(sk) - if sts_token: - self.set_session_token(session_token=sts_token) - try: - self.get_body("Ping", {}, json.dumps({})) - except Exception as e: - raise VikingMemoryException( - 1000028, "missed", "host or region is incorrect: {}".format(str(e)) - ) from None - - def setHeader(self, header): - api_info = VikingMemoryService.get_api_info() - for key in api_info: - for item in header: - api_info[key].header[item] = header[item] - self.api_info = api_info - - @staticmethod - def get_service_info(host, region, scheme, connection_timeout, socket_timeout): - service_info = ServiceInfo( - host, - {"Host": host}, - Credentials("", "", "air", region), - connection_timeout, - socket_timeout, - scheme=scheme, - ) - return service_info - - @staticmethod - def get_api_info(): - api_info = { - "CreateCollection": ApiInfo( - "POST", - "/api/memory/collection/create", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - "GetCollection": ApiInfo( - "POST", - "/api/memory/collection/info", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - "DropCollection": ApiInfo( - "POST", - "/api/memory/collection/delete", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - "UpdateCollection": ApiInfo( - "POST", - "/api/memory/collection/update", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - "SearchMemory": ApiInfo( - "POST", - "/api/memory/search", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - "AddMessages": ApiInfo( - "POST", - "/api/memory/messages/add", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - "Ping": ApiInfo( - "GET", - "/api/memory/ping", - {}, - {}, - {"Accept": "application/json", "Content-Type": "application/json"}, - ), - } - return api_info - - def get_body(self, api, params, body): - if api not in self.api_info: - raise Exception("no such api") - api_info = self.api_info[api] - r = self.prepare_request(api_info, params) - r.headers["Content-Type"] = "application/json" - r.headers["Traffic-Source"] = "SDK" - r.body = body - - SignerV4.sign(r, self.service_info.credentials) - - url = r.build() - resp = self.session.get( - url, - headers=r.headers, - data=r.body, - timeout=( - self.service_info.connection_timeout, - self.service_info.socket_timeout, - ), - ) - if resp.status_code == 200: - return json.dumps(resp.json()) - else: - raise Exception(resp.text.encode("utf-8")) - - def get_body_exception(self, api, params, body): - try: - res = self.get_body(api, params, body) - except Exception as e: - try: - res_json = json.loads(e.args[0].decode("utf-8")) - except Exception: - raise VikingMemoryException( - 1000028, "missed", "json load res error, res:{}".format(str(e)) - ) from None - code = res_json.get("code", 1000028) - request_id = res_json.get("request_id", 1000028) - message = res_json.get("message", None) - - raise VikingMemoryException(code, request_id, message) - - if res == "": - raise VikingMemoryException( - 1000028, - "missed", - "empty response due to unknown error, please contact customer service", - ) from None - return res - - def get_exception(self, api, params): - try: - res = self.get(api, params) - except Exception as e: - try: - res_json = json.loads(e.args[0].decode("utf-8")) - except Exception: - raise VikingMemoryException( - 1000028, "missed", "json load res error, res:{}".format(str(e)) - ) from None - code = res_json.get("code", 1000028) - request_id = res_json.get("request_id", 1000028) - message = res_json.get("message", None) - raise VikingMemoryException(code, request_id, message) - if res == "": - raise VikingMemoryException( - 1000028, - "missed", - "empty response due to unknown error, please contact customer service", - ) from None - return res - - def create_collection( - self, - collection_name, - description="", - custom_event_type_schemas=None, - custom_entity_type_schemas=None, - builtin_event_types=None, - builtin_entity_types=None, - ): - if custom_event_type_schemas is None: - custom_event_type_schemas = [] - if custom_entity_type_schemas is None: - custom_entity_type_schemas = [] - if builtin_entity_types is None: - builtin_entity_types = ["sys_profile_v1"] - if builtin_event_types is None: - builtin_event_types = ["sys_event_v1", "sys_profile_collect_v1"] - params = { - "CollectionName": collection_name, - "Description": description, - "CustomEventTypeSchemas": custom_event_type_schemas, - "CustomEntityTypeSchemas": custom_entity_type_schemas, - "BuiltinEventTypes": builtin_event_types, - "BuiltinEntityTypes": builtin_entity_types, - } - res = self.json("CreateCollection", {}, json.dumps(params)) - return json.loads(res) - - def get_collection(self, collection_name): - params = {"CollectionName": collection_name} - res = self.json("GetCollection", {}, json.dumps(params)) - return json.loads(res) - - def drop_collection(self, collection_name): - params = {"CollectionName": collection_name} - res = self.json("DropCollection", {}, json.dumps(params)) - return json.loads(res) - - def update_collection( - self, - collection_name, - custom_event_type_schemas=[], - custom_entity_type_schemas=[], - builtin_event_types=[], - builtin_entity_types=[], - ): - params = { - "CollectionName": collection_name, - "CustomEventTypeSchemas": custom_event_type_schemas, - "CustomEntityTypeSchemas": custom_entity_type_schemas, - "BuiltinEventTypes": builtin_event_types, - "BuiltinEntityTypes": builtin_entity_types, - } - res = self.json("UpdateCollection", {}, json.dumps(params)) - return json.loads(res) - - def search_memory(self, collection_name, query, filter, limit=10): - params = { - "collection_name": collection_name, - "limit": limit, - "filter": filter, - } - if query: - params["query"] = query - res = self.json("SearchMemory", {}, json.dumps(params)) - return json.loads(res) - - def add_messages( - self, collection_name, session_id, messages, metadata, entities=None - ): - params = { - "collection_name": collection_name, - "session_id": session_id, - "messages": messages, - "metadata": metadata, - } - if entities is not None: - params["entities"] = entities - res = self.json("AddMessages", {}, json.dumps(params)) - return json.loads(res) - - -def memory2event(role, text): - return json.dumps({"role": role, "parts": [{"text": text}]}, ensure_ascii=False) - - -def generate_random_letters(length): - # 生成包含所有大小写字母的字符集 - letters = string.ascii_letters - return "".join(random.choice(letters) for _ in range(length)) - - -def format_milliseconds(timestamp_ms): - """ - Convert the millisecond - level timestamp to a string in the 'YYYYMMDD HH:MM:SS' format. - - Parameters: - - timestamp_ms: Millisecond - level timestamp (integer or float) - - Returns: - - Formatted time string - - """ - # Convert milliseconds to seconds - timestamp_seconds = timestamp_ms / 1000 - - # Convert to a datetime object - dt = datetime.fromtimestamp(timestamp_seconds) - - # Output in the specified format - return dt.strftime("%Y%m%d %H:%M:%S") - - -# ======= adapted from https://github.com/volcengine/mcp-server/blob/main/server/mcp_server_vikingdb_memory/src/mcp_server_vikingdb_memory/common/memory_client.py ======= - - -class VikingMemoryDatabase(BaseModel, BaseDatabase): - config: VikingMemConfig = Field( - default_factory=VikingMemConfig, - description="VikingDB configuration", - ) - - def model_post_init(self, context: Any, /) -> None: - self._vm = VikingMemoryService( - ak=self.config.volcengine_ak, sk=self.config.volcengine_sk - ) - - def add_memories( - self, - collection_name: str, - text: str, - user_id: str, - ) -> str: - # Add Messages - session_id = generate_random_letters(10) - # proces - message = json.loads(text) - content = message["parts"][0]["text"] - role = ( - "user" if message["role"] == "user" else "assistant" - ) # field 'role': viking memory only allow 'assistant','system','user', - messages = [{"role": role, "content": content}] - metadata = { - "default_user_id": user_id, - "default_assistant_id": "assistant", - "time": int(time.time() * 1000), - } - - rsp = self._vm.add_messages( - collection_name=collection_name, - session_id=session_id, - messages=messages, - metadata=metadata, - ) - return str(rsp) - - def add(self, data: list[str], **kwargs): - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "collection_name is required" - user_id = kwargs.get("user_id") - assert user_id is not None, "user_id is required" - try: - self._vm.get_collection(collection_name=collection_name) - except Exception: - self._vm.create_collection( - collection_name=collection_name, - ) - - for text in data: - self.add_memories( - collection_name=collection_name, text=text, user_id=user_id - ) - - return "success" - - def search_memory( - self, collection_name: str, query: str, user_id: str, top_k: int = 5 - ) -> list[str]: - """ - Search for stored memories. This method is called whenever a user asks any question. - If a search yields no results, do not repeat the search within the same conversation. - The retrieved memories are used to supplement your understanding of the user and to reply to the user's question. - Args: - collection_name: viking db collection_name - query: Any question asked by the user. - Returns: - The user's memories related to the query. - """ - - result = [] - try: - # ------- get profiles ----------- - try: - limit = 1 - filter = { - "user_id": user_id, - "memory_type": ["sys_profile_v1"], - } - rsp = self._vm.search_memory( - collection_name=collection_name, - query="sys_profile_v1", - filter=filter, - limit=limit, - ) - profiles = [ - item.get("memory_info").get("user_profile") - for item in rsp.get("data").get("result_list") - ] - if len(profiles) > 0: - result.append(memory2event("user", profiles[0])) - except Exception as e: - result.append( - memory2event("user", f"SearchMemory: Get Profiles Error: {str(e)}") - ) - - # -------- get memory ----------- - try: - # Search Memory - limit = top_k - filter = { - "user_id": user_id, - "memory_type": ["sys_event_v1"], - } - rsp = self._vm.search_memory( - collection_name=collection_name, - query=query, - filter=filter, - limit=limit, - ) - result_list = rsp.get("data").get("result_list") - - content = [ - memory2event("user", item.get("memory_info").get("summary")) - for item in result_list - ] - - result.extend(content) - - except Exception as e: - result.append( - memory2event("user", f"SearchMemory: Get Memory Error: {str(e)}") - ) - - return result - - except Exception as e: - logger.error(f"Error in get_doc: {str(e)}") - result.append( - memory2event("user", f"SearchMemory: Get Memory Error: {str(e)}") - ) - return result - - def query(self, query: str, **kwargs: Any) -> list[str]: - """ - Args: - query: query text - **kwargs: collection_name(required), top_k(optional, default 5) - - Returns: list of str, the search result - """ - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "collection_name is required" - user_id = kwargs.get("user_id") - assert user_id is not None, "user_id is required" - top_k = kwargs.get("top_k", 5) - resp = self.search_memory(collection_name, query, user_id=user_id, top_k=top_k) - return resp - - def delete(self, **kwargs: Any): - collection_name = kwargs.get("collection_name") - assert collection_name is not None, "collection_name is required" - self._vm.drop_collection(collection_name) diff --git a/veadk/database/kv/__init__.py b/veadk/knowledgebase/backends/__init__.py similarity index 100% rename from veadk/database/kv/__init__.py rename to veadk/knowledgebase/backends/__init__.py diff --git a/veadk/knowledgebase/backends/base_backend.py b/veadk/knowledgebase/backends/base_backend.py new file mode 100644 index 00000000..939f41f4 --- /dev/null +++ b/veadk/knowledgebase/backends/base_backend.py @@ -0,0 +1,59 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from pydantic import BaseModel + + +class BaseKnowledgebaseBackend(ABC, BaseModel): + index: str + """Index or collection name of the vector storage.""" + + @abstractmethod + def precheck_index_naming(self) -> None: + """Check the index name is valid or not. + + If index naming is not valid, raise an exception. + """ + + @abstractmethod + def add_from_directory(self, directory: str, **kwargs) -> bool: + """Add knowledge from file path to knowledgebase""" + + @abstractmethod + def add_from_files(self, files: list[str], **kwargs) -> bool: + """Add knowledge (e.g, documents, strings, ...) to knowledgebase""" + + @abstractmethod + def add_from_text(self, text: str | list[str], **kwargs) -> bool: + """Add knowledge from text to knowledgebase""" + + @abstractmethod + def search(self, **kwargs) -> list: + """Search knowledge from knowledgebase""" + + # Optional methods for future use: + # - `delete`: Delete collection or documents + # - `list_docs`: List original documents + # - `list_chunks`: List embedded document chunks + + # def delete(self, **kwargs) -> bool: + # """Delete knowledge from knowledgebase""" + + # def list_docs(self, **kwargs) -> None: + # """List original documents in knowledgebase""" + + # def list_chunks(self, **kwargs) -> None: + # """List embeded document chunks in knowledgebase""" diff --git a/veadk/knowledgebase/backends/in_memory_backend.py b/veadk/knowledgebase/backends/in_memory_backend.py new file mode 100644 index 00000000..ed8088a4 --- /dev/null +++ b/veadk/knowledgebase/backends/in_memory_backend.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llama_index.core import Document, SimpleDirectoryReader, VectorStoreIndex +from llama_index.core.schema import BaseNode +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from pydantic import Field +from typing_extensions import Any, override + +from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig +from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend +from veadk.knowledgebase.backends.utils import get_llama_index_splitter + + +class InMemoryKnowledgeBackend(BaseKnowledgebaseBackend): + embedding_config: NormalEmbeddingModelConfig | EmbeddingModelConfig = Field( + default_factory=EmbeddingModelConfig + ) + """Embedding model configs""" + + def model_post_init(self, __context: Any) -> None: + self._embed_model = OpenAILikeEmbedding( + model_name=self.embedding_config.name, + api_key=self.embedding_config.api_key, + api_base=self.embedding_config.api_base, + ) + self._vector_index = VectorStoreIndex([], embed_model=self._embed_model) + + @override + def precheck_index_naming(self) -> None: + # Checking is not needed + pass + + @override + def add_from_directory(self, directory: str) -> bool: + documents = SimpleDirectoryReader(input_dir=directory).load_data() + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def add_from_files(self, files: list[str]) -> bool: + documents = SimpleDirectoryReader(input_files=files).load_data() + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def add_from_text(self, text: str | list[str]) -> bool: + if isinstance(text, str): + documents = [Document(text=text)] + else: + documents = [Document(text=t) for t in text] + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def search(self, query: str, top_k: int = 5) -> list[str]: + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) + retrieved_nodes = _retriever.retrieve(query) + return [node.text for node in retrieved_nodes] + + def _split_documents(self, documents: list[Document]) -> list[BaseNode]: + """Split document into chunks""" + nodes = [] + for document in documents: + splitter = get_llama_index_splitter(document.metadata.get("file_path", "")) + _nodes = splitter.get_nodes_from_documents([document]) + nodes.extend(_nodes) + return nodes diff --git a/veadk/knowledgebase/backends/opensearch_backend.py b/veadk/knowledgebase/backends/opensearch_backend.py new file mode 100644 index 00000000..0b799c87 --- /dev/null +++ b/veadk/knowledgebase/backends/opensearch_backend.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from llama_index.core import ( + Document, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, +) +from llama_index.core.schema import BaseNode +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from pydantic import Field +from typing_extensions import Any, override + +import veadk.config # noqa E401 +from veadk.configs.database_configs import OpensearchConfig +from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig +from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend +from veadk.knowledgebase.backends.utils import get_llama_index_splitter + +try: + from llama_index.vector_stores.opensearch import ( + OpensearchVectorClient, + OpensearchVectorStore, + ) +except ImportError: + raise ImportError( + "Please install VeADK extensions\npip install veadk-python[extensions]" + ) + + +class OpensearchKnowledgeBackend(BaseKnowledgebaseBackend): + opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig) + """Opensearch client configs""" + + embedding_config: EmbeddingModelConfig | NormalEmbeddingModelConfig = Field( + default_factory=EmbeddingModelConfig + ) + """Embedding model configs""" + + def model_post_init(self, __context: Any) -> None: + self.precheck_index_naming() + self._opensearch_client = OpensearchVectorClient( + endpoint=self.opensearch_config.host, + port=self.opensearch_config.port, + http_auth=( + self.opensearch_config.username, + self.opensearch_config.password, + ), + use_ssl=True, + verify_certs=False, + dim=self.embedding_config.dim, + index=self.index, # collection name + ) + + self._vector_store = OpensearchVectorStore(client=self._opensearch_client) + + self._storage_context = StorageContext.from_defaults( + vector_store=self._vector_store + ) + + self._embed_model = OpenAILikeEmbedding( + model_name=self.embedding_config.name, + api_key=self.embedding_config.api_key, + api_base=self.embedding_config.api_base, + ) + + self._vector_index = VectorStoreIndex.from_documents( + documents=[], + storage_context=self._storage_context, + embed_model=self._embed_model, + ) + + @override + def precheck_index_naming(self) -> None: + if not ( + isinstance(self.index, str) + and not self.index.startswith(("_", "-")) + and self.index.islower() + and re.match(r"^[a-z0-9_\-.]+$", self.index) + ): + raise ValueError( + "The index name does not conform to the naming rules of OpenSearch" + ) + + @override + def add_from_directory(self, directory: str) -> bool: + documents = SimpleDirectoryReader(input_dir=directory).load_data() + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def add_from_files(self, files: list[str]) -> bool: + documents = SimpleDirectoryReader(input_files=files).load_data() + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def add_from_text(self, text: str | list[str]) -> bool: + if isinstance(text, str): + documents = [Document(text=text)] + else: + documents = [Document(text=t) for t in text] + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def search(self, query: str, top_k: int = 5) -> list[str]: + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) + retrieved_nodes = _retriever.retrieve(query) + return [node.text for node in retrieved_nodes] + + def _split_documents(self, documents: list[Document]) -> list[BaseNode]: + """Split document into chunks""" + nodes = [] + for document in documents: + splitter = get_llama_index_splitter(document.metadata.get("file_path", "")) + _nodes = splitter.get_nodes_from_documents([document]) + nodes.extend(_nodes) + return nodes diff --git a/veadk/knowledgebase/backends/redis_backend.py b/veadk/knowledgebase/backends/redis_backend.py new file mode 100644 index 00000000..508be5cf --- /dev/null +++ b/veadk/knowledgebase/backends/redis_backend.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llama_index.core import ( + Document, + SimpleDirectoryReader, + StorageContext, + VectorStoreIndex, +) +from llama_index.core.schema import BaseNode +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from pydantic import Field +from typing_extensions import Any, override + +import veadk.config # noqa E401 +from veadk.configs.database_configs import RedisConfig +from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig +from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend +from veadk.knowledgebase.backends.utils import get_llama_index_splitter + +try: + from llama_index.vector_stores.redis import RedisVectorStore + from llama_index.vector_stores.redis.schema import ( + RedisIndexInfo, + RedisVectorStoreSchema, + ) + from redis import Redis + from redisvl.schema.fields import BaseVectorFieldAttributes +except ImportError: + raise ImportError( + "Please install VeADK extensions\npip install veadk-python[extensions]" + ) + + +class RedisKnowledgeBackend(BaseKnowledgebaseBackend): + redis_config: RedisConfig = Field(default_factory=RedisConfig) + """Redis client configs""" + + embedding_config: EmbeddingModelConfig | NormalEmbeddingModelConfig = Field( + default_factory=EmbeddingModelConfig + ) + """Embedding model configs""" + + def model_post_init(self, __context: Any) -> None: + # We will use `from_url` to init Redis client once the + # AK/SK -> STS token is ready. + # self._redis_client = Redis.from_url(url=...) + + self._redis_client = Redis( + host=self.redis_config.host, + port=self.redis_config.port, + db=self.redis_config.db, + password=self.redis_config.password, + ) + + self._embed_model = OpenAILikeEmbedding( + model_name=self.embedding_config.name, + api_key=self.embedding_config.api_key, + api_base=self.embedding_config.api_base, + ) + + self._schema = RedisVectorStoreSchema( + index=RedisIndexInfo(name=self.index), + ) + if "vector" in self._schema.fields: + vector_field = self._schema.fields["vector"] + if ( + vector_field + and vector_field.attrs + and isinstance(vector_field.attrs, BaseVectorFieldAttributes) + ): + vector_field.attrs.dims = self.embedding_config.dim + + self._vector_store = RedisVectorStore( + schema=self._schema, + redis_client=self._redis_client, + overwrite=True, + collection_name=self.index, + ) + + self._storage_context = StorageContext.from_defaults( + vector_store=self._vector_store + ) + + self._vector_index = VectorStoreIndex.from_documents( + documents=[], + storage_context=self._storage_context, + embed_model=self._embed_model, + ) + + @override + def precheck_index_naming(self) -> None: + # Checking is not needed + pass + + @override + def add_from_directory(self, directory: str) -> bool: + documents = SimpleDirectoryReader(input_dir=directory).load_data() + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def add_from_files(self, files: list[str]) -> bool: + documents = SimpleDirectoryReader(input_files=files).load_data() + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def add_from_text(self, text: str | list[str]) -> bool: + if isinstance(text, str): + documents = [Document(text=text)] + else: + documents = [Document(text=t) for t in text] + nodes = self._split_documents(documents) + self._vector_index.insert_nodes(nodes) + return True + + @override + def search(self, query: str, top_k: int = 5) -> list[str]: + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) + retrieved_nodes = _retriever.retrieve(query) + return [node.text for node in retrieved_nodes] + + def _split_documents(self, documents: list[Document]) -> list[BaseNode]: + """Split document into chunks""" + nodes = [] + for document in documents: + splitter = get_llama_index_splitter(document.metadata.get("file_path", "")) + _nodes = splitter.get_nodes_from_documents([document]) + nodes.extend(_nodes) + return nodes diff --git a/veadk/knowledgebase/backends/utils.py b/veadk/knowledgebase/backends/utils.py new file mode 100644 index 00000000..d4ff0903 --- /dev/null +++ b/veadk/knowledgebase/backends/utils.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path +from typing import Literal + +from llama_index.core.node_parser import ( + CodeSplitter, + HTMLNodeParser, + MarkdownNodeParser, + SentenceSplitter, +) +from volcengine.auth.SignerV4 import SignerV4 +from volcengine.base.Request import Request +from volcengine.Credentials import Credentials + + +def get_llama_index_splitter( + file_path: str, +) -> CodeSplitter | MarkdownNodeParser | HTMLNodeParser | SentenceSplitter: + suffix = Path(file_path).suffix.lower() + + if suffix in [".py", ".js", ".java", ".cpp"]: + return CodeSplitter(language=suffix.strip(".")) + elif suffix in [".md"]: + return MarkdownNodeParser() + elif suffix in [".html", ".htm"]: + return HTMLNodeParser() + else: + return SentenceSplitter(chunk_size=512, chunk_overlap=50) + + +def build_vikingdb_knowledgebase_request( + path: str, + volcengine_access_key: str, + volcengine_secret_key: str, + method: Literal["GET", "POST", "PUT", "DELETE"] = "POST", + region: str = "cn-beijing", + params=None, + data=None, + doseq=0, +) -> Request: + if params: + for key in params: + if ( + type(params[key]) is int + or type(params[key]) is float + or type(params[key]) is bool + ): + params[key] = str(params[key]) + elif type(params[key]) is list: + if not doseq: + params[key] = ",".join(params[key]) + + r = Request() + r.set_shema("https") + r.set_method(method) + r.set_connection_timeout(10) + r.set_socket_timeout(10) + + mheaders = { + "Accept": "application/json", + "Content-Type": "application/json", + } + r.set_headers(mheaders) + + if params: + r.set_query(params) + + r.set_path(path) + + if data is not None: + r.set_body(json.dumps(data)) + + credentials = Credentials( + volcengine_access_key, volcengine_secret_key, "air", region + ) + SignerV4.sign(r, credentials) + return r diff --git a/veadk/knowledgebase/backends/vikingdb_knowledge_backend.py b/veadk/knowledgebase/backends/vikingdb_knowledge_backend.py new file mode 100644 index 00000000..af5da720 --- /dev/null +++ b/veadk/knowledgebase/backends/vikingdb_knowledge_backend.py @@ -0,0 +1,377 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import re +from pathlib import Path +from typing import Any, Literal + +import requests +from pydantic import Field +from typing_extensions import override + +import veadk.config # noqa E401 +from veadk.config import getenv +from veadk.consts import DEFAULT_TOS_BUCKET_NAME +from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend +from veadk.knowledgebase.backends.utils import build_vikingdb_knowledgebase_request +from veadk.utils.logger import get_logger +from veadk.utils.misc import formatted_timestamp + +try: + from veadk.integrations.ve_tos.ve_tos import VeTOS +except ImportError: + raise ImportError( + "Please install VeADK extensions\npip install veadk-python[extensions]" + ) + +logger = get_logger(__name__) + + +def _read_file_to_bytes(file_path: str) -> tuple[bytes, str]: + """Read file content to bytes, and file name""" + with open(file_path, "rb") as f: + file_content = f.read() + file_name = file_path.split("/")[-1] + return file_content, file_name + + +def _extract_tos_attributes(**kwargs) -> tuple[str, str]: + """Extract TOS attributes from kwargs""" + tos_bucket_name = kwargs.get("tos_bucket_name", DEFAULT_TOS_BUCKET_NAME) + tos_bucket_path = kwargs.get("tos_bucket_path", "knowledgebase") + return tos_bucket_name, tos_bucket_path + + +def get_files_in_directory(directory: str): + dir_path = Path(directory) + if not dir_path.is_dir(): + raise ValueError(f"The directory does not exist: {directory}") + file_paths = [str(file) for file in dir_path.iterdir() if file.is_file()] + return file_paths + + +def _upload_bytes_to_tos(content: bytes, tos_bucket_name: str, object_key: str) -> str: + ve_tos = VeTOS(bucket_name=tos_bucket_name) + asyncio.run(ve_tos.upload(object_key=object_key, data=content)) + return f"{ve_tos.bucket_name}/{object_key}" + + +class VikingDBKnowledgeBackend(BaseKnowledgebaseBackend): + volcengine_access_key: str = Field( + default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY") + ) + + volcengine_secret_key: str = Field( + default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY") + ) + + volcengine_project: str = "default" + """VikingDB knowledgebase project in Volcengine console platform. Default by `default`""" + + region: str = "cn-beijing" + """VikingDB knowledgebase region""" + + def precheck_index_naming(self): + if not ( + isinstance(self.index, str) + and 0 < len(self.index) <= 128 + and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", self.index) + ): + raise ValueError( + "The index name does not conform to the rules: " + "it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." + ) + + def model_post_init(self, __context: Any) -> None: + self.precheck_index_naming() + # check whether collection exist, if not, create it + if not self.collection_status()["existed"]: + logger.warning( + f"VikingDB knowledgebase collection {self.index} does not exist, please create it first..." + ) + + @override + def add_from_directory(self, directory: str, **kwargs) -> bool: + """ + Args: + directory: str, the directory to add to knowledgebase + **kwargs: + - tos_bucket_name: str, the bucket name of TOS + - tos_bucket_path: str, the path of TOS bucket + """ + tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs) + files = get_files_in_directory(directory=directory) + for _file in files: + content, file_name = _read_file_to_bytes(_file) + tos_url = _upload_bytes_to_tos( + content, + tos_bucket_name=tos_bucket_name, + object_key=f"{tos_bucket_path}/{file_name}", + ) + self._add_doc(tos_url=tos_url) + return True + + @override + def add_from_files(self, files: list[str], **kwargs) -> bool: + """ + Args: + files: list[str], the files to add to knowledgebase + **kwargs: + - tos_bucket_name: str, the bucket name of TOS + - tos_bucket_path: str, the path of TOS bucket + """ + tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs) + for _file in files: + content, file_name = _read_file_to_bytes(_file) + tos_url = _upload_bytes_to_tos( + content, + tos_bucket_name=tos_bucket_name, + object_key=f"{tos_bucket_path}/{file_name}", + ) + self._add_doc(tos_url=tos_url) + return True + + @override + def add_from_text(self, text: str | list[str], **kwargs) -> bool: + """ + Args: + text: str or list[str], the text to add to knowledgebase + **kwargs: + - tos_bucket_name: str, the bucket name of TOS + - tos_bucket_path: str, the path of TOS bucket + """ + tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs) + if isinstance(text, list): + object_keys = kwargs.get( + "tos_object_keys", + [ + f"{tos_bucket_path}/{formatted_timestamp()}-{i}.txt" + for i, _ in enumerate(text) + ], + ) + for _text, _object_key in zip(text, object_keys): + _content = _text.encode("utf-8") + tos_url = _upload_bytes_to_tos(_content, tos_bucket_name, _object_key) + self._add_doc(tos_url=tos_url) + return True + elif isinstance(text, str): + content = text.encode("utf-8") + object_key = kwargs.get( + "object_key", f"veadk/knowledgebase/{formatted_timestamp()}.txt" + ) + tos_url = _upload_bytes_to_tos(content, tos_bucket_name, object_key) + self._add_doc(tos_url=tos_url) + else: + raise ValueError("text must be str or list[str]") + return True + + @override + def search(self, query: str, top_k: int = 5) -> list: + return self._search_knowledge(query=query, top_k=top_k) + + def delete_collection(self) -> bool: + DELETE_COLLECTION_PATH = "/api/knowledge/collection/delete" + + response = self._do_request( + body={ + "name": self.index, + "project": self.volcengine_project, + }, + path=DELETE_COLLECTION_PATH, + method="POST", + ) + + if response.get("code") != 0: + logger.error(f"Error during collection deletion: {response}") + return False + return True + + def delete_doc_by_id(self, id: str) -> bool: + DELETE_DOC_PATH = "/api/knowledge/doc/delete" + response = self._do_request( + body={ + "collection_name": self.index, + "project": self.volcengine_project, + "doc_id": id, + }, + path=DELETE_DOC_PATH, + method="POST", + ) + + if response.get("code") != 0: + return False + return True + + def list_docs(self, offset: int = 0, limit: int = -1): + """List documents in collection. + + Args: + offset (int): The offset of the first document to return. + limit (int): The maximum number of documents to return. -1 means return all documents but max is 100. + """ + LIST_DOCS_PATH = "/api/knowledge/doc/list" + response = self._do_request( + body={ + "collection_name": self.index, + "project": self.volcengine_project, + "offset": offset, + "limit": limit, + }, + path=LIST_DOCS_PATH, + method="POST", + ) + if response.get("code") != 0: + raise ValueError(f"Error during list documents: {response.get('code')}") + if not response["data"].get("doc_list", []): + return [] + return response["data"]["doc_list"] + + def list_chunks(self, offset: int = 0, limit: int = -1): + """List chunks in collection. + + Args: + offset (int): The offset of the first chunk to return. + limit (int): The maximum number of chunks to return. -1 means return all chunks but max is 100. + """ + LIST_CHUNKS_PATH = "/api/knowledge/point/list" + response = self._do_request( + body={ + "collection_name": self.index, + "project": self.volcengine_project, + "offset": offset, + "limit": limit, + }, + path=LIST_CHUNKS_PATH, + method="POST", + ) + + if response.get("code") != 0: + raise ValueError(f"Error during list chunks: {response}") + + if not response["data"].get("point_list", []): + return [] + data = [ + { + "id": res["point_id"], + "content": res["content"], + "metadata": res["doc_info"], + } + for res in response["data"]["point_list"] + ] + return data + + def collection_status(self): + COLLECTION_INFO_PATH = "/api/knowledge/collection/info" + response = self._do_request( + body={ + "name": self.index, + "project": self.volcengine_project, + }, + path=COLLECTION_INFO_PATH, + method="POST", + ) + if response["code"] == 0: + status = response["data"]["pipeline_list"][0]["index_list"][0]["status"] + return { + "existed": True, + "status": status, + } + elif response["code"] == 1000005: + return { + "existed": False, + "status": None, + } + else: + raise ValueError(f"Error during collection status: {response}") + + def create_collection(self) -> None: + CREATE_COLLECTION_PATH = "/api/knowledge/collection/create" + + response = self._do_request( + body={ + "name": self.index, + "project": "default", + "description": "Created by Volcengine Agent Development Kit (VeADK).", + }, + path=CREATE_COLLECTION_PATH, + method="POST", + ) + + if response.get("code") != 0: + raise ValueError( + f"Error during collection creation: {response.get('code')}" + ) + + def _add_doc(self, tos_url: str) -> Any: + ADD_DOC_PATH = "/api/knowledge/doc/add" + + response = self._do_request( + body={ + "collection_name": self.index, + "project": "default", + "add_type": "tos", + "tos_path": tos_url, + }, + path=ADD_DOC_PATH, + method="POST", + ) + return response + + def _search_knowledge(self, query: str, top_k: int = 5) -> list[str]: + SEARCH_KNOWLEDGE_PATH = "/api/knowledge/collection/search_knowledge" + + response = self._do_request( + body={ + "name": self.index, + "query": query, + "limit": top_k, + }, + path=SEARCH_KNOWLEDGE_PATH, + method="POST", + ) + + if response.get("code") != 0: + raise ValueError( + f"Error during knowledge search: {response.get('code')}, message: {response.get('message')}" + ) + + search_result_list = response.get("data", {}).get("result_list", []) + + return [ + search_result.get("content", "") for search_result in search_result_list + ] + + def _do_request( + self, + body: dict, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE"] = "POST", + ) -> dict: + VIKINGDB_KNOWLEDGEBASE_BASE_URL = "api-knowledgebase.mlp.cn-beijing.volces.com" + + request = build_vikingdb_knowledgebase_request( + path=path, + volcengine_access_key=self.volcengine_access_key, + volcengine_secret_key=self.volcengine_secret_key, + method=method, + data=body, + ) + response = requests.request( + method=method, + url=f"https://{VIKINGDB_KNOWLEDGEBASE_BASE_URL}{path}", + headers=request.headers, + data=request.body, + ) + return response.json() diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index 2fa9e833..48f61d1a 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -11,168 +11,134 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import io -import os.path -from typing import Any, BinaryIO, Literal, TextIO -from pydantic import BaseModel +from typing import Any, Callable, Literal -from veadk.database.database_adapter import get_knowledgebase_database_adapter -from veadk.database.database_factory import DatabaseFactory -from veadk.utils.misc import formatted_timestamp +from pydantic import BaseModel, Field +from typing_extensions import Union + +from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend from veadk.utils.logger import get_logger logger = get_logger(__name__) +def _get_backend_cls(backend: str) -> type[BaseKnowledgebaseBackend]: + match backend: + case "local": + from veadk.knowledgebase.backends.in_memory_backend import ( + InMemoryKnowledgeBackend, + ) + + return InMemoryKnowledgeBackend + case "opensearch": + from veadk.knowledgebase.backends.opensearch_backend import ( + OpensearchKnowledgeBackend, + ) + + return OpensearchKnowledgeBackend + case "viking": + from veadk.knowledgebase.backends.vikingdb_knowledge_backend import ( + VikingDBKnowledgeBackend, + ) + + return VikingDBKnowledgeBackend + case "redis": + from veadk.knowledgebase.backends.redis_backend import ( + RedisKnowledgeBackend, + ) + + return RedisKnowledgeBackend + + raise ValueError(f"Unsupported knowledgebase backend: {backend}") + + def build_knowledgebase_index(app_name: str): return f"veadk_kb_{app_name}" class KnowledgeBase(BaseModel): - backend: Literal["local", "opensearch", "viking", "redis", "mysql"] = "local" + backend: Union[ + Literal["local", "opensearch", "viking", "redis"], BaseKnowledgebaseBackend + ] = "local" + """Knowledgebase backend type. Supported backends are: + - local: In-memory knowledgebase, data will be lost when the program exits. + - opensearch: OpenSearch knowledgebase, requires an OpenSearch cluster. + - viking: Volcengine VikingDB knowledgebase, requires VikingDB service. + - redis: Redis knowledgebase, requires Redis with vector search capability. + Default is `local`.""" + + backend_config: dict = Field(default_factory=dict) + """Configuration for the backend""" + top_k: int = 10 - db_config: Any | None = None + """Number of top similar documents to retrieve during search. + + Default is 10.""" - def model_post_init(self, __context: Any) -> None: - logger.info( - f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}" - ) + app_name: str = "" - self._db_client = DatabaseFactory.create( - backend=self.backend, config=self.db_config - ) - self._adapter = get_knowledgebase_database_adapter(self._db_client) + index: str = "" + """The name of the knowledgebase index. If not provided, it will be generated based on the `app_name`.""" - logger.info( - f"Initialized knowledgebase: db_client={self._db_client.__class__.__name__} adapter={self._adapter}" - ) + def model_post_init(self, __context: Any) -> None: + if isinstance(self.backend, BaseKnowledgebaseBackend): + self._backend = self.backend + logger.info( + f"Initialized knowledgebase with provided backend instance {self._backend.__class__.__name__}" + ) + return - def add( - self, - data: str | list[str] | TextIO | BinaryIO | bytes, - app_name: str, - **kwargs, - ): - """ - Add documents to the vector database. - Args: - data (str | list[str] | TextIO | BinaryIO | bytes): The data to be added. - - str: A single file path. (viking only) - - list[str]: A list of file paths. - - TextIO: A file object (TextIO). (viking only) file descriptor - - BinaryIO: A file object (BinaryIO). (viking only) file descriptor - - bytes: Binary data. (viking only) binary data (f.read()) - app_name: index name - **kwargs: Additional keyword arguments. - - file_name (str | list[str]): The file name or a list of file names (including suffix). (viking only) - """ - if self.backend != "viking" and not ( - isinstance(data, str) or isinstance(data, list) - ): + # must provide at least one of them + if not self.app_name and not self.index: raise ValueError( - "Only vikingdb supports uploading files or file characters." + "Either `app_name` or `index` must be provided one of them." ) - index = build_knowledgebase_index(app_name) - logger.info(f"Adding documents to knowledgebase: index={index}") - - if self.backend == "viking": - # Case 1: Handling file paths or lists of file paths (str) - if isinstance(data, str) and os.path.isfile(data): - # Get the file name (including the suffix) - if "file_name" not in kwargs or not kwargs["file_name"]: - kwargs["file_name"] = os.path.basename(data) - return self._adapter.add(data=data, index=index, **kwargs) - # Case 2: Handling when list[str] is a full path (list[str]) - if isinstance(data, list): - if all(isinstance(item, str) for item in data): - all_paths = all(os.path.isfile(item) for item in data) - all_not_paths = all(not os.path.isfile(item) for item in data) - if all_paths: - if "file_name" not in kwargs or not kwargs["file_name"]: - kwargs["file_name"] = [ - os.path.basename(item) for item in data - ] - return self._adapter.add(data=data, index=index, **kwargs) - elif ( - not all_not_paths - ): # Prevent the occurrence of non-existent paths - # There is a mixture of paths and non-paths - raise ValueError( - "Mixed file paths and content strings in list are not allowed" - ) - # Case 3: Handling strings or string arrays (content) (str or list[str]) - if isinstance(data, str) or ( - isinstance(data, list) and all(isinstance(item, str) for item in data) - ): - if "file_name" not in kwargs or not kwargs["file_name"]: - if isinstance(data, str): - kwargs["file_name"] = f"{formatted_timestamp()}.txt" - else: # list[str] without file_names - prefix_file_name = formatted_timestamp() - kwargs["file_name"] = [ - f"{prefix_file_name}_{i}.txt" for i in range(len(data)) - ] - return self._adapter.add(data=data, index=index, **kwargs) - - # Case 4: Handling binary data (bytes) - if isinstance(data, bytes): - # user must give file_name - if "file_name" not in kwargs: - raise ValueError("file_name must be provided for binary data") - return self._adapter.add(data=data, index=index, **kwargs) - - # Case 5: Handling file objects TextIO or BinaryIO - if isinstance(data, (io.TextIOWrapper, io.BufferedReader)): - if not kwargs.get("file_name") and hasattr(data, "name"): - kwargs["file_name"] = os.path.basename(data.name) - return self._adapter.add(data=data, index=index, **kwargs) - # Case6: Unsupported data type - raise TypeError(f"Unsupported data type: {type(data)}") - - if not isinstance(data, list): - raise TypeError( - f"Unsupported data type: {type(data)}. Only viking support file_path and file bytes" + # priority use index + if self.app_name and self.index: + logger.warning( + "`app_name` and `index` are both provided, using `index` as the knowledgebase index name." ) - # not viking - return self._adapter.add(data=data, index=index, **kwargs) - def search(self, query: str, app_name: str, top_k: int | None = None) -> list[str]: - top_k = self.top_k if top_k is None else top_k + # generate index name if `index` not provided but `app_name` is provided + if self.app_name and not self.index: + self.index = build_knowledgebase_index(self.app_name) + logger.info( + f"Knowledgebase index is set to {self.index} (generated by the app_name: {self.app_name})." + ) logger.info( - f"Searching knowledgebase: app_name={app_name} query={query} top_k={top_k}" + f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}" ) - index = build_knowledgebase_index(app_name) - result = self._adapter.query(query=query, index=index, top_k=top_k) - if len(result) == 0: - logger.warning(f"No documents found in knowledgebase. Query: {query}") - return result - - def delete(self, app_name: str) -> bool: - index = build_knowledgebase_index(app_name) - return self._adapter.delete(index=index) - - def delete_doc(self, app_name: str, id: str) -> bool: - index = build_knowledgebase_index(app_name) - return self._adapter.delete_doc(index=index, id=id) - - def list_chunks( - self, app_name: str, offset: int = 0, limit: int = 100 - ) -> list[dict]: - index = build_knowledgebase_index(app_name) - return self._adapter.list_chunks(index=index, offset=offset, limit=limit) - - def list_docs(self, app_name: str, offset: int = 0, limit: int = 100) -> list[dict]: - if self.backend == "viking": - index = build_knowledgebase_index(app_name) - return self._adapter.list_docs(index=index, offset=offset, limit=limit) - else: - raise NotImplementedError( - f"list_docs not supported for {self.backend}, only viking support list_docs" - ) + self._backend = _get_backend_cls(self.backend)( + index=self.index, **self.backend_config if self.backend_config else {} + ) + logger.info( + f"Initialized knowledgebase with backend {self._backend.__class__.__name__}" + ) + + def add_from_directory(self, directory: str, **kwargs) -> bool: + """Add knowledge from file path to knowledgebase""" + return self._backend.add_from_directory(directory=directory, **kwargs) + + def add_from_files(self, files: list[str], **kwargs) -> bool: + """Add knowledge (e.g, documents, strings, ...) to knowledgebase""" + return self._backend.add_from_files(files=files, **kwargs) - def exists(self, app_name: str) -> bool: - index = build_knowledgebase_index(app_name) - return self._adapter.index_exists(index=index) + def add_from_text(self, text: str | list[str], **kwargs) -> bool: + """Add knowledge from text to knowledgebase""" + return self._backend.add_from_text(text=text, **kwargs) + + def search(self, query: str, top_k: int = 0, **kwargs) -> list[str]: + """Search knowledge from knowledgebase""" + if top_k == 0: + top_k = self.top_k + return self._backend.search(query=query, top_k=top_k, **kwargs) + + def __getattr__(self, name) -> Callable: + """In case of knowledgebase have no backends' methods (`delete`, `list_chunks`, etc) + + For example, knowledgebase.delete(...) -> self._backend.delete(...) + """ + return getattr(self._backend, name) diff --git a/veadk/memory/__init__.py b/veadk/memory/__init__.py index 7f463206..c2ddbbd3 100644 --- a/veadk/memory/__init__.py +++ b/veadk/memory/__init__.py @@ -11,3 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from veadk.memory.long_term_memory import LongTermMemory + from veadk.memory.short_term_memory import ShortTermMemory + + +# Lazy loading for classes +def __getattr__(name): + if name == "ShortTermMemory": + from veadk.memory.short_term_memory import ShortTermMemory + + return ShortTermMemory + if name == "LongTeremMemory": + from veadk.memory.long_term_memory import LongTermMemory + + return LongTermMemory + raise AttributeError(f"module 'veadk.memory' has no attribute '{name}'") + + +__all__ = ["ShortTermMemory", "LongTermMemory"] diff --git a/veadk/memory/long_term_memory.py b/veadk/memory/long_term_memory.py index e8d96bc9..e9c91a7a 100644 --- a/veadk/memory/long_term_memory.py +++ b/veadk/memory/long_term_memory.py @@ -25,45 +25,98 @@ from google.adk.memory.memory_entry import MemoryEntry from google.adk.sessions import Session from google.genai import types -from pydantic import BaseModel -from typing_extensions import override +from pydantic import BaseModel, Field +from typing_extensions import Union, override -from veadk.database import DatabaseFactory -from veadk.database.database_adapter import get_long_term_memory_database_adapter +from veadk.memory.long_term_memory_backends.base_backend import ( + BaseLongTermMemoryBackend, +) from veadk.utils.logger import get_logger logger = get_logger(__name__) +def _get_backend_cls(backend: str) -> type[BaseLongTermMemoryBackend]: + match backend: + case "local": + from veadk.memory.long_term_memory_backends.in_memory_backend import ( + InMemoryLTMBackend, + ) + + return InMemoryLTMBackend + case "opensearch": + from veadk.memory.long_term_memory_backends.opensearch_backend import ( + OpensearchLTMBackend, + ) + + return OpensearchLTMBackend + case "viking": + from veadk.memory.long_term_memory_backends.vikingdb_memory_backend import ( + VikingDBLTMBackend, + ) + + return VikingDBLTMBackend + case "redis": + from veadk.memory.long_term_memory_backends.redis_backend import ( + RedisLTMBackend, + ) + + return RedisLTMBackend + + raise ValueError(f"Unsupported long term memory backend: {backend}") + + def build_long_term_memory_index(app_name: str, user_id: str): return f"{app_name}_{user_id}" class LongTermMemory(BaseMemoryService, BaseModel): - backend: Literal[ - "local", "opensearch", "redis", "mysql", "viking", "viking_mem" + backend: Union[ + Literal["local", "opensearch", "redis", "viking", "viking_mem"], + BaseLongTermMemoryBackend, ] = "opensearch" + """Long term memory backend type""" + + backend_config: dict = Field(default_factory=dict) + """Long term memory backend configuration""" + top_k: int = 5 + """Number of top similar documents to retrieve during search.""" + + app_name: str = "" + + user_id: str = "" def model_post_init(self, __context: Any) -> None: - if self.backend == "viking": - logger.warning( - "`viking` backend is deprecated, switching to `viking_mem` backend." - ) - self.backend = "viking_mem" + self._backend = None - logger.info( - f"Initializing long term memory: backend={self.backend} top_k={self.top_k}" - ) + # Once user define a backend instance, use it directly + if isinstance(self.backend, BaseLongTermMemoryBackend): + self._backend = self.backend + logger.info( + f"Initialized long term memory with provided backend instance {self._backend.__class__.__name__}" + ) + return - self._db_client = DatabaseFactory.create( - backend=self.backend, - ) - self._adapter = get_long_term_memory_database_adapter(self._db_client) + if self.backend_config: + logger.warning( + f"Initialized long term memory backend {self.backend} with config. We will ignore `app_name` and `user_id` if provided." + ) + self._backend = _get_backend_cls(self.backend)(**self.backend_config) + return - logger.info( - f"Initialized long term memory: db_client={self._db_client.__class__.__name__} adapter={self._adapter}" - ) + if self.app_name and self.user_id: + self._index = build_long_term_memory_index( + app_name=self.app_name, user_id=self.user_id + ) + logger.info(f"Long term memory index set to {self._index}.") + self._backend = _get_backend_cls(self.backend)( + index=self._index, **self.backend_config if self.backend_config else {} + ) + else: + logger.warning( + "Neither `backend_instance`, `backend_config`, nor (`app_name`/`user_id`) is provided, the long term memory storage will initialize when adding a session." + ) def _filter_and_convert_events(self, events: list[Event]) -> list[str]: final_events = [] @@ -91,40 +144,57 @@ async def add_session_to_memory( self, session: Session, ): + app_name = session.app_name + user_id = session.user_id + + if self._index != build_long_term_memory_index(app_name, user_id): + logger.warning( + f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}" + ) + return + + if not self._backend and isinstance(self.backend, str): + self._index = build_long_term_memory_index(app_name, user_id) + self._backend = _get_backend_cls(self.backend)( + index=self._index, **self.backend_config if self.backend_config else {} + ) + logger.info( + f"Initialize long term memory backend now, index is {self._index}" + ) + event_strings = self._filter_and_convert_events(session.events) - index = build_long_term_memory_index(session.app_name, session.user_id) logger.info( - f"Adding {len(event_strings)} events to long term memory: index={index}" + f"Adding {len(event_strings)} events to long term memory: index={self._index}" ) - # check if viking memory database, should give a user id: if/else - if self.backend == "viking_mem": - self._adapter.add(data=event_strings, index=index, user_id=session.user_id) - else: - self._adapter.add(data=event_strings, index=index) + if self._backend: + self._backend.save_memory(event_strings=event_strings, user_id=user_id) - logger.info( - f"Added {len(event_strings)} events to long term memory: index={index}" - ) + logger.info( + f"Added {len(event_strings)} events to long term memory: index={self._index}" + ) + else: + logger.error( + "Long term memory backend initialize failed, cannot add session to memory." + ) @override async def search_memory(self, *, app_name: str, user_id: str, query: str): - index = build_long_term_memory_index(app_name, user_id) - logger.info( - f"Searching long term memory: query={query} index={index} top_k={self.top_k}" + f"Searching long term memory: query={query} index={self._index} top_k={self.top_k}" ) - # user id if viking memory db - if self.backend == "viking_mem": - memory_chunks = self._adapter.query( - query=query, index=index, top_k=self.top_k, user_id=user_id - ) - else: - memory_chunks = self._adapter.query( - query=query, index=index, top_k=self.top_k + # prevent model invoke `load_memory` before add session to this memory + if not self._backend: + logger.error( + "Long term memory backend is not initialized, cannot search memory." ) + return SearchMemoryResponse(memories=[]) + + memory_chunks = self._backend.search_memory( + query=query, top_k=self.top_k, user_id=user_id + ) memory_events = [] for memory in memory_chunks: @@ -152,6 +222,6 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str): ) logger.info( - f"Return {len(memory_events)} memory events for query: {query} index={index}" + f"Return {len(memory_events)} memory events for query: {query} index={self._index}" ) return SearchMemoryResponse(memories=memory_events) diff --git a/veadk/database/relational/__init__.py b/veadk/memory/long_term_memory_backends/__init__.py similarity index 100% rename from veadk/database/relational/__init__.py rename to veadk/memory/long_term_memory_backends/__init__.py diff --git a/veadk/database/base_database.py b/veadk/memory/long_term_memory_backends/base_backend.py similarity index 58% rename from veadk/database/base_database.py rename to veadk/memory/long_term_memory_backends/base_backend.py index b0256fd9..8bfbd16b 100644 --- a/veadk/database/base_database.py +++ b/veadk/memory/long_term_memory_backends/base_backend.py @@ -13,33 +13,21 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Any +from pydantic import BaseModel -class DatabaseType: - LOCAL = "local" - RELATIONAL = "relational" - VECTOR = "vector" - KV = "kv" +class BaseLongTermMemoryBackend(ABC, BaseModel): + index: str -class BaseDatabase(ABC): - """Base class for database. - - Args: - type: type of the database - - Note: - No `update` function support currently. - """ - - def __init__(self): - pass - - def add(self, texts: list[Any], **kwargs: Any): ... + @abstractmethod + def precheck_index_naming(self): + """Check the index name is valid or not""" @abstractmethod - def delete(self, **kwargs: Any): ... + def save_memory(self, event_strings: list[str], **kwargs) -> bool: + """Save memory to long term memory backend""" @abstractmethod - def query(self, query: str, **kwargs: Any) -> list[str]: ... + def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]: + """Retrieve memory from long term memory backend""" diff --git a/veadk/memory/long_term_memory_backends/in_memory_backend.py b/veadk/memory/long_term_memory_backends/in_memory_backend.py new file mode 100644 index 00000000..d2d244cc --- /dev/null +++ b/veadk/memory/long_term_memory_backends/in_memory_backend.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llama_index.core import Document, VectorStoreIndex +from llama_index.core.schema import BaseNode +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from pydantic import Field +from typing_extensions import Any, override + +from veadk.configs.model_configs import EmbeddingModelConfig +from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.memory.long_term_memory_backends.base_backend import ( + BaseLongTermMemoryBackend, +) + + +class InMemoryLTMBackend(BaseLongTermMemoryBackend): + embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig) + """Embedding model configs""" + + def precheck_index_naming(self): + # no checking + pass + + def model_post_init(self, __context: Any) -> None: + self._embed_model = OpenAILikeEmbedding( + model_name=self.embedding_config.name, + api_key=self.embedding_config.api_key, + api_base=self.embedding_config.api_base, + ) + self._vector_index = VectorStoreIndex([], embed_model=self._embed_model) + + @override + def save_memory(self, event_strings: list[str], **kwargs) -> bool: + for event_string in event_strings: + document = Document(text=event_string) + nodes = self._split_documents([document]) + self._vector_index.insert_nodes(nodes) + return True + + @override + def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]: + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) + retrieved_nodes = _retriever.retrieve(query) + return [node.text for node in retrieved_nodes] + + def _split_documents(self, documents: list[Document]) -> list[BaseNode]: + """Split document into chunks""" + nodes = [] + for document in documents: + splitter = get_llama_index_splitter(document.metadata.get("file_path", "")) + _nodes = splitter.get_nodes_from_documents([document]) + nodes.extend(_nodes) + return nodes diff --git a/veadk/memory/long_term_memory_backends/opensearch_backend.py b/veadk/memory/long_term_memory_backends/opensearch_backend.py new file mode 100644 index 00000000..ad2de4f9 --- /dev/null +++ b/veadk/memory/long_term_memory_backends/opensearch_backend.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from llama_index.core import ( + Document, + StorageContext, + VectorStoreIndex, +) +from llama_index.core.schema import BaseNode +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from pydantic import Field +from typing_extensions import Any, override + +import veadk.config # noqa E401 +from veadk.configs.database_configs import OpensearchConfig +from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig +from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.memory.long_term_memory_backends.base_backend import ( + BaseLongTermMemoryBackend, +) + +try: + from llama_index.vector_stores.opensearch import ( + OpensearchVectorClient, + OpensearchVectorStore, + ) +except ImportError: + raise ImportError( + "Please install VeADK extensions\npip install veadk-python[extensions]" + ) + + +class OpensearchLTMBackend(BaseLongTermMemoryBackend): + opensearch_config: OpensearchConfig = Field(default_factory=OpensearchConfig) + """Opensearch client configs""" + + embedding_config: EmbeddingModelConfig | NormalEmbeddingModelConfig = Field( + default_factory=EmbeddingModelConfig + ) + """Embedding model configs""" + + def precheck_index_naming(self): + if not ( + isinstance(self.index, str) + and not self.index.startswith(("_", "-")) + and self.index.islower() + and re.match(r"^[a-z0-9_\-.]+$", self.index) + ): + raise ValueError( + "The index name does not conform to the naming rules of OpenSearch" + ) + + def model_post_init(self, __context: Any) -> None: + self._opensearch_client = OpensearchVectorClient( + endpoint=self.opensearch_config.host, + port=self.opensearch_config.port, + http_auth=( + self.opensearch_config.username, + self.opensearch_config.password, + ), + use_ssl=True, + verify_certs=False, + dim=self.embedding_config.dim, + index=self.index, # collection name + ) + + self._vector_store = OpensearchVectorStore(client=self._opensearch_client) + + self._storage_context = StorageContext.from_defaults( + vector_store=self._vector_store + ) + + self._embed_model = OpenAILikeEmbedding( + model_name=self.embedding_config.name, + api_key=self.embedding_config.api_key, + api_base=self.embedding_config.api_base, + ) + + self._vector_index = VectorStoreIndex.from_documents( + documents=[], + storage_context=self._storage_context, + embed_model=self._embed_model, + ) + self._retriever = self._vector_index.as_retriever() + + @override + def save_memory(self, event_strings: list[str], **kwargs) -> bool: + for event_string in event_strings: + document = Document(text=event_string) + nodes = self._split_documents([document]) + self._vector_index.insert_nodes(nodes) + return True + + @override + def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]: + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) + retrieved_nodes = _retriever.retrieve(query) + return [node.text for node in retrieved_nodes] + + def _split_documents(self, documents: list[Document]) -> list[BaseNode]: + """Split document into chunks""" + nodes = [] + for document in documents: + splitter = get_llama_index_splitter(document.metadata.get("file_path", "")) + _nodes = splitter.get_nodes_from_documents([document]) + nodes.extend(_nodes) + return nodes diff --git a/veadk/memory/long_term_memory_backends/redis_backend.py b/veadk/memory/long_term_memory_backends/redis_backend.py new file mode 100644 index 00000000..fbf469da --- /dev/null +++ b/veadk/memory/long_term_memory_backends/redis_backend.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from llama_index.core import ( + Document, + StorageContext, + VectorStoreIndex, +) +from llama_index.core.schema import BaseNode +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from pydantic import Field +from typing_extensions import Any, override + +import veadk.config # noqa E401 +from veadk.configs.database_configs import RedisConfig +from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig +from veadk.knowledgebase.backends.utils import get_llama_index_splitter +from veadk.memory.long_term_memory_backends.base_backend import ( + BaseLongTermMemoryBackend, +) + +try: + from llama_index.vector_stores.redis import RedisVectorStore + from llama_index.vector_stores.redis.schema import ( + RedisIndexInfo, + RedisVectorStoreSchema, + ) + from redis import Redis + from redisvl.schema.fields import BaseVectorFieldAttributes +except ImportError: + raise ImportError( + "Please install VeADK extensions\npip install veadk-python[extensions]" + ) + + +class RedisLTMBackend(BaseLongTermMemoryBackend): + redis_config: RedisConfig = Field(default_factory=RedisConfig) + """Redis client configs""" + + embedding_config: EmbeddingModelConfig | NormalEmbeddingModelConfig = Field( + default_factory=EmbeddingModelConfig + ) + """Embedding model configs""" + + def precheck_index_naming(self): + # no checking + pass + + def model_post_init(self, __context: Any) -> None: + # We will use `from_url` to init Redis client once the + # AK/SK -> STS token is ready. + # self._redis_client = Redis.from_url(url=...) + + self._redis_client = Redis( + host=self.redis_config.host, + port=self.redis_config.port, + db=self.redis_config.db, + password=self.redis_config.password, + ) + + self._embed_model = OpenAILikeEmbedding( + model_name=self.embedding_config.name, + api_key=self.embedding_config.api_key, + api_base=self.embedding_config.api_base, + ) + + self._schema = RedisVectorStoreSchema( + index=RedisIndexInfo(name=self.index), + ) + if "vector" in self._schema.fields: + vector_field = self._schema.fields["vector"] + if ( + vector_field + and vector_field.attrs + and isinstance(vector_field.attrs, BaseVectorFieldAttributes) + ): + vector_field.attrs.dims = self.embedding_config.dim + self._vector_store = RedisVectorStore( + schema=self._schema, + redis_client=self._redis_client, + overwrite=True, + collection_name=self.index, + ) + + self._storage_context = StorageContext.from_defaults( + vector_store=self._vector_store + ) + + self._vector_index = VectorStoreIndex.from_documents( + documents=[], + storage_context=self._storage_context, + embed_model=self._embed_model, + ) + + @override + def save_memory(self, event_strings: list[str], **kwargs) -> bool: + for event_string in event_strings: + document = Document(text=event_string) + nodes = self._split_documents([document]) + self._vector_index.insert_nodes(nodes) + return True + + @override + def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]: + _retriever = self._vector_index.as_retriever(similarity_top_k=top_k) + retrieved_nodes = _retriever.retrieve(query) + return [node.text for node in retrieved_nodes] + + def _split_documents(self, documents: list[Document]) -> list[BaseNode]: + """Split document into chunks""" + nodes = [] + for document in documents: + splitter = get_llama_index_splitter(document.metadata.get("file_path", "")) + _nodes = splitter.get_nodes_from_documents([document]) + nodes.extend(_nodes) + return nodes diff --git a/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py b/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py new file mode 100644 index 00000000..cafe73c4 --- /dev/null +++ b/veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +import time +import uuid +from typing import Any + +from pydantic import Field +from typing_extensions import override + +import veadk.config # noqa E401 +from veadk.config import getenv +from veadk.memory.long_term_memory_backends.base_backend import ( + BaseLongTermMemoryBackend, +) +from veadk.utils.logger import get_logger + +try: + from mcp_server_vikingdb_memory.common.memory_client import VikingDBMemoryService +except ImportError: + raise ImportError( + "Please install VeADK extensions\npip install veadk-python[extensions]" + ) + +logger = get_logger(__name__) + + +class VikingDBLTMBackend(BaseLongTermMemoryBackend): + volcengine_access_key: str = Field( + default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY") + ) + + volcengine_secret_key: str = Field( + default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY") + ) + + region: str = "cn-beijing" + """VikingDB memory region""" + + def precheck_index_naming(self): + if not ( + isinstance(self.index, str) + and 1 <= len(self.index) <= 128 + and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", self.index) + ): + raise ValueError( + "The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128." + ) + + def model_post_init(self, __context: Any) -> None: + self._client = VikingDBMemoryService( + ak=self.volcengine_access_key, + sk=self.volcengine_secret_key, + region=self.region, + ) + + # check whether collection exist, if not, create it + if not self._collection_exist(): + self._create_collection() + + def _collection_exist(self) -> bool: + try: + self._client.get_collection(collection_name=self.index) + return True + except Exception: + return False + + def _create_collection(self) -> None: + response = self._client.create_collection( + collection_name=self.index, + description="Created by Volcengine Agent Development Kit VeADK", + builtin_event_types=["sys_event_v1"], + ) + return response + + @override + def save_memory(self, event_strings: list[str], **kwargs) -> bool: + user_id = kwargs.get("user_id") + if user_id is None: + raise ValueError("user_id is required") + session_id = str(uuid.uuid1()) + messages = [] + for raw_events in event_strings: + event = json.loads(raw_events) + content = event["parts"][0]["text"] + role = ( + "user" if event["role"] == "user" else "assistant" + ) # field 'role': viking memory only allow 'assistant','system','user', + messages.append({"role": role, "content": content}) + metadata = { + "default_user_id": user_id, + "default_assistant_id": "assistant", + "time": int(time.time() * 1000), + } + response = self._client.add_messages( + collection_name=self.index, + messages=messages, + metadata=metadata, + session_id=session_id, + ) + + if not response.get("code") == 0: + raise ValueError(f"Save VikingDB memory error: {response}") + + return True + + @override + def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]: + user_id = kwargs.get("user_id") + if user_id is None: + raise ValueError("user_id is required") + filter = { + "user_id": user_id, + "memory_type": ["sys_event_v1"], + } + response = self._client.search_memory( + collection_name=self.index, query=query, filter=filter, limit=top_k + ) + + if not response.get("code") == 0: + raise ValueError(f"Search VikingDB memory error: {response}") + + result = response.get("data", {}).get("result_list", []) + if result: + return [ + json.dumps( + { + "role": "user", + "parts": [{"text": r.get("memory_info").get("summary")}], + }, + ensure_ascii=False, + ) + for r in result + ] + return [] diff --git a/veadk/memory/short_term_memory.py b/veadk/memory/short_term_memory.py index d3878c1f..9838b0d1 100644 --- a/veadk/memory/short_term_memory.py +++ b/veadk/memory/short_term_memory.py @@ -12,95 +12,103 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import Literal +from functools import wraps +from typing import Any, Callable, Literal + +from google.adk.sessions import ( + BaseSessionService, + DatabaseSessionService, + InMemorySessionService, +) +from pydantic import BaseModel, Field, PrivateAttr + +from veadk.memory.short_term_memory_backends.mysql_backend import ( + MysqlSTMBackend, +) +from veadk.memory.short_term_memory_backends.postgresql_backend import ( + PostgreSqlSTMBackend, +) +from veadk.memory.short_term_memory_backends.sqlite_backend import ( + SQLiteSTMBackend, +) +from veadk.utils.logger import get_logger -from google.adk.sessions import DatabaseSessionService, InMemorySessionService +logger = get_logger(__name__) -from veadk.config import getenv -from veadk.utils.logger import get_logger -from .short_term_memory_processor import ShortTermMemoryProcessor +def wrap_get_session_with_callbacks(obj, callback_fn: Callable): + get_session_fn = getattr(obj, "get_session") -logger = get_logger(__name__) + @wraps(get_session_fn) + def wrapper(*args, **kwargs): + result = get_session_fn(*args, **kwargs) + callback_fn(result, *args, **kwargs) + return result -DEFAULT_LOCAL_DATABASE_PATH = "/tmp/veadk_local_database.db" + setattr(obj, "get_session", wrapper) -class ShortTermMemory: - """ - Short term memory class. +class ShortTermMemory(BaseModel): + backend: Literal["local", "mysql", "sqlite", "postgresql", "database"] = "local" + """Short term memory backend. `Local` for in-memory storage, `mysql` for mysql / PostgreSQL storage. `sqlite` for sqlite storage.""" - This class is used to store short term memory. - """ + backend_configs: dict = Field(default_factory=dict) + """Backend specific configurations.""" - def __init__( - self, - backend: Literal["local", "database", "mysql"] = "local", - db_url: str = "", - enable_memory_optimization: bool = False, - ): - self.backend = backend - self.db_url = db_url - - if self.backend == "mysql": - host = getenv("DATABASE_MYSQL_HOST") - user = getenv("DATABASE_MYSQL_USER") - password = getenv("DATABASE_MYSQL_PASSWORD") - database = getenv("DATABASE_MYSQL_DATABASE") - db_url = f"mysql+pymysql://{user}:{password}@{host}/{database}" - - self.db_url = db_url - self.backend = "database" - - if self.backend == "local": - logger.warning( - f"Short term memory backend: {self.backend}, the history will be lost after application shutdown." - ) - self.session_service = InMemorySessionService() - elif self.backend == "database": - if self.db_url == "" or self.db_url is None: - logger.warning("The `db_url` is an empty or None string.") - self._use_default_database() - else: - try: - self.session_service = DatabaseSessionService(db_url=self.db_url) - logger.info("Connected to database with db_url.") - except Exception as e: - logger.error(f"Failed to connect to database, error: {e}.") - self._use_default_database() - else: - raise ValueError(f"Unknown short term memory backend: {self.backend}") + db_url: str = "" + """Database connection URL, e.g. `sqlite:///./test.db`. Once set, it will override the `backend` parameter.""" - if enable_memory_optimization and backend == "database": - self.processor = ShortTermMemoryProcessor() - intercept_get_session = self.processor.patch() - self.session_service.get_session = intercept_get_session( - self.session_service.get_session - ) + local_database_path: str = "/tmp/veadk_local_database.db" + """Local database path, only used when `backend` is `sqlite`. Default to `/tmp/veadk_local_database.db`.""" + + after_load_memory_callback: Callable | None = None + """A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input.""" - def _use_default_database(self): - self.db_url = DEFAULT_LOCAL_DATABASE_PATH - logger.info(f"Using default local database {self.db_url}") - if not os.path.exists(self.db_url): - self.create_local_sqlite3_db(self.db_url) - self.session_service = DatabaseSessionService(db_url="sqlite:///" + self.db_url) + _session_service: BaseSessionService = PrivateAttr() - def create_local_sqlite3_db(self, path: str): - import sqlite3 + def model_post_init(self, __context: Any) -> None: + if self.db_url: + logger.info("The `db_url` is set, ignore `backend` option.") + self._session_service = DatabaseSessionService(db_url=self.db_url) + else: + if self.backend == "database": + logger.warning( + "Backend `database` is deprecated, use `sqlite` to create short term memory." + ) + self.backend = "sqlite" + match self.backend: + case "local": + self._session_service = InMemorySessionService() + case "mysql": + self._session_service = MysqlSTMBackend( + **self.backend_configs + ).session_service + case "sqlite": + self._session_service = SQLiteSTMBackend( + local_path=self.local_database_path + ).session_service + case "postgresql": + self._session_service = PostgreSqlSTMBackend( + **self.backend_configs + ).session_service + + if self.after_load_memory_callback: + wrap_get_session_with_callbacks( + self._session_service, self.after_load_memory_callback + ) - conn = sqlite3.connect(path) - conn.close() - logger.debug(f"Create local sqlite3 database {path} done.") + @property + def session_service(self) -> BaseSessionService: + return self._session_service async def create_session( self, app_name: str, user_id: str, session_id: str, - ): - if isinstance(self.session_service, DatabaseSessionService): - list_sessions_response = await self.session_service.list_sessions( + ) -> None: + if isinstance(self._session_service, DatabaseSessionService): + list_sessions_response = await self._session_service.list_sessions( app_name=app_name, user_id=user_id ) @@ -109,12 +117,12 @@ async def create_session( ) if ( - await self.session_service.get_session( + await self._session_service.get_session( app_name=app_name, user_id=user_id, session_id=session_id ) is None ): # create a new session for this running - await self.session_service.create_session( + await self._session_service.create_session( app_name=app_name, user_id=user_id, session_id=session_id ) diff --git a/veadk/database/vector/__init__.py b/veadk/memory/short_term_memory_backends/__init__.py similarity index 100% rename from veadk/database/vector/__init__.py rename to veadk/memory/short_term_memory_backends/__init__.py diff --git a/veadk/database/__init__.py b/veadk/memory/short_term_memory_backends/base_backend.py similarity index 60% rename from veadk/database/__init__.py rename to veadk/memory/short_term_memory_backends/base_backend.py index b7a89d21..a7b62f27 100644 --- a/veadk/database/__init__.py +++ b/veadk/memory/short_term_memory_backends/base_backend.py @@ -12,6 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .database_factory import DatabaseFactory -__all__ = ["DatabaseFactory"] +from abc import ABC, abstractmethod +from functools import cached_property + +from google.adk.sessions import BaseSessionService +from pydantic import BaseModel + + +class BaseShortTermMemoryBackend(ABC, BaseModel): + """ + Base class for short term memory backend. + """ + + @cached_property + @abstractmethod + def session_service(self) -> BaseSessionService: + """Return the session service instance.""" diff --git a/veadk/memory/short_term_memory_backends/mysql_backend.py b/veadk/memory/short_term_memory_backends/mysql_backend.py new file mode 100644 index 00000000..15905440 --- /dev/null +++ b/veadk/memory/short_term_memory_backends/mysql_backend.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from typing import Any + +from google.adk.sessions import ( + BaseSessionService, + DatabaseSessionService, +) +from pydantic import Field +from typing_extensions import override + +import veadk.config # noqa E401 +from veadk.configs.database_configs import MysqlConfig +from veadk.memory.short_term_memory_backends.base_backend import ( + BaseShortTermMemoryBackend, +) + + +class MysqlSTMBackend(BaseShortTermMemoryBackend): + mysql_config: MysqlConfig = Field(default_factory=MysqlConfig) + + def model_post_init(self, context: Any) -> None: + self._db_url = f"mysql+pymysql://{self.mysql_config.user}:{self.mysql_config.password}@{self.mysql_config.host}/{self.mysql_config.database}" + + @cached_property + @override + def session_service(self) -> BaseSessionService: + return DatabaseSessionService(db_url=self._db_url) diff --git a/veadk/memory/short_term_memory_backends/postgresql_backend.py b/veadk/memory/short_term_memory_backends/postgresql_backend.py new file mode 100644 index 00000000..296fcd5a --- /dev/null +++ b/veadk/memory/short_term_memory_backends/postgresql_backend.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from typing import Any + +from google.adk.sessions import ( + BaseSessionService, + DatabaseSessionService, +) +from pydantic import Field +from typing_extensions import override + +import veadk.config # noqa E401 +from veadk.configs.database_configs import PostgreSqlConfig +from veadk.memory.short_term_memory_backends.base_backend import ( + BaseShortTermMemoryBackend, +) + + +class PostgreSqlSTMBackend(BaseShortTermMemoryBackend): + postgresql_config: PostgreSqlConfig = Field(default_factory=PostgreSqlConfig) + + def model_post_init(self, context: Any) -> None: + self._db_url = f"postgresql+psycopg2://{self.postgresql_config.user}:{self.postgresql_config.password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}" + + @cached_property + @override + def session_service(self) -> BaseSessionService: + return DatabaseSessionService(db_url=self._db_url) diff --git a/veadk/memory/short_term_memory_backends/sqlite_backend.py b/veadk/memory/short_term_memory_backends/sqlite_backend.py new file mode 100644 index 00000000..4a3d1c44 --- /dev/null +++ b/veadk/memory/short_term_memory_backends/sqlite_backend.py @@ -0,0 +1,48 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sqlite3 +from functools import cached_property +from typing import Any + +from google.adk.sessions import ( + BaseSessionService, + DatabaseSessionService, +) +from typing_extensions import override + +from veadk.memory.short_term_memory_backends.base_backend import ( + BaseShortTermMemoryBackend, +) + + +class SQLiteSTMBackend(BaseShortTermMemoryBackend): + local_path: str + + def model_post_init(self, context: Any) -> None: + # if the DB file not exists, create it + if not self._db_exists(): + conn = sqlite3.connect(self.local_path) + conn.close() + + self._db_url = f"sqlite:///{self.local_path}" + + @cached_property + @override + def session_service(self) -> BaseSessionService: + return DatabaseSessionService(db_url=self._db_url) + + def _db_exists(self) -> bool: + return os.path.exists(self.local_path)