Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 4044cbf

Browse files
committed
Corrections tutorials
1 parent 2f8cad5 commit 4044cbf

2 files changed

Lines changed: 108 additions & 76 deletions

File tree

tutorials/generative/2d_diffusion_autoencoder/2d_diffusion_autoencoder_tutorial.ipynb

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@
107107
"import os\n",
108108
"import tempfile\n",
109109
"import time\n",
110-
"import os\n",
111110
"import matplotlib.pyplot as plt\n",
112111
"import numpy as np\n",
113112
"import torch\n",
@@ -119,8 +118,6 @@
119118
"from monai.config import print_config\n",
120119
"from monai.data import DataLoader\n",
121120
"from monai.utils import set_determinism\n",
122-
"from torch.cuda.amp import GradScaler, autocast\n",
123-
"from tqdm import tqdm\n",
124121
"from sklearn.linear_model import LogisticRegression\n",
125122
"\n",
126123
"from generative.inferers import DiffusionInferer\n",
@@ -152,8 +149,7 @@
152149
"outputs": [],
153150
"source": [
154151
"directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
155-
"root_dir = tempfile.mkdtemp() if directory is None else directory\n",
156-
"root_dir = '/home/s2086085/pedro_idcom/experiment_data'"
152+
"root_dir = tempfile.mkdtemp() if directory is None else directory"
157153
]
158154
},
159155
{
@@ -192,6 +188,7 @@
192188
]
193189
},
194190
{
191+
"attachments": {},
195192
"cell_type": "markdown",
196193
"id": "6986f55c",
197194
"metadata": {},
@@ -200,27 +197,19 @@
200197
"\n",
201198
"1. `LoadImaged` loads the brain images from files.\n",
202199
"2. `EnsureChannelFirstd` ensures the original data to construct \"channel first\" shape.\n",
203-
"3. The first `Lambdad` transform chooses the first channel of the image, which is the T1-weighted image.\n",
204-
"4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm to match the original paper.\n",
200+
"3. The first `Lambdad` transform chooses the first channel of the image, which is the Flair image.\n",
201+
"4. `Spacingd` resamples the image to the specified voxel spacing, we use 3,3,2 mm.\n",
205202
"5. `ScaleIntensityRangePercentilesd` Apply range scaling to a numpy array based on the intensity distribution of the input. Transform is very common with MRI images.\n",
206203
"6. `RandSpatialCropd` randomly crop out a 2D patch from the 3D image.\n",
207204
"6. The last `Lambdad` transform obtains `slice_label` by summing up the label to have a single scalar value (healthy `=1` or not `=2` )."
208205
]
209206
},
210207
{
211208
"cell_type": "code",
212-
"execution_count": 5,
209+
"execution_count": null,
213210
"id": "c68d2d91-9a0b-4ac1-ae49-f4a64edbd82a",
214211
"metadata": {},
215-
"outputs": [
216-
{
217-
"name": "stderr",
218-
"output_type": "stream",
219-
"text": [
220-
"<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.\n"
221-
]
222-
}
223-
],
212+
"outputs": [],
224213
"source": [
225214
"channel = 0 # 0 = Flair\n",
226215
"assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n",
@@ -229,8 +218,7 @@
229218
" [\n",
230219
" transforms.LoadImaged(keys=[\"image\", \"label\"]),\n",
231220
" transforms.EnsureChannelFirstd(keys=[\"image\", \"label\"]),\n",
232-
" transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, :, :, :]),\n",
233-
" transforms.AddChanneld(keys=[\"image\"]),\n",
221+
" transforms.Lambdad(keys=[\"image\"], func=lambda x: x[channel, None, :, :, :]),\n",
234222
" transforms.EnsureTyped(keys=[\"image\", \"label\"]),\n",
235223
" transforms.Orientationd(keys=[\"image\", \"label\"], axcodes=\"RAS\"),\n",
236224
" transforms.Spacingd(keys=[\"image\", \"label\"], pixdim=(3.0, 3.0, 2.0), mode=(\"bilinear\", \"nearest\")),\n",
@@ -292,7 +280,7 @@
292280
" section=\"training\",\n",
293281
" cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n",
294282
" num_workers=4,\n",
295-
" download=False, # Set download to True if the dataset hasnt been downloaded yet\n",
283+
" download=True, # Set download to True if the dataset hasnt been downloaded yet\n",
296284
" seed=0,\n",
297285
" transform=train_transforms,\n",
298286
")\n",
@@ -336,7 +324,7 @@
336324
" section=\"validation\",\n",
337325
" cache_rate=1, # you may need a few Gb of RAM... Set to 0 otherwise\n",
338326
" num_workers=4,\n",
339-
" download=False, # Set download to True if the dataset hasnt been downloaded yet\n",
327+
" download=True, # Set download to True if the dataset hasnt been downloaded yet\n",
340328
" seed=0,\n",
341329
" transform=train_transforms,\n",
342330
")\n",
@@ -645,7 +633,7 @@
645633
}
646634
],
647635
"source": [
648-
"n_iterations = 1e4 # training for longer helps a lot with reconstruction quality, even if the loss is already low\n",
636+
"n_iterations = 1e4 # training for longer (1e4 ~ 3h) helps a lot with reconstruction quality, even if the loss is already low\n",
649637
"batch_size = 64\n",
650638
"val_interval = 100\n",
651639
"iter_loss_list, val_iter_loss_list = [], []\n",
@@ -810,7 +798,9 @@
810798
"# get latent space of training set\n",
811799
"latents_train = []\n",
812800
"classes_train = []\n",
813-
"for i in range(15): # 15 slices from each volume\n",
801+
"# 15 slices from each volume\n",
802+
"nb_slices_per_volume = 15\n",
803+
"for _ in range(nb_slices_per_volume): \n",
814804
" for batch in train_loader:\n",
815805
" images = batch[\"image\"].to(device)\n",
816806
" latent = model.semantic_encoder(images)\n",

0 commit comments

Comments
 (0)