Skip to content

Commit 3f482fa

Browse files
authored
Merge pull request #78 from basf/models
add build_model and get_num_params methods
2 parents e16ba03 + e7ca3a2 commit 3f482fa

3 files changed

Lines changed: 412 additions & 20 deletions

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 155 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,30 +108,24 @@ def set_params(self, **parameters):
108108

109109
return self
110110

111-
def fit(
111+
def build_model(
112112
self,
113113
X,
114114
y,
115115
val_size: float = 0.2,
116116
X_val=None,
117117
y_val=None,
118-
max_epochs: int = 100,
119118
random_state: int = 101,
120119
batch_size: int = 128,
121120
shuffle: bool = True,
122-
patience: int = 15,
123-
monitor: str = "val_loss",
124-
mode: str = "min",
125121
lr: float = 1e-4,
126122
lr_patience: int = 10,
127123
factor: float = 0.1,
128124
weight_decay: float = 1e-06,
129-
checkpoint_path="model_checkpoints",
130125
dataloader_kwargs={},
131-
**trainer_kwargs
132126
):
133127
"""
134-
Trains the regression model using the provided training data. Optionally, a separate validation set can be used.
128+
Builds the model using the provided training data.
135129
136130
Parameters
137131
----------
@@ -145,20 +139,12 @@ def fit(
145139
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
146140
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
147141
The validation target values. Required if `X_val` is provided.
148-
max_epochs : int, default=100
149-
Maximum number of epochs for training.
150142
random_state : int, default=101
151143
Controls the shuffling applied to the data before applying the split.
152144
batch_size : int, default=64
153145
Number of samples per gradient update.
154146
shuffle : bool, default=True
155147
Whether to shuffle the training data before each epoch.
156-
patience : int, default=10
157-
Number of epochs with no improvement on the validation loss to wait before early stopping.
158-
monitor : str, default="val_loss"
159-
The metric to monitor for early stopping.
160-
mode : str, default="min"
161-
Whether the monitored metric should be minimized (`min`) or maximized (`max`).
162148
lr : float, default=1e-3
163149
Learning rate for the optimizer.
164150
lr_patience : int, default=10
@@ -167,17 +153,15 @@ def fit(
167153
Factor by which the learning rate will be reduced.
168154
weight_decay : float, default=0.025
169155
Weight decay (L2 penalty) coefficient.
170-
checkpoint_path : str, default="model_checkpoints"
171-
Path where the checkpoints are being saved.
172156
dataloader_kwargs: dict, default={}
173157
The kwargs for the pytorch dataloader class.
174-
**trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
158+
175159
176160
177161
Returns
178162
-------
179163
self : object
180-
The fitted regressor.
164+
The built classifier.
181165
"""
182166
if not isinstance(X, pd.DataFrame):
183167
X = pd.DataFrame(X)
@@ -219,6 +203,157 @@ def fit(
219203
weight_decay=weight_decay,
220204
)
221205

206+
self.built = True
207+
208+
return self
209+
210+
def get_number_of_params(self, requires_grad=True):
211+
"""
212+
Calculate the number of parameters in the model.
213+
214+
Parameters
215+
----------
216+
requires_grad : bool, optional
217+
If True, only count the parameters that require gradients (trainable parameters).
218+
If False, count all parameters. Default is True.
219+
220+
Returns
221+
-------
222+
int
223+
The total number of parameters in the model.
224+
225+
Raises
226+
------
227+
ValueError
228+
If the model has not been built prior to calling this method.
229+
"""
230+
if not self.built:
231+
raise ValueError(
232+
"The model must be built before the number of parameters can be estimated"
233+
)
234+
else:
235+
if requires_grad:
236+
return sum(
237+
p.numel() for p in self.model.parameters() if p.requires_grad
238+
)
239+
else:
240+
return sum(p.numel() for p in self.model.parameters())
241+
242+
def fit(
243+
self,
244+
X,
245+
y,
246+
val_size: float = 0.2,
247+
X_val=None,
248+
y_val=None,
249+
max_epochs: int = 100,
250+
random_state: int = 101,
251+
batch_size: int = 128,
252+
shuffle: bool = True,
253+
patience: int = 15,
254+
monitor: str = "val_loss",
255+
mode: str = "min",
256+
lr: float = 1e-4,
257+
lr_patience: int = 10,
258+
factor: float = 0.1,
259+
weight_decay: float = 1e-06,
260+
checkpoint_path="model_checkpoints",
261+
dataloader_kwargs={},
262+
rebuild=True,
263+
**trainer_kwargs
264+
):
265+
"""
266+
Trains the classification model using the provided training data. Optionally, a separate validation set can be used.
267+
268+
Parameters
269+
----------
270+
X : DataFrame or array-like, shape (n_samples, n_features)
271+
The training input samples.
272+
y : array-like, shape (n_samples,) or (n_samples, n_targets)
273+
The target values (real numbers).
274+
val_size : float, default=0.2
275+
The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided.
276+
X_val : DataFrame or array-like, shape (n_samples, n_features), optional
277+
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
278+
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
279+
The validation target values. Required if `X_val` is provided.
280+
max_epochs : int, default=100
281+
Maximum number of epochs for training.
282+
random_state : int, default=101
283+
Controls the shuffling applied to the data before applying the split.
284+
batch_size : int, default=64
285+
Number of samples per gradient update.
286+
shuffle : bool, default=True
287+
Whether to shuffle the training data before each epoch.
288+
patience : int, default=10
289+
Number of epochs with no improvement on the validation loss to wait before early stopping.
290+
monitor : str, default="val_loss"
291+
The metric to monitor for early stopping.
292+
mode : str, default="min"
293+
Whether the monitored metric should be minimized (`min`) or maximized (`max`).
294+
lr : float, default=1e-3
295+
Learning rate for the optimizer.
296+
lr_patience : int, default=10
297+
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
298+
factor : float, default=0.1
299+
Factor by which the learning rate will be reduced.
300+
weight_decay : float, default=0.025
301+
Weight decay (L2 penalty) coefficient.
302+
checkpoint_path : str, default="model_checkpoints"
303+
Path where the checkpoints are being saved.
304+
dataloader_kwargs: dict, default={}
305+
The kwargs for the pytorch dataloader class.
306+
rebuild: bool, default=True
307+
Whether to rebuild the model when it already was built.
308+
**trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
309+
310+
311+
Returns
312+
-------
313+
self : object
314+
The fitted classifier.
315+
"""
316+
if not self.built and not rebuild:
317+
if not isinstance(X, pd.DataFrame):
318+
X = pd.DataFrame(X)
319+
if isinstance(y, pd.Series):
320+
y = y.values
321+
if X_val:
322+
if not isinstance(X_val, pd.DataFrame):
323+
X_val = pd.DataFrame(X_val)
324+
if isinstance(y_val, pd.Series):
325+
y_val = y_val.values
326+
327+
self.data_module = MambularDataModule(
328+
preprocessor=self.preprocessor,
329+
batch_size=batch_size,
330+
shuffle=shuffle,
331+
X_val=X_val,
332+
y_val=y_val,
333+
val_size=val_size,
334+
random_state=random_state,
335+
regression=False,
336+
**dataloader_kwargs
337+
)
338+
339+
self.data_module.preprocess_data(
340+
X, y, X_val, y_val, val_size=val_size, random_state=random_state
341+
)
342+
343+
num_classes = len(np.unique(y))
344+
345+
self.model = TaskModel(
346+
model_class=self.base_model,
347+
num_classes=num_classes,
348+
config=self.config,
349+
cat_feature_info=self.data_module.cat_feature_info,
350+
num_feature_info=self.data_module.num_feature_info,
351+
lr=lr,
352+
lr_patience=lr_patience,
353+
lr_factor=factor,
354+
weight_decay=weight_decay,
355+
)
356+
222357
early_stop_callback = EarlyStopping(
223358
monitor=monitor, min_delta=0.00, patience=patience, verbose=False, mode=mode
224359
)

mambular/models/sklearn_base_lss.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,135 @@ def set_params(self, **parameters):
130130

131131
return self
132132

133+
def build_model(
134+
self,
135+
X,
136+
y,
137+
val_size: float = 0.2,
138+
X_val=None,
139+
y_val=None,
140+
random_state: int = 101,
141+
batch_size: int = 128,
142+
shuffle: bool = True,
143+
lr: float = 1e-4,
144+
lr_patience: int = 10,
145+
factor: float = 0.1,
146+
weight_decay: float = 1e-06,
147+
dataloader_kwargs={},
148+
):
149+
"""
150+
Builds the model using the provided training data.
151+
152+
Parameters
153+
----------
154+
X : DataFrame or array-like, shape (n_samples, n_features)
155+
The training input samples.
156+
y : array-like, shape (n_samples,) or (n_samples, n_targets)
157+
The target values (real numbers).
158+
val_size : float, default=0.2
159+
The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided.
160+
X_val : DataFrame or array-like, shape (n_samples, n_features), optional
161+
The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
162+
y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
163+
The validation target values. Required if `X_val` is provided.
164+
random_state : int, default=101
165+
Controls the shuffling applied to the data before applying the split.
166+
batch_size : int, default=64
167+
Number of samples per gradient update.
168+
shuffle : bool, default=True
169+
Whether to shuffle the training data before each epoch.
170+
lr : float, default=1e-3
171+
Learning rate for the optimizer.
172+
lr_patience : int, default=10
173+
Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
174+
factor : float, default=0.1
175+
Factor by which the learning rate will be reduced.
176+
weight_decay : float, default=0.025
177+
Weight decay (L2 penalty) coefficient.
178+
dataloader_kwargs: dict, default={}
179+
The kwargs for the pytorch dataloader class.
180+
181+
Returns
182+
-------
183+
self : object
184+
The built distributional regressor.
185+
"""
186+
if not isinstance(X, pd.DataFrame):
187+
X = pd.DataFrame(X)
188+
if isinstance(y, pd.Series):
189+
y = y.values
190+
if X_val:
191+
if not isinstance(X_val, pd.DataFrame):
192+
X_val = pd.DataFrame(X_val)
193+
if isinstance(y_val, pd.Series):
194+
y_val = y_val.values
195+
196+
self.data_module = MambularDataModule(
197+
preprocessor=self.preprocessor,
198+
batch_size=batch_size,
199+
shuffle=shuffle,
200+
X_val=X_val,
201+
y_val=y_val,
202+
val_size=val_size,
203+
random_state=random_state,
204+
regression=False,
205+
**dataloader_kwargs
206+
)
207+
208+
self.data_module.preprocess_data(
209+
X, y, X_val, y_val, val_size=val_size, random_state=random_state
210+
)
211+
212+
num_classes = len(np.unique(y))
213+
214+
self.model = TaskModel(
215+
model_class=self.base_model,
216+
num_classes=num_classes,
217+
config=self.config,
218+
cat_feature_info=self.data_module.cat_feature_info,
219+
num_feature_info=self.data_module.num_feature_info,
220+
lr=lr,
221+
lr_patience=lr_patience,
222+
lr_factor=factor,
223+
weight_decay=weight_decay,
224+
)
225+
226+
self.built = True
227+
228+
return self
229+
230+
def get_number_of_params(self, requires_grad=True):
231+
"""
232+
Calculate the number of parameters in the model.
233+
234+
Parameters
235+
----------
236+
requires_grad : bool, optional
237+
If True, only count the parameters that require gradients (trainable parameters).
238+
If False, count all parameters. Default is True.
239+
240+
Returns
241+
-------
242+
int
243+
The total number of parameters in the model.
244+
245+
Raises
246+
------
247+
ValueError
248+
If the model has not been built prior to calling this method.
249+
"""
250+
if not self.built:
251+
raise ValueError(
252+
"The model must be built before the number of parameters can be estimated"
253+
)
254+
else:
255+
if requires_grad:
256+
return sum(
257+
p.numel() for p in self.model.parameters() if p.requires_grad
258+
)
259+
else:
260+
return sum(p.numel() for p in self.model.parameters())
261+
133262
def fit(
134263
self,
135264
X,

0 commit comments

Comments
 (0)