|
17 | 17 | from typing import Any, Optional |
18 | 18 |
|
19 | 19 | from google.adk.agents.callback_context import CallbackContext |
| 20 | +from google.adk.agents.invocation_context import InvocationContext |
20 | 21 | from google.adk.models.llm_request import LlmRequest |
21 | 22 | from google.adk.models.llm_response import LlmResponse |
| 23 | +from google.adk.plugins.base_plugin import BasePlugin |
22 | 24 | from google.adk.tools import BaseTool, ToolContext |
| 25 | +from google.genai import types |
23 | 26 | from opentelemetry import trace |
24 | 27 |
|
25 | 28 | from veadk.utils.logger import get_logger |
26 | 29 |
|
27 | 30 | logger = get_logger(__name__) |
28 | 31 |
|
29 | 32 |
|
| 33 | +class UserMessagePlugin(BasePlugin): |
| 34 | + def __init__(self, name: str): |
| 35 | + super().__init__(name) |
| 36 | + |
| 37 | + async def on_user_message_callback( |
| 38 | + self, |
| 39 | + *, |
| 40 | + invocation_context: InvocationContext, |
| 41 | + user_message: types.Content, |
| 42 | + ) -> Optional[types.Content]: |
| 43 | + """Callback executed when a user message is received before an invocation starts. |
| 44 | +
|
| 45 | + This callback helps logging and modifying the user message before the |
| 46 | + runner starts the invocation. |
| 47 | +
|
| 48 | + Args: |
| 49 | + invocation_context: The context for the entire invocation. |
| 50 | + user_message: The message content input by user. |
| 51 | +
|
| 52 | + Returns: |
| 53 | + An optional `types.Content` to be returned to the ADK. Returning a |
| 54 | + value to replace the user message. Returning `None` to proceed |
| 55 | + normally. |
| 56 | + """ |
| 57 | + trace.get_tracer("gcp.vertex.agent") |
| 58 | + span = trace.get_current_span() |
| 59 | + |
| 60 | + logger.debug(f"User message plugin works, catch {span}") |
| 61 | + span_name = getattr(span, "name", None) |
| 62 | + if span_name and span_name.startswith("invocation"): |
| 63 | + agent_name = invocation_context.agent.name |
| 64 | + invoke_branch = ( |
| 65 | + invocation_context.branch if invocation_context.branch else agent_name |
| 66 | + ) |
| 67 | + current_session = invocation_context.session |
| 68 | + |
| 69 | + span.set_attribute("app_name", current_session.app_name) |
| 70 | + span.set_attribute("user_id", current_session.user_id) |
| 71 | + span.set_attribute("session_id", current_session.id) |
| 72 | + |
| 73 | + span.set_attribute("agent_name", agent_name) |
| 74 | + span.set_attribute("invoke_branch", invoke_branch) |
| 75 | + |
| 76 | + logger.debug( |
| 77 | + f"Add attributes to {span_name}: app_name={current_session.app_name}, user_id={current_session.user_id}, session_id={current_session.id}, agent_name={agent_name}, invoke_branch={invoke_branch}" |
| 78 | + ) |
| 79 | + |
| 80 | + return None |
| 81 | + |
| 82 | + |
30 | 83 | def replace_bytes_with_empty(data): |
31 | 84 | """ |
32 | 85 | Recursively traverse the data structure and replace all bytes types with empty strings. |
|
0 commit comments