Skip to content

Commit b16af74

Browse files
authored
Merge pull request #84 from basf/trainer_fix
Trainer fix
2 parents 6da2cf8 + be1879d commit b16af74

4 files changed

Lines changed: 26 additions & 11 deletions

File tree

mambular/base_models/lightning_wrapper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
lss=False,
3838
family=None,
3939
loss_fct: callable = None,
40-
**kwargs
40+
**kwargs,
4141
):
4242
super().__init__()
4343
self.num_classes = num_classes
@@ -300,7 +300,7 @@ def configure_optimizers(self):
300300
A dictionary containing the optimizer and lr_scheduler configurations.
301301
"""
302302
optimizer = torch.optim.Adam(
303-
self.parameters(),
303+
self.model.parameters(),
304304
lr=self.lr,
305305
weight_decay=self.weight_decay,
306306
)

mambular/models/sklearn_base_classifier.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ..data_utils.datamodule import MambularDataModule
1010
from ..preprocessing import Preprocessor
1111
import numpy as np
12+
from lightning.pytorch.callbacks import ModelSummary
1213

1314

1415
class SklearnBaseClassifier(BaseEstimator):
@@ -367,12 +368,16 @@ def fit(
367368
)
368369

369370
# Initialize the trainer and train the model
370-
trainer = pl.Trainer(
371+
self.trainer = pl.Trainer(
371372
max_epochs=max_epochs,
372-
callbacks=[early_stop_callback, checkpoint_callback],
373+
callbacks=[
374+
early_stop_callback,
375+
checkpoint_callback,
376+
ModelSummary(max_depth=2),
377+
],
373378
**trainer_kwargs
374379
)
375-
trainer.fit(self.model, self.data_module)
380+
self.trainer.fit(self.model, self.data_module)
376381

377382
best_model_path = checkpoint_callback.best_model_path
378383
if best_model_path:

mambular/models/sklearn_base_lss.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
PoissonDistribution,
3232
StudentTDistribution,
3333
)
34+
from lightning.pytorch.callbacks import ModelSummary
3435

3536

3637
class SklearnBaseLSS(BaseEstimator):
@@ -409,12 +410,16 @@ def fit(
409410
)
410411

411412
# Initialize the trainer and train the model
412-
trainer = pl.Trainer(
413+
self.trainer = pl.Trainer(
413414
max_epochs=max_epochs,
414-
callbacks=[early_stop_callback, checkpoint_callback],
415+
callbacks=[
416+
early_stop_callback,
417+
checkpoint_callback,
418+
ModelSummary(max_depth=2),
419+
],
415420
**trainer_kwargs
416421
)
417-
trainer.fit(self.model, self.data_module)
422+
self.trainer.fit(self.model, self.data_module)
418423

419424
best_model_path = checkpoint_callback.best_model_path
420425
if best_model_path:

mambular/models/sklearn_base_regressor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ..base_models.lightning_wrapper import TaskModel
99
from ..data_utils.datamodule import MambularDataModule
1010
from ..preprocessing import Preprocessor
11+
from lightning.pytorch.callbacks import ModelSummary
1112

1213

1314
class SklearnBaseRegressor(BaseEstimator):
@@ -356,12 +357,16 @@ def fit(
356357
)
357358

358359
# Initialize the trainer and train the model
359-
trainer = pl.Trainer(
360+
self.trainer = pl.Trainer(
360361
max_epochs=max_epochs,
361-
callbacks=[early_stop_callback, checkpoint_callback],
362+
callbacks=[
363+
early_stop_callback,
364+
checkpoint_callback,
365+
ModelSummary(max_depth=2),
366+
],
362367
**trainer_kwargs
363368
)
364-
trainer.fit(self.model, self.data_module)
369+
self.trainer.fit(self.model, self.data_module)
365370

366371
best_model_path = checkpoint_callback.best_model_path
367372
if best_model_path:

0 commit comments

Comments
 (0)