@@ -108,30 +108,24 @@ def set_params(self, **parameters):
108108
109109 return self
110110
111- def fit (
111+ def build_model (
112112 self ,
113113 X ,
114114 y ,
115115 val_size : float = 0.2 ,
116116 X_val = None ,
117117 y_val = None ,
118- max_epochs : int = 100 ,
119118 random_state : int = 101 ,
120119 batch_size : int = 128 ,
121120 shuffle : bool = True ,
122- patience : int = 15 ,
123- monitor : str = "val_loss" ,
124- mode : str = "min" ,
125121 lr : float = 1e-4 ,
126122 lr_patience : int = 10 ,
127123 factor : float = 0.1 ,
128124 weight_decay : float = 1e-06 ,
129- checkpoint_path = "model_checkpoints" ,
130125 dataloader_kwargs = {},
131- ** trainer_kwargs
132126 ):
133127 """
134- Trains the regression model using the provided training data. Optionally, a separate validation set can be used .
128+ Builds the model using the provided training data.
135129
136130 Parameters
137131 ----------
@@ -145,20 +139,12 @@ def fit(
145139 The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
146140 y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
147141 The validation target values. Required if `X_val` is provided.
148- max_epochs : int, default=100
149- Maximum number of epochs for training.
150142 random_state : int, default=101
151143 Controls the shuffling applied to the data before applying the split.
152144 batch_size : int, default=64
153145 Number of samples per gradient update.
154146 shuffle : bool, default=True
155147 Whether to shuffle the training data before each epoch.
156- patience : int, default=10
157- Number of epochs with no improvement on the validation loss to wait before early stopping.
158- monitor : str, default="val_loss"
159- The metric to monitor for early stopping.
160- mode : str, default="min"
161- Whether the monitored metric should be minimized (`min`) or maximized (`max`).
162148 lr : float, default=1e-3
163149 Learning rate for the optimizer.
164150 lr_patience : int, default=10
@@ -167,17 +153,15 @@ def fit(
167153 Factor by which the learning rate will be reduced.
168154 weight_decay : float, default=0.025
169155 Weight decay (L2 penalty) coefficient.
170- checkpoint_path : str, default="model_checkpoints"
171- Path where the checkpoints are being saved.
172156 dataloader_kwargs: dict, default={}
173157 The kwargs for the pytorch dataloader class.
174- **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
158+
175159
176160
177161 Returns
178162 -------
179163 self : object
180- The fitted regressor .
164+ The built classifier .
181165 """
182166 if not isinstance (X , pd .DataFrame ):
183167 X = pd .DataFrame (X )
@@ -219,6 +203,157 @@ def fit(
219203 weight_decay = weight_decay ,
220204 )
221205
206+ self .built = True
207+
208+ return self
209+
210+ def get_number_of_params (self , requires_grad = True ):
211+ """
212+ Calculate the number of parameters in the model.
213+
214+ Parameters
215+ ----------
216+ requires_grad : bool, optional
217+ If True, only count the parameters that require gradients (trainable parameters).
218+ If False, count all parameters. Default is True.
219+
220+ Returns
221+ -------
222+ int
223+ The total number of parameters in the model.
224+
225+ Raises
226+ ------
227+ ValueError
228+ If the model has not been built prior to calling this method.
229+ """
230+ if not self .built :
231+ raise ValueError (
232+ "The model must be built before the number of parameters can be estimated"
233+ )
234+ else :
235+ if requires_grad :
236+ return sum (
237+ p .numel () for p in self .model .parameters () if p .requires_grad
238+ )
239+ else :
240+ return sum (p .numel () for p in self .model .parameters ())
241+
242+ def fit (
243+ self ,
244+ X ,
245+ y ,
246+ val_size : float = 0.2 ,
247+ X_val = None ,
248+ y_val = None ,
249+ max_epochs : int = 100 ,
250+ random_state : int = 101 ,
251+ batch_size : int = 128 ,
252+ shuffle : bool = True ,
253+ patience : int = 15 ,
254+ monitor : str = "val_loss" ,
255+ mode : str = "min" ,
256+ lr : float = 1e-4 ,
257+ lr_patience : int = 10 ,
258+ factor : float = 0.1 ,
259+ weight_decay : float = 1e-06 ,
260+ checkpoint_path = "model_checkpoints" ,
261+ dataloader_kwargs = {},
262+ rebuild = True ,
263+ ** trainer_kwargs
264+ ):
265+ """
266+ Trains the classification model using the provided training data. Optionally, a separate validation set can be used.
267+
268+ Parameters
269+ ----------
270+ X : DataFrame or array-like, shape (n_samples, n_features)
271+ The training input samples.
272+ y : array-like, shape (n_samples,) or (n_samples, n_targets)
273+ The target values (real numbers).
274+ val_size : float, default=0.2
275+ The proportion of the dataset to include in the validation split if `X_val` is None. Ignored if `X_val` is provided.
276+ X_val : DataFrame or array-like, shape (n_samples, n_features), optional
277+ The validation input samples. If provided, `X` and `y` are not split and this data is used for validation.
278+ y_val : array-like, shape (n_samples,) or (n_samples, n_targets), optional
279+ The validation target values. Required if `X_val` is provided.
280+ max_epochs : int, default=100
281+ Maximum number of epochs for training.
282+ random_state : int, default=101
283+ Controls the shuffling applied to the data before applying the split.
284+ batch_size : int, default=64
285+ Number of samples per gradient update.
286+ shuffle : bool, default=True
287+ Whether to shuffle the training data before each epoch.
288+ patience : int, default=10
289+ Number of epochs with no improvement on the validation loss to wait before early stopping.
290+ monitor : str, default="val_loss"
291+ The metric to monitor for early stopping.
292+ mode : str, default="min"
293+ Whether the monitored metric should be minimized (`min`) or maximized (`max`).
294+ lr : float, default=1e-3
295+ Learning rate for the optimizer.
296+ lr_patience : int, default=10
297+ Number of epochs with no improvement on the validation loss to wait before reducing the learning rate.
298+ factor : float, default=0.1
299+ Factor by which the learning rate will be reduced.
300+ weight_decay : float, default=0.025
301+ Weight decay (L2 penalty) coefficient.
302+ checkpoint_path : str, default="model_checkpoints"
303+ Path where the checkpoints are being saved.
304+ dataloader_kwargs: dict, default={}
305+ The kwargs for the pytorch dataloader class.
306+ rebuild: bool, default=True
307+ Whether to rebuild the model when it already was built.
308+ **trainer_kwargs : Additional keyword arguments for PyTorch Lightning's Trainer class.
309+
310+
311+ Returns
312+ -------
313+ self : object
314+ The fitted classifier.
315+ """
316+ if not self .built and not rebuild :
317+ if not isinstance (X , pd .DataFrame ):
318+ X = pd .DataFrame (X )
319+ if isinstance (y , pd .Series ):
320+ y = y .values
321+ if X_val :
322+ if not isinstance (X_val , pd .DataFrame ):
323+ X_val = pd .DataFrame (X_val )
324+ if isinstance (y_val , pd .Series ):
325+ y_val = y_val .values
326+
327+ self .data_module = MambularDataModule (
328+ preprocessor = self .preprocessor ,
329+ batch_size = batch_size ,
330+ shuffle = shuffle ,
331+ X_val = X_val ,
332+ y_val = y_val ,
333+ val_size = val_size ,
334+ random_state = random_state ,
335+ regression = False ,
336+ ** dataloader_kwargs
337+ )
338+
339+ self .data_module .preprocess_data (
340+ X , y , X_val , y_val , val_size = val_size , random_state = random_state
341+ )
342+
343+ num_classes = len (np .unique (y ))
344+
345+ self .model = TaskModel (
346+ model_class = self .base_model ,
347+ num_classes = num_classes ,
348+ config = self .config ,
349+ cat_feature_info = self .data_module .cat_feature_info ,
350+ num_feature_info = self .data_module .num_feature_info ,
351+ lr = lr ,
352+ lr_patience = lr_patience ,
353+ lr_factor = factor ,
354+ weight_decay = weight_decay ,
355+ )
356+
222357 early_stop_callback = EarlyStopping (
223358 monitor = monitor , min_delta = 0.00 , patience = patience , verbose = False , mode = mode
224359 )
0 commit comments