@@ -12,6 +12,7 @@ def __init__(
1212 layer_norm_after_embedding = False ,
1313 use_cls = False ,
1414 cls_position = 0 ,
15+ cat_encoding = "int" ,
1516 ):
1617 """
1718 Embedding layer that handles numerical and categorical embeddings.
@@ -56,15 +57,23 @@ def __init__(
5657 ]
5758 )
5859
59- self .cat_embeddings = nn .ModuleList (
60- [
61- nn .Sequential (
62- nn .Embedding (num_categories + 1 , d_model ),
63- self .embedding_activation ,
60+ self .cat_embeddings = nn .ModuleList ()
61+ for feature_name , num_categories in cat_feature_info .items ():
62+ if cat_encoding == "int" :
63+ self .cat_embeddings .append (
64+ nn .Sequential (
65+ nn .Embedding (num_categories + 1 , d_model ),
66+ self .embedding_activation ,
67+ )
68+ )
69+ elif cat_encoding == "one-hot" :
70+ self .cat_embeddings .append (
71+ nn .Sequential (
72+ OneHotEncoding (num_categories ),
73+ nn .Linear (num_categories , d_model , bias = False ),
74+ self .embedding_activation ,
75+ )
6476 )
65- for feature_name , num_categories in cat_feature_info .items ()
66- ]
67- )
6877
6978 if self .use_cls :
7079 self .cls_token = nn .Parameter (torch .zeros (1 , 1 , d_model ))
@@ -143,3 +152,12 @@ def forward(self, num_features=None, cat_features=None):
143152 )
144153
145154 return x
155+
156+
157+ class OneHotEncoding (nn .Module ):
158+ def __init__ (self , num_categories ):
159+ super (OneHotEncoding , self ).__init__ ()
160+ self .num_categories = num_categories
161+
162+ def forward (self , x ):
163+ return torch .nn .functional .one_hot (x , num_classes = self .num_categories ).float ()
0 commit comments