|
8 | 8 | "# 3D Latent Diffusion Model" |
9 | 9 | ] |
10 | 10 | }, |
11 | | - { |
12 | | - "cell_type": "markdown", |
13 | | - "id": "bcbbb4a3", |
14 | | - "metadata": {}, |
15 | | - "source": [ |
16 | | - "## Set up environment using Colab\n" |
17 | | - ] |
18 | | - }, |
19 | | - { |
20 | | - "cell_type": "code", |
21 | | - "execution_count": 1, |
22 | | - "id": "8caae787", |
23 | | - "metadata": {}, |
24 | | - "outputs": [], |
25 | | - "source": [ |
26 | | - "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"\n", |
27 | | - "!python -c \"import matplotlib\" || pip install -q matplotlib\n", |
28 | | - "%matplotlib inline" |
29 | | - ] |
30 | | - }, |
31 | 11 | { |
32 | 12 | "cell_type": "markdown", |
33 | 13 | "id": "da9e6b23", |
|
78 | 58 | ], |
79 | 59 | "source": [ |
80 | 60 | "import os\n", |
| 61 | + "import shutil\n", |
81 | 62 | "import tempfile\n", |
| 63 | + "\n", |
82 | 64 | "import matplotlib.pyplot as plt\n", |
83 | | - "from tqdm import tqdm\n", |
84 | | - "import shutil\n", |
85 | 65 | "import torch\n", |
86 | 66 | "import torch.nn.functional as F\n", |
87 | | - "from torch.cuda.amp import GradScaler, autocast\n", |
88 | | - "\n", |
89 | 67 | "from monai import transforms\n", |
90 | 68 | "from monai.apps import DecathlonDataset\n", |
91 | 69 | "from monai.config import print_config\n", |
92 | 70 | "from monai.data import DataLoader\n", |
93 | 71 | "from monai.utils import first, set_determinism\n", |
| 72 | + "from torch.cuda.amp import GradScaler, autocast\n", |
| 73 | + "from torch.nn import L1Loss\n", |
| 74 | + "from tqdm import tqdm\n", |
94 | 75 | "\n", |
95 | | - "\n", |
96 | | - "from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n", |
97 | 76 | "from generative.inferers import LatentDiffusionInferer\n", |
98 | | - "from generative.schedulers import DDPMScheduler\n", |
99 | 77 | "from generative.losses.adversarial_loss import PatchAdversarialLoss\n", |
100 | 78 | "from generative.losses.perceptual import PerceptualLoss\n", |
101 | | - "from torch.nn import L1Loss\n", |
| 79 | + "from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n", |
| 80 | + "from generative.networks.schedulers import DDPMScheduler\n", |
102 | 81 | "\n", |
103 | 82 | "print_config()" |
104 | 83 | ] |
|
183 | 162 | ], |
184 | 163 | "source": [ |
185 | 164 | "batch_size = 2\n", |
186 | | - "channel = 0 # 0 = Flair\n", |
187 | | - "assert channel in [0,1,2,3], 'Choose a valid channel'\n", |
| 165 | + "channel = 0 # 0 = Flair\n", |
| 166 | + "assert channel in [0, 1, 2, 3], \"Choose a valid channel\"\n", |
188 | 167 | "\n", |
189 | 168 | "train_transforms = transforms.Compose(\n", |
190 | 169 | " [\n", |
191 | 170 | " transforms.LoadImaged(keys=[\"image\"]),\n", |
192 | 171 | " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", |
193 | | - " transforms.Lambdad(keys=\"image\", func=lambda x: x[channel,:, :, :]),\n", |
| 172 | + " transforms.Lambdad(keys=\"image\", func=lambda x: x[channel, :, :, :]),\n", |
194 | 173 | " transforms.AddChanneld(keys=[\"image\"]),\n", |
195 | 174 | " transforms.EnsureTyped(keys=[\"image\"]),\n", |
196 | 175 | " transforms.Orientationd(keys=[\"image\"], axcodes=\"RAS\"),\n", |
197 | | - " transforms.Spacingd(keys=[\"image\"], pixdim=(2.4, 2.4, 2.2), mode=(\"bilinear\"),),\n", |
198 | | - " transforms.CenterSpatialCropd(keys=[\"image\"],roi_size = (96, 96, 64)),\n", |
199 | | - " transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower= 0, upper= 99.5, b_min= 0, b_max= 1),\n", |
| 176 | + " transforms.Spacingd(\n", |
| 177 | + " keys=[\"image\"],\n", |
| 178 | + " pixdim=(2.4, 2.4, 2.2),\n", |
| 179 | + " mode=(\"bilinear\"),\n", |
| 180 | + " ),\n", |
| 181 | + " transforms.CenterSpatialCropd(keys=[\"image\"], roi_size=(96, 96, 64)),\n", |
| 182 | + " transforms.ScaleIntensityRangePercentilesd(keys=\"image\", lower=0, upper=99.5, b_min=0, b_max=1),\n", |
200 | 183 | " ]\n", |
201 | 184 | ")\n", |
202 | | - "train_ds = DecathlonDataset(root_dir=root_dir, \n", |
203 | | - " task='Task01_BrainTumour', \n", |
204 | | - " section=\"training\", #validation\n", |
205 | | - " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", |
206 | | - " num_workers=4,\n", |
207 | | - " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", |
208 | | - " seed=0, transform = train_transforms) \n", |
| 185 | + "train_ds = DecathlonDataset(\n", |
| 186 | + " root_dir=root_dir,\n", |
| 187 | + " task=\"Task01_BrainTumour\",\n", |
| 188 | + " section=\"training\", # validation\n", |
| 189 | + " cache_rate=1.0, # you may need a few Gb of RAM... Set to 0 otherwise\n", |
| 190 | + " num_workers=4,\n", |
| 191 | + " download=False, # Set download to True if the dataset hasnt been downloaded yet\n", |
| 192 | + " seed=0,\n", |
| 193 | + " transform=train_transforms,\n", |
| 194 | + ")\n", |
209 | 195 | "train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)\n", |
210 | 196 | "print(f'Image shape {train_ds[0][\"image\"].shape}')" |
211 | 197 | ] |
|
252 | 238 | "check_data = first(train_loader)\n", |
253 | 239 | "idx = 0\n", |
254 | 240 | "\n", |
255 | | - "img = check_data[\"image\"][idx,0]\n", |
| 241 | + "img = check_data[\"image\"][idx, 0]\n", |
256 | 242 | "fig, axs = plt.subplots(nrows=1, ncols=3)\n", |
257 | 243 | "for ax in axs:\n", |
258 | 244 | " ax.axis(\"off\")\n", |
259 | 245 | "ax = axs[0]\n", |
260 | | - "ax.imshow(img[...,img.shape[2]//2], cmap=\"gray\")\n", |
| 246 | + "ax.imshow(img[..., img.shape[2] // 2], cmap=\"gray\")\n", |
261 | 247 | "ax = axs[1]\n", |
262 | | - "ax.imshow(img[:,img.shape[1]//2, ...], cmap=\"gray\")\n", |
| 248 | + "ax.imshow(img[:, img.shape[1] // 2, ...], cmap=\"gray\")\n", |
263 | 249 | "ax = axs[2]\n", |
264 | | - "ax.imshow(img[img.shape[0]//2, ...], cmap=\"gray\")" |
| 250 | + "ax.imshow(img[img.shape[0] // 2, ...], cmap=\"gray\")\n", |
| 251 | + "# plt.savefig(\"training_examples.png\")" |
265 | 252 | ] |
266 | 253 | }, |
267 | 254 | { |
|
316 | 303 | " in_channels=3,\n", |
317 | 304 | " out_channels=3,\n", |
318 | 305 | " num_res_blocks=1,\n", |
319 | | - " num_channels=[32,64,64],\n", |
| 306 | + " num_channels=[32, 64, 64],\n", |
320 | 307 | " attention_levels=(False, True, True),\n", |
321 | 308 | " num_head_channels=1,\n", |
322 | 309 | ")\n", |
|
344 | 331 | " beta_end=0.0195,\n", |
345 | 332 | ")\n", |
346 | 333 | "\n", |
347 | | - "inferer = LatentDiffusionInferer(scheduler);" |
| 334 | + "inferer = LatentDiffusionInferer(scheduler)" |
348 | 335 | ] |
349 | 336 | }, |
350 | 337 | { |
|
364 | 351 | "source": [ |
365 | 352 | "l1_loss = L1Loss()\n", |
366 | 353 | "adv_loss = PatchAdversarialLoss(criterion=\"least_squares\")\n", |
367 | | - "loss_perceptual = PerceptualLoss(spatial_dims=3, network_type='squeeze', is_fake_3d=True, fake_3d_ratio=0.2)\n", |
| 354 | + "loss_perceptual = PerceptualLoss(spatial_dims=3, network_type=\"squeeze\", is_fake_3d=True, fake_3d_ratio=0.2)\n", |
368 | 355 | "loss_perceptual.to(device)\n", |
| 356 | + "\n", |
| 357 | + "\n", |
369 | 358 | "def KL_loss(z_mu, z_sigma):\n", |
370 | | - " kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim = [1, 2, 3, 4])\n", |
| 359 | + " kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])\n", |
371 | 360 | " return torch.sum(kl_loss) / kl_loss.shape[0]\n", |
372 | 361 | "\n", |
| 362 | + "\n", |
373 | 363 | "adv_weight = 0.01\n", |
374 | 364 | "perceptual_weight = 0.001\n", |
375 | 365 | "kl_weight = 1e-6" |
|
527 | 517 | " progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), ncols=110)\n", |
528 | 518 | " progress_bar.set_description(f\"Epoch {epoch}\")\n", |
529 | 519 | " for step, batch in progress_bar:\n", |
530 | | - " images = batch[\"image\"].to(device) # choose only one of Brats channels\n", |
531 | | - " \n", |
| 520 | + " images = batch[\"image\"].to(device) # choose only one of Brats channels\n", |
| 521 | + "\n", |
532 | 522 | " # Generator part\n", |
533 | 523 | " optimizer_g.zero_grad(set_to_none=True)\n", |
534 | | - " reconstruction, z_mu, z_sigma = autoencoder(images)\n", |
| 524 | + " reconstruction, z_mu, z_sigma = autoencoder(images)\n", |
535 | 525 | " kl_loss = KL_loss(z_mu, z_sigma)\n", |
536 | 526 | "\n", |
537 | 527 | " recons_loss = l1_loss(reconstruction.float(), images.float())\n", |
538 | 528 | " p_loss = loss_perceptual(reconstruction.float(), images.float())\n", |
539 | 529 | " loss_g = recons_loss + kl_weight * kl_loss + perceptual_weight * p_loss\n", |
540 | | - " \n", |
| 530 | + "\n", |
541 | 531 | " if epoch > autoencoder_warm_up_n_epochs:\n", |
542 | 532 | " logits_fake = discriminator(reconstruction.contiguous().float())[-1]\n", |
543 | 533 | " generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)\n", |
544 | 534 | " loss_g += adv_weight * generator_loss\n", |
545 | 535 | "\n", |
546 | 536 | " loss_g.backward()\n", |
547 | 537 | " optimizer_g.step()\n", |
548 | | - " \n", |
| 538 | + "\n", |
549 | 539 | " if epoch > autoencoder_warm_up_n_epochs:\n", |
550 | 540 | " # Discriminator part\n", |
551 | 541 | " optimizer_d.zero_grad(set_to_none=True)\n", |
|
604 | 594 | "source": [ |
605 | 595 | "plt.style.use(\"ggplot\")\n", |
606 | 596 | "plt.title(\"Learning Curves\", fontsize=20)\n", |
607 | | - "plt.plot(epoch_recon_loss_list) \n", |
| 597 | + "plt.plot(epoch_recon_loss_list)\n", |
608 | 598 | "plt.yticks(fontsize=12)\n", |
609 | 599 | "plt.xticks(fontsize=12)\n", |
610 | 600 | "plt.xlabel(\"Epochs\", fontsize=16)\n", |
|
685 | 675 | "for ax in axs:\n", |
686 | 676 | " ax.axis(\"off\")\n", |
687 | 677 | "ax = axs[0]\n", |
688 | | - "ax.imshow(img[...,img.shape[2]//2], cmap=\"gray\")\n", |
| 678 | + "ax.imshow(img[..., img.shape[2] // 2], cmap=\"gray\")\n", |
689 | 679 | "ax = axs[1]\n", |
690 | | - "ax.imshow(img[:,img.shape[1]//2, ...], cmap=\"gray\")\n", |
| 680 | + "ax.imshow(img[:, img.shape[1] // 2, ...], cmap=\"gray\")\n", |
691 | 681 | "ax = axs[2]\n", |
692 | | - "ax.imshow(img[img.shape[0]//2, ...], cmap=\"gray\")" |
| 682 | + "ax.imshow(img[img.shape[0] // 2, ...], cmap=\"gray\")" |
693 | 683 | ] |
694 | 684 | }, |
695 | 685 | { |
|
733 | 723 | " for step, batch in progress_bar:\n", |
734 | 724 | " images = batch[\"image\"].to(device)\n", |
735 | 725 | " optimizer_diff.zero_grad(set_to_none=True)\n", |
736 | | - " \n", |
| 726 | + "\n", |
737 | 727 | " with autocast(enabled=True):\n", |
738 | 728 | " # Generate random noise\n", |
739 | 729 | " noise = torch.randn_like(z).to(device)\n", |
740 | 730 | " # Get model prediction\n", |
741 | | - " noise_pred = inferer(inputs=images, autoencoder_model = autoencoder, diffusion_model=unet, noise=noise)\n", |
| 731 | + " noise_pred = inferer(inputs=images, autoencoder_model=autoencoder, diffusion_model=unet, noise=noise)\n", |
742 | 732 | "\n", |
743 | 733 | " loss = F.mse_loss(noise_pred.float(), noise.float())\n", |
744 | 734 | "\n", |
|
781 | 771 | } |
782 | 772 | ], |
783 | 773 | "source": [ |
784 | | - "plt.plot(epoch_loss_list);\n", |
| 774 | + "plt.plot(epoch_loss_list)\n", |
785 | 775 | "plt.title(\"Learning Curves\", fontsize=20)\n", |
786 | | - "plt.plot(epoch_loss_list) \n", |
| 776 | + "plt.plot(epoch_loss_list)\n", |
787 | 777 | "plt.yticks(fontsize=12)\n", |
788 | 778 | "plt.xticks(fontsize=12)\n", |
789 | 779 | "plt.xlabel(\"Epochs\", fontsize=16)\n", |
|
863 | 853 | ], |
864 | 854 | "source": [ |
865 | 855 | "idx = 0\n", |
866 | | - "img = synthetic_images[idx, channel].detach().cpu().numpy() # images\n", |
| 856 | + "img = synthetic_images[idx, channel].detach().cpu().numpy() # images\n", |
867 | 857 | "fig, axs = plt.subplots(nrows=1, ncols=3)\n", |
868 | 858 | "for ax in axs:\n", |
869 | 859 | " ax.axis(\"off\")\n", |
870 | 860 | "ax = axs[0]\n", |
871 | | - "ax.imshow(img[...,img.shape[2]//2], cmap=\"gray\")\n", |
| 861 | + "ax.imshow(img[..., img.shape[2] // 2], cmap=\"gray\")\n", |
872 | 862 | "ax = axs[1]\n", |
873 | | - "ax.imshow(img[:,img.shape[1]//2, ...], cmap=\"gray\")\n", |
| 863 | + "ax.imshow(img[:, img.shape[1] // 2, ...], cmap=\"gray\")\n", |
874 | 864 | "ax = axs[2]\n", |
875 | | - "ax.imshow(img[img.shape[0]//2, ...], cmap=\"gray\")" |
| 865 | + "ax.imshow(img[img.shape[0] // 2, ...], cmap=\"gray\")" |
876 | 866 | ] |
877 | 867 | }, |
878 | 868 | { |
|
902 | 892 | "main_language": "python" |
903 | 893 | }, |
904 | 894 | "kernelspec": { |
905 | | - "display_name": "Python 3.8.2 ('torch_gpu')", |
| 895 | + "display_name": "Python 3", |
906 | 896 | "language": "python", |
907 | 897 | "name": "python3" |
908 | 898 | }, |
|
916 | 906 | "name": "python", |
917 | 907 | "nbconvert_exporter": "python", |
918 | 908 | "pygments_lexer": "ipython3", |
919 | | - "version": "3.8.2" |
| 909 | + "version": "3.8.12" |
920 | 910 | }, |
921 | 911 | "vscode": { |
922 | 912 | "interpreter": { |
|
0 commit comments