diff --git a/.pyrit_conf_example b/.pyrit_conf_example index 9d9e66305d..9694646b35 100644 --- a/.pyrit_conf_example +++ b/.pyrit_conf_example @@ -117,6 +117,20 @@ operation: op_trash_panda # Applies only to the pyrit_backend server. max_concurrent_scenario_runs: 3 +# Custom Initializer Registration (REST API) +# ------------------------------------------- +# When true, the REST API accepts POST /api/initializers to register custom +# initializer scripts and DELETE /api/initializers/{name} to remove custom +# initializers. +# +# ⚠️ WARNING: Enabling this allows arbitrary Python code execution on the +# server via the REST API. Only enable on trusted networks. +# The pyrit_backend default host is localhost, which limits exposure. +# If you bind to 0.0.0.0, ensure you are on a trusted network. +# +# Default: false +allow_custom_initializers: false + # Silent Mode # ----------- # If true, suppresses print statements during initialization. diff --git a/pyrit/backend/models/__init__.py b/pyrit/backend/models/__init__.py index b33901f560..388076fcd5 100644 --- a/pyrit/backend/models/__init__.py +++ b/pyrit/backend/models/__init__.py @@ -51,6 +51,7 @@ InitializerParameterSummary, ListRegisteredInitializersResponse, RegisteredInitializer, + RegisterInitializerRequest, ) from pyrit.backend.models.scenarios import ( ListRegisteredScenariosResponse, @@ -110,6 +111,7 @@ "InitializerParameterSummary", "ListRegisteredInitializersResponse", "RegisteredInitializer", + "RegisterInitializerRequest", # Targets "CreateTargetRequest", "TargetCapabilitiesInfo", diff --git a/pyrit/backend/models/initializers.py b/pyrit/backend/models/initializers.py index 15174dfd53..dfcc491de7 100644 --- a/pyrit/backend/models/initializers.py +++ b/pyrit/backend/models/initializers.py @@ -8,11 +8,10 @@ before scenario execution. These models represent initializer metadata. """ -from typing import Optional - from pydantic import BaseModel, Field from pyrit.backend.models.common import PaginationInfo +from pyrit.identifiers.class_name_utils import REGISTRY_NAME_PATTERN class InitializerParameterSummary(BaseModel): @@ -20,7 +19,7 @@ class InitializerParameterSummary(BaseModel): name: str = Field(..., description="Parameter name") description: str = Field(..., description="Human-readable description of the parameter") - default: Optional[list[str]] = Field(None, description="Default value(s), or None if required") + default: list[str] | None = Field(None, description="Default value(s), or None if required") class RegisteredInitializer(BaseModel): @@ -42,3 +41,14 @@ class ListRegisteredInitializersResponse(BaseModel): items: list[RegisteredInitializer] = Field(..., description="List of initializer summaries") pagination: PaginationInfo = Field(..., description="Pagination metadata") + + +class RegisterInitializerRequest(BaseModel): + """Request body for registering a custom initializer by uploading script content.""" + + name: str = Field( + ..., + pattern=REGISTRY_NAME_PATTERN, + description="Registry name for the initializer (e.g., 'my_custom')", + ) + script_content: str = Field(..., description="Python source code containing a PyRITInitializer subclass") diff --git a/pyrit/backend/routes/initializers.py b/pyrit/backend/routes/initializers.py index 7c10d7ad63..818f560d89 100644 --- a/pyrit/backend/routes/initializers.py +++ b/pyrit/backend/routes/initializers.py @@ -4,34 +4,56 @@ """ Initializer API routes. -Provides endpoints for listing available initializers and their metadata. +Provides endpoints for listing, registering, and removing initializers. Route structure: - /api/initializers — list all initializers - /api/initializers/{name} — get single initializer detail + GET /api/initializers — list all initializers + GET /api/initializers/{name} — get single initializer detail + POST /api/initializers — register initializer from script + DELETE /api/initializers/{name} — unregister an initializer """ -from typing import Optional - -from fastapi import APIRouter, HTTPException, Query, status +from fastapi import APIRouter, HTTPException, Query, Request, status from pyrit.backend.models.common import ProblemDetail from pyrit.backend.models.initializers import ( ListRegisteredInitializersResponse, RegisteredInitializer, + RegisterInitializerRequest, ) from pyrit.backend.services.initializer_service import get_initializer_service router = APIRouter(prefix="/initializers", tags=["initializers"]) +def _check_custom_initializers_allowed(request: Request) -> None: + """ + Check that allow_custom_initializers is enabled on the server. + + Args: + request: The incoming FastAPI request. + + Raises: + HTTPException: 403 if custom initializer operations are not enabled. + """ + allowed = getattr(request.app.state, "allow_custom_initializers", False) + if not allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=( + "Custom initializer operations are disabled. " + "Set allow_custom_initializers: true in .pyrit_conf to enable." + ), + ) + + @router.get( "", response_model=ListRegisteredInitializersResponse, ) async def list_initializers( limit: int = Query(50, ge=1, le=200, description="Maximum items per page"), - cursor: Optional[str] = Query(None, description="Pagination cursor (initializer_name to start after)"), + cursor: str | None = Query(None, description="Pagination cursor (initializer_name to start after)"), ) -> ListRegisteredInitializersResponse: """ List all available initializers. @@ -73,3 +95,81 @@ async def get_initializer(initializer_name: str) -> RegisteredInitializer: ) return initializer + + +@router.post( + "", + response_model=RegisteredInitializer, + status_code=status.HTTP_201_CREATED, + responses={ + 403: {"model": ProblemDetail, "description": "Custom initializer operations disabled"}, + 409: {"model": ProblemDetail, "description": "Initializer name already registered"}, + }, +) +async def register_initializer( + request: Request, + body: RegisterInitializerRequest, +) -> RegisteredInitializer: + """ + Register an initializer by uploading Python source code. + + The script must contain a concrete PyRITInitializer subclass. + Requires allow_custom_initializers to be enabled in pyrit_conf. + + Args: + request: The incoming FastAPI request. + body: Request body with name and script_content. + + Returns: + The newly registered initializer summary. + """ + _check_custom_initializers_allowed(request) + service = get_initializer_service() + + try: + return await service.register_initializer_async(name=body.name, script_content=body.script_content) + except ValueError as e: + detail = str(e) + if "already registered" in detail: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=detail) from None + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) from None + + +@router.delete( + "/{initializer_name}", + status_code=status.HTTP_204_NO_CONTENT, + responses={ + 400: {"model": ProblemDetail, "description": "Cannot remove built-in initializer"}, + 403: {"model": ProblemDetail, "description": "Custom initializer operations disabled"}, + 404: {"model": ProblemDetail, "description": "Initializer not found"}, + }, +) +async def unregister_initializer( + request: Request, + initializer_name: str, +) -> None: + """ + Remove a custom initializer from the registry. + + Built-in initializers cannot be removed. Requires + allow_custom_initializers to be enabled in pyrit_conf. + + Args: + request: The incoming FastAPI request. + initializer_name: Registry name of the initializer to remove. + """ + _check_custom_initializers_allowed(request) + service = get_initializer_service() + + try: + await service.unregister_initializer_async(initializer_name=initializer_name) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from None + except KeyError: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Initializer '{initializer_name}' not found", + ) from None diff --git a/pyrit/backend/services/initializer_service.py b/pyrit/backend/services/initializer_service.py index 77b0f2bf28..153ca59412 100644 --- a/pyrit/backend/services/initializer_service.py +++ b/pyrit/backend/services/initializer_service.py @@ -2,12 +2,13 @@ # Licensed under the MIT license. """ -Initializer service for listing available initializers. +Initializer service for listing, registering, and removing initializers. -Provides read-only access to the InitializerRegistry, exposing initializer +Provides access to the InitializerRegistry, exposing initializer metadata through the REST API. """ +import logging from functools import lru_cache from pyrit.backend.models.common import PaginationInfo @@ -18,6 +19,8 @@ ) from pyrit.registry import InitializerMetadata, InitializerRegistry +logger = logging.getLogger(__name__) + def _metadata_to_registered_initializer(metadata: InitializerMetadata) -> RegisteredInitializer: """ @@ -47,7 +50,7 @@ def _metadata_to_registered_initializer(metadata: InitializerMetadata) -> Regist class InitializerService: """ - Service for listing available initializers. + Service for listing, registering, and removing initializers. Uses InitializerRegistry as the source of truth for initializer metadata. """ @@ -99,6 +102,44 @@ async def get_initializer_async(self, *, initializer_name: str) -> RegisteredIni return _metadata_to_registered_initializer(metadata) return None + async def register_initializer_async( + self, + *, + name: str, + script_content: str, + ) -> RegisteredInitializer: + """ + Register an initializer from uploaded Python source code. + + Args: + name: Registry name for the new initializer. + script_content: Python source code containing a PyRITInitializer subclass. + + Returns: + The newly registered initializer summary. + + Raises: + ValueError: If the script is invalid or contains no initializer class. + """ + self._registry.register_from_content(name=name, script_content=script_content) + + initializer = await self.get_initializer_async(initializer_name=name) + if not initializer: + raise ValueError(f"Initializer '{name}' was registered but metadata could not be retrieved.") + return initializer + + async def unregister_initializer_async(self, *, initializer_name: str) -> None: + """ + Remove a custom initializer from the registry. + + Built-in initializers cannot be removed. + + Args: + initializer_name: The registry name to remove. + """ + self._registry.unregister_and_cleanup(initializer_name) + logger.info(f"Unregistered initializer: {initializer_name}") + @staticmethod def _paginate( *, diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index c17eb83b54..708e19c733 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -147,6 +147,7 @@ def __init__( self._operator = config.operator self._operation = config.operation self._max_concurrent_scenario_runs = config.max_concurrent_scenario_runs + self._allow_custom_initializers = config.allow_custom_initializers # Lazy-loaded registries self._scenario_registry: Optional[ScenarioRegistry] = None @@ -223,6 +224,7 @@ def with_overrides( derived._operator = self._operator derived._operation = self._operation derived._max_concurrent_scenario_runs = self._max_concurrent_scenario_runs + derived._allow_custom_initializers = self._allow_custom_initializers derived._scenario_config = self._scenario_config # Apply overrides or inherit diff --git a/pyrit/cli/pyrit_backend.py b/pyrit/cli/pyrit_backend.py index 8eed2cc929..819ad7baa9 100644 --- a/pyrit/cli/pyrit_backend.py +++ b/pyrit/cli/pyrit_backend.py @@ -199,8 +199,17 @@ async def initialize_and_run_async(*, parsed_args: Namespace) -> int: default_labels["operation"] = context._operation app.state.default_labels = default_labels app.state.max_concurrent_scenario_runs = context._max_concurrent_scenario_runs + app.state.allow_custom_initializers = context._allow_custom_initializers display_host = parsed_args.host + if context._allow_custom_initializers: + print("⚠️ WARNING: Custom initializer registration is ENABLED (allow_custom_initializers: true).") + print(" This allows arbitrary Python code execution via the REST API.") + if parsed_args.host == "0.0.0.0": + print(" 🚨 Server is bound to 0.0.0.0 — accessible from the NETWORK. Use only on trusted networks!") + else: + print(f" Server is bound to {display_host}.") + print(f"🚀 Starting PyRIT backend on http://{display_host}:{parsed_args.port}") print(f" API Docs: http://{display_host}:{parsed_args.port}/docs") if parsed_args.host == "0.0.0.0": diff --git a/pyrit/identifiers/__init__.py b/pyrit/identifiers/__init__.py index a5eabdb0b7..a85c2cacab 100644 --- a/pyrit/identifiers/__init__.py +++ b/pyrit/identifiers/__init__.py @@ -8,8 +8,10 @@ build_seed_identifier, ) from pyrit.identifiers.class_name_utils import ( + REGISTRY_NAME_PATTERN, class_name_to_snake_case, snake_case_to_class_name, + validate_registry_name, ) from pyrit.identifiers.component_identifier import ComponentIdentifier, Identifiable, config_hash from pyrit.identifiers.evaluation_identifier import ( @@ -31,8 +33,10 @@ "compute_eval_hash", "EvaluationIdentifier", "Identifiable", + "REGISTRY_NAME_PATTERN", "ScorerEvaluationIdentifier", "snake_case_to_class_name", + "validate_registry_name", "config_hash", "IdentifierFilter", "IdentifierType", diff --git a/pyrit/identifiers/class_name_utils.py b/pyrit/identifiers/class_name_utils.py index f1a4d715a6..2bd903be36 100644 --- a/pyrit/identifiers/class_name_utils.py +++ b/pyrit/identifiers/class_name_utils.py @@ -10,6 +10,31 @@ import re +# Valid registry names: lowercase letter followed by up to 63 lowercase +# letters, digits, or underscores. This matches the output of +# class_name_to_snake_case and is safe for use as filesystem components. +REGISTRY_NAME_PATTERN = r"^[a-z][a-z0-9_]{0,63}$" + +_REGISTRY_NAME_RE = re.compile(REGISTRY_NAME_PATTERN) + + +def validate_registry_name(name: str) -> None: + """ + Validate that *name* is a legal registry name. + + Args: + name: The name to validate. + + Raises: + ValueError: If *name* does not match the required pattern. + """ + if not _REGISTRY_NAME_RE.match(name): + raise ValueError( + f"Invalid registry name '{name}'. " + f"Names must match {REGISTRY_NAME_PATTERN} " + "(lowercase ASCII, digits, underscores; 1-64 chars; must start with a letter)." + ) + def class_name_to_snake_case(class_name: str, *, suffix: str = "") -> str: """ diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py index b291840491..7d251a9cba 100644 --- a/pyrit/registry/class_registries/base_class_registry.py +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -310,6 +310,23 @@ def register( self._class_entries[name] = entry self._metadata_cache = None + def unregister(self, name: str) -> None: + """ + Remove a registered class from the registry. + + Args: + name: The registry name of the class to remove. + + Raises: + KeyError: If the name is not registered. + """ + self._ensure_discovered() + if name not in self._class_entries: + available = ", ".join(self.get_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + del self._class_entries[name] + self._metadata_cache = None + def create_instance(self, name: str, **kwargs: object) -> T: """ Create an instance of a registered class. diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py index 23c6d3e6f9..81d8475ac7 100644 --- a/pyrit/registry/class_registries/initializer_registry.py +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -11,12 +11,13 @@ from __future__ import annotations import importlib.util +import inspect import logging from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Optional -from pyrit.identifiers.class_name_utils import class_name_to_snake_case +from pyrit.identifiers.class_name_utils import class_name_to_snake_case, validate_registry_name from pyrit.registry.base import ClassRegistryEntry from pyrit.registry.class_registries.base_class_registry import ( BaseClassRegistry, @@ -81,8 +82,14 @@ def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: boo if self._discovery_path is None: raise ValueError("self._discovery_path is not initialized") + self._builtin_names: set[str] = set() super().__init__(lazy_discovery=lazy_discovery) + def is_builtin(self, name: str) -> bool: + """Return True if *name* was registered during built-in discovery.""" + self._ensure_discovered() + return name in self._builtin_names + def _discover(self) -> None: """Discover all initializers from the specified discovery path.""" discovery_path = self._discovery_path @@ -96,7 +103,7 @@ def _discover(self) -> None: from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer if discovery_path.is_file(): - self._process_file(file_path=discovery_path, base_class=PyRITInitializer) + self._process_file(file_path=discovery_path, base_class=PyRITInitializer, builtin=True) else: for _file_stem, _file_path, initializer_class in discover_in_directory( directory=discovery_path, @@ -105,18 +112,18 @@ def _discover(self) -> None: ): self._register_initializer( initializer_class=initializer_class, + builtin=True, ) - def _process_file(self, *, file_path: Path, base_class: type) -> None: + def _process_file(self, *, file_path: Path, base_class: type, builtin: bool = False) -> None: """ Process a Python file to extract initializer subclasses. Args: file_path: Path to the Python file to process. base_class: The PyRITInitializer base class. + builtin: Whether discovered classes should be marked as built-in. """ - import inspect - short_name = file_path.stem try: @@ -137,6 +144,7 @@ def _process_file(self, *, file_path: Path, base_class: type) -> None: ): self._register_initializer( initializer_class=attr, + builtin=builtin, ) except Exception as e: @@ -146,12 +154,14 @@ def _register_initializer( self, *, initializer_class: type[PyRITInitializer], + builtin: bool = False, ) -> None: """ Register an initializer class. Args: initializer_class: The initializer class to register. + builtin: Whether this is a built-in initializer. """ try: # Convert class name to snake_case for registry name @@ -168,6 +178,8 @@ def _register_initializer( entry = ClassEntry(registered_class=initializer_class) self._class_entries[registry_name] = entry + if builtin: + self._builtin_names.add(registry_name) logger.debug(f"Registered initializer: {registry_name} ({initializer_class.__name__})") except Exception as e: @@ -208,6 +220,124 @@ def _build_metadata(self, name: str, entry: ClassEntry[PyRITInitializer]) -> Ini required_env_vars=(), ) + def register_from_content(self, *, name: str, script_content: str) -> str: + """ + Register an initializer from uploaded Python source code. + + Writes *script_content* to a managed directory, loads it as a + module, discovers the first concrete ``PyRITInitializer`` + subclass, and registers it under *name*. + + Note: + Registrations are runtime-only and are not rediscovered on + server restart. Script files persist on disk as import + artifacts for the current process. + + Args: + name: Registry name for the new initializer. + script_content: Python source code that defines a + ``PyRITInitializer`` subclass. + + Returns: + The registry name that was registered. + + Raises: + ValueError: If the source cannot be compiled, does not + contain a valid initializer class, or *name* collides + with an existing entry. + """ + self._ensure_discovered() + + validate_registry_name(name) + + if name in self._class_entries: + raise ValueError(f"Initializer '{name}' is already registered. Unregister it first to replace it.") + + # Deferred: importing pyrit.setup triggers heavy __init__.py chain + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + # Write to a managed directory so importlib can load it + managed_dir = self._get_custom_scripts_dir() + script_path = managed_dir / f"{name}.py" + try: + script_path.write_text(script_content, encoding="utf-8") + except OSError as e: + raise ValueError(f"Failed to write initializer script: {e}") from e + + try: + spec = importlib.util.spec_from_file_location(f"custom_initializer.{name}", script_path) + if not spec or not spec.loader: + raise ValueError(f"Could not load initializer script for '{name}'") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + discovered: type[PyRITInitializer] | None = None + for attr_name in dir(module): + attr = getattr(module, attr_name) + if ( + inspect.isclass(attr) + and issubclass(attr, PyRITInitializer) + and attr is not PyRITInitializer + and not inspect.isabstract(attr) + and attr.__module__ == module.__name__ + ): + discovered = attr + break + + if discovered is None: + raise ValueError(f"Uploaded script for '{name}' does not contain a concrete PyRITInitializer subclass.") + except ValueError: + script_path.unlink(missing_ok=True) + raise + except Exception as e: + script_path.unlink(missing_ok=True) + raise ValueError(f"Failed to load initializer script '{name}': {e}") from e + + entry = ClassEntry(registered_class=discovered) + self._class_entries[name] = entry + self._metadata_cache = None + logger.info(f"Registered custom initializer: {name} ({discovered.__name__})") + return name + + def unregister_and_cleanup(self, name: str) -> None: + """ + Unregister a custom initializer and clean up its script file. + + Built-in initializers cannot be removed. For custom initializers + added via ``register_from_content``, the saved script file is + also deleted. + + Args: + name: The registry name to remove. + + Raises: + KeyError: If the name is not registered. + ValueError: If the name refers to a built-in initializer. + """ + self._ensure_discovered() + if name in self._builtin_names: + raise ValueError(f"Cannot remove built-in initializer '{name}'.") + self.unregister(name) + + script_path = self._get_custom_scripts_dir() / f"{name}.py" + script_path.unlink(missing_ok=True) + + @staticmethod + def _get_custom_scripts_dir() -> Path: + """ + Get the directory for storing uploaded custom initializer scripts. + + Returns: + Path to ``~/.pyrit/custom_initializers/``, created if needed. + """ + # Deferred: importing pyrit.common.path triggers pyrit __init__.py + from pyrit.common.path import CONFIGURATION_DIRECTORY_PATH + + custom_dir = CONFIGURATION_DIRECTORY_PATH / "custom_initializers" + custom_dir.mkdir(parents=True, exist_ok=True) + return custom_dir + @staticmethod def resolve_script_paths(*, script_paths: list[str]) -> list[Path]: """ diff --git a/pyrit/setup/configuration_loader.py b/pyrit/setup/configuration_loader.py index 46c262bded..0fe2db0a2e 100644 --- a/pyrit/setup/configuration_loader.py +++ b/pyrit/setup/configuration_loader.py @@ -133,6 +133,7 @@ class ConfigurationLoader(YamlLoadable): operation: Optional[str] = None scenario: Optional[Union[str, dict[str, Any]]] = None max_concurrent_scenario_runs: int = 3 + allow_custom_initializers: bool = False extensions: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: diff --git a/tests/unit/backend/test_initializer_service.py b/tests/unit/backend/test_initializer_service.py index 8c3c5977d0..6f52c5647a 100644 --- a/tests/unit/backend/test_initializer_service.py +++ b/tests/unit/backend/test_initializer_service.py @@ -28,6 +28,14 @@ def client() -> TestClient: return TestClient(app) +@pytest.fixture +def client_with_custom_initializers_enabled(): + """Create a test client with allow_custom_initializers enabled.""" + app.state.allow_custom_initializers = True + yield TestClient(app) + app.state.allow_custom_initializers = False + + @pytest.fixture(autouse=True) def clear_service_cache(): """Clear the initializer service singleton cache between tests.""" @@ -283,3 +291,216 @@ def test_get_initializer_returns_404_when_not_found(self, client: TestClient) -> response = client.get("/api/initializers/nonexistent") assert response.status_code == status.HTTP_404_NOT_FOUND + + +# ============================================================================ +# Service Register/Unregister Tests +# ============================================================================ + + +_SAMPLE_SCRIPT = """ +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +class MyCustomInitializer(PyRITInitializer): + \"\"\"A custom test initializer.\"\"\" + + async def initialize_async(self) -> None: + pass +""" + + +class TestInitializerServiceRegister: + """Tests for InitializerService.register_initializer_async.""" + + async def test_register_initializer_calls_registry(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + mock_registry = MagicMock() + mock_registry.register_from_content.return_value = "my_custom" + mock_registry.list_metadata.return_value = [ + _make_initializer_metadata(registry_name="my_custom", class_name="MyCustomInitializer") + ] + service._registry = mock_registry + + result = await service.register_initializer_async(name="my_custom", script_content=_SAMPLE_SCRIPT) + + mock_registry.register_from_content.assert_called_once_with(name="my_custom", script_content=_SAMPLE_SCRIPT) + assert result.initializer_name == "my_custom" + + async def test_register_initializer_propagates_value_error(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + mock_registry = MagicMock() + mock_registry.register_from_content.side_effect = ValueError("no classes found") + service._registry = mock_registry + + with pytest.raises(ValueError): + await service.register_initializer_async(name="bad", script_content="x = 1") + + +class TestInitializerServiceUnregister: + """Tests for InitializerService.unregister_initializer_async.""" + + async def test_unregister_initializer_calls_registry(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + mock_registry = MagicMock() + service._registry = mock_registry + + await service.unregister_initializer_async(initializer_name="target") + + mock_registry.unregister_and_cleanup.assert_called_once_with("target") + + async def test_unregister_initializer_propagates_key_error(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + mock_registry = MagicMock() + mock_registry.unregister_and_cleanup.side_effect = KeyError("not found") + service._registry = mock_registry + + with pytest.raises(KeyError): + await service.unregister_initializer_async(initializer_name="nonexistent") + + async def test_unregister_initializer_propagates_value_error_for_builtin(self) -> None: + with patch.object(InitializerService, "__init__", lambda self: None): + service = InitializerService() + mock_registry = MagicMock() + mock_registry.unregister_and_cleanup.side_effect = ValueError("Cannot remove built-in") + service._registry = mock_registry + + with pytest.raises(ValueError, match="Cannot remove built-in"): + await service.unregister_initializer_async(initializer_name="simple") + + +# ============================================================================ +# POST / DELETE Route Tests +# ============================================================================ + + +class TestRegisterInitializerRoute: + """Tests for POST /api/initializers route.""" + + def test_post_returns_403_when_custom_initializers_disabled(self, client: TestClient) -> None: + app.state.allow_custom_initializers = False + response = client.post("/api/initializers", json={"name": "test", "script_content": _SAMPLE_SCRIPT}) + assert response.status_code == status.HTTP_403_FORBIDDEN + assert "disabled" in response.json()["detail"].lower() + + @pytest.mark.parametrize("bad_name", ["../traversal", "UPPER", "has space", "1digit", ""]) + def test_post_returns_422_for_invalid_name( + self, client_with_custom_initializers_enabled: TestClient, bad_name: str + ) -> None: + response = client_with_custom_initializers_enabled.post( + "/api/initializers", json={"name": bad_name, "script_content": _SAMPLE_SCRIPT} + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_post_returns_201_with_registered_initializer( + self, client_with_custom_initializers_enabled: TestClient + ) -> None: + summary = RegisteredInitializer( + initializer_name="my_custom", + initializer_type="MyCustomInitializer", + description="Custom init", + ) + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.register_initializer_async = AsyncMock(return_value=summary) + mock_get_service.return_value = mock_service + + response = client_with_custom_initializers_enabled.post( + "/api/initializers", json={"name": "my_custom", "script_content": _SAMPLE_SCRIPT} + ) + + assert response.status_code == status.HTTP_201_CREATED + data = response.json() + assert data["initializer_name"] == "my_custom" + + def test_post_returns_400_for_invalid_script(self, client_with_custom_initializers_enabled: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.register_initializer_async = AsyncMock(side_effect=ValueError("no classes")) + mock_get_service.return_value = mock_service + + response = client_with_custom_initializers_enabled.post( + "/api/initializers", json={"name": "bad", "script_content": "x = 1"} + ) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + + def test_post_forwards_name_and_content(self, client_with_custom_initializers_enabled: TestClient) -> None: + summary = RegisteredInitializer( + initializer_name="my_init", + initializer_type="MyInit", + description="desc", + ) + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.register_initializer_async = AsyncMock(return_value=summary) + mock_get_service.return_value = mock_service + + client_with_custom_initializers_enabled.post( + "/api/initializers", json={"name": "my_init", "script_content": _SAMPLE_SCRIPT} + ) + + call_kwargs = mock_service.register_initializer_async.call_args.kwargs + assert call_kwargs["name"] == "my_init" + assert call_kwargs["script_content"] == _SAMPLE_SCRIPT + + def test_post_returns_409_for_duplicate_name(self, client_with_custom_initializers_enabled: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.register_initializer_async = AsyncMock( + side_effect=ValueError("Initializer 'dup' is already registered.") + ) + mock_get_service.return_value = mock_service + + response = client_with_custom_initializers_enabled.post( + "/api/initializers", json={"name": "dup", "script_content": _SAMPLE_SCRIPT} + ) + + assert response.status_code == status.HTTP_409_CONFLICT + + +class TestUnregisterInitializerRoute: + """Tests for DELETE /api/initializers/{name} route.""" + + def test_delete_returns_403_when_custom_initializers_disabled(self, client: TestClient) -> None: + app.state.allow_custom_initializers = False + response = client.delete("/api/initializers/target") + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_delete_returns_204_on_success(self, client_with_custom_initializers_enabled: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.unregister_initializer_async = AsyncMock(return_value=None) + mock_get_service.return_value = mock_service + + response = client_with_custom_initializers_enabled.delete("/api/initializers/target") + + assert response.status_code == status.HTTP_204_NO_CONTENT + + def test_delete_returns_404_when_not_found(self, client_with_custom_initializers_enabled: TestClient) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.unregister_initializer_async = AsyncMock(side_effect=KeyError("not found")) + mock_get_service.return_value = mock_service + + response = client_with_custom_initializers_enabled.delete("/api/initializers/nonexistent") + + assert response.status_code == status.HTTP_404_NOT_FOUND + + def test_delete_returns_400_for_builtin_initializer( + self, client_with_custom_initializers_enabled: TestClient + ) -> None: + with patch("pyrit.backend.routes.initializers.get_initializer_service") as mock_get_service: + mock_service = MagicMock() + mock_service.unregister_initializer_async = AsyncMock( + side_effect=ValueError("Cannot remove built-in initializer 'simple'.") + ) + mock_get_service.return_value = mock_service + + response = client_with_custom_initializers_enabled.delete("/api/initializers/simple") + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "built-in" in response.json()["detail"].lower() diff --git a/tests/unit/cli/test_pyrit_backend.py b/tests/unit/cli/test_pyrit_backend.py index 7ea08197ab..a6d568aad8 100644 --- a/tests/unit/cli/test_pyrit_backend.py +++ b/tests/unit/cli/test_pyrit_backend.py @@ -55,3 +55,58 @@ async def test_initialize_and_run_passes_config_file_to_frontend_core(self) -> N mock_uvicorn_config.assert_called_once() mock_uvicorn_server.assert_called_once() mock_server.serve.assert_awaited_once() + + async def test_startup_warning_when_custom_initializers_enabled(self, capsys) -> None: + """Should print a warning when allow_custom_initializers is True.""" + parsed_args = pyrit_backend.parse_args(args=[]) + + with ( + patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, + patch("uvicorn.Config"), + patch("uvicorn.Server") as mock_uvicorn_server, + ): + mock_core = MagicMock() + mock_core.initialize_async = AsyncMock() + mock_core._initializer_configs = None + mock_core._allow_custom_initializers = True + mock_core._operator = None + mock_core._operation = None + mock_core._max_concurrent_scenario_runs = 3 + mock_core_class.return_value = mock_core + + mock_server = MagicMock() + mock_server.serve = AsyncMock() + mock_uvicorn_server.return_value = mock_server + + await pyrit_backend.initialize_and_run_async(parsed_args=parsed_args) + + captured = capsys.readouterr() + assert "WARNING" in captured.out + assert "allow_custom_initializers" in captured.out + + async def test_no_startup_warning_when_custom_initializers_disabled(self, capsys) -> None: + """Should not print custom initializer warning when disabled.""" + parsed_args = pyrit_backend.parse_args(args=[]) + + with ( + patch("pyrit.cli.pyrit_backend.frontend_core.FrontendCore") as mock_core_class, + patch("uvicorn.Config"), + patch("uvicorn.Server") as mock_uvicorn_server, + ): + mock_core = MagicMock() + mock_core.initialize_async = AsyncMock() + mock_core._initializer_configs = None + mock_core._allow_custom_initializers = False + mock_core._operator = None + mock_core._operation = None + mock_core._max_concurrent_scenario_runs = 3 + mock_core_class.return_value = mock_core + + mock_server = MagicMock() + mock_server.serve = AsyncMock() + mock_uvicorn_server.return_value = mock_server + + await pyrit_backend.initialize_and_run_async(parsed_args=parsed_args) + + captured = capsys.readouterr() + assert "allow_custom_initializers" not in captured.out diff --git a/tests/unit/identifiers/test_class_name_utils.py b/tests/unit/identifiers/test_class_name_utils.py index f660ee35e9..4a01909366 100644 --- a/tests/unit/identifiers/test_class_name_utils.py +++ b/tests/unit/identifiers/test_class_name_utils.py @@ -3,7 +3,11 @@ import pytest -from pyrit.identifiers.class_name_utils import class_name_to_snake_case, snake_case_to_class_name +from pyrit.identifiers.class_name_utils import ( + class_name_to_snake_case, + snake_case_to_class_name, + validate_registry_name, +) # --- class_name_to_snake_case --- @@ -94,3 +98,44 @@ def test_round_trip_snake_to_class(class_name): snake = class_name_to_snake_case(class_name) result = snake_case_to_class_name(snake) assert result == class_name + + +# --- validate_registry_name --- + + +@pytest.mark.parametrize( + "name", + ["simple", "my_custom", "a", "target", "load_default_datasets", "x" * 64], +) +def test_validate_registry_name_accepts_valid(name): + validate_registry_name(name) # should not raise + + +@pytest.mark.parametrize( + "name", + [ + "", # empty + "1starts_digit", # starts with digit + "_leading", # starts with underscore + "UPPER", # uppercase + "has-dash", # dash + "has.dot", # dot + "has space", # space + "../traversal", # path traversal + "x" * 65, # too long + ], +) +def test_validate_registry_name_rejects_invalid(name): + with pytest.raises(ValueError, match="Invalid registry name"): + validate_registry_name(name) + + +@pytest.mark.parametrize( + "class_name", + ["SimpleInitializer", "TargetInitializer", "LoadDefaultDatasets", "AIRTInitializer"], +) +def test_validate_registry_name_accepts_snake_case_output(class_name): + """Names produced by class_name_to_snake_case should always be valid registry names.""" + snake = class_name_to_snake_case(class_name, suffix="Initializer") + if snake: # skip empty (suffix == class_name edge case) + validate_registry_name(snake) diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py index 728872576e..380718b554 100644 --- a/tests/unit/registry/test_base.py +++ b/tests/unit/registry/test_base.py @@ -3,8 +3,10 @@ from dataclasses import dataclass, field +import pytest + from pyrit.registry.base import ClassRegistryEntry, _matches_filters -from pyrit.registry.class_registries.base_class_registry import ClassEntry +from pyrit.registry.class_registries.base_class_registry import BaseClassRegistry, ClassEntry @dataclass(frozen=True) @@ -14,6 +16,21 @@ class MetadataWithTags(ClassRegistryEntry): tags: tuple[str, ...] = field(kw_only=True) +class _TestRegistry(BaseClassRegistry[object, ClassRegistryEntry]): + """Minimal concrete registry for testing BaseClassRegistry methods.""" + + def _discover(self) -> None: + pass + + def _build_metadata(self, name: str, entry: ClassEntry[object]) -> ClassRegistryEntry: + return ClassRegistryEntry( + class_name=entry.registered_class.__name__, + class_module=entry.registered_class.__module__, + class_description=entry.get_description(fallback=""), + registry_name=name, + ) + + class TestDescriptionFromDocstring: """Tests for ClassRegistryEntry.description_from_docstring.""" @@ -209,3 +226,97 @@ def test_matches_filters_combined_include_and_exclude(self): ) is False ) + + +# ============================================================================ +# BaseClassRegistry.unregister Tests +# ============================================================================ + + +class _DummyClass: + """A dummy class for registry testing.""" + + +class _AnotherClass: + """Another dummy class.""" + + +def test_unregister_removes_entry(): + """Test that unregister removes a registered entry.""" + registry = _TestRegistry(lazy_discovery=True) + registry.register(_DummyClass, name="dummy") + assert "dummy" in registry + + registry.unregister("dummy") + assert "dummy" not in registry + assert len(registry) == 0 + + +def test_unregister_raises_key_error_for_missing(): + """Test that unregister raises KeyError when name is not registered.""" + registry = _TestRegistry(lazy_discovery=True) + + with pytest.raises(KeyError, match="not_here"): + registry.unregister("not_here") + + +def test_unregister_key_error_lists_available_names(): + """Test that the KeyError message includes available names.""" + registry = _TestRegistry(lazy_discovery=True) + registry.register(_DummyClass, name="alpha") + registry.register(_AnotherClass, name="beta") + + with pytest.raises(KeyError, match="alpha"): + registry.unregister("missing") + + +def test_unregister_invalidates_metadata_cache(): + """Test that unregister clears the metadata cache.""" + registry = _TestRegistry(lazy_discovery=True) + registry.register(_DummyClass, name="cached") + + registry.list_metadata() + assert registry._metadata_cache is not None + + registry.unregister("cached") + assert registry._metadata_cache is None + + +def test_unregister_does_not_affect_other_entries(): + """Test that unregistering one entry leaves others intact.""" + registry = _TestRegistry(lazy_discovery=True) + registry.register(_DummyClass, name="keep") + registry.register(_AnotherClass, name="remove") + + registry.unregister("remove") + + assert "keep" in registry + assert "remove" not in registry + assert registry.get_class("keep") is _DummyClass + + +def test_unregister_then_re_register(): + """Test that an entry can be re-registered after being unregistered.""" + registry = _TestRegistry(lazy_discovery=True) + registry.register(_DummyClass, name="reuse") + + registry.unregister("reuse") + assert "reuse" not in registry + + registry.register(_AnotherClass, name="reuse") + assert registry.get_class("reuse") is _AnotherClass + + +def test_unregister_makes_metadata_reflect_removal(): + """Test that list_metadata no longer includes the unregistered entry.""" + registry = _TestRegistry(lazy_discovery=True) + registry.register(_DummyClass, name="alpha") + registry.register(_AnotherClass, name="beta") + + assert len(registry.list_metadata()) == 2 + + registry.unregister("alpha") + metadata = registry.list_metadata() + + assert len(metadata) == 1 + assert metadata[0].registry_name == "beta" diff --git a/tests/unit/registry/test_initializer_registry.py b/tests/unit/registry/test_initializer_registry.py index 991019bcef..6396d5caad 100644 --- a/tests/unit/registry/test_initializer_registry.py +++ b/tests/unit/registry/test_initializer_registry.py @@ -1,7 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import tempfile from pathlib import Path +from unittest.mock import patch + +import pytest from pyrit.registry.class_registries.base_class_registry import ClassEntry from pyrit.registry.class_registries.initializer_registry import ( @@ -11,6 +15,14 @@ from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer +@pytest.fixture +def lazy_registry() -> InitializerRegistry: + """Create an InitializerRegistry with lazy discovery already marked as complete.""" + registry = InitializerRegistry(lazy_discovery=True) + registry._discovered = True + return registry + + def test_initializer_registry_default_discovery_path(): """Test that InitializerRegistry sets the default discovery path when None is passed.""" registry = InitializerRegistry(lazy_discovery=True) @@ -41,3 +53,163 @@ async def initialize_async(self) -> None: assert metadata.class_description == "A fake initializer for testing." assert metadata.class_name == "FakeInitializer" assert metadata.registry_name == "fake" + + +# ============================================================================ +# register_from_content Tests +# ============================================================================ + +_VALID_SCRIPT = """ +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +class ScriptTestInitializer(PyRITInitializer): + \"\"\"A test initializer from script.\"\"\" + + async def initialize_async(self) -> None: + pass +""" + + +def test_register_from_content_discovers_class(lazy_registry): + """Test registering an initializer from uploaded content.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + mock_dir.return_value = Path(tempfile.mkdtemp()) + name = lazy_registry.register_from_content(name="my_custom", script_content=_VALID_SCRIPT) + + assert name == "my_custom" + assert "my_custom" in lazy_registry + + +@pytest.mark.parametrize("bad_name", ["../traversal", "UPPER", "has space", "1digit", ""]) +def test_register_from_content_rejects_invalid_name(lazy_registry, bad_name): + """Test that register_from_content rejects names that fail registry name validation.""" + with pytest.raises(ValueError, match="Invalid registry name"): + lazy_registry.register_from_content(name=bad_name, script_content=_VALID_SCRIPT) + + +def test_register_from_content_no_classes_raises_value_error(lazy_registry): + """Test that ValueError is raised when content has no initializer classes.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + mock_dir.return_value = Path(tempfile.mkdtemp()) + + with pytest.raises(ValueError, match="does not contain"): + lazy_registry.register_from_content(name="empty", script_content="x = 1\n") + + +def test_register_from_content_bad_syntax_raises_value_error(lazy_registry): + """Test that a script with syntax errors raises ValueError.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + tmp_dir = Path(tempfile.mkdtemp()) + mock_dir.return_value = tmp_dir + + with pytest.raises(ValueError, match="Failed to load"): + lazy_registry.register_from_content(name="bad", script_content="def bad syntax(:\n") + + +def test_register_from_content_bad_syntax_cleans_up_file(lazy_registry): + """Test that a failed import cleans up the script file.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + tmp_dir = Path(tempfile.mkdtemp()) + mock_dir.return_value = tmp_dir + + with pytest.raises(ValueError): + lazy_registry.register_from_content(name="orphan", script_content="def bad syntax(:\n") + + assert not (tmp_dir / "orphan.py").exists() + + +def test_register_from_content_no_class_cleans_up_file(lazy_registry): + """Test that missing initializer class cleans up the script file.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + tmp_dir = Path(tempfile.mkdtemp()) + mock_dir.return_value = tmp_dir + + with pytest.raises(ValueError, match="does not contain"): + lazy_registry.register_from_content(name="no_class", script_content="x = 1\n") + + assert not (tmp_dir / "no_class.py").exists() + + +def test_register_from_content_rejects_duplicate_name(lazy_registry): + """Test that registering over an existing name raises ValueError.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + mock_dir.return_value = Path(tempfile.mkdtemp()) + lazy_registry.register_from_content(name="dup", script_content=_VALID_SCRIPT) + + with pytest.raises(ValueError, match="already registered"): + lazy_registry.register_from_content(name="dup", script_content=_VALID_SCRIPT) + + +def test_register_from_content_ignores_imported_classes(lazy_registry): + """Test that imported base classes are not registered.""" + script = """ +from pyrit.setup.initializers.simple import SimpleInitializer +from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +class LocalOnlyInitializer(PyRITInitializer): + \"\"\"Local only.\"\"\" + + async def initialize_async(self) -> None: + pass +""" + + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + mock_dir.return_value = Path(tempfile.mkdtemp()) + name = lazy_registry.register_from_content(name="local_only", script_content=script) + + assert name == "local_only" + cls = lazy_registry.get_class("local_only") + assert cls.__name__ == "LocalOnlyInitializer" + + +def test_unregister_and_cleanup_removes_entry_and_file(lazy_registry): + """Test that unregister_and_cleanup removes both registry entry and script file.""" + tmp_dir = Path(tempfile.mkdtemp()) + with patch.object(InitializerRegistry, "_get_custom_scripts_dir", return_value=tmp_dir): + lazy_registry.register_from_content(name="cleanup_test", script_content=_VALID_SCRIPT) + assert "cleanup_test" in lazy_registry + assert (tmp_dir / "cleanup_test.py").exists() + + lazy_registry.unregister_and_cleanup("cleanup_test") + assert "cleanup_test" not in lazy_registry + assert not (tmp_dir / "cleanup_test.py").exists() + + +def test_unregister_and_cleanup_rejects_builtin(lazy_registry): + """Test that unregister_and_cleanup raises ValueError for built-in initializers.""" + + class BuiltinInit(PyRITInitializer): + async def initialize_async(self) -> None: + pass + + entry = ClassEntry(registered_class=BuiltinInit) + lazy_registry._class_entries["builtin_test"] = entry + lazy_registry._builtin_names.add("builtin_test") + + with pytest.raises(ValueError, match="Cannot remove built-in"): + lazy_registry.unregister_and_cleanup("builtin_test") + + assert "builtin_test" in lazy_registry + + +def test_is_builtin_returns_true_for_discovered_initializers(lazy_registry): + """Test that is_builtin correctly identifies built-in entries.""" + + class FakeInit(PyRITInitializer): + async def initialize_async(self) -> None: + pass + + entry = ClassEntry(registered_class=FakeInit) + lazy_registry._class_entries["fake"] = entry + lazy_registry._builtin_names.add("fake") + + assert lazy_registry.is_builtin("fake") is True + + +def test_is_builtin_returns_false_for_custom_initializers(lazy_registry): + """Test that is_builtin returns False for custom-registered entries.""" + with patch.object(InitializerRegistry, "_get_custom_scripts_dir") as mock_dir: + mock_dir.return_value = Path(tempfile.mkdtemp()) + lazy_registry.register_from_content(name="custom", script_content=_VALID_SCRIPT) + + assert lazy_registry.is_builtin("custom") is False