|
| 1 | +"""Evaluate all interpretability methods on StageNet + MIMIC-IV dataset using comprehensiveness |
| 2 | +and sufficiency metrics. |
| 3 | +
|
| 4 | +This example demonstrates: |
| 5 | +1. Loading a pre-trained StageNet model with processors and MIMIC-IV dataset |
| 6 | +2. Computing attributions with various interpretability methods |
| 7 | +3. Evaluating attribution faithfulness with Comprehensiveness & Sufficiency for each method |
| 8 | +4. Presenting results in a summary table |
| 9 | +""" |
| 10 | + |
| 11 | +import datetime |
| 12 | +import argparse |
| 13 | + |
| 14 | +import torch |
| 15 | +from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient |
| 16 | +from pyhealth.interpret.methods import * |
| 17 | +from pyhealth.metrics.interpretability import evaluate_attribution |
| 18 | +from pyhealth.metrics.interpretability.utils import SampleClass |
| 19 | +from pyhealth.models import Transformer |
| 20 | +from pyhealth.tasks import MortalityPredictionStageNetMIMIC4 |
| 21 | +from pyhealth.trainer import Trainer |
| 22 | +from pyhealth.datasets.utils import load_processors |
| 23 | +from pathlib import Path |
| 24 | +import pandas as pd |
| 25 | + |
| 26 | +# python -u examples/interpretability/custom_sample_filter.py --pos_threshold 0.5 --neg_threshold 0.1 --device cuda:2 |
| 27 | +def main(): |
| 28 | + parser = argparse.ArgumentParser( |
| 29 | + description="Comma separated list of interpretability methods to evaluate" |
| 30 | + ) |
| 31 | + parser.add_argument( |
| 32 | + "--pos_threshold", |
| 33 | + type=float, |
| 34 | + default=None, |
| 35 | + help="Positive threshold for interpretability evaluation (default: 0.5).", |
| 36 | + ) |
| 37 | + parser.add_argument( |
| 38 | + "--neg_threshold", |
| 39 | + type=float, |
| 40 | + default=None, |
| 41 | + help="Negative threshold for interpretability evaluation (default: 0.5).", |
| 42 | + ) |
| 43 | + parser.add_argument( |
| 44 | + "--device", |
| 45 | + type=str, |
| 46 | + default="cuda:0", |
| 47 | + help="Device to use for evaluation (default: cuda:0)", |
| 48 | + ) |
| 49 | + args = parser.parse_args() |
| 50 | + """Main execution function.""" |
| 51 | + print("=" * 70) |
| 52 | + print("Interpretability Metrics Example: Transformer + MIMIC-IV") |
| 53 | + print("=" * 70) |
| 54 | + |
| 55 | + now = datetime.datetime.now() |
| 56 | + print(f"Start Time: {now.strftime('%Y-%m-%d %H:%M:%S')}") |
| 57 | + |
| 58 | + # Set path |
| 59 | + CACHE_DIR = Path("/home/yongdaf2/interpret/cache/mp_mimic4") |
| 60 | + CKPTS_DIR = Path("/shared/eng/pyhealth_dka/ckpts/mp_transformer_mimic4") |
| 61 | + OUTPUT_DIR = Path("/home/yongdaf2/interpret/output/mp_transformer_mimic4") |
| 62 | + CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| 63 | + CKPTS_DIR.mkdir(parents=True, exist_ok=True) |
| 64 | + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
| 65 | + print(f"\nUsing cache dir: {CACHE_DIR}") |
| 66 | + print(f"Using checkpoints dir: {CKPTS_DIR}") |
| 67 | + print(f"Using output dir: {OUTPUT_DIR}") |
| 68 | + |
| 69 | + # Set device |
| 70 | + device = args.device |
| 71 | + print(f"\nUsing device: {device}") |
| 72 | + |
| 73 | + # Load MIMIC-IV dataset |
| 74 | + print("\n Loading MIMIC-IV dataset...") |
| 75 | + base_dataset = MIMIC4Dataset( |
| 76 | + ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/", |
| 77 | + ehr_tables=[ |
| 78 | + "patients", |
| 79 | + "admissions", |
| 80 | + "diagnoses_icd", |
| 81 | + "procedures_icd", |
| 82 | + "labevents", |
| 83 | + ], |
| 84 | + cache_dir=str(CACHE_DIR), |
| 85 | + num_workers=16, |
| 86 | + ) |
| 87 | + |
| 88 | + # Apply mortality prediction task |
| 89 | + if not (CKPTS_DIR / "input_processors.pkl").exists(): |
| 90 | + raise FileNotFoundError(f"Input processors not found in {CKPTS_DIR}. ") |
| 91 | + if not (CKPTS_DIR / "output_processors.pkl").exists(): |
| 92 | + raise FileNotFoundError(f"Output processors not found in {CKPTS_DIR}. ") |
| 93 | + input_processors, output_processors = load_processors(str(CKPTS_DIR)) |
| 94 | + print("✓ Loaded input and output processors from checkpoint directory.") |
| 95 | + |
| 96 | + sample_dataset = base_dataset.set_task( |
| 97 | + MortalityPredictionStageNetMIMIC4(), |
| 98 | + num_workers=16, |
| 99 | + input_processors=input_processors, |
| 100 | + output_processors=output_processors, |
| 101 | + ) |
| 102 | + print(f"✓ Loaded {len(sample_dataset)} samples") |
| 103 | + |
| 104 | + # Split dataset and get test loader |
| 105 | + _, _, test_dataset = split_by_patient(sample_dataset, [0.9, 0.09, 0.01], seed=233) |
| 106 | + test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False) |
| 107 | + print(f"✓ Test set: {len(test_dataset)} samples") |
| 108 | + |
| 109 | + # Initialize and load pre-trained model |
| 110 | + print("\n Loading pre-trained Transformer model...") |
| 111 | + model = Transformer( |
| 112 | + dataset=sample_dataset, |
| 113 | + embedding_dim=128, |
| 114 | + heads=4, |
| 115 | + dropout=0.3, |
| 116 | + num_layers=3, |
| 117 | + ) |
| 118 | + |
| 119 | + trainer = Trainer(model=model, device=device) |
| 120 | + trainer.load_ckpt(str(CKPTS_DIR / "best.ckpt")) |
| 121 | + model = model.to(device) |
| 122 | + model.eval() |
| 123 | + print(f"✓ Loaded checkpoint: {CKPTS_DIR / 'best.ckpt'}") |
| 124 | + print(f"✓ Model moved to {device}") |
| 125 | + |
| 126 | + pos_threshold = args.pos_threshold |
| 127 | + neg_threshold = args.neg_threshold |
| 128 | + def sample_filter_fn( |
| 129 | + y_probs: torch.Tensor, |
| 130 | + classifier_type: str, |
| 131 | + ) -> torch.Tensor: |
| 132 | + """ |
| 133 | + Custom sample filter function that classifies samples based on |
| 134 | + positive and negative probability thresholds. |
| 135 | +
|
| 136 | + negative samples: 0 < y_probs < neg_threshold |
| 137 | + ignored samples: neg_threshold <= y_probs < pos_threshold |
| 138 | + positive samples: y_probs >= pos_threshold |
| 139 | + """ |
| 140 | + nonlocal pos_threshold, neg_threshold |
| 141 | + batch_size = y_probs.shape[0] |
| 142 | + result = torch.full( |
| 143 | + (batch_size,), |
| 144 | + SampleClass.POSITIVE, |
| 145 | + dtype=torch.long, |
| 146 | + device=y_probs.device, |
| 147 | + ) |
| 148 | + if classifier_type in ("binary", "multilabel"): |
| 149 | + if pos_threshold is not None: |
| 150 | + result[y_probs < pos_threshold] = SampleClass.IGNORE |
| 151 | + if neg_threshold is not None: |
| 152 | + result[y_probs < neg_threshold] = SampleClass.NEGATIVE |
| 153 | + return result |
| 154 | + |
| 155 | + interpreter = IntegratedGradients(model, use_embeddings=True) |
| 156 | + print(f"\nEvaluating using Integrated Gradients...") |
| 157 | + |
| 158 | + # Option 1: Functional API (simple one-off evaluation) |
| 159 | + print("\nEvaluating with Functional API on full dataset...") |
| 160 | + print("Using: evaluate_attribution(model, dataloader, method, ...)") |
| 161 | + |
| 162 | + results_functional = evaluate_attribution( |
| 163 | + model, |
| 164 | + test_loader, |
| 165 | + interpreter, |
| 166 | + metrics=["comprehensiveness", "sufficiency"], |
| 167 | + percentages=[25, 50, 99], |
| 168 | + sample_filter=sample_filter_fn, |
| 169 | + ) |
| 170 | + |
| 171 | + print("\n" + "=" * 70) |
| 172 | + print("Dataset-Wide Results (Functional API)") |
| 173 | + print("=" * 70) |
| 174 | + comp = results_functional["comprehensiveness"] |
| 175 | + suff = results_functional["sufficiency"] |
| 176 | + print(f"\nComprehensiveness: {comp:.4f}") |
| 177 | + print(f"Sufficiency: {suff:.4f}") |
| 178 | + |
| 179 | + print("") |
| 180 | + print("=" * 70) |
| 181 | + print("Summary of Results for All Methods") |
| 182 | + print({"Method": "Integrated Gradients", "Comprehensiveness": comp, "Sufficiency": suff}) |
| 183 | + |
| 184 | + end = datetime.datetime.now() |
| 185 | + print(f"End Time: {end.strftime('%Y-%m-%d %H:%M:%S')}") |
| 186 | + print(f"Total Duration: {end - now}") |
| 187 | + |
| 188 | +if __name__ == "__main__": |
| 189 | + main() |
0 commit comments