Skip to content

Commit b10ff52

Browse files
committed
adapt embedding layer to new preprocessing
1 parent c3e9c90 commit b10ff52

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,10 @@ def forward(self, num_features=None, cat_features=None):
141141
# Process categorical embeddings
142142
if self.cat_embeddings and cat_features is not None:
143143
cat_embeddings = [
144-
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
144+
emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1)
145+
for i, emb in enumerate(self.cat_embeddings)
145146
]
147+
146148
cat_embeddings = torch.stack(cat_embeddings, dim=1)
147149
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
148150
if self.layer_norm_after_embedding:
@@ -175,6 +177,7 @@ def forward(self, num_features=None, cat_features=None):
175177

176178
# Combine categorical and numerical embeddings
177179
if cat_embeddings is not None and num_embeddings is not None:
180+
178181
x = torch.cat([cat_embeddings, num_embeddings], dim=1)
179182
elif cat_embeddings is not None:
180183
x = cat_embeddings

mambular/preprocessing/preprocessor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,9 @@ def fit(self, X, y=None):
464464
)
465465

466466
elif feature_preprocessing == "box-cox":
467+
numeric_transformer_steps.append(
468+
("minmax", MinMaxScaler(feature_range=(1e-03, 1)))
469+
)
467470
numeric_transformer_steps.append(
468471
(
469472
"box-cox",

0 commit comments

Comments
 (0)