Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ persona.templated_view # short attribute-based system prompt
persona.biography_view # full biography text
persona.statements # list of Statement

qa_pairs = dataset.get_qa(persona.id, type="implicit", item_type="mcq")
qa_pairs_train = dataset.get_qa(persona.id, item_type="frq")
qa_pairs_test = dataset.get_qa(persona.id, item_type="mcq", scope="shared")
explicit_interview = dataset.get_qa(
persona.id,
type="explicit",
item_type="mcq",
source="interview",
)
```

### PersonaGuess
Expand Down Expand Up @@ -106,7 +113,7 @@ full_prompt, response_start_idx = format_messages(messages, tokenizer)

When iterating over variants in an experiment, pass `"templated"` or `"biography"` to `format_prompt(persona, variant)`. Calling `format_prompt()` yields the persona-less Assistant baseline prompt. Use `BASELINE_PERSONA_ID` and `BASELINE_PERSONA_NAME` from `persona_data.prompts` for the shared baseline identity in artifacts and UI labels.

For multiple-choice prompts, use `format_mc_question(qa)` to render the question, choices, and trailing answer-only instruction. Use `mc_answer_only_instruction(n_choices)` if you need just the instruction text, and `mc_correct_letter(qa)` to get the gold label.
For multiple-choice prompts, use `format_mc_question(qa)` to render the question, choices, and trailing answer-only instruction. Use `mc_answer_only_instruction(n_choices)` if you need just the instruction text, and `mc_correct_letter(qa)` to get the gold label. `QAPair.split_group_id` is an optional generic split key for rows that belong to one leakage group. `QAPair.related_qids` is an optional generic list for rows related to multiple source rows; for SynthPersona implicit shared MCQs, it contains source free-response qids.

`format_messages` handles tokenizers that do not support the `"system"` role (for example Gemma 2) by merging system content into the first user message. Pass `add_generation_prompt=True` to render an inference-ready prompt (messages ending in a `user` turn); the returned `response_start_idx` then equals the prompt length, ready to slice `model.generate` output.

Expand Down
28 changes: 24 additions & 4 deletions docs/synth_persona.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ dataset = SynthPersonaDataset()
small = SynthPersonaDataset(sample_size=100)
```

The default dataset source is `implicit-personalization/synth-persona`. The loader reads two JSONL files:
The default dataset source is `implicit-personalization/synth-persona`. The loader downloads these files:

- `dataset_personas.jsonl`
- `dataset_qa.jsonl`
- `implicit_shared_mc_bank.json`

`sample_size` keeps the leading personas and filters QA rows to the loaded persona IDs.
`implicit_shared_mc_bank.json` is read internally to hydrate `QAPair.related_qids`; user code normally only calls `get_qa()`. `sample_size` keeps the leading personas and filters QA rows to the loaded persona IDs.

## Records

Expand All @@ -26,7 +27,16 @@ The default dataset source is `implicit-personalization/synth-persona`. The load

`Statement` includes `sid`, `category`, `claim`, and `support_turns`.

`QAPair` includes `qid`, `type`, `item_type`, `question`, `answer`, `choices`, `correct_choice_index`, `evidence_sids`, and `tags`.
`QAPair` keeps the cross-dataset interface small:

- core task fields: `qid`, `type`, `scope`, `item_type`, `question`, and `answer`
- multiple-choice fields: `choices`, `choice_labels`, `correct_choice_index`, and `correct_choice_letter`
- common query fields: `source`, `bank_id`, and `family_id`
- common provenance fields: `evidence_sids`, `tags`, and `is_baseline`
- split helpers: `split_group_id` and `related_qids`
- `metadata`: a dictionary containing dataset-specific public fields such as `source_turn`, `attribute_keys`, `statement_category`, `family_name`, `domain`, and `axis`

`QAPair.split_group_id` is a single leakage-group key when one exists. For SynthPersona explicit rows, it is `explicit:{bank_id}`. `QAPair.related_qids` is used when a row has many source rows. For SynthPersona implicit shared multiple-choice rows, it contains the public qids of individual implicit free-response rows used as source evidence for the shared item.

## Persona fields

Expand All @@ -47,7 +57,14 @@ It also exposes `name` as a derived property.
persona = dataset[0]

qa_pairs = dataset.get_qa(persona.id)
qa_pairs = dataset.get_qa(persona.id, type="explicit", item_type="mcq")
train_frq = dataset.get_qa(persona.id, item_type="frq")
test_mcq = dataset.get_qa(persona.id, item_type="mcq", scope="shared")
explicit_interview = dataset.get_qa(
persona.id,
type="explicit",
item_type="mcq",
source="interview",
)

loaded_persona = dataset.get_persona("p1")
```
Expand All @@ -58,5 +75,8 @@ loaded_persona = dataset.get_persona("p1")

- `type` can be `"explicit"` or `"implicit"`.
- `item_type` can be `"mcq"` (multiple-choice) or `"frq"` (free-response).
- `scope` can be `"individual"` (persona-specific row) or `"shared"` (common bank row).
- `source` is only populated for explicit rows and can be `"seed_attribute"`, `"interview"`, or `"statement"`.
- `bank_id` and `family_id` can filter related rows for analysis or split construction.
- `sample_size` keeps a leading slice rather than sampling randomly.
- The loader keeps the dataset eager and notebook-friendly rather than streaming.
7 changes: 4 additions & 3 deletions src/persona_data/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

_CONVERSATIONAL_SUFFIX = "\n\nAnswer naturally and conversationally as this person."

BASELINE_PERSONA_ID = "baseline"
BASELINE_PERSONA_ID = "baseline_assistant"
BASELINE_PERSONA_NAME = "Assistant"
BASELINE_VARIANT = "baseline"

_PERSONA_VIEWS = {"templated": "templated_view", "biography": "biography_view"}

Expand All @@ -31,7 +32,7 @@ def mc_answer_only_instruction(n_choices: int) -> str:

def format_prompt(
persona: PersonaData | str | None = None,
variant: Literal["baseline", "templated", "biography"] = BASELINE_PERSONA_ID,
variant: Literal["baseline", "templated", "biography"] = BASELINE_VARIANT,
mode: Literal["roleplay", "conversational"] = "roleplay",
) -> str:
"""Build a standard persona system prompt.
Expand All @@ -49,7 +50,7 @@ def format_prompt(

if isinstance(persona, str):
profile_text = persona
elif variant == BASELINE_PERSONA_ID:
elif variant == BASELINE_VARIANT:
profile_text = BASELINE_PERSONA_NAME
elif variant not in _PERSONA_VIEWS:
raise ValueError(f"Unsupported persona prompt variant: {variant!r}")
Expand Down
125 changes: 105 additions & 20 deletions src/persona_data/synth_persona.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
import json
from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterator, Literal
from typing import Any, Iterator, Literal

QAType = Literal["explicit", "implicit"]
QAItemType = Literal["mcq", "frq"]
QAScope = Literal["individual", "shared"]
QASource = Literal["seed_attribute", "interview", "statement"]


@dataclass
class QAPair:
qid: str
type: Literal["explicit", "implicit"]
item_type: Literal["mcq", "frq"]
scope: Literal["individual", "shared"]
type: QAType
item_type: QAItemType
scope: QAScope
question: str
answer: str
choices: list[str] = field(default_factory=list)
choice_labels: list[str] = field(default_factory=list)
correct_choice_index: int | None = None
correct_choice_letter: str | None = None
source: QASource | str | None = None
bank_id: str | None = None
family_id: str | None = None
evidence_sids: list[str] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
is_baseline: bool = False
split_group_id: str | None = None
related_qids: list[str] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)

def __repr__(self):
return f"QAPair(qid={self.qid!r}, type={self.type!r}, item_type={self.item_type!r})"
Expand Down Expand Up @@ -50,6 +65,20 @@ def __repr__(self):
return f"PersonaData(id={self.id!r}, name={self.name!r})"


def _load_related_qids_by_bank_id(bank_path: Path | str) -> dict[str, list[str]]:
"""Load generic related-qid mappings from a shared QA bank file."""
with open(bank_path) as f:
payload = json.load(f)
related: dict[str, list[str]] = {}
for item in payload.get("items", []):
bank_id = item.get("bank_id")
if bank_id:
related[str(bank_id)] = item.get("related_qids") or item.get(
"related_frq_qids"
) or []
return related


class PersonaDataset:
"""Persona dataset loaded from local JSONL files."""

Expand All @@ -58,6 +87,7 @@ def __init__(
personas_path: Path | str,
qa_path: Path | str,
*,
related_qids_by_bank_id: Mapping[str, list[str]] | None = None,
sample_size: int | None = None,
) -> None:
self.sample_size = sample_size
Expand Down Expand Up @@ -91,13 +121,49 @@ def __init__(

loaded_ids = set(self._personas_by_id)
self._qa: dict[str, list[QAPair]] = defaultdict(list)
related_qids_by_bank_id = related_qids_by_bank_id or {}

with open(qa_path) as f:
for line in f:
if not line.strip():
continue
d = json.loads(line)
if d["id"] not in loaded_ids:
continue
related_qids = d.get("related_qids") or d.get("related_frq_qids") or []
bank_id = d.get("bank_id")
if bank_id in related_qids_by_bank_id:
related_qids = list(related_qids_by_bank_id[str(bank_id)])
split_group_id = d.get("split_group_id")
if split_group_id is None and d.get("type") == "explicit" and d.get("bank_id"):
split_group_id = f"explicit:{d['bank_id']}"
metadata = {
key: value
for key, value in d.items()
if key
not in {
"id",
"qid",
"type",
"item_type",
"scope",
"question",
"answer",
"choices",
"choice_labels",
"correct_choice_index",
"correct_choice_letter",
"source",
"bank_id",
"family_id",
"evidence_sids",
"tags",
"is_baseline",
"split_group_id",
"related_qids",
"related_frq_qids",
}
}
self._qa[d["id"]].append(
QAPair(
qid=d["qid"],
Expand All @@ -107,9 +173,18 @@ def __init__(
question=d["question"],
answer=d["answer"],
choices=d.get("choices") or [],
choice_labels=d.get("choice_labels") or [],
correct_choice_index=d.get("correct_choice_index"),
evidence_sids=d.get("evidence_sids", []),
tags=d.get("tags", []),
correct_choice_letter=d.get("correct_choice_letter"),
source=d.get("source"),
bank_id=d.get("bank_id"),
family_id=d.get("family_id"),
evidence_sids=d.get("evidence_sids") or [],
tags=d.get("tags") or [],
is_baseline=d.get("is_baseline", False),
split_group_id=split_group_id,
related_qids=related_qids,
metadata=metadata,
)
)

Expand All @@ -131,45 +206,51 @@ def get_persona(self, persona_id: str) -> PersonaData | None:
def get_qa(
self,
persona_id: str,
type: Literal["explicit", "implicit"] | None = None,
item_type: Literal["mcq", "frq"] | None = None,
scope: Literal["individual", "shared"] | None = None,
type: QAType | None = None,
item_type: QAItemType | None = None,
scope: QAScope | None = None,
source: QASource | str | None = None,
bank_id: str | None = None,
family_id: str | None = None,
) -> list[QAPair]:
"""Return QA pairs for a persona, optionally filtered by type / item_type."""
"""Return QA pairs for one persona, optionally filtered by schema fields."""
pairs = self._qa.get(persona_id, [])
if type is not None:
pairs = [p for p in pairs if p.type == type]
if item_type is not None:
pairs = [p for p in pairs if p.item_type == item_type]
if scope is not None:
pairs = [p for p in pairs if p.scope == scope]
if source is not None:
pairs = [p for p in pairs if p.source == source]
if bank_id is not None:
pairs = [p for p in pairs if p.bank_id == bank_id]
if family_id is not None:
pairs = [p for p in pairs if p.family_id == family_id]
return pairs

def train_test_split(
self,
persona_id: str,
*,
n_train: int = 25,
train_type: Literal["explicit", "implicit"] | None = "explicit",
train_item_type: Literal["mcq", "frq"] | None = "mcq",
train_type: QAType | None = "explicit",
train_item_type: QAItemType | None = "mcq",
) -> tuple[list[QAPair], list[QAPair]]:
"""Return ``(train, test)`` QA splits for one persona.

Train: ``scope='individual'`` QAs — persona-specific items used to
derive a persona representation (Doc-to-LoRA conditioning, steering
vector, SAE features, …). Capped at ``n_train`` and by default
narrowed to explicit MCQs.
Test: ``scope='shared'`` QAs — the questions every persona answers,
kept as a held-out evaluation set with whatever mix of
explicit/implicit is present in the data.
Train: individual explicit multiple-choice QAs by default. This keeps
the historical helper behavior stable. Use `get_qa(..., item_type="frq")`
directly when building FRQ-trained steering vectors.
Test: shared multiple-choice QAs by default.
"""
train = self.get_qa(
persona_id,
type=train_type,
item_type=train_item_type,
scope="individual",
)[:n_train]
test = self.get_qa(persona_id, scope="shared")
test = self.get_qa(persona_id, item_type="mcq", scope="shared")
return train, test


Expand All @@ -185,10 +266,14 @@ def __init__(
from huggingface_hub import hf_hub_download

# HF Hub caches locally under HF_HOME so repeat runs are instant.
implicit_bank_path = hf_hub_download(
hf_repo, "implicit_shared_mc_bank.json", repo_type="dataset"
)
super().__init__(
personas_path=hf_hub_download(
hf_repo, "dataset_personas.jsonl", repo_type="dataset"
),
qa_path=hf_hub_download(hf_repo, "dataset_qa.jsonl", repo_type="dataset"),
related_qids_by_bank_id=_load_related_qids_by_bank_id(implicit_bank_path),
sample_size=sample_size,
)
Loading