Skip to content

Commit a4c5992

Browse files
committed
fix minor bugs related to imports and dim identification
1 parent b8bc5e9 commit a4c5992

3 files changed

Lines changed: 4 additions & 3 deletions

File tree

mambular/base_models/mambatab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self.initial_layer = nn.Linear(input_dim, config.d_model)
7373
self.norm_f = LayerNorm(config.d_model)
7474

75-
self.embedding_activation = self.hparams.num_embedding_activation
75+
self.embedding_activation = self.hparams.embedding_activation
7676

7777
self.axis = config.axis
7878

mambular/base_models/saint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..arch_utils.transformer_utils import RowColTransformer
55
from ..configs.saint_config import DefaultSAINTConfig
66
from .basemodel import BaseModel
7+
import numpy as np
78

89

910
class SAINT(BaseModel):

mambular/base_models/tabtransformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def __init__(
9393
)
9494

9595
mlp_input_dim = 0
96-
for feature_name, input_shape in num_feature_info.items():
97-
mlp_input_dim += input_shape
96+
for feature_name, info in num_feature_info.items():
97+
mlp_input_dim += info["dimension"]
9898
mlp_input_dim += self.hparams.d_model
9999

100100
self.tabular_head = MLPhead(

0 commit comments

Comments
 (0)