Skip to content

Commit fcb17c1

Browse files
committed
include ReGLU in TabTransformer
1 parent c3d4c01 commit fcb17c1

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

mambular/base_models/tabtransformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from ..configs.tabtransformer_config import DefaultTabTransformerConfig
1313
from .basemodel import BaseModel
14+
from ..arch_utils.transformer_utils import CustomTransformerEncoderLayer
1415

1516

1617
class TabTransformer(BaseModel):
@@ -91,7 +92,7 @@ def __init__(
9192
"num_embedding_activation", config.num_embedding_activation
9293
)
9394

94-
encoder_layer = nn.TransformerEncoderLayer(
95+
encoder_layer = CustomTransformerEncoderLayer(
9596
d_model=self.hparams.get("d_model", config.d_model),
9697
nhead=self.hparams.get("n_heads", config.n_heads),
9798
batch_first=True,

0 commit comments

Comments
 (0)