Skip to content

Commit 72ebc7b

Browse files
authored
Merge pull request #258 from bishnukhadka/tabr
TabR integration + compatibility check for PLR embedding with pre-processing types.
2 parents 2957f09 + f0477f2 commit 72ebc7b

9 files changed

Lines changed: 631 additions & 193 deletions

File tree

.gitignore

Lines changed: 0 additions & 176 deletions
This file was deleted.

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Mambular is a Python package that brings the power of advanced deep learning arc
8383
| `Trompt` | Trompt: Towards a Better Deep Neural Network for Tabular Data introduced [here](https://arxiv.org/abs/2305.18446). |
8484
| `Tangos` | Tangos: Regularizing Tabular Neural Networks through Gradient Orthogonalization and Specialization introduced [here](https://openreview.net/pdf?id=n6H86gW8u0d). |
8585
| `ModernNCA` | Revisiting Nearest Neighbor for Tabular Data: A Deep Tabular Baseline Two Decades Later introduced [here](https://arxiv.org/abs/2407.03257). |
86+
| `TabR` | TabR: Tabular Deep Learning Meets Nearest Neighbors in 2023 [here](https://arxiv.org/abs/2307.14338) |
8687

8788

8889

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config)
125125
if self.layer_norm_after_embedding:
126126
self.embedding_norm = nn.LayerNorm(self.d_model)
127127

128+
self.feature_info = (num_feature_info, cat_feature_info, emb_feature_info)
129+
128130
def forward(self, num_features, cat_features, emb_features):
129131
"""Defines the forward pass of the model.
130132
@@ -171,6 +173,8 @@ def forward(self, num_features, cat_features, emb_features):
171173

172174
# Process numerical embeddings based on embedding_type
173175
if self.embedding_type == "plr":
176+
# check pre-processing type compatibility with plr
177+
self.check_plr_embedding_compatibility(self.feature_info)
174178
# For PLR, pass all numerical features together
175179
if num_features is not None:
176180
num_features = torch.stack(num_features, dim=1).squeeze(
@@ -226,6 +230,21 @@ def forward(self, num_features, cat_features, emb_features):
226230
x = self.embedding_dropout(x)
227231

228232
return x
233+
234+
def check_plr_embedding_compatibility(self, feature_info:tuple):
235+
# List of incompatible preprocessing terms for PLR embedding
236+
incompatible_terms = ['ple', 'one-hot', 'polynomial', 'splines', 'sigmoid', 'rbf']
237+
238+
# Iterate through each dictionary in the tuple (data)
239+
for sub_dict in feature_info:
240+
# Iterate through each feature in the current dictionary
241+
for feature, properties in sub_dict.items():
242+
preprocessing = properties.get('preprocessing', '')
243+
244+
# Check for incompatible terms in the preprocessing string
245+
for term in incompatible_terms:
246+
if term in preprocessing:
247+
raise ValueError(f"PLR embedding type doesn't work with the '{term}' pre-processing method.\n")
229248

230249

231250
class OneHotEncoding(nn.Module):

mambular/base_models/modern_nca.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(
2222
self.save_hyperparameters(ignore=["feature_information"])
2323

2424
self.returns_ensemble = False
25-
self.uses_nca_candidates = True
25+
self.uses_candidates = True
2626

2727
self.T = config.temperature
2828
self.sample_rate = config.sample_rate
@@ -31,6 +31,7 @@ def __init__(
3131
*feature_information,
3232
config=config,
3333
)
34+
3435
input_dim = np.sum(
3536
[len(info) * self.hparams.d_model for info in feature_information]
3637
)
@@ -75,7 +76,7 @@ def forward(self, *data):
7576
x = self.post_encoder(x)
7677
return self.tabular_head(x)
7778

78-
def nca_train(self, *data, targets, candidate_x, candidate_y):
79+
def train_with_candidates(self, *data, targets, candidate_x, candidate_y):
7980
"""NCA-style training forward pass selecting candidates."""
8081
if self.hparams.use_embeddings:
8182
x = self.embedding_layer(*data)
@@ -85,6 +86,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
8586
B, S, D = candidate_x.shape
8687
candidate_x = candidate_x.reshape(B, S * D)
8788
else:
89+
8890
x = torch.cat([t for tensors in data for t in tensors], dim=1)
8991
candidate_x = torch.cat(
9092
[t for tensors in candidate_x for t in tensors], dim=1
@@ -129,7 +131,7 @@ def nca_train(self, *data, targets, candidate_x, candidate_y):
129131

130132
return logits
131133

132-
def nca_validate(self, *data, candidate_x, candidate_y):
134+
def validate_with_candidates(self, *data, candidate_x, candidate_y):
133135
"""Validation forward pass with NCA-style candidate selection."""
134136
if self.hparams.use_embeddings:
135137
x = self.embedding_layer(*data)
@@ -172,7 +174,7 @@ def nca_validate(self, *data, candidate_x, candidate_y):
172174

173175
return logits
174176

175-
def nca_predict(self, *data, candidate_x, candidate_y):
177+
def predict_with_candidates(self, *data, candidate_x, candidate_y):
176178
"""Prediction forward pass with candidate selection."""
177179
if self.hparams.use_embeddings:
178180
x = self.embedding_layer(*data)

0 commit comments

Comments
 (0)