Skip to content

Commit 27c3c21

Browse files
author
Fan Zhang
committed
Replicate isic2018 experiment with isic2018
1 parent ae16408 commit 27c3c21

9 files changed

Lines changed: 685 additions & 5 deletions
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""
2+
ISIC 2018 — Binary melanoma classification under artifact modes.
3+
4+
Replicates image-mode experiments from:
5+
"A Study of Artifacts on Melanoma Classification under Diffusion-Based Perturbations"
6+
7+
Supports all 12 preprocessing modes (whole, lesion, background, bbox, bbox70,
8+
bbox90, high_whole, low_whole, high_lesion, low_lesion, high_background,
9+
low_background) via ``--mode``.
10+
11+
5-fold stratified cross-validation splits are generated from sample labels.
12+
13+
Expected directory structure under ``--root``::
14+
15+
<root>/
16+
<annotations_csv> ← annotation file (--annotations_csv)
17+
<image_dir>/ ← images (--image_dir)
18+
<mask_dir>/ ← masks (--mask_dir)
19+
20+
Annotation file:
21+
https://github.com/alceubissoto/debiasing-skin/tree/main/artefacts-annotation
22+
23+
Images / masks (~9 GB total):
24+
https://challenge.isic-archive.com/data/#2018
25+
(or pass ``--download`` to fetch automatically)
26+
27+
Usage::
28+
29+
python isic2018_artifacts_classification.py --root /path/to/isic2018_data
30+
python isic2018_artifacts_classification.py --root /path/to/isic2018_data --mode lesion
31+
"""
32+
33+
import argparse
34+
import logging
35+
import os
36+
37+
import numpy as np
38+
from sklearn.model_selection import StratifiedKFold
39+
40+
from pyhealth.datasets import ISIC2018ArtifactsDataset, get_dataloader
41+
from pyhealth.models import TorchvisionModel
42+
from pyhealth.processors.dermoscopic_image_processor import VALID_MODES
43+
from pyhealth.tasks import ISIC2018ArtifactsBinaryClassification
44+
from pyhealth.trainer import Trainer
45+
46+
parser = argparse.ArgumentParser(description="Train ISIC2018 artifact classifier")
47+
parser.add_argument(
48+
"--root",
49+
type=str,
50+
required=True,
51+
help="Root directory containing the annotation CSV, images, and masks.",
52+
)
53+
parser.add_argument(
54+
"--image_dir",
55+
type=str,
56+
default="ISIC2018_Task1-2_Training_Input",
57+
help="Sub-directory (relative to root, or absolute path) for ISIC images.",
58+
)
59+
parser.add_argument(
60+
"--mask_dir",
61+
type=str,
62+
default="ISIC2018_Task1_Training_GroundTruth",
63+
help="Sub-directory (relative to root, or absolute path) for segmentation masks.",
64+
)
65+
parser.add_argument(
66+
"--annotations_csv",
67+
type=str,
68+
default="isic_bias.csv",
69+
help="Annotation CSV filename (relative to root, or absolute path).",
70+
)
71+
parser.add_argument(
72+
"--mode",
73+
type=str,
74+
default="whole",
75+
choices=VALID_MODES,
76+
help="Image preprocessing mode.",
77+
)
78+
parser.add_argument(
79+
"--model",
80+
type=str,
81+
default="resnet50",
82+
help="Torchvision model backbone (e.g. resnet50, vit_b_16).",
83+
)
84+
parser.add_argument("--epochs", type=int, default=10)
85+
parser.add_argument("--batch_size", type=int, default=32)
86+
parser.add_argument("--lr", type=float, default=1e-4)
87+
parser.add_argument("--n_splits", type=int, default=5)
88+
parser.add_argument("--seed", type=int, default=42)
89+
parser.add_argument("--download", action="store_true", help="Auto-download data.")
90+
args = parser.parse_args()
91+
92+
# Route PyHealth trainer logs to stdout so per-epoch metrics are visible.
93+
_handler = logging.StreamHandler()
94+
_handler.setFormatter(logging.Formatter("%(message)s"))
95+
logging.getLogger("pyhealth.trainer").addHandler(_handler)
96+
logging.getLogger("pyhealth.trainer").setLevel(logging.INFO)
97+
98+
99+
# =============================================================================
100+
# Example run & results
101+
# =============================================================================
102+
# Command:
103+
# python isic2018_artifacts_classification_resnet50.py --root /path/to/isic2018_data
104+
#
105+
# Parameters:
106+
# --mode whole
107+
# --model resnet50
108+
# --epochs 10
109+
# --batch_size 32
110+
# --lr 1e-4
111+
# --n_splits 5
112+
# --seed 42
113+
#
114+
# 5-fold stratified CV results (whole mode, ResNet-50, ImageNet pretrained):
115+
#
116+
# Split 1 AUROC: 0.800 Accuracy: 0.844
117+
# Split 2 AUROC: 0.803 Accuracy: 0.829
118+
# Split 3 AUROC: 0.758 Accuracy: 0.788
119+
# Split 4 AUROC: 0.790 Accuracy: 0.807
120+
# Split 5 AUROC: 0.829 Accuracy: 0.840
121+
# ─────────────────────────────────────
122+
# Mean AUROC: 0.796 Accuracy: 0.822
123+
#
124+
# Matches findings from:
125+
# "A Study of Artifacts on Melanoma Classification under
126+
# Diffusion-Based Perturbations"
127+
# =============================================================================
128+
129+
if __name__ == "__main__":
130+
# ------------------------------------------------------------------
131+
# 1. Build dataset — all path resolution delegated to the loader
132+
# ------------------------------------------------------------------
133+
dataset = ISIC2018ArtifactsDataset(
134+
root=args.root,
135+
annotations_csv=args.annotations_csv,
136+
image_dir=args.image_dir,
137+
mask_dir=args.mask_dir,
138+
mode=args.mode,
139+
download=args.download,
140+
)
141+
dataset.stats()
142+
143+
# ------------------------------------------------------------------
144+
# 2. Apply task → SampleDataset with binary labels
145+
# ------------------------------------------------------------------
146+
task = ISIC2018ArtifactsBinaryClassification()
147+
samples = dataset.set_task(task)
148+
149+
# ------------------------------------------------------------------
150+
# 3. Generate stratified K-fold splits from sample labels
151+
# ------------------------------------------------------------------
152+
labels = np.array([samples[i]["label"] for i in range(len(samples))])
153+
indices = np.arange(len(labels))
154+
155+
skf = StratifiedKFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed)
156+
157+
output_dir = os.path.join(args.root, "checkpoints", args.mode)
158+
os.makedirs(output_dir, exist_ok=True)
159+
160+
for fold, (train_val_idx, test_idx) in enumerate(skf.split(indices, labels), start=1):
161+
print(f"\n{'='*60}")
162+
print(f" Mode: {args.mode} | Split {fold}/{args.n_splits}")
163+
print(f"{'='*60}")
164+
165+
# Use 10% of train_val as validation
166+
val_size = max(1, int(0.1 * len(train_val_idx)))
167+
rng = np.random.default_rng(args.seed + fold)
168+
rng.shuffle(train_val_idx)
169+
val_idx = train_val_idx[:val_size]
170+
train_idx = train_val_idx[val_size:]
171+
172+
train_loader = get_dataloader(
173+
samples.subset(train_idx), batch_size=args.batch_size, shuffle=True
174+
)
175+
val_loader = get_dataloader(
176+
samples.subset(val_idx), batch_size=args.batch_size, shuffle=False
177+
)
178+
test_loader = get_dataloader(
179+
samples.subset(test_idx), batch_size=args.batch_size, shuffle=False
180+
)
181+
182+
# --------------------------------------------------------------
183+
# 4. Fresh model per fold
184+
# --------------------------------------------------------------
185+
model = TorchvisionModel(
186+
dataset=samples,
187+
model_name=args.model,
188+
model_config={"weights": "DEFAULT"},
189+
)
190+
191+
# --------------------------------------------------------------
192+
# 5. Train
193+
# --------------------------------------------------------------
194+
trainer = Trainer(
195+
model=model,
196+
metrics=["accuracy", "roc_auc"],
197+
)
198+
trainer.train(
199+
train_dataloader=train_loader,
200+
val_dataloader=val_loader,
201+
epochs=args.epochs,
202+
optimizer_params={"lr": args.lr},
203+
)
204+
205+
# --------------------------------------------------------------
206+
# 6. Evaluate
207+
# --------------------------------------------------------------
208+
scores = trainer.evaluate(test_loader)
209+
print(f"Mode: {args.mode} Split {fold} test results: {scores}")
210+
211+
samples.close()
212+

0 commit comments

Comments
 (0)