Skip to content

Commit af2a565

Browse files
committed
tabR integration
1 parent 2957f09 commit af2a565

7 files changed

Lines changed: 589 additions & 192 deletions

File tree

.gitignore

Lines changed: 0 additions & 176 deletions
This file was deleted.

mambular/base_models/modern_nca.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self.save_hyperparameters(ignore=["feature_information"])
2323

2424
self.returns_ensemble = False
25-
self.uses_nca_candidates = True
25+
self.uses_candidates = True
2626

2727
self.T = config.temperature
2828
self.sample_rate = config.sample_rate
@@ -31,6 +31,7 @@ def __init__(
3131
*feature_information,
3232
config=config,
3333
)
34+
print(self.embedding_layer)
3435
input_dim = np.sum(
3536
[len(info) * self.hparams.d_model for info in feature_information]
3637
)
@@ -75,7 +76,7 @@ def forward(self, *data):
7576
x = self.post_encoder(x)
7677
return self.tabular_head(x)
7778

78-
def nca_train(self, *data, targets, candidate_x, candidate_y):
79+
def train_with_candidates(self, *data, targets, candidate_x, candidate_y):
7980
"""NCA-style training forward pass selecting candidates."""
8081
if self.hparams.use_embeddings:
8182
x = self.embedding_layer(*data)
@@ -85,6 +86,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
8586
B, S, D = candidate_x.shape
8687
candidate_x = candidate_x.reshape(B, S * D)
8788
else:
89+
8890
x = torch.cat([t for tensors in data for t in tensors], dim=1)
8991
candidate_x = torch.cat(
9092
[t for tensors in candidate_x for t in tensors], dim=1
@@ -129,7 +131,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
129131

130132
return logits
131133

132-
def nca_validate(self, *data, candidate_x, candidate_y):
134+
def validate_with_candidates(self, *data, candidate_x, candidate_y):
133135
"""Validation forward pass with NCA-style candidate selection."""
134136
if self.hparams.use_embeddings:
135137
x = self.embedding_layer(*data)
@@ -172,7 +174,7 @@ def nca_validate(self, *data, candidate_x, candidate_y):
172174

173175
return logits
174176

175-
def nca_predict(self, *data, candidate_x, candidate_y):
177+
def predict_with_candidates(self, *data, candidate_x, candidate_y):
176178
"""Prediction forward pass with candidate selection."""
177179
if self.hparams.use_embeddings:
178180
x = self.embedding_layer(*data)

0 commit comments

Comments
 (0)