We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 07c16a1 + 9e8043f commit f5407d8Copy full SHA for f5407d8
1 file changed
mambular/data_utils/dataset.py
@@ -24,6 +24,8 @@ def __init__(
24
labels=None,
25
regression=True,
26
):
27
+ assert cat_features_list or num_features_list
28
+
29
self.cat_features_list = cat_features_list # Categorical features tensors
30
self.num_features_list = num_features_list # Numerical features tensors
31
self.embeddings_list = embeddings_list # Embeddings tensors (optional)
@@ -44,7 +46,8 @@ def __init__(
44
46
self.labels = None # No labels in prediction mode
45
47
48
def __len__(self):
- return len(self.num_features_list[0]) # Use numerical features length
49
+ _feats = self.num_features_list if self.num_features_list else self.cat_features_list
50
+ return len(_feats[0])
51
52
def __getitem__(self, idx):
53
"""Retrieves the features and label for a given index.
0 commit comments