File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -37,7 +37,7 @@ def __init__(
3737 lss = False ,
3838 family = None ,
3939 loss_fct : callable = None ,
40- ** kwargs
40+ ** kwargs ,
4141 ):
4242 super ().__init__ ()
4343 self .num_classes = num_classes
@@ -300,7 +300,7 @@ def configure_optimizers(self):
300300 A dictionary containing the optimizer and lr_scheduler configurations.
301301 """
302302 optimizer = torch .optim .Adam (
303- self .parameters (),
303+ self .model . parameters (),
304304 lr = self .lr ,
305305 weight_decay = self .weight_decay ,
306306 )
Original file line number Diff line number Diff line change 99from ..data_utils .datamodule import MambularDataModule
1010from ..preprocessing import Preprocessor
1111import numpy as np
12+ from lightning .pytorch .callbacks import ModelSummary
1213
1314
1415class SklearnBaseClassifier (BaseEstimator ):
@@ -367,12 +368,16 @@ def fit(
367368 )
368369
369370 # Initialize the trainer and train the model
370- trainer = pl .Trainer (
371+ self . trainer = pl .Trainer (
371372 max_epochs = max_epochs ,
372- callbacks = [early_stop_callback , checkpoint_callback ],
373+ callbacks = [
374+ early_stop_callback ,
375+ checkpoint_callback ,
376+ ModelSummary (max_depth = 2 ),
377+ ],
373378 ** trainer_kwargs
374379 )
375- trainer .fit (self .model , self .data_module )
380+ self . trainer .fit (self .model , self .data_module )
376381
377382 best_model_path = checkpoint_callback .best_model_path
378383 if best_model_path :
Original file line number Diff line number Diff line change 3131 PoissonDistribution ,
3232 StudentTDistribution ,
3333)
34+ from lightning .pytorch .callbacks import ModelSummary
3435
3536
3637class SklearnBaseLSS (BaseEstimator ):
@@ -409,12 +410,16 @@ def fit(
409410 )
410411
411412 # Initialize the trainer and train the model
412- trainer = pl .Trainer (
413+ self . trainer = pl .Trainer (
413414 max_epochs = max_epochs ,
414- callbacks = [early_stop_callback , checkpoint_callback ],
415+ callbacks = [
416+ early_stop_callback ,
417+ checkpoint_callback ,
418+ ModelSummary (max_depth = 2 ),
419+ ],
415420 ** trainer_kwargs
416421 )
417- trainer .fit (self .model , self .data_module )
422+ self . trainer .fit (self .model , self .data_module )
418423
419424 best_model_path = checkpoint_callback .best_model_path
420425 if best_model_path :
Original file line number Diff line number Diff line change 88from ..base_models .lightning_wrapper import TaskModel
99from ..data_utils .datamodule import MambularDataModule
1010from ..preprocessing import Preprocessor
11+ from lightning .pytorch .callbacks import ModelSummary
1112
1213
1314class SklearnBaseRegressor (BaseEstimator ):
@@ -356,12 +357,16 @@ def fit(
356357 )
357358
358359 # Initialize the trainer and train the model
359- trainer = pl .Trainer (
360+ self . trainer = pl .Trainer (
360361 max_epochs = max_epochs ,
361- callbacks = [early_stop_callback , checkpoint_callback ],
362+ callbacks = [
363+ early_stop_callback ,
364+ checkpoint_callback ,
365+ ModelSummary (max_depth = 2 ),
366+ ],
362367 ** trainer_kwargs
363368 )
364- trainer .fit (self .model , self .data_module )
369+ self . trainer .fit (self .model , self .data_module )
365370
366371 best_model_path = checkpoint_callback .best_model_path
367372 if best_model_path :
You can’t perform that action at this time.
0 commit comments