diff --git a/veadk/database/database_adapter.py b/veadk/database/database_adapter.py index b789cf03..55081b0d 100644 --- a/veadk/database/database_adapter.py +++ b/veadk/database/database_adapter.py @@ -28,7 +28,26 @@ def __init__(self, client): self.client: RedisDatabase = client - def add(self, data: list[str], index: str): + def index_exists(self, index: str) -> bool: + """ + Check if the index (key) exists in Redis. + + Args: + index: The Redis key to check + + Returns: + bool: True if the key exists, False otherwise + """ + try: + # Use Redis EXISTS command to check if key exists + return bool(self.client._client.exists(index)) + except Exception as e: + logger.error( + f"Failed to check if key exists in Redis: index={index} error={e}" + ) + return False + + def add(self, data: list[str], index: str, **kwargs): logger.debug(f"Adding documents to Redis database: index={index}") try: @@ -78,7 +97,7 @@ def delete_doc(self, index: str, id: str) -> bool: ) return False - def list_docs(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: + def list_chunks(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: logger.debug(f"Listing documents from Redis database: index={index}") try: # Get all documents from Redis @@ -99,6 +118,24 @@ def __init__(self, client): self.client: MysqlDatabase = client + def index_exists(self, index: str) -> bool: + """ + Check if the table (index) exists in MySQL database. + + Args: + index: The table name to check + + Returns: + bool: True if the table exists, False otherwise + """ + try: + return self.client.table_exists(index) + except Exception as e: + logger.error( + f"Failed to check if table exists in MySQL: index={index} error={e}" + ) + return False + def create_table(self, table_name: str): logger.debug(f"Creating table for SQL database: table_name={table_name}") @@ -111,7 +148,7 @@ def create_table(self, table_name: str): """ self.client.add(sql) - def add(self, data: list[str], index: str): + def add(self, data: list[str], index: str, **kwargs): logger.debug( f"Adding documents to SQL database: table_name={index} data_len={len(data)}" ) @@ -188,6 +225,25 @@ def __init__(self, client): self.client: OpenSearchVectorDatabase = client + def index_exists(self, index: str) -> bool: + """ + Check if the collection (index) exists in OpenSearch. + + Args: + index: The collection name to check + + Returns: + bool: True if the collection exists, False otherwise + """ + try: + self._validate_index(index) + return self.client.collection_exists(index) + except Exception as e: + logger.error( + f"Failed to check if collection exists in OpenSearch: index={index} error={e}" + ) + return False + def _validate_index(self, index: str): """ Verify whether the string conforms to the naming rules of index_name in OpenSearch. @@ -203,7 +259,7 @@ def _validate_index(self, index: str): "The index name does not conform to the naming rules of OpenSearch" ) - def add(self, data: list[str], index: str): + def add(self, data: list[str], index: str, **kwargs): self._validate_index(index) logger.debug( @@ -247,7 +303,7 @@ def delete_doc(self, index: str, id: str) -> bool: ) return False - def list_docs(self, index: str, offset: int = 0, limit: int = 1000) -> list[dict]: + def list_chunks(self, index: str, offset: int = 0, limit: int = 1000) -> list[dict]: self._validate_index(index) logger.debug(f"Listing documents from vector database: index={index}") return self.client.list_docs(collection_name=index, offset=offset, limit=limit) @@ -259,6 +315,25 @@ def __init__(self, client): self.client: VikingDatabase = client + def index_exists(self, index: str) -> bool: + """ + Check if the collection (index) exists in VikingDB. + + Args: + index: The collection name to check + + Returns: + bool: True if the collection exists, False otherwise + """ + try: + self._validate_index(index) + return self.client.collection_exists(index) + except Exception as e: + logger.error( + f"Failed to check if collection exists in VikingDB: index={index} error={e}" + ) + return False + def _validate_index(self, index: str): """ Only English letters, numbers, and underscores (_) are allowed. @@ -322,6 +397,13 @@ def delete_doc(self, index: str, id: str) -> bool: logger.debug(f"Deleting documents from vector database: index={index} id={id}") return self.client.delete_by_id(collection_name=index, id=id) + def list_chunks(self, index: str, offset: int, limit: int) -> list[dict]: + self._validate_index(index) + logger.debug(f"Listing documents from vector database: index={index}") + return self.client.list_chunks( + collection_name=index, offset=offset, limit=limit + ) + def list_docs(self, index: str, offset: int, limit: int) -> list[dict]: self._validate_index(index) logger.debug(f"Listing documents from vector database: index={index}") @@ -334,6 +416,25 @@ def __init__(self, client): self.client: VikingMemoryDatabase = client + def index_exists(self, index: str) -> bool: + """ + Check if the collection (index) exists in VikingMemoryDB. + + Note: + VikingMemoryDatabase does not support checking if a collection exists. + This method always returns False. + + Args: + index: The collection name to check + + Returns: + bool: Always returns False as VikingMemoryDatabase does not support this functionality + """ + logger.warning( + "VikingMemoryDatabase does not support checking if a collection exists" + ) + raise NotImplementedError("VikingMemoryDatabase does not support index_exists") + def _validate_index(self, index: str): if not ( isinstance(index, str) @@ -371,7 +472,7 @@ def delete(self, index: str) -> bool: def delete_docs(self, index: str, ids: list[int]): raise NotImplementedError("VikingMemoryDatabase does not support delete_docs") - def list_docs(self, index: str): + def list_chunks(self, index: str): raise NotImplementedError("VikingMemoryDatabase does not support list_docs") @@ -381,6 +482,23 @@ def __init__(self, client): self.client: LocalDataBase = client + def index_exists(self, index: str) -> bool: + """ + Check if the index exists in LocalDataBase. + + Note: + LocalDataBase does not support checking if an index exists. + This method always returns False. + + Args: + index: The index name to check (not used in LocalDataBase) + + Returns: + bool: Always returns False as LocalDataBase does not support this functionality + """ + logger.warning("LocalDataBase does not support checking if an index exists") + return True + def add(self, data: list[str], **kwargs): self.client.add(data) @@ -393,7 +511,7 @@ def delete(self, index: str) -> bool: def delete_doc(self, index: str, id: str) -> bool: return self.client.delete_doc(id) - def list_docs(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: + def list_chunks(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]: return self.client.list_docs(offset=offset, limit=limit) diff --git a/veadk/database/viking/viking_database.py b/veadk/database/viking/viking_database.py index 18474768..3127b0e2 100644 --- a/veadk/database/viking/viking_database.py +++ b/veadk/database/viking/viking_database.py @@ -41,8 +41,9 @@ doc_del_path = "/api/knowledge/collection/delete" doc_add_path = "/api/knowledge/doc/add" doc_info_path = "/api/knowledge/doc/info" -list_docs_path = "/api/knowledge/point/list" -delete_docs_path = "/api/knowledge/point/delete" +list_point_path = "/api/knowledge/point/list" +list_docs_path = "/api/knowledge/doc/list" +delete_docs_path = "/api/knowledge/doc/delete" class VolcengineTOSConfig(BaseModel): @@ -136,11 +137,25 @@ def _upload_to_tos( self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any, - ): - file_ext = kwargs.get( - "file_ext", ".pdf" - ) # when bytes data, file_ext is required + ) -> tuple[int, str]: + """ + Upload data to TOS (Tinder Object Storage). + + Args: + data: The data to be uploaded. Can be one of the following types: + - str: File path or string data + - list[str]: List of strings + - TextIO: File object (text) + - BinaryIO: File object (binary) + - bytes: Binary data + **kwargs: Additional keyword arguments. + - file_name (str): The file name (including suffix). + Returns: + tuple: A tuple containing the status code and TOS URL. + - status_code (int): HTTP status code + - tos_url (str): The URL of the uploaded file in TOS + """ ak = self.config.volcengine_ak sk = self.config.volcengine_sk @@ -151,21 +166,31 @@ def _upload_to_tos( client = tos.TosClientV2(ak, sk, tos_endpoint, tos_region, max_connections=1024) + # Extract file_name from kwargs - this is now required and includes the extension + file_names = kwargs.get("file_name") + if isinstance(data, str) and os.path.isfile(data): # Process file path - file_ext = os.path.splitext(data)[1] - new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}" + # Use provided file_name which includes the extension + new_key = f"{tos_key}/{file_names}" with open(data, "rb") as f: upload_data = f.read() + elif ( + isinstance(data, list) + and all(isinstance(item, str) for item in data) + and all(os.path.isfile(item) for item in data) + ): + # Process list of file paths - this should be handled at a higher level + raise ValueError( + "Uploading multiple files through a list of file paths is not supported in _upload_to_tos directly. Please call this function for each file individually." + ) + elif isinstance( data, (io.TextIOWrapper, io.BufferedReader), # file type: TextIO | BinaryIO ): # Process file stream - # Try to get the file extension from the file name, and use the default value if there is none - file_ext = ".unknown" - if hasattr(data, "name"): - _, file_ext = os.path.splitext(data.name) - new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}" + # Use provided file_name which includes the extension + new_key = f"{tos_key}/{file_names}" if isinstance(data, TextIO): # Encode the text stream content into bytes upload_data = data.read().encode("utf-8") @@ -174,16 +199,19 @@ def _upload_to_tos( upload_data = data.read() elif isinstance(data, str): # Process ordinary strings - new_key = f"{tos_key}/{str(uuid.uuid4())}.txt" + # Use provided file_name which includes the extension + new_key = f"{tos_key}/{file_names}" upload_data = data.encode("utf-8") # Encode as byte type elif isinstance(data, list): # Process list of strings - new_key = f"{tos_key}/{str(uuid.uuid4())}.txt" + # Use provided file_name which includes the extension + new_key = f"{tos_key}/{file_names}" # Join the strings in the list with newlines and encode as byte type upload_data = "\n".join(data).encode("utf-8") elif isinstance(data, bytes): # Process bytes data - new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}" + # Use provided file_name which includes the extension + new_key = f"{tos_key}/{file_names}" upload_data = data else: @@ -231,28 +259,136 @@ def add( **kwargs, ): """ + Add documents to the Viking database. Args: - data: str, file path or file stream: Both file or file.read() are acceptable. - **kwargs: collection_name(required) + data: The data to be added. Can be one of the following types: + - str: File path or string data + - list[str]: List of file paths or list of strings + - TextIO: File object (text) + - BinaryIO: File object (binary) + - bytes: Binary data + collection_name: The name of the collection to add documents to. + **kwargs: Additional keyword arguments. + - file_name (str | list[str]): The file name or a list of file names (including suffix). + - doc_id (str): The document ID. If not provided, a UUID will be generated. Returns: - { + dict or list: A dictionary containing the TOS URL and document ID, or a list of such dictionaries for multiple file uploads. + Format: { "tos_url": "tos:///", "doc_id": "", } """ - - status, tos_url = self._upload_to_tos(data=data, **kwargs) - if status != 200: - raise ValueError(f"Error in upload_to_tos: {status}") - doc_id = self._add_doc( - collection_name=collection_name, - tos_url=tos_url, - doc_id=str(uuid.uuid4()), - ) - return { - "tos_url": f"tos://{tos_url}", - "doc_id": doc_id, - } + # Handle list of file paths (multiple file upload) + if ( + isinstance(data, list) + and all(isinstance(item, str) for item in data) + and all(os.path.isfile(item) for item in data) + ): + # Handle multiple file upload + file_names = kwargs.get("file_name") + if ( + not file_names + or not isinstance(file_names, list) + or len(file_names) != len(data) + ): + raise ValueError( + "For multiple file upload, file_name must be provided as a list with the same length as data" + ) + + results = [] + for i, file_path in enumerate(data): + # Create kwargs for this specific file + single_kwargs = kwargs.copy() + single_kwargs["file_name"] = file_names[i] + + # Generate or use provided doc_id for this file + doc_id = single_kwargs.get("doc_id") + if not doc_id: + doc_id = str(uuid.uuid4()) + single_kwargs["doc_id"] = doc_id + + status, tos_url = self._upload_to_tos(data=file_path, **single_kwargs) + if status != 200: + raise ValueError( + f"Error in upload_to_tos for file {file_path}: {status}" + ) + + doc_id = self._add_doc( + collection_name=collection_name, + tos_url=tos_url, + doc_id=doc_id, + ) + + results.append( + { + "tos_url": f"tos://{tos_url}", + "doc_id": doc_id, + } + ) + + return results + + # Handle list of strings (multiple string upload) + elif isinstance(data, list) and all(isinstance(item, str) for item in data): + # Handle multiple string upload + file_names = kwargs.get("file_name") + if ( + not file_names + or not isinstance(file_names, list) + or len(file_names) != len(data) + ): + raise ValueError( + "For multiple string upload, file_name must be provided as a list with the same length as data" + ) + + results = [] + for i, content in enumerate(data): + # Create kwargs for this specific string + single_kwargs = kwargs.copy() + single_kwargs["file_name"] = file_names[i] + + # Generate or use provided doc_id for this string + doc_id = single_kwargs.get("doc_id") + if not doc_id: + doc_id = str(uuid.uuid4()) + single_kwargs["doc_id"] = doc_id + + status, tos_url = self._upload_to_tos(data=content, **single_kwargs) + if status != 200: + raise ValueError(f"Error in upload_to_tos for string {i}: {status}") + + doc_id = self._add_doc( + collection_name=collection_name, + tos_url=tos_url, + doc_id=doc_id, + ) + + results.append( + { + "tos_url": f"tos://{tos_url}", + "doc_id": doc_id, + } + ) + + return results + + # Handle single file upload or other data types + else: + # Handle doc_id from kwargs or generate a new one + doc_id = kwargs.get("doc_id", str(uuid.uuid4())) + + status, tos_url = self._upload_to_tos(data=data, **kwargs) + if status != 200: + raise ValueError(f"Error in upload_to_tos: {status}") + doc_id = self._add_doc( + collection_name=collection_name, + tos_url=tos_url, + doc_id=doc_id, + ) + return { + "tos_url": f"tos://{tos_url}", + "doc_id": doc_id, + } def delete(self, **kwargs: Any): name = kwargs.get("name") @@ -403,7 +539,7 @@ def collection_exists(self, collection_name: str) -> bool: else: return False - def list_docs( + def list_chunks( self, collection_name: str, offset: int = 0, limit: int = -1 ) -> list[dict]: request_params = { @@ -415,7 +551,7 @@ def list_docs( list_doc_req = prepare_request( method="POST", - path=list_docs_path, + path=list_point_path, config=self.config, data=request_params, ) @@ -431,6 +567,9 @@ def list_docs( logger.error(f"Error in list_docs: {result['message']}") raise ValueError(f"Error in list_docs: {result['message']}") + if not result["data"].get("point_list", []): + return [] + data = [ { "id": res["point_id"], @@ -441,11 +580,43 @@ def list_docs( ] return data + def list_docs( + self, collection_name: str, offset: int = 0, limit: int = -1 + ) -> list[dict]: + request_params = { + "collection_name": collection_name, + "project": self.config.project, + "offset": offset, + "limit": limit, + } + + list_doc_req = prepare_request( + method="POST", + path=list_docs_path, + config=self.config, + data=request_params, + ) + resp = requests.request( + method=list_doc_req.method, + url="https://{}{}".format(g_knowledge_base_domain, list_doc_req.path), + headers=list_doc_req.headers, + data=list_doc_req.body, + ) + + result = resp.json() + if result["code"] != 0: + logger.error(f"Error in list_docs: {result['message']}") + raise ValueError(f"Error in list_docs: {result['message']}") + + if not result["data"].get("doc_list", []): + return [] + return result["data"]["doc_list"] + def delete_by_id(self, collection_name: str, id: str) -> bool: request_params = { "collection_name": collection_name, "project": self.config.project, - "point_id": id, + "doc_id": id, } delete_by_id_req = prepare_request( diff --git a/veadk/knowledgebase/knowledgebase.py b/veadk/knowledgebase/knowledgebase.py index beab826f..2fa9e833 100644 --- a/veadk/knowledgebase/knowledgebase.py +++ b/veadk/knowledgebase/knowledgebase.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import io +import os.path from typing import Any, BinaryIO, Literal, TextIO from pydantic import BaseModel from veadk.database.database_adapter import get_knowledgebase_database_adapter from veadk.database.database_factory import DatabaseFactory +from veadk.utils.misc import formatted_timestamp from veadk.utils.logger import get_logger logger = get_logger(__name__) @@ -54,9 +56,16 @@ def add( ): """ Add documents to the vector database. - You can only upload files or file characters when the adapter type used is vikingdb. - In addition, if you upload data of the bytes type, - for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'. + Args: + data (str | list[str] | TextIO | BinaryIO | bytes): The data to be added. + - str: A single file path. (viking only) + - list[str]: A list of file paths. + - TextIO: A file object (TextIO). (viking only) file descriptor + - BinaryIO: A file object (BinaryIO). (viking only) file descriptor + - bytes: Binary data. (viking only) binary data (f.read()) + app_name: index name + **kwargs: Additional keyword arguments. + - file_name (str | list[str]): The file name or a list of file names (including suffix). (viking only) """ if self.backend != "viking" and not ( isinstance(data, str) or isinstance(data, list) @@ -66,10 +75,68 @@ def add( ) index = build_knowledgebase_index(app_name) - logger.info(f"Adding documents to knowledgebase: index={index}") - self._adapter.add(data=data, index=index) + if self.backend == "viking": + # Case 1: Handling file paths or lists of file paths (str) + if isinstance(data, str) and os.path.isfile(data): + # Get the file name (including the suffix) + if "file_name" not in kwargs or not kwargs["file_name"]: + kwargs["file_name"] = os.path.basename(data) + return self._adapter.add(data=data, index=index, **kwargs) + # Case 2: Handling when list[str] is a full path (list[str]) + if isinstance(data, list): + if all(isinstance(item, str) for item in data): + all_paths = all(os.path.isfile(item) for item in data) + all_not_paths = all(not os.path.isfile(item) for item in data) + if all_paths: + if "file_name" not in kwargs or not kwargs["file_name"]: + kwargs["file_name"] = [ + os.path.basename(item) for item in data + ] + return self._adapter.add(data=data, index=index, **kwargs) + elif ( + not all_not_paths + ): # Prevent the occurrence of non-existent paths + # There is a mixture of paths and non-paths + raise ValueError( + "Mixed file paths and content strings in list are not allowed" + ) + # Case 3: Handling strings or string arrays (content) (str or list[str]) + if isinstance(data, str) or ( + isinstance(data, list) and all(isinstance(item, str) for item in data) + ): + if "file_name" not in kwargs or not kwargs["file_name"]: + if isinstance(data, str): + kwargs["file_name"] = f"{formatted_timestamp()}.txt" + else: # list[str] without file_names + prefix_file_name = formatted_timestamp() + kwargs["file_name"] = [ + f"{prefix_file_name}_{i}.txt" for i in range(len(data)) + ] + return self._adapter.add(data=data, index=index, **kwargs) + + # Case 4: Handling binary data (bytes) + if isinstance(data, bytes): + # user must give file_name + if "file_name" not in kwargs: + raise ValueError("file_name must be provided for binary data") + return self._adapter.add(data=data, index=index, **kwargs) + + # Case 5: Handling file objects TextIO or BinaryIO + if isinstance(data, (io.TextIOWrapper, io.BufferedReader)): + if not kwargs.get("file_name") and hasattr(data, "name"): + kwargs["file_name"] = os.path.basename(data.name) + return self._adapter.add(data=data, index=index, **kwargs) + # Case6: Unsupported data type + raise TypeError(f"Unsupported data type: {type(data)}") + + if not isinstance(data, list): + raise TypeError( + f"Unsupported data type: {type(data)}. Only viking support file_path and file bytes" + ) + # not viking + return self._adapter.add(data=data, index=index, **kwargs) def search(self, query: str, app_name: str, top_k: int | None = None) -> list[str]: top_k = self.top_k if top_k is None else top_k @@ -85,12 +152,27 @@ def search(self, query: str, app_name: str, top_k: int | None = None) -> list[st def delete(self, app_name: str) -> bool: index = build_knowledgebase_index(app_name) - return self.adapter.delete(index=index) + return self._adapter.delete(index=index) def delete_doc(self, app_name: str, id: str) -> bool: index = build_knowledgebase_index(app_name) return self._adapter.delete_doc(index=index, id=id) + def list_chunks( + self, app_name: str, offset: int = 0, limit: int = 100 + ) -> list[dict]: + index = build_knowledgebase_index(app_name) + return self._adapter.list_chunks(index=index, offset=offset, limit=limit) + def list_docs(self, app_name: str, offset: int = 0, limit: int = 100) -> list[dict]: + if self.backend == "viking": + index = build_knowledgebase_index(app_name) + return self._adapter.list_docs(index=index, offset=offset, limit=limit) + else: + raise NotImplementedError( + f"list_docs not supported for {self.backend}, only viking support list_docs" + ) + + def exists(self, app_name: str) -> bool: index = build_knowledgebase_index(app_name) - return self._adapter.list_docs(index=index, offset=offset, limit=limit) + return self._adapter.index_exists(index=index)