@@ -101,20 +101,21 @@ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config)
101101 ]
102102 )
103103
104- if self .embedding_projection :
105- self .emb_embeddings = nn .ModuleList (
106- [
107- nn .Sequential (
108- nn .Linear (
109- feature_info ["dimension" ],
110- self .d_model ,
111- bias = self .embedding_bias ,
112- ),
113- self .embedding_activation ,
114- )
115- for feature_name , feature_info in emb_feature_info .items ()
116- ]
117- )
104+ if len (emb_feature_info ) >= 1 :
105+ if self .embedding_projection :
106+ self .emb_embeddings = nn .ModuleList (
107+ [
108+ nn .Sequential (
109+ nn .Linear (
110+ feature_info ["dimension" ],
111+ self .d_model ,
112+ bias = self .embedding_bias ,
113+ ),
114+ self .embedding_activation ,
115+ )
116+ for feature_name , feature_info in emb_feature_info .items ()
117+ ]
118+ )
118119
119120 # Class token if required
120121 if self .use_cls :
@@ -181,15 +182,16 @@ def forward(self, num_features, cat_features, emb_features):
181182 if self .layer_norm_after_embedding :
182183 num_embeddings = self .embedding_norm (num_embeddings )
183184
184- if self .embedding_projection :
185- emb_embeddings = [
186- emb (emb_features [i ]) for i , emb in enumerate (self .emb_embeddings )
187- ]
188- emb_embeddings = torch .stack (emb_embeddings , dim = 1 )
189- else :
190- emb_embeddings = torch .stack (emb_features , dim = 1 )
191- if self .layer_norm_after_embedding :
192- emb_embeddings = self .embedding_norm (emb_embeddings )
185+ if emb_features != []:
186+ if self .embedding_projection :
187+ emb_embeddings = [
188+ emb (emb_features [i ]) for i , emb in enumerate (self .emb_embeddings )
189+ ]
190+ emb_embeddings = torch .stack (emb_embeddings , dim = 1 )
191+ else :
192+ emb_embeddings = torch .stack (emb_features , dim = 1 )
193+ if self .layer_norm_after_embedding :
194+ emb_embeddings = self .embedding_norm (emb_embeddings )
193195
194196 embeddings = [
195197 e for e in [cat_embeddings , num_embeddings , emb_embeddings ] if e is not None
0 commit comments