diff --git a/.gitignore b/.gitignore index ece76c1..1e81ee9 100644 --- a/.gitignore +++ b/.gitignore @@ -42,7 +42,12 @@ pytest-*.xml *.pyz *.metadata *.json -localdata/evals/* +*.png +*.svg + +# GitHub and PyPI PepSeqPred logos +!PepSeqPred_logo_black.png +!PepSeqPred_logo_white.png # Bundled pretrained model artifacts shipped with the package !src/pepseqpred/api/pretrained_artifacts/**/*.pt diff --git a/scripts/hpc/trainffnn.sh b/scripts/hpc/trainffnn.sh index 314dede..a845946 100644 --- a/scripts/hpc/trainffnn.sh +++ b/scripts/hpc/trainffnn.sh @@ -81,6 +81,9 @@ WINDOW_SIZE="${WINDOW_SIZE:-1000}" STRIDE="${STRIDE:-900}" SPLIT_TYPE="${SPLIT_TYPE:-id-family}" # id-family or id LABEL_CACHE_MODE="${LABEL_CACHE_MODE:-current}" # current or all +SAVE_VAL_CURVES="${SAVE_VAL_CURVES:-0}" # 1 to enable validation ROC/PR artifacts +VAL_CURVE_MAX_POINTS="${VAL_CURVE_MAX_POINTS:-2048}" +VAL_PLOT_FORMATS="${VAL_PLOT_FORMATS:-png}" mkdir -p "${SAVE_PATH}" @@ -114,6 +117,13 @@ else TRAIN_MODE_ARGS+=(--train-seeds "$TRAIN_SEEDS") fi +VAL_CURVE_ARGS=() +if [ "${SAVE_VAL_CURVES}" -eq 1 ]; then + VAL_CURVE_ARGS+=(--save-val-curves) + VAL_CURVE_ARGS+=(--val-curve-max-points "$VAL_CURVE_MAX_POINTS") + VAL_CURVE_ARGS+=(--val-plot-formats "$VAL_PLOT_FORMATS") +fi + ${LAUNCHER} torchrun --nproc_per_node=4 train_ffnn.pyz \ --embedding-dirs "${EMBEDDING_DIRS[@]}" \ --label-shards "${LABEL_SHARDS[@]}" \ @@ -133,6 +143,7 @@ ${LAUNCHER} torchrun --nproc_per_node=4 train_ffnn.pyz \ --results-csv "$RESULTS_CSV" \ --num-workers "$NUM_WORKERS" \ --window-size "$WINDOW_SIZE" \ - --stride "$STRIDE" + --stride "$STRIDE" \ + "${VAL_CURVE_ARGS[@]}" # USAGE: sbatch trainffnn.sh /scratch/$USER/esm2/artifacts/pts/shard_000 /scratch/$USER/esm2/artifacts/pts/shard_001 /scratch/$USER/esm2/artifacts/pts/shard_002 /scratch/$USER/esm2/artifacts/pts/shard_003 -- /scratch/$USER/labels/labels_shard_000.pt /scratch/$USER/labels/labels_shard_001.pt /scratch/$USER/labels/labels_shard_002.pt /scratch/$USER/labels/labels_shard_003.pt diff --git a/src/pepseqpred/apps/train_ffnn_cli.py b/src/pepseqpred/apps/train_ffnn_cli.py index b3f9fc1..fe5b943 100644 --- a/src/pepseqpred/apps/train_ffnn_cli.py +++ b/src/pepseqpred/apps/train_ffnn_cli.py @@ -22,7 +22,8 @@ import time from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Any, Tuple +from typing import Dict, List, Any, Mapping, Sequence, Tuple +import numpy as np import torch from torch.utils.data import DataLoader import torch.distributed as dist @@ -32,7 +33,11 @@ from pepseqpred.core.io.write import append_csv_row from pepseqpred.core.data.proteindataset import ProteinDataset, pad_collate from pepseqpred.core.models.ffnn import PepSeqFFNN -from pepseqpred.core.train.trainer import Trainer, TrainerConfig +from pepseqpred.core.train.trainer import ( + Trainer, + TrainerConfig, + ValidationCurveArtifactConfig +) from pepseqpred.core.train.ddp import init_ddp from pepseqpred.core.train.split import ( split_ids, @@ -100,6 +105,509 @@ def _finite_or_none(value: Any) -> float | None: return num if math.isfinite(num) else None +def _parse_plot_formats(raw: str) -> Tuple[str, ...]: + """Parses comma-separated plot file formats.""" + tokens = [t.strip().lower() for t in str(raw).split(",")] + formats = tuple(t for t in tokens if len(t) > 0) + if len(formats) < 1: + raise ValueError("--val-plot-formats must include at least one format") + allowed = {"png", "svg", "pdf"} + bad = [fmt for fmt in formats if fmt not in allowed] + if len(bad) > 0: + raise ValueError( + f"Unsupported --val-plot-formats values={bad}; allowed={sorted(allowed)}" + ) + return formats + + +def _as_optional_int(value: Any) -> int | None: + """Parses optional integer-like value.""" + if value is None: + return None + try: + return int(value) + except (TypeError, ValueError): + return None + + +def _resolve_pr_zoom_limits( + fold_evaluations: Sequence[Mapping[str, Any]], + baseline_y: float | None, + recall_xmax: float = 0.20, +) -> Tuple[float, float]: + """Computes stable PR-plot zoom y-limits for highly imbalanced datasets.""" + precision_values: List[float] = [] + xmax = float(recall_xmax) + if not math.isfinite(xmax) or xmax <= 0.0: + xmax = 0.20 + for fold_eval in fold_evaluations: + pr_obj = fold_eval.get("pr_curve") + if not isinstance(pr_obj, Mapping): + continue + if not bool(pr_obj.get("available", False)): + continue + recalls = pr_obj.get("recall") + precisions = pr_obj.get("precision") + if not isinstance(recalls, list) or not isinstance(precisions, list): + continue + if len(recalls) != len(precisions): + continue + for recall, precision in zip(recalls, precisions): + try: + recall_f = float(recall) + precision_f = float(precision) + except (TypeError, ValueError): + continue + if not math.isfinite(recall_f) or not math.isfinite(precision_f): + continue + if 0.0 < recall_f <= xmax: + precision_values.append(precision_f) + + if len(precision_values) < 1: + y_top = 0.05 + else: + q95 = float(np.quantile(np.asarray(precision_values), 0.95)) + y_top = max(0.02, q95 * 1.15) + if baseline_y is not None and math.isfinite(float(baseline_y)): + y_top = max(y_top, float(baseline_y) * 6.0) + y_top = min(1.0, max(0.02, y_top)) + return (0.0, y_top) + + +def _resolve_pr_zoom_xmax( + fold_evaluations: Sequence[Mapping[str, Any]], + preferred_xmax: float = 0.20, +) -> float: + """Chooses a PR zoom x-maximum with enough sampled points to draw visible lines.""" + candidate_maxes = [float(preferred_xmax), 0.30, 0.50, 1.00] + valid_maxes = [x for x in candidate_maxes if math.isfinite(x) and x > 0.0] + if len(valid_maxes) < 1: + valid_maxes = [0.20, 0.30, 0.50, 1.00] + + recalls: List[float] = [] + for fold_eval in fold_evaluations: + pr_obj = fold_eval.get("pr_curve") + if not isinstance(pr_obj, Mapping): + continue + if not bool(pr_obj.get("available", False)): + continue + rec = pr_obj.get("recall") + if not isinstance(rec, list): + continue + for raw in rec: + try: + value = float(raw) + except (TypeError, ValueError): + continue + if math.isfinite(value) and value >= 0.0: + recalls.append(value) + + if len(recalls) < 1: + return float(valid_maxes[0]) + + for xmax in valid_maxes: + n_zoom = sum(1 for r in recalls if 0.0 <= r <= xmax) + if n_zoom >= 2: + return float(xmax) + return 1.0 + + +def _plot_fold_curves( + fold_evaluations: Sequence[Mapping[str, Any]], + plot_path_base: Path, + formats: Sequence[str], + curve_key: str, + title: str, + x_label: str, + y_label: str, + x_key: str, + y_key: str, + metric_key: str, + chance_line: bool, + baseline_y: float | None = None, + secondary_metric_key: str | None = None, + metric_label: str | None = None, + secondary_metric_label: str | None = None, + x_limits: Tuple[float, float] | None = None, + y_limits: Tuple[float, float] | None = None, + legend_loc: str = "lower right", +) -> List[str]: + """Writes a fold-only curve panel for ROC or PR.""" + try: + import matplotlib.pyplot as plt + except Exception as e: + raise RuntimeError( + "Matplotlib is required for plotting. Install matplotlib to use --save-val-curves." + ) from e + + fig, ax = plt.subplots(figsize=(7.0, 6.0), dpi=150) + plotted = 0 + for fold_eval in fold_evaluations: + curve_obj = fold_eval.get(curve_key) + if not isinstance(curve_obj, Mapping): + continue + if not bool(curve_obj.get("available", False)): + continue + x_vals = curve_obj.get(x_key) + y_vals = curve_obj.get(y_key) + if not isinstance(x_vals, list) or not isinstance(y_vals, list): + continue + if len(x_vals) < 2 or len(y_vals) < 2: + continue + fold_index = _as_optional_int(fold_eval.get("fold_index")) + fold_label = int(fold_index) if fold_index is not None else (plotted + 1) + metrics_obj = fold_eval.get("metrics") + metric_value = float("nan") + if isinstance(metrics_obj, Mapping): + raw_metric = metrics_obj.get(metric_key, float("nan")) + try: + metric_value = float(raw_metric) + except (TypeError, ValueError): + metric_value = float("nan") + if not math.isfinite(metric_value): + if curve_key == "pr_curve" and metric_key == "pr_auc": + raw_ap = curve_obj.get("ap", float("nan")) + try: + metric_value = float(raw_ap) + except (TypeError, ValueError): + metric_value = float("nan") + primary_label = metric_label if metric_label is not None else metric_key.upper() + primary_metric_text = ( + f"{primary_label}: {metric_value:.3f}" + if math.isfinite(metric_value) + else f"{primary_label}: nan" + ) + secondary_metric_text = None + if secondary_metric_key is not None and isinstance(metrics_obj, Mapping): + raw_secondary = metrics_obj.get(secondary_metric_key, float("nan")) + try: + secondary_value = float(raw_secondary) + except (TypeError, ValueError): + secondary_value = float("nan") + if ( + not math.isfinite(secondary_value) + and curve_key == "pr_curve" + and secondary_metric_key == "pr_auc_trapz" + ): + raw_trapz = curve_obj.get("auprc_trapz", float("nan")) + try: + secondary_value = float(raw_trapz) + except (TypeError, ValueError): + secondary_value = float("nan") + secondary_label = ( + secondary_metric_label + if secondary_metric_label is not None + else secondary_metric_key.upper() + ) + if math.isfinite(secondary_value): + secondary_metric_text = f"{secondary_label}: {secondary_value:.3f}" + else: + secondary_metric_text = f"{secondary_label}: nan" + label_parts = [f"Fold {fold_label}", primary_metric_text] + if secondary_metric_text is not None: + label_parts.append(secondary_metric_text) + ax.plot( + x_vals, + y_vals, + linewidth=1.8, + alpha=0.95, + label=" | ".join(label_parts), + ) + plotted += 1 + + if chance_line: + ax.plot( + [0.0, 1.0], + [0.0, 1.0], + linestyle="--", + linewidth=1.0, + color="#C77DBB", + label="Chance", + ) + if baseline_y is not None and math.isfinite(float(baseline_y)): + y_val = float(baseline_y) + ax.plot( + [0.0, 1.0], + [y_val, y_val], + linestyle="--", + linewidth=1.0, + color="#888888", + label=f"Prevalence baseline: {y_val:.4f}", + ) + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + if x_limits is None: + ax.set_xlim(0.0, 1.0) + else: + ax.set_xlim(float(x_limits[0]), float(x_limits[1])) + if y_limits is None: + ax.set_ylim(0.0, 1.0) + else: + ax.set_ylim(float(y_limits[0]), float(y_limits[1])) + ax.grid(alpha=0.2) + if plotted > 0: + ax.legend(loc=legend_loc, frameon=False, fontsize=9) + fig.tight_layout() + + saved_paths: List[str] = [] + plot_path_base.parent.mkdir(parents=True, exist_ok=True) + for fmt in formats: + out_path = plot_path_base.with_suffix(f".{fmt}") + fig.savefig(out_path, dpi=300, bbox_inches="tight") + saved_paths.append(str(out_path)) + plt.close(fig) + return saved_paths + + +def _build_set_best_fold_curve_payload( + run_rows: Sequence[Mapping[str, Any]], + set_index: int, + val_curve_subdir: str, +) -> Dict[str, Any]: + """Builds fold-level best-epoch curve payload for one ensemble set.""" + fold_map: Dict[int, Dict[str, Any]] = {} + warnings: List[str] = [] + + rows_in_set = sorted( + ( + row + for row in run_rows + if int(_as_optional_int(row.get("EnsembleSetIndex")) or 1) == int(set_index) + ), + key=lambda row: int(_as_optional_int(row.get("RunIndex")) or int(1e9)), + ) + + for row in rows_in_set: + fold_idx = _as_optional_int(row.get("FoldIndex")) + if fold_idx is None: + continue + if fold_idx in fold_map: + warnings.append( + f"set_index={set_index} fold_index={fold_idx} has multiple runs; keeping first by RunIndex." + ) + continue + + run_save_dir_raw = row.get("RunSaveDir") + best_epoch = _as_optional_int(row.get("BestEpoch")) + status = str(row.get("Status", "")).strip().upper() + run_index = _as_optional_int(row.get("RunIndex")) + + if status != "OK": + warnings.append( + f"Skipping fold {fold_idx} run_index={run_index}: status={status or 'UNKNOWN'}" + ) + continue + if run_save_dir_raw is None: + warnings.append( + f"Skipping fold {fold_idx} run_index={run_index}: missing RunSaveDir." + ) + continue + if best_epoch is None or best_epoch < 0: + warnings.append( + f"Skipping fold {fold_idx} run_index={run_index}: invalid BestEpoch={row.get('BestEpoch')}." + ) + continue + + run_save_dir = Path(str(run_save_dir_raw)) + curve_json = run_save_dir / str(val_curve_subdir) / f"epoch_{best_epoch:04d}_curves.json" + if not curve_json.exists(): + warnings.append( + f"Skipping fold {fold_idx} run_index={run_index}: missing curve JSON {curve_json}." + ) + continue + + try: + curve_payload = json.loads(curve_json.read_text(encoding="utf-8")) + except Exception as e: + warnings.append( + f"Skipping fold {fold_idx} run_index={run_index}: could not read {curve_json} ({e})." + ) + continue + if not isinstance(curve_payload, Mapping): + warnings.append( + f"Skipping fold {fold_idx} run_index={run_index}: curve JSON is not an object ({curve_json})." + ) + continue + + metrics_obj = curve_payload.get("eval_metrics") + roc_obj = curve_payload.get("roc_curve") + pr_obj = curve_payload.get("pr_curve") + if not isinstance(metrics_obj, Mapping): + metrics_obj = {} + if not isinstance(roc_obj, Mapping): + roc_obj = {"available": False, "reason": "missing-roc-curve", "fpr": [], "tpr": []} + if not isinstance(pr_obj, Mapping): + pr_obj = { + "available": False, + "reason": "missing-pr-curve", + "recall": [], + "precision": [], + "baseline_positive_rate": None, + "ap": None, + "auprc_trapz": None, + } + + fold_map[int(fold_idx)] = { + "fold_index": int(fold_idx), + "run_index": int(run_index) if run_index is not None else None, + "best_epoch": int(best_epoch), + "run_save_dir": str(run_save_dir), + "curve_json": str(curve_json), + "metrics": dict(metrics_obj), + "roc_curve": dict(roc_obj), + "pr_curve": dict(pr_obj), + } + + fold_entries = [fold_map[k] for k in sorted(fold_map.keys())] + baseline_values: List[float] = [] + for fold_eval in fold_entries: + pr_obj = fold_eval.get("pr_curve") + if not isinstance(pr_obj, Mapping): + continue + baseline_raw = pr_obj.get("baseline_positive_rate") + baseline_value = _finite_or_none(baseline_raw) + if baseline_value is not None: + baseline_values.append(float(baseline_value)) + pr_baseline = ( + float(sum(baseline_values) / len(baseline_values)) + if len(baseline_values) > 0 + else None + ) + return { + "set_index": int(set_index), + "n_folds_with_curves": int(len(fold_entries)), + "folds": fold_entries, + "pr_baseline_mean": pr_baseline, + "warnings": warnings, + } + + +def _write_ensemble_validation_curve_artifacts( + run_rows: Sequence[Mapping[str, Any]], + set_index: int, + output_dir: Path, + plot_formats: Sequence[str], + val_curve_subdir: str = "validation_curves", +) -> Dict[str, Any]: + """Writes fold-level best-epoch ROC/PR plots and sidecar JSON for one set.""" + payload = _build_set_best_fold_curve_payload( + run_rows=run_rows, + set_index=set_index, + val_curve_subdir=val_curve_subdir, + ) + fold_entries = payload["folds"] + baseline_y = payload.get("pr_baseline_mean") + baseline_float = ( + float(baseline_y) + if baseline_y is not None and math.isfinite(float(baseline_y)) + else None + ) + + output_dir.mkdir(parents=True, exist_ok=True) + + plot_status = "ok" + plot_outputs = { + "roc_auc_folds": [], + "pr_auc_folds": [], + "pr_auc_folds_zoom": [], + } + if len(fold_entries) > 0: + try: + pr_zoom_xmax = _resolve_pr_zoom_xmax( + fold_evaluations=fold_entries, + preferred_xmax=0.20, + ) + plot_outputs["roc_auc_folds"] = _plot_fold_curves( + fold_evaluations=fold_entries, + plot_path_base=output_dir / "roc_auc_folds", + formats=plot_formats, + curve_key="roc_curve", + title="Validation ROC Curves by Fold (Best Epoch)", + x_label="FPR", + y_label="TPR", + x_key="fpr", + y_key="tpr", + metric_key="auc", + chance_line=True, + metric_label="AUC", + ) + plot_outputs["pr_auc_folds"] = _plot_fold_curves( + fold_evaluations=fold_entries, + plot_path_base=output_dir / "pr_auc_folds", + formats=plot_formats, + curve_key="pr_curve", + title="Validation PR Curves by Fold (Best Epoch)", + x_label="Recall", + y_label="Precision", + x_key="recall", + y_key="precision", + metric_key="pr_auc", + chance_line=False, + secondary_metric_key="pr_auc_trapz", + metric_label="AP", + secondary_metric_label="AUPRC(trapz)", + baseline_y=baseline_float, + ) + plot_outputs["pr_auc_folds_zoom"] = _plot_fold_curves( + fold_evaluations=fold_entries, + plot_path_base=output_dir / "pr_auc_folds_zoom", + formats=plot_formats, + curve_key="pr_curve", + title="Validation PR Curves by Fold (Best Epoch, Zoomed)", + x_label="Recall", + y_label="Precision", + x_key="recall", + y_key="precision", + metric_key="pr_auc", + chance_line=False, + secondary_metric_key="pr_auc_trapz", + metric_label="AP", + secondary_metric_label="AUPRC(trapz)", + baseline_y=baseline_float, + x_limits=(0.0, float(pr_zoom_xmax)), + y_limits=_resolve_pr_zoom_limits( + fold_evaluations=fold_entries, + baseline_y=baseline_float, + recall_xmax=float(pr_zoom_xmax), + ), + legend_loc="upper right", + ) + except RuntimeError: + plot_status = "matplotlib_unavailable" + else: + plot_status = "no_fold_curves" + + out_payload = { + "set_index": int(set_index), + "plot_status": str(plot_status), + "plot_formats": [str(fmt) for fmt in plot_formats], + "plot_dir": str(output_dir), + "n_folds_with_curves": int(payload["n_folds_with_curves"]), + "pr_baseline_mean": payload.get("pr_baseline_mean"), + "warnings": list(payload["warnings"]), + "plot_outputs": plot_outputs, + } + out_payload["folds"] = list(fold_entries) + out_path = output_dir / "ensemble_fold_validation_curves.json" + out_path.write_text( + json.dumps(_sanitize_for_json(out_payload), indent=2, allow_nan=False), + encoding="utf-8", + ) + return { + "set_index": int(set_index), + "plot_status": str(plot_status), + "plot_formats": [str(fmt) for fmt in plot_formats], + "plot_dir": str(output_dir), + "artifact_json": str(out_path), + "n_folds_with_curves": int(payload["n_folds_with_curves"]), + "pr_baseline_mean": payload.get("pr_baseline_mean"), + "warnings": list(payload["warnings"]), + "plot_outputs": plot_outputs, + } + + @dataclass class RunPlan: """Execution plan for one train/validation run.""" @@ -472,6 +980,27 @@ def main() -> None: type=Path, default=None, help="Optional CSV output path for per-run results") + parser.add_argument("--save-val-curves", + action="store_true", + dest="save_val_curves", + default=False, + help=( + "If set, save per-epoch validation ROC/PR curve data and plots. " + "In ensemble-kfold mode, also writes set-level fold consistency ROC/PR plots " + "using each fold's best epoch." + )) + parser.add_argument("--val-curve-max-points", + action="store", + dest="val_curve_max_points", + type=int, + default=2048, + help="Maximum number of points saved per validation ROC/PR curve.") + parser.add_argument("--val-plot-formats", + action="store", + dest="val_plot_formats", + type=str, + default="png", + help="Comma-separated plot formats for validation curves (png,svg,pdf).") parser.add_argument("--ensemble-manifest", type=Path, default=None, @@ -499,6 +1028,16 @@ def main() -> None: if ddp is not None and rank != 0: logger.disabled = True + val_curve_artifacts = None + if args.save_val_curves: + if args.val_curve_max_points < 2: + raise ValueError("--val-curve-max-points must be >= 2") + val_curve_artifacts = ValidationCurveArtifactConfig( + max_points=int(args.val_curve_max_points), + plot_formats=_parse_plot_formats(args.val_plot_formats), + output_subdir="validation_curves" + ) + results_csv = args.results_csv or ( args.save_path / "multi_run_results.csv") run_rows: List[Dict[str, Any]] = [] @@ -747,7 +1286,10 @@ def main() -> None: run_save_dir = None t0 = time.time() fit_summary = trainer.fit( - save_dir=run_save_dir, score_key=args.best_model_metric) + save_dir=run_save_dir, + score_key=args.best_model_metric, + val_curve_artifacts=val_curve_artifacts + ) elapsed_s = time.time() - t0 if rank == 0: @@ -904,6 +1446,43 @@ def main() -> None: "best_metric_value": row.get("BestMetricValue") }) + set_curve_artifacts: Dict[int, Dict[str, Any]] = {} + if args.save_val_curves: + for set_index in sorted(sets_map.keys()): + entry = sets_map[set_index] + set_dir_raw = entry.get("set_dir") + if set_dir_raw: + curve_out_dir = ( + Path(str(set_dir_raw)) + / "validation_curves" + / "ensemble_folds" + ) + else: + curve_out_dir = ( + args.save_path + / "validation_curves" + / "ensemble_folds" + ) + curve_artifacts = _write_ensemble_validation_curve_artifacts( + run_rows=run_rows, + set_index=int(set_index), + output_dir=curve_out_dir, + plot_formats=tuple(val_curve_artifacts.plot_formats) + if val_curve_artifacts is not None + else ("png",), + ) + set_curve_artifacts[int(set_index)] = curve_artifacts + for warning in curve_artifacts.get("warnings", []): + logger.warning( + "ensemble_validation_curve_warning", + extra={ + "extra": { + "set_index": int(set_index), + "message": str(warning), + } + }, + ) + set_payloads: List[Dict[str, Any]] = [] for set_index in sorted(sets_map.keys()): entry = sets_map[set_index] @@ -936,6 +1515,9 @@ def main() -> None: }, "members": members } + curve_artifacts = set_curve_artifacts.get(int(set_index)) + if curve_artifacts is not None: + set_payload["validation_curve_artifacts"] = curve_artifacts if n_sets == 1: set_manifest_path = args.ensemble_manifest or ( diff --git a/src/pepseqpred/core/train/curveartifacts.py b/src/pepseqpred/core/train/curveartifacts.py new file mode 100644 index 0000000..b8436cb --- /dev/null +++ b/src/pepseqpred/core/train/curveartifacts.py @@ -0,0 +1,318 @@ +"""curve_artifacts.py + +Validation ROC/PR curve artifact helpers for training. + +This module builds deterministic, downsampled ROC/PR payloads from validation +arrays, writes JSON sidecars, and optionally renders plots when matplotlib is +available. +""" + +import json +import math +from pathlib import Path +from typing import Any, Dict, List, Mapping, Sequence, Tuple +import numpy as np +from sklearn.metrics import ( + average_precision_score, + precision_recall_curve, + roc_curve +) + + +def _sanitize_for_json(value: Any) -> Any: + """Recursively converts non-finite floats to None for strict JSON output.""" + if isinstance(value, dict): + return {k: _sanitize_for_json(v) for k, v in value.items()} + if isinstance(value, list): + return [_sanitize_for_json(v) for v in value] + if isinstance(value, float): + return value if math.isfinite(value) else None + return value + + +def _downsample_curve( + x: np.ndarray, + y: np.ndarray, + max_points: int +) -> Tuple[List[float], List[float]]: + """Downsamples curve points deterministically to keep payloads bounded.""" + if x.size != y.size: + raise ValueError(f"Curve x/y size mismatch: x={x.size} y={y.size}") + if max_points < 2: + raise ValueError("--val-curve-max-points must be >= 2") + if x.size <= max_points: + return [float(v) for v in x], [float(v) for v in y] + + idx = np.linspace(0, x.size - 1, num=max_points, dtype=np.int64) + idx = np.unique(idx) + return [float(v) for v in x[idx]], [float(v) for v in y[idx]] + + +def build_roc_curve_payload( + y_true: np.ndarray, + y_prob: np.ndarray, + max_points: int +) -> Dict[str, Any]: + """Builds ROC curve payload from validation labels/probabilities.""" + if max_points < 2: + raise ValueError("--val-curve-max-points must be >= 2") + if y_true.size < 1 or np.unique(y_true).size < 2: + return { + "available": False, + "reason": "single-class-labels", + "fpr": [], + "tpr": [] + } + + fpr, tpr, _ = roc_curve(y_true, y_prob) + fpr_points, tpr_points = _downsample_curve( + x=np.asarray(fpr, dtype=np.float64), + y=np.asarray(tpr, dtype=np.float64), + max_points=max_points + ) + return { + "available": True, + "reason": None, + "fpr": fpr_points, + "tpr": tpr_points + } + + +def _compute_pr_auc_trapezoid(y_true: np.ndarray, y_prob: np.ndarray) -> float: + """Computes PR AUC by trapezoidal integration.""" + if y_true.size < 1: + return float("nan") + + unique_labels = np.unique(y_true) + if unique_labels.size < 2: + return 1.0 if int(unique_labels[0]) == 1 else 0.0 + + precision, recall, _ = precision_recall_curve(y_true, y_prob) + recall = np.asarray(recall, dtype=np.float64)[::-1] + precision = np.asarray(precision, dtype=np.float64)[::-1] + return float(np.trapezoid(precision, recall)) + + +def build_pr_curve_payload( + y_true: np.ndarray, + y_prob: np.ndarray, + max_points: int +) -> Dict[str, Any]: + """Builds PR curve payload from validation labels/probabilities.""" + if max_points < 2: + raise ValueError("--val-curve-max-points must be >= 2") + if y_true.size < 1: + return { + "available": False, + "reason": "no-valid-residues", + "precision": [], + "recall": [], + "baseline_positive_rate": None, + "ap": None, + "auprc_trapz": None + } + + precision, recall, _ = precision_recall_curve(y_true, y_prob) + recall = np.asarray(recall, dtype=np.float64)[::-1] + precision = np.asarray(precision, dtype=np.float64)[::-1] + + try: + ap = float(average_precision_score(y_true, y_prob)) + except Exception: + ap = float("nan") + try: + auprc_trapz = float(_compute_pr_auc_trapezoid(y_true, y_prob)) + except Exception: + auprc_trapz = float("nan") + + recall_points, precision_points = _downsample_curve( + x=recall, + y=precision, + max_points=max_points + ) + pos_rate = float((y_true == 1).mean()) if y_true.size > 0 else float("nan") + return { + "available": True, + "reason": None, + "precision": precision_points, + "recall": recall_points, + "baseline_positive_rate": ( + float(pos_rate) if math.isfinite(pos_rate) else None + ), + "ap": float(ap) if math.isfinite(ap) else None, + "auprc_trapz": float(auprc_trapz) if math.isfinite(auprc_trapz) else None + } + + +def _plot_curve( + plt: Any, + curve_payload: Mapping[str, Any], + path_base: Path, + formats: Sequence[str], + title: str, + x_label: str, + y_label: str, + x_key: str, + y_key: str, + chance_line: bool = False, + baseline_y: float | None = None +) -> List[str]: + """Plots a single curve panel and writes one file per requested format.""" + fig, ax = plt.subplots(figsize=(7.0, 6.0), dpi=150) + + available = bool(curve_payload.get("available", False)) + if available: + x_vals = curve_payload.get(x_key, []) + y_vals = curve_payload.get(y_key, []) + if not isinstance(x_vals, list) or not isinstance(y_vals, list): + x_vals = [] + y_vals = [] + if len(x_vals) >= 2 and len(y_vals) >= 2: + ax.plot(x_vals, y_vals, linewidth=2.0, color="#1f77b4") + else: + available = False + + if not available: + reason = str(curve_payload.get("reason", "unavailable")) + ax.text( + 0.5, + 0.5, + f"Curve unavailable:\n{reason}", + ha="center", + va="center", + fontsize=11 + ) + + if chance_line: + ax.plot([0.0, 1.0], [0.0, 1.0], linestyle="--", + linewidth=1.0, color="#888888") + if baseline_y is not None and math.isfinite(float(baseline_y)): + y_val = float(baseline_y) + ax.plot([0.0, 1.0], [y_val, y_val], linestyle="--", + linewidth=1.0, color="#B05A2B") + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_xlim(0.0, 1.0) + ax.set_ylim(0.0, 1.0) + ax.grid(alpha=0.2) + fig.tight_layout() + + saved_paths: List[str] = [] + path_base.parent.mkdir(parents=True, exist_ok=True) + for fmt in formats: + out_path = path_base.with_suffix(f".{fmt}") + fig.savefig(out_path, dpi=300, bbox_inches="tight") + saved_paths.append(str(out_path)) + plt.close(fig) + return saved_paths + + +def write_validation_curve_artifacts( + epoch: int, + y_true: np.ndarray, + y_prob: np.ndarray, + metrics: Mapping[str, Any] | None, + output_dir: Path | str, + max_points: int = 2048, + plot_formats: Sequence[str] = ("png",) +) -> Dict[str, Any]: + """Writes validation curve JSON payload and optional ROC/PR plots.""" + if max_points < 2: + raise ValueError("--val-curve-max-points must be >= 2") + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + y_true_np = np.asarray(y_true).reshape(-1).astype(np.int64, copy=False) + y_prob_np = np.asarray(y_prob).reshape(-1).astype(np.float64, copy=False) + if y_true_np.size != y_prob_np.size: + raise ValueError( + f"Expected y_true/y_prob length match, got y_true={y_true_np.size} y_prob={y_prob_np.size}" + ) + + fmts = [str(fmt).strip().lower() for fmt in plot_formats if str(fmt).strip()] + if len(fmts) < 1: + raise ValueError("--val-plot-formats must include at least one format") + allowed = {"png", "svg", "pdf"} + bad = [fmt for fmt in fmts if fmt not in allowed] + if len(bad) > 0: + raise ValueError( + f"Unsupported --val-plot-formats values={bad}; allowed={sorted(allowed)}" + ) + + roc_payload = build_roc_curve_payload( + y_true=y_true_np, + y_prob=y_prob_np, + max_points=max_points + ) + pr_payload = build_pr_curve_payload( + y_true=y_true_np, + y_prob=y_prob_np, + max_points=max_points + ) + + payload: Dict[str, Any] = { + "epoch": int(epoch), + "n_residues": int(y_true_np.size), + "max_points": int(max_points), + "plot_formats": [str(fmt) for fmt in fmts], + "eval_metrics": dict(metrics) if isinstance(metrics, Mapping) else {}, + "roc_curve": roc_payload, + "pr_curve": pr_payload + } + + epoch_tag = f"epoch_{int(epoch):04d}" + json_path = output_dir / f"{epoch_tag}_curves.json" + json_path.write_text( + json.dumps(_sanitize_for_json(payload), indent=2, allow_nan=False), + encoding="utf-8" + ) + + artifact_paths: Dict[str, Any] = { + "curve_json": str(json_path), + "roc_auc_plots": [], + "pr_auc_plots": [], + "plot_status": "ok" + } + + try: + import matplotlib.pyplot as plt + except Exception: + artifact_paths["plot_status"] = "matplotlib_unavailable" + return artifact_paths + + baseline_y = pr_payload.get("baseline_positive_rate") + baseline_float = ( + float(baseline_y) + if baseline_y is not None and math.isfinite(float(baseline_y)) + else None + ) + + artifact_paths["roc_auc_plots"] = _plot_curve( + plt=plt, + curve_payload=roc_payload, + path_base=output_dir / f"{epoch_tag}_roc_auc", + formats=fmts, + title=f"Validation ROC Curve (Epoch {int(epoch)})", + x_label="False Positive Rate", + y_label="True Positive Rate", + x_key="fpr", + y_key="tpr", + chance_line=True + ) + artifact_paths["pr_auc_plots"] = _plot_curve( + plt=plt, + curve_payload=pr_payload, + path_base=output_dir / f"{epoch_tag}_pr_auc", + formats=fmts, + title=f"Validation PR Curve (Epoch {int(epoch)})", + x_label="Recall", + y_label="Precision", + x_key="recall", + y_key="precision", + chance_line=False, + baseline_y=baseline_float + ) + return artifact_paths diff --git a/src/pepseqpred/core/train/trainer.py b/src/pepseqpred/core/train/trainer.py index 5936c8e..fabf369 100644 --- a/src/pepseqpred/core/train/trainer.py +++ b/src/pepseqpred/core/train/trainer.py @@ -19,6 +19,7 @@ from .ddp import ddp_rank, ddp_all_reduce_sum, ddp_gather_all_1d from .metrics import compute_eval_metrics from .threshold import find_threshold_max_recall_min_precision +from .curveartifacts import write_validation_curve_artifacts @dataclass @@ -34,6 +35,14 @@ class TrainerConfig: pos_weight: Optional[float] = None +@dataclass(frozen=True) +class ValidationCurveArtifactConfig: + """Configuration for optional validation ROC/PR artifact generation.""" + max_points: int = 2048 + plot_formats: Tuple[str, ...] = ("png",) + output_subdir: str = "validation_curves" + + class Trainer: """ Trainer class used to facilitate model training. Can take in most types of neural networks as input for training. @@ -156,7 +165,12 @@ def _batch_step(self, batch: torch.Tensor, train: bool = True) -> Dict[str, Any] "y": y_flat.detach().cpu(), "probs": probs_flat.detach().cpu()} - def _run_epoch(self, epoch: int, train: bool = True) -> Dict[str, Any]: + def _run_epoch( + self, + epoch: int, + train: bool = True, + capture_eval_arrays: bool = False + ) -> Dict[str, Any]: """Runs one complete epoch (training step) from start to finish.""" loader = self.train_loader if train else self.val_loader if loader is None: @@ -243,6 +257,8 @@ def _run_epoch(self, epoch: int, train: bool = True) -> Dict[str, Any]: # compute eval metrics cm = None eval_metrics = None + y_true_epoch: Optional[np.ndarray] = None + y_prob_epoch: Optional[np.ndarray] = None if not train: # all ranks must participate in gathers to avoid DDP divergence. if len(all_y) > 0: @@ -272,6 +288,9 @@ def _run_epoch(self, epoch: int, train: bool = True) -> Dict[str, Any]: ys) > 0 else np.array([]) y_prob = torch.cat(ps, dim=0).numpy() if len( ps) > 0 else np.array([]) + if capture_eval_arrays: + y_true_epoch = y_true + y_prob_epoch = y_prob # guard against invalid or non-existent residues globally if y_true.size == 0: @@ -344,9 +363,19 @@ def _run_epoch(self, epoch: int, train: bool = True) -> Dict[str, Any]: if not train: out["acc"] = avg_acc out["eval_metrics"] = eval_metrics + if capture_eval_arrays and ddp_rank() == 0: + out["eval_arrays"] = { + "y_true": y_true_epoch if y_true_epoch is not None else np.array([]), + "y_prob": y_prob_epoch if y_prob_epoch is not None else np.array([]) + } return out - def fit(self, save_dir: Optional[Path | str] = None, score_key: str = "loss") -> Dict[str, Any]: + def fit( + self, + save_dir: Optional[Path | str] = None, + score_key: str = "loss", + val_curve_artifacts: Optional[ValidationCurveArtifactConfig] = None + ) -> Dict[str, Any]: """ Fits a neural network model to the data provided. @@ -356,6 +385,8 @@ def fit(self, save_dir: Optional[Path | str] = None, score_key: str = "loss") -> An optional path to a directory to save the best performing model to. score_key : str Score key used to determine the "best" model trained/evaluated so far. Default is `"loss"`. + val_curve_artifacts : ValidationCurveArtifactConfig or None + Optional config to save per-epoch validation ROC/PR curve data and plots. Returns ------- @@ -393,7 +424,13 @@ def fit(self, save_dir: Optional[Path | str] = None, score_key: str = "loss") -> eval_out = None if self.val_loader is not None: - eval_out = self._run_epoch(epoch, train=False) + eval_out = self._run_epoch( + epoch, + train=False, + capture_eval_arrays=( + val_curve_artifacts is not None and save_dir is not None + ) + ) if ddp_rank() == 0 and eval_out["eval_metrics"] is not None: self.logger.info("eval_epoch_summary", extra={"extra": { @@ -412,6 +449,34 @@ def fit(self, save_dir: Optional[Path | str] = None, score_key: str = "loss") -> "pr_auc": float(eval_out["eval_metrics"]["pr_auc"]) }}) + if (save_dir is not None and ddp_rank() == 0 + and val_curve_artifacts is not None + and eval_out.get("eval_metrics") is not None): + eval_arrays = eval_out.get("eval_arrays") + if isinstance(eval_arrays, dict): + y_true = eval_arrays.get("y_true") + y_prob = eval_arrays.get("y_prob") + if isinstance(y_true, np.ndarray) and isinstance(y_prob, np.ndarray): + artifact_paths = write_validation_curve_artifacts( + epoch=epoch, + y_true=y_true, + y_prob=y_prob, + metrics=eval_out["eval_metrics"], + output_dir=Path(save_dir) / + val_curve_artifacts.output_subdir, + max_points=int(val_curve_artifacts.max_points), + plot_formats=tuple( + val_curve_artifacts.plot_formats) + ) + self.logger.info("val_curve_artifacts_written", + extra={"extra": { + "epoch": int(epoch), + "curve_json": artifact_paths.get("curve_json"), + "plot_status": artifact_paths.get("plot_status"), + "roc_auc_plots": artifact_paths.get("roc_auc_plots", []), + "pr_auc_plots": artifact_paths.get("pr_auc_plots", []) + }}) + # save from validated model only if save_dir is not None and ddp_rank() == 0 and eval_out["eval_metrics"] is not None: metric_loss = float(eval_out["loss"]) diff --git a/tests/integration/test_train_clis_inprocess.py b/tests/integration/test_train_clis_inprocess.py index 26d0044..04af1dd 100644 --- a/tests/integration/test_train_clis_inprocess.py +++ b/tests/integration/test_train_clis_inprocess.py @@ -52,6 +52,62 @@ def test_train_ffnn_cli_main_inprocess(training_artifacts, tmp_path: Path, monke assert (save_dir / "multi_run_summary.json").exists() +def test_train_ffnn_cli_main_inprocess_with_val_curves( + training_artifacts, tmp_path: Path, monkeypatch +): + save_dir = tmp_path / "train_out_curves" + + monkeypatch.setattr( + sys, + "argv", + [ + "train_ffnn_cli.py", + "--embedding-dirs", + str(training_artifacts["embedding_dir"]), + "--label-shards", + str(training_artifacts["label_shard"]), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--val-frac", + "0.5", + "--split-seeds", + "11", + "--train-seeds", + "101", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + "--save-val-curves", + "--val-curve-max-points", + "128", + "--val-plot-formats", + "png" + ] + ) + + train_cli.main() + + run_dirs = list(save_dir.glob("run_*")) + assert run_dirs + curves_dir = run_dirs[0] / "validation_curves" + assert (curves_dir / "epoch_0000_curves.json").exists() + + roc_plot = curves_dir / "epoch_0000_roc_auc.png" + pr_plot = curves_dir / "epoch_0000_pr_auc.png" + if roc_plot.exists() or pr_plot.exists(): + assert roc_plot.exists() + assert pr_plot.exists() + + def test_train_ffnn_cli_ensemble_kfold_inprocess(training_artifacts, tmp_path: Path, monkeypatch): save_dir = tmp_path / "ensemble_out" @@ -118,6 +174,94 @@ def test_train_ffnn_cli_ensemble_kfold_inprocess(training_artifacts, tmp_path: P assert all(Path(x["manifest_path"]).exists() for x in payload["sets"]) +def test_train_ffnn_cli_ensemble_kfold_with_aggregate_val_curves( + training_artifacts, tmp_path: Path, monkeypatch +): + save_dir = tmp_path / "ensemble_out_curves" + + monkeypatch.setattr( + sys, + "argv", + [ + "train_ffnn_cli.py", + "--embedding-dirs", + str(training_artifacts["embedding_dir"]), + "--label-shards", + str(training_artifacts["label_shard"]), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--split-type", + "id-family", + "--train-mode", + "ensemble-kfold", + "--n-folds", + "2", + "--split-seeds", + "17", + "--train-seeds", + "101", + "--save-val-curves", + "--val-curve-max-points", + "128", + "--val-plot-formats", + "png", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + ], + ) + + train_cli.main() + + fold_dirs = sorted(save_dir.glob("fold_*")) + assert len(fold_dirs) == 2 + assert all((fold_dir / "fully_connected.pt").exists() for fold_dir in fold_dirs) + assert all( + (fold_dir / "validation_curves" / "epoch_0000_curves.json").exists() + for fold_dir in fold_dirs + ) + + aggregate_json = ( + save_dir + / "validation_curves" + / "ensemble_folds" + / "ensemble_fold_validation_curves.json" + ) + assert aggregate_json.exists() + aggregate_payload = json.loads(aggregate_json.read_text(encoding="utf-8")) + assert int(aggregate_payload["set_index"]) == 1 + assert int(aggregate_payload["n_folds_with_curves"]) == 2 + assert len(aggregate_payload["folds"]) == 2 + + manifest_payload = json.loads( + (save_dir / "ensemble_manifest.json").read_text(encoding="utf-8") + ) + curve_meta = manifest_payload.get("validation_curve_artifacts") + assert isinstance(curve_meta, dict) + assert Path(curve_meta["artifact_json"]).exists() + + if str(curve_meta.get("plot_status")) == "ok": + roc_paths = [Path(p) for p in curve_meta["plot_outputs"]["roc_auc_folds"]] + pr_paths = [Path(p) for p in curve_meta["plot_outputs"]["pr_auc_folds"]] + pr_zoom_paths = [Path(p) + for p in curve_meta["plot_outputs"]["pr_auc_folds_zoom"]] + assert len(roc_paths) == 1 + assert len(pr_paths) == 1 + assert len(pr_zoom_paths) == 1 + assert roc_paths[0].exists() + assert pr_paths[0].exists() + assert pr_zoom_paths[0].exists() + + @pytest.mark.slow def test_train_ffnn_optuna_cli_main_inprocess( training_artifacts, tmp_path: Path, monkeypatch diff --git a/tests/unit/apps/test_train_cli_coverage.py b/tests/unit/apps/test_train_cli_coverage.py new file mode 100644 index 0000000..fff23f4 --- /dev/null +++ b/tests/unit/apps/test_train_cli_coverage.py @@ -0,0 +1,318 @@ +import json +import shutil +import sys +from pathlib import Path +from uuid import uuid4 + +import optuna +import pandas as pd +import pytest +import torch + +import pepseqpred.apps.train_ffnn_cli as train_cli +import pepseqpred.apps.train_ffnn_optuna_cli as optuna_cli + +pytestmark = pytest.mark.unit + + +def _mk_case_dir(tag: str) -> Path: + case_dir = Path("localdata") / f"unit_realcov_{tag}_{uuid4().hex[:8]}" + case_dir.mkdir(parents=True, exist_ok=True) + return case_dir + + +def _write_training_artifacts(case_dir: Path, *, all_uncertain: bool) -> tuple[Path, Path]: + emb_dir = case_dir / "embeddings" + emb_dir.mkdir(parents=True, exist_ok=True) + label_shard = case_dir / "labels_000.pt" + + labels = {} + pos_count = 0 + neg_count = 0 + ids = [("P001", "111"), ("P002", "111"), ("P003", "222"), ("P004", "222")] + for i, (protein_id, family) in enumerate(ids): + torch.manual_seed(10 + i) + emb = torch.randn(6, 4, dtype=torch.float32) + torch.save(emb, emb_dir / f"{protein_id}-{family}.pt") + + if all_uncertain: + y = torch.tensor( + [[0.0, 1.0, 0.0]] * 6, + dtype=torch.float32 + ) + else: + y = torch.tensor([1, 0, 0, 1, 0, 0], dtype=torch.float32) + pos_count += int((y == 1).sum().item()) + neg_count += int((y == 0).sum().item()) + labels[protein_id] = y + + if all_uncertain: + # Keep positive class weight sane for the test run. + pos_count = 1 + neg_count = 1 + + torch.save( + {"labels": labels, "class_stats": {"pos_count": pos_count, "neg_count": neg_count}}, + label_shard + ) + return emb_dir, label_shard + + +def _run_main(entrypoint, argv: list[str]) -> None: + old_argv = list(sys.argv) + sys.argv = argv + try: + entrypoint() + finally: + sys.argv = old_argv + + +def test_train_cli_helper_parsers_and_numeric_summary(): + summary_empty = train_cli.summarize_numeric( + pd.Series([float("nan"), float("inf"), -float("inf")]) + ) + assert summary_empty["count"] == 0 + assert summary_empty["mean"] is None + + summary_ok = train_cli.summarize_numeric(pd.Series([1.0, 2.0, 3.0])) + assert summary_ok["count"] == 3 + assert summary_ok["mean"] == pytest.approx(2.0) + + assert train_cli._finite_or_none("nan") is None + assert train_cli._finite_or_none("abc") is None + assert train_cli._finite_or_none("1.5") == pytest.approx(1.5) + + assert train_cli._parse_plot_formats("png, svg") == ("png", "svg") + with pytest.raises(ValueError, match="val-plot-formats"): + train_cli._parse_plot_formats(",,") + with pytest.raises(ValueError, match="Unsupported"): + train_cli._parse_plot_formats("png,jpg") + + +def test_train_ffnn_cli_real_no_valid_score_with_val_curve_artifacts(): + case_dir = _mk_case_dir("ffnn_no_valid") + emb_dir, label_shard = _write_training_artifacts( + case_dir, all_uncertain=True + ) + save_dir = case_dir / "train_out" + + try: + _run_main( + train_cli.main, + [ + "train_ffnn_cli.py", + "--embedding-dirs", + str(emb_dir), + "--label-shards", + str(label_shard), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--val-frac", + "0.5", + "--split-seeds", + "11", + "--train-seeds", + "101", + "--best-model-metric", + "f1", + "--save-val-curves", + "--val-curve-max-points", + "64", + "--val-plot-formats", + "png", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + ], + ) + + run_dirs = sorted(save_dir.glob("run_*")) + assert len(run_dirs) == 1 + assert (run_dirs[0] / "fully_connected.pt").exists() + assert (run_dirs[0] / "validation_curves" / + "epoch_0000_curves.json").exists() + + runs_df = pd.read_csv(save_dir / "runs.csv") + assert int(runs_df.shape[0]) == 1 + assert str(runs_df.iloc[0]["BestMetricKey"]) == "f1" + assert str(runs_df.iloc[0]["Status"]) == "NO_VALID_SCORE" + + summary = json.loads( + (save_dir / "multi_run_summary.json").read_text(encoding="utf-8") + ) + assert int(summary["n_runs"]) == 1 + finally: + shutil.rmtree(case_dir, ignore_errors=True) + + +def test_train_ffnn_cli_real_ensemble_manifest_generation(): + case_dir = _mk_case_dir("ffnn_ensemble") + emb_dir, label_shard = _write_training_artifacts( + case_dir, all_uncertain=False + ) + save_dir = case_dir / "ensemble_out" + + try: + _run_main( + train_cli.main, + [ + "train_ffnn_cli.py", + "--embedding-dirs", + str(emb_dir), + "--label-shards", + str(label_shard), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--train-mode", + "ensemble-kfold", + "--split-type", + "id-family", + "--n-folds", + "2", + "--split-seeds", + "17,19", + "--train-seeds", + "101,202", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + "--ensemble-manifest", + str(save_dir / "ensemble_manifest.json"), + ], + ) + + payload = json.loads( + (save_dir / "ensemble_manifest.json").read_text(encoding="utf-8") + ) + assert payload["train_mode"] == "ensemble-kfold" + assert payload["n_sets"] == 2 + assert len(payload["sets"]) == 2 + assert all(int(x["n_members"]) == 2 for x in payload["sets"]) + finally: + shutil.rmtree(case_dir, ignore_errors=True) + + +def test_train_ffnn_optuna_cli_real_with_storage_and_helpers(): + assert optuna_cli._broadcast_params({"a": 1}, None) == {"a": 1} + + study = optuna.create_study(sampler=optuna.samplers.RandomSampler(seed=7)) + trial_flat = study.ask() + sizes_flat, drop_flat, depth_flat, _ = optuna_cli.build_hidden_sizes( + trial=trial_flat, + depth_min=2, + depth_max=2, + width_min=64, + width_max=64, + mode="flat" + ) + assert depth_flat == 2 + assert sizes_flat == (64, 64) + assert len(drop_flat) == 2 + + trial_bottle = study.ask() + sizes_bottle, _, _, _ = optuna_cli.build_hidden_sizes( + trial=trial_bottle, + depth_min=3, + depth_max=3, + width_min=32, + width_max=96, + mode="bottleneck" + ) + assert all( + sizes_bottle[i] >= sizes_bottle[i + 1] + for i in range(len(sizes_bottle) - 1) + ) + + trial_pyramid = study.ask() + sizes_pyramid, _, _, _ = optuna_cli.build_hidden_sizes( + trial=trial_pyramid, + depth_min=3, + depth_max=3, + width_min=32, + width_max=96, + mode="pyramid" + ) + assert all( + sizes_pyramid[i] <= sizes_pyramid[i + 1] + for i in range(len(sizes_pyramid) - 1) + ) + + case_dir = _mk_case_dir("optuna_storage") + emb_dir, label_shard = _write_training_artifacts( + case_dir, all_uncertain=False + ) + save_dir = case_dir / "optuna_out" + csv_path = save_dir / "trials.csv" + storage_uri = f"sqlite:///{(case_dir / 'study.db').as_posix()}" + + try: + _run_main( + optuna_cli.main, + [ + "train_ffnn_optuna_cli.py", + "--embedding-dirs", + str(emb_dir), + "--label-shards", + str(label_shard), + "--storage", + storage_uri, + "--n-trials", + "1", + "--epochs", + "1", + "--val-frac", + "0.5", + "--subset", + "4", + "--batch-sizes", + "2", + "--num-workers", + "0", + "--metric", + "auc", + "--arch-mode", + "flat", + "--depth-min", + "1", + "--depth-max", + "1", + "--width-min", + "64", + "--width-max", + "64", + "--save-path", + str(save_dir), + "--csv-path", + str(csv_path), + "--study-name", + "unit_realcov_study", + ], + ) + + best_payload = json.loads( + (save_dir / "best_trial.json").read_text(encoding="utf-8") + ) + assert best_payload["study_name"] == "unit_realcov_study" + assert best_payload["metric"] == "auc" + assert csv_path.exists() + assert (case_dir / "study.db").exists() + finally: + shutil.rmtree(case_dir, ignore_errors=True) diff --git a/tests/unit/core/train/test_curve_artifacts.py b/tests/unit/core/train/test_curve_artifacts.py new file mode 100644 index 0000000..7b88d68 --- /dev/null +++ b/tests/unit/core/train/test_curve_artifacts.py @@ -0,0 +1,92 @@ +import json +import numpy as np +import pytest +from pepseqpred.core.train.curveartifacts import ( + _downsample_curve, + build_roc_curve_payload, + build_pr_curve_payload, + write_validation_curve_artifacts +) + +pytestmark = pytest.mark.unit + + +def test_downsample_curve_is_deterministic(): + x = np.linspace(0.0, 1.0, num=11, dtype=np.float64) + y = x ** 2 + x_out, y_out = _downsample_curve(x=x, y=y, max_points=5) + + assert len(x_out) == 5 + assert len(y_out) == 5 + assert x_out[0] == pytest.approx(0.0) + assert x_out[-1] == pytest.approx(1.0) + assert y_out[0] == pytest.approx(0.0) + assert y_out[-1] == pytest.approx(1.0) + + +def test_build_roc_curve_payload_handles_single_class(): + y_true = np.zeros(8, dtype=np.int64) + y_prob = np.linspace(0.0, 1.0, num=8, dtype=np.float64) + out = build_roc_curve_payload(y_true=y_true, y_prob=y_prob, max_points=32) + + assert out["available"] is False + assert out["reason"] == "single-class-labels" + assert out["fpr"] == [] + assert out["tpr"] == [] + + +def test_build_pr_curve_payload_handles_no_valid_residues(): + y_true = np.asarray([], dtype=np.int64) + y_prob = np.asarray([], dtype=np.float64) + out = build_pr_curve_payload(y_true=y_true, y_prob=y_prob, max_points=32) + + assert out["available"] is False + assert out["reason"] == "no-valid-residues" + assert out["precision"] == [] + assert out["recall"] == [] + assert out["ap"] is None + assert out["auprc_trapz"] is None + + +def test_write_validation_curve_artifacts_rejects_invalid_max_points(tmp_path): + y_true = np.asarray([0, 1], dtype=np.int64) + y_prob = np.asarray([0.2, 0.8], dtype=np.float64) + with pytest.raises(ValueError, match="val-curve-max-points"): + write_validation_curve_artifacts( + epoch=0, + y_true=y_true, + y_prob=y_prob, + metrics={}, + output_dir=tmp_path, + max_points=1, + plot_formats=("png",) + ) + + +def test_write_validation_curve_artifacts_writes_json_and_plots_when_available(tmp_path): + y_true = np.asarray([0, 1, 1, 0, 1, 0], dtype=np.int64) + y_prob = np.asarray([0.05, 0.9, 0.65, 0.3, 0.8, 0.25], dtype=np.float64) + out = write_validation_curve_artifacts( + epoch=0, + y_true=y_true, + y_prob=y_prob, + metrics={"auc": float("nan"), "pr_auc": 0.77}, + output_dir=tmp_path, + max_points=16, + plot_formats=("png",) + ) + + json_path = tmp_path / "epoch_0000_curves.json" + assert json_path.exists() + payload = json.loads(json_path.read_text(encoding="utf-8")) + assert payload["epoch"] == 0 + assert payload["eval_metrics"]["auc"] is None + assert payload["eval_metrics"]["pr_auc"] == pytest.approx(0.77) + assert isinstance(payload["roc_curve"]["fpr"], list) + assert isinstance(payload["pr_curve"]["precision"], list) + + if out["plot_status"] == "ok": + assert (tmp_path / "epoch_0000_roc_auc.png").exists() + assert (tmp_path / "epoch_0000_pr_auc.png").exists() + else: + assert out["plot_status"] == "matplotlib_unavailable" diff --git a/tests/unit/core/train/test_trainer_fit.py b/tests/unit/core/train/test_trainer_fit.py index c01ab62..3b290c7 100644 --- a/tests/unit/core/train/test_trainer_fit.py +++ b/tests/unit/core/train/test_trainer_fit.py @@ -4,7 +4,11 @@ import pytest import torch from pepseqpred.core.models.ffnn import PepSeqFFNN -from pepseqpred.core.train.trainer import Trainer, TrainerConfig +from pepseqpred.core.train.trainer import ( + Trainer, + TrainerConfig, + ValidationCurveArtifactConfig +) pytestmark = pytest.mark.unit @@ -66,6 +70,30 @@ def test_fit_without_validation_saves_no_val_checkpoint(tmp_path: Path): assert (tmp_path / "fully_connected_no_val.pt").exists() +def test_fit_with_validation_curve_artifacts(tmp_path: Path): + trainer = _make_trainer( + _make_batches(), _make_batches(), emb_dim=4, epochs=1) + summary = trainer.fit( + save_dir=tmp_path, + score_key="loss", + val_curve_artifacts=ValidationCurveArtifactConfig( + max_points=32, + plot_formats=("png",), + output_subdir="validation_curves" + ) + ) + + assert summary["best_epoch"] >= 0 + curves_dir = tmp_path / "validation_curves" + assert (curves_dir / "epoch_0000_curves.json").exists() + + roc_plot = curves_dir / "epoch_0000_roc_auc.png" + pr_plot = curves_dir / "epoch_0000_pr_auc.png" + if roc_plot.exists() or pr_plot.exists(): + assert roc_plot.exists() + assert pr_plot.exists() + + def test_run_epoch_eval_no_valid_residues(): train_loader = _make_batches(n_batches=1, mask_value=1) val_loader = _make_batches(n_batches=1, mask_value=0)