|
242 | 242 | "outputs": [], |
243 | 243 | "source": [ |
244 | 244 | "max_epochs = 5\n", |
245 | | - "model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(\"cuda:0\")\n", |
| 245 | + "device = torch.device(\"cuda:0\" if torch.cuda.device_count() > 0 else \"cpu\")\n", |
| 246 | + "model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)\n", |
246 | 247 | "\n", |
247 | 248 | "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", |
248 | 249 | "trainer = SupervisedTrainer(\n", |
249 | | - " device=torch.device(\"cuda:0\"),\n", |
| 250 | + " device=device,\n", |
250 | 251 | " max_epochs=max_epochs,\n", |
251 | 252 | " train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),\n", |
252 | 253 | " network=model,\n", |
|
312 | 313 | "max_items_to_print = 10\n", |
313 | 314 | "with eval_mode(model):\n", |
314 | 315 | " for item in DataLoader(testdata, batch_size=1, num_workers=0):\n", |
315 | | - " prob = np.array(model(item[\"image\"].to(\"cuda:0\")).detach().to(\"cpu\"))[0]\n", |
| 316 | + " prob = np.array(model(item[\"image\"].to(device)).detach().to(\"cpu\"))[0]\n", |
316 | 317 | " pred = class_names[prob.argmax()]\n", |
317 | 318 | " gt = item[\"class_name\"][0]\n", |
318 | 319 | " print(f\"Class prediction is {pred}. Ground-truth: {gt}\")\n", |
|
0 commit comments