|
| 1 | +# Weakly supervised segmentation demo using PyMIC |
| 2 | + |
| 3 | +In this example, we show scribble-supervised learning methods implemented in PyMIC. |
| 4 | +Currently, the following are available in PyMIC: |
| 5 | +|PyMIC Method|Reference|Remarks| |
| 6 | +|---|---|---| |
| 7 | +|WSLEntropyMinimization|[Grandvalet et al.][em_paper], NeurIPS 2005| Entropy minimization for regularization| |
| 8 | +|WSLTotalVariation| [Luo et al.][tv_paper], arXiv 2022| Tobal variation for regularization| |
| 9 | +|WSLMumfordShah| [Kim et al.][mumford_paper], TIP 2020| Mumford-Shah loss for regularization| |
| 10 | +|WSLGatedCRF| [Lbukhov et al.][gcrf_paper], arXiv 2019| Gated CRF for regularization| |
| 11 | +|WSLUSTM| [Liu et al.][ustm_paper], PR 2022| Adapt USTM with transform-consistency| |
| 12 | +|WSLDMPLS| [Luo et al.][dmpls_paper], MICCAI 2022| Dynamically mixed pseudo label supervision| |
| 13 | + |
| 14 | +[em_paper]:https://papers.nips.cc/paper/2004/file/96f2b50b5d3613adf9c27049b2a888c7-Paper.pdf |
| 15 | +[tv_paper]:https://arxiv.org/abs/2111.02403 |
| 16 | +[mumford_paper]:https://doi.org/10.1109/TIP.2019.2941265 |
| 17 | +[gcrf_paper]:http://arxiv.org/abs/1906.04651 |
| 18 | +[ustm_paper]:https://doi.org/10.1016/j.patcog.2021.108341 |
| 19 | +[dmpls_paper]:https://arxiv.org/abs/2203.02106 |
| 20 | + |
| 21 | + |
| 22 | +## Data |
| 23 | +The [ACDC][ACDC_link] (Automatic Cardiac Diagnosis Challenge) dataset is used in this demo. It contains 200 short-axis cardiac cine MR images of 100 patients, and the classes for segmentation are: Right Ventricle (RV), Myocardiym (Myo) and Left Ventricle (LV). [Valvano et al.][scribble_link] provided scribble annotations of this dataset. The images and scribble annotations are available in `PyMIC_data/ACDC/preprocess`, where we have normalized the intensity to [0, 1]. You can download `PyMIC_data` from .... The images are split at patient level into 70%, 10% and 20% for training, validation and testing, respectively (see `config/data` for details). |
| 24 | + |
| 25 | +[ACDC_link]:https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html |
| 26 | +[scribble_link]:https://gvalvano.github.io/wss-multiscale-adversarial-attention-gates/data |
| 27 | + |
| 28 | +## Training |
| 29 | +In this demo, we experiment with five methods: EM, TV, GatedCRF, USTM and DMPLS, and they are compared with the baseline of learning from annotated pixels with partial CE loss. All these methods use UNet2D as the backbone network. |
| 30 | + |
| 31 | +### Baseline Method |
| 32 | +The dataset setting is similar to that in the `seg_ssl/ACDC` demo. Here we use a slightly different setting of data transform: |
| 33 | + |
| 34 | +```bash |
| 35 | +tensor_type = float |
| 36 | +task_type = seg |
| 37 | +root_dir = /home/disk2t/projects/PyMIC_project/PyMIC_data/ACDC/preprocess |
| 38 | +train_csv = config/data/image_train.csv |
| 39 | +valid_csv = config/data/image_valid.csv |
| 40 | +test_csv = config/data/image_test.csv |
| 41 | +train_batch_size = 4 |
| 42 | + |
| 43 | +# data transforms |
| 44 | +train_transform = [Pad, RandomCrop, RandomFlip, NormalizeWithMeanStd, PartialLabelToProbability] |
| 45 | +valid_transform = [NormalizeWithMeanStd, Pad, LabelToProbability] |
| 46 | +test_transform = [NormalizeWithMeanStd, Pad] |
| 47 | + |
| 48 | +Pad_output_size = [4, 224, 224] |
| 49 | +Pad_ceil_mode = False |
| 50 | + |
| 51 | +RandomCrop_output_size = [3, 224, 224] |
| 52 | +RandomCrop_foreground_focus = False |
| 53 | +RandomCrop_foreground_ratio = None |
| 54 | +Randomcrop_mask_label = None |
| 55 | + |
| 56 | +RandomFlip_flip_depth = False |
| 57 | +RandomFlip_flip_height = True |
| 58 | +RandomFlip_flip_width = True |
| 59 | + |
| 60 | +NormalizeWithMeanStd_channels = [0] |
| 61 | +``` |
| 62 | + |
| 63 | +Please note that we use a `PartialLabelToProbability` class to convert the partial labels into a one-hot segmentation map and a mask for annotated pixels. The mask is used as a pixel weighting map in `CrossEntropyLoss`, so that parial CE loss is calculated as a weighted CE loss, i.e., the weight for unannotated pixels is 0. |
| 64 | + |
| 65 | + |
| 66 | +The configuration of 2D UNet is: |
| 67 | + |
| 68 | +```bash |
| 69 | +net_type = UNet2D |
| 70 | +class_num = 4 |
| 71 | +in_chns = 1 |
| 72 | +feature_chns = [16, 32, 64, 128, 256] |
| 73 | +dropout = [0.0, 0.0, 0.0, 0.5, 0.5] |
| 74 | +bilinear = True |
| 75 | +deep_supervise= False |
| 76 | +``` |
| 77 | + |
| 78 | +For training, we use the CrossEntropyLoss with pixel weighting (i.e., partial CE loss), and train the network by the `Adam` optimizer. The maximal iteration is 20k, and the training is early stopped if there is not performance improvement on the validation set for 8k iteratins. The learning rate scheduler is `ReduceLROnPlateau`. The corresponding configuration is: |
| 79 | + |
| 80 | +```bash |
| 81 | +gpus = [0] |
| 82 | +loss_type = CrossEntropyLoss |
| 83 | + |
| 84 | +# for optimizers |
| 85 | +optimizer = Adam |
| 86 | +learning_rate = 1e-3 |
| 87 | +momentum = 0.9 |
| 88 | +weight_decay = 1e-5 |
| 89 | + |
| 90 | +# for lr schedular |
| 91 | +lr_scheduler = ReduceLROnPlateau |
| 92 | +lr_gamma = 0.5 |
| 93 | +ReduceLROnPlateau_patience = 2000 |
| 94 | +early_stop_patience = 8000 |
| 95 | +ckpt_save_dir = model/unet2d_baseline |
| 96 | + |
| 97 | +# start iter |
| 98 | +iter_start = 0 |
| 99 | +iter_max = 20000 |
| 100 | +iter_valid = 100 |
| 101 | +iter_save = [2000, 20000] |
| 102 | +``` |
| 103 | + |
| 104 | +During inference, we use a sliding window of 3x224x224, and post process the results by `KeepLargestComponent`. The configuration is: |
| 105 | +```bash |
| 106 | +# checkpoint mode can be [0-latest, 1-best, 2-specified] |
| 107 | +ckpt_mode = 1 |
| 108 | +output_dir = result/unet2d_baseline |
| 109 | +post_process = KeepLargestComponent |
| 110 | + |
| 111 | +sliding_window_enable = True |
| 112 | +sliding_window_size = [3, 224, 224] |
| 113 | +sliding_window_stride = [3, 224, 224] |
| 114 | +``` |
| 115 | + |
| 116 | +The following commands are used for training and inference with this method, respectively: |
| 117 | + |
| 118 | +```bash |
| 119 | +pymic_run train config/unet2d_baseline.cfg |
| 120 | +pymic_run test config/unet2d_baseline.cfg |
| 121 | +``` |
| 122 | + |
| 123 | +### Entropy Minimization |
| 124 | +The configuration file for Entropy Minimization is `config/unet2d_em.cfg`. The data configuration has been described above, and the settings for data augmentation, network, optmizer, learning rate scheduler and inference are the same as those in the baseline method. Specific setting for Entropy Minimization is: |
| 125 | + |
| 126 | +```bash |
| 127 | +wsl_method = EntropyMinimization |
| 128 | +regularize_w = 0.1 |
| 129 | +rampup_start = 2000 |
| 130 | +rampup_end = 15000 |
| 131 | +``` |
| 132 | + |
| 133 | +where wet the weight of the regularization loss as 0.1, rampup is used to gradually increase it from 0 t 0.1. |
| 134 | + |
| 135 | +The following commands are used for training and inference with this method, respectively: |
| 136 | + |
| 137 | +```bash |
| 138 | +pymic_wsl train config/unet2d_em.cfg |
| 139 | +pymic_run test config/unet2d_em.cfg |
| 140 | +``` |
| 141 | + |
| 142 | +### TV |
| 143 | +The configuration file for TV is `config/unet2d_tv.cfg`. The corresponding setting is: |
| 144 | + |
| 145 | +```bash |
| 146 | +wsl_method = TotalVariation |
| 147 | +regularize_w = 0.1 |
| 148 | +rampup_start = 2000 |
| 149 | +rampup_end = 15000 |
| 150 | +``` |
| 151 | + |
| 152 | +The following commands are used for training and inference with this method, respectively: |
| 153 | +```bash |
| 154 | +pymic_wsl train config/unet2d_tv.cfg |
| 155 | +pymic_run test config/unet2d_tv.cfg |
| 156 | +``` |
| 157 | + |
| 158 | +### Gated CRF |
| 159 | +The configuration file for Gated CRF is `config/unet2d_gcrf.cfg`. The corresponding setting is: |
| 160 | + |
| 161 | +```bash |
| 162 | +wsl_method = GatedCRF |
| 163 | +regularize_w = 0.1 |
| 164 | +rampup_start = 2000 |
| 165 | +rampup_end = 15000 |
| 166 | +GatedCRFLoss_W0 = 1.0 |
| 167 | +GatedCRFLoss_XY0 = 5 |
| 168 | +GatedCRFLoss_rgb = 0.1 |
| 169 | +GatedCRFLoss_W1 = 1.0 |
| 170 | +GatedCRFLoss_XY1 = 3 |
| 171 | +GatedCRFLoss_Radius = 5 |
| 172 | +``` |
| 173 | + |
| 174 | +The following commands are used for training and inference with this method, respectively: |
| 175 | + |
| 176 | +```bash |
| 177 | +pymic_wsl train config/unet2d_gcrf.cfg |
| 178 | +pymic_run test config/unet2d_gcrf.cfg |
| 179 | +``` |
| 180 | + |
| 181 | +### USTM |
| 182 | +The configuration file for USTM is `config/unet2d_ustm.cfg`. The corresponding setting is: |
| 183 | + |
| 184 | +```bash |
| 185 | +wsl_method = USTM |
| 186 | +regularize_w = 0.1 |
| 187 | +rampup_start = 2000 |
| 188 | +rampup_end = 15000 |
| 189 | +``` |
| 190 | + |
| 191 | +The commands for training and inference are: |
| 192 | + |
| 193 | +```bash |
| 194 | +pymic_wsl train config/unet2d_ustm.cfg |
| 195 | +pymic_run test config/unet2d_ustm.cfg |
| 196 | +``` |
| 197 | + |
| 198 | +### DMPLS |
| 199 | +The configuration file for DMPLS is `config/unet2d_dmpls.cfg`, and the corresponding setting is: |
| 200 | + |
| 201 | +```bash |
| 202 | +wsl_method = DMPLS |
| 203 | +regularize_w = 0.1 |
| 204 | +rampup_start = 2000 |
| 205 | +rampup_end = 15000 |
| 206 | +``` |
| 207 | + |
| 208 | +The training and inference commands are: |
| 209 | + |
| 210 | +```bash |
| 211 | +pymic_ssl train config/unet2d_dmpls.cfg |
| 212 | +pymic_run test config/unet2d_dmpls.cfg |
| 213 | +``` |
| 214 | + |
| 215 | +## Evaluation |
| 216 | +Use `pymic_eval_seg config/evaluation.cfg` for quantitative evaluation of the segmentation results. You need to edit `config/evaluation.cfg` first, for example: |
| 217 | + |
| 218 | +```bash |
| 219 | +metric = dice |
| 220 | +label_list = [1,2,3] |
| 221 | +organ_name = heart |
| 222 | +ground_truth_folder_root = /home/disk2t/projects/PyMIC_project/PyMIC_data/ACDC/preprocess |
| 223 | +segmentation_folder_root = ./result/unet2d_baseline |
| 224 | +evaluation_image_pair = ./config/data/image_test_gt_seg.csv |
| 225 | +``` |
| 226 | + |
0 commit comments