@@ -22,7 +22,7 @@ def __init__(
2222 self .save_hyperparameters (ignore = ["feature_information" ])
2323
2424 self .returns_ensemble = False
25- self .uses_nca_candidates = True
25+ self .uses_candidates = True
2626
2727 self .T = config .temperature
2828 self .sample_rate = config .sample_rate
@@ -31,6 +31,7 @@ def __init__(
3131 * feature_information ,
3232 config = config ,
3333 )
34+ print (self .embedding_layer )
3435 input_dim = np .sum (
3536 [len (info ) * self .hparams .d_model for info in feature_information ]
3637 )
@@ -75,7 +76,7 @@ def forward(self, *data):
7576 x = self .post_encoder (x )
7677 return self .tabular_head (x )
7778
78- def nca_train (self , * data , targets , candidate_x , candidate_y ):
79+ def train_with_candidates (self , * data , targets , candidate_x , candidate_y ):
7980 """NCA-style training forward pass selecting candidates."""
8081 if self .hparams .use_embeddings :
8182 x = self .embedding_layer (* data )
@@ -85,6 +86,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
8586 B , S , D = candidate_x .shape
8687 candidate_x = candidate_x .reshape (B , S * D )
8788 else :
89+
8890 x = torch .cat ([t for tensors in data for t in tensors ], dim = 1 )
8991 candidate_x = torch .cat (
9092 [t for tensors in candidate_x for t in tensors ], dim = 1
@@ -129,7 +131,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
129131
130132 return logits
131133
132- def nca_validate (self , * data , candidate_x , candidate_y ):
134+ def validate_with_candidates (self , * data , candidate_x , candidate_y ):
133135 """Validation forward pass with NCA-style candidate selection."""
134136 if self .hparams .use_embeddings :
135137 x = self .embedding_layer (* data )
@@ -172,7 +174,7 @@ def nca_validate(self, *data, candidate_x, candidate_y):
172174
173175 return logits
174176
175- def nca_predict (self , * data , candidate_x , candidate_y ):
177+ def predict_with_candidates (self , * data , candidate_x , candidate_y ):
176178 """Prediction forward pass with candidate selection."""
177179 if self .hparams .use_embeddings :
178180 x = self .embedding_layer (* data )
0 commit comments