@@ -33,7 +33,11 @@ def save_hyperparameters(self, ignore=[]):
3333 List of keys to ignore while saving hyperparameters, by default [].
3434 """
3535 # Filter the config and extra hparams for ignored keys
36- config_hparams = {k : v for k , v in vars (self .config ).items () if k not in ignore } if self .config else {}
36+ config_hparams = (
37+ {k : v for k , v in vars (self .config ).items () if k not in ignore }
38+ if self .config
39+ else {}
40+ )
3741 extra_hparams = {k : v for k , v in self .extra_hparams .items () if k not in ignore }
3842 config_hparams .update (extra_hparams )
3943
@@ -148,7 +152,9 @@ def initialize_pooling_layers(self, config, n_inputs):
148152 """Initializes the layers needed for learnable pooling methods based on self.hparams.pooling_method."""
149153 if self .hparams .pooling_method == "learned_flatten" :
150154 # Flattening + Linear layer
151- self .learned_flatten_pooling = nn .Linear (n_inputs * config .dim_feedforward , config .dim_feedforward )
155+ self .learned_flatten_pooling = nn .Linear (
156+ n_inputs * config .dim_feedforward , config .dim_feedforward
157+ )
152158
153159 elif self .hparams .pooling_method == "attention" :
154160 # Attention-based pooling with learnable attention weights
@@ -216,3 +222,29 @@ def pool_sequence(self, out):
216222 return out
217223 else :
218224 raise ValueError (f"Invalid pooling method: { self .hparams .pooling_method } " )
225+
226+ def encode (self , num_features , cat_features ):
227+ if not hasattr (self , "embedding_layer" ):
228+ raise ValueError ("The model does not have an embedding layer" )
229+
230+ # Check if at least one of the contextualized embedding methods exists
231+ valid_layers = ["mamba" , "rnn" , "lstm" , "encoder" ]
232+ available_layer = next (
233+ (attr for attr in valid_layers if hasattr (self , attr )), None
234+ )
235+
236+ if not available_layer :
237+ raise ValueError ("The model does not generate contextualized embeddings" )
238+
239+ # Get the actual layer and call it
240+ x = self .embedding_layer (num_features = num_features , cat_features = cat_features )
241+
242+ if getattr (self .hparams , "shuffle_embeddings" , False ):
243+ x = x [:, self .perm , :]
244+
245+ layer = getattr (self , available_layer )
246+ if available_layer == "rnn" :
247+ embeddings , _ = layer (x )
248+ else :
249+ embeddings = layer (x )
250+ return embeddings
0 commit comments