@@ -82,7 +82,7 @@ def __init__(
8282 else :
8383 output_dim = num_classes
8484
85- self .model = model_class (
85+ self .base_model = model_class (
8686 config = config ,
8787 num_feature_info = num_feature_info ,
8888 cat_feature_info = cat_feature_info ,
@@ -107,7 +107,7 @@ def forward(self, num_features, cat_features):
107107 Model output.
108108 """
109109
110- return self .model .forward (num_features , cat_features )
110+ return self .base_model .forward (num_features , cat_features )
111111
112112 def compute_loss (self , predictions , y_true ):
113113 """
@@ -168,16 +168,6 @@ def training_step(self, batch, batch_idx):
168168 prog_bar = True ,
169169 logger = True ,
170170 )
171- elif isinstance (self .loss_fct , nn .MSELoss ):
172- rmse = torch .sqrt (loss )
173- self .log (
174- "train_rmse" ,
175- rmse ,
176- on_step = True ,
177- on_epoch = True ,
178- prog_bar = True ,
179- logger = True ,
180- )
181171
182172 return loss
183173
@@ -205,7 +195,7 @@ def validation_step(self, batch, batch_idx):
205195 self .log (
206196 "val_loss" ,
207197 val_loss ,
208- on_step = True ,
198+ on_step = False ,
209199 on_epoch = True ,
210200 prog_bar = True ,
211201 logger = True ,
@@ -218,17 +208,7 @@ def validation_step(self, batch, batch_idx):
218208 self .log (
219209 "val_acc" ,
220210 acc ,
221- on_step = True ,
222- on_epoch = True ,
223- prog_bar = True ,
224- logger = True ,
225- )
226- elif isinstance (self .loss_fct , nn .MSELoss ):
227- rmse = torch .sqrt (val_loss )
228- self .log (
229- "val_rmse" ,
230- rmse ,
231- on_step = True ,
211+ on_step = False ,
232212 on_epoch = True ,
233213 prog_bar = True ,
234214 logger = True ,
@@ -272,17 +252,7 @@ def test_step(self, batch, batch_idx):
272252 self .log (
273253 "test_acc" ,
274254 acc ,
275- on_step = True ,
276- on_epoch = True ,
277- prog_bar = True ,
278- logger = True ,
279- )
280- elif isinstance (self .loss_fct , nn .MSELoss ):
281- rmse = torch .sqrt (test_loss )
282- self .log (
283- "test_rmse" ,
284- rmse ,
285- on_step = True ,
255+ on_step = False ,
286256 on_epoch = True ,
287257 prog_bar = True ,
288258 logger = True ,
@@ -300,7 +270,7 @@ def configure_optimizers(self):
300270 A dictionary containing the optimizer and lr_scheduler configurations.
301271 """
302272 optimizer = torch .optim .Adam (
303- self .model .parameters (),
273+ self .base_model .parameters (),
304274 lr = self .lr ,
305275 weight_decay = self .weight_decay ,
306276 )
0 commit comments