Skip to content

Commit fac6a1f

Browse files
committed
include encoding function to create embeddings
1 parent 50a3883 commit fac6a1f

1 file changed

Lines changed: 34 additions & 2 deletions

File tree

mambular/base_models/basemodel.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@ def save_hyperparameters(self, ignore=[]):
3333
List of keys to ignore while saving hyperparameters, by default [].
3434
"""
3535
# Filter the config and extra hparams for ignored keys
36-
config_hparams = {k: v for k, v in vars(self.config).items() if k not in ignore} if self.config else {}
36+
config_hparams = (
37+
{k: v for k, v in vars(self.config).items() if k not in ignore}
38+
if self.config
39+
else {}
40+
)
3741
extra_hparams = {k: v for k, v in self.extra_hparams.items() if k not in ignore}
3842
config_hparams.update(extra_hparams)
3943

@@ -148,7 +152,9 @@ def initialize_pooling_layers(self, config, n_inputs):
148152
"""Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method."""
149153
if self.hparams.pooling_method == "learned_flatten":
150154
# Flattening + Linear layer
151-
self.learned_flatten_pooling = nn.Linear(n_inputs * config.dim_feedforward, config.dim_feedforward)
155+
self.learned_flatten_pooling = nn.Linear(
156+
n_inputs * config.dim_feedforward, config.dim_feedforward
157+
)
152158

153159
elif self.hparams.pooling_method == "attention":
154160
# Attention-based pooling with learnable attention weights
@@ -216,3 +222,29 @@ def pool_sequence(self, out):
216222
return out
217223
else:
218224
raise ValueError(f"Invalid pooling method: {self.hparams.pooling_method}")
225+
226+
def encode(self, num_features, cat_features):
227+
if not hasattr(self, "embedding_layer"):
228+
raise ValueError("The model does not have an embedding layer")
229+
230+
# Check if at least one of the contextualized embedding methods exists
231+
valid_layers = ["mamba", "rnn", "lstm", "encoder"]
232+
available_layer = next(
233+
(attr for attr in valid_layers if hasattr(self, attr)), None
234+
)
235+
236+
if not available_layer:
237+
raise ValueError("The model does not generate contextualized embeddings")
238+
239+
# Get the actual layer and call it
240+
x = self.embedding_layer(num_features=num_features, cat_features=cat_features)
241+
242+
if getattr(self.hparams, "shuffle_embeddings", False):
243+
x = x[:, self.perm, :]
244+
245+
layer = getattr(self, available_layer)
246+
if available_layer == "rnn":
247+
embeddings, _ = layer(x)
248+
else:
249+
embeddings = layer(x)
250+
return embeddings

0 commit comments

Comments
 (0)