Skip to content

Commit 39b86e4

Browse files
authored
Fix Interpretability Methods target_class_idx (#926)
* rename arg name for chefer * Initial attempts to fix the interpretability target_class_idx * Support negative prediction for interpretability metric. * Fix tests * Fix more tests * Revert "Support negative prediction for interpretability metric." This reverts commit fe8c8ad. * Reapply "Support all samples for interpretability metric" * Initial attempt for the filter * Fixup * Fix sample_class handling * fixup * fix test * Fix arg name * Add example * fix docs
1 parent d7641e0 commit 39b86e4

24 files changed

Lines changed: 691 additions & 589 deletions

examples/cxr/covid19cxr_tutorial.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,7 @@
13391339
" # Input size is inferred automatically from image dimensions\n",
13401340
" result = chefer_gen.attribute(\n",
13411341
" interpolate=True,\n",
1342-
" class_index=pred_class,\n",
1342+
" target_class_idx=pred_class,\n",
13431343
" **batch\n",
13441344
" )\n",
13451345
" attr_map = result[\"image\"] # Keyed by task schema's feature key\n",

examples/cxr/covid19cxr_tutorial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
# Compute attribution for each class in the prediction set
132132
overlays = []
133133
for class_idx in predset_class_indices:
134-
attr_map = chefer.attribute(class_index=class_idx, **batch)["image"]
134+
attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"]
135135
_, _, overlay = visualize_image_attr(
136136
image=batch["image"][0],
137137
attribution=attr_map[0, 0],

examples/cxr/covid19cxr_tutorial_display.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
# Compute attribution for each class in the prediction set
129129
overlays = []
130130
for class_idx in predset_class_indices:
131-
attr_map = chefer.attribute(class_index=class_idx, **batch)["image"]
131+
attr_map = chefer.attribute(target_class_idx=class_idx, **batch)["image"]
132132
_, _, overlay = visualize_image_attr(
133133
image=batch["image"][0],
134134
attribution=attr_map[0, 0],
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""Evaluate all interpretability methods on StageNet + MIMIC-IV dataset using comprehensiveness
2+
and sufficiency metrics.
3+
4+
This example demonstrates:
5+
1. Loading a pre-trained StageNet model with processors and MIMIC-IV dataset
6+
2. Computing attributions with various interpretability methods
7+
3. Evaluating attribution faithfulness with Comprehensiveness & Sufficiency for each method
8+
4. Presenting results in a summary table
9+
"""
10+
11+
import datetime
12+
import argparse
13+
14+
import torch
15+
from pyhealth.datasets import MIMIC4Dataset, get_dataloader, split_by_patient
16+
from pyhealth.interpret.methods import *
17+
from pyhealth.metrics.interpretability import evaluate_attribution
18+
from pyhealth.metrics.interpretability.utils import SampleClass
19+
from pyhealth.models import Transformer
20+
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
21+
from pyhealth.trainer import Trainer
22+
from pyhealth.datasets.utils import load_processors
23+
from pathlib import Path
24+
import pandas as pd
25+
26+
# python -u examples/interpretability/custom_sample_filter.py --pos_threshold 0.5 --neg_threshold 0.1 --device cuda:2
27+
def main():
28+
parser = argparse.ArgumentParser(
29+
description="Comma separated list of interpretability methods to evaluate"
30+
)
31+
parser.add_argument(
32+
"--pos_threshold",
33+
type=float,
34+
default=None,
35+
help="Positive threshold for interpretability evaluation (default: 0.5).",
36+
)
37+
parser.add_argument(
38+
"--neg_threshold",
39+
type=float,
40+
default=None,
41+
help="Negative threshold for interpretability evaluation (default: 0.5).",
42+
)
43+
parser.add_argument(
44+
"--device",
45+
type=str,
46+
default="cuda:0",
47+
help="Device to use for evaluation (default: cuda:0)",
48+
)
49+
args = parser.parse_args()
50+
"""Main execution function."""
51+
print("=" * 70)
52+
print("Interpretability Metrics Example: Transformer + MIMIC-IV")
53+
print("=" * 70)
54+
55+
now = datetime.datetime.now()
56+
print(f"Start Time: {now.strftime('%Y-%m-%d %H:%M:%S')}")
57+
58+
# Set path
59+
CACHE_DIR = Path("/home/yongdaf2/interpret/cache/mp_mimic4")
60+
CKPTS_DIR = Path("/shared/eng/pyhealth_dka/ckpts/mp_transformer_mimic4")
61+
OUTPUT_DIR = Path("/home/yongdaf2/interpret/output/mp_transformer_mimic4")
62+
CACHE_DIR.mkdir(parents=True, exist_ok=True)
63+
CKPTS_DIR.mkdir(parents=True, exist_ok=True)
64+
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
65+
print(f"\nUsing cache dir: {CACHE_DIR}")
66+
print(f"Using checkpoints dir: {CKPTS_DIR}")
67+
print(f"Using output dir: {OUTPUT_DIR}")
68+
69+
# Set device
70+
device = args.device
71+
print(f"\nUsing device: {device}")
72+
73+
# Load MIMIC-IV dataset
74+
print("\n Loading MIMIC-IV dataset...")
75+
base_dataset = MIMIC4Dataset(
76+
ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
77+
ehr_tables=[
78+
"patients",
79+
"admissions",
80+
"diagnoses_icd",
81+
"procedures_icd",
82+
"labevents",
83+
],
84+
cache_dir=str(CACHE_DIR),
85+
num_workers=16,
86+
)
87+
88+
# Apply mortality prediction task
89+
if not (CKPTS_DIR / "input_processors.pkl").exists():
90+
raise FileNotFoundError(f"Input processors not found in {CKPTS_DIR}. ")
91+
if not (CKPTS_DIR / "output_processors.pkl").exists():
92+
raise FileNotFoundError(f"Output processors not found in {CKPTS_DIR}. ")
93+
input_processors, output_processors = load_processors(str(CKPTS_DIR))
94+
print("✓ Loaded input and output processors from checkpoint directory.")
95+
96+
sample_dataset = base_dataset.set_task(
97+
MortalityPredictionStageNetMIMIC4(),
98+
num_workers=16,
99+
input_processors=input_processors,
100+
output_processors=output_processors,
101+
)
102+
print(f"✓ Loaded {len(sample_dataset)} samples")
103+
104+
# Split dataset and get test loader
105+
_, _, test_dataset = split_by_patient(sample_dataset, [0.9, 0.09, 0.01], seed=233)
106+
test_loader = get_dataloader(test_dataset, batch_size=16, shuffle=False)
107+
print(f"✓ Test set: {len(test_dataset)} samples")
108+
109+
# Initialize and load pre-trained model
110+
print("\n Loading pre-trained Transformer model...")
111+
model = Transformer(
112+
dataset=sample_dataset,
113+
embedding_dim=128,
114+
heads=4,
115+
dropout=0.3,
116+
num_layers=3,
117+
)
118+
119+
trainer = Trainer(model=model, device=device)
120+
trainer.load_ckpt(str(CKPTS_DIR / "best.ckpt"))
121+
model = model.to(device)
122+
model.eval()
123+
print(f"✓ Loaded checkpoint: {CKPTS_DIR / 'best.ckpt'}")
124+
print(f"✓ Model moved to {device}")
125+
126+
pos_threshold = args.pos_threshold
127+
neg_threshold = args.neg_threshold
128+
def sample_filter_fn(
129+
y_probs: torch.Tensor,
130+
classifier_type: str,
131+
) -> torch.Tensor:
132+
"""
133+
Custom sample filter function that classifies samples based on
134+
positive and negative probability thresholds.
135+
136+
negative samples: 0 < y_probs < neg_threshold
137+
ignored samples: neg_threshold <= y_probs < pos_threshold
138+
positive samples: y_probs >= pos_threshold
139+
"""
140+
nonlocal pos_threshold, neg_threshold
141+
batch_size = y_probs.shape[0]
142+
result = torch.full(
143+
(batch_size,),
144+
SampleClass.POSITIVE,
145+
dtype=torch.long,
146+
device=y_probs.device,
147+
)
148+
if classifier_type in ("binary", "multilabel"):
149+
if pos_threshold is not None:
150+
result[y_probs < pos_threshold] = SampleClass.IGNORE
151+
if neg_threshold is not None:
152+
result[y_probs < neg_threshold] = SampleClass.NEGATIVE
153+
return result
154+
155+
interpreter = IntegratedGradients(model, use_embeddings=True)
156+
print(f"\nEvaluating using Integrated Gradients...")
157+
158+
# Option 1: Functional API (simple one-off evaluation)
159+
print("\nEvaluating with Functional API on full dataset...")
160+
print("Using: evaluate_attribution(model, dataloader, method, ...)")
161+
162+
results_functional = evaluate_attribution(
163+
model,
164+
test_loader,
165+
interpreter,
166+
metrics=["comprehensiveness", "sufficiency"],
167+
percentages=[25, 50, 99],
168+
sample_filter=sample_filter_fn,
169+
)
170+
171+
print("\n" + "=" * 70)
172+
print("Dataset-Wide Results (Functional API)")
173+
print("=" * 70)
174+
comp = results_functional["comprehensiveness"]
175+
suff = results_functional["sufficiency"]
176+
print(f"\nComprehensiveness: {comp:.4f}")
177+
print(f"Sufficiency: {suff:.4f}")
178+
179+
print("")
180+
print("=" * 70)
181+
print("Summary of Results for All Methods")
182+
print({"Method": "Integrated Gradients", "Comprehensiveness": comp, "Sufficiency": suff})
183+
184+
end = datetime.datetime.now()
185+
print(f"End Time: {end.strftime('%Y-%m-%d %H:%M:%S')}")
186+
print(f"Total Duration: {end - now}")
187+
188+
if __name__ == "__main__":
189+
main()

pyhealth/interpret/methods/base_interpreter.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13-
from typing import Dict, cast
13+
from typing import Dict, Optional, cast
1414

1515
import torch
1616
import torch.nn as nn
@@ -138,8 +138,12 @@ def attribute(
138138
by the task's ``input_schema``.
139139
- Label key (optional): Ground truth labels, may be needed
140140
by some methods for loss computation.
141-
- ``class_index`` (optional): Target class for attribution.
142-
If not provided, uses the predicted class.
141+
- ``target_class_idx`` (Optional[int]): Target class for
142+
attribution. For binary classification (single logit
143+
output), this is a no-op because there is only one
144+
output. For multi-class or multi-label classification,
145+
specifies which class index to explain. If not provided,
146+
uses the argmax of logits.
143147
- Additional method-specific parameters (e.g., ``baseline``,
144148
``steps``, ``interpolate``).
145149
@@ -207,6 +211,42 @@ def attribute(
207211
"""
208212
pass
209213

214+
def _resolve_target_indices(
215+
self,
216+
logits: torch.Tensor,
217+
target_class_idx: Optional[int],
218+
) -> torch.Tensor:
219+
"""Resolve target class indices for attribution.
220+
221+
Returns a ``[batch]`` tensor of class indices identifying which
222+
logit to explain. All prediction modes share this single code
223+
path:
224+
225+
* **Binary** (single logit): ``target_class_idx`` is a no-op
226+
because there is only one output. Always returns zeros
227+
(index 0).
228+
* **Multi-class / multi-label**: uses ``target_class_idx`` if
229+
given, otherwise the argmax of logits.
230+
231+
Args:
232+
logits: Model output logits, shape ``[batch, num_classes]``.
233+
target_class_idx: Optional user-specified class index.
234+
235+
Returns:
236+
``torch.LongTensor`` of shape ``[batch]``.
237+
"""
238+
if logits.shape[-1] == 1:
239+
# Single logit output — nothing to select.
240+
return torch.zeros(
241+
logits.shape[0], device=logits.device, dtype=torch.long,
242+
)
243+
if target_class_idx is not None:
244+
return torch.full(
245+
(logits.shape[0],), target_class_idx,
246+
device=logits.device, dtype=torch.long,
247+
)
248+
return logits.argmax(dim=-1)
249+
210250
def _prediction_mode(self) -> str:
211251
"""Resolve the prediction mode from the model.
212252

pyhealth/interpret/methods/chefer.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ class CheferRelevance(BaseInterpreter):
128128
>>> print(attributions["conditions"].shape) # [batch, num_tokens]
129129
>>>
130130
>>> # Optional: attribute to a specific class (e.g., class 1)
131-
>>> attributions = interpreter.attribute(class_index=1, **batch)
131+
>>> attributions = interpreter.attribute(target_class_idx=1, **batch)
132132
"""
133133

134134
def __init__(self, model: BaseModel):
@@ -139,14 +139,16 @@ def __init__(self, model: BaseModel):
139139

140140
def attribute(
141141
self,
142-
class_index: Optional[int] = None,
142+
target_class_idx: Optional[int] = None,
143143
**data,
144144
) -> Dict[str, torch.Tensor]:
145145
"""Compute relevance scores for each input token.
146146
147147
Args:
148-
class_index: Target class index to compute attribution for.
149-
If None (default), uses the model's predicted class.
148+
target_class_idx: Target class index to compute attribution for.
149+
If None (default), uses the argmax of model output.
150+
For binary classification (single logit output), this is
151+
a no-op because there is only one output.
150152
**data: Input data from dataloader batch containing feature
151153
keys and label key.
152154
@@ -163,15 +165,10 @@ def attribute(
163165
self.model.set_attention_hooks(False)
164166

165167
# --- 2. Backward from target class ---
166-
if class_index is None:
167-
class_index_t = torch.argmax(logits, dim=-1)
168-
elif isinstance(class_index, int):
169-
class_index_t = torch.tensor(class_index)
170-
else:
171-
class_index_t = class_index
168+
target_indices = self._resolve_target_indices(logits, target_class_idx)
172169

173170
one_hot = F.one_hot(
174-
class_index_t.detach().clone(), logits.size(1)
171+
target_indices.detach().clone(), logits.size(1)
175172
).float()
176173
one_hot = one_hot.requires_grad_(True)
177174
scalar = torch.sum(one_hot.to(logits.device) * logits)

0 commit comments

Comments
 (0)