Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 3f72d37

Browse files
committed
Merge remote-tracking branch 'origin/main' into main
2 parents d3f14b6 + dbbbb27 commit 3f72d37

2 files changed

Lines changed: 91 additions & 101 deletions

File tree

tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb

Lines changed: 58 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,6 @@
88
"# 3D Latent Diffusion Model"
99
]
1010
},
11-
{
12-
"cell_type": "markdown",
13-
"id": "bcbbb4a3",
14-
"metadata": {},
15-
"source": [
16-
"## Set up environment using Colab\n"
17-
]
18-
},
19-
{
20-
"cell_type": "code",
21-
"execution_count": 1,
22-
"id": "8caae787",
23-
"metadata": {},
24-
"outputs": [],
25-
"source": [
26-
"!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"\n",
27-
"!python -c \"import matplotlib\" || pip install -q matplotlib\n",
28-
"%matplotlib inline"
29-
]
30-
},
3111
{
3212
"cell_type": "markdown",
3313
"id": "da9e6b23",
@@ -78,27 +58,26 @@
7858
],
7959
"source": [
8060
"import os\n",
61+
"import shutil\n",
8162
"import tempfile\n",
63+
"\n",
8264
"import matplotlib.pyplot as plt\n",
83-
"from tqdm import tqdm\n",
84-
"import shutil\n",
8565
"import torch\n",
8666
"import torch.nn.functional as F\n",
87-
"from torch.cuda.amp import GradScaler, autocast\n",
88-
"\n",
8967
"from monai import transforms\n",
9068
"from monai.apps import DecathlonDataset\n",
9169
"from monai.config import print_config\n",
9270
"from monai.data import DataLoader\n",
9371
"from monai.utils import first, set_determinism\n",
72+
"from torch.cuda.amp import GradScaler, autocast\n",
73+
"from torch.nn import L1Loss\n",
74+
"from tqdm import tqdm\n",
9475
"\n",
95-
"\n",
96-
"from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n",
9776
"from generative.inferers import LatentDiffusionInferer\n",
98-
"from generative.schedulers import DDPMScheduler\n",
9977
"from generative.losses.adversarial_loss import PatchAdversarialLoss\n",
10078
"from generative.losses.perceptual import PerceptualLoss\n",
101-
"from torch.nn import L1Loss\n",
79+
"from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n",
80+
"from generative.networks.schedulers import DDPMScheduler\n",
10281
"\n",
10382
"print_config()"
10483
]
@@ -183,29 +162,36 @@
183162
],
184163
"source": [
185164
"batch_size = 2\n",
186-
"channel = 0 # 0 = Flair\n",
187-
"assert channel in [0,1,2,3], 'Choose a valid channel'\n",
165+
"channel = 0 # 0 = Flair\n",
166+
"assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n",
188167
"\n",
189168
"train_transforms = transforms.Compose(\n",
190169
" [\n",
191170
" transforms.LoadImaged(keys=[\"image\"]),\n",
192171
" transforms.EnsureChannelFirstd(keys=[\"image\"]),\n",
193-
" transforms.Lambdad(keys=\"image\", func=lambda x: x[channel,:, :, :]),\n",
172+
" transforms.Lambdad(keys=\"image\", func=lambda x: x[channel, :, :, :]),\n",
194173
" transforms.AddChanneld(keys=[\"image\"]),\n",
195174
" transforms.EnsureTyped(keys=[\"image\"]),\n",
196175
" transforms.Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n",
197-
" transforms.Spacingd(keys=[\"image\"], pixdim=(2.4, 2.4, 2.2), mode=(\"bilinear\"),),\n",
198-
" transforms.CenterSpatialCropd(keys=[\"image\"],roi_size = (96, 96, 64)),\n",
199-
" transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower= 0, upper= 99.5, b_min= 0, b_max= 1),\n",
176+
" transforms.Spacingd(\n",
177+
" keys=[\"image\"],\n",
178+
" pixdim=(2.4, 2.4, 2.2),\n",
179+
" mode=(\"bilinear\"),\n",
180+
" ),\n",
181+
" transforms.CenterSpatialCropd(keys=[\"image\"], roi_size=(96, 96, 64)),\n",
182+
" transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower=0, upper=99.5, b_min=0, b_max=1),\n",
200183
" ]\n",
201184
")\n",
202-
"train_ds = DecathlonDataset(root_dir=root_dir, \n",
203-
" task='Task01_BrainTumour', \n",
204-
" section=\"training\", #validation\n",
205-
" cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n",
206-
" num_workers=4,\n",
207-
" download=False, # Set download to True if the dataset hasnt been downloaded yet\n",
208-
" seed=0, transform = train_transforms) \n",
185+
"train_ds = DecathlonDataset(\n",
186+
" root_dir=root_dir,\n",
187+
" task=\"Task01_BrainTumour\",\n",
188+
" section=\"training\", # validation\n",
189+
" cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n",
190+
" num_workers=4,\n",
191+
" download=False, # Set download to True if the dataset hasnt been downloaded yet\n",
192+
" seed=0,\n",
193+
" transform=train_transforms,\n",
194+
")\n",
209195
"train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)\n",
210196
"print(f'Image shape {train_ds[0][\"image\"].shape}')"
211197
]
@@ -252,16 +238,17 @@
252238
"check_data = first(train_loader)\n",
253239
"idx = 0\n",
254240
"\n",
255-
"img = check_data[\"image\"][idx,0]\n",
241+
"img = check_data[\"image\"][idx, 0]\n",
256242
"fig, axs = plt.subplots(nrows=1, ncols=3)\n",
257243
"for ax in axs:\n",
258244
" ax.axis(\"off\")\n",
259245
"ax = axs[0]\n",
260-
"ax.imshow(img[...,img.shape[2]//2], cmap=\"gray\")\n",
246+
"ax.imshow(img[..., img.shape[2] // 2], cmap=\"gray\")\n",
261247
"ax = axs[1]\n",
262-
"ax.imshow(img[:,img.shape[1]//2, ...], cmap=\"gray\")\n",
248+
"ax.imshow(img[:, img.shape[1] // 2, ...], cmap=\"gray\")\n",
263249
"ax = axs[2]\n",
264-
"ax.imshow(img[img.shape[0]//2, ...], cmap=\"gray\")"
250+
"ax.imshow(img[img.shape[0] // 2, ...], cmap=\"gray\")\n",
251+
"# plt.savefig(\"training_examples.png\")"
265252
]
266253
},
267254
{
@@ -316,7 +303,7 @@
316303
" in_channels=3,\n",
317304
" out_channels=3,\n",
318305
" num_res_blocks=1,\n",
319-
" num_channels=[32,64,64],\n",
306+
" num_channels=[32, 64, 64],\n",
320307
" attention_levels=(False, True, True),\n",
321308
" num_head_channels=1,\n",
322309
")\n",
@@ -344,7 +331,7 @@
344331
" beta_end=0.0195,\n",
345332
")\n",
346333
"\n",
347-
"inferer = LatentDiffusionInferer(scheduler);"
334+
"inferer = LatentDiffusionInferer(scheduler)"
348335
]
349336
},
350337
{
@@ -364,12 +351,15 @@
364351
"source": [
365352
"l1_loss = L1Loss()\n",
366353
"adv_loss = PatchAdversarialLoss(criterion=\"least_squares\")\n",
367-
"loss_perceptual = PerceptualLoss(spatial_dims=3, network_type='squeeze', is_fake_3d=True, fake_3d_ratio=0.2)\n",
354+
"loss_perceptual = PerceptualLoss(spatial_dims=3, network_type=\"squeeze\", is_fake_3d=True, fake_3d_ratio=0.2)\n",
368355
"loss_perceptual.to(device)\n",
356+
"\n",
357+
"\n",
369358
"def KL_loss(z_mu, z_sigma):\n",
370-
" kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim = [1, 2, 3, 4])\n",
359+
" kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])\n",
371360
" return torch.sum(kl_loss) / kl_loss.shape[0]\n",
372361
"\n",
362+
"\n",
373363
"adv_weight = 0.01\n",
374364
"perceptual_weight = 0.001\n",
375365
"kl_weight = 1e-6"
@@ -527,25 +517,25 @@
527517
" progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)\n",
528518
" progress_bar.set_description(f\"Epoch {epoch}\")\n",
529519
" for step, batch in progress_bar:\n",
530-
" images = batch[\"image\"].to(device) # choose only one of Brats channels\n",
531-
" \n",
520+
" images = batch[\"image\"].to(device) # choose only one of Brats channels\n",
521+
"\n",
532522
" # Generator part\n",
533523
" optimizer_g.zero_grad(set_to_none=True)\n",
534-
" reconstruction, z_mu, z_sigma = autoencoder(images)\n",
524+
" reconstruction, z_mu, z_sigma = autoencoder(images)\n",
535525
" kl_loss = KL_loss(z_mu, z_sigma)\n",
536526
"\n",
537527
" recons_loss = l1_loss(reconstruction.float(), images.float())\n",
538528
" p_loss = loss_perceptual(reconstruction.float(), images.float())\n",
539529
" loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss\n",
540-
" \n",
530+
"\n",
541531
" if epoch > autoencoder_warm_up_n_epochs:\n",
542532
" logits_fake = discriminator(reconstruction.contiguous().float())[-1]\n",
543533
" generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)\n",
544534
" loss_g += adv_weight * generator_loss\n",
545535
"\n",
546536
" loss_g.backward()\n",
547537
" optimizer_g.step()\n",
548-
" \n",
538+
"\n",
549539
" if epoch > autoencoder_warm_up_n_epochs:\n",
550540
" # Discriminator part\n",
551541
" optimizer_d.zero_grad(set_to_none=True)\n",
@@ -604,7 +594,7 @@
604594
"source": [
605595
"plt.style.use(\"ggplot\")\n",
606596
"plt.title(\"Learning Curves\", fontsize=20)\n",
607-
"plt.plot(epoch_recon_loss_list) \n",
597+
"plt.plot(epoch_recon_loss_list)\n",
608598
"plt.yticks(fontsize=12)\n",
609599
"plt.xticks(fontsize=12)\n",
610600
"plt.xlabel(\"Epochs\", fontsize=16)\n",
@@ -685,11 +675,11 @@
685675
"for ax in axs:\n",
686676
" ax.axis(\"off\")\n",
687677
"ax = axs[0]\n",
688-
"ax.imshow(img[...,img.shape[2]//2], cmap=\"gray\")\n",
678+
"ax.imshow(img[..., img.shape[2] // 2], cmap=\"gray\")\n",
689679
"ax = axs[1]\n",
690-
"ax.imshow(img[:,img.shape[1]//2, ...], cmap=\"gray\")\n",
680+
"ax.imshow(img[:, img.shape[1] // 2, ...], cmap=\"gray\")\n",
691681
"ax = axs[2]\n",
692-
"ax.imshow(img[img.shape[0]//2, ...], cmap=\"gray\")"
682+
"ax.imshow(img[img.shape[0] // 2, ...], cmap=\"gray\")"
693683
]
694684
},
695685
{
@@ -733,12 +723,12 @@
733723
" for step, batch in progress_bar:\n",
734724
" images = batch[\"image\"].to(device)\n",
735725
" optimizer_diff.zero_grad(set_to_none=True)\n",
736-
" \n",
726+
"\n",
737727
" with autocast(enabled=True):\n",
738728
" # Generate random noise\n",
739729
" noise = torch.randn_like(z).to(device)\n",
740730
" # Get model prediction\n",
741-
" noise_pred = inferer(inputs=images, autoencoder_model = autoencoder, diffusion_model=unet, noise=noise)\n",
731+
" noise_pred = inferer(inputs=images, autoencoder_model=autoencoder, diffusion_model=unet, noise=noise)\n",
742732
"\n",
743733
" loss = F.mse_loss(noise_pred.float(), noise.float())\n",
744734
"\n",
@@ -781,9 +771,9 @@
781771
}
782772
],
783773
"source": [
784-
"plt.plot(epoch_loss_list);\n",
774+
"plt.plot(epoch_loss_list)\n",
785775
"plt.title(\"Learning Curves\", fontsize=20)\n",
786-
"plt.plot(epoch_loss_list) \n",
776+
"plt.plot(epoch_loss_list)\n",
787777
"plt.yticks(fontsize=12)\n",
788778
"plt.xticks(fontsize=12)\n",
789779
"plt.xlabel(\"Epochs\", fontsize=16)\n",
@@ -863,16 +853,16 @@
863853
],
864854
"source": [
865855
"idx = 0\n",
866-
"img = synthetic_images[idx, channel].detach().cpu().numpy() # images\n",
856+
"img = synthetic_images[idx, channel].detach().cpu().numpy() # images\n",
867857
"fig, axs = plt.subplots(nrows=1, ncols=3)\n",
868858
"for ax in axs:\n",
869859
" ax.axis(\"off\")\n",
870860
"ax = axs[0]\n",
871-
"ax.imshow(img[...,img.shape[2]//2], cmap=\"gray\")\n",
861+
"ax.imshow(img[..., img.shape[2] // 2], cmap=\"gray\")\n",
872862
"ax = axs[1]\n",
873-
"ax.imshow(img[:,img.shape[1]//2, ...], cmap=\"gray\")\n",
863+
"ax.imshow(img[:, img.shape[1] // 2, ...], cmap=\"gray\")\n",
874864
"ax = axs[2]\n",
875-
"ax.imshow(img[img.shape[0]//2, ...], cmap=\"gray\")"
865+
"ax.imshow(img[img.shape[0] // 2, ...], cmap=\"gray\")"
876866
]
877867
},
878868
{
@@ -902,7 +892,7 @@
902892
"main_language": "python"
903893
},
904894
"kernelspec": {
905-
"display_name": "Python 3.8.2 ('torch_gpu')",
895+
"display_name": "Python 3",
906896
"language": "python",
907897
"name": "python3"
908898
},
@@ -916,7 +906,7 @@
916906
"name": "python",
917907
"nbconvert_exporter": "python",
918908
"pygments_lexer": "ipython3",
919-
"version": "3.8.2"
909+
"version": "3.8.12"
920910
},
921911
"vscode": {
922912
"interpreter": {

0 commit comments

Comments
 (0)