66
77
88class EmbeddingLayer (nn .Module ):
9- def __init__ (self , num_feature_info , cat_feature_info , config ):
9+ def __init__ (self , num_feature_info , cat_feature_info , emb_feature_info , config ):
1010 """Embedding layer that handles numerical and categorical embeddings.
1111
1212 Parameters
@@ -28,6 +28,7 @@ def __init__(self, num_feature_info, cat_feature_info, config):
2828 self .layer_norm_after_embedding = getattr (
2929 config , "layer_norm_after_embedding" , False
3030 )
31+ self .embedding_projection = getattr (config , "embedding_projection" , True )
3132 self .use_cls = getattr (config , "use_cls" , False )
3233 self .cls_position = getattr (config , "cls_position" , 0 )
3334 self .embedding_dropout = (
@@ -100,6 +101,22 @@ def __init__(self, num_feature_info, cat_feature_info, config):
100101 ]
101102 )
102103
104+ if len (emb_feature_info ) >= 1 :
105+ if self .embedding_projection :
106+ self .emb_embeddings = nn .ModuleList (
107+ [
108+ nn .Sequential (
109+ nn .Linear (
110+ feature_info ["dimension" ],
111+ self .d_model ,
112+ bias = self .embedding_bias ,
113+ ),
114+ self .embedding_activation ,
115+ )
116+ for feature_name , feature_info in emb_feature_info .items ()
117+ ]
118+ )
119+
103120 # Class token if required
104121 if self .use_cls :
105122 self .cls_token = nn .Parameter (torch .zeros (1 , 1 , self .d_model ))
@@ -108,15 +125,12 @@ def __init__(self, num_feature_info, cat_feature_info, config):
108125 if self .layer_norm_after_embedding :
109126 self .embedding_norm = nn .LayerNorm (self .d_model )
110127
111- def forward (self , num_features = None , cat_features = None ):
128+ def forward (self , num_features , cat_features , emb_features ):
112129 """Defines the forward pass of the model.
113130
114131 Parameters
115132 ----------
116- num_features : Tensor, optional
117- Tensor containing the numerical features.
118- cat_features : Tensor, optional
119- Tensor containing the categorical features.
133+ data: tuple of lists of tensors
120134
121135 Returns
122136 -------
@@ -128,6 +142,7 @@ def forward(self, num_features=None, cat_features=None):
128142 ValueError
129143 If no features are provided to the model.
130144 """
145+ num_embeddings , cat_embeddings , emb_embeddings = None , None , None
131146
132147 # Class token initialization
133148 if self .use_cls :
@@ -149,8 +164,6 @@ def forward(self, num_features=None, cat_features=None):
149164 cat_embeddings = torch .squeeze (cat_embeddings , dim = 2 )
150165 if self .layer_norm_after_embedding :
151166 cat_embeddings = self .embedding_norm (cat_embeddings )
152- else :
153- cat_embeddings = None
154167
155168 # Process numerical embeddings based on embedding_type
156169 if self .embedding_type == "plr" :
@@ -163,26 +176,33 @@ def forward(self, num_features=None, cat_features=None):
163176 num_embeddings = self .num_embeddings (num_features )
164177 if self .layer_norm_after_embedding :
165178 num_embeddings = self .embedding_norm (num_embeddings )
166- else :
167- num_embeddings = None
168179 else :
169180 # For linear and ndt embeddings, handle each feature individually
170181 if self .num_embeddings and num_features is not None :
171182 num_embeddings = [emb (num_features [i ]) for i , emb in enumerate (self .num_embeddings )] # type: ignore
172183 num_embeddings = torch .stack (num_embeddings , dim = 1 )
173184 if self .layer_norm_after_embedding :
174185 num_embeddings = self .embedding_norm (num_embeddings )
186+
187+ if emb_features != []:
188+ if self .embedding_projection :
189+ emb_embeddings = [
190+ emb (emb_features [i ]) for i , emb in enumerate (self .emb_embeddings )
191+ ]
192+ emb_embeddings = torch .stack (emb_embeddings , dim = 1 )
175193 else :
176- num_embeddings = None
177-
178- # Combine categorical and numerical embeddings
179- if cat_embeddings is not None and num_embeddings is not None :
180-
181- x = torch .cat ([cat_embeddings , num_embeddings ], dim = 1 )
182- elif cat_embeddings is not None :
183- x = cat_embeddings
184- elif num_embeddings is not None :
185- x = num_embeddings
194+
195+ emb_embeddings = torch .stack (emb_features , dim = 1 )
196+ if self .layer_norm_after_embedding :
197+ emb_embeddings = self .embedding_norm (emb_embeddings )
198+
199+ embeddings = [
200+ e for e in [cat_embeddings , num_embeddings , emb_embeddings ] if e is not None
201+ ]
202+
203+ if embeddings :
204+ x = torch .cat (embeddings , dim = 1 ) if len (embeddings ) > 1 else embeddings [0 ]
205+
186206 else :
187207 raise ValueError ("No features provided to the model." )
188208
0 commit comments