Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,30 @@
from veadk.runner import Runner


# Import the standalone function instead of accessing as class method
from veadk.runner import _convert_messages


def _test_convert_messages(runner):
"""Test message conversion logic using standalone _convert_messages function"""
# Test single text message conversion
message = "test message"
expected_message = [
types.Content(
parts=[types.Part(text=message)],
role="user",
)
]
actual_message = runner._convert_messages(
message, session_id="test_session_id", upload_inline_data_to_tos=True
# Modified: Call _convert_messages directly (not as runner method)
actual_message = _convert_messages(
message,
app_name=runner.app_name,
user_id=runner.user_id,
session_id="test_session_id",
)
assert actual_message == expected_message

# Test multiple text messages conversion
message = ["test message 1", "test message 2"]
expected_message = [
types.Content(
Expand All @@ -44,13 +55,18 @@ def _test_convert_messages(runner):
role="user",
),
]
actual_message = runner._convert_messages(
message, session_id="test_session_id", upload_inline_data_to_tos=True
# Modified: Call _convert_messages directly (not as runner method)
actual_message = _convert_messages(
message,
app_name=runner.app_name,
user_id=runner.user_id,
session_id="test_session_id",
)
assert actual_message == expected_message


def test_runner():
"""Test Runner class initialization and core properties"""
short_term_memory = ShortTermMemory()
long_term_memory = LongTermMemory(backend="local")
agent = Agent(
Expand All @@ -64,10 +80,9 @@ def test_runner():
runner = Runner(agent=agent, short_term_memory=short_term_memory)
assert runner.long_term_memory == agent.long_term_memory

adk_runner = runner.runner
assert adk_runner.memory_service == agent.long_term_memory
assert adk_runner.session_service == runner.short_term_memory.session_service
# Verify inherited ADKRunner properties
assert runner.memory_service == agent.long_term_memory
assert runner.session_service == runner.short_term_memory.session_service

_test_convert_messages(runner)
_test_convert_messages(runner)
# Run message conversion tests
_test_convert_messages(runner)
19 changes: 11 additions & 8 deletions veadk/integrations/ve_tos/ve_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from veadk.utils.logger import get_logger

if TYPE_CHECKING:
import tos
pass


# Initialize logger before using it
Expand All @@ -41,9 +41,12 @@ def __init__(
self.sk = sk if sk else os.getenv("VOLCENGINE_SECRET_KEY", "")
self.region = region
self.bucket_name = bucket_name
self._tos_module = None

try:
import tos

self._tos_module = tos
except ImportError as e:
logger.error(
"Failed to import 'tos' module. Please install it using: pip install tos\n"
Expand All @@ -54,7 +57,7 @@ def __init__(

self._client = None
try:
self._client = tos.TosClientV2(
self._client = self._tos_module.TosClientV2(
ak=self.ak,
sk=self.sk,
endpoint=f"tos-{self.region}.volces.com",
Expand All @@ -68,7 +71,7 @@ def _refresh_client(self):
try:
if self._client:
self._client.close()
self._client = tos.TosClientV2(
self._client = self._tos_module.TosClientV2(
self.ak,
self.sk,
endpoint=f"tos-{self.region}.volces.com",
Expand All @@ -87,13 +90,13 @@ def create_bucket(self) -> bool:
try:
self._client.head_bucket(self.bucket_name)
logger.info(f"Bucket {self.bucket_name} already exists")
except tos.exceptions.TosServerError as e:
except self._tos_module.exceptions.TosServerError as e:
if e.status_code == 404:
try:
self._client.create_bucket(
bucket=self.bucket_name,
storage_class=tos.StorageClassType.Storage_Class_Standard,
acl=tos.ACLType.ACL_Public_Read,
storage_class=self._tos_module.StorageClassType.Storage_Class_Standard,
acl=self._tos_module.ACLType.ACL_Public_Read,
)
logger.info(f"Bucket {self.bucket_name} created successfully")
self._refresh_client()
Expand All @@ -115,7 +118,7 @@ def _set_cors_rules(self) -> bool:
logger.error("TOS client is not initialized")
return False
try:
rule = tos.models2.CORSRule(
rule = self._tos_module.models2.CORSRule(
allowed_origins=["*"],
allowed_methods=["GET", "HEAD"],
allowed_headers=["*"],
Expand Down Expand Up @@ -174,7 +177,7 @@ def _do_upload_bytes(self, object_key: str, data: bytes) -> None:
self._client.put_object(
bucket=self.bucket_name, key=object_key, content=data
)
logger.debug(f"Upload success, object_key: {object_key}")
logger.debug(f"Upload success, url: {object_key}")
self._close()
return
except Exception as e:
Expand Down
Loading