From f33913d9b6a787e42a58f346d94507f775bf460f Mon Sep 17 00:00:00 2001 From: williamli-15 Date: Wed, 6 May 2026 16:06:54 +0800 Subject: [PATCH] Support SynthPersona QA schema filters --- README.md | 11 ++- docs/synth_persona.md | 28 ++++++- src/persona_data/prompts.py | 7 +- src/persona_data/synth_persona.py | 125 +++++++++++++++++++++++++----- tests/test_datasets.py | 56 ++++++++++++- 5 files changed, 197 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 1696471..77ee08e 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,14 @@ persona.templated_view # short attribute-based system prompt persona.biography_view # full biography text persona.statements # list of Statement -qa_pairs = dataset.get_qa(persona.id, type="implicit", item_type="mcq") +qa_pairs_train = dataset.get_qa(persona.id, item_type="frq") +qa_pairs_test = dataset.get_qa(persona.id, item_type="mcq", scope="shared") +explicit_interview = dataset.get_qa( + persona.id, + type="explicit", + item_type="mcq", + source="interview", +) ``` ### PersonaGuess @@ -106,7 +113,7 @@ full_prompt, response_start_idx = format_messages(messages, tokenizer) When iterating over variants in an experiment, pass `"templated"` or `"biography"` to `format_prompt(persona, variant)`. Calling `format_prompt()` yields the persona-less Assistant baseline prompt. Use `BASELINE_PERSONA_ID` and `BASELINE_PERSONA_NAME` from `persona_data.prompts` for the shared baseline identity in artifacts and UI labels. -For multiple-choice prompts, use `format_mc_question(qa)` to render the question, choices, and trailing answer-only instruction. Use `mc_answer_only_instruction(n_choices)` if you need just the instruction text, and `mc_correct_letter(qa)` to get the gold label. +For multiple-choice prompts, use `format_mc_question(qa)` to render the question, choices, and trailing answer-only instruction. Use `mc_answer_only_instruction(n_choices)` if you need just the instruction text, and `mc_correct_letter(qa)` to get the gold label. `QAPair.split_group_id` is an optional generic split key for rows that belong to one leakage group. `QAPair.related_qids` is an optional generic list for rows related to multiple source rows; for SynthPersona implicit shared MCQs, it contains source free-response qids. `format_messages` handles tokenizers that do not support the `"system"` role (for example Gemma 2) by merging system content into the first user message. Pass `add_generation_prompt=True` to render an inference-ready prompt (messages ending in a `user` turn); the returned `response_start_idx` then equals the prompt length, ready to slice `model.generate` output. diff --git a/docs/synth_persona.md b/docs/synth_persona.md index 69c006d..0db7cfb 100644 --- a/docs/synth_persona.md +++ b/docs/synth_persona.md @@ -11,12 +11,13 @@ dataset = SynthPersonaDataset() small = SynthPersonaDataset(sample_size=100) ``` -The default dataset source is `implicit-personalization/synth-persona`. The loader reads two JSONL files: +The default dataset source is `implicit-personalization/synth-persona`. The loader downloads these files: - `dataset_personas.jsonl` - `dataset_qa.jsonl` +- `implicit_shared_mc_bank.json` -`sample_size` keeps the leading personas and filters QA rows to the loaded persona IDs. +`implicit_shared_mc_bank.json` is read internally to hydrate `QAPair.related_qids`; user code normally only calls `get_qa()`. `sample_size` keeps the leading personas and filters QA rows to the loaded persona IDs. ## Records @@ -26,7 +27,16 @@ The default dataset source is `implicit-personalization/synth-persona`. The load `Statement` includes `sid`, `category`, `claim`, and `support_turns`. -`QAPair` includes `qid`, `type`, `item_type`, `question`, `answer`, `choices`, `correct_choice_index`, `evidence_sids`, and `tags`. +`QAPair` keeps the cross-dataset interface small: + +- core task fields: `qid`, `type`, `scope`, `item_type`, `question`, and `answer` +- multiple-choice fields: `choices`, `choice_labels`, `correct_choice_index`, and `correct_choice_letter` +- common query fields: `source`, `bank_id`, and `family_id` +- common provenance fields: `evidence_sids`, `tags`, and `is_baseline` +- split helpers: `split_group_id` and `related_qids` +- `metadata`: a dictionary containing dataset-specific public fields such as `source_turn`, `attribute_keys`, `statement_category`, `family_name`, `domain`, and `axis` + +`QAPair.split_group_id` is a single leakage-group key when one exists. For SynthPersona explicit rows, it is `explicit:{bank_id}`. `QAPair.related_qids` is used when a row has many source rows. For SynthPersona implicit shared multiple-choice rows, it contains the public qids of individual implicit free-response rows used as source evidence for the shared item. ## Persona fields @@ -47,7 +57,14 @@ It also exposes `name` as a derived property. persona = dataset[0] qa_pairs = dataset.get_qa(persona.id) -qa_pairs = dataset.get_qa(persona.id, type="explicit", item_type="mcq") +train_frq = dataset.get_qa(persona.id, item_type="frq") +test_mcq = dataset.get_qa(persona.id, item_type="mcq", scope="shared") +explicit_interview = dataset.get_qa( + persona.id, + type="explicit", + item_type="mcq", + source="interview", +) loaded_persona = dataset.get_persona("p1") ``` @@ -58,5 +75,8 @@ loaded_persona = dataset.get_persona("p1") - `type` can be `"explicit"` or `"implicit"`. - `item_type` can be `"mcq"` (multiple-choice) or `"frq"` (free-response). +- `scope` can be `"individual"` (persona-specific row) or `"shared"` (common bank row). +- `source` is only populated for explicit rows and can be `"seed_attribute"`, `"interview"`, or `"statement"`. +- `bank_id` and `family_id` can filter related rows for analysis or split construction. - `sample_size` keeps a leading slice rather than sampling randomly. - The loader keeps the dataset eager and notebook-friendly rather than streaming. diff --git a/src/persona_data/prompts.py b/src/persona_data/prompts.py index 513cb01..dc1d488 100644 --- a/src/persona_data/prompts.py +++ b/src/persona_data/prompts.py @@ -15,8 +15,9 @@ _CONVERSATIONAL_SUFFIX = "\n\nAnswer naturally and conversationally as this person." -BASELINE_PERSONA_ID = "baseline" +BASELINE_PERSONA_ID = "baseline_assistant" BASELINE_PERSONA_NAME = "Assistant" +BASELINE_VARIANT = "baseline" _PERSONA_VIEWS = {"templated": "templated_view", "biography": "biography_view"} @@ -31,7 +32,7 @@ def mc_answer_only_instruction(n_choices: int) -> str: def format_prompt( persona: PersonaData | str | None = None, - variant: Literal["baseline", "templated", "biography"] = BASELINE_PERSONA_ID, + variant: Literal["baseline", "templated", "biography"] = BASELINE_VARIANT, mode: Literal["roleplay", "conversational"] = "roleplay", ) -> str: """Build a standard persona system prompt. @@ -49,7 +50,7 @@ def format_prompt( if isinstance(persona, str): profile_text = persona - elif variant == BASELINE_PERSONA_ID: + elif variant == BASELINE_VARIANT: profile_text = BASELINE_PERSONA_NAME elif variant not in _PERSONA_VIEWS: raise ValueError(f"Unsupported persona prompt variant: {variant!r}") diff --git a/src/persona_data/synth_persona.py b/src/persona_data/synth_persona.py index 847cbd1..2c9c982 100644 --- a/src/persona_data/synth_persona.py +++ b/src/persona_data/synth_persona.py @@ -1,22 +1,37 @@ import json from collections import defaultdict +from collections.abc import Mapping from dataclasses import dataclass, field from pathlib import Path -from typing import Iterator, Literal +from typing import Any, Iterator, Literal + +QAType = Literal["explicit", "implicit"] +QAItemType = Literal["mcq", "frq"] +QAScope = Literal["individual", "shared"] +QASource = Literal["seed_attribute", "interview", "statement"] @dataclass class QAPair: qid: str - type: Literal["explicit", "implicit"] - item_type: Literal["mcq", "frq"] - scope: Literal["individual", "shared"] + type: QAType + item_type: QAItemType + scope: QAScope question: str answer: str choices: list[str] = field(default_factory=list) + choice_labels: list[str] = field(default_factory=list) correct_choice_index: int | None = None + correct_choice_letter: str | None = None + source: QASource | str | None = None + bank_id: str | None = None + family_id: str | None = None evidence_sids: list[str] = field(default_factory=list) tags: list[str] = field(default_factory=list) + is_baseline: bool = False + split_group_id: str | None = None + related_qids: list[str] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) def __repr__(self): return f"QAPair(qid={self.qid!r}, type={self.type!r}, item_type={self.item_type!r})" @@ -50,6 +65,20 @@ def __repr__(self): return f"PersonaData(id={self.id!r}, name={self.name!r})" +def _load_related_qids_by_bank_id(bank_path: Path | str) -> dict[str, list[str]]: + """Load generic related-qid mappings from a shared QA bank file.""" + with open(bank_path) as f: + payload = json.load(f) + related: dict[str, list[str]] = {} + for item in payload.get("items", []): + bank_id = item.get("bank_id") + if bank_id: + related[str(bank_id)] = item.get("related_qids") or item.get( + "related_frq_qids" + ) or [] + return related + + class PersonaDataset: """Persona dataset loaded from local JSONL files.""" @@ -58,6 +87,7 @@ def __init__( personas_path: Path | str, qa_path: Path | str, *, + related_qids_by_bank_id: Mapping[str, list[str]] | None = None, sample_size: int | None = None, ) -> None: self.sample_size = sample_size @@ -91,6 +121,8 @@ def __init__( loaded_ids = set(self._personas_by_id) self._qa: dict[str, list[QAPair]] = defaultdict(list) + related_qids_by_bank_id = related_qids_by_bank_id or {} + with open(qa_path) as f: for line in f: if not line.strip(): @@ -98,6 +130,40 @@ def __init__( d = json.loads(line) if d["id"] not in loaded_ids: continue + related_qids = d.get("related_qids") or d.get("related_frq_qids") or [] + bank_id = d.get("bank_id") + if bank_id in related_qids_by_bank_id: + related_qids = list(related_qids_by_bank_id[str(bank_id)]) + split_group_id = d.get("split_group_id") + if split_group_id is None and d.get("type") == "explicit" and d.get("bank_id"): + split_group_id = f"explicit:{d['bank_id']}" + metadata = { + key: value + for key, value in d.items() + if key + not in { + "id", + "qid", + "type", + "item_type", + "scope", + "question", + "answer", + "choices", + "choice_labels", + "correct_choice_index", + "correct_choice_letter", + "source", + "bank_id", + "family_id", + "evidence_sids", + "tags", + "is_baseline", + "split_group_id", + "related_qids", + "related_frq_qids", + } + } self._qa[d["id"]].append( QAPair( qid=d["qid"], @@ -107,9 +173,18 @@ def __init__( question=d["question"], answer=d["answer"], choices=d.get("choices") or [], + choice_labels=d.get("choice_labels") or [], correct_choice_index=d.get("correct_choice_index"), - evidence_sids=d.get("evidence_sids", []), - tags=d.get("tags", []), + correct_choice_letter=d.get("correct_choice_letter"), + source=d.get("source"), + bank_id=d.get("bank_id"), + family_id=d.get("family_id"), + evidence_sids=d.get("evidence_sids") or [], + tags=d.get("tags") or [], + is_baseline=d.get("is_baseline", False), + split_group_id=split_group_id, + related_qids=related_qids, + metadata=metadata, ) ) @@ -131,11 +206,14 @@ def get_persona(self, persona_id: str) -> PersonaData | None: def get_qa( self, persona_id: str, - type: Literal["explicit", "implicit"] | None = None, - item_type: Literal["mcq", "frq"] | None = None, - scope: Literal["individual", "shared"] | None = None, + type: QAType | None = None, + item_type: QAItemType | None = None, + scope: QAScope | None = None, + source: QASource | str | None = None, + bank_id: str | None = None, + family_id: str | None = None, ) -> list[QAPair]: - """Return QA pairs for a persona, optionally filtered by type / item_type.""" + """Return QA pairs for one persona, optionally filtered by schema fields.""" pairs = self._qa.get(persona_id, []) if type is not None: pairs = [p for p in pairs if p.type == type] @@ -143,6 +221,12 @@ def get_qa( pairs = [p for p in pairs if p.item_type == item_type] if scope is not None: pairs = [p for p in pairs if p.scope == scope] + if source is not None: + pairs = [p for p in pairs if p.source == source] + if bank_id is not None: + pairs = [p for p in pairs if p.bank_id == bank_id] + if family_id is not None: + pairs = [p for p in pairs if p.family_id == family_id] return pairs def train_test_split( @@ -150,18 +234,15 @@ def train_test_split( persona_id: str, *, n_train: int = 25, - train_type: Literal["explicit", "implicit"] | None = "explicit", - train_item_type: Literal["mcq", "frq"] | None = "mcq", + train_type: QAType | None = "explicit", + train_item_type: QAItemType | None = "mcq", ) -> tuple[list[QAPair], list[QAPair]]: """Return ``(train, test)`` QA splits for one persona. - Train: ``scope='individual'`` QAs — persona-specific items used to - derive a persona representation (Doc-to-LoRA conditioning, steering - vector, SAE features, …). Capped at ``n_train`` and by default - narrowed to explicit MCQs. - Test: ``scope='shared'`` QAs — the questions every persona answers, - kept as a held-out evaluation set with whatever mix of - explicit/implicit is present in the data. + Train: individual explicit multiple-choice QAs by default. This keeps + the historical helper behavior stable. Use `get_qa(..., item_type="frq")` + directly when building FRQ-trained steering vectors. + Test: shared multiple-choice QAs by default. """ train = self.get_qa( persona_id, @@ -169,7 +250,7 @@ def train_test_split( item_type=train_item_type, scope="individual", )[:n_train] - test = self.get_qa(persona_id, scope="shared") + test = self.get_qa(persona_id, item_type="mcq", scope="shared") return train, test @@ -185,10 +266,14 @@ def __init__( from huggingface_hub import hf_hub_download # HF Hub caches locally under HF_HOME so repeat runs are instant. + implicit_bank_path = hf_hub_download( + hf_repo, "implicit_shared_mc_bank.json", repo_type="dataset" + ) super().__init__( personas_path=hf_hub_download( hf_repo, "dataset_personas.jsonl", repo_type="dataset" ), qa_path=hf_hub_download(hf_repo, "dataset_qa.jsonl", repo_type="dataset"), + related_qids_by_bank_id=_load_related_qids_by_bank_id(implicit_bank_path), sample_size=sample_size, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index e469a9a..6ad985d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -50,7 +50,7 @@ def test_format_prompt_rejects_unknown_mode(): def test_format_prompt_baseline_default(): """Calling without args yields the persona-less Assistant baseline prompt.""" prompt = format_prompt() - assert BASELINE_PERSONA_ID == "baseline" + assert BASELINE_PERSONA_ID == "baseline_assistant" assert BASELINE_PERSONA_NAME in prompt @@ -146,6 +146,9 @@ def _write_persona_fixture(tmp: Path) -> tuple[Path, Path]: "scope": "individual", "question": "Who is it?", "answer": "Ada", + "source": "seed_attribute", + "bank_id": "explicit_name", + "attribute_keys": ["first_name"], }, { "id": "p1", @@ -157,6 +160,8 @@ def _write_persona_fixture(tmp: Path) -> tuple[Path, Path]: "answer": "math", "choices": ["math", "science"], "correct_choice_index": 0, + "family_id": "implicit_hint", + "axis": "curiosity", }, { "id": "p1", @@ -168,6 +173,10 @@ def _write_persona_fixture(tmp: Path) -> tuple[Path, Path]: "answer": "math", "choices": ["math", "art"], "correct_choice_index": 0, + "source": "interview", + "bank_id": "explicit_interview_field", + "source_turn": 7, + "source_turns": [7], }, { "id": "p1", @@ -179,6 +188,11 @@ def _write_persona_fixture(tmp: Path) -> tuple[Path, Path]: "answer": "yes", "choices": ["yes", "no"], "correct_choice_index": 0, + "bank_id": "implicit_shared", + "family_id": "implicit_shared", + "family_name": "Implicit Shared", + "domain": "identity", + "axis": "curiosity", }, { "id": "p1", @@ -189,7 +203,12 @@ def _write_persona_fixture(tmp: Path) -> tuple[Path, Path]: "question": "Shared explicit?", "answer": "yes", "choices": ["yes", "no"], + "choice_labels": ["A", "B"], "correct_choice_index": 0, + "correct_choice_letter": "A", + "source": "statement", + "bank_id": "explicit_statement_yes", + "statement_category": "values", }, { "id": "p2", @@ -226,6 +245,41 @@ def test_persona_dataset_qa_filters_by_type_and_item_type(tmp_path: Path): assert [q.qid for q in ds.get_qa("p1", type="explicit")] == ["q1", "q3", "q5"] assert [q.qid for q in ds.get_qa("p1", item_type="mcq")] == ["q2", "q3", "q4", "q5"] assert [q.qid for q in ds.get_qa("p1", scope="shared")] == ["q4", "q5"] + assert [q.qid for q in ds.get_qa("p1", source="interview")] == ["q3"] + assert [q.qid for q in ds.get_qa("p1", bank_id="explicit_statement_yes")] == ["q5"] + assert [q.qid for q in ds.get_qa("p1", family_id="implicit_shared")] == ["q4"] + + +def test_persona_dataset_exposes_current_public_metadata(tmp_path: Path): + personas_path, qa_path = _write_persona_fixture(tmp_path) + ds = PersonaDataset(personas_path, qa_path) + + qa = ds.get_qa("p1", bank_id="explicit_statement_yes")[0] + assert qa.choice_labels == ["A", "B"] + assert qa.correct_choice_letter == "A" + assert qa.source == "statement" + assert qa.metadata["statement_category"] == "values" + assert qa.split_group_id == "explicit:explicit_statement_yes" + + implicit = ds.get_qa("p1", family_id="implicit_shared")[0] + assert implicit.metadata["family_name"] == "Implicit Shared" + assert implicit.metadata["domain"] == "identity" + assert implicit.metadata["axis"] == "curiosity" + assert implicit.split_group_id is None + assert implicit.related_qids == [] + + +def test_persona_dataset_hydrates_related_qids_from_generic_bank_mapping(tmp_path: Path): + personas_path, qa_path = _write_persona_fixture(tmp_path) + + ds = PersonaDataset( + personas_path, + qa_path, + related_qids_by_bank_id={"implicit_shared": ["p1-q10", "p2-q20"]}, + ) + + implicit = ds.get_qa("p1", family_id="implicit_shared")[0] + assert implicit.related_qids == ["p1-q10", "p2-q20"] def test_persona_dataset_train_test_split_uses_individual_for_train(tmp_path: Path):