1111from kornia import image_to_tensor , tensor_to_image
1212from kornia .augmentation import ColorJitter , RandomChannelShuffle , RandomHorizontalFlip , RandomThinPlateSpline
1313from pytorch_lightning import LightningModule , Trainer
14+ from pytorch_lightning .callbacks .progress import TQDMProgressBar
1415from pytorch_lightning .loggers import CSVLogger
1516from torch import Tensor
1617from torch .nn import functional as F
1718from torch .utils .data import DataLoader
1819from torchvision .datasets import CIFAR10
1920
20- AVAIL_GPUS = min (1 , torch .cuda .device_count ())
21-
2221# %% [markdown]
2322# ## Define Data Augmentations module
2423#
@@ -106,10 +105,11 @@ def __init__(self):
106105
107106 self .transform = DataAugmentation () # per batch augmentation_kornia
108107
109- self .accuracy = torchmetrics .Accuracy ()
108+ self .train_accuracy = torchmetrics .Accuracy ()
109+ self .val_accuracy = torchmetrics .Accuracy ()
110110
111111 def forward (self , x ):
112- return F . softmax ( self .model (x ) )
112+ return self .model (x )
113113
114114 def compute_loss (self , y_hat , y ):
115115 return F .cross_entropy (y_hat , y )
@@ -127,21 +127,28 @@ def _to_vis(data):
127127 plt .figure (figsize = win_size )
128128 plt .imshow (_to_vis (imgs_aug ))
129129
130+ def on_after_batch_transfer (self , batch , dataloader_idx ):
131+ x , y = batch
132+ if self .trainer .training :
133+ x = self .transform (x ) # => we perform GPU/Batched data augmentation
134+ return x , y
135+
130136 def training_step (self , batch , batch_idx ):
131137 x , y = batch
132- x_aug = self .transform (x ) # => we perform GPU/Batched data augmentation
133- y_hat = self (x_aug )
138+ y_hat = self (x )
134139 loss = self .compute_loss (y_hat , y )
140+ self .train_accuracy .update (y_hat , y )
135141 self .log ("train_loss" , loss , prog_bar = False )
136- self .log ("train_acc" , self .accuracy ( y_hat , y ) , prog_bar = False )
142+ self .log ("train_acc" , self .train_accuracy , prog_bar = False )
137143 return loss
138144
139145 def validation_step (self , batch , batch_idx ):
140146 x , y = batch
141147 y_hat = self (x )
142148 loss = self .compute_loss (y_hat , y )
149+ self .val_accuracy .update (y_hat , y )
143150 self .log ("valid_loss" , loss , prog_bar = False )
144- self .log ("valid_acc" , self .accuracy ( y_hat , y ) , prog_bar = True )
151+ self .log ("valid_acc" , self .val_accuracy , prog_bar = True )
145152
146153 def configure_optimizers (self ):
147154 optimizer = torch .optim .AdamW (self .model .parameters (), lr = 1e-4 )
@@ -158,7 +165,7 @@ def train_dataloader(self):
158165 return loader
159166
160167 def val_dataloader (self ):
161- dataset = CIFAR10 (os .getcwd (), train = True , download = True , transform = self .preprocess )
168+ dataset = CIFAR10 (os .getcwd (), train = False , download = True , transform = self .preprocess )
162169 loader = DataLoader (dataset , batch_size = 32 )
163170 return loader
164171
@@ -179,8 +186,9 @@ def val_dataloader(self):
179186# %%
180187# Initialize a trainer
181188trainer = Trainer (
182- progress_bar_refresh_rate = 20 ,
183- gpus = AVAIL_GPUS ,
189+ callbacks = [TQDMProgressBar (refresh_rate = 20 )],
190+ accelerator = "auto" ,
191+ devices = 1 if torch .cuda .is_available () else None , # limiting got iPython runs
184192 max_epochs = 10 ,
185193 logger = CSVLogger (save_dir = "logs/" , name = "cifar10-resnet18" ),
186194)
0 commit comments