|
| 1 | +""" |
| 2 | +ISIC 2018 — Binary melanoma classification under artifact modes. |
| 3 | +
|
| 4 | +Replicates image-mode experiments from: |
| 5 | + "A Study of Artifacts on Melanoma Classification under Diffusion-Based Perturbations" |
| 6 | +
|
| 7 | +Supports all 12 preprocessing modes (whole, lesion, background, bbox, bbox70, |
| 8 | +bbox90, high_whole, low_whole, high_lesion, low_lesion, high_background, |
| 9 | +low_background) via ``--mode``. |
| 10 | +
|
| 11 | +5-fold stratified cross-validation splits are generated from sample labels. |
| 12 | +
|
| 13 | +Expected directory structure under ``--root``:: |
| 14 | +
|
| 15 | + <root>/ |
| 16 | + <annotations_csv> ← annotation file (--annotations_csv) |
| 17 | + <image_dir>/ ← images (--image_dir) |
| 18 | + <mask_dir>/ ← masks (--mask_dir) |
| 19 | +
|
| 20 | +Annotation file: |
| 21 | + https://github.com/alceubissoto/debiasing-skin/tree/main/artefacts-annotation |
| 22 | +
|
| 23 | +Images / masks (~9 GB total): |
| 24 | + https://challenge.isic-archive.com/data/#2018 |
| 25 | + (or pass ``--download`` to fetch automatically) |
| 26 | +
|
| 27 | +Usage:: |
| 28 | +
|
| 29 | + python isic2018_artifacts_classification.py --root /path/to/isic2018_data |
| 30 | + python isic2018_artifacts_classification.py --root /path/to/isic2018_data --mode lesion |
| 31 | +""" |
| 32 | + |
| 33 | +import argparse |
| 34 | +import logging |
| 35 | +import os |
| 36 | + |
| 37 | +import numpy as np |
| 38 | +from sklearn.model_selection import StratifiedKFold |
| 39 | + |
| 40 | +from pyhealth.datasets import ISIC2018ArtifactsDataset, get_dataloader |
| 41 | +from pyhealth.models import TorchvisionModel |
| 42 | +from pyhealth.processors.dermoscopic_image_processor import VALID_MODES |
| 43 | +from pyhealth.tasks import ISIC2018ArtifactsBinaryClassification |
| 44 | +from pyhealth.trainer import Trainer |
| 45 | + |
| 46 | +parser = argparse.ArgumentParser(description="Train ISIC2018 artifact classifier") |
| 47 | +parser.add_argument( |
| 48 | + "--root", |
| 49 | + type=str, |
| 50 | + required=True, |
| 51 | + help="Root directory containing the annotation CSV, images, and masks.", |
| 52 | +) |
| 53 | +parser.add_argument( |
| 54 | + "--image_dir", |
| 55 | + type=str, |
| 56 | + default="ISIC2018_Task1-2_Training_Input", |
| 57 | + help="Sub-directory (relative to root, or absolute path) for ISIC images.", |
| 58 | +) |
| 59 | +parser.add_argument( |
| 60 | + "--mask_dir", |
| 61 | + type=str, |
| 62 | + default="ISIC2018_Task1_Training_GroundTruth", |
| 63 | + help="Sub-directory (relative to root, or absolute path) for segmentation masks.", |
| 64 | +) |
| 65 | +parser.add_argument( |
| 66 | + "--annotations_csv", |
| 67 | + type=str, |
| 68 | + default="isic_bias.csv", |
| 69 | + help="Annotation CSV filename (relative to root, or absolute path).", |
| 70 | +) |
| 71 | +parser.add_argument( |
| 72 | + "--mode", |
| 73 | + type=str, |
| 74 | + default="whole", |
| 75 | + choices=VALID_MODES, |
| 76 | + help="Image preprocessing mode.", |
| 77 | +) |
| 78 | +parser.add_argument( |
| 79 | + "--model", |
| 80 | + type=str, |
| 81 | + default="resnet50", |
| 82 | + help="Torchvision model backbone (e.g. resnet50, vit_b_16).", |
| 83 | +) |
| 84 | +parser.add_argument("--epochs", type=int, default=10) |
| 85 | +parser.add_argument("--batch_size", type=int, default=32) |
| 86 | +parser.add_argument("--lr", type=float, default=1e-4) |
| 87 | +parser.add_argument("--n_splits", type=int, default=5) |
| 88 | +parser.add_argument("--seed", type=int, default=42) |
| 89 | +parser.add_argument("--download", action="store_true", help="Auto-download data.") |
| 90 | +args = parser.parse_args() |
| 91 | + |
| 92 | +# Route PyHealth trainer logs to stdout so per-epoch metrics are visible. |
| 93 | +_handler = logging.StreamHandler() |
| 94 | +_handler.setFormatter(logging.Formatter("%(message)s")) |
| 95 | +logging.getLogger("pyhealth.trainer").addHandler(_handler) |
| 96 | +logging.getLogger("pyhealth.trainer").setLevel(logging.INFO) |
| 97 | + |
| 98 | + |
| 99 | +# ============================================================================= |
| 100 | +# Example run & results |
| 101 | +# ============================================================================= |
| 102 | +# Command: |
| 103 | +# python isic2018_artifacts_classification_resnet50.py --root /path/to/isic2018_data |
| 104 | +# |
| 105 | +# Parameters: |
| 106 | +# --mode whole |
| 107 | +# --model resnet50 |
| 108 | +# --epochs 10 |
| 109 | +# --batch_size 32 |
| 110 | +# --lr 1e-4 |
| 111 | +# --n_splits 5 |
| 112 | +# --seed 42 |
| 113 | +# |
| 114 | +# 5-fold stratified CV results (whole mode, ResNet-50, ImageNet pretrained): |
| 115 | +# |
| 116 | +# Split 1 AUROC: 0.800 Accuracy: 0.844 |
| 117 | +# Split 2 AUROC: 0.803 Accuracy: 0.829 |
| 118 | +# Split 3 AUROC: 0.758 Accuracy: 0.788 |
| 119 | +# Split 4 AUROC: 0.790 Accuracy: 0.807 |
| 120 | +# Split 5 AUROC: 0.829 Accuracy: 0.840 |
| 121 | +# ───────────────────────────────────── |
| 122 | +# Mean AUROC: 0.796 Accuracy: 0.822 |
| 123 | +# |
| 124 | +# Matches findings from: |
| 125 | +# "A Study of Artifacts on Melanoma Classification under |
| 126 | +# Diffusion-Based Perturbations" |
| 127 | +# ============================================================================= |
| 128 | + |
| 129 | +if __name__ == "__main__": |
| 130 | + # ------------------------------------------------------------------ |
| 131 | + # 1. Build dataset — all path resolution delegated to the loader |
| 132 | + # ------------------------------------------------------------------ |
| 133 | + dataset = ISIC2018ArtifactsDataset( |
| 134 | + root=args.root, |
| 135 | + annotations_csv=args.annotations_csv, |
| 136 | + image_dir=args.image_dir, |
| 137 | + mask_dir=args.mask_dir, |
| 138 | + mode=args.mode, |
| 139 | + download=args.download, |
| 140 | + ) |
| 141 | + dataset.stats() |
| 142 | + |
| 143 | + # ------------------------------------------------------------------ |
| 144 | + # 2. Apply task → SampleDataset with binary labels |
| 145 | + # ------------------------------------------------------------------ |
| 146 | + task = ISIC2018ArtifactsBinaryClassification() |
| 147 | + samples = dataset.set_task(task) |
| 148 | + |
| 149 | + # ------------------------------------------------------------------ |
| 150 | + # 3. Generate stratified K-fold splits from sample labels |
| 151 | + # ------------------------------------------------------------------ |
| 152 | + labels = np.array([samples[i]["label"] for i in range(len(samples))]) |
| 153 | + indices = np.arange(len(labels)) |
| 154 | + |
| 155 | + skf = StratifiedKFold(n_splits=args.n_splits, shuffle=True, random_state=args.seed) |
| 156 | + |
| 157 | + output_dir = os.path.join(args.root, "checkpoints", args.mode) |
| 158 | + os.makedirs(output_dir, exist_ok=True) |
| 159 | + |
| 160 | + for fold, (train_val_idx, test_idx) in enumerate(skf.split(indices, labels), start=1): |
| 161 | + print(f"\n{'='*60}") |
| 162 | + print(f" Mode: {args.mode} | Split {fold}/{args.n_splits}") |
| 163 | + print(f"{'='*60}") |
| 164 | + |
| 165 | + # Use 10% of train_val as validation |
| 166 | + val_size = max(1, int(0.1 * len(train_val_idx))) |
| 167 | + rng = np.random.default_rng(args.seed + fold) |
| 168 | + rng.shuffle(train_val_idx) |
| 169 | + val_idx = train_val_idx[:val_size] |
| 170 | + train_idx = train_val_idx[val_size:] |
| 171 | + |
| 172 | + train_loader = get_dataloader( |
| 173 | + samples.subset(train_idx), batch_size=args.batch_size, shuffle=True |
| 174 | + ) |
| 175 | + val_loader = get_dataloader( |
| 176 | + samples.subset(val_idx), batch_size=args.batch_size, shuffle=False |
| 177 | + ) |
| 178 | + test_loader = get_dataloader( |
| 179 | + samples.subset(test_idx), batch_size=args.batch_size, shuffle=False |
| 180 | + ) |
| 181 | + |
| 182 | + # -------------------------------------------------------------- |
| 183 | + # 4. Fresh model per fold |
| 184 | + # -------------------------------------------------------------- |
| 185 | + model = TorchvisionModel( |
| 186 | + dataset=samples, |
| 187 | + model_name=args.model, |
| 188 | + model_config={"weights": "DEFAULT"}, |
| 189 | + ) |
| 190 | + |
| 191 | + # -------------------------------------------------------------- |
| 192 | + # 5. Train |
| 193 | + # -------------------------------------------------------------- |
| 194 | + trainer = Trainer( |
| 195 | + model=model, |
| 196 | + metrics=["accuracy", "roc_auc"], |
| 197 | + ) |
| 198 | + trainer.train( |
| 199 | + train_dataloader=train_loader, |
| 200 | + val_dataloader=val_loader, |
| 201 | + epochs=args.epochs, |
| 202 | + optimizer_params={"lr": args.lr}, |
| 203 | + ) |
| 204 | + |
| 205 | + # -------------------------------------------------------------- |
| 206 | + # 6. Evaluate |
| 207 | + # -------------------------------------------------------------- |
| 208 | + scores = trainer.evaluate(test_loader) |
| 209 | + print(f"Mode: {args.mode} Split {fold} test results: {scores}") |
| 210 | + |
| 211 | + samples.close() |
| 212 | + |
0 commit comments