Skip to content

Commit 029c6d8

Browse files
committed
adding one-hot encoding to embedding_layer
1 parent 8933cf7 commit 029c6d8

1 file changed

Lines changed: 26 additions & 8 deletions

File tree

mambular/arch_utils/embedding_layer.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
layer_norm_after_embedding=False,
1313
use_cls=False,
1414
cls_position=0,
15+
cat_encoding="int",
1516
):
1617
"""
1718
Embedding layer that handles numerical and categorical embeddings.
@@ -56,15 +57,23 @@ def __init__(
5657
]
5758
)
5859

59-
self.cat_embeddings = nn.ModuleList(
60-
[
61-
nn.Sequential(
62-
nn.Embedding(num_categories + 1, d_model),
63-
self.embedding_activation,
60+
self.cat_embeddings = nn.ModuleList()
61+
for feature_name, num_categories in cat_feature_info.items():
62+
if cat_encoding == "int":
63+
self.cat_embeddings.append(
64+
nn.Sequential(
65+
nn.Embedding(num_categories + 1, d_model),
66+
self.embedding_activation,
67+
)
68+
)
69+
elif cat_encoding == "one-hot":
70+
self.cat_embeddings.append(
71+
nn.Sequential(
72+
OneHotEncoding(num_categories),
73+
nn.Linear(num_categories, d_model, bias=False),
74+
self.embedding_activation,
75+
)
6476
)
65-
for feature_name, num_categories in cat_feature_info.items()
66-
]
67-
)
6877

6978
if self.use_cls:
7079
self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
@@ -143,3 +152,12 @@ def forward(self, num_features=None, cat_features=None):
143152
)
144153

145154
return x
155+
156+
157+
class OneHotEncoding(nn.Module):
158+
def __init__(self, num_categories):
159+
super(OneHotEncoding, self).__init__()
160+
self.num_categories = num_categories
161+
162+
def forward(self, x):
163+
return torch.nn.functional.one_hot(x, num_classes=self.num_categories).float()

0 commit comments

Comments
 (0)