Skip to content

Commit d4fbeb5

Browse files
committed
formatting
1 parent 79718f7 commit d4fbeb5

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,11 @@ def forward(self, num_features, cat_features, emb_features):
156156
# Process categorical embeddings
157157
if self.cat_embeddings and cat_features is not None:
158158
cat_embeddings = [
159-
emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1)
159+
(
160+
emb(cat_features[i])
161+
if emb(cat_features[i]).ndim == 3
162+
else emb(cat_features[i]).unsqueeze(1)
163+
)
160164
for i, emb in enumerate(self.cat_embeddings)
161165
]
162166

mambular/base_models/tabm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ def __init__(
2222
# Pass config to BaseModel
2323
super().__init__(config=config, **kwargs)
2424

25-
26-
2725
# Save hparams including config attributes
2826
self.save_hyperparameters(ignore=["feature_information"])
2927
if not self.hparams.average_ensembles:

0 commit comments

Comments
 (0)