Skip to content

Commit c3d4c01

Browse files
committed
include cls token at end of sequence
1 parent a9f5e4c commit c3d4c01

1 file changed

Lines changed: 27 additions & 8 deletions

File tree

mambular/base_models/mambular.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def __init__(
174174
torch.zeros(1, 1, self.hparams.get("d_model", config.d_model))
175175
)
176176

177+
if self.pooling_method == "cls":
178+
self.use_cls = True
179+
else:
180+
self.use_cls = self.hparams.get("use_cls", config.use_cls)
181+
177182
if self.hparams.get("layer_norm_after_embedding"):
178183
self.embedding_norm = nn.LayerNorm(
179184
self.hparams.get("d_model", config.d_model)
@@ -198,10 +203,13 @@ def forward(self, num_features, cat_features):
198203
Tensor
199204
The output predictions of the model.
200205
"""
201-
batch_size = (
202-
cat_features[0].size(0) if cat_features != [] else num_features[0].size(0)
203-
)
204-
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
206+
if self.use_cls:
207+
batch_size = (
208+
cat_features[0].size(0)
209+
if cat_features != []
210+
else num_features[0].size(0)
211+
)
212+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
205213

206214
if len(self.cat_embeddings) > 0 and cat_features:
207215
cat_embeddings = [
@@ -225,11 +233,20 @@ def forward(self, num_features, cat_features):
225233
num_embeddings = None
226234

227235
if cat_embeddings is not None and num_embeddings is not None:
228-
x = torch.cat([cls_tokens, cat_embeddings, num_embeddings], dim=1)
236+
if self.use_cls:
237+
x = torch.cat([cat_embeddings, num_embeddings, cls_tokens], dim=1)
238+
else:
239+
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
229240
elif cat_embeddings is not None:
230-
x = torch.cat([cls_tokens, cat_embeddings], dim=1)
241+
if self.use_cls:
242+
x = torch.cat([cat_embeddings, cls_tokens], dim=1)
243+
else:
244+
x = cat_embeddings
231245
elif num_embeddings is not None:
232-
x = torch.cat([cls_tokens, num_embeddings], dim=1)
246+
if self.use_cls:
247+
x = torch.cat([num_embeddings, cls_tokens], dim=1)
248+
else:
249+
x = num_embeddings
233250
else:
234251
raise ValueError("No features provided to the model.")
235252

@@ -242,7 +259,9 @@ def forward(self, num_features, cat_features):
242259
elif self.pooling_method == "sum":
243260
x = torch.sum(x, dim=1)
244261
elif self.pooling_method == "cls_token":
245-
x = x[:, 0]
262+
x = x[:, -1]
263+
elif self.pooling_method == "last":
264+
x = x[:, -1]
246265
else:
247266
raise ValueError(f"Invalid pooling method: {self.pooling_method}")
248267

0 commit comments

Comments
 (0)