@@ -188,7 +188,7 @@ def training_step(self, batch, batch_idx): # type: ignore
188188 Tensor
189189 Training loss.
190190 """
191- cat_features , num_features , labels = batch
191+ num_features , cat_features , labels = batch
192192
193193 # Check if the model has a `penalty_forward` method
194194 if hasattr (self .base_model , "penalty_forward" ):
@@ -235,7 +235,7 @@ def validation_step(self, batch, batch_idx): # type: ignore
235235 Validation loss.
236236 """
237237
238- cat_features , num_features , labels = batch
238+ num_features , cat_features , labels = batch
239239 preds = self (num_features = num_features , cat_features = cat_features )
240240 val_loss = self .compute_loss (preds , labels )
241241
@@ -277,7 +277,7 @@ def test_step(self, batch, batch_idx): # type: ignore
277277 Tensor
278278 Test loss.
279279 """
280- cat_features , num_features , labels = batch
280+ num_features , cat_features , labels = batch
281281 preds = self (num_features = num_features , cat_features = cat_features )
282282 test_loss = self .compute_loss (preds , labels )
283283
@@ -308,7 +308,7 @@ def predict_step(self, batch, batch_idx):
308308 Predictions.
309309 """
310310
311- cat_features , num_features = batch
311+ num_features , cat_features = batch
312312 preds = self (num_features = num_features , cat_features = cat_features )
313313
314314 return preds
0 commit comments