diff --git a/agentic_ai/agents/agent_framework/STATE_MANAGEMENT.md b/agentic_ai/agents/agent_framework/STATE_MANAGEMENT.md index c4b1d2923..972870621 100644 --- a/agentic_ai/agents/agent_framework/STATE_MANAGEMENT.md +++ b/agentic_ai/agents/agent_framework/STATE_MANAGEMENT.md @@ -253,7 +253,43 @@ Executors can include arbitrary JSON-friendly payloads in their `snapshot_state` ## 6. External checkpoint storage implementations -### 6.1 Redis-backed CheckpointStorage +Starting with `agent-framework` 1.2.1 the framework ships ready-made +`CheckpointStorage` implementations, so most workshops should use those +directly rather than rolling their own. The samples below are kept for +reference / advanced customization. + +### 6.0 Built-in storages (recommended) + +```python +# In-process (default) – great for tests and single-process demos. +from agent_framework import InMemoryCheckpointStorage +storage = InMemoryCheckpointStorage() + +# JSON-on-disk – atomic writes, path-traversal protection, type-restricted +# pickle deserialization for safety. +from agent_framework import FileCheckpointStorage +storage = FileCheckpointStorage("/var/lib/workshop/checkpoints") + +# Durable, partitioned by workflow_name. Auto-creates the database and +# container on first use; supports DefaultAzureCredential or an account key. +from agent_framework_azure_cosmos import CosmosCheckpointStorage +from azure.identity.aio import DefaultAzureCredential +storage = CosmosCheckpointStorage( + endpoint="https://my-account.documents.azure.com:443/", + credential=DefaultAzureCredential(), + database_name="agent-db", + container_name="checkpoints", +) +``` + +In this workshop the multi-agent agents pick a storage via +`agentic_ai/agents/agent_framework/multi_agent/checkpoint_storage.py` based on +the `WORKFLOW_CHECKPOINT_BACKEND` environment variable +(`memory` (default) | `file` | `cosmos`). For `file` the location can be +overridden with `WORKFLOW_CHECKPOINT_DIR`; for `cosmos` the standard +`AZURE_COSMOS_*` environment variables apply. + +### 6.1 Custom Redis-backed CheckpointStorage ```python import json @@ -306,6 +342,11 @@ class RedisCheckpointStorage(CheckpointStorage): ### 6.2 Azure Cosmos DB CheckpointStorage +> Prefer the built-in `agent_framework_azure_cosmos.CosmosCheckpointStorage` +> shown in §6.0 for production deployments. The custom implementation +> below is kept for reference if you need to integrate with an existing +> Cosmos schema. + ```python from azure.cosmos.aio import CosmosClient from agent_framework._workflow._checkpoint import CheckpointStorage, WorkflowCheckpoint diff --git a/agentic_ai/agents/agent_framework/multi_agent/checkpoint_storage.py b/agentic_ai/agents/agent_framework/multi_agent/checkpoint_storage.py new file mode 100644 index 000000000..ec56b4f8b --- /dev/null +++ b/agentic_ai/agents/agent_framework/multi_agent/checkpoint_storage.py @@ -0,0 +1,208 @@ +"""Checkpoint storage factory for multi-agent workflows. + +This module exposes a single ``create_checkpoint_storage`` helper that returns +a ``CheckpointStorage`` instance using the storages shipped with +``agent-framework`` 1.2.1: + +* ``InMemoryCheckpointStorage`` — default; in-process, lost on restart. +* ``FileCheckpointStorage`` — JSON-on-disk with atomic writes and + path-traversal protection. +* ``CosmosCheckpointStorage`` — durable, partitioned by ``workflow_name``, + shipped in ``agent_framework_azure_cosmos``. + +Selection is driven by the ``WORKFLOW_CHECKPOINT_BACKEND`` environment +variable (``memory`` | ``file`` | ``cosmos``) so deployments can opt into +durable checkpointing without touching agent code. + +Storage instances are cached per (backend, session) tuple so successive +agent invocations within the same process share the same in-memory state +(matching the behaviour of the previous hand-rolled ``DictCheckpointStorage``). + +Helpers: + +* ``prune_checkpoints`` — bound the number of saved checkpoints per workflow + using only the public ``CheckpointStorage`` protocol, replacing the + ``_RETENTION`` cap that lived inside the old custom storage classes. +""" + +from __future__ import annotations + +import logging +import os +from threading import Lock +from typing import Any, Dict, Optional, Tuple + +from agent_framework import CheckpointStorage, FileCheckpointStorage, InMemoryCheckpointStorage + +logger = logging.getLogger(__name__) + + +_BACKEND_ENV = "WORKFLOW_CHECKPOINT_BACKEND" +_FILE_DIR_ENV = "WORKFLOW_CHECKPOINT_DIR" +_DEFAULT_FILE_DIR = ".checkpoints" + +_storage_cache: Dict[Tuple[str, str], CheckpointStorage] = {} +_cache_lock = Lock() + + +def _resolve_backend() -> str: + backend = (os.getenv(_BACKEND_ENV) or "memory").strip().lower() + if backend not in {"memory", "file", "cosmos"}: + logger.warning( + "Unknown %s=%r; falling back to 'memory'. Allowed values: memory, file, cosmos.", + _BACKEND_ENV, + backend, + ) + backend = "memory" + return backend + + +def _build_file_storage(session_id: str) -> CheckpointStorage: + base_dir = os.getenv(_FILE_DIR_ENV) or _DEFAULT_FILE_DIR + # Scope per session so concurrent sessions cannot accidentally read each + # other's checkpoint files. Session IDs are sanitized by collapsing any + # path-traversal characters before joining. + safe_session = "".join(ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in session_id) + storage_path = os.path.join(base_dir, safe_session) + return FileCheckpointStorage(storage_path) + + +def _build_cosmos_storage() -> CheckpointStorage: + # Imported lazily so the cosmos extra is only required when actually used. + try: + from agent_framework_azure_cosmos import CosmosCheckpointStorage + except ImportError as exc: # pragma: no cover - defensive + raise RuntimeError( + "WORKFLOW_CHECKPOINT_BACKEND=cosmos requires the " + "'agent-framework-azure-cosmos' package to be installed." + ) from exc + + # Try managed-identity first when no AZURE_COSMOS_KEY is configured. The + # CosmosCheckpointStorage already reads endpoint / database / container / + # key from AZURE_COSMOS_* environment variables, so we only need to + # supply a credential for the keyless case. + if os.getenv("AZURE_COSMOS_KEY"): + return CosmosCheckpointStorage() + + try: + from azure.identity.aio import DefaultAzureCredential + except ImportError as exc: # pragma: no cover - defensive + raise RuntimeError( + "WORKFLOW_CHECKPOINT_BACKEND=cosmos without AZURE_COSMOS_KEY " + "requires 'azure-identity' to be installed for managed-identity auth." + ) from exc + + return CosmosCheckpointStorage(credential=DefaultAzureCredential()) + + +def create_checkpoint_storage(session_id: str) -> CheckpointStorage: + """Return a per-session ``CheckpointStorage`` from configuration. + + Args: + session_id: Used to scope file-backed storage to a per-session + directory and to key the in-process cache so successive calls + within the same process share state. + + Returns: + A storage instance compatible with the 1.2.x ``CheckpointStorage`` + protocol. + """ + backend = _resolve_backend() + cache_key = (backend, session_id) + + with _cache_lock: + existing = _storage_cache.get(cache_key) + if existing is not None: + return existing + + if backend == "file": + storage: CheckpointStorage = _build_file_storage(session_id) + elif backend == "cosmos": + storage = _build_cosmos_storage() + else: + storage = InMemoryCheckpointStorage() + + _storage_cache[cache_key] = storage + logger.info("Created %s checkpoint storage for session=%s", backend, session_id) + return storage + + +def reset_storage_cache() -> None: + """Clear the in-process storage cache. Intended for tests.""" + with _cache_lock: + _storage_cache.clear() + + +async def prune_checkpoints( + storage: CheckpointStorage, + workflow_name: str, + *, + retain: int, +) -> None: + """Bound the number of checkpoints retained for ``workflow_name``. + + Only the most recent ``retain`` checkpoints (by ``timestamp``) are kept; + older ones are deleted via the public ``CheckpointStorage.delete`` method. + Failures are logged and swallowed so checkpoint hygiene cannot break a + chat turn. + """ + if retain <= 0: + return + try: + checkpoints = await storage.list_checkpoints(workflow_name=workflow_name) + except Exception as exc: # pragma: no cover - defensive + logger.debug("Unable to list checkpoints for pruning (%s): %s", workflow_name, exc) + return + + if len(checkpoints) <= retain: + return + + # Most-recent-first ordering using the timestamp field of WorkflowCheckpoint. + checkpoints.sort(key=lambda cp: getattr(cp, "timestamp", "") or "", reverse=True) + for stale in checkpoints[retain:]: + try: + await storage.delete(stale.checkpoint_id) + except Exception as exc: # pragma: no cover - defensive + logger.debug("Failed to prune checkpoint %s: %s", stale.checkpoint_id, exc) + + +async def purge_checkpoints(storage: CheckpointStorage, workflow_name: Optional[str]) -> None: + """Delete every checkpoint for ``workflow_name`` using the public protocol. + + No-ops when ``workflow_name`` is missing (the protocol cannot enumerate + across workflows). + """ + if not workflow_name: + return + try: + ids = await storage.list_checkpoint_ids(workflow_name=workflow_name) + except Exception as exc: # pragma: no cover - defensive + logger.debug("Unable to list checkpoint ids for purge (%s): %s", workflow_name, exc) + return + for checkpoint_id in ids: + try: + await storage.delete(checkpoint_id) + except Exception as exc: # pragma: no cover - defensive + logger.debug("Failed to delete checkpoint %s during purge: %s", checkpoint_id, exc) + + +__all__ = [ + "create_checkpoint_storage", + "prune_checkpoints", + "purge_checkpoints", + "reset_storage_cache", +] + + +def _coerce_checkpoint_storage(candidate: Any) -> Optional[CheckpointStorage]: + """Validate that ``candidate`` looks like a ``CheckpointStorage`` instance. + + Used by callers that accept storage overrides from configuration so that + test doubles can be substituted without inheriting from the protocol class. + """ + if candidate is None: + return None + for method_name in ("save", "load", "delete", "get_latest", "list_checkpoints", "list_checkpoint_ids"): + if not callable(getattr(candidate, method_name, None)): + return None + return candidate # type: ignore[return-value] diff --git a/agentic_ai/agents/agent_framework/multi_agent/handoff_multi_domain_agent.py b/agentic_ai/agents/agent_framework/multi_agent/handoff_multi_domain_agent.py index ab8ea3a08..d60878dc8 100644 --- a/agentic_ai/agents/agent_framework/multi_agent/handoff_multi_domain_agent.py +++ b/agentic_ai/agents/agent_framework/multi_agent/handoff_multi_domain_agent.py @@ -26,18 +26,14 @@ from __future__ import annotations -import asyncio import logging import os -from threading import Lock as ThreadLock from typing import Any, Dict, List, Optional from agent_framework import ( Agent as FrameworkAgent, ChatOptions, - CheckpointStorage, MCPStreamableHTTPTool, - WorkflowCheckpoint, ) from agent_framework.openai import OpenAIChatClient from agent_framework_orchestrations import ( @@ -48,6 +44,10 @@ from agents.base_agent import BaseAgent, ToolCallTrackingMixin from agents.agent_framework.utils import create_filtered_tool_list +from agents.agent_framework.multi_agent.checkpoint_storage import ( + create_checkpoint_storage, + prune_checkpoints, +) logger = logging.getLogger(__name__) @@ -147,74 +147,10 @@ } -class _DictCheckpointStorage(CheckpointStorage): - """Dictionary-backed ``CheckpointStorage`` shared via the session state store. - - Survives across BaseAgent instances within the same session so that the - HandoffBuilder workflow can resume mid-conversation on subsequent - requests. - """ - - _RETENTION = 5 +class _CheckpointRetention: + """How many checkpoints per workflow to keep on disk/in-memory.""" - def __init__(self, backing: Dict[str, Any]) -> None: - self._backing = backing - self._checkpoints: Dict[str, Dict[str, Any]] = backing.setdefault("checkpoints", {}) - self._async_lock = asyncio.Lock() - self._sync_lock = ThreadLock() - - async def save(self, checkpoint: WorkflowCheckpoint) -> str: - async with self._async_lock: - self._checkpoints[checkpoint.checkpoint_id] = checkpoint.to_dict() - self._backing["latest_checkpoint"] = checkpoint.checkpoint_id - self._backing["workflow_name"] = checkpoint.workflow_name - - if len(self._checkpoints) > self._RETENTION: - sorted_ids = sorted( - self._checkpoints.items(), - key=lambda item: (item[1].get("timestamp", ""), item[1].get("iteration_count", 0)), - ) - for cid, _ in sorted_ids[: -self._RETENTION]: - self._checkpoints.pop(cid, None) - return checkpoint.checkpoint_id - - async def load(self, checkpoint_id: str) -> WorkflowCheckpoint | None: - async with self._async_lock: - data = self._checkpoints.get(checkpoint_id) - if not data: - return None - return WorkflowCheckpoint.from_dict(data) - - async def list_checkpoint_ids(self, *, workflow_name: str) -> List[str]: - async with self._async_lock: - return [cid for cid, d in self._checkpoints.items() if d.get("workflow_name") == workflow_name] - - async def list_checkpoints(self, *, workflow_name: str) -> List[WorkflowCheckpoint]: - async with self._async_lock: - ids = [cid for cid, d in self._checkpoints.items() if d.get("workflow_name") == workflow_name] - return [WorkflowCheckpoint.from_dict(self._checkpoints[cid]) for cid in ids] - - async def delete(self, checkpoint_id: str) -> bool: - async with self._async_lock: - removed = self._checkpoints.pop(checkpoint_id, None) - if removed and self._backing.get("latest_checkpoint") == checkpoint_id: - self._backing.pop("latest_checkpoint", None) - return removed is not None - - async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: - async with self._async_lock: - latest_id = self._backing.get("latest_checkpoint") - if not latest_id: - return None - data = self._checkpoints.get(latest_id) - if not data or data.get("workflow_name") != workflow_name: - return None - return WorkflowCheckpoint.from_dict(data) - - @property - def latest_checkpoint_id(self) -> str | None: - with self._sync_lock: - return self._backing.get("latest_checkpoint") + DEFAULT = 5 class Agent(ToolCallTrackingMixin, BaseAgent): @@ -235,11 +171,13 @@ def __init__( self._domain_agents: Dict[str, FrameworkAgent] = {} self._mcp_tool: Optional[MCPStreamableHTTPTool] = None - # Checkpoint storage is backed by the per-session state_store so the - # workflow's conversation state survives across HTTP requests. - self._handoff_state_key = f"{session_id}_handoff_state" - backing = state_store.setdefault(self._handoff_state_key, {}) - self._checkpoint_storage = _DictCheckpointStorage(backing) + # Checkpoint storage uses the built-in 1.2.1 backends (in-memory by + # default; FileCheckpointStorage / CosmosCheckpointStorage when + # WORKFLOW_CHECKPOINT_BACKEND is set). The helper caches one storage + # instance per session inside the process so successive HTTP requests + # share state. + self._workflow_name = f"handoff-{session_id}" + self._checkpoint_storage = create_checkpoint_storage(session_id) # Track the pending ``request_info`` ID so the next turn can resume. self._pending_request_id_key = f"{session_id}_handoff_pending_req" @@ -340,7 +278,7 @@ async def _setup(self) -> None: description=cfg["description"], instructions=cfg["instructions"], tools=domain_tools, - default_options=ChatOptions(model=self.openai_model_name), + default_options=ChatOptions(model=self.azure_deployment), require_per_service_call_history_persistence=True, ) await agent.__aenter__() @@ -361,7 +299,7 @@ async def _setup(self) -> None: # can route anywhere" behaviour. self._workflow = ( HandoffBuilder( - name=f"handoff-{self.session_id}", + name=self._workflow_name, participants=list(self._domain_agents.values()), ) .with_start_agent(self._domain_agents[start_id]) @@ -404,7 +342,13 @@ async def chat_async(self, prompt: str) -> str: self._current_turn += 1 self.state_store[self._turn_key] = self._current_turn - latest_checkpoint = self._checkpoint_storage.latest_checkpoint_id + # Look up the most recent checkpoint via the public 1.2.x + # CheckpointStorage protocol so any backend (memory / file / cosmos) + # works without bespoke plumbing. + latest_checkpoint_obj = await self._checkpoint_storage.get_latest( + workflow_name=self._workflow_name + ) + latest_checkpoint = latest_checkpoint_obj.checkpoint_id if latest_checkpoint_obj else None # Resume an in-flight workflow (typical path after the first turn) by # responding to the pending HandoffAgentUserRequest with the new user @@ -500,6 +444,15 @@ async def chat_async(self, prompt: str) -> str: ) self._setstate({"mode": "handoff_multi_domain", "current_domain": self._current_domain}) + # Cap retained checkpoints to avoid unbounded growth across long + # conversations; mirrors the previous _RETENTION=5 behaviour but uses + # the public CheckpointStorage protocol so any backend benefits. + await prune_checkpoints( + self._checkpoint_storage, + self._workflow_name, + retain=_CheckpointRetention.DEFAULT, + ) + return assistant_response # ------------------------------------------------------------------ diff --git a/agentic_ai/agents/agent_framework/multi_agent/magentic_group.py b/agentic_ai/agents/agent_framework/multi_agent/magentic_group.py index a8208d799..f78c6d440 100644 --- a/agentic_ai/agents/agent_framework/multi_agent/magentic_group.py +++ b/agentic_ai/agents/agent_framework/multi_agent/magentic_group.py @@ -3,14 +3,12 @@ import json import logging import os -from threading import Lock as ThreadLock from typing import Any, Callable, Dict, Iterable, List, Optional, cast from agent_framework import ( Agent as FrameworkAgent, ChatOptions, MCPStreamableHTTPTool, - WorkflowCheckpoint, WorkflowEvent, CheckpointStorage, ResponseStream, @@ -27,91 +25,17 @@ from agents.base_agent import BaseAgent, ToolCallTrackingMixin from agents.agent_framework.utils import create_filtered_tool_list +from agents.agent_framework.multi_agent.checkpoint_storage import ( + _coerce_checkpoint_storage, + create_checkpoint_storage, + prune_checkpoints, + purge_checkpoints, +) logger = logging.getLogger(__name__) -class DictCheckpointStorage(CheckpointStorage): - """Dictionary-backed checkpoint storage that persists across Agent instances.""" - - _RETENTION = 5 - - def __init__(self, backing_store: Dict[str, Any]) -> None: - self._backing = backing_store - self._checkpoints: Dict[str, Dict[str, Any]] = backing_store.setdefault("checkpoints", {}) - self._async_lock = asyncio.Lock() - self._sync_lock = ThreadLock() - - async def save(self, checkpoint: WorkflowCheckpoint) -> str: - async with self._async_lock: - self._checkpoints[checkpoint.checkpoint_id] = checkpoint.to_dict() - self._backing["latest_checkpoint"] = checkpoint.checkpoint_id - self._backing["workflow_name"] = checkpoint.workflow_name - - if len(self._checkpoints) > self._RETENTION: - sorted_ids = sorted( - self._checkpoints.items(), - key=lambda item: (item[1].get("timestamp", ""), item[1].get("iteration_count", 0)), - ) - for checkpoint_id, _ in sorted_ids[:-self._RETENTION]: - self._checkpoints.pop(checkpoint_id, None) - return checkpoint.checkpoint_id - - async def load(self, checkpoint_id: str) -> WorkflowCheckpoint | None: - async with self._async_lock: - data = self._checkpoints.get(checkpoint_id) - if not data: - return None - return WorkflowCheckpoint.from_dict(data) - - async def list_checkpoint_ids(self, *, workflow_name: str) -> List[str]: - async with self._async_lock: - return [cid for cid, data in self._checkpoints.items() if data.get("workflow_name") == workflow_name] - - async def list_checkpoints(self, *, workflow_name: str) -> List[WorkflowCheckpoint]: - async with self._async_lock: - ids = [cid for cid, data in self._checkpoints.items() if data.get("workflow_name") == workflow_name] - return [WorkflowCheckpoint.from_dict(self._checkpoints[cid]) for cid in ids] - - async def delete(self, checkpoint_id: str) -> bool: - async with self._async_lock: - removed = self._checkpoints.pop(checkpoint_id, None) - if removed and self._backing.get("latest_checkpoint") == checkpoint_id: - self._backing.pop("latest_checkpoint", None) - return removed is not None - - async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None: - async with self._async_lock: - latest_id = self._backing.get("latest_checkpoint") - if not latest_id: - return None - data = self._checkpoints.get(latest_id) - if not data or data.get("workflow_name") != workflow_name: - return None - return WorkflowCheckpoint.from_dict(data) - - @property - def latest_checkpoint_id(self) -> str | None: - with self._sync_lock: - return self._backing.get("latest_checkpoint") - - def mark_pending_prompt(self, prompt: str) -> None: - with self._sync_lock: - self._backing["pending_prompt"] = prompt - - def consume_pending_prompt(self) -> str | None: - with self._sync_lock: - prompt = self._backing.get("pending_prompt") - if prompt is not None: - self._backing.pop("pending_prompt", None) - return prompt - - def clear_all(self) -> None: - with self._sync_lock: - self._checkpoints.clear() - self._backing.pop("latest_checkpoint", None) - self._backing.pop("workflow_id", None) - self._backing.pop("pending_prompt", None) +_CHECKPOINT_RETENTION = 5 class Agent(ToolCallTrackingMixin, BaseAgent): @@ -212,7 +136,7 @@ def __init__( or self.state_store.get("magentic_checkpoint_storage_factory") ) storage_override = self.state_store.get("magentic_checkpoint_storage") - self._checkpoint_storage_override: Optional[CheckpointStorage] = self._coerce_checkpoint_storage( + self._checkpoint_storage_override: Optional[CheckpointStorage] = _coerce_checkpoint_storage( storage_override ) if storage_override and not self._checkpoint_storage_override: @@ -231,12 +155,15 @@ def __init__( self._max_reset_count = int(self._config.get("max_reset_count", 1)) self._participant_overrides: Dict[str, Dict[str, Any]] = self._config.get("participant_overrides", {}) self._pending_prompt_state_key = f"{self.session_id}_magentic_pending_prompt" - self._in_memory_checkpoint_storage: Optional[DictCheckpointStorage] = None + # The workflow name is captured after the first build() and reused so + # successive HTTP requests can locate prior checkpoints via the + # standardized CheckpointStorage protocol (get_latest(workflow_name=...)). + self._workflow_name_state_key = f"{self.session_id}_magentic_workflow_name" self._ws_manager = None # Will be set from backend if available self._stream_agent_id: Optional[str] = None self._stream_line_open: bool = False self._last_agent_message: Optional[str] = None # Track last agent message for deduplication - + # Initialize tool tracking from mixin self.init_tool_tracking() @@ -247,8 +174,7 @@ def set_websocket_manager(self, manager: Any) -> None: async def chat_async(self, prompt: str) -> str: self._validate_configuration() - checkpoint_state = self.state_store.setdefault(f"{self.session_id}_magentic_checkpoint", {}) - checkpoint_storage = self._create_checkpoint_storage(checkpoint_state) + checkpoint_storage = self._create_checkpoint_storage() headers = self._build_headers() tools = await self._maybe_create_tools(headers) @@ -262,7 +188,7 @@ async def chat_async(self, prompt: str) -> str: manager_client = self._get_manager_client() task = self._render_task_with_history(prompt) - await self._mark_pending_prompt(checkpoint_storage, prompt) + self._mark_pending_prompt(prompt) workflow = await self._build_workflow(participant_client, manager_client, tools, checkpoint_storage) @@ -295,10 +221,18 @@ async def chat_async(self, prompt: str) -> str: return cleaned_answer def _validate_configuration(self) -> None: - if not all([self.azure_openai_key, self.azure_deployment, self.azure_openai_endpoint, self.api_version]): + if not all([self.azure_deployment, self.azure_openai_endpoint, self.api_version]): + raise RuntimeError( + "Azure OpenAI configuration is incomplete. Ensure AZURE_OPENAI_CHAT_DEPLOYMENT, " + "AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_API_VERSION are set." + ) + # Either an API key or a managed-identity credential must be available. + # _build_chat_client falls back to azure_credential when azure_openai_key + # is unset; the other agents in this package follow the same pattern. + if not self.azure_openai_key and not getattr(self, "azure_credential", None): raise RuntimeError( - "Azure OpenAI configuration is incomplete. Ensure AZURE_OPENAI_API_KEY, " - "AZURE_OPENAI_CHAT_DEPLOYMENT, AZURE_OPENAI_ENDPOINT, and AZURE_OPENAI_API_VERSION are set." + "Azure OpenAI authentication is not configured. Set AZURE_OPENAI_API_KEY " + "or provide a managed-identity credential." ) def _build_headers(self) -> Dict[str, str]: @@ -421,7 +355,7 @@ async def _resume_previous_run( if cleaned_answer is None: await self._reset_checkpoint_progress(checkpoint_storage) return None - original_prompt = await self._consume_pending_prompt(checkpoint_storage) + original_prompt = self._consume_pending_prompt() if original_prompt: self.append_to_chat_history( [ @@ -472,7 +406,13 @@ async def _build_workflow( enable_plan_review=self._enable_plan_review, ) - return builder.build() + workflow = builder.build() + # MagenticBuilder generates a random workflow name (e.g. + # ``WorkflowBuilder-``) per build. Persist it in the per-session + # state_store so subsequent requests can call + # ``storage.get_latest(workflow_name=...)`` to find prior checkpoints. + self.state_store[self._workflow_name_state_key] = workflow.name + return workflow async def _create_participants( self, @@ -579,7 +519,7 @@ async def _create_participants( agent_kwargs: Dict[str, Any] = { **defaults, "client": participant_client, - "default_options": ChatOptions(model=self.openai_model_name), + "default_options": ChatOptions(model=self.azure_deployment), } # Apply tool filtering for this participant's domain @@ -894,40 +834,39 @@ def _sanitize_final_answer(self, final_answer: Optional[str]) -> Optional[str]: # No marker found - return cleaned text return final_answer.strip() or None - def _create_checkpoint_storage(self, checkpoint_state: Dict[str, Any]) -> CheckpointStorage: - if self._checkpoint_storage_override: + def _create_checkpoint_storage(self) -> CheckpointStorage: + """Resolve the active checkpoint storage for this session. + + Resolution order: + + 1. An explicit per-process override (set by tests or hosts that wire + up a custom storage via ``state_store["magentic_checkpoint_storage"]``). + 2. A factory callable supplied via constructor or + ``state_store["magentic_checkpoint_storage_factory"]``. + 3. The standard built-in storages exposed by + ``checkpoint_storage.create_checkpoint_storage`` (in-memory by + default; FileCheckpointStorage / CosmosCheckpointStorage when + ``WORKFLOW_CHECKPOINT_BACKEND`` is set). + """ + if self._checkpoint_storage_override is not None: return self._checkpoint_storage_override if self._checkpoint_storage_factory: - storage = self._checkpoint_storage_factory(checkpoint_state, self.session_id) - if storage: + # The legacy factory signature passed (state_dict, session_id); + # preserve it so existing hosts continue to work. + storage_candidate: Any = self._checkpoint_storage_factory({}, self.session_id) + storage = _coerce_checkpoint_storage(storage_candidate) + if storage is not None: if self._config.get("cache_factory_storage", True): self.state_store["magentic_checkpoint_storage"] = storage self._checkpoint_storage_override = storage return storage logger.warning( - "[AgentFramework-Magentic] Provided checkpoint storage factory returned None; falling back to in-memory storage." + "[AgentFramework-Magentic] Provided checkpoint storage factory returned an " + "object that does not implement CheckpointStorage; falling back to built-in storage." ) - if self._in_memory_checkpoint_storage is None: - self._in_memory_checkpoint_storage = DictCheckpointStorage(checkpoint_state) - return self._in_memory_checkpoint_storage - - def _coerce_checkpoint_storage(self, candidate: Any) -> Optional[CheckpointStorage]: - if candidate is None: - return None - - required_methods = [ - "save", - "load", - ] - - for method_name in required_methods: - method = getattr(candidate, method_name, None) - if not callable(method): - return None - - return cast(CheckpointStorage, candidate) + return create_checkpoint_storage(self.session_id) def _load_effective_config(self, runtime_config: Optional[Dict[str, Any]]) -> Dict[str, Any]: merged: Dict[str, Any] = {} @@ -1033,133 +972,34 @@ def _maybe_parse_bool(self, value: Optional[str]) -> Optional[bool]: return False return None - async def _mark_pending_prompt(self, storage: CheckpointStorage, prompt: str) -> None: - """Mark a pending prompt in storage.""" + def _mark_pending_prompt(self, prompt: str) -> None: + """Persist the in-flight prompt so a resume after a crash can re-emit it.""" self.state_store[self._pending_prompt_state_key] = prompt - mark_fn = getattr(storage, "mark_pending_prompt", None) - if callable(mark_fn): - try: - await self._call_maybe_async(mark_fn, prompt) - except Exception as exc: - logger.debug("Failed to mark pending prompt: %s", exc) - - async def _consume_pending_prompt(self, storage: CheckpointStorage) -> Optional[str]: - """Consume and return pending prompt from storage.""" - stored_prompt = self.state_store.get(self._pending_prompt_state_key) - storage_prompt = None - - consume_fn = getattr(storage, "consume_pending_prompt", None) - if callable(consume_fn): - try: - storage_prompt = await self._call_maybe_async(consume_fn) - except Exception as exc: - logger.debug("Failed to consume pending prompt: %s", exc) - if stored_prompt or storage_prompt: - self.state_store.pop(self._pending_prompt_state_key, None) - - return storage_prompt or stored_prompt + def _consume_pending_prompt(self) -> Optional[str]: + """Pop and return the previously-stored in-flight prompt, if any.""" + return self.state_store.pop(self._pending_prompt_state_key, None) + + def _current_workflow_name(self) -> Optional[str]: + """Return the workflow name recorded by the most recent build, if any.""" + return self.state_store.get(self._workflow_name_state_key) async def _reset_checkpoint_progress(self, storage: CheckpointStorage) -> None: - await self._purge_checkpoint_storage(storage) + """Wipe checkpoints for the active workflow and clear the pending prompt.""" + await purge_checkpoints(storage, self._current_workflow_name()) self.state_store.pop(self._pending_prompt_state_key, None) - async def _purge_checkpoint_storage(self, storage: CheckpointStorage) -> None: - """Delete all checkpoints from storage.""" - # Try clear_all first - clear_fn = getattr(storage, "clear_all", None) - if callable(clear_fn): - try: - await self._call_maybe_async(clear_fn) - return - except Exception as exc: - logger.debug("clear_all failed: %s", exc) - - # Fallback: list and delete individually using the 1.2.x CheckpointStorage protocol. - # ``list_checkpoint_ids`` now requires a keyword-only ``workflow_name``. - list_fn = getattr(storage, "list_checkpoint_ids", None) - delete_fn = getattr(storage, "delete", None) - if not (callable(list_fn) and callable(delete_fn)): - return - - try: - workflow_name = self._workflow_name_for_storage(storage) - checkpoint_ids = await self._call_maybe_async(list_fn, workflow_name=workflow_name) if workflow_name else [] - if checkpoint_ids: - for checkpoint_id in checkpoint_ids: - try: - await self._call_maybe_async(delete_fn, checkpoint_id) - except Exception as exc: - logger.debug("Failed to delete checkpoint %s: %s", checkpoint_id, exc) - except Exception as exc: - logger.debug("Unable to enumerate checkpoints: %s", exc) - - @staticmethod - def _workflow_name_for_storage(storage: CheckpointStorage) -> str | None: - """Best-effort lookup of the active workflow name for a storage instance. - - The DictCheckpointStorage shipped with this module records the workflow - name on every ``save()``; for other storages we cannot infer it. - """ - backing = getattr(storage, "_backing", None) - if isinstance(backing, dict): - return backing.get("workflow_name") - return None - async def _get_latest_checkpoint_id(self, storage: CheckpointStorage) -> Optional[str]: - """Get the most recent checkpoint ID from storage.""" - # Try latest_checkpoint_id property/method first (nonstandard convenience - # exposed by the in-process DictCheckpointStorage in this module). - latest_id_attr = getattr(storage, "latest_checkpoint_id", None) - if callable(latest_id_attr): - try: - latest_id = await self._call_maybe_async(latest_id_attr) - if isinstance(latest_id, str): - return latest_id - except Exception: - pass - elif isinstance(latest_id_attr, str): - return latest_id_attr - - # Best-effort: the 1.2.x ``CheckpointStorage`` protocol requires a - # keyword-only ``workflow_name`` on ``get_latest`` / ``list_checkpoints`` - # / ``list_checkpoint_ids``. Without one we cannot call those methods. - workflow_name = self._workflow_name_for_storage(storage) - - # Try the 1.2.x ``get_latest`` shortcut. - get_latest_fn = getattr(storage, "get_latest", None) - if callable(get_latest_fn) and workflow_name: - try: - latest = await self._call_maybe_async(get_latest_fn, workflow_name=workflow_name) - if latest is not None: - checkpoint_id = getattr(latest, "checkpoint_id", None) - if isinstance(checkpoint_id, str): - return checkpoint_id - except Exception: - pass - - # Try list_checkpoints and pick the most recent entry. - list_checkpoints_fn = getattr(storage, "list_checkpoints", None) - if callable(list_checkpoints_fn) and workflow_name: - try: - checkpoints = await self._call_maybe_async(list_checkpoints_fn, workflow_name=workflow_name) - if checkpoints: - latest = max(checkpoints, key=lambda cp: ( - getattr(cp, "timestamp", ""), - getattr(cp, "iteration_count", 0), - )) - return latest.checkpoint_id - except Exception: - pass - - # Fallback: list checkpoint IDs and return last. - list_ids_fn = getattr(storage, "list_checkpoint_ids", None) - if callable(list_ids_fn) and workflow_name: - try: - checkpoint_ids = await self._call_maybe_async(list_ids_fn, workflow_name=workflow_name) - if checkpoint_ids: - return checkpoint_ids[-1] - except Exception: - pass - - return None + """Resolve the most recent checkpoint ID using only the public 1.2.x protocol.""" + workflow_name = self._current_workflow_name() + if not workflow_name: + return None + try: + latest = await storage.get_latest(workflow_name=workflow_name) + except Exception as exc: # pragma: no cover - defensive + logger.debug("get_latest failed for workflow %s: %s", workflow_name, exc) + return None + if latest is None: + return None + checkpoint_id = getattr(latest, "checkpoint_id", None) + return checkpoint_id if isinstance(checkpoint_id, str) else None diff --git a/agentic_ai/agents/agent_framework/multi_agent/reflection_agent.py b/agentic_ai/agents/agent_framework/multi_agent/reflection_agent.py index f01158643..a1c3fedc7 100644 --- a/agentic_ai/agents/agent_framework/multi_agent/reflection_agent.py +++ b/agentic_ai/agents/agent_framework/multi_agent/reflection_agent.py @@ -108,7 +108,7 @@ async def _setup_agents(self) -> None: name="PrimaryAgent", instructions=PRIMARY_AGENT_INSTRUCTIONS, tools=tools, - default_options=ChatOptions(model=self.openai_model_name), + default_options=ChatOptions(model=self.azure_deployment), ) self._reviewer = FrameworkAgent( @@ -116,7 +116,7 @@ async def _setup_agents(self) -> None: name="Reviewer", instructions=REVIEWER_INSTRUCTIONS, tools=tools, - default_options=ChatOptions(model=self.openai_model_name), + default_options=ChatOptions(model=self.azure_deployment), ) # Initialize agents diff --git a/agentic_ai/agents/agent_framework/single_agent.py b/agentic_ai/agents/agent_framework/single_agent.py index 1a099c21c..b06d1e270 100644 --- a/agentic_ai/agents/agent_framework/single_agent.py +++ b/agentic_ai/agents/agent_framework/single_agent.py @@ -85,7 +85,7 @@ async def _setup_single_agent(self) -> None: name="ai_assistant", instructions=instructions, tools=tools, - default_options=ChatOptions(model=self.openai_model_name), + default_options=ChatOptions(model=self.azure_deployment), ) try: diff --git a/tests/test_agent_framework_1_2_1_regression.py b/tests/test_agent_framework_1_2_1_regression.py index 2b1af5111..4e56ee354 100644 --- a/tests/test_agent_framework_1_2_1_regression.py +++ b/tests/test_agent_framework_1_2_1_regression.py @@ -149,11 +149,25 @@ def test_framework_agent_constructor_signature(self): # description was wired through in 1.2.x and is consumed by HandoffBuilder assert 'description' in params - def test_chat_options_model_id(self): - """ChatOptions accepts model_id for specifying the model.""" + def test_chat_options_model(self): + """ChatOptions exposes ``model`` (not ``model_id``) in 1.2.x. + + Regression for the CI failure where agents constructed + ``ChatOptions(model_id=...)``. ``ChatOptions`` is a ``TypedDict`` with + ``total=False``, so unknown keys are silently accepted at construction + time and only blow up later when forwarded as kwargs to the OpenAI + ``responses.create`` call (``TypeError: unexpected keyword argument + 'model_id'``). Assert against ``__annotations__`` so the test fails + loudly if the field name changes again. + """ from agent_framework import ChatOptions - opts = ChatOptions(model_id="gpt-4o") - assert opts["model_id"] == "gpt-4o" + opts = ChatOptions(model="gpt-4o") + assert opts["model"] == "gpt-4o" + assert "model" in ChatOptions.__annotations__ + assert "model_id" not in ChatOptions.__annotations__, ( + "ChatOptions field is `model` in agent-framework 1.2.x; " + "passing model_id= silently constructs a bad dict that fails downstream." + ) def test_openai_client_constructor_supports_azure_kwargs(self): """OpenAIChatClient accepts model, api_key, credential, azure_endpoint, api_version.""" @@ -570,46 +584,66 @@ def test_magentic_agent_init(self): assert agent._max_round_count == 4 assert agent._max_stall_count == 2 - def test_magentic_checkpoint_storage(self): - """DictCheckpointStorage works for checkpointing.""" - from agents.agent_framework.multi_agent.magentic_group import DictCheckpointStorage - from agent_framework import WorkflowCheckpoint - - backing = {} - storage = DictCheckpointStorage(backing) - assert storage.latest_checkpoint_id is None + def test_magentic_uses_builtin_checkpoint_storage(self): + """Magentic agent's _create_checkpoint_storage returns a built-in 1.2.1 storage.""" + from agents.agent_framework.multi_agent.magentic_group import Agent + from agent_framework import InMemoryCheckpointStorage + from agents.agent_framework.multi_agent import checkpoint_storage as cps - def test_magentic_checkpoint_storage_implements_1_2_1_protocol(self): - """DictCheckpointStorage implements the 1.2.x CheckpointStorage protocol. + cps.reset_storage_cache() + agent = Agent(state_store={}, session_id="magentic-builtin-test") + storage = agent._create_checkpoint_storage() + # Default backend is in-memory + assert isinstance(storage, InMemoryCheckpointStorage) - In 1.0.0rc1 the methods were named save_checkpoint/load_checkpoint/...; in - 1.2.x they are save/load/delete/get_latest with workflow_name kwargs. - """ - from agents.agent_framework.multi_agent.magentic_group import DictCheckpointStorage + def test_handoff_uses_builtin_checkpoint_storage(self): + """Handoff agent picks up the built-in checkpoint storage by default.""" + from agents.agent_framework.multi_agent.handoff_multi_domain_agent import Agent + from agent_framework import InMemoryCheckpointStorage + from agents.agent_framework.multi_agent import checkpoint_storage as cps + + cps.reset_storage_cache() + agent = Agent(state_store={}, session_id="handoff-builtin-test") + assert isinstance(agent._checkpoint_storage, InMemoryCheckpointStorage) + assert agent._workflow_name == "handoff-handoff-builtin-test" + + def test_checkpoint_storage_factory_implements_1_2_1_protocol(self): + """Built-in storage exposes the 1.2.x CheckpointStorage protocol surface.""" + from agents.agent_framework.multi_agent.checkpoint_storage import ( + create_checkpoint_storage, + reset_storage_cache, + ) import inspect + + reset_storage_cache() + storage = create_checkpoint_storage("protocol-test") for name in ("save", "load", "delete", "get_latest", "list_checkpoints", "list_checkpoint_ids"): - assert callable(getattr(DictCheckpointStorage, name, None)), ( - f"DictCheckpointStorage must implement 1.2.x method {name!r}" + assert callable(getattr(storage, name, None)), ( + f"Built-in storage must expose 1.2.x method {name!r}" ) - # Old method names should NOT exist - assert not hasattr(DictCheckpointStorage, "save_checkpoint"), \ - "save_checkpoint was renamed to save in 1.2.x" - assert not hasattr(DictCheckpointStorage, "load_checkpoint"), \ - "load_checkpoint was renamed to load in 1.2.x" - # workflow_name keyword-only enforcement + # Old method names should NOT exist on the built-in storage. + assert not hasattr(storage, "save_checkpoint") + assert not hasattr(storage, "load_checkpoint") + # workflow_name keyword-only enforcement on protocol methods. for name in ("list_checkpoints", "list_checkpoint_ids", "get_latest"): - sig = inspect.signature(getattr(DictCheckpointStorage, name)) + sig = inspect.signature(getattr(storage, name)) assert "workflow_name" in sig.parameters, ( f"{name} must accept workflow_name keyword in 1.2.x" ) - def test_handoff_checkpoint_storage_implements_1_2_1_protocol(self): - """The handoff agent's _DictCheckpointStorage matches the 1.2.x protocol.""" - from agents.agent_framework.multi_agent.handoff_multi_domain_agent import _DictCheckpointStorage - for name in ("save", "load", "delete", "get_latest", "list_checkpoints", "list_checkpoint_ids"): - assert callable(getattr(_DictCheckpointStorage, name, None)), ( - f"_DictCheckpointStorage must implement 1.2.x method {name!r}" - ) + def test_checkpoint_storage_backend_selection(self, tmp_path, monkeypatch): + """WORKFLOW_CHECKPOINT_BACKEND env var selects file-backed storage.""" + from agent_framework import FileCheckpointStorage + from agents.agent_framework.multi_agent.checkpoint_storage import ( + create_checkpoint_storage, + reset_storage_cache, + ) + + monkeypatch.setenv("WORKFLOW_CHECKPOINT_BACKEND", "file") + monkeypatch.setenv("WORKFLOW_CHECKPOINT_DIR", str(tmp_path)) + reset_storage_cache() + storage = create_checkpoint_storage("file-session") + assert isinstance(storage, FileCheckpointStorage) def test_workflow_checkpoint_uses_workflow_name(self): """WorkflowCheckpoint uses workflow_name (was workflow_id in 1.0.0rc1).""" @@ -620,53 +654,65 @@ def test_workflow_checkpoint_uses_workflow_name(self): assert "workflow_id" not in sig.parameters, \ "workflow_id was renamed to workflow_name in 1.2.x" - def test_dict_checkpoint_storage_save_load_roundtrip(self): + def test_builtin_storage_save_load_roundtrip(self): """save() persists a checkpoint that load()/get_latest() can recover.""" - from agents.agent_framework.multi_agent.magentic_group import DictCheckpointStorage - from agent_framework import WorkflowCheckpoint + from agent_framework import InMemoryCheckpointStorage, WorkflowCheckpoint - storage = DictCheckpointStorage({}) + storage = InMemoryCheckpointStorage() cp = WorkflowCheckpoint(workflow_name="wf-A", graph_signature_hash="hash-1", checkpoint_id="cp-1") - cid = asyncio.get_event_loop().run_until_complete(storage.save(cp)) + cid = asyncio.run(storage.save(cp)) assert cid == "cp-1" - loaded = asyncio.get_event_loop().run_until_complete(storage.load("cp-1")) + loaded = asyncio.run(storage.load("cp-1")) assert loaded is not None assert loaded.workflow_name == "wf-A" assert loaded.checkpoint_id == "cp-1" - latest = asyncio.get_event_loop().run_until_complete(storage.get_latest(workflow_name="wf-A")) + latest = asyncio.run(storage.get_latest(workflow_name="wf-A")) assert latest is not None and latest.checkpoint_id == "cp-1" # get_latest with a non-matching workflow_name returns None - none = asyncio.get_event_loop().run_until_complete(storage.get_latest(workflow_name="other")) + none = asyncio.run(storage.get_latest(workflow_name="other")) assert none is None def test_get_latest_checkpoint_id_uses_workflow_name_kwarg(self): - """_get_latest_checkpoint_id resumes from a 1.2.x-protocol-only storage. - - Storages that follow the 1.2.x ``CheckpointStorage`` protocol require - the keyword-only ``workflow_name`` argument on ``get_latest`` / - ``list_checkpoints`` / ``list_checkpoint_ids`` and do not expose the - nonstandard ``latest_checkpoint_id`` convenience. The resume helper - must call those methods with ``workflow_name=`` so it can still find - the latest checkpoint. + """_get_latest_checkpoint_id resumes via the standardized 1.2.x protocol. + + After a previous build recorded the workflow name in state_store, the + helper must call ``storage.get_latest(workflow_name=...)`` with that + recorded name to find the latest checkpoint. """ - from agents.agent_framework.multi_agent.magentic_group import Agent, DictCheckpointStorage - from agent_framework import WorkflowCheckpoint + from agents.agent_framework.multi_agent.magentic_group import Agent + from agent_framework import InMemoryCheckpointStorage, WorkflowCheckpoint - # Back the storage with a dict that records the workflow_name so - # _workflow_name_for_storage can recover it (mirrors how save() works). - backing: dict = {"workflow_name": "wf-resume"} - storage = DictCheckpointStorage(backing) + storage = InMemoryCheckpointStorage() cp = WorkflowCheckpoint(workflow_name="wf-resume", graph_signature_hash="h", checkpoint_id="cp-latest") - asyncio.get_event_loop().run_until_complete(storage.save(cp)) + asyncio.run(storage.save(cp)) - agent = Agent(state_store={}, session_id="test") - latest_id = asyncio.get_event_loop().run_until_complete(agent._get_latest_checkpoint_id(storage)) + # Simulate a previous build having recorded the workflow_name. + state_store: dict = {} + agent = Agent(state_store=state_store, session_id="test") + state_store[agent._workflow_name_state_key] = "wf-resume" + + latest_id = asyncio.run(agent._get_latest_checkpoint_id(storage)) assert latest_id == "cp-latest" + # Without a recorded workflow_name, no resume target is returned. + state_store.pop(agent._workflow_name_state_key) + latest_id_none = asyncio.run(agent._get_latest_checkpoint_id(storage)) + assert latest_id_none is None + + def test_magentic_factory_override_still_works(self): + """Hosts can still inject a custom CheckpointStorage via state_store.""" + from agents.agent_framework.multi_agent.magentic_group import Agent + from agent_framework import InMemoryCheckpointStorage + + custom = InMemoryCheckpointStorage() + state_store = {"magentic_checkpoint_storage": custom} + agent = Agent(state_store=state_store, session_id="override-test") + assert agent._create_checkpoint_storage() is custom + def test_magentic_sanitize_final_answer(self): """FINAL_ANSWER prefix is stripped from workflow output.""" from agents.agent_framework.multi_agent.magentic_group import Agent @@ -700,9 +746,7 @@ def test_magentic_process_output_event(self): agent.set_websocket_manager(ws_manager) event = WorkflowEvent.output("exec1", "Final answer text") - asyncio.get_event_loop().run_until_complete( - agent._process_workflow_event(event) - ) + asyncio.run(agent._process_workflow_event(event)) ws_manager.broadcast.assert_called() call_args = ws_manager.broadcast.call_args[0] @@ -724,9 +768,7 @@ def test_magentic_process_data_event(self): mock_data.text = "streaming token" event = WorkflowEvent.emit("crm_billing", mock_data) - asyncio.get_event_loop().run_until_complete( - agent._process_workflow_event(event) - ) + asyncio.run(agent._process_workflow_event(event)) # Should broadcast agent_start and agent_token assert ws_manager.broadcast.call_count >= 1