|
107 | 107 | "import os\n", |
108 | 108 | "import tempfile\n", |
109 | 109 | "import time\n", |
110 | | - "import os\n", |
111 | 110 | "import matplotlib.pyplot as plt\n", |
112 | 111 | "import numpy as np\n", |
113 | 112 | "import torch\n", |
|
119 | 118 | "from monai.config import print_config\n", |
120 | 119 | "from monai.data import DataLoader\n", |
121 | 120 | "from monai.utils import set_determinism\n", |
122 | | - "from torch.cuda.amp import GradScaler, autocast\n", |
123 | | - "from tqdm import tqdm\n", |
124 | 121 | "from sklearn.linear_model import LogisticRegression\n", |
125 | 122 | "\n", |
126 | 123 | "from generative.inferers import DiffusionInferer\n", |
|
152 | 149 | "outputs": [], |
153 | 150 | "source": [ |
154 | 151 | "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", |
155 | | - "root_dir = tempfile.mkdtemp() if directory is None else directory\n", |
156 | | - "root_dir = '/home/s2086085/pedro_idcom/experiment_data'" |
| 152 | + "root_dir = tempfile.mkdtemp() if directory is None else directory" |
157 | 153 | ] |
158 | 154 | }, |
159 | 155 | { |
|
192 | 188 | ] |
193 | 189 | }, |
194 | 190 | { |
| 191 | + "attachments": {}, |
195 | 192 | "cell_type": "markdown", |
196 | 193 | "id": "6986f55c", |
197 | 194 | "metadata": {}, |
|
200 | 197 | "\n", |
201 | 198 | "1. `LoadImaged` loads the brain images from files.\n", |
202 | 199 | "2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n", |
203 | | - "3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image.\n", |
204 | | - "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper.\n", |
| 200 | + "3. The first `Lambdad` transform chooses the first channel of the image, which is the Flair image.\n", |
| 201 | + "4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm.\n", |
205 | 202 | "5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n", |
206 | 203 | "6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n", |
207 | 204 | "6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )." |
208 | 205 | ] |
209 | 206 | }, |
210 | 207 | { |
211 | 208 | "cell_type": "code", |
212 | | - "execution_count": 5, |
| 209 | + "execution_count": null, |
213 | 210 | "id": "c68d2d91-9a0b-4ac1-ae49-f4a64edbd82a", |
214 | 211 | "metadata": {}, |
215 | | - "outputs": [ |
216 | | - { |
217 | | - "name": "stderr", |
218 | | - "output_type": "stream", |
219 | | - "text": [ |
220 | | - "<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.\n" |
221 | | - ] |
222 | | - } |
223 | | - ], |
| 212 | + "outputs": [], |
224 | 213 | "source": [ |
225 | 214 | "channel = 0 # 0 = Flair\n", |
226 | 215 | "assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n", |
|
229 | 218 | " [\n", |
230 | 219 | " transforms.LoadImaged(keys=[\"image\", \"label\"]),\n", |
231 | 220 | " transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n", |
232 | | - " transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, :, :, :]),\n", |
233 | | - " transforms.AddChanneld(keys=[\"image\"]),\n", |
| 221 | + " transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, None, :, :, :]),\n", |
234 | 222 | " transforms.EnsureTyped(keys=[\"image\", \"label\"]),\n", |
235 | 223 | " transforms.Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n", |
236 | 224 | " transforms.Spacingd(keys=[\"image\", \"label\"], pixdim=(3.0, 3.0, 2.0), mode=(\"bilinear\", \"nearest\")),\n", |
|
292 | 280 | " section=\"training\",\n", |
293 | 281 | " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", |
294 | 282 | " num_workers=4,\n", |
295 | | - " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", |
| 283 | + " download=True, # Set download to True if the dataset hasnt been downloaded yet\n", |
296 | 284 | " seed=0,\n", |
297 | 285 | " transform=train_transforms,\n", |
298 | 286 | ")\n", |
|
336 | 324 | " section=\"validation\",\n", |
337 | 325 | " cache_rate=1, # you may need a few Gb of RAM... Set to 0 otherwise\n", |
338 | 326 | " num_workers=4,\n", |
339 | | - " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", |
| 327 | + " download=True, # Set download to True if the dataset hasnt been downloaded yet\n", |
340 | 328 | " seed=0,\n", |
341 | 329 | " transform=train_transforms,\n", |
342 | 330 | ")\n", |
|
645 | 633 | } |
646 | 634 | ], |
647 | 635 | "source": [ |
648 | | - "n_iterations = 1e4 # training for longer helps a lot with reconstruction quality, even if the loss is already low\n", |
| 636 | + "n_iterations = 1e4 # training for longer (1e4 ~ 3h) helps a lot with reconstruction quality, even if the loss is already low\n", |
649 | 637 | "batch_size = 64\n", |
650 | 638 | "val_interval = 100\n", |
651 | 639 | "iter_loss_list, val_iter_loss_list = [], []\n", |
|
810 | 798 | "# get latent space of training set\n", |
811 | 799 | "latents_train = []\n", |
812 | 800 | "classes_train = []\n", |
813 | | - "for i in range(15): # 15 slices from each volume\n", |
| 801 | + "# 15 slices from each volume\n", |
| 802 | + "nb_slices_per_volume = 15\n", |
| 803 | + "for _ in range(nb_slices_per_volume): \n", |
814 | 804 | " for batch in train_loader:\n", |
815 | 805 | " images = batch[\"image\"].to(device)\n", |
816 | 806 | " latent = model.semantic_encoder(images)\n", |
|
0 commit comments