1+ import numpy as np
12import torch
23from torch .utils .data import Dataset
3- import numpy as np
44
55
66class MambularDataset (Dataset ):
@@ -49,7 +49,8 @@ def __getitem__(self, idx):
4949 feature_tensor [idx ] for feature_tensor in self .cat_features_list
5050 ]
5151 num_features = [
52- torch .as_tensor (feature_tensor [idx ]).clone ().detach ().to (torch .float32 )
52+ torch .as_tensor (feature_tensor [idx ]).clone (
53+ ).detach ().to (torch .float32 )
5354 for feature_tensor in self .num_features_list
5455 ]
5556 label = self .labels [idx ]
@@ -62,57 +63,3 @@ def __getitem__(self, idx):
6263
6364 # Keep categorical and numerical features separate
6465 return cat_features , num_features , label
65-
66-
67- class EmbeddingMambularDataset (Dataset ):
68- """
69- A specialized version of MambularDataset intended for datasets related to protein studies, maintaining the
70- same structure and functionality.
71-
72- This class is designed to handle structured data with separate categorical and numerical features, suitable
73- for both regression and classification tasks within the context of protein studies.
74-
75- Parameters:
76- cat_features_list (list of Tensors): A list of tensors representing the categorical features.
77- num_features_list (list of Tensors): A list of tensors representing the numerical features.
78- labels (Tensor): A tensor of labels.
79- regression (bool, optional): A flag indicating if the dataset is for a regression task. Defaults to True.
80- """
81-
82- def __init__ (self , cat_features_list , num_features_list , labels , regression = True ):
83- self .cat_features_list = cat_features_list # Categorical features tensors
84- self .num_features_list = num_features_list # Numerical features tensors
85- self .labels = labels
86- self .regression = regression
87-
88- def __len__ (self ):
89- return len (self .labels )
90-
91- def __getitem__ (self , idx ):
92- """
93- Retrieves the features and label for a given index in the context of protein studies.
94-
95- Parameters:
96- idx (int): The index of the data point.
97-
98- Returns:
99- tuple: A tuple containing two lists of tensors (one for categorical features and one for numerical
100- features) and a single label (float if regression is True), specifically designed for protein data.
101- """
102- cat_features = [
103- feature_tensor [idx ] for feature_tensor in self .cat_features_list
104- ]
105- num_features = [
106- torch .tensor (feature_tensor [idx ], dtype = torch .float32 )
107- for feature_tensor in self .num_features_list
108- ]
109- label = self .labels [idx ]
110- if self .regression :
111- # Convert the label to float for regression tasks
112- # label = float(label)
113- label = torch .tensor (label , dtype = torch .float32 )
114- else :
115- label = torch .tensor (label , dtype = torch .long )
116-
117- # Keep categorical and numerical features separate
118- return cat_features , num_features , label
0 commit comments