-
Notifications
You must be signed in to change notification settings - Fork 757
[DRAFT] FEAT add Agent Threat Rules (ATR) adversarial payload dataset loader #1715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
eeee2345
wants to merge
4
commits into
microsoft:main
Choose a base branch
from
eeee2345:feat/atr-dataset-loader
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+585
−0
Draft
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
deae6d6
FEAT: add Agent Threat Rules (ATR) adversarial payload dataset loader
eeee2345 44dce8b
DOC: drop manual 0_dataset.md entry per #1707
eeee2345 5f4490c
MAINT: address review feedback on ATR dataset loader
eeee2345 8d8cf4d
Merge branch 'main' into feat/atr-dataset-loader
eeee2345 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
320 changes: 320 additions & 0 deletions
320
pyrit/datasets/seed_datasets/remote/agent_threat_rules_dataset.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,320 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| import logging | ||
| from enum import Enum | ||
| from typing import Literal, Optional | ||
|
|
||
| from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( | ||
| _RemoteDatasetLoader, | ||
| ) | ||
| from pyrit.models import SeedDataset, SeedPrompt | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class ATRCategory(Enum): | ||
| """ | ||
| ATR taxonomy categories. | ||
|
|
||
| Reflects the full ATR rule taxonomy (ten categories). The autoresearch | ||
| payload corpus currently covers six of these; filtering by an uncovered | ||
| category returns an empty dataset. | ||
| """ | ||
|
|
||
| PROMPT_INJECTION = "prompt-injection" | ||
| TOOL_POISONING = "tool-poisoning" | ||
| CONTEXT_EXFILTRATION = "context-exfiltration" | ||
| AGENT_MANIPULATION = "agent-manipulation" | ||
| PRIVILEGE_ESCALATION = "privilege-escalation" | ||
| SKILL_COMPROMISE = "skill-compromise" | ||
| DATA_POISONING = "data-poisoning" | ||
| EXCESSIVE_AUTONOMY = "excessive-autonomy" | ||
| MODEL_ABUSE = "model-abuse" | ||
| MODEL_SECURITY = "model-security" | ||
|
|
||
|
|
||
| class ATRDetectionField(Enum): | ||
| """ | ||
| Agent surface that an ATR payload targets. | ||
|
|
||
| Each entry in adversarial-samples.json carries a ``detection_field`` value | ||
| indicating which agent input or output channel the payload is intended to | ||
| appear on. Useful for narrowing the dataset to the surface the user is | ||
| actually testing. | ||
| """ | ||
|
|
||
| USER_INPUT = "user_input" | ||
| CONTENT = "content" | ||
| TOOL_ARGS = "tool_args" | ||
| TOOL_NAME = "tool_name" | ||
| TOOL_RESPONSE = "tool_response" | ||
| AGENT_OUTPUT = "agent_output" | ||
|
|
||
|
|
||
| class ATRVariationType(Enum): | ||
| """ | ||
| Variation type label for an ATR payload. | ||
|
|
||
| Indicates whether a payload is an original seed entry or an | ||
| autoresearch-derived variant. | ||
| """ | ||
|
|
||
| ORIGINAL = "original" | ||
| GENERATED = "generated" | ||
|
|
||
|
|
||
| # Maps rule IDs in the autoresearch coverage to their ATR taxonomy category. | ||
| # This dict reflects the rules currently represented in adversarial-samples.json. | ||
| # When ATR extends autoresearch coverage to additional rules, add entries here. | ||
| # The ATRCategory enum is the single source of truth for the category strings — | ||
| # a typo in either side becomes a static error at import time rather than a | ||
| # silent data-quality bug. | ||
| _RULE_ID_TO_CATEGORY: dict[str, ATRCategory] = { | ||
| "ATR-2026-00001": ATRCategory.PROMPT_INJECTION, | ||
| "ATR-2026-00002": ATRCategory.PROMPT_INJECTION, | ||
| "ATR-2026-00003": ATRCategory.PROMPT_INJECTION, | ||
| "ATR-2026-00010": ATRCategory.TOOL_POISONING, | ||
| "ATR-2026-00020": ATRCategory.CONTEXT_EXFILTRATION, | ||
| "ATR-2026-00021": ATRCategory.CONTEXT_EXFILTRATION, | ||
| "ATR-2026-00030": ATRCategory.AGENT_MANIPULATION, | ||
| "ATR-2026-00040": ATRCategory.PRIVILEGE_ESCALATION, | ||
| "ATR-2026-00060": ATRCategory.SKILL_COMPROMISE, | ||
| "ATR-2026-00064": ATRCategory.SKILL_COMPROMISE, | ||
| } | ||
|
|
||
|
|
||
| # Default upstream URL pinned to a specific commit for reproducibility, mirroring | ||
| # the HarmBench loader convention. | ||
| _DEFAULT_SOURCE = ( | ||
| "https://raw.githubusercontent.com/Agent-Threat-Rule/agent-threat-rules/" | ||
| "db793f9/data/autoresearch/adversarial-samples.json" | ||
| ) | ||
|
|
||
|
|
||
| class _AgentThreatRulesDataset(_RemoteDatasetLoader): | ||
| """ | ||
| Loader for the Agent Threat Rules (ATR) adversarial payload corpus. | ||
|
|
||
| ATR is an open MIT-licensed detection standard for AI agent threats. The | ||
| upstream catalog ships rules across ten attack categories (prompt-injection, | ||
| tool-poisoning, skill-compromise, agent-manipulation, context-exfiltration, | ||
| data-poisoning, excessive-autonomy, model-abuse, model-security, | ||
| privilege-escalation) and 336 rules at the time of this loader's pin. | ||
|
|
||
| This loader surfaces the autoresearch adversarial-payload corpus | ||
| (``data/autoresearch/adversarial-samples.json``), which contains 1,054 | ||
| attack-prompt entries across ten base rule scenarios in six of the ten | ||
| categories. Each entry carries an attack technique label (paraphrase, | ||
| language_switch, encoding, role_play, and 17 others) and the agent surface | ||
| the payload targets (``user_input``, ``content``, ``tool_args``, | ||
| ``tool_name``, ``tool_response``, ``agent_output``). | ||
|
|
||
| Reference: https://github.com/Agent-Threat-Rule/agent-threat-rules | ||
| License: MIT. | ||
|
|
||
| Each entry is mapped to a SeedPrompt with the payload as ``value``. The | ||
| upstream metadata fields (``original_rule_id``, ``technique``, | ||
| ``detection_field``, ``variation_type``) are preserved on | ||
| ``SeedPrompt.metadata`` so downstream consumers can route, filter, or | ||
| score by them. ``harm_categories`` is set to the rule's ATR category | ||
| (single-element list). | ||
|
|
||
| The optional ``categories``, ``techniques``, ``detection_fields``, and | ||
| ``variation_types`` arguments narrow the dataset client-side after fetch. | ||
| Passing an empty list is rejected — pass ``None`` to disable a filter. | ||
| """ | ||
|
|
||
| # Class-attribute metadata picked up by SeedDatasetMetadata | ||
| harm_categories: list[str] = [ | ||
| "prompt-injection", | ||
| "tool-poisoning", | ||
| "context-exfiltration", | ||
| "agent-manipulation", | ||
| "privilege-escalation", | ||
| "skill-compromise", | ||
| ] | ||
| modalities: list[str] = ["text"] | ||
| size: str = "large" # 1,054 seeds | ||
| tags: set[str] = {"safety", "agent_security", "prompt_injection"} | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| source: str = _DEFAULT_SOURCE, | ||
| source_type: Literal["public_url", "file"] = "public_url", | ||
| categories: Optional[list[ATRCategory]] = None, | ||
| techniques: Optional[list[str]] = None, | ||
| detection_fields: Optional[list[ATRDetectionField]] = None, | ||
| variation_types: Optional[list[ATRVariationType]] = None, | ||
| ) -> None: | ||
| """ | ||
| Initialize the ATR dataset loader. | ||
|
|
||
| Args: | ||
| source: URL or local path to ``adversarial-samples.json``. Defaults | ||
| to a pinned commit on the upstream ATR repository for | ||
| reproducibility; pass the raw URL on ``main`` or a different | ||
| tag to track upstream. | ||
| source_type: ``"public_url"`` or ``"file"``. | ||
| categories: Optional non-empty list of ATRCategory values; if | ||
| provided, only payloads whose original rule maps to one of | ||
| these categories are returned. Pass ``None`` (not ``[]``) to | ||
| include all categories. | ||
| techniques: Optional non-empty list of technique strings (free | ||
| text, since the upstream taxonomy of techniques is open-set); | ||
| if provided, only payloads with a matching technique are | ||
| returned. Pass ``None`` (not ``[]``) to include all techniques. | ||
| detection_fields: Optional non-empty list of ATRDetectionField | ||
| values; if provided, only payloads targeting one of these | ||
| surfaces are returned. Pass ``None`` (not ``[]``) to include | ||
| all surfaces. | ||
| variation_types: Optional non-empty list of ATRVariationType | ||
| values; if provided, only payloads of those variation types | ||
| are returned. Pass ``None`` (not ``[]``) to include all | ||
| variation types. | ||
|
|
||
| Raises: | ||
| ValueError: If any filter is an empty list (``[]``), or if | ||
| ``categories``, ``detection_fields``, or ``variation_types`` | ||
| contain values that are not instances of their expected enum. | ||
| """ | ||
| # Reject empty-list filters — a silent empty filter that returns the | ||
| # full dataset is almost always a caller bug; force the caller to use | ||
| # ``None`` if they want all entries. | ||
| if categories is not None: | ||
| if not categories: | ||
| raise ValueError("`categories` must be a non-empty list (pass None to include all categories)") | ||
| self._validate_enums(categories, ATRCategory, "category") | ||
| if techniques is not None and not techniques: | ||
| raise ValueError("`techniques` must be a non-empty list (pass None to include all techniques)") | ||
| if detection_fields is not None: | ||
| if not detection_fields: | ||
| raise ValueError( | ||
| "`detection_fields` must be a non-empty list (pass None to include all detection fields)" | ||
| ) | ||
| self._validate_enums(detection_fields, ATRDetectionField, "detection_field") | ||
| if variation_types is not None: | ||
| if not variation_types: | ||
| raise ValueError( | ||
| "`variation_types` must be a non-empty list (pass None to include all variation types)" | ||
| ) | ||
| self._validate_enums(variation_types, ATRVariationType, "variation_type") | ||
|
|
||
| self.source = source | ||
| self.source_type: Literal["public_url", "file"] = source_type | ||
| self._categories = {c.value for c in categories} if categories else None | ||
| self._techniques = set(techniques) if techniques else None | ||
| self._detection_fields = {d.value for d in detection_fields} if detection_fields else None | ||
| self._variation_types = {v.value for v in variation_types} if variation_types else None | ||
|
|
||
| @property | ||
| def dataset_name(self) -> str: | ||
| """Return the dataset name.""" | ||
| return "agent_threat_rules" | ||
|
|
||
| async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: | ||
| """ | ||
| Fetch the ATR adversarial payload corpus and return as SeedDataset. | ||
|
|
||
| Args: | ||
| cache: Whether to cache the fetched dataset. Defaults to True. | ||
|
|
||
| Returns: | ||
| SeedDataset: A SeedDataset containing one SeedPrompt per matching | ||
| ATR payload entry. | ||
|
|
||
| Raises: | ||
| ValueError: If any entry is missing a required field. | ||
| """ | ||
| required_keys = { | ||
| "id", | ||
| "original_rule_id", | ||
| "technique", | ||
| "payload", | ||
| "detection_field", | ||
| "variation_type", | ||
| } | ||
|
|
||
| examples = self._fetch_from_url( | ||
| source=self.source, | ||
| source_type=self.source_type, | ||
| cache=cache, | ||
| ) | ||
|
|
||
| authors = ["ATR Community"] | ||
| source_url = "https://github.com/Agent-Threat-Rule/agent-threat-rules" | ||
|
|
||
| seeds: list[SeedPrompt] = [] | ||
| skipped_unknown_rule = 0 | ||
|
|
||
| for example in examples: | ||
| missing = required_keys - example.keys() | ||
| if missing: | ||
| raise ValueError(f"Missing keys in ATR entry: {', '.join(sorted(missing))}") | ||
|
|
||
| rule_id = example["original_rule_id"] | ||
| category = _RULE_ID_TO_CATEGORY.get(rule_id) | ||
| if category is None: | ||
| # Unknown rule — likely a new rule_id that landed upstream | ||
| # before the loader's mapping was extended. Skip rather than | ||
| # mislabel; warn in aggregate at end. | ||
| skipped_unknown_rule += 1 | ||
| continue | ||
|
|
||
| category_value = category.value | ||
|
|
||
| if self._categories and category_value not in self._categories: | ||
| continue | ||
| if self._techniques and example["technique"] not in self._techniques: | ||
| continue | ||
| if self._detection_fields and example["detection_field"] not in self._detection_fields: | ||
| continue | ||
| if self._variation_types and example["variation_type"] not in self._variation_types: | ||
| continue | ||
|
|
||
| metadata: dict[str, str | int] = { | ||
| "original_rule_id": rule_id, | ||
| "technique": example["technique"], | ||
| "detection_field": example["detection_field"], | ||
| "variation_type": example["variation_type"], | ||
| "atr_id": example["id"], | ||
| } | ||
|
|
||
| # Per-rule description so downstream consumers reading metadata see | ||
| # the category that actually applies to this seed (rather than a | ||
| # corpus-wide list that ignores the rule's specific family). | ||
| category_label = category_value.replace("-", " ") | ||
| description = ( | ||
| f"Agent Threat Rules (ATR) adversarial payload in the {category_label} family. Rule {rule_id}." | ||
| ) | ||
|
|
||
| seeds.append( | ||
| SeedPrompt( | ||
| value=example["payload"], | ||
| data_type="text", | ||
| name=rule_id, | ||
| dataset_name=self.dataset_name, | ||
| harm_categories=[category_value], | ||
| description=description, | ||
| authors=authors, | ||
| source=source_url, | ||
| metadata=metadata, | ||
| ) | ||
| ) | ||
|
|
||
| if skipped_unknown_rule: | ||
| logger.warning( | ||
| "Skipped %d ATR payload(s) whose original_rule_id is not in the " | ||
| "loader's category mapping. Update _RULE_ID_TO_CATEGORY in " | ||
| "agent_threat_rules_dataset.py to extend coverage.", | ||
| skipped_unknown_rule, | ||
| ) | ||
|
|
||
| logger.info( | ||
| "Successfully loaded %d ATR adversarial payloads (from %d total upstream entries)", | ||
| len(seeds), | ||
| len(examples), | ||
| ) | ||
|
|
||
| return SeedDataset(seeds=seeds, dataset_name=self.dataset_name) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty filter list silently disables the filter:
categories=[]becomesset(), which is falsy, soif self._categories and ...skips filtering entirely and the user gets back the whole dataset. Most users would expect "no categories selected → empty result."I'd either raise on empty (simplest) or normalize empty-to-
Noneonly after a deliberate decision is made about the semantics. Same applies totechniques,detection_fields,variation_types.