diff --git a/tests/test_runner.py b/tests/test_runner.py index 9d7055a1..82e76f90 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -20,7 +20,13 @@ from veadk.runner import Runner +# Import the standalone function instead of accessing as class method +from veadk.runner import _convert_messages + + def _test_convert_messages(runner): + """Test message conversion logic using standalone _convert_messages function""" + # Test single text message conversion message = "test message" expected_message = [ types.Content( @@ -28,11 +34,16 @@ def _test_convert_messages(runner): role="user", ) ] - actual_message = runner._convert_messages( - message, session_id="test_session_id", upload_inline_data_to_tos=True + # Modified: Call _convert_messages directly (not as runner method) + actual_message = _convert_messages( + message, + app_name=runner.app_name, + user_id=runner.user_id, + session_id="test_session_id", ) assert actual_message == expected_message + # Test multiple text messages conversion message = ["test message 1", "test message 2"] expected_message = [ types.Content( @@ -44,13 +55,18 @@ def _test_convert_messages(runner): role="user", ), ] - actual_message = runner._convert_messages( - message, session_id="test_session_id", upload_inline_data_to_tos=True + # Modified: Call _convert_messages directly (not as runner method) + actual_message = _convert_messages( + message, + app_name=runner.app_name, + user_id=runner.user_id, + session_id="test_session_id", ) assert actual_message == expected_message def test_runner(): + """Test Runner class initialization and core properties""" short_term_memory = ShortTermMemory() long_term_memory = LongTermMemory(backend="local") agent = Agent( @@ -64,10 +80,9 @@ def test_runner(): runner = Runner(agent=agent, short_term_memory=short_term_memory) assert runner.long_term_memory == agent.long_term_memory - adk_runner = runner.runner - assert adk_runner.memory_service == agent.long_term_memory - assert adk_runner.session_service == runner.short_term_memory.session_service + # Verify inherited ADKRunner properties + assert runner.memory_service == agent.long_term_memory + assert runner.session_service == runner.short_term_memory.session_service - _test_convert_messages(runner) - _test_convert_messages(runner) + # Run message conversion tests _test_convert_messages(runner) diff --git a/veadk/integrations/ve_tos/ve_tos.py b/veadk/integrations/ve_tos/ve_tos.py index a52a952b..42eeceb6 100644 --- a/veadk/integrations/ve_tos/ve_tos.py +++ b/veadk/integrations/ve_tos/ve_tos.py @@ -22,7 +22,7 @@ from veadk.utils.logger import get_logger if TYPE_CHECKING: - import tos + pass # Initialize logger before using it @@ -41,9 +41,12 @@ def __init__( self.sk = sk if sk else os.getenv("VOLCENGINE_SECRET_KEY", "") self.region = region self.bucket_name = bucket_name + self._tos_module = None try: import tos + + self._tos_module = tos except ImportError as e: logger.error( "Failed to import 'tos' module. Please install it using: pip install tos\n" @@ -54,7 +57,7 @@ def __init__( self._client = None try: - self._client = tos.TosClientV2( + self._client = self._tos_module.TosClientV2( ak=self.ak, sk=self.sk, endpoint=f"tos-{self.region}.volces.com", @@ -68,7 +71,7 @@ def _refresh_client(self): try: if self._client: self._client.close() - self._client = tos.TosClientV2( + self._client = self._tos_module.TosClientV2( self.ak, self.sk, endpoint=f"tos-{self.region}.volces.com", @@ -87,13 +90,13 @@ def create_bucket(self) -> bool: try: self._client.head_bucket(self.bucket_name) logger.info(f"Bucket {self.bucket_name} already exists") - except tos.exceptions.TosServerError as e: + except self._tos_module.exceptions.TosServerError as e: if e.status_code == 404: try: self._client.create_bucket( bucket=self.bucket_name, - storage_class=tos.StorageClassType.Storage_Class_Standard, - acl=tos.ACLType.ACL_Public_Read, + storage_class=self._tos_module.StorageClassType.Storage_Class_Standard, + acl=self._tos_module.ACLType.ACL_Public_Read, ) logger.info(f"Bucket {self.bucket_name} created successfully") self._refresh_client() @@ -115,7 +118,7 @@ def _set_cors_rules(self) -> bool: logger.error("TOS client is not initialized") return False try: - rule = tos.models2.CORSRule( + rule = self._tos_module.models2.CORSRule( allowed_origins=["*"], allowed_methods=["GET", "HEAD"], allowed_headers=["*"], @@ -174,7 +177,7 @@ def _do_upload_bytes(self, object_key: str, data: bytes) -> None: self._client.put_object( bucket=self.bucket_name, key=object_key, content=data ) - logger.debug(f"Upload success, object_key: {object_key}") + logger.debug(f"Upload success, url: {object_key}") self._close() return except Exception as e: diff --git a/veadk/runner.py b/veadk/runner.py index a2b9d33b..b04909b6 100644 --- a/veadk/runner.py +++ b/veadk/runner.py @@ -11,31 +11,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import asyncio +import functools +from types import MethodType from typing import Union +from google import genai from google.adk.agents import RunConfig +from google.adk.agents.base_agent import BaseAgent from google.adk.agents.invocation_context import LlmCallsLimitExceededError -from google.adk.agents.run_config import StreamingMode -from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner as ADKRunner from google.genai import types from google.genai.types import Blob -from veadk.a2a.remote_ve_agent import RemoteVeAgent from veadk.agent import Agent from veadk.agents.loop_agent import LoopAgent from veadk.agents.parallel_agent import ParallelAgent from veadk.agents.sequential_agent import SequentialAgent +from veadk.config import getenv from veadk.evaluation import EvalSetRecorder from veadk.memory.short_term_memory import ShortTermMemory from veadk.types import MediaMessage from veadk.utils.logger import get_logger -from veadk.utils.misc import getenv, read_png_to_bytes +from veadk.utils.misc import formatted_timestamp, read_png_to_bytes logger = get_logger(__name__) - RunnerMessage = Union[ str, # single turn text-based prompt list[str], # multiple turn text-based prompt @@ -44,133 +46,227 @@ list[MediaMessage | str], # multiple turn prompt with media and text-based prompt ] -VeAgent = Union[Agent, RemoteVeAgent, SequentialAgent, ParallelAgent, LoopAgent] +def pre_run_process(self, process_func, new_message, user_id, session_id): + if new_message.parts: + for part in new_message.parts: + if ( + part.inline_data + and part.inline_data.mime_type == "image/png" + and self.upload_inline_data_to_tos + ): + process_func( + part, + self.app_name, + user_id, + session_id, + ) + return -class Runner: - def __init__( - self, - agent: VeAgent, - short_term_memory: ShortTermMemory | None = None, - plugins: list[BasePlugin] | None = None, - app_name: str = "veadk_default_app", - user_id: str = "veadk_default_user", - ): - self.app_name = app_name - self.user_id = user_id - self.agent = agent +def post_run_process(self): + return - if not short_term_memory: - logger.info( - "No short term memory provided, using a in-memory memory by default." + +def intercept_new_message(process_func): + def decorator(func): + @functools.wraps(func) + async def wrapper( + self, + *, + user_id: str, + session_id: str, + new_message: types.Content, + **kwargs, + ): + pre_run_process(self, process_func, new_message, user_id, session_id) + + async for event in func( + user_id=user_id, + session_id=session_id, + new_message=new_message, + **kwargs, + ): + yield event + + post_run_process(self) + + return wrapper + + return decorator + + +def _convert_messages( + messages: RunnerMessage, + app_name: str, + user_id: str, + session_id: str, +) -> list: + """Convert VeADK formatted messages to Google ADK formatted messages.""" + if isinstance(messages, str): + _messages = [types.Content(role="user", parts=[types.Part(text=messages)])] + elif isinstance(messages, MediaMessage): + assert messages.media.endswith(".png"), ( + "The MediaMessage only supports PNG format file for now." + ) + _messages = [ + types.Content( + role="user", + parts=[ + types.Part(text=messages.text), + types.Part( + inline_data=Blob( + display_name=messages.media, + data=read_png_to_bytes(messages.media), + mime_type="image/png", + ) + ), + ], ) - self.short_term_memory = ShortTermMemory() - else: - self.short_term_memory = short_term_memory + ] + elif isinstance(messages, list): + converted_messages = [] + for message in messages: + converted_messages.extend( + _convert_messages(message, app_name, user_id, session_id) + ) + _messages = converted_messages + else: + raise ValueError(f"Unknown message type: {type(messages)}") - self.session_service = self.short_term_memory.session_service + return _messages - # prevent VeRemoteAgent has no long-term memory attr - if isinstance(self.agent, Agent): - self.long_term_memory = self.agent.long_term_memory - else: - self.long_term_memory = None - self.runner = ADKRunner( - app_name=self.app_name, - agent=self.agent, - session_service=self.session_service, - memory_service=self.long_term_memory, - plugins=plugins, - ) +def _upload_image_to_tos( + part: genai.types.Part, app_name: str, user_id: str, session_id: str +) -> None: + try: + if part.inline_data and part.inline_data.display_name and part.inline_data.data: + from veadk.integrations.ve_tos.ve_tos import VeTOS + + ve_tos = VeTOS() - def _convert_messages( - self, messages, session_id, upload_inline_data_to_tos - ) -> list: - if isinstance(messages, str): - messages = [types.Content(role="user", parts=[types.Part(text=messages)])] - elif isinstance(messages, MediaMessage): - assert messages.media.endswith(".png"), ( - "The MediaMessage only supports PNG format file for now." + object_key, tos_url = ve_tos.build_tos_url( + user_id=user_id, + app_name=app_name, + session_id=session_id, + data_path=part.inline_data.display_name, ) - data = read_png_to_bytes(messages.media) - tos_url = "" - if upload_inline_data_to_tos: - try: - from veadk.integrations.ve_tos.ve_tos import VeTOS - - ve_tos = VeTOS() - object_key, tos_url = ve_tos.build_tos_url( - self.user_id, self.app_name, session_id, messages.media - ) - upload_task = ve_tos.upload(object_key, data) - if upload_task is not None: - asyncio.create_task(upload_task) - except Exception as e: - logger.error(f"Upload to TOS failed: {e}") - tos_url = None - else: + upload_task = ve_tos.upload(object_key, part.inline_data.data) + + if upload_task is not None: + asyncio.create_task(upload_task) + + part.inline_data.display_name = tos_url + except Exception as e: + logger.error(f"Upload to TOS failed: {e}") + + +class Runner(ADKRunner): + def __init__( + self, + agent: BaseAgent | Agent, + short_term_memory: ShortTermMemory | None = None, + app_name: str = "veadk_default_app", + user_id: str = "veadk_default_user", + upload_inline_data_to_tos: bool = False, + *args, + **kwargs, + ) -> None: + self.user_id = user_id + self.long_term_memory = None + self.short_term_memory = short_term_memory + self.upload_inline_data_to_tos = upload_inline_data_to_tos + + session_service = kwargs.pop("session_service", None) + memory_service = kwargs.pop("memory_service", None) + + if session_service: + if short_term_memory: logger.warning( - "Loss of multimodal data may occur in the tracing process." + "Short term memory is enabled, but session service is also provided. We will use session service from runner argument." ) - messages = [ - types.Content( - role="user", - parts=[ - types.Part(text=messages.text), - types.Part( - inline_data=Blob( - display_name=tos_url, - data=data, - mime_type="image/png", - ) - ), - ], + if not session_service: + if short_term_memory: + session_service = short_term_memory.session_service + logger.debug( + f"Use session service {session_service} from short term memory." ) - ] - elif isinstance(messages, list): - converted_messages = [] - for message in messages: - converted_messages.extend( - self._convert_messages( - message, session_id, upload_inline_data_to_tos - ) + else: + logger.warning( + "No short term memory or session service provided, use an in-memory one instead." ) - messages = converted_messages - else: - raise ValueError(f"Unknown message type: {type(messages)}") + short_term_memory = ShortTermMemory() + self.short_term_memory = short_term_memory + session_service = short_term_memory.session_service - return messages + if memory_service: + if hasattr(agent, "long_term_memory") and agent.long_term_memory: # type: ignore + self.long_term_memory = agent.long_term_memory # type: ignore + logger.warning( + "Long term memory in agent is enabled, but memory service is also provided. We will use memory service from runner argument." + ) + + if not memory_service: + if hasattr(agent, "long_term_memory") and agent.long_term_memory: # type: ignore + self.long_term_memory = agent.long_term_memory # type: ignore + memory_service = agent.long_term_memory # type: ignore + else: + logger.info("No long term memory provided.") + + super().__init__( + agent=agent, + session_service=session_service, + memory_service=memory_service, + app_name=app_name, + *args, + **kwargs, + ) - async def _run( + self.run_async = MethodType( + intercept_new_message(_upload_image_to_tos)(self.run_async), self + ) + + async def run( self, - session_id: str, - message: types.Content, + messages: RunnerMessage, + user_id: str = "", + session_id: str = f"tmp-session-{formatted_timestamp()}", run_config: RunConfig | None = None, - stream: bool = False, + save_tracing_data: bool = False, + upload_inline_data_to_tos: bool = False, ): - stream_mode = StreamingMode.SSE if stream else StreamingMode.NONE + if upload_inline_data_to_tos: + _upload_inline_data_to_tos = self.upload_inline_data_to_tos + self.upload_inline_data_to_tos = upload_inline_data_to_tos - if run_config is not None: - stream_mode = run_config.streaming_mode - else: + if not run_config: run_config = RunConfig( - streaming_mode=stream_mode, + # streaming_mode=stream_mode, max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100)), ) - logger.info(f"Run config: {run_config}") - try: + user_id = user_id or self.user_id + + converted_messages: list = _convert_messages( + messages, self.app_name, user_id, session_id + ) - async def event_generator(): - async for event in self.runner.run_async( + if self.short_term_memory: + await self.short_term_memory.create_session( + app_name=self.app_name, user_id=self.user_id, session_id=session_id + ) + + final_output = "" + for converted_message in converted_messages: + try: + async for event in self.run_async( user_id=self.user_id, session_id=session_id, - new_message=message, + new_message=converted_message, run_config=run_config, ): if event.get_function_calls(): @@ -182,45 +278,10 @@ async def event_generator(): and event.content.parts[0].text is not None and len(event.content.parts[0].text.strip()) > 0 ): - yield event.content.parts[0].text - - final_output = "" - async for chunk in event_generator(): - if stream: - print(chunk, end="", flush=True) - final_output += chunk - if stream: - print() # end with a new line - except LlmCallsLimitExceededError as e: - logger.warning(f"Max number of llm calls limit exceeded: {e}") - final_output = "" - - return final_output - - async def run( - self, - messages: RunnerMessage, - session_id: str, - stream: bool = False, - run_config: RunConfig | None = None, - save_tracing_data: bool = False, - upload_inline_data_to_tos: bool = False, - ): - converted_messages: list = self._convert_messages( - messages, session_id, upload_inline_data_to_tos - ) - - await self.short_term_memory.create_session( - app_name=self.app_name, user_id=self.user_id, session_id=session_id - ) - - logger.info("Begin to process user messages.") - - final_output = "" - for converted_message in converted_messages: - final_output = await self._run( - session_id, converted_message, run_config, stream - ) + final_output += event.content.parts[0].text + except LlmCallsLimitExceededError as e: + logger.warning(f"Max number of llm calls limit exceeded: {e}") + final_output = "" # try to save tracing file if save_tracing_data: @@ -228,6 +289,9 @@ async def run( self._print_trace_id() + if upload_inline_data_to_tos: + self.upload_inline_data_to_tos = _upload_inline_data_to_tos # type: ignore + return final_output def get_trace_id(self) -> str: @@ -250,54 +314,6 @@ def get_trace_id(self) -> str: logger.warning(f"Get tracer id failed as {e}") return "" - async def run_with_raw_message( - self, - message: types.Content, - session_id: str, - run_config: RunConfig | None = None, - ): - run_config = ( - RunConfig(max_llm_calls=int(getenv("MODEL_AGENT_MAX_LLM_CALLS", 100))) - if not run_config - else run_config - ) - - logger.info(f"Run config: {run_config}") - - await self.short_term_memory.create_session( - app_name=self.app_name, user_id=self.user_id, session_id=session_id - ) - - try: - - async def event_generator(): - async for event in self.runner.run_async( - user_id=self.user_id, - session_id=session_id, - new_message=message, - run_config=run_config, - ): - if event.get_function_calls(): - for function_call in event.get_function_calls(): - logger.debug(f"Function call: {function_call}") - elif ( - event.content is not None - and event.content.parts - and event.content.parts[0].text is not None - and len(event.content.parts[0].text.strip()) > 0 - ): - yield event.content.parts[0].text - - final_output = "" - - async for chunk in event_generator(): - final_output += chunk - except LlmCallsLimitExceededError as e: - logger.warning(f"Max number of llm calls limit exceeded: {e}") - final_output = "" - - return final_output - def _print_trace_id(self) -> None: if not isinstance(self.agent, Agent): logger.warning( @@ -368,61 +384,3 @@ async def save_session_to_long_term_memory(self, session_id: str) -> None: await self.long_term_memory.add_session_to_memory(session) logger.info(f"Add session `{session.id}` to long term memory.") - - # [deprecated] we will not host a chat-service in VeADK, so the following two methods are deprecated - - # async def run_with_final_event( - # self, - # messages: RunnerMessage, - # session_id: str, - # ): - # """non-streaming run with final event""" - # messages: list = self._convert_messages(messages) - - # await self.short_term_memory.create_session( - # app_name=self.app_name, user_id=self.user_id, session_id=session_id - # ) - - # logger.info("Begin to process user messages.") - - # final_event = "" - # async for event in self.runner.run_async( - # user_id=self.user_id, session_id=session_id, new_message=messages[0] - # ): - # if event.get_function_calls(): - # for function_call in event.get_function_calls(): - # logger.debug(f"Function call: {function_call}") - # elif ( - # not event.partial - # and event.content.parts[0].text is not None - # and len(event.content.parts[0].text.strip()) > 0 - # ): - # final_event = event.model_dump_json(exclude_none=True, by_alias=True) - - # return final_event - - # async def run_sse( - # self, - # session_id: str, - # prompt: str, - # ): - # message = types.Content(role="user", parts=[types.Part(text=prompt)]) - - # await self.short_term_memory.create_session( - # app_name=self.app_name, user_id=self.user_id, session_id=session_id - # ) - - # logger.info("Begin to process user messages under SSE method.") - - # async for event in self.runner.run_async( - # user_id=self.user_id, - # session_id=session_id, - # new_message=message, - # run_config=RunConfig(streaming_mode=StreamingMode.SSE), - # ): - # # Format as SSE data - # sse_event = event.model_dump_json(exclude_none=True, by_alias=True) - # if event.get_function_calls(): - # for function_call in event.get_function_calls(): - # logger.debug(f"SSE function call event: {sse_event}") - # yield f"data: {sse_event}\n\n"