We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c3d4c01 commit fcb17c1Copy full SHA for fcb17c1
1 file changed
mambular/base_models/tabtransformer.py
@@ -11,6 +11,7 @@
11
)
12
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
13
from .basemodel import BaseModel
14
+from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
15
16
17
class TabTransformer(BaseModel):
@@ -91,7 +92,7 @@ def __init__(
91
92
"num_embedding_activation", config.num_embedding_activation
93
94
- encoder_layer = nn.TransformerEncoderLayer(
95
+ encoder_layer = CustomTransformerEncoderLayer(
96
d_model=self.hparams.get("d_model", config.d_model),
97
nhead=self.hparams.get("n_heads", config.n_heads),
98
batch_first=True,
0 commit comments