From e9587fa62fe7ba05f0757500c73f661ad98242e4 Mon Sep 17 00:00:00 2001 From: saim256 <283483516+saim256@users.noreply.github.com> Date: Tue, 12 May 2026 17:31:05 -0400 Subject: [PATCH] fix: require admin auth for memory clear --- bounties/issue-2285/src/memory_routes.py | 28 +++++++++++ .../issue-2285/tests/test_memory_routes.py | 48 +++++++++++++++++-- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/bounties/issue-2285/src/memory_routes.py b/bounties/issue-2285/src/memory_routes.py index d474bd0a5..1b8c47493 100644 --- a/bounties/issue-2285/src/memory_routes.py +++ b/bounties/issue-2285/src/memory_routes.py @@ -23,6 +23,8 @@ from __future__ import annotations +import hmac +import os from typing import Any, Dict, List, Optional from flask import Blueprint, jsonify, request, current_app @@ -35,6 +37,28 @@ memory_bp = Blueprint("agent_memory", __name__, url_prefix="/api/memory") +def _require_admin() -> Optional[tuple]: + """Require an admin key for destructive memory operations.""" + expected_key = os.environ.get("MEMORY_ADMIN_KEY", "") + if not expected_key: + return jsonify({ + "error": "unauthorized", + "message": "MEMORY_ADMIN_KEY not configured" + }), 401 + + provided_key = ( + request.headers.get("X-Admin-Key", "") + or request.headers.get("X-API-Key", "") + ) + if not hmac.compare_digest(provided_key, expected_key): + return jsonify({ + "error": "unauthorized", + "message": "Invalid admin key" + }), 401 + + return None + + def _get_engine() -> AgentMemoryEngine: """Get memory engine from Flask app config or create new. @@ -494,6 +518,10 @@ def clear_memory() -> tuple: } """ try: + auth_error = _require_admin() + if auth_error is not None: + return auth_error + agent_id = _validate_agent_id(request.args.get("agent_id")) engine = _get_engine() diff --git a/bounties/issue-2285/tests/test_memory_routes.py b/bounties/issue-2285/tests/test_memory_routes.py index e4e818d0e..9bb898368 100644 --- a/bounties/issue-2285/tests/test_memory_routes.py +++ b/bounties/issue-2285/tests/test_memory_routes.py @@ -15,6 +15,7 @@ import unittest from pathlib import Path from typing import Any, Dict +from unittest.mock import patch # Add src directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent / "src")) @@ -527,9 +528,11 @@ def test_clear_memory(self) -> None: content_type="application/json" ) - response = self.client.delete( - "/api/memory/clear?agent_id=test-agent" - ) + with patch.dict("os.environ", {"MEMORY_ADMIN_KEY": "test-admin"}, clear=False): + response = self.client.delete( + "/api/memory/clear?agent_id=test-agent", + headers={"X-Admin-Key": "test-admin"}, + ) self.assertEqual(response.status_code, 200) data = response.get_json() @@ -543,6 +546,45 @@ def test_clear_memory(self) -> None: recent_data = recent_response.get_json() self.assertEqual(len(recent_data["recalls"]), 0) + def test_clear_memory_requires_admin_key(self) -> None: + """Test clearing agent memory requires admin authentication.""" + self.client.post( + "/api/memory/record", + json={ + "agent_id": "test-agent", + "content_id": "video-unauthorized" + }, + content_type="application/json" + ) + + for headers in ({}, {"X-Admin-Key": "wrong-admin"}): + with self.subTest(headers=headers): + with patch.dict("os.environ", {"MEMORY_ADMIN_KEY": "test-admin"}, clear=False): + response = self.client.delete( + "/api/memory/clear?agent_id=test-agent", + headers=headers, + ) + + self.assertEqual(response.status_code, 401) + self.assertEqual(response.get_json()["error"], "unauthorized") + + recent_response = self.client.get( + "/api/memory/recent?agent_id=test-agent" + ) + recent_data = recent_response.get_json() + self.assertEqual(len(recent_data["recalls"]), 1) + + def test_clear_memory_denies_when_admin_key_unconfigured(self) -> None: + """Test clear endpoint fails closed when MEMORY_ADMIN_KEY is absent.""" + with patch.dict("os.environ", {}, clear=True): + response = self.client.delete( + "/api/memory/clear?agent_id=test-agent", + headers={"X-Admin-Key": "test-admin"}, + ) + + self.assertEqual(response.status_code, 401) + self.assertEqual(response.get_json()["error"], "unauthorized") + def test_record_with_importance(self) -> None: """Test recording content with importance score.""" response = self.client.post(