Skip to content

Commit f5407d8

Browse files
authored
Merge pull request #278 from MaxSchambach/fix-dataset
Fix `MambularDataset` length for data with only categorical features
2 parents 07c16a1 + 9e8043f commit f5407d8

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

mambular/data_utils/dataset.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(
2424
labels=None,
2525
regression=True,
2626
):
27+
assert cat_features_list or num_features_list
28+
2729
self.cat_features_list = cat_features_list # Categorical features tensors
2830
self.num_features_list = num_features_list # Numerical features tensors
2931
self.embeddings_list = embeddings_list # Embeddings tensors (optional)
@@ -44,7 +46,8 @@ def __init__(
4446
self.labels = None # No labels in prediction mode
4547

4648
def __len__(self):
47-
return len(self.num_features_list[0]) # Use numerical features length
49+
_feats = self.num_features_list if self.num_features_list else self.cat_features_list
50+
return len(_feats[0])
4851

4952
def __getitem__(self, idx):
5053
"""Retrieves the features and label for a given index.

0 commit comments

Comments
 (0)