Skip to content

Commit d08af31

Browse files
committed
adjust order in __getitem__ functionality and batch for lightningmodule
1 parent fac6a1f commit d08af31

2 files changed

Lines changed: 12 additions & 9 deletions

File tree

mambular/base_models/lightning_wrapper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def training_step(self, batch, batch_idx): # type: ignore
188188
Tensor
189189
Training loss.
190190
"""
191-
cat_features, num_features, labels = batch
191+
num_features, cat_features, labels = batch
192192

193193
# Check if the model has a `penalty_forward` method
194194
if hasattr(self.base_model, "penalty_forward"):
@@ -235,7 +235,7 @@ def validation_step(self, batch, batch_idx): # type: ignore
235235
Validation loss.
236236
"""
237237

238-
cat_features, num_features, labels = batch
238+
num_features, cat_features, labels = batch
239239
preds = self(num_features=num_features, cat_features=cat_features)
240240
val_loss = self.compute_loss(preds, labels)
241241

@@ -277,7 +277,7 @@ def test_step(self, batch, batch_idx): # type: ignore
277277
Tensor
278278
Test loss.
279279
"""
280-
cat_features, num_features, labels = batch
280+
num_features, cat_features, labels = batch
281281
preds = self(num_features=num_features, cat_features=cat_features)
282282
test_loss = self.compute_loss(preds, labels)
283283

@@ -308,7 +308,7 @@ def predict_step(self, batch, batch_idx):
308308
Predictions.
309309
"""
310310

311-
cat_features, num_features = batch
311+
num_features, cat_features = batch
312312
preds = self(num_features=num_features, cat_features=cat_features)
313313

314314
return preds

mambular/data_utils/dataset.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ class MambularDataset(Dataset):
2020
regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True.
2121
"""
2222

23-
def __init__(self, cat_features_list, num_features_list, labels=None, regression=True):
23+
def __init__(
24+
self, cat_features_list, num_features_list, labels=None, regression=True
25+
):
2426
self.cat_features_list = cat_features_list # Categorical features tensors
2527
self.num_features_list = num_features_list # Numerical features tensors
2628
self.regression = regression
@@ -54,7 +56,9 @@ def __getitem__(self, idx):
5456
tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical features)
5557
and a single label (if available).
5658
"""
57-
cat_features = [feature_tensor[idx] for feature_tensor in self.cat_features_list]
59+
cat_features = [
60+
feature_tensor[idx] for feature_tensor in self.cat_features_list
61+
]
5862
num_features = [
5963
torch.as_tensor(feature_tensor[idx]).clone().detach().to(torch.float32)
6064
for feature_tensor in self.num_features_list
@@ -68,7 +72,6 @@ def __getitem__(self, idx):
6872
label = label.clone().detach().to(torch.float32)
6973
else:
7074
label = label.clone().detach().to(torch.long)
71-
return cat_features, num_features, label
75+
return num_features, cat_features, label
7276
else:
73-
return cat_features, num_features # No label in prediction mode
74-
77+
return num_features, cat_features # No label in prediction mode

0 commit comments

Comments
 (0)