Skip to content

Commit d413fd8

Browse files
committed
adding option to one-hot encode cat features in embedding layer
1 parent 029c6d8 commit d413fd8

3 files changed

Lines changed: 14 additions & 5 deletions

File tree

mambular/base_models/ft_transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,12 @@ def __init__(
132132
embedding_activation=self.hparams.get(
133133
"embedding_activation", config.embedding_activation
134134
),
135-
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
135+
layer_norm_after_embedding=self.hparams.get(
136+
"layer_norm_after_embedding", config.layer_norm_after_embedding
137+
),
136138
use_cls=True,
137139
cls_position=0,
140+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
138141
)
139142

140143
head_activation = self.hparams.get("head_activation", config.head_activation)

mambular/base_models/mambular.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,12 @@ def __init__(
150150
embedding_activation=self.hparams.get(
151151
"embedding_activation", config.embedding_activation
152152
),
153-
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
154-
use_cls=True,
155-
cls_position=0,
153+
layer_norm_after_embedding=self.hparams.get(
154+
"layer_norm_after_embedding", config.layer_norm_after_embedding
155+
),
156+
use_cls=False,
157+
cls_position=-1,
158+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
156159
)
157160

158161
head_activation = self.hparams.get("head_activation", config.head_activation)

mambular/base_models/tabtransformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,12 @@ def __init__(
139139
embedding_activation=self.hparams.get(
140140
"embedding_activation", config.embedding_activation
141141
),
142-
layer_norm_after_embedding=self.hparams.get("layer_norm_after_embedding"),
142+
layer_norm_after_embedding=self.hparams.get(
143+
"layer_norm_after_embedding", config.layer_norm_after_embedding
144+
),
143145
use_cls=True,
144146
cls_position=0,
147+
cat_encoding=self.hparams.get("cat_encoding", config.cat_encoding),
145148
)
146149

147150
head_activation = self.hparams.get("head_activation", config.head_activation)

0 commit comments

Comments
 (0)