Skip to content

Commit 4001dcc

Browse files
test: updated testing suite for multi-source dataset pipeline and preprocessing
1 parent bb78057 commit 4001dcc

3 files changed

Lines changed: 642 additions & 0 deletions

File tree

Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
import sys
2+
import types
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
import pytest
7+
import torch
8+
9+
import pepseqpred.apps.esm_cli as esm_cli
10+
import pepseqpred.apps.labels_cli as labels_cli
11+
import pepseqpred.apps.train_ffnn_cli as train_cli
12+
from pepseqpred.core.io.keys import parse_fullname
13+
from pepseqpred.core.preprocess.preparedataset import prepare_dataset
14+
from pepseqpred.core.train.split import split_ids_grouped
15+
16+
pytestmark = [pytest.mark.integration, pytest.mark.slow]
17+
18+
19+
class FakeAlphabet:
20+
def get_batch_converter(self):
21+
def _batch_converter(pairs):
22+
labels = [name for name, _seq in pairs]
23+
seqs = [seq for _name, seq in pairs]
24+
max_len = max((len(seq) for seq in seqs), default=0)
25+
tokens = torch.zeros((len(seqs), max_len + 2), dtype=torch.long)
26+
for i, seq in enumerate(seqs):
27+
seq_len = len(seq)
28+
tokens[i, 1:1 + seq_len] = 1
29+
tokens[i, 1 + seq_len] = 2
30+
return labels, seqs, tokens
31+
32+
return _batch_converter
33+
34+
35+
class FakeESMModel(torch.nn.Module):
36+
def __init__(self):
37+
super().__init__()
38+
self.p = torch.nn.Parameter(torch.zeros(1))
39+
40+
def forward(self, batch_tokens, repr_layers, return_contacts=False):
41+
_ = return_contacts
42+
batch_size, token_len = batch_tokens.shape
43+
rep_dim = 3 # append_seq_len => final emb dim=4
44+
reps = torch.ones((batch_size, token_len, rep_dim),
45+
dtype=torch.float32)
46+
return {"representations": {repr_layers[0]: reps}}
47+
48+
49+
def _write_code_list(path: Path, codes: list[str]) -> None:
50+
path.write_text("Sequence name\n" + "\n".join(codes) +
51+
"\n", encoding="utf-8")
52+
53+
54+
def _append_fasta_records(path: Path, records: list[tuple[str, str]]) -> None:
55+
with path.open("a", encoding="utf-8") as out_f:
56+
for header, seq in records:
57+
out_f.write(f">{header}\n{seq}\n")
58+
59+
60+
def _build_pv1_inputs(root: Path) -> tuple[Path, Path, Path]:
61+
root.mkdir(parents=True, exist_ok=True)
62+
meta = root / "pv1_meta.tsv"
63+
z = root / "pv1_z.tsv"
64+
fasta = root / "pv1_targets.fasta"
65+
66+
pd.DataFrame(
67+
[
68+
{
69+
"CodeName": "pv1_pep_1",
70+
"Category": "SetCover",
71+
"SpeciesID": "1",
72+
"Species": "PV1",
73+
"Protein": "Prot",
74+
"FullName": "ID=PV1P001 AC=A1 OXX=11,22,301_0_4",
75+
"Peptide": "MNPQ",
76+
"Encoding": "enc",
77+
},
78+
{
79+
"CodeName": "pv1_pep_2",
80+
"Category": "SetCover",
81+
"SpeciesID": "1",
82+
"Species": "PV1",
83+
"Protein": "Prot",
84+
"FullName": "ID=PV1P001 AC=A1 OXX=11,22,301_2_6",
85+
"Peptide": "PQRS",
86+
"Encoding": "enc",
87+
},
88+
]
89+
).to_csv(meta, sep="\t", index=False)
90+
pd.DataFrame(
91+
[
92+
{"Sequence name": "pv1_pep_1", "VW_001": 30.0, "VW_002": 0.0},
93+
{"Sequence name": "pv1_pep_2", "VW_001": 1.0, "VW_002": 2.0},
94+
]
95+
).to_csv(z, sep="\t", index=False)
96+
fasta.write_text(
97+
">ID=PV1P001 AC=A1 OXX=11,22,301\nMNPQRS\n",
98+
encoding="utf-8",
99+
)
100+
return meta, z, fasta
101+
102+
103+
def _build_cwp_inputs(root: Path) -> tuple[Path, Path, Path, Path]:
104+
root.mkdir(parents=True, exist_ok=True)
105+
meta = root / "cwp_meta.tsv"
106+
reactive = root / "cwp_reactive.tsv"
107+
nonreactive = root / "cwp_nonreactive.tsv"
108+
fasta = root / "cwp_targets.faa"
109+
110+
pd.DataFrame(
111+
[
112+
{
113+
"CodeName": "CWP_000001",
114+
"SequenceAccession": "A0CWP1",
115+
"Cluster50ID": "Cocci_id50_010",
116+
"StartIndex": 0,
117+
"StopIndex": 4,
118+
"PeptideSequence": "ACDE",
119+
},
120+
{
121+
"CodeName": "CWP_000002",
122+
"SequenceAccession": "A0CWP1",
123+
"Cluster50ID": "Cocci_id50_010",
124+
"StartIndex": 1,
125+
"StopIndex": 5,
126+
"PeptideSequence": "CDEF",
127+
},
128+
]
129+
).to_csv(meta, sep="\t", index=False)
130+
_write_code_list(reactive, ["CWP_000001"])
131+
_write_code_list(nonreactive, ["CWP_000002"])
132+
fasta.write_text(">tr|A0CWP1|A0CWP1_FAKE\nACDEFG\n", encoding="utf-8")
133+
return meta, reactive, nonreactive, fasta
134+
135+
136+
def _build_bkp_inputs(root: Path) -> tuple[Path, Path, Path, Path]:
137+
root.mkdir(parents=True, exist_ok=True)
138+
meta = root / "bkp_meta.tsv"
139+
reactive = root / "bkp_reactive.tsv"
140+
nonreactive = root / "bkp_nonreactive.tsv"
141+
fasta = root / "bkp_targets.faa"
142+
143+
pd.DataFrame(
144+
[
145+
{
146+
"CodeName": "BKP_000001",
147+
"SequenceAccession": "A0BKP1",
148+
"reClusterID_70": "BKP1_id70_200",
149+
"alignStart": "0.0",
150+
"alignStop": "4.0",
151+
"PeptideSequence": "WXYZ",
152+
},
153+
{
154+
"CodeName": "BKP_000002",
155+
"SequenceAccession": "A0BKP1",
156+
"reClusterID_70": "BKP1_id70_200",
157+
"alignStart": "1.0",
158+
"alignStop": "5.0",
159+
"PeptideSequence": "XYZA",
160+
},
161+
]
162+
).to_csv(meta, sep="\t", index=False)
163+
_write_code_list(reactive, ["BKP_000001"])
164+
_write_code_list(nonreactive, ["BKP_000002"])
165+
fasta.write_text(">tr|A0BKP1|A0BKP1_FAKE\nWXYZAB\n", encoding="utf-8")
166+
return meta, reactive, nonreactive, fasta
167+
168+
169+
def test_prepare_dataset_multisource_pipeline_smoke(monkeypatch, tmp_path: Path):
170+
# Build three mini datasets.
171+
pv1_meta, pv1_z, pv1_fasta = _build_pv1_inputs(tmp_path / "pv1")
172+
cwp_meta, cwp_reactive, cwp_nonreactive, cwp_fasta = _build_cwp_inputs(
173+
tmp_path / "cwp")
174+
bkp_meta, bkp_reactive, bkp_nonreactive, bkp_fasta = _build_bkp_inputs(
175+
tmp_path / "bkp")
176+
177+
out_pv1 = tmp_path / "out_pv1"
178+
out_cwp = tmp_path / "out_cwp"
179+
out_bkp = tmp_path / "out_bkp"
180+
181+
prepare_dataset(
182+
dataset_kind="pv1",
183+
meta_path=pv1_meta,
184+
z_path=pv1_z,
185+
output_dir=out_pv1,
186+
protein_fasta=pv1_fasta,
187+
is_epitope_min_subjects=1,
188+
)
189+
prepare_dataset(
190+
dataset_kind="cwp",
191+
meta_path=cwp_meta,
192+
output_dir=out_cwp,
193+
protein_fasta=cwp_fasta,
194+
reactive_codes=cwp_reactive,
195+
nonreactive_codes=cwp_nonreactive,
196+
group_id_offset=100_000_000,
197+
)
198+
prepare_dataset(
199+
dataset_kind="bkp",
200+
meta_path=bkp_meta,
201+
output_dir=out_bkp,
202+
protein_fasta=bkp_fasta,
203+
reactive_codes=bkp_reactive,
204+
nonreactive_codes=bkp_nonreactive,
205+
group_id_offset=200_000_000,
206+
)
207+
208+
# Combine prepared artifacts.
209+
combined_dir = tmp_path / "combined"
210+
combined_dir.mkdir(parents=True, exist_ok=True)
211+
combined_fasta = combined_dir / "prepared_targets.fasta"
212+
combined_meta = combined_dir / "prepared_labels_metadata.tsv"
213+
combined_emb_meta = combined_dir / "prepared_embedding_metadata.tsv"
214+
combined_fasta.write_text("", encoding="utf-8")
215+
216+
for source in [out_pv1, out_cwp, out_bkp]:
217+
recs = []
218+
header = None
219+
seq_lines = []
220+
for raw in (source / "prepared_targets.fasta").read_text(encoding="utf-8").splitlines():
221+
line = raw.strip()
222+
if line == "":
223+
continue
224+
if line.startswith(">"):
225+
if header is not None:
226+
recs.append((header, "".join(seq_lines)))
227+
header = line[1:].strip()
228+
seq_lines = []
229+
else:
230+
seq_lines.append(line)
231+
if header is not None:
232+
recs.append((header, "".join(seq_lines)))
233+
_append_fasta_records(combined_fasta, recs)
234+
235+
labels_df = pd.concat(
236+
[
237+
pd.read_csv(out_pv1 / "prepared_labels_metadata.tsv", sep="\t"),
238+
pd.read_csv(out_cwp / "prepared_labels_metadata.tsv", sep="\t"),
239+
pd.read_csv(out_bkp / "prepared_labels_metadata.tsv", sep="\t"),
240+
],
241+
ignore_index=True,
242+
)
243+
labels_df.to_csv(combined_meta, sep="\t", index=False)
244+
245+
emb_meta_df = pd.concat(
246+
[
247+
pd.read_csv(out_pv1 / "prepared_embedding_metadata.tsv", sep="\t"),
248+
pd.read_csv(out_cwp / "prepared_embedding_metadata.tsv", sep="\t"),
249+
pd.read_csv(out_bkp / "prepared_embedding_metadata.tsv", sep="\t"),
250+
],
251+
ignore_index=True,
252+
).drop_duplicates(subset=["Name", "Family"])
253+
emb_meta_df.to_csv(combined_emb_meta, sep="\t", index=False)
254+
255+
# Assert grouped split behavior: no family overlap between train/val IDs.
256+
id_to_family = {
257+
parse_fullname(str(name))[0]: str(int(family))
258+
for name, family in emb_meta_df[["Name", "Family"]].itertuples(index=False, name=None)
259+
}
260+
all_ids = sorted(id_to_family.keys())
261+
train_ids, val_ids = split_ids_grouped(
262+
all_ids,
263+
val_frac=0.34,
264+
seed=11,
265+
groups=id_to_family,
266+
)
267+
train_fams = {id_to_family[pid] for pid in train_ids}
268+
val_fams = {id_to_family[pid] for pid in val_ids}
269+
assert train_fams.isdisjoint(val_fams)
270+
271+
# Run ESM CLI with fake model.
272+
fake_pretrained = types.SimpleNamespace(
273+
fake_model=lambda: (FakeESMModel(), FakeAlphabet())
274+
)
275+
monkeypatch.setattr(esm_cli.esm, "pretrained", fake_pretrained)
276+
monkeypatch.setattr(esm_cli.torch.cuda, "is_available", lambda: False)
277+
monkeypatch.setattr(esm_cli.torch.cuda, "device_count", lambda: 0)
278+
279+
embs_out = tmp_path / "esm_out"
280+
monkeypatch.setattr(
281+
sys,
282+
"argv",
283+
[
284+
"esm_cli.py",
285+
"--fasta-file",
286+
str(combined_fasta),
287+
"--metadata-file",
288+
str(combined_emb_meta),
289+
"--out-dir",
290+
str(embs_out),
291+
"--embedding-key-mode",
292+
"id-family",
293+
"--model-name",
294+
"fake_model",
295+
"--max-tokens",
296+
"16",
297+
"--batch-size",
298+
"4",
299+
],
300+
)
301+
esm_cli.main()
302+
303+
emb_dir = embs_out / "artifacts" / "pts"
304+
assert emb_dir.exists()
305+
306+
# Run labels CLI against prepared metadata and generated embeddings.
307+
labels_pt = tmp_path / "labels.pt"
308+
monkeypatch.setattr(
309+
sys,
310+
"argv",
311+
[
312+
"labels_cli.py",
313+
str(combined_meta),
314+
str(labels_pt),
315+
"--emb-dir",
316+
str(emb_dir),
317+
"--embedding-key-delim",
318+
"-",
319+
"--calc-pos-weight",
320+
],
321+
)
322+
labels_cli.main()
323+
assert labels_pt.exists()
324+
325+
# Train smoke run using grouped split (id-family) over all three datasets.
326+
save_dir = tmp_path / "train_out"
327+
monkeypatch.setattr(
328+
sys,
329+
"argv",
330+
[
331+
"train_ffnn_cli.py",
332+
"--embedding-dirs",
333+
str(emb_dir),
334+
"--label-shards",
335+
str(labels_pt),
336+
"--epochs",
337+
"1",
338+
"--batch-size",
339+
"2",
340+
"--num-workers",
341+
"0",
342+
"--hidden-sizes",
343+
"8",
344+
"--dropouts",
345+
"0.1",
346+
"--val-frac",
347+
"0.34",
348+
"--split-type",
349+
"id-family",
350+
"--split-seeds",
351+
"11",
352+
"--train-seeds",
353+
"101",
354+
"--save-path",
355+
str(save_dir),
356+
"--results-csv",
357+
str(save_dir / "runs.csv"),
358+
],
359+
)
360+
train_cli.main()
361+
362+
assert (save_dir / "runs.csv").exists()
363+
run_dirs = list(save_dir.glob("run_*"))
364+
assert run_dirs
365+
assert (run_dirs[0] / "fully_connected.pt").exists()

0 commit comments

Comments
 (0)