Skip to content

Commit 53b77c5

Browse files
committed
adjusting class attribute in lightning wrapper
1 parent 71cc68e commit 53b77c5

1 file changed

Lines changed: 6 additions & 36 deletions

File tree

mambular/base_models/lightning_wrapper.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
else:
8383
output_dim = num_classes
8484

85-
self.model = model_class(
85+
self.base_model = model_class(
8686
config=config,
8787
num_feature_info=num_feature_info,
8888
cat_feature_info=cat_feature_info,
@@ -107,7 +107,7 @@ def forward(self, num_features, cat_features):
107107
Model output.
108108
"""
109109

110-
return self.model.forward(num_features, cat_features)
110+
return self.base_model.forward(num_features, cat_features)
111111

112112
def compute_loss(self, predictions, y_true):
113113
"""
@@ -168,16 +168,6 @@ def training_step(self, batch, batch_idx):
168168
prog_bar=True,
169169
logger=True,
170170
)
171-
elif isinstance(self.loss_fct, nn.MSELoss):
172-
rmse = torch.sqrt(loss)
173-
self.log(
174-
"train_rmse",
175-
rmse,
176-
on_step=True,
177-
on_epoch=True,
178-
prog_bar=True,
179-
logger=True,
180-
)
181171

182172
return loss
183173

@@ -205,7 +195,7 @@ def validation_step(self, batch, batch_idx):
205195
self.log(
206196
"val_loss",
207197
val_loss,
208-
on_step=True,
198+
on_step=False,
209199
on_epoch=True,
210200
prog_bar=True,
211201
logger=True,
@@ -218,17 +208,7 @@ def validation_step(self, batch, batch_idx):
218208
self.log(
219209
"val_acc",
220210
acc,
221-
on_step=True,
222-
on_epoch=True,
223-
prog_bar=True,
224-
logger=True,
225-
)
226-
elif isinstance(self.loss_fct, nn.MSELoss):
227-
rmse = torch.sqrt(val_loss)
228-
self.log(
229-
"val_rmse",
230-
rmse,
231-
on_step=True,
211+
on_step=False,
232212
on_epoch=True,
233213
prog_bar=True,
234214
logger=True,
@@ -272,17 +252,7 @@ def test_step(self, batch, batch_idx):
272252
self.log(
273253
"test_acc",
274254
acc,
275-
on_step=True,
276-
on_epoch=True,
277-
prog_bar=True,
278-
logger=True,
279-
)
280-
elif isinstance(self.loss_fct, nn.MSELoss):
281-
rmse = torch.sqrt(test_loss)
282-
self.log(
283-
"test_rmse",
284-
rmse,
285-
on_step=True,
255+
on_step=False,
286256
on_epoch=True,
287257
prog_bar=True,
288258
logger=True,
@@ -300,7 +270,7 @@ def configure_optimizers(self):
300270
A dictionary containing the optimizer and lr_scheduler configurations.
301271
"""
302272
optimizer = torch.optim.Adam(
303-
self.model.parameters(),
273+
self.base_model.parameters(),
304274
lr=self.lr,
305275
weight_decay=self.weight_decay,
306276
)

0 commit comments

Comments
 (0)