@@ -174,6 +174,11 @@ def __init__(
174174 torch .zeros (1 , 1 , self .hparams .get ("d_model" , config .d_model ))
175175 )
176176
177+ if self .pooling_method == "cls" :
178+ self .use_cls = True
179+ else :
180+ self .use_cls = self .hparams .get ("use_cls" , config .use_cls )
181+
177182 if self .hparams .get ("layer_norm_after_embedding" ):
178183 self .embedding_norm = nn .LayerNorm (
179184 self .hparams .get ("d_model" , config .d_model )
@@ -198,10 +203,13 @@ def forward(self, num_features, cat_features):
198203 Tensor
199204 The output predictions of the model.
200205 """
201- batch_size = (
202- cat_features [0 ].size (0 ) if cat_features != [] else num_features [0 ].size (0 )
203- )
204- cls_tokens = self .cls_token .expand (batch_size , - 1 , - 1 )
206+ if self .use_cls :
207+ batch_size = (
208+ cat_features [0 ].size (0 )
209+ if cat_features != []
210+ else num_features [0 ].size (0 )
211+ )
212+ cls_tokens = self .cls_token .expand (batch_size , - 1 , - 1 )
205213
206214 if len (self .cat_embeddings ) > 0 and cat_features :
207215 cat_embeddings = [
@@ -225,11 +233,20 @@ def forward(self, num_features, cat_features):
225233 num_embeddings = None
226234
227235 if cat_embeddings is not None and num_embeddings is not None :
228- x = torch .cat ([cls_tokens , cat_embeddings , num_embeddings ], dim = 1 )
236+ if self .use_cls :
237+ x = torch .cat ([cat_embeddings , num_embeddings , cls_tokens ], dim = 1 )
238+ else :
239+ x = torch .cat ([cat_embeddings , num_embeddings ], dim = 1 )
229240 elif cat_embeddings is not None :
230- x = torch .cat ([cls_tokens , cat_embeddings ], dim = 1 )
241+ if self .use_cls :
242+ x = torch .cat ([cat_embeddings , cls_tokens ], dim = 1 )
243+ else :
244+ x = cat_embeddings
231245 elif num_embeddings is not None :
232- x = torch .cat ([cls_tokens , num_embeddings ], dim = 1 )
246+ if self .use_cls :
247+ x = torch .cat ([num_embeddings , cls_tokens ], dim = 1 )
248+ else :
249+ x = num_embeddings
233250 else :
234251 raise ValueError ("No features provided to the model." )
235252
@@ -242,7 +259,9 @@ def forward(self, num_features, cat_features):
242259 elif self .pooling_method == "sum" :
243260 x = torch .sum (x , dim = 1 )
244261 elif self .pooling_method == "cls_token" :
245- x = x [:, 0 ]
262+ x = x [:, - 1 ]
263+ elif self .pooling_method == "last" :
264+ x = x [:, - 1 ]
246265 else :
247266 raise ValueError (f"Invalid pooling method: { self .pooling_method } " )
248267
0 commit comments