@@ -130,9 +130,7 @@ def preprocess_data(
130130 embeddings_val = [embeddings_val ]
131131
132132 split_data += embeddings_train
133- split_result = train_test_split (
134- * split_data , test_size = val_size , random_state = random_state
135- )
133+ split_result = train_test_split (* split_data , test_size = val_size , random_state = random_state )
136134
137135 self .X_train , self .X_val , self .y_train , self .y_val = split_result [:4 ]
138136 self .embeddings_train = split_result [4 ::2 ]
@@ -161,37 +159,31 @@ def preprocess_data(
161159 self .embeddings_val = None
162160
163161 # Fit the preprocessor on the combined training and validation data
164- combined_X = pd .concat ([self .X_train , self .X_val ], axis = 0 ).reset_index (
165- drop = True
166- )
162+ combined_X = pd .concat ([self .X_train , self .X_val ], axis = 0 ).reset_index (drop = True )
167163 combined_y = np .concatenate ((self .y_train , self .y_val ), axis = 0 )
168164
169165 if self .embeddings_train is not None and self .embeddings_val is not None :
170166 combined_embeddings = [
171167 np .concatenate ((emb_train , emb_val ), axis = 0 )
172- for emb_train , emb_val in zip (
173- self .embeddings_train , self .embeddings_val
174- )
168+ for emb_train , emb_val in zip (self .embeddings_train , self .embeddings_val , strict = False )
175169 ]
176170 else :
177171 combined_embeddings = None
178172
179173 self .preprocessor .fit (combined_X , combined_y , combined_embeddings )
180174
181175 # Update feature info based on the actual processed data
182- (self .num_feature_info , self .cat_feature_info , self .embedding_feature_info ) = (
183- self .preprocessor .get_feature_info ()
184- )
176+ (
177+ self .num_feature_info ,
178+ self .cat_feature_info ,
179+ self .embedding_feature_info ,
180+ ) = self .preprocessor .get_feature_info ()
185181
186182 def setup (self , stage : str ):
187183 """Transform the data and create DataLoaders."""
188184 if stage == "fit" :
189- train_preprocessed_data = self .preprocessor .transform (
190- self .X_train , self .embeddings_train
191- )
192- val_preprocessed_data = self .preprocessor .transform (
193- self .X_val , self .embeddings_val
194- )
185+ train_preprocessed_data = self .preprocessor .transform (self .X_train , self .embeddings_train )
186+ val_preprocessed_data = self .preprocessor .transform (self .X_val , self .embeddings_val )
195187
196188 # Initialize lists for tensors
197189 train_cat_tensors = []
@@ -205,75 +197,40 @@ def setup(self, stage: str):
205197 for key in self .cat_feature_info : # type: ignore
206198 dtype = (
207199 torch .float32
208- if any (
209- x in self .cat_feature_info [key ]["preprocessing" ]
210- for x in ["onehot" , "pretrained" ]
211- )
200+ if any (x in self .cat_feature_info [key ]["preprocessing" ] for x in ["onehot" , "pretrained" ]) # type: ignore
212201 else torch .long
213202 )
214203
215- cat_key = "cat_" + str (
216- key
217- ) # Assuming categorical keys are prefixed with 'cat_'
204+ cat_key = "cat_" + str (key ) # Assuming categorical keys are prefixed with 'cat_'
218205 if cat_key in train_preprocessed_data :
219- train_cat_tensors .append (
220- torch .tensor (train_preprocessed_data [cat_key ], dtype = dtype )
221- )
206+ train_cat_tensors .append (torch .tensor (train_preprocessed_data [cat_key ], dtype = dtype ))
222207 if cat_key in val_preprocessed_data :
223- val_cat_tensors .append (
224- torch .tensor (val_preprocessed_data [cat_key ], dtype = dtype )
225- )
208+ val_cat_tensors .append (torch .tensor (val_preprocessed_data [cat_key ], dtype = dtype ))
226209
227210 binned_key = "num_" + str (key ) # for binned features
228211 if binned_key in train_preprocessed_data :
229- train_cat_tensors .append (
230- torch .tensor (train_preprocessed_data [binned_key ], dtype = dtype )
231- )
212+ train_cat_tensors .append (torch .tensor (train_preprocessed_data [binned_key ], dtype = dtype ))
232213
233214 if binned_key in val_preprocessed_data :
234- val_cat_tensors .append (
235- torch .tensor (val_preprocessed_data [binned_key ], dtype = dtype )
236- )
215+ val_cat_tensors .append (torch .tensor (val_preprocessed_data [binned_key ], dtype = dtype ))
237216
238217 # Populate tensors for numerical features, if present in processed data
239218 for key in self .num_feature_info : # type: ignore
240- num_key = "num_" + str (
241- key
242- ) # Assuming numerical keys are prefixed with 'num_'
219+ num_key = "num_" + str (key ) # Assuming numerical keys are prefixed with 'num_'
243220 if num_key in train_preprocessed_data :
244- train_num_tensors .append (
245- torch .tensor (
246- train_preprocessed_data [num_key ], dtype = torch .float32
247- )
248- )
221+ train_num_tensors .append (torch .tensor (train_preprocessed_data [num_key ], dtype = torch .float32 ))
249222 if num_key in val_preprocessed_data :
250- val_num_tensors .append (
251- torch .tensor (
252- val_preprocessed_data [num_key ], dtype = torch .float32
253- )
254- )
223+ val_num_tensors .append (torch .tensor (val_preprocessed_data [num_key ], dtype = torch .float32 ))
255224
256225 if self .embedding_feature_info is not None :
257226 for key in self .embedding_feature_info :
258227 if key in train_preprocessed_data :
259- train_emb_tensors .append (
260- torch .tensor (
261- train_preprocessed_data [key ], dtype = torch .float32
262- )
263- )
228+ train_emb_tensors .append (torch .tensor (train_preprocessed_data [key ], dtype = torch .float32 ))
264229 if key in val_preprocessed_data :
265- val_emb_tensors .append (
266- torch .tensor (
267- val_preprocessed_data [key ], dtype = torch .float32
268- )
269- )
270-
271- train_labels = torch .tensor (
272- self .y_train , dtype = self .labels_dtype
273- ).unsqueeze (dim = 1 )
274- val_labels = torch .tensor (self .y_val , dtype = self .labels_dtype ).unsqueeze (
275- dim = 1
276- )
230+ val_emb_tensors .append (torch .tensor (val_preprocessed_data [key ], dtype = torch .float32 ))
231+
232+ train_labels = torch .tensor (self .y_train , dtype = self .labels_dtype ).unsqueeze (dim = 1 )
233+ val_labels = torch .tensor (self .y_val , dtype = self .labels_dtype ).unsqueeze (dim = 1 )
277234
278235 self .train_dataset = MambularDataset (
279236 train_cat_tensors ,
@@ -300,42 +257,27 @@ def preprocess_new_data(self, X, embeddings):
300257 for key in self .cat_feature_info : # type: ignore
301258 dtype = (
302259 torch .float32
303- if any (
304- x in self .cat_feature_info [key ]["preprocessing" ]
305- for x in ["onehot" , "pretrained" ]
306- )
260+ if any (x in self .cat_feature_info [key ]["preprocessing" ] for x in ["onehot" , "pretrained" ]) # type: ignore
307261 else torch .long
308262 )
309- cat_key = "cat_" + str (
310- key
311- ) # Assuming categorical keys are prefixed with 'cat_'
263+ cat_key = "cat_" + str (key ) # Assuming categorical keys are prefixed with 'cat_'
312264 if cat_key in preprocessed_data :
313- cat_tensors .append (
314- torch .tensor (preprocessed_data [cat_key ], dtype = dtype )
315- )
265+ cat_tensors .append (torch .tensor (preprocessed_data [cat_key ], dtype = dtype ))
316266
317267 binned_key = "num_" + str (key ) # for binned features
318268 if binned_key in preprocessed_data :
319- cat_tensors .append (
320- torch .tensor (preprocessed_data [binned_key ], dtype = dtype )
321- )
269+ cat_tensors .append (torch .tensor (preprocessed_data [binned_key ], dtype = dtype ))
322270
323271 # Populate tensors for numerical features, if present in processed data
324272 for key in self .num_feature_info : # type: ignore
325- num_key = "num_" + str (
326- key
327- ) # Assuming numerical keys are prefixed with 'num_'
273+ num_key = "num_" + str (key ) # Assuming numerical keys are prefixed with 'num_'
328274 if num_key in preprocessed_data :
329- num_tensors .append (
330- torch .tensor (preprocessed_data [num_key ], dtype = torch .float32 )
331- )
275+ num_tensors .append (torch .tensor (preprocessed_data [num_key ], dtype = torch .float32 ))
332276
333277 if self .embedding_feature_info is not None :
334278 for key in self .embedding_feature_info :
335279 if key in preprocessed_data :
336- emb_tensors .append (
337- torch .tensor (preprocessed_data [key ], dtype = torch .float32 )
338- )
280+ emb_tensors .append (torch .tensor (preprocessed_data [key ], dtype = torch .float32 ))
339281
340282 return MambularDataset (
341283 cat_tensors ,
@@ -374,9 +316,7 @@ def val_dataloader(self):
374316 DataLoader: DataLoader instance for the validation dataset.
375317 """
376318 if hasattr (self , "val_dataset" ):
377- return DataLoader (
378- self .val_dataset , batch_size = self .batch_size , ** self .dataloader_kwargs
379- )
319+ return DataLoader (self .val_dataset , batch_size = self .batch_size , ** self .dataloader_kwargs )
380320 else :
381321 raise ValueError ("No validation dataset provided!" )
382322
@@ -387,9 +327,7 @@ def test_dataloader(self):
387327 DataLoader: DataLoader instance for the test dataset.
388328 """
389329 if hasattr (self , "test_dataset" ):
390- return DataLoader (
391- self .test_dataset , batch_size = self .batch_size , ** self .dataloader_kwargs
392- )
330+ return DataLoader (self .test_dataset , batch_size = self .batch_size , ** self .dataloader_kwargs )
393331 else :
394332 raise ValueError ("No test dataset provided!" )
395333
0 commit comments