@@ -54,20 +54,17 @@ class NDTF(BaseModel):
5454
5555 def __init__ (
5656 self ,
57- cat_feature_info ,
58- num_feature_info ,
57+ feature_information : tuple , # Expecting (num_feature_info, cat_feature_info, embedding_feature_info)
5958 num_classes : int = 1 ,
6059 config : DefaultNDTFConfig = DefaultNDTFConfig (), # noqa: B008
6160 ** kwargs ,
6261 ):
6362 super ().__init__ (config = config , ** kwargs )
64- self .save_hyperparameters (ignore = ["cat_feature_info" , "num_feature_info " ])
63+ self .save_hyperparameters (ignore = ["feature_information " ])
6564
66- self .cat_feature_info = cat_feature_info
67- self .num_feature_info = num_feature_info
6865 self .returns_ensemble = False
6966
70- input_dim = get_feature_dimensions (num_feature_info , cat_feature_info )
67+ input_dim = get_feature_dimensions (* feature_information )
7168
7269 self .input_dimensions = [input_dim ]
7370
@@ -78,10 +75,13 @@ def __init__(
7875 [
7976 NeuralDecisionTree (
8077 input_dim = self .input_dimensions [idx ],
81- depth = np .random .randint (self .hparams .min_depth , self .hparams .max_depth ),
78+ depth = np .random .randint (
79+ self .hparams .min_depth , self .hparams .max_depth
80+ ),
8281 output_dim = num_classes ,
8382 lamda = self .hparams .lamda ,
84- temperature = self .hparams .temperature + np .abs (np .random .normal (0 , 0.1 )),
83+ temperature = self .hparams .temperature
84+ + np .abs (np .random .normal (0 , 0.1 )),
8585 node_sampling = self .hparams .node_sampling ,
8686 )
8787 for idx in range (self .hparams .n_ensembles )
@@ -103,21 +103,20 @@ def __init__(
103103 requires_grad = True ,
104104 )
105105
106- def forward (self , num_features , cat_features ) -> torch .Tensor :
106+ def forward (self , * data ) -> torch .Tensor :
107107 """Forward pass of the NDTF model.
108108
109109 Parameters
110110 ----------
111- x : torch.Tensor
112- Input tensor .
111+ data : tuple
112+ Input tuple of tensors of num_features, cat_features, embeddings .
113113
114114 Returns
115115 -------
116116 torch.Tensor
117117 Output tensor.
118118 """
119- x = num_features + cat_features
120- x = torch .cat (x , dim = 1 )
119+ x = torch .cat ([t for tensors in data for t in tensors ], dim = 1 )
121120 x = self .conv_layer (x .unsqueeze (2 ))
122121 x = x .transpose (1 , 2 ).squeeze (- 1 )
123122
@@ -131,21 +130,20 @@ def forward(self, num_features, cat_features) -> torch.Tensor:
131130
132131 return preds @ self .tree_weights
133132
134- def penalty_forward (self , num_features , cat_features ) -> torch .Tensor :
133+ def penalty_forward (self , * data ) -> torch .Tensor :
135134 """Forward pass of the NDTF model.
136135
137136 Parameters
138137 ----------
139- x : torch.Tensor
140- Input tensor .
138+ data : tuple
139+ Input tuple of tensors of num_features, cat_features, embeddings .
141140
142141 Returns
143142 -------
144143 torch.Tensor
145144 Output tensor.
146145 """
147- x = num_features + cat_features
148- x = torch .cat (x , dim = 1 )
146+ x = torch .cat ([t for tensors in data for t in tensors ], dim = 1 )
149147 x = self .conv_layer (x .unsqueeze (2 ))
150148 x = x .transpose (1 , 2 ).squeeze (- 1 )
151149
0 commit comments