diff --git a/doc/code/targets/0_prompt_targets.md b/doc/code/targets/0_prompt_targets.md index e00983f769..25fc2b476c 100644 --- a/doc/code/targets/0_prompt_targets.md +++ b/doc/code/targets/0_prompt_targets.md @@ -25,7 +25,7 @@ A `PromptTarget` is a generic place to send a prompt. With PyRIT, the idea is th With some algorithms, you want to send a prompt, set a system prompt, and modify conversation history (including PAIR [@chao2023pair], TAP [@mehrotra2023tap], and flip attack [@li2024flipattack]). These algorithms require a target whose [`TargetCapabilities`](#target-capabilities) declare both `supports_multi_turn=True` and `supports_editable_history=True` — i.e. you can modify a conversation history. Consumers express this requirement via `CHAT_TARGET_REQUIREMENTS` and validate it against `target.configuration` at construction time. See [Target Capabilities](#target-capabilities) below for the full list of capabilities and how they compose into a `TargetConfiguration`. -Note: The previous `PromptChatTarget` class is **deprecated** as of v0.13.0 and will be removed in v0.15.0. Use `PromptTarget` directly with a `TargetConfiguration` declaring `supports_multi_turn=True` and `supports_editable_history=True`. See [Target Capabilities](#target-capabilities) for details. +Note: The previous `PromptChatTarget` class is **deprecated** as of v0.14.0 and will be removed in v0.16.0. Use `PromptTarget` directly with a `TargetConfiguration` declaring `supports_multi_turn=True` and `supports_editable_history=True`. See [Target Capabilities](#target-capabilities) for details. Here are some examples: @@ -107,6 +107,20 @@ target = MyHTTPTarget(custom_configuration=config, ...) The full implementation lives in [`pyrit/prompt_target/common/target_capabilities.py`](https://github.com/microsoft/PyRIT/blob/main/pyrit/prompt_target/common/target_capabilities.py) and [`pyrit/prompt_target/common/target_configuration.py`](https://github.com/microsoft/PyRIT/blob/main/pyrit/prompt_target/common/target_configuration.py). For runnable examples — inspecting capabilities on a real target, comparing known model profiles, and `ADAPT` vs `RAISE` in action — see [Target Capabilities](./6_1_target_capabilities.ipynb). +### Discovering live target capabilities + +Declared capabilities describe what a target *should* support. For deployments where actual behavior is uncertain — custom OpenAI-compatible endpoints, gateways that strip features, models whose support drifts — you can probe what the target *actually* accepts at runtime: + +```python +from pyrit.prompt_target import discover_target_capabilities_async + +# Probe boolean capabilities and input modalities, returning a +# best-effort TargetCapabilities: +queried = await discover_target_capabilities_async(target=target) +``` + +Each probe sends a minimal request (bounded by `per_probe_timeout_s`, default 30s, with one retry on transient errors) and only marks a capability or modality as supported if the call returns cleanly. `discover_target_capabilities_async` returns a merged view: probed where possible, declared where probing is unavailable or out of scope. "Supported" here means *the request was accepted* — a target that silently ignores a system prompt or `response_format` directive is still reported as supporting it, so validate response content out of band when the distinction matters. This function is not safe to call concurrently with other operations on the same target instance: it temporarily mutates `target._configuration` and writes probe rows to memory (rows are tagged with `prompt_metadata["capability_probe"] == "1"` for filtering). See [Target Capabilities](./6_1_target_capabilities.ipynb) for runnable examples. + ## Multi-Modal Targets Like most of PyRIT, targets can be multi-modal. diff --git a/doc/code/targets/6_1_target_capabilities.ipynb b/doc/code/targets/6_1_target_capabilities.ipynb index 18c1902062..7630a560f6 100644 --- a/doc/code/targets/6_1_target_capabilities.ipynb +++ b/doc/code/targets/6_1_target_capabilities.ipynb @@ -53,13 +53,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "No new upgrade operations detected.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "No new upgrade operations detected.\n", "supports_multi_turn: True\n", "supports_editable_history: True\n", "supports_system_prompt: True\n", @@ -462,12 +456,297 @@ "try:\n", " no_editable_history.ensure_can_handle(capability=CapabilityName.EDITABLE_HISTORY)\n", "except ValueError as exc:\n", - " print(exc)\n", - "# ---" + " print(exc)" + ] + }, + { + "cell_type": "markdown", + "id": "19", + "metadata": {}, + "source": [ + "## 7. Discovering live target capabilities\n", + "\n", + "Declared capabilities describe what a target *should* support. For deployments where the actual\n", + "behavior is uncertain — custom OpenAI-compatible endpoints, gateways that strip features, models\n", + "whose support drifts over time — you can probe what the target *actually* accepts at runtime with\n", + "`discover_target_capabilities_async`. It runs both the boolean capability probes and the input\n", + "modality probes and returns a best-effort `TargetCapabilities`.\n", + "\n", + "Internally it walks each capability that has a registered probe (currently\n", + "`SYSTEM_PROMPT`, `MULTI_MESSAGE_PIECES`, `MULTI_TURN`, `JSON_OUTPUT`, `JSON_SCHEMA`), sends a\n", + "minimal request, and includes the capability in the result only if the call succeeds.\n", + "During probing the target's configuration is temporarily replaced with a permissive one so\n", + "`ensure_can_handle` does not short-circuit a probe for a capability the target declares as\n", + "unsupported. The original configuration is restored before the function returns. The same\n", + "treatment is applied to each input modality combination declared in\n", + "`capabilities.input_modalities`, sending a small payload built from optional `test_assets`.\n", + "\n", + "Each probe call is bounded by `per_probe_timeout_s` (default 30s) and is retried once on\n", + "transient errors before being declared failed. The returned `TargetCapabilities` is a merged\n", + "view: probed where possible, declared where probing is unavailable or out of scope.\n", + "\"Supported\" here means *the request was accepted* — a target that silently ignores a system\n", + "prompt or `response_format` directive will still be reported as supporting that capability.\n", + "\n", + "This function is **not safe to call concurrently** with other operations on the same target\n", + "instance: it temporarily mutates `target._configuration` and writes probe rows to\n", + "`target._memory`. Probe-written memory rows are tagged with\n", + "`prompt_metadata[\"capability_probe\"] == \"1\"` so consumers can filter them.\n", + "\n", + "Typical usage against a real endpoint:\n", + "\n", + "```python\n", + "from pyrit.prompt_target import discover_target_capabilities_async\n", + "\n", + "queried = await discover_target_capabilities_async(target=target)\n", + "print(queried)\n", + "```\n", + "\n", + "Below we mock the target's underlying transport (`_send_prompt_to_target_async`) so the notebook\n", + "stays self-contained — the result shape is the same as a live run. We mock the protected method\n", + "rather than `send_prompt_async` so the probe still exercises the real validation and memory\n", + "pipeline." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "queried capabilities:\n", + " - supports_editable_history\n", + " - supports_json_output\n", + " - supports_json_schema\n", + " - supports_multi_message_pieces\n", + " - supports_multi_turn\n", + " - supports_system_prompt\n" + ] + } + ], + "source": [ + "from unittest.mock import AsyncMock\n", + "\n", + "from pyrit.models import MessagePiece\n", + "from pyrit.prompt_target import discover_target_capabilities_async\n", + "\n", + "\n", + "def _ok_response():\n", + " return [\n", + " Message(\n", + " [\n", + " MessagePiece(\n", + " role=\"assistant\",\n", + " original_value=\"ok\",\n", + " original_value_data_type=\"text\",\n", + " conversation_id=\"probe\",\n", + " response_error=\"none\",\n", + " )\n", + " ]\n", + " )\n", + " ]\n", + "\n", + "\n", + "probe_target = OpenAIChatTarget(model_name=\"gpt-4o\", endpoint=\"https://example.invalid/\", api_key=\"sk-not-a-real-key\")\n", + "probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign]\n", + "\n", + "queried = await discover_target_capabilities_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore\n", + "print(\"discover_target_capabilities_async result:\")\n", + "print(f\" supports_multi_turn: {queried.supports_multi_turn}\")\n", + "print(f\" supports_system_prompt: {queried.supports_system_prompt}\")\n", + "print(f\" supports_multi_message_pieces: {queried.supports_multi_message_pieces}\")\n", + "print(f\" supports_json_output: {queried.supports_json_output}\")\n", + "print(f\" supports_json_schema: {queried.supports_json_schema}\")\n", + "print(f\" input_modalities: {sorted(sorted(m) for m in queried.input_modalities)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "21", + "metadata": {}, + "source": [ + "To narrow the probe to specific capabilities (faster, fewer calls), pass `capabilities=`:\n", + "\n", + "```python\n", + "from pyrit.prompt_target.common.target_capabilities import CapabilityName\n", + "\n", + "queried = await discover_target_capabilities_async(\n", + " target=target,\n", + " capabilities=[CapabilityName.JSON_SCHEMA, CapabilityName.SYSTEM_PROMPT],\n", + ")\n", + "```\n", + "\n", + "Similarly, narrow the modality probe set with `test_modalities=` and override the\n", + "packaged default probe assets with `test_assets=`." + ] + }, + { + "cell_type": "markdown", + "id": "22", + "metadata": {}, + "source": [ + "### Discovering undeclared modalities\n", + "\n", + "By default `discover_target_capabilities_async` only probes modality combinations the target already\n", + "**declares** in `capabilities.input_modalities`. For an OpenAI-compatible endpoint that\n", + "claims text-only but might actually accept images, pass `test_modalities=` explicitly to\n", + "probe combinations beyond the declared baseline. Provide `test_assets=` as well if you need\n", + "to override the packaged defaults or probe a modality without one:\n", + "\n", + "```python\n", + "queried = await discover_target_capabilities_async(\n", + " target=target,\n", + " test_modalities={frozenset({\"text\"}), frozenset({\"text\", \"image_path\"})},\n", + " test_assets={\"image_path\": \"/path/to/test_image.png\"},\n", + ")\n", + "```\n", + "\n", + "Similarly, when narrowing the probe set with `capabilities=`, capabilities NOT in the\n", + "narrowed set are copied from the target's declared values rather than being reset to\n", + "`False` — narrowing controls *what is re-queried*, not what the returned dataclass\n", + "reports. This makes incremental probing safe:\n", + "\n", + "```python\n", + "# Re-query only JSON support; other declared flags pass through unchanged.\n", + "queried = await discover_target_capabilities_async(\n", + " target=target,\n", + " capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA},\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "## 8. Applying probed capabilities back onto the target\n", + "\n", + "`discover_target_capabilities_async` is intentionally pure: it returns a `TargetCapabilities` without\n", + "mutating the target. That lets you inspect (or diff against the declared view, log, gate on\n", + "the result) before committing. Once you're satisfied, call `target.apply_capabilities(...)`\n", + "to install the probed view on the instance. The target's existing\n", + "`CapabilityHandlingPolicy` is preserved — policy expresses user intent (ADAPT vs RAISE),\n", + "which is independent of what the probe found.\n", + "\n", + "Why a two-step pattern rather than auto-apply? Probe results are an upper bound\n", + "(\"the request was accepted\"); a target that silently ignores a feature still passes its\n", + "probe. Keeping discovery separate from application lets callers diff, log, persist, or\n", + "reject the result before it affects subsequent sends.\n", + "\n", + "Below is the end-to-end pattern: construct a target whose declared capabilities are\n", + "pessimistic, discover what the endpoint actually accepts, diff the two views, then apply." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "declared (before probing):\n", + " supports_multi_turn: False\n", + " supports_system_prompt: False\n", + " supports_json_output: False\n", + "\n", + "probed (returned from discover_target_capabilities_async, target NOT yet updated):\n", + " supports_multi_turn: True\n", + " supports_system_prompt: True\n", + " supports_json_output: True\n", + " target.capabilities.supports_multi_turn (still declared): False\n", + "\n", + "flags probed True that were declared False: ['supports_multi_turn', 'supports_system_prompt', 'supports_multi_message_pieces', 'supports_json_output', 'supports_json_schema']\n", + "\n", + "after apply_capabilities:\n", + " supports_multi_turn: True\n", + " supports_system_prompt: True\n", + " supports_json_output: True\n", + " policy preserved: True\n", + "\n", + "CHAT_TARGET_REQUIREMENTS.validate now passes against the probed target\n" + ] + } + ], + "source": [ + "# Start with an instance that declares fewer capabilities than the endpoint actually has,\n", + "# e.g. a custom gateway whose support we're unsure about.\n", + "pessimistic_config = TargetConfiguration(\n", + " capabilities=TargetCapabilities(\n", + " supports_multi_turn=False,\n", + " supports_system_prompt=False,\n", + " supports_multi_message_pieces=False,\n", + " supports_json_output=False,\n", + " supports_json_schema=False,\n", + " # Editable history has no live probe and falls back to the declared value.\n", + " # Declare it True here so the probed view inherits it.\n", + " supports_editable_history=True,\n", + " ),\n", + ")\n", + "endpoint_target = OpenAIChatTarget(\n", + " model_name=\"custom-model\",\n", + " endpoint=\"https://example.invalid/\",\n", + " api_key=\"sk-not-a-real-key\",\n", + " custom_configuration=pessimistic_config,\n", + ")\n", + "endpoint_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign]\n", + "\n", + "print(\"declared (before probing):\")\n", + "print(f\" supports_multi_turn: {endpoint_target.capabilities.supports_multi_turn}\")\n", + "print(f\" supports_system_prompt: {endpoint_target.capabilities.supports_system_prompt}\")\n", + "print(f\" supports_json_output: {endpoint_target.capabilities.supports_json_output}\")\n", + "\n", + "# Step 1: discover. No mutation yet — `endpoint_target.capabilities` is unchanged.\n", + "probed_caps = await discover_target_capabilities_async(target=endpoint_target, per_probe_timeout_s=5.0) # type: ignore\n", + "\n", + "print(\"\\nprobed (returned from discover_target_capabilities_async, target NOT yet updated):\")\n", + "print(f\" supports_multi_turn: {probed_caps.supports_multi_turn}\")\n", + "print(f\" supports_system_prompt: {probed_caps.supports_system_prompt}\")\n", + "print(f\" supports_json_output: {probed_caps.supports_json_output}\")\n", + "print(f\" target.capabilities.supports_multi_turn (still declared): {endpoint_target.capabilities.supports_multi_turn}\")\n", + "\n", + "# Step 2: diff — see exactly what the probe upgraded.\n", + "declared = pessimistic_config.capabilities\n", + "upgraded = [\n", + " name\n", + " for name in (\n", + " \"supports_multi_turn\",\n", + " \"supports_system_prompt\",\n", + " \"supports_multi_message_pieces\",\n", + " \"supports_json_output\",\n", + " \"supports_json_schema\",\n", + " )\n", + " if getattr(probed_caps, name) and not getattr(declared, name)\n", + "]\n", + "print(f\"\\nflags probed True that were declared False: {upgraded}\")\n", + "\n", + "# Step 3: apply. Policy is preserved; the normalization pipeline is rebuilt.\n", + "original_policy = endpoint_target.configuration.policy\n", + "endpoint_target.apply_capabilities(capabilities=probed_caps)\n", + "\n", + "print(\"\\nafter apply_capabilities:\")\n", + "print(f\" supports_multi_turn: {endpoint_target.capabilities.supports_multi_turn}\")\n", + "print(f\" supports_system_prompt: {endpoint_target.capabilities.supports_system_prompt}\")\n", + "print(f\" supports_json_output: {endpoint_target.capabilities.supports_json_output}\")\n", + "print(f\" policy preserved: {endpoint_target.configuration.policy is original_policy}\")\n", + "\n", + "# Subsequent consumer checks now reflect the probed reality — for example, a chat-style\n", + "# requirement that would have failed against the pessimistic declaration now passes.\n", + "CHAT_TARGET_REQUIREMENTS.validate(target=endpoint_target)\n", + "print(\"\\nCHAT_TARGET_REQUIREMENTS.validate now passes against the probed target\")" ] } ], "metadata": { + "jupytext": { + "main_language": "python" + }, "language_info": { "codemirror_mode": { "name": "ipython", diff --git a/doc/code/targets/6_1_target_capabilities.py b/doc/code/targets/6_1_target_capabilities.py index 985374357b..2793083d12 100644 --- a/doc/code/targets/6_1_target_capabilities.py +++ b/doc/code/targets/6_1_target_capabilities.py @@ -5,7 +5,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.19.0 +# jupytext_version: 1.19.1 # --- # %% [markdown] @@ -245,4 +245,211 @@ no_editable_history.ensure_can_handle(capability=CapabilityName.EDITABLE_HISTORY) except ValueError as exc: print(exc) -# --- + +# %% [markdown] +# ## 7. Discovering live target capabilities +# +# Declared capabilities describe what a target *should* support. For deployments where the actual +# behavior is uncertain — custom OpenAI-compatible endpoints, gateways that strip features, models +# whose support drifts over time — you can probe what the target *actually* accepts at runtime with +# `discover_target_capabilities_async`. It runs both the boolean capability probes and the input +# modality probes and returns a best-effort `TargetCapabilities`. +# +# Internally it walks each capability that has a registered probe (currently +# `SYSTEM_PROMPT`, `MULTI_MESSAGE_PIECES`, `MULTI_TURN`, `JSON_OUTPUT`, `JSON_SCHEMA`), sends a +# minimal request, and includes the capability in the result only if the call succeeds. +# During probing the target's configuration is temporarily replaced with a permissive one so +# `ensure_can_handle` does not short-circuit a probe for a capability the target declares as +# unsupported. The original configuration is restored before the function returns. The same +# treatment is applied to each input modality combination declared in +# `capabilities.input_modalities`, sending a small payload built from optional `test_assets`. +# +# Each probe call is bounded by `per_probe_timeout_s` (default 30s) and is retried once on +# transient errors before being declared failed. The returned `TargetCapabilities` is a merged +# view: probed where possible, declared where probing is unavailable or out of scope. +# "Supported" here means *the request was accepted* — a target that silently ignores a system +# prompt or `response_format` directive will still be reported as supporting that capability. +# +# This function is **not safe to call concurrently** with other operations on the same target +# instance: it temporarily mutates `target._configuration` and writes probe rows to +# `target._memory`. Probe-written memory rows are tagged with +# `prompt_metadata["capability_probe"] == "1"` so consumers can filter them. +# +# Typical usage against a real endpoint: +# +# ```python +# from pyrit.prompt_target import discover_target_capabilities_async +# +# queried = await discover_target_capabilities_async(target=target) +# print(queried) +# ``` +# +# Below we mock the target's underlying transport (`_send_prompt_to_target_async`) so the notebook +# stays self-contained — the result shape is the same as a live run. We mock the protected method +# rather than `send_prompt_async` so the probe still exercises the real validation and memory +# pipeline. + +# %% +from unittest.mock import AsyncMock + +from pyrit.models import MessagePiece +from pyrit.prompt_target import discover_target_capabilities_async + + +def _ok_response(): + return [ + Message( + [ + MessagePiece( + role="assistant", + original_value="ok", + original_value_data_type="text", + conversation_id="probe", + response_error="none", + ) + ] + ) + ] + + +probe_target = OpenAIChatTarget(model_name="gpt-4o", endpoint="https://example.invalid/", api_key="sk-not-a-real-key") +probe_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + +queried = await discover_target_capabilities_async(target=probe_target, per_probe_timeout_s=5.0) # type: ignore +print("discover_target_capabilities_async result:") +print(f" supports_multi_turn: {queried.supports_multi_turn}") +print(f" supports_system_prompt: {queried.supports_system_prompt}") +print(f" supports_multi_message_pieces: {queried.supports_multi_message_pieces}") +print(f" supports_json_output: {queried.supports_json_output}") +print(f" supports_json_schema: {queried.supports_json_schema}") +print(f" input_modalities: {sorted(sorted(m) for m in queried.input_modalities)}") + +# %% [markdown] +# To narrow the probe to specific capabilities (faster, fewer calls), pass `capabilities=`: +# +# ```python +# from pyrit.prompt_target.common.target_capabilities import CapabilityName +# +# queried = await discover_target_capabilities_async( +# target=target, +# capabilities=[CapabilityName.JSON_SCHEMA, CapabilityName.SYSTEM_PROMPT], +# ) +# ``` +# +# Similarly, narrow the modality probe set with `test_modalities=` and override the +# packaged default probe assets with `test_assets=`. + +# %% [markdown] +# ### Discovering undeclared modalities +# +# By default `discover_target_capabilities_async` only probes modality combinations the target already +# **declares** in `capabilities.input_modalities`. For an OpenAI-compatible endpoint that +# claims text-only but might actually accept images, pass `test_modalities=` explicitly to +# probe combinations beyond the declared baseline. Provide `test_assets=` as well if you need +# to override the packaged defaults or probe a modality without one: +# +# ```python +# queried = await discover_target_capabilities_async( +# target=target, +# test_modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, +# test_assets={"image_path": "/path/to/test_image.png"}, +# ) +# ``` +# +# Similarly, when narrowing the probe set with `capabilities=`, capabilities NOT in the +# narrowed set are copied from the target's declared values rather than being reset to +# `False` — narrowing controls *what is re-queried*, not what the returned dataclass +# reports. This makes incremental probing safe: +# +# ```python +# # Re-query only JSON support; other declared flags pass through unchanged. +# queried = await discover_target_capabilities_async( +# target=target, +# capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA}, +# ) +# ``` + +# %% [markdown] +# ## 8. Applying probed capabilities back onto the target +# +# `discover_target_capabilities_async` is intentionally pure: it returns a `TargetCapabilities` without +# mutating the target. That lets you inspect (or diff against the declared view, log, gate on +# the result) before committing. Once you're satisfied, call `target.apply_capabilities(...)` +# to install the probed view on the instance. The target's existing +# `CapabilityHandlingPolicy` is preserved — policy expresses user intent (ADAPT vs RAISE), +# which is independent of what the probe found. +# +# Why a two-step pattern rather than auto-apply? Probe results are an upper bound +# ("the request was accepted"); a target that silently ignores a feature still passes its +# probe. Keeping discovery separate from application lets callers diff, log, persist, or +# reject the result before it affects subsequent sends. +# +# Below is the end-to-end pattern: construct a target whose declared capabilities are +# pessimistic, discover what the endpoint actually accepts, diff the two views, then apply. + +# %% +# Start with an instance that declares fewer capabilities than the endpoint actually has, +# e.g. a custom gateway whose support we're unsure about. +pessimistic_config = TargetConfiguration( + capabilities=TargetCapabilities( + supports_multi_turn=False, + supports_system_prompt=False, + supports_multi_message_pieces=False, + supports_json_output=False, + supports_json_schema=False, + # Editable history has no live probe and falls back to the declared value. + # Declare it True here so the probed view inherits it. + supports_editable_history=True, + ), +) +endpoint_target = OpenAIChatTarget( + model_name="custom-model", + endpoint="https://example.invalid/", + api_key="sk-not-a-real-key", + custom_configuration=pessimistic_config, +) +endpoint_target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + +print("declared (before probing):") +print(f" supports_multi_turn: {endpoint_target.capabilities.supports_multi_turn}") +print(f" supports_system_prompt: {endpoint_target.capabilities.supports_system_prompt}") +print(f" supports_json_output: {endpoint_target.capabilities.supports_json_output}") + +# Step 1: discover. No mutation yet — `endpoint_target.capabilities` is unchanged. +probed_caps = await discover_target_capabilities_async(target=endpoint_target, per_probe_timeout_s=5.0) # type: ignore + +print("\nprobed (returned from discover_target_capabilities_async, target NOT yet updated):") +print(f" supports_multi_turn: {probed_caps.supports_multi_turn}") +print(f" supports_system_prompt: {probed_caps.supports_system_prompt}") +print(f" supports_json_output: {probed_caps.supports_json_output}") +print(f" target.capabilities.supports_multi_turn (still declared): {endpoint_target.capabilities.supports_multi_turn}") + +# Step 2: diff — see exactly what the probe upgraded. +declared = pessimistic_config.capabilities +upgraded = [ + name + for name in ( + "supports_multi_turn", + "supports_system_prompt", + "supports_multi_message_pieces", + "supports_json_output", + "supports_json_schema", + ) + if getattr(probed_caps, name) and not getattr(declared, name) +] +print(f"\nflags probed True that were declared False: {upgraded}") + +# Step 3: apply. Policy is preserved; the normalization pipeline is rebuilt. +original_policy = endpoint_target.configuration.policy +endpoint_target.apply_capabilities(capabilities=probed_caps) + +print("\nafter apply_capabilities:") +print(f" supports_multi_turn: {endpoint_target.capabilities.supports_multi_turn}") +print(f" supports_system_prompt: {endpoint_target.capabilities.supports_system_prompt}") +print(f" supports_json_output: {endpoint_target.capabilities.supports_json_output}") +print(f" policy preserved: {endpoint_target.configuration.policy is original_policy}") + +# Subsequent consumer checks now reflect the probed reality — for example, a chat-style +# requirement that would have failed against the pessimistic declaration now passes. +CHAT_TARGET_REQUIREMENTS.validate(target=endpoint_target) +print("\nCHAT_TARGET_REQUIREMENTS.validate now passes against the probed target") diff --git a/pyrit/datasets/prompt_target/target_capabilities/probe_audio.wav b/pyrit/datasets/prompt_target/target_capabilities/probe_audio.wav new file mode 100644 index 0000000000..8dbde9545c Binary files /dev/null and b/pyrit/datasets/prompt_target/target_capabilities/probe_audio.wav differ diff --git a/pyrit/datasets/prompt_target/target_capabilities/probe_image.png b/pyrit/datasets/prompt_target/target_capabilities/probe_image.png new file mode 100644 index 0000000000..85dda3a6b1 Binary files /dev/null and b/pyrit/datasets/prompt_target/target_capabilities/probe_image.png differ diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index 489fe34900..82f897c156 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -14,6 +14,9 @@ from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.discover_target_capabilities import ( + discover_target_capabilities_async, +) from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget from pyrit.prompt_target.common.target_capabilities import ( @@ -103,5 +106,6 @@ def __getattr__(name: str) -> object: "TargetRequirements", "UnsupportedCapabilityBehavior", "TextTarget", + "discover_target_capabilities_async", "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/common/discover_target_capabilities.py b/pyrit/prompt_target/common/discover_target_capabilities.py new file mode 100644 index 0000000000..859d07d428 --- /dev/null +++ b/pyrit/prompt_target/common/discover_target_capabilities.py @@ -0,0 +1,839 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Runtime capability and modality discovery for prompt targets. + +This module exposes two complementary probes: + +* :func:`_discover_capability_flags_async` discovers the boolean capability flags + defined on :class:`TargetCapabilities` (e.g. ``supports_system_prompt``, + ``supports_multi_message_pieces``). For each capability that has a probe + defined, a minimal request is sent to the target. If the request succeeds, + the capability is included in the returned set. Capabilities without a + registered probe fall back to the target's declared native support from + ``target.capabilities``. +* :func:`_discover_input_modalities_async` discovers which input modality + combinations a target actually supports by sending a minimal test request + for each combination declared in ``TargetCapabilities.input_modalities``. + +.. note:: + Output modality probing is intentionally not provided. Unlike inputs, + output modality is largely a property of the endpoint type (chat models + return text, image models return images, TTS endpoints return audio) + rather than something the caller controls per request, and there is no + PyRIT-level ``response_format=image`` style hint to assert against. + Eliciting non-text output reliably depends on prompt phrasing, costs + real compute per probe, and is prone to false negatives from safety + filters. Trust ``target.capabilities.output_modalities`` as declared. + +.. warning:: + These probes only verify that a request was *accepted*. They do not prove + that the endpoint enforced the feature, and the JSON probes are only + meaningful for targets that translate ``prompt_metadata`` JSON hints into + provider request fields. Treat the results as an upper bound on support and + validate response content separately when that distinction matters. +""" + +import asyncio +import json +import logging +import os +import uuid +from collections.abc import Awaitable, Callable, Iterable, Iterator +from contextlib import contextmanager +from dataclasses import replace + +from pyrit.common.path import DATASETS_PATH +from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityName, + TargetCapabilities, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration + +logger = logging.getLogger(__name__) + +# Per-call timeout (seconds) applied to every discovery request. Override per-call via +# the ``per_probe_timeout_s`` parameter on the public functions. +DEFAULT_PROBE_TIMEOUT_SECONDS: float = 30.0 +DEFAULT_PROBE_RETRY_BACKOFF_SECONDS: float = 0.1 +MAX_PROBE_RETRY_BACKOFF_SECONDS: float = 1.0 + +# Exceptions that are deterministic on the probe payload and will not become +# valid on a retry (malformed Message, type errors, missing attributes, etc.). +# These fail the probe immediately rather than wasting backoff time. +_NON_RETRYABLE_PROBE_EXCEPTIONS: tuple[type[BaseException], ...] = ( + ValueError, + TypeError, + AttributeError, +) + +# Marker stamped onto every MessagePiece this module writes to memory. Consumers +# that aggregate or display memory rows can filter probe-written rows by checking +# ``piece.prompt_metadata.get("capability_probe") == "1"``. Memory does not yet +# expose a delete-by-conversation-id API, so tagging is the cleanup mechanism. +PROBE_METADATA_KEY: str = "capability_probe" +PROBE_METADATA_VALUE: str = "1" + +_CapabilityProbe = Callable[[PromptTarget, float, int], Awaitable[bool]] + + +def _json_enforcing_target_types() -> tuple[type[PromptTarget], ...]: + """ + Return the tuple of target classes that translate ``prompt_metadata`` JSON + hints (``response_format``, ``json_schema``) into native provider request + fields. Used to suppress the "JSON probe is upper-bound only" debug log for + these targets and their subclasses. + + Imports are lazy to avoid a circular dependency at module load time and to + keep this concern entirely within the discovery module. Class objects (not + strings) are returned so renames are caught by import errors rather than + silently flipping the log behavior. + + Returns: + tuple[type[PromptTarget], ...]: The target classes that enforce JSON hints. + """ + from pyrit.prompt_target.openai.openai_chat_target import OpenAIChatTarget + from pyrit.prompt_target.openai.openai_response_target import OpenAIResponseTarget + + return (OpenAIChatTarget, OpenAIResponseTarget) + + +# Every text probe sends a text-only payload. Permissive overrides therefore +# always include this combination so that ``_validate_request``'s per-piece +# data-type check does not reject text probes against text-less targets. +_TEXT_MODALITY: frozenset[frozenset[PromptDataType]] = frozenset({frozenset({"text"})}) + +# Packaged fallback assets for non-text modality discovery. +_TARGET_CAPABILITIES_DATASET_PATH = DATASETS_PATH / "prompt_target" / "target_capabilities" + + +@contextmanager +def _permissive_configuration( + *, + target: PromptTarget, + extra_input_modalities: Iterable[frozenset[PromptDataType]] | None = None, +) -> Iterator[None]: + """ + Temporarily replace ``target``'s configuration with one that declares every + boolean capability as natively supported. + + This bypasses :meth:`PromptTarget._validate_request`, which would otherwise + short-circuit probes for capabilities the target declares as unsupported + before any API call is made. The original configuration is restored on exit. + + Args: + target (PromptTarget): The target whose configuration is temporarily replaced. + extra_input_modalities (Iterable[frozenset[PromptDataType]] | None): + Additional modality combinations to include in ``input_modalities`` + during the override. Used by modality probes so that + ``_validate_request``'s per-piece data-type check does not reject + combinations the caller asked us to test but the target does not + yet declare. Defaults to None. + + Yields: + None: Control returns to the ``with`` block while the permissive + configuration is in effect. + """ + original = target.configuration + merged_modalities = original.capabilities.input_modalities | _TEXT_MODALITY + if extra_input_modalities is not None: + merged_modalities = frozenset(merged_modalities | frozenset(extra_input_modalities)) + permissive_caps = replace( + original.capabilities, + supports_multi_turn=True, + supports_multi_message_pieces=True, + supports_json_schema=True, + supports_json_output=True, + supports_editable_history=True, + supports_system_prompt=True, + input_modalities=merged_modalities, + ) + # Rebuild a fresh configuration from the instance's native capabilities so + # probes bypass preflight validation without inheriting ADAPT policy or + # custom normalizer overrides from the target's runtime configuration. + probe_configuration = TargetConfiguration(capabilities=permissive_caps) + target._configuration = probe_configuration + try: + yield + finally: + target._configuration = original + + +def _new_conversation_id() -> str: + """ + Generate a unique conversation id for a single capability probe. + + Returns: + str: A conversation id of the form ``"capability-probe-"``. + """ + return f"capability-probe-{uuid.uuid4()}" + + +def _probe_metadata(extra: dict[str, str | int] | None = None) -> dict[str, str | int]: + """Return a fresh ``prompt_metadata`` dict tagged as a capability probe.""" + metadata: dict[str, str | int] = {PROBE_METADATA_KEY: PROBE_METADATA_VALUE} + if extra: + metadata.update(extra) + return metadata + + +def _user_text_piece(*, value: str, conversation_id: str) -> MessagePiece: + """ + Build a single user-role text :class:`MessagePiece` for use in a probe. + + The piece's ``prompt_metadata`` is tagged with :data:`PROBE_METADATA_KEY` + so that consumers aggregating memory can filter out probe-written rows. + + Args: + value (str): The text payload to send. + conversation_id (str): The conversation id to attach to the piece. + + Returns: + MessagePiece: A user-role text piece bound to ``conversation_id``. + """ + return MessagePiece( + role="user", + original_value=value, + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + + +async def _send_and_check_async( + *, + target: PromptTarget, + message: Message, + timeout_s: float, + retries: int = 1, + label: str = "Capability probe", +) -> bool: + """ + Send ``message`` and report whether the call succeeded cleanly. + + Each attempt is bounded by ``timeout_s``. Transient errors (timeouts, + connection/OS errors) trigger up to ``retries`` retries with a short + exponential backoff. Deterministic errors that will not become valid on + a retry (``ValueError``, ``TypeError``, ``AttributeError`` — typically + from message validation or programmer error in a probe payload) fail + the probe immediately. An explicit error response from the target is + treated as deterministic and never retried. + + Args: + target (PromptTarget): The target to send the probe message to. + message (Message): The probe message to send. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only transient errors are retried; non-retryable errors and + non-error responses are final. Retry attempts use exponential + backoff starting at :data:`DEFAULT_PROBE_RETRY_BACKOFF_SECONDS`. + Defaults to 1. + label (str): Short label used in log messages. Defaults to + ``"Capability probe"``. + + Returns: + bool: ``True`` iff the call returned without raising and every response + piece reported ``response_error == "none"``; ``False`` otherwise. + Any other ``response_error`` value (``"blocked"``, ``"processing"``, + ``"empty"``, ``"unknown"``) is treated as failure. An empty response + list (or responses with no message pieces) is also treated as a failure. + """ + attempts = max(1, retries + 1) + last_exc: Exception | None = None + for attempt in range(attempts): + try: + responses = await asyncio.wait_for(target.send_prompt_async(message=message), timeout=timeout_s) + except asyncio.TimeoutError: + last_exc = TimeoutError(f"timed out after {timeout_s}s") + logger.debug("%s timed out (attempt %d/%d)", label, attempt + 1, attempts) + if attempt + 1 < attempts: + await _sleep_before_retry_async(attempt=attempt) + continue + except _NON_RETRYABLE_PROBE_EXCEPTIONS as exc: + # Deterministic on the probe payload — retrying will not help. + logger.debug("%s failed with non-retryable error: %s", label, exc) + return False + except Exception as exc: + last_exc = exc + logger.debug("%s failed (attempt %d/%d): %s", label, attempt + 1, attempts, exc) + if attempt + 1 < attempts: + await _sleep_before_retry_async(attempt=attempt) + continue + + if not responses or any(not r.message_pieces for r in responses): + logger.debug("%s returned an empty response; treating as failure", label) + return False + for response in responses: + for piece in response.message_pieces: + if piece.response_error != "none": + logger.debug("%s returned error response: %s", label, piece.converted_value) + return False + return True + + logger.info("%s exhausted %d attempt(s); last error: %s", label, attempts, last_exc) + return False + + +def _retry_backoff_seconds(*, attempt: int) -> float: + """Return the exponential backoff delay for a retry attempt.""" + return min(DEFAULT_PROBE_RETRY_BACKOFF_SECONDS * (2**attempt), MAX_PROBE_RETRY_BACKOFF_SECONDS) + + +async def _sleep_before_retry_async(*, attempt: int) -> None: + """Sleep for the retry backoff associated with ``attempt``.""" + await asyncio.sleep(_retry_backoff_seconds(attempt=attempt)) + + +async def _probe_system_prompt_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a system prompt followed by a user message. + + Writes a system-role :class:`MessagePiece` directly to ``target._memory`` + rather than calling :meth:`pyrit.prompt_target.PromptChatTarget.set_system_prompt` + (which is only defined on ``PromptChatTarget`` subclasses anyway). + ``set_system_prompt`` can be overridden by subclasses (e.g. mocks) to do + nothing or to perform extra work, which would mask whether the underlying + API actually accepts a system message. A direct memory write also works + uniformly for plain ``PromptTarget`` subclasses that have no + ``set_system_prompt`` method, and guarantees the probe sees the same + multi-piece, system-then-user payload the target's wire layer would see + via the standard pipeline. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the system + user request succeeded; ``False`` otherwise. + """ + conversation_id = _new_conversation_id() + system_piece = MessagePiece( + role="system", + original_value="You are a helpful assistant.", + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + try: + target._memory.add_message_to_memory(request=Message([system_piece])) + except Exception as exc: + logger.debug("System-prompt probe could not seed system message: %s", exc) + return False + user_piece = _user_text_piece(value="hi", conversation_id=conversation_id) + return await _send_and_check_async( + target=target, + message=Message([user_piece]), + timeout_s=timeout_s, + retries=retries, + label="System-prompt probe", + ) + + +async def _probe_multi_message_pieces_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a single message containing multiple pieces. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the multi-piece request succeeded; ``False`` otherwise. + """ + conversation_id = _new_conversation_id() + pieces = [ + _user_text_piece(value="part one", conversation_id=conversation_id), + _user_text_piece(value="part two", conversation_id=conversation_id), + ] + return await _send_and_check_async( + target=target, + message=Message(pieces), + timeout_s=timeout_s, + retries=retries, + label="Multi-message-pieces probe", + ) + + +async def _probe_multi_turn_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a request that includes prior conversation history. + + ``PromptTarget.send_prompt_async`` reads conversation history from memory but + does not write to it (persistence normally happens in the orchestrator + layer). To exercise true multi-turn behavior, this probe: + + 1. Sends an initial user message. + 2. Persists that user message and a synthetic assistant reply directly to + the target's memory under the same ``conversation_id``. + 3. Sends a second user message; ``send_prompt_async`` then fetches the + 2-message history and the target receives a real 3-message + multi-turn payload. + + The synthetic assistant reply's content is irrelevant — we are testing + whether the target's API accepts a multi-turn payload, not whether the + model recalls anything. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if both turns succeeded; ``False`` if either turn failed. + """ + conversation_id = _new_conversation_id() + first = _user_text_piece(value="My favorite color is blue.", conversation_id=conversation_id) + if not await _send_and_check_async( + target=target, message=Message([first]), timeout_s=timeout_s, retries=retries, label="Multi-turn probe (turn 1)" + ): + return False + + # Seed memory so the second send sees real prior history. + try: + target._memory.add_message_to_memory(request=Message([first])) + assistant_reply = MessagePiece( + role="assistant", + original_value="Got it.", + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ).to_message() + target._memory.add_message_to_memory(request=assistant_reply) + except Exception as exc: + logger.debug("Multi-turn probe could not seed conversation history: %s", exc) + return False + + second = _user_text_piece(value="What did I just tell you?", conversation_id=conversation_id) + return await _send_and_check_async( + target=target, + message=Message([second]), + timeout_s=timeout_s, + retries=retries, + label="Multi-turn probe (turn 2)", + ) + + +async def _probe_json_output_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a request asking for JSON-mode output. + + This probe is only meaningful for targets that translate PyRIT's JSON + metadata hints into native provider request fields. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the JSON-mode request succeeded; ``False`` otherwise. + """ + conversation_id = _new_conversation_id() + piece = MessagePiece( + role="user", + original_value='Respond with a JSON object: {"ok": true}.', + original_value_data_type="text", + conversation_id=conversation_id, + # This only becomes a real JSON-mode request on targets that honor + # PyRIT's JSON metadata contract when building the provider payload. + prompt_metadata=_probe_metadata({"response_format": "json"}), + ) + return await _send_and_check_async( + target=target, message=Message([piece]), timeout_s=timeout_s, retries=retries, label="JSON-output probe" + ) + + +async def _probe_json_schema_async(target: PromptTarget, timeout_s: float, retries: int = 1) -> bool: + """ + Probe whether ``target`` accepts a request constrained by a JSON schema. + + This probe is only meaningful for targets that translate PyRIT's JSON + metadata hints into native provider request fields. + + Args: + target (PromptTarget): The target to probe. + timeout_s (float): Per-attempt timeout in seconds. + retries (int): Number of additional attempts after the first failure. + Only exceptions/timeouts are retried; an explicit error response + is final. Defaults to 1. + + Returns: + bool: ``True`` if the schema-constrained request succeeded; ``False`` otherwise. + """ + schema = { + "type": "object", + "properties": {"ok": {"type": "boolean"}}, + "required": ["ok"], + "additionalProperties": False, + } + conversation_id = _new_conversation_id() + piece = MessagePiece( + role="user", + original_value='Respond with a JSON object matching the schema: {"ok": true}.', + original_value_data_type="text", + conversation_id=conversation_id, + # As above, this probe is only strong for targets that map these + # metadata keys to native JSON-schema request parameters. + prompt_metadata=_probe_metadata( + { + "response_format": "json", + "json_schema": json.dumps(schema), + } + ), + ) + return await _send_and_check_async( + target=target, message=Message([piece]), timeout_s=timeout_s, retries=retries, label="JSON-schema probe" + ) + + +# Registry of capabilities that can be queried via a live API call. +# Capabilities not present here fall back to the target's declared support. +_CAPABILITY_PROBES: dict[CapabilityName, _CapabilityProbe] = { + CapabilityName.SYSTEM_PROMPT: _probe_system_prompt_async, + CapabilityName.MULTI_MESSAGE_PIECES: _probe_multi_message_pieces_async, + CapabilityName.MULTI_TURN: _probe_multi_turn_async, + CapabilityName.JSON_OUTPUT: _probe_json_output_async, + CapabilityName.JSON_SCHEMA: _probe_json_schema_async, +} + + +async def _discover_capability_flags_async( + *, + target: PromptTarget, + capabilities: Iterable[CapabilityName] | None = None, + per_probe_timeout_s: float = DEFAULT_PROBE_TIMEOUT_SECONDS, + retries: int = 1, +) -> set[CapabilityName]: + """ + Probe which capabilities ``target`` accepts. + + Registered capabilities are checked with live requests. Capabilities + without a live probe fall back to declared native support. + + Args: + target (PromptTarget): The target to probe. + capabilities (Iterable[CapabilityName] | None): Capabilities to check. + Defaults to every member of :class:`CapabilityName`. + per_probe_timeout_s (float): Per-attempt timeout (seconds) applied to + each probe request. Defaults to + :data:`DEFAULT_PROBE_TIMEOUT_SECONDS`. + retries (int): Number of additional attempts after the first failure + for each probe. Only exceptions/timeouts are retried; an explicit + error response is final. Set to ``0`` to disable retries. + Defaults to 1. + + Returns: + set[CapabilityName]: The capabilities confirmed to work against the target. + """ + capabilities_to_check: list[CapabilityName] = ( + list(capabilities) if capabilities is not None else list(CapabilityName) + ) + + queried: set[CapabilityName] = set() + json_capabilities = {CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA} + queried_json_capabilities: set[CapabilityName] = set() + with _permissive_configuration(target=target): + for capability in capabilities_to_check: + probe = _CAPABILITY_PROBES.get(capability) + if probe is None: + # Capabilities without a probe are handled after the permissive + # override is removed so we can read the target's native flags. + continue + + try: + # "Supported" means the request was accepted. A target can + # still ignore the feature semantics after accepting the call. + if await probe(target, per_probe_timeout_s, retries): + queried.add(capability) + if capability in json_capabilities: + queried_json_capabilities.add(capability) + except Exception as exc: + logger.debug("Probe for %s raised: %s", capability.value, exc) + + # JSON probes only verify the target accepted the request, not that the + # target translated the JSON metadata into provider request fields. Emit + # a single summary line when probes succeeded against a target that does + # not enforce JSON hints, so the result is treated as an upper bound. + # ``isinstance`` covers user-defined subclasses of enforcing targets. + if queried_json_capabilities and not isinstance(target, _json_enforcing_target_types()): + logger.debug( + "JSON capability probes %s succeeded for %s, but this target does not translate " + "prompt_metadata JSON hints into provider request fields; treat the result as upper-bound support only.", + sorted(c.value for c in queried_json_capabilities), + type(target).__name__, + ) + + # Read unprobed capabilities from target.capabilities, not + # target.configuration, so ADAPTed behavior is not reported as native + # support. + for capability in capabilities_to_check: + if capability not in _CAPABILITY_PROBES and target.capabilities.includes(capability=capability): + queried.add(capability) + + return queried + + +# --------------------------------------------------------------------------- +# Modality query +# --------------------------------------------------------------------------- + + +# Default mapping of non-text modalities to packaged probe assets. Callers can +# override via the ``test_assets`` parameter of +# :func:`_discover_input_modalities_async`. Modalities whose assets do not exist +# on disk are skipped (logged and excluded from the result). +DEFAULT_TEST_ASSETS: dict[PromptDataType, str] = { + "audio_path": str(_TARGET_CAPABILITIES_DATASET_PATH / "probe_audio.wav"), + "image_path": str(_TARGET_CAPABILITIES_DATASET_PATH / "probe_image.png"), +} + + +async def _discover_input_modalities_async( + *, + target: PromptTarget, + test_modalities: set[frozenset[PromptDataType]] | None = None, + test_assets: dict[PromptDataType, str] | None = None, + per_probe_timeout_s: float = DEFAULT_PROBE_TIMEOUT_SECONDS, + retries: int = 1, +) -> set[frozenset[PromptDataType]]: + """ + Probe which input modality combinations ``target`` accepts. + + Each modality combination is checked with a minimal request built from the + supplied test assets. + + Args: + target (PromptTarget): The target to probe. + test_modalities (set[frozenset[PromptDataType]] | None): Specific + modality combinations to test. Defaults to the combinations + declared in ``target.capabilities.input_modalities``. + test_assets (dict[PromptDataType, str] | None): Mapping from + non-text modality to a file path used as the probe payload. + Defaults to :data:`DEFAULT_TEST_ASSETS`. Combinations whose + non-text assets are missing on disk are skipped. + per_probe_timeout_s (float): Per-attempt timeout (seconds) applied to + each probe request. Defaults to + :data:`DEFAULT_PROBE_TIMEOUT_SECONDS`. + retries (int): Number of additional attempts after the first failure + for each probe. Only exceptions/timeouts are retried; an explicit + error response is final. Set to ``0`` to disable retries. + Defaults to 1. + + Returns: + set[frozenset[PromptDataType]]: The modality combinations confirmed + to work against the target. + """ + if test_modalities is None: + declared = target.capabilities.input_modalities + test_modalities = set(declared) + elif not test_modalities: + logger.info("_discover_input_modalities_async called with an empty test_modalities set; nothing to probe.") + return set() + + assets = test_assets if test_assets is not None else DEFAULT_TEST_ASSETS + + queried: set[frozenset[PromptDataType]] = set() + with _permissive_configuration(target=target, extra_input_modalities=test_modalities): + for combination in test_modalities: + try: + message = _create_test_message(modalities=combination, test_assets=assets) + except FileNotFoundError as exc: + # Skip combinations we cannot construct a valid probe payload for. + logger.info("Skipping modality %s: %s", combination, exc) + continue + except ValueError as exc: + logger.info("Skipping modality %s: %s", combination, exc) + continue + + # "Supported" means the request was accepted. A target may still + # ignore the non-text payload after accepting it. + if await _send_and_check_async( + target=target, + message=message, + timeout_s=per_probe_timeout_s, + retries=retries, + label=f"Modality probe {sorted(combination)}", + ): + queried.add(combination) + + return queried + + +async def discover_target_capabilities_async( + *, + target: PromptTarget, + per_probe_timeout_s: float = DEFAULT_PROBE_TIMEOUT_SECONDS, + test_modalities: set[frozenset[PromptDataType]] | None = None, + test_assets: dict[PromptDataType, str] | None = None, + capabilities: Iterable[CapabilityName] | None = None, + retries: int = 1, + apply: bool = False, +) -> TargetCapabilities: + """ + Probe both the boolean capability flags and the input modality combinations + that ``target`` accepts, and return a merged best-effort + :class:`TargetCapabilities`. + + Boolean capabilities with a registered probe are checked with live + requests; capabilities without a probe fall back to the target's + declared native support. Each input modality combination is checked + with a minimal request built from the supplied test assets. + "Supported" means the request was accepted — a target that silently + ignores a feature is still reported as supporting it. + + Args: + target (PromptTarget): The target to probe. + per_probe_timeout_s (float): Per-attempt timeout (seconds) applied to + each probe request. + test_modalities (set[frozenset[PromptDataType]] | None): Specific + modality combinations to probe. Defaults to the target's declared + ``input_modalities``. Combinations not listed here fall back to + the target's declared support. + test_assets (dict[PromptDataType, str] | None): Mapping from non-text + modality to a file path used as the probe payload. Defaults to + :data:`DEFAULT_TEST_ASSETS`. Combinations whose non-text assets + are missing on disk are skipped. + capabilities (Iterable[CapabilityName] | None): Capabilities to probe. + Defaults to every member of :class:`CapabilityName`. Capabilities + not listed here fall back to the target's declared support. + retries (int): Number of additional attempts after the first failure + for each probe. Only exceptions/timeouts are retried; an explicit + error response is final. Set to ``0`` to disable retries. + Defaults to 1. + apply (bool): If True, install the discovered capabilities on ``target`` + via :meth:`PromptTarget.apply_capabilities` before returning. + Probe results are an upper bound (the request was accepted, not + necessarily honored), so leave this False when you want to inspect + or diff the result before committing to it. Defaults to False. + + Returns: + TargetCapabilities: A merged capability view: probed where possible, + declared where probing is unavailable or out of scope. + """ + capabilities_to_probe = list(capabilities) if capabilities is not None else None + + queried_caps = await _discover_capability_flags_async( + target=target, + capabilities=capabilities_to_probe, + per_probe_timeout_s=per_probe_timeout_s, + retries=retries, + ) + queried_modalities = await _discover_input_modalities_async( + target=target, + test_modalities=test_modalities, + test_assets=test_assets, + per_probe_timeout_s=per_probe_timeout_s, + retries=retries, + ) + + declared = target.capabilities + # If the caller narrows the capability set, leave the rest at their + # declared values instead of silently forcing them to False. + probed: set[CapabilityName] = ( + set(capabilities_to_probe) if capabilities_to_probe is not None else set(CapabilityName) + ) + + def _resolve(name: CapabilityName) -> bool: + if name in probed: + return name in queried_caps + return bool(getattr(declared, name.value)) + + resolved_multi_turn = _resolve(CapabilityName.MULTI_TURN) + # Editable history is only meaningful if multi-turn probing/declaration + # also resolved to True. + resolved_editable_history = declared.supports_editable_history and resolved_multi_turn + if test_modalities is None: + # Mirror the boolean fallback: combinations the probe could not confirm + # fall back to the target's declared support rather than being silently + # dropped (e.g. on transient network failure). + resolved_input_modalities = frozenset(queried_modalities | declared.input_modalities) + else: + resolved_input_modalities = frozenset( + queried_modalities | (declared.input_modalities - frozenset(test_modalities)) + ) + + resolved = TargetCapabilities( + supports_multi_turn=resolved_multi_turn, + supports_multi_message_pieces=_resolve(CapabilityName.MULTI_MESSAGE_PIECES), + supports_json_schema=_resolve(CapabilityName.JSON_SCHEMA), + supports_json_output=_resolve(CapabilityName.JSON_OUTPUT), + supports_editable_history=resolved_editable_history, + supports_system_prompt=_resolve(CapabilityName.SYSTEM_PROMPT), + input_modalities=resolved_input_modalities, + # Output modalities are still declarative because probing them would + # require target-specific response inspection. + output_modalities=declared.output_modalities, + ) + + if apply: + target.apply_capabilities(capabilities=resolved) + + return resolved + + +def _create_test_message( + *, + modalities: frozenset[PromptDataType], + test_assets: dict[PromptDataType, str], +) -> Message: + """ + Build a minimal :class:`Message` that exercises ``modalities``. + + Args: + modalities (frozenset[PromptDataType]): The modalities to include. + test_assets (dict[PromptDataType, str]): Mapping from non-text + modality to a file path used for the probe. + + Returns: + Message: A message containing one piece per modality. + + Raises: + FileNotFoundError: If a configured asset path does not exist. + ValueError: If a non-text modality has no configured asset. + """ + conversation_id = f"modality-probe-{uuid.uuid4()}" + pieces: list[MessagePiece] = [] + + for modality in modalities: + if modality == "text": + pieces.append( + MessagePiece( + role="user", + original_value="test", + original_value_data_type="text", + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + ) + continue + + asset_path = test_assets.get(modality) + if asset_path is None: + raise ValueError(f"No test asset configured for modality '{modality}'.") + if not os.path.isfile(asset_path): + raise FileNotFoundError(f"Test asset for modality '{modality}' not found at: {asset_path}") + + pieces.append( + MessagePiece( + role="user", + original_value=asset_path, + original_value_data_type=modality, + conversation_id=conversation_id, + prompt_metadata=_probe_metadata(), + ) + ) + + return Message(pieces) diff --git a/pyrit/prompt_target/common/prompt_target.py b/pyrit/prompt_target/common/prompt_target.py index a4cf2a8a96..42f2688b01 100644 --- a/pyrit/prompt_target/common/prompt_target.py +++ b/pyrit/prompt_target/common/prompt_target.py @@ -406,6 +406,29 @@ def capabilities(self) -> TargetCapabilities: """ return self._configuration.capabilities + def apply_capabilities(self, *, capabilities: TargetCapabilities) -> None: + """ + Replace this target's capabilities, preserving the existing handling policy. + The normalization pipeline is rebuilt from the input capabilities and the + current policy. + + Policy is preserved because it expresses user intent (ADAPT vs RAISE), + independent of what the probe found. To change policy or normalizer + overrides, build a new :class:`TargetConfiguration` and pass it via + ``custom_configuration`` at construction time instead. + + Note: + This mutates the target's identifier (derived from the configuration). + + Args: + capabilities (TargetCapabilities): The capabilities to install on + this instance. + """ + self._configuration = TargetConfiguration( + capabilities=capabilities, + policy=self._configuration.policy, + ) + @classmethod def get_default_configuration(cls, underlying_model: str | None = None) -> TargetConfiguration: """ diff --git a/tests/unit/common/test_common_net_utility.py b/tests/unit/common/test_common_net_utility.py index 58fff4b222..5088a166de 100644 --- a/tests/unit/common/test_common_net_utility.py +++ b/tests/unit/common/test_common_net_utility.py @@ -77,8 +77,14 @@ def response_callback(request): async def test_debug_is_false_by_default(): with patch("pyrit.common.net_utility.get_httpx_client") as mock_get_httpx_client: - mock_client_instance = MagicMock() - mock_get_httpx_client.return_value = mock_client_instance + mock_client_context = MagicMock() + mock_client = MagicMock() + mock_client.request = AsyncMock( + return_value=httpx.Response(status_code=200, request=httpx.Request("GET", "http://example.com")) + ) + mock_client_context.__aenter__ = AsyncMock(return_value=mock_client) + mock_client_context.__aexit__ = AsyncMock(return_value=None) + mock_get_httpx_client.return_value = mock_client_context await make_request_and_raise_if_error_async(endpoint_uri="http://example.com", method="GET") diff --git a/tests/unit/prompt_target/target/test_prompt_target.py b/tests/unit/prompt_target/target/test_prompt_target.py index 93bcc0bf19..f3174c2649 100644 --- a/tests/unit/prompt_target/target/test_prompt_target.py +++ b/tests/unit/prompt_target/target/test_prompt_target.py @@ -634,3 +634,50 @@ async def normalize_async(self, messages): # pragma: no cover - not exercised ) assert a.get_identifier().hash != b.get_identifier().hash + + +def test_apply_capabilities_replaces_capabilities_and_preserves_policy(patch_central_database): + initial_policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ) + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False), + policy=initial_policy, + ), + ) + + new_caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + target.apply_capabilities(capabilities=new_caps) + + assert target.capabilities == new_caps + # Policy is preserved by identity, not just by value. + assert target.configuration.policy is initial_policy + + +def test_apply_capabilities_rebuilds_pipeline(patch_central_database): + adapt_policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ) + target = OpenAIChatTarget( + model_name="gpt-4o", + endpoint="https://mock.azure.com/", + api_key="mock-api-key", + custom_configuration=TargetConfiguration( + capabilities=TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True), + policy=adapt_policy, + ), + ) + assert target.configuration.pipeline._normalizers, "Expected ADAPT pipeline to be non-empty" + + target.apply_capabilities(capabilities=TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True)) + assert not target.configuration.pipeline._normalizers, "Expected pipeline to be rebuilt as empty" diff --git a/tests/unit/prompt_target/test_discover_target_capabilities.py b/tests/unit/prompt_target/test_discover_target_capabilities.py new file mode 100644 index 0000000000..e21f24a6f2 --- /dev/null +++ b/tests/unit/prompt_target/test_discover_target_capabilities.py @@ -0,0 +1,1078 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import asyncio +import json +import logging +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +from pyrit.models import Message, MessagePiece, PromptDataType +from pyrit.prompt_target.common.discover_target_capabilities import ( + _CAPABILITY_PROBES, + DEFAULT_TEST_ASSETS, + _create_test_message, + _discover_capability_flags_async, + _discover_input_modalities_async, + _permissive_configuration, + discover_target_capabilities_async, +) +from pyrit.prompt_target.common.prompt_target import PromptTarget +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration +from tests.unit.mocks import MockPromptTarget + + +class _RealValidationTarget(PromptTarget): + """ + Bare ``PromptTarget`` subclass that does NOT override ``_validate_request``. + + Tests that need to verify ``_permissive_configuration`` actually bypasses + the validation guard use this instead of ``MockPromptTarget`` (which + no-ops ``_validate_request``). + """ + + _DEFAULT_CONFIGURATION: TargetConfiguration = TargetConfiguration( + capabilities=TargetCapabilities(), + ) + + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return _ok_response() + + +def _ok_response(*, conversation_id: str = "probe", text: str = "ok") -> list[Message]: + return [ + Message( + [ + MessagePiece( + role="assistant", + original_value=text, + original_value_data_type="text", + conversation_id=conversation_id, + response_error="none", + ) + ] + ) + ] + + +def _error_response(*, conversation_id: str = "probe") -> list[Message]: + return [ + Message( + [ + MessagePiece( + role="assistant", + original_value="blocked", + original_value_data_type="text", + conversation_id=conversation_id, + response_error="blocked", + ) + ] + ) + ] + + +@pytest.mark.usefixtures("patch_central_database") +class TestPermissiveConfiguration: + def test_replaces_and_restores_configuration(self) -> None: + target = MockPromptTarget() + original = target.configuration + + with _permissive_configuration(target=target): + permissive = target.configuration + assert permissive is not original + for capability in CapabilityName: + assert permissive.includes(capability=capability) + + assert target.configuration is original + + def test_restores_on_exception(self) -> None: + target = MockPromptTarget() + original = target.configuration + + with pytest.raises(RuntimeError): + with _permissive_configuration(target=target): + raise RuntimeError("boom") + + assert target.configuration is original + + +@pytest.mark.usefixtures("patch_central_database") +class TestDiscoverTargetCapabilitiesAsync: + async def test_returns_only_supported_when_all_probes_succeed(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_capability_flags_async(target=target) + + # Every capability with a probe should be in the result. + for capability in _CAPABILITY_PROBES: + assert capability in result + + async def test_excludes_capabilities_when_probe_fails(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("nope")) # type: ignore[method-assign] + + result = await _discover_capability_flags_async(target=target) + + for capability in _CAPABILITY_PROBES: + assert capability not in result + + async def test_excludes_capabilities_when_response_has_error(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_error_response()) # type: ignore[method-assign] + + result = await _discover_capability_flags_async(target=target) + + for capability in _CAPABILITY_PROBES: + assert capability not in result + + async def test_filters_by_requested_capabilities(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + requested = {CapabilityName.SYSTEM_PROMPT, CapabilityName.MULTI_TURN} + result = await _discover_capability_flags_async(target=target, capabilities=requested) + + assert result == requested + + async def test_capability_without_probe_falls_back_to_declared_support(self) -> None: + target = MockPromptTarget() + # Override the configuration so editable_history is declared as supported. + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_editable_history=True), + ) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.EDITABLE_HISTORY}, + ) + + assert result == {CapabilityName.EDITABLE_HISTORY} + + async def test_capability_without_probe_excluded_when_not_declared(self) -> None: + target = MockPromptTarget() + # Override to a configuration that does NOT declare editable_history. + target._configuration = TargetConfiguration(capabilities=TargetCapabilities()) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.EDITABLE_HISTORY}, + ) + + assert result == set() + + async def test_capability_without_probe_excluded_when_only_adapted(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ + ADAPT in the policy must NOT count as native support for the fallback. + + Today every adaptable capability also has a probe, so this scenario only + arises if a future capability is declared adaptable without a probe. + We simulate that by removing SYSTEM_PROMPT from the registry and + configuring the target with ``ADAPT`` for it but no native support. + """ + from pyrit.prompt_target.common import discover_target_capabilities as qtc + from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + UnsupportedCapabilityBehavior, + ) + + patched_probes = {k: v for k, v in qtc._CAPABILITY_PROBES.items() if k is not CapabilityName.SYSTEM_PROMPT} + monkeypatch.setattr(qtc, "_CAPABILITY_PROBES", patched_probes) + + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(), # no native SYSTEM_PROMPT + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + } + ), + ) + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == set() + + async def test_accepts_single_pass_iterable(self) -> None: + """Passing a generator must not silently drop fallback (non-probed) capabilities.""" + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_editable_history=True), + ) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + gen = (c for c in [CapabilityName.SYSTEM_PROMPT, CapabilityName.EDITABLE_HISTORY]) + result = await _discover_capability_flags_async(target=target, capabilities=gen) + + assert CapabilityName.SYSTEM_PROMPT in result + assert CapabilityName.EDITABLE_HISTORY in result + + async def test_retries_zero_disables_retry(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("boom")) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + retries=0, + ) + + assert result == set() + assert target._send_prompt_to_target_async.await_count == 1 + + async def test_retries_use_exponential_backoff(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("boom")) # type: ignore[method-assign] + + with patch( + "pyrit.prompt_target.common.discover_target_capabilities.asyncio.sleep", new_callable=AsyncMock + ) as sleep_mock: + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + retries=2, + ) + + assert result == set() + assert sleep_mock.await_args_list[0].args == (0.1,) + assert sleep_mock.await_args_list[1].args == (0.2,) + + async def test_non_retryable_validation_errors_fail_fast(self) -> None: + """ + Deterministic errors (ValueError/TypeError/AttributeError) come from + malformed payloads or programmer error and will not become valid on + a retry. They must fail the probe immediately without consuming the + retry budget or sleeping for backoff. + """ + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock( # type: ignore[method-assign] + side_effect=ValueError("malformed payload") + ) + + with patch( + "pyrit.prompt_target.common.discover_target_capabilities.asyncio.sleep", new_callable=AsyncMock + ) as sleep_mock: + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + retries=3, + ) + + assert result == set() + # No retries consumed and no backoff sleeps issued. + assert target._send_prompt_to_target_async.await_count == 1 + sleep_mock.assert_not_awaited() + + async def test_restores_configuration_after_probing(self) -> None: + target = MockPromptTarget() + original = target.configuration + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await _discover_capability_flags_async(target=target) + + assert target.configuration is original + + async def test_multi_turn_probe_sends_history_on_second_call(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.MULTI_TURN}, + ) + + # Multi-turn probe sends two requests on the same conversation_id, and + # seeds memory between them so the second call carries real history. + calls = target._send_prompt_to_target_async.await_args_list + assert len(calls) == 2 + + first_conv = calls[0].kwargs["normalized_conversation"] + second_conv = calls[1].kwargs["normalized_conversation"] + + first_conv_id = first_conv[-1].message_pieces[0].conversation_id + second_conv_id = second_conv[-1].message_pieces[0].conversation_id + assert first_conv_id == second_conv_id + + # First call is a single-turn user message; the second call must include + # the seeded user + assistant history followed by the new user turn. + assert len(first_conv) == 1 + assert len(second_conv) >= 3 + roles = [msg.message_pieces[0]._role for msg in second_conv] + assert roles[-3:] == ["user", "assistant", "user"] + + async def test_multi_turn_probe_short_circuits_on_first_failure(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("first call fails")) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.MULTI_TURN}, + ) + + assert result == set() + # _send_and_check_async retries once on exception, so the failing + # first turn is attempted twice; the second turn is never reached. + assert target._send_prompt_to_target_async.await_count == 2 + + async def test_json_schema_probe_sends_schema_in_metadata(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_SCHEMA}, + ) + + normalized: list[Message] = target._send_prompt_to_target_async.await_args.kwargs["normalized_conversation"] + metadata = normalized[-1].message_pieces[0].prompt_metadata + assert metadata is not None + assert metadata["response_format"] == "json" + # Schema is JSON-encoded into a string for prompt_metadata's value type. + schema = json.loads(metadata["json_schema"]) + assert schema["type"] == "object" + + @pytest.mark.parametrize("capability", [CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA]) + async def test_logs_debug_for_unenforced_json_probe( + self, capability: CapabilityName, caplog: pytest.LogCaptureFixture + ) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + with caplog.at_level(logging.DEBUG): + result = await _discover_capability_flags_async(target=target, capabilities={capability}) + + assert result == {capability} + matching = [r for r in caplog.records if r.message.startswith("JSON capability probes")] + assert len(matching) == 1 + assert capability.value in matching[0].message + + async def test_logs_unenforced_json_probe_summary_once_for_both_capabilities( + self, caplog: pytest.LogCaptureFixture + ) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + with caplog.at_level(logging.DEBUG): + await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA}, + ) + + # A single summary line covers both probed JSON capabilities. + matching = [r for r in caplog.records if r.message.startswith("JSON capability probes")] + assert len(matching) == 1 + assert CapabilityName.JSON_OUTPUT.value in matching[0].message + assert CapabilityName.JSON_SCHEMA.value in matching[0].message + + async def test_does_not_log_unenforced_json_probe_when_probe_fails(self, caplog: pytest.LogCaptureFixture) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("boom")) # type: ignore[method-assign] + + with caplog.at_level(logging.DEBUG): + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.JSON_SCHEMA}, + retries=0, + ) + + assert result == set() + assert not any(r.message.startswith("JSON capability probes") for r in caplog.records) + + async def test_does_not_log_debug_for_enforced_json_probe(self, caplog: pytest.LogCaptureFixture) -> None: + target_type = type("FakeEnforcingTarget", (MockPromptTarget,), {}) + target = target_type() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + with ( + patch( + "pyrit.prompt_target.common.discover_target_capabilities._json_enforcing_target_types", + return_value=(target_type,), + ), + caplog.at_level(logging.DEBUG), + ): + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + ) + + assert result == {CapabilityName.JSON_OUTPUT} + assert not any(r.message.startswith("JSON capability probes") for r in caplog.records) + + async def test_subclass_of_enforced_target_does_not_log(self, caplog: pytest.LogCaptureFixture) -> None: + # ``isinstance`` covers user-defined subclasses of enforcing targets. + base = type("EnforcingBase", (MockPromptTarget,), {}) + sub = type("UserSubclass", (base,), {}) + target = sub() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + with ( + patch( + "pyrit.prompt_target.common.discover_target_capabilities._json_enforcing_target_types", + return_value=(base,), + ), + caplog.at_level(logging.DEBUG), + ): + await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + ) + + assert not any(r.message.startswith("JSON capability probes") for r in caplog.records) + + async def test_system_prompt_probe_installs_system_message_and_sends_user(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + # The probe writes a system message directly to memory (bypassing + # PromptTarget.set_system_prompt, which subclasses can override) and + # then sends a user-role message. Message.validate forbids mixed + # roles in a single Message, so the system and user turns are + # separate. Verify the system message is in memory and the wire + # payload contains the system + user history. + normalized: list[Message] = target._send_prompt_to_target_async.await_args.kwargs["normalized_conversation"] + roles_sent = [piece._role for msg in normalized for piece in msg.message_pieces] + assert "system" in roles_sent + assert roles_sent[-1] == "user" + # The last sent Message itself should be user-only. + assert [piece._role for piece in normalized[-1].message_pieces] == ["user"] + + async def test_multi_message_pieces_probe_sends_two_pieces(self) -> None: + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.MULTI_MESSAGE_PIECES}, + ) + + normalized: list[Message] = target._send_prompt_to_target_async.await_args.kwargs["normalized_conversation"] + assert len(normalized[-1].message_pieces) == 2 + + async def test_probes_run_under_permissive_configuration(self) -> None: + """ + Even when the target declares no boolean capabilities, the probe should + still execute because the configuration is temporarily permissive. + + Uses ``_RealValidationTarget`` so that ``_validate_request`` actually + runs and would reject the multi-piece probe were the override absent. + """ + target = _RealValidationTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.MULTI_MESSAGE_PIECES}, + ) + + # Probe was actually invoked through the full send_prompt_async pipeline, + # which means _validate_request ran and was satisfied by the permissive + # override (the bare target declares no capabilities natively). + assert send_mock.await_count >= 1 + assert CapabilityName.MULTI_MESSAGE_PIECES in result + + async def test_probed_capability_excluded_when_only_adapted(self) -> None: + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_system_prompt=False), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + } + ), + ) + + async def reject_system_roles(*, normalized_conversation: list[Message]) -> list[Message]: + roles = [piece._role for message in normalized_conversation for piece in message.message_pieces] + if "system" in roles: + raise RuntimeError("system messages are not natively supported") + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=reject_system_roles) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == set() + + async def test_probe_configuration_does_not_reuse_adapted_pipeline(self) -> None: + target = MockPromptTarget() + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities(supports_system_prompt=False), + policy=CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + } + ), + ) + + async def require_native_system_role(*, normalized_conversation: list[Message]) -> list[Message]: + roles = [piece._role for message in normalized_conversation for piece in message.message_pieces] + if "system" not in roles: + raise RuntimeError("probe used adapted system-prompt shaping") + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=require_native_system_role) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == {CapabilityName.SYSTEM_PROMPT} + + +@pytest.mark.usefixtures("patch_central_database") +class TestDiscoverTargetCapabilitiesIsolatedTarget: + """Tests using a bare PromptTarget subclass (no PromptChatTarget extras).""" + + async def test_with_minimal_target_subclass(self) -> None: + class _MinimalTarget(PromptTarget): + async def _send_prompt_to_target_async(self, *, normalized_conversation: list[Message]) -> list[Message]: + return _ok_response() + + target = _MinimalTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_capability_flags_async(target=target) + + for capability in _CAPABILITY_PROBES: + assert capability in result + + +# --------------------------------------------------------------------------- +# Modality query tests +# --------------------------------------------------------------------------- + + +def _set_input_modalities( + *, + target: MockPromptTarget, + modalities: set[frozenset[PromptDataType]], +) -> None: + target._configuration = TargetConfiguration( + capabilities=TargetCapabilities( + input_modalities=frozenset(modalities), + ), + ) + + +@pytest.fixture +def image_asset(tmp_path: Path) -> str: + """Create a tiny placeholder file usable as an image_path asset.""" + asset = tmp_path / "test_image.png" + asset.write_bytes(b"\x89PNG\r\n\x1a\n") + return str(asset) + + +@pytest.mark.usefixtures("patch_central_database") +class TestCreateTestMessage: + def test_default_assets_exist_for_packaged_modalities(self) -> None: + msg = _create_test_message( + modalities=frozenset({"audio_path", "image_path"}), + test_assets=DEFAULT_TEST_ASSETS, + ) + + types = {piece.original_value_data_type for piece in msg.message_pieces} + assert types == {"audio_path", "image_path"} + + def test_text_only(self) -> None: + msg = _create_test_message(modalities=frozenset({"text"}), test_assets={}) + assert len(msg.message_pieces) == 1 + assert msg.message_pieces[0].original_value_data_type == "text" + + def test_multimodal_uses_assets(self, image_asset: str) -> None: + msg = _create_test_message( + modalities=frozenset({"text", "image_path"}), + test_assets={"image_path": image_asset}, + ) + types = {piece.original_value_data_type for piece in msg.message_pieces} + assert types == {"text", "image_path"} + + # All pieces share the same conversation_id (Message.validate requires it). + conv_ids = {piece.conversation_id for piece in msg.message_pieces} + assert len(conv_ids) == 1 + + def test_missing_asset_file_raises_filenotfound(self, tmp_path: Path) -> None: + missing_path = str(tmp_path / "does_not_exist.png") + with pytest.raises(FileNotFoundError): + _create_test_message( + modalities=frozenset({"image_path"}), + test_assets={"image_path": missing_path}, + ) + + def test_unconfigured_modality_raises_valueerror(self) -> None: + with pytest.raises(ValueError, match="No test asset configured"): + _create_test_message( + modalities=frozenset({"image_path"}), + test_assets={}, + ) + + +@pytest.mark.usefixtures("patch_central_database") +class TestVerifyTargetModalitiesAsync: + async def test_all_combinations_supported(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_input_modalities_async(target=target) + + assert frozenset({"text"}) in result + + async def test_exception_excludes_combination(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(side_effect=Exception("nope")) # type: ignore[method-assign] + + result = await _discover_input_modalities_async(target=target) + + assert result == set() + + async def test_error_response_excludes_combination(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_error_response()) # type: ignore[method-assign] + + result = await _discover_input_modalities_async(target=target) + + assert result == set() + + async def test_partial_support_via_selective_failure(self, image_asset: str) -> None: + target = MockPromptTarget() + _set_input_modalities( + target=target, + modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, + ) + + async def selective_send(*, normalized_conversation: list[Message]) -> list[Message]: + message = normalized_conversation[-1] + types = {p.original_value_data_type for p in message.message_pieces} + if "image_path" in types: + raise Exception("image not supported") + return _ok_response() + + target._send_prompt_to_target_async = selective_send # type: ignore[method-assign] + + result = await _discover_input_modalities_async( + target=target, + test_assets={"image_path": image_asset}, + ) + + assert frozenset({"text"}) in result + assert frozenset({"text", "image_path"}) not in result + + async def test_explicit_test_modalities_overrides_declared(self, image_asset: str) -> None: + target = MockPromptTarget() + # Declared as text-only, but caller asks us to probe text+image too. + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_input_modalities_async( + target=target, + test_modalities={frozenset({"text"}), frozenset({"text", "image_path"})}, + test_assets={"image_path": image_asset}, + ) + + assert frozenset({"text"}) in result + assert frozenset({"text", "image_path"}) in result + + async def test_combination_skipped_when_asset_missing(self, tmp_path: Path) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text", "image_path"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + # An explicit empty mapping disables the packaged defaults, so + # image_path combinations are skipped instead of probed. + result = await _discover_input_modalities_async(target=target, test_assets={}) + + assert result == set() + assert target._send_prompt_to_target_async.await_count == 0 + + async def test_empty_test_modalities_returns_empty_without_probing(self) -> None: + target = MockPromptTarget() + _set_input_modalities(target=target, modalities={frozenset({"text"})}) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await _discover_input_modalities_async(target=target, test_modalities=set()) + + assert result == set() + assert target._send_prompt_to_target_async.await_count == 0 + + async def test_explicit_test_modalities_runs_under_permissive_configuration(self, image_asset: str) -> None: + """ + Probing a modality combination the target does NOT declare must still + succeed. Uses ``_RealValidationTarget`` so ``_validate_request`` runs + and would reject the multi-piece, non-text payload were the + permissive override absent. + """ + target = _RealValidationTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + result = await _discover_input_modalities_async( + target=target, + test_modalities={frozenset({"text", "image_path"})}, + test_assets={"image_path": image_asset}, + ) + + assert send_mock.await_count == 1 + assert frozenset({"text", "image_path"}) in result + + +@pytest.mark.usefixtures("patch_central_database") +class TestSendAndCheckTimeout: + async def test_timeout_returns_false_after_retries(self) -> None: + """ + When ``send_prompt_async`` exceeds ``per_probe_timeout_s``, the probe + is treated as failed. ``_send_and_check_async`` retries once on + timeout, so the underlying mock is awaited twice and the capability + is excluded from the queried set. + """ + target = MockPromptTarget() + + async def _hang(**_kwargs: object) -> list[Message]: + await asyncio.sleep(10) + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=_hang) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + per_probe_timeout_s=0.01, + ) + + assert result == set() + # One initial attempt plus one retry. + assert target._send_prompt_to_target_async.await_count == 2 + + +@pytest.mark.usefixtures("patch_central_database") +class TestSystemPromptProbeMemoryFailure: + async def test_returns_false_when_memory_seed_raises(self) -> None: + """ + If seeding the system message into memory raises (e.g. backend + offline), the system-prompt probe returns False without attempting + the user send. + """ + target = MockPromptTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + with patch.object(target._memory, "add_message_to_memory", side_effect=RuntimeError("memory offline")): + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.SYSTEM_PROMPT}, + ) + + assert result == set() + # The user send is never attempted because seeding failed. + send_mock.assert_not_awaited() + + +@pytest.mark.usefixtures("patch_central_database") +class TestVerifyTargetAsync: + async def test_returns_target_capabilities_assembled_from_probes(self) -> None: + """ + ``discover_target_capabilities_async`` runs both the capability and modality probes + and assembles a :class:`TargetCapabilities` populated from the + queried results, copying ``output_modalities`` from the target's + declared capabilities and deriving editable history conservatively. + """ + declared = TargetCapabilities( + input_modalities=frozenset({frozenset({"text"})}), + output_modalities=frozenset({frozenset({"text"})}), + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target, per_probe_timeout_s=5.0) + + assert isinstance(result, TargetCapabilities) + # Single-piece probes that don't touch memory always succeed when + # the underlying send returns a clean response. + assert result.supports_multi_message_pieces is True + assert result.supports_json_schema is True + assert result.supports_json_output is True + # Editable history is conservative and therefore cannot remain true + # when multi-turn support was not confirmed by probing. + assert result.supports_editable_history is False + # Modalities returned from the modality probe (text combination). + assert frozenset({"text"}) in result.input_modalities + # Output modalities copied through (not probed). + assert result.output_modalities == declared.output_modalities + + async def test_excludes_capabilities_when_probe_send_fails(self) -> None: + """ + When the underlying send raises, no capability or modality is + queried, but ``supports_editable_history``, ``output_modalities``, + and declared ``input_modalities`` are still preserved conservatively + from the declared capabilities. + """ + declared = TargetCapabilities( + supports_editable_history=True, + output_modalities=frozenset({frozenset({"text"})}), + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target, per_probe_timeout_s=0.5) + + assert result.supports_multi_turn is False + assert result.supports_system_prompt is False + assert result.supports_json_output is False + assert result.supports_json_schema is False + assert result.supports_multi_message_pieces is False + # Editable history is derived conservatively and must fall when + # multi-turn probing disproves the prerequisite capability. + assert result.supports_editable_history is False + # When probing cannot confirm modalities, declared modalities are + # preserved (mirroring the boolean fallback semantics). + assert result.input_modalities == declared.input_modalities + # Output modalities still copied. + assert result.output_modalities == declared.output_modalities + + async def test_empty_response_treated_as_failure(self) -> None: + """A target returning an empty response list must NOT be reported as supporting probes.""" + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=[]) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT, CapabilityName.MULTI_MESSAGE_PIECES}, + ) + + assert result == set() + + async def test_response_with_no_pieces_treated_as_failure(self) -> None: + """Responses whose Messages have no pieces must also be rejected.""" + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock( # type: ignore[method-assign] + return_value=[Message.__new__(Message)] + ) + # Bypass __init__ to construct a Message with no pieces (Message.__init__ rejects empty). + empty_msg = target._send_prompt_to_target_async.return_value[0] + empty_msg.message_pieces = [] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + ) + + assert result == set() + + async def test_mixed_empty_message_in_response_treated_as_failure(self) -> None: + """Any empty Message in a multi-message response must cause the probe to fail.""" + target = MockPromptTarget() + ok = _ok_response()[0] + empty = Message.__new__(Message) + empty.message_pieces = [] + target._send_prompt_to_target_async = AsyncMock(return_value=[ok, empty]) # type: ignore[method-assign] + + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + ) + + assert result == set() + + async def test_discover_target_capabilities_async_forwards_test_modalities(self, image_asset: str) -> None: + declared = TargetCapabilities(input_modalities=frozenset({frozenset({"text"})})) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) + + extra_combo = frozenset({"text", "image_path"}) + result = await discover_target_capabilities_async( + target=target, + test_modalities={extra_combo}, + test_assets={"image_path": image_asset}, + per_probe_timeout_s=2.0, + ) + + # The undeclared combination is in the result only if test_modalities was forwarded. + assert extra_combo in result.input_modalities + + async def test_discover_target_capabilities_async_preserves_declared_modalities_when_test_modalities_narrowed( + self, image_asset: str + ) -> None: + declared_combo = frozenset({"text"}) + probed_combo = frozenset({"text", "image_path"}) + declared = TargetCapabilities(input_modalities=frozenset({declared_combo, probed_combo})) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) + + result = await discover_target_capabilities_async( + target=target, + test_modalities={probed_combo}, + test_assets={"image_path": image_asset}, + per_probe_timeout_s=2.0, + ) + + assert result.input_modalities == frozenset({declared_combo, probed_combo}) + + async def test_discover_target_capabilities_async_forwards_capabilities(self) -> None: + """``discover_target_capabilities_async`` must forward ``capabilities`` to narrow the probe set.""" + target = MockPromptTarget() + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + per_probe_timeout_s=2.0, + ) + + # Only the JSON_OUTPUT probe (1 send) and the modality probe(s) should run; + # if `capabilities` were ignored, all 5 capability probes would fire (>= 6 sends + # because multi-turn issues 2 sends). + assert target._send_prompt_to_target_async.await_count <= 3 + + async def test_discover_target_capabilities_async_preserves_declared_when_capabilities_narrowed(self) -> None: + """ + When ``capabilities`` narrows the probe set, capabilities NOT in the + narrowed set must fall back to the target's declared values rather + than being silently reset to False. + """ + declared = TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=True, + supports_json_schema=True, + supports_editable_history=True, + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async( + target=target, + capabilities={CapabilityName.JSON_OUTPUT}, + per_probe_timeout_s=2.0, + ) + + # The probed capability reflects the queried result. + assert result.supports_json_output is True + # Non-probed capabilities fall back to declared values. + assert result.supports_multi_turn is True + assert result.supports_system_prompt is True + assert result.supports_json_schema is True + assert result.supports_editable_history is True + + async def test_discover_target_capabilities_async_drops_editable_history_when_multi_turn_probe_fails(self) -> None: + """Editable history must not remain true when probing disproves multi-turn support.""" + declared = TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + output_modalities=frozenset({frozenset({"text"})}), + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + + async def selective_send(*, normalized_conversation: list[Message]) -> list[Message]: + latest_text = normalized_conversation[-1].message_pieces[0].original_value + if latest_text == "My favorite color is blue." or latest_text == "What did I just tell you?": + raise RuntimeError("multi-turn unsupported") + return _ok_response() + + target._send_prompt_to_target_async = AsyncMock(side_effect=selective_send) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target, per_probe_timeout_s=2.0) + + assert result.supports_multi_turn is False + assert result.supports_editable_history is False + + async def test_discover_target_capabilities_async_accepts_single_pass_iterable(self) -> None: + declared = TargetCapabilities( + supports_multi_turn=True, + supports_editable_history=True, + ) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + gen = (c for c in [CapabilityName.JSON_OUTPUT, CapabilityName.EDITABLE_HISTORY]) + result = await discover_target_capabilities_async( + target=target, + capabilities=gen, + per_probe_timeout_s=2.0, + ) + + assert result.supports_json_output is True + assert result.supports_editable_history is True + + async def test_discover_target_capabilities_async_apply_installs_capabilities_on_target(self) -> None: + """When ``apply=True``, the discovered capabilities are installed on the target.""" + declared = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + assert target.capabilities.supports_multi_turn is False + + result = await discover_target_capabilities_async(target=target, per_probe_timeout_s=2.0, apply=True) + + assert target.capabilities == result + assert target.capabilities.supports_multi_turn is True + + async def test_discover_target_capabilities_async_apply_defaults_to_false(self) -> None: + """By default, ``discover_target_capabilities_async`` must not mutate the target.""" + declared = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + target = MockPromptTarget() + target._configuration = TargetConfiguration(capabilities=declared) + target._send_prompt_to_target_async = AsyncMock(return_value=_ok_response()) # type: ignore[method-assign] + + result = await discover_target_capabilities_async(target=target, per_probe_timeout_s=2.0) + + # Result reflects probe; target capabilities remain at declared values. + assert result.supports_multi_turn is True + assert target.capabilities.supports_multi_turn is False + + +@pytest.mark.usefixtures("patch_central_database") +class TestMultiTurnProbeMemoryFailure: + async def test_returns_false_when_history_seed_raises(self) -> None: + """ + If seeding conversation history into memory raises, the multi-turn + probe returns False rather than proceeding with a half-seeded + conversation that would produce a false positive. + """ + target = MockPromptTarget() + send_mock = AsyncMock(return_value=_ok_response()) + target._send_prompt_to_target_async = send_mock # type: ignore[method-assign] + + with patch.object(target._memory, "add_message_to_memory", side_effect=RuntimeError("memory offline")): + result = await _discover_capability_flags_async( + target=target, + capabilities={CapabilityName.MULTI_TURN}, + ) + + assert result == set() + # The first turn ran (1 send); the second turn must NOT run because + # seeding failed, otherwise the probe would falsely succeed. + assert send_mock.await_count == 1