Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def _test_convert_messages(runner):
role="user",
)
]
actual_message = runner._convert_messages(message, session_id="test_session_id")
actual_message = runner._convert_messages(
message, session_id="test_session_id", upload_inline_data_to_tos=True
)
assert actual_message == expected_message

message = ["test message 1", "test message 2"]
Expand All @@ -42,7 +44,9 @@ def _test_convert_messages(runner):
role="user",
),
]
actual_message = runner._convert_messages(message, session_id="test_session_id")
actual_message = runner._convert_messages(
message, session_id="test_session_id", upload_inline_data_to_tos=True
)
assert actual_message == expected_message


Expand Down
67 changes: 62 additions & 5 deletions tests/test_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,29 @@

import pytest
from unittest import mock
import veadk.integrations.ve_tos.ve_tos as tos_mod

# Check if tos module is available
import importlib

TOS_AVAILABLE = False
try:
importlib.import_module("veadk.integrations.ve_tos.ve_tos")
TOS_AVAILABLE = True
except ImportError:
pass

# Skip tests that require tos module if it's not available
require_tos = pytest.mark.skipif(not TOS_AVAILABLE, reason="tos module not available")

# 使用 pytest-asyncio
pytest_plugins = ("pytest_asyncio",)


@pytest.fixture
@require_tos
def mock_client(monkeypatch):
import veadk.integrations.ve_tos.ve_tos as tos_mod

fake_client = mock.Mock()

monkeypatch.setenv("DATABASE_TOS_REGION", "test-region")
Expand All @@ -33,9 +48,17 @@ def mock_client(monkeypatch):

class FakeExceptions:
class TosServerError(Exception):
def __init__(self, msg):
def __init__(
self,
msg: str,
code: int = 0,
host_id: str = "",
resource: str = "",
request_id: str = "",
header=None,
):
super().__init__(msg)
self.status_code = None
self.status_code = code

monkeypatch.setattr(tos_mod.tos, "exceptions", FakeExceptions)
monkeypatch.setattr(
Expand All @@ -51,27 +74,34 @@ def __init__(self, msg):


@pytest.fixture
@require_tos
def tos_client(mock_client):
import veadk.integrations.ve_tos.ve_tos as tos_mod

return tos_mod.VeTOS()


@require_tos
def test_create_bucket_exists(tos_client, mock_client):
mock_client.head_bucket.return_value = None # head_bucket 正常返回表示存在
result = tos_client.create_bucket()
assert result is True
mock_client.create_bucket.assert_not_called()


@require_tos
def test_create_bucket_not_exists(tos_client, mock_client):
exc = tos_mod.tos.exceptions.TosServerError("not found")
exc.status_code = 404
import veadk.integrations.ve_tos.ve_tos as tos_mod

exc = tos_mod.tos.exceptions.TosServerError(msg="not found", code=404)
mock_client.head_bucket.side_effect = exc

result = tos_client.create_bucket()
assert result is True
mock_client.create_bucket.assert_called_once()


@require_tos
@pytest.mark.asyncio
async def test_upload_bytes_success(tos_client, mock_client):
mock_client.head_bucket.return_value = True
Expand All @@ -83,6 +113,7 @@ async def test_upload_bytes_success(tos_client, mock_client):
mock_client.close.assert_called_once()


@require_tos
@pytest.mark.asyncio
async def test_upload_file_success(tmp_path, tos_client, mock_client):
mock_client.head_bucket.return_value = True
Expand All @@ -95,6 +126,7 @@ async def test_upload_file_success(tmp_path, tos_client, mock_client):
mock_client.close.assert_called_once()


@require_tos
def test_download_success(tmp_path, tos_client, mock_client):
save_path = tmp_path / "out.txt"
mock_client.get_object.return_value = [b"abc", b"def"]
Expand All @@ -104,7 +136,32 @@ def test_download_success(tmp_path, tos_client, mock_client):
assert save_path.read_bytes() == b"abcdef"


@require_tos
def test_download_fail(tos_client, mock_client):
mock_client.get_object.side_effect = Exception("boom")
result = tos_client.download("obj-key", "somewhere.txt")
assert result is False


@require_tos
@pytest.mark.skipif(TOS_AVAILABLE, reason="tos module is available")
def test_tos_import_error():
"""Test VeTOS behavior when tos module is not installed"""
# Remove tos from sys.modules to simulate it's not installed
import sys

original_tos = sys.modules.get("tos")
if "tos" in sys.modules:
del sys.modules["tos"]

try:
# Try to import ve_tos module, which should raise ImportError
with pytest.raises(ImportError) as exc_info:
pass

# Check that the error message contains installation instructions
assert "pip install tos" in str(exc_info.value)
finally:
# Restore original state
if original_tos is not None:
sys.modules["tos"] = original_tos
38 changes: 28 additions & 10 deletions veadk/integrations/ve_tos/ve_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,27 @@
import os
from veadk.config import getenv
from veadk.utils.logger import get_logger
import tos
import asyncio
from typing import Union
from pydantic import BaseModel, Field
from typing import Any
from urllib.parse import urlparse
from datetime import datetime

# Initialize logger before using it
logger = get_logger(__name__)

# Try to import tos module, and provide helpful error message if it fails
try:
import tos
except ImportError as e:
logger.error(
"Failed to import 'tos' module. Please install it using: pip install tos\n"
)
raise ImportError(
"Missing 'tos' module. Please install it using: pip install tos\n"
) from e


class TOSConfig(BaseModel):
region: str = Field(
Expand Down Expand Up @@ -59,10 +70,13 @@ def model_post_init(self, __context: Any) -> None:
logger.info("Connected to TOS successfully.")
except Exception as e:
logger.error(f"Client initialization failed:{e}")
return None
self._client = None

def create_bucket(self) -> bool:
"""If the bucket does not exist, create it"""
if not self._client:
logger.error("TOS client is not initialized")
return False
try:
self._client.head_bucket(self.config.bucket_name)
logger.info(f"Bucket {self.config.bucket_name} already exists")
Expand All @@ -76,6 +90,9 @@ def create_bucket(self) -> bool:
)
logger.info(f"Bucket {self.config.bucket_name} created successfully")
return True
else:
logger.error(f"Bucket creation failed: {str(e)}")
return False
except Exception as e:
logger.error(f"Bucket creation failed: {str(e)}")
return False
Expand Down Expand Up @@ -103,26 +120,24 @@ def upload(
data: Union[str, bytes],
):
if isinstance(data, str):
data_type = "file"
# data is a file path
return asyncio.to_thread(self._do_upload_file, object_key, data)
elif isinstance(data, bytes):
data_type = "bytes"
# data is bytes content
return asyncio.to_thread(self._do_upload_bytes, object_key, data)
else:
error_msg = f"Upload failed: data type error. Only str (file path) and bytes are supported, got {type(data)}"
logger.error(error_msg)
raise ValueError(error_msg)
if data_type == "file":
return asyncio.to_thread(self._do_upload_file, object_key, data)
elif data_type == "bytes":
return asyncio.to_thread(self._do_upload_bytes, object_key, data)

def _do_upload_bytes(self, object_key: str, bytes: bytes) -> bool:
def _do_upload_bytes(self, object_key: str, data: bytes) -> bool:
try:
if not self._client:
return False
if not self.create_bucket():
return False
self._client.put_object(
bucket=self.config.bucket_name, key=object_key, content=bytes
bucket=self.config.bucket_name, key=object_key, content=data
)
logger.debug(f"Upload success, object_key: {object_key}")
self._close()
Expand Down Expand Up @@ -152,6 +167,9 @@ def _do_upload_file(self, object_key: str, file_path: str) -> bool:

def download(self, object_key: str, save_path: str) -> bool:
"""download image from TOS"""
if not self._client:
logger.error("TOS client is not initialized")
return False
try:
object_stream = self._client.get_object(self.config.bucket_name, object_key)

Expand Down
48 changes: 34 additions & 14 deletions veadk/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from veadk.agents.sequential_agent import SequentialAgent
from veadk.config import getenv
from veadk.evaluation import EvalSetRecorder
from veadk.integrations.ve_tos.ve_tos import VeTOS
from veadk.memory.short_term_memory import ShortTermMemory
from veadk.types import MediaMessage
from veadk.utils.logger import get_logger
Expand Down Expand Up @@ -87,24 +86,36 @@ def __init__(
plugins=plugins,
)

def _convert_messages(self, messages, session_id) -> list:
def _convert_messages(
self, messages, session_id, upload_inline_data_to_tos
) -> list:
if isinstance(messages, str):
messages = [types.Content(role="user", parts=[types.Part(text=messages)])]
elif isinstance(messages, MediaMessage):
assert messages.media.endswith(".png"), (
"The MediaMessage only supports PNG format file for now."
)
data = read_png_to_bytes(messages.media)

ve_tos = VeTOS()
object_key, tos_url = ve_tos.build_tos_url(
self.user_id, self.app_name, session_id, messages.media
)
try:
asyncio.create_task(ve_tos.upload(object_key, data))
except Exception as e:
logger.error(f"Upload to TOS failed: {e}")
tos_url = None
tos_url = "<tos_url>"
if upload_inline_data_to_tos:
try:
from veadk.integrations.ve_tos.ve_tos import VeTOS

ve_tos = VeTOS()
object_key, tos_url = ve_tos.build_tos_url(
self.user_id, self.app_name, session_id, messages.media
)
upload_task = ve_tos.upload(object_key, data)
if upload_task is not None:
asyncio.create_task(upload_task)
except Exception as e:
logger.error(f"Upload to TOS failed: {e}")
tos_url = None

else:
logger.warning(
"Loss of multimodal data may occur in the tracing process."
)

messages = [
types.Content(
Expand All @@ -124,7 +135,11 @@ def _convert_messages(self, messages, session_id) -> list:
elif isinstance(messages, list):
converted_messages = []
for message in messages:
converted_messages.extend(self._convert_messages(message, session_id))
converted_messages.extend(
self._convert_messages(
message, session_id, upload_inline_data_to_tos
)
)
messages = converted_messages
else:
raise ValueError(f"Unknown message type: {type(messages)}")
Expand Down Expand Up @@ -179,6 +194,7 @@ async def event_generator():
print() # end with a new line
except LlmCallsLimitExceededError as e:
logger.warning(f"Max number of llm calls limit exceeded: {e}")
final_output = ""

return final_output

Expand All @@ -189,8 +205,11 @@ async def run(
stream: bool = False,
run_config: RunConfig | None = None,
save_tracing_data: bool = False,
upload_inline_data_to_tos: bool = False,
):
converted_messages: list = self._convert_messages(messages, session_id)
converted_messages: list = self._convert_messages(
messages, session_id, upload_inline_data_to_tos
)

await self.short_term_memory.create_session(
app_name=self.app_name, user_id=self.user_id, session_id=session_id
Expand Down Expand Up @@ -276,6 +295,7 @@ async def event_generator():
final_output += chunk
except LlmCallsLimitExceededError as e:
logger.warning(f"Max number of llm calls limit exceeded: {e}")
final_output = ""

return final_output

Expand Down