Skip to content

Commit a2c7845

Browse files
committed
adapt lightning layer and preprocessor to account for no passed embeddings
1 parent 4ec70f8 commit a2c7845

2 files changed

Lines changed: 26 additions & 24 deletions

File tree

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,21 @@ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config)
101101
]
102102
)
103103

104-
if self.embedding_projection:
105-
self.emb_embeddings = nn.ModuleList(
106-
[
107-
nn.Sequential(
108-
nn.Linear(
109-
feature_info["dimension"],
110-
self.d_model,
111-
bias=self.embedding_bias,
112-
),
113-
self.embedding_activation,
114-
)
115-
for feature_name, feature_info in emb_feature_info.items()
116-
]
117-
)
104+
if len(emb_feature_info) >= 1:
105+
if self.embedding_projection:
106+
self.emb_embeddings = nn.ModuleList(
107+
[
108+
nn.Sequential(
109+
nn.Linear(
110+
feature_info["dimension"],
111+
self.d_model,
112+
bias=self.embedding_bias,
113+
),
114+
self.embedding_activation,
115+
)
116+
for feature_name, feature_info in emb_feature_info.items()
117+
]
118+
)
118119

119120
# Class token if required
120121
if self.use_cls:
@@ -181,15 +182,16 @@ def forward(self, num_features, cat_features, emb_features):
181182
if self.layer_norm_after_embedding:
182183
num_embeddings = self.embedding_norm(num_embeddings)
183184

184-
if self.embedding_projection:
185-
emb_embeddings = [
186-
emb(emb_features[i]) for i, emb in enumerate(self.emb_embeddings)
187-
]
188-
emb_embeddings = torch.stack(emb_embeddings, dim=1)
189-
else:
190-
emb_embeddings = torch.stack(emb_features, dim=1)
191-
if self.layer_norm_after_embedding:
192-
emb_embeddings = self.embedding_norm(emb_embeddings)
185+
if emb_features != []:
186+
if self.embedding_projection:
187+
emb_embeddings = [
188+
emb(emb_features[i]) for i, emb in enumerate(self.emb_embeddings)
189+
]
190+
emb_embeddings = torch.stack(emb_embeddings, dim=1)
191+
else:
192+
emb_embeddings = torch.stack(emb_features, dim=1)
193+
if self.layer_norm_after_embedding:
194+
emb_embeddings = self.embedding_norm(emb_embeddings)
193195

194196
embeddings = [
195197
e for e in [cat_embeddings, num_embeddings, emb_embeddings] if e is not None

mambular/preprocessing/preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,7 +727,7 @@ def get_feature_info(self, verbose=True):
727727
"categories": None,
728728
}
729729
else:
730-
embedding_feature_info = None
730+
embedding_feature_info = {}
731731

732732
if not self.column_transformer:
733733
raise RuntimeError("The preprocessor has not been fitted yet.")

0 commit comments

Comments
 (0)