File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -72,7 +72,7 @@ def __init__(
7272 self .initial_layer = nn .Linear (input_dim , config .d_model )
7373 self .norm_f = LayerNorm (config .d_model )
7474
75- self .embedding_activation = self .hparams .num_embedding_activation
75+ self .embedding_activation = self .hparams .embedding_activation
7676
7777 self .axis = config .axis
7878
Original file line number Diff line number Diff line change 44from ..arch_utils .transformer_utils import RowColTransformer
55from ..configs .saint_config import DefaultSAINTConfig
66from .basemodel import BaseModel
7+ import numpy as np
78
89
910class SAINT (BaseModel ):
Original file line number Diff line number Diff line change @@ -93,8 +93,8 @@ def __init__(
9393 )
9494
9595 mlp_input_dim = 0
96- for feature_name , input_shape in num_feature_info .items ():
97- mlp_input_dim += input_shape
96+ for feature_name , info in num_feature_info .items ():
97+ mlp_input_dim += info [ "dimension" ]
9898 mlp_input_dim += self .hparams .d_model
9999
100100 self .tabular_head = MLPhead (
You can’t perform that action at this time.
0 commit comments