Skip to content

Commit 44d05ac

Browse files
committed
TabR integration + compatibility checking of pre-processing method with plr embedding
1 parent af2a565 commit 44d05ac

2 files changed

Lines changed: 21 additions & 1 deletion

File tree

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,8 @@ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config)
125125
if self.layer_norm_after_embedding:
126126
self.embedding_norm = nn.LayerNorm(self.d_model)
127127

128+
self.feature_info = (num_feature_info, cat_feature_info, emb_feature_info)
129+
128130
def forward(self, num_features, cat_features, emb_features):
129131
"""Defines the forward pass of the model.
130132
@@ -171,6 +173,8 @@ def forward(self, num_features, cat_features, emb_features):
171173

172174
# Process numerical embeddings based on embedding_type
173175
if self.embedding_type == "plr":
176+
# check pre-processing type compatibility with plr
177+
self.check_plr_embedding_compatibility(self.feature_info)
174178
# For PLR, pass all numerical features together
175179
if num_features is not None:
176180
num_features = torch.stack(num_features, dim=1).squeeze(
@@ -226,6 +230,21 @@ def forward(self, num_features, cat_features, emb_features):
226230
x = self.embedding_dropout(x)
227231

228232
return x
233+
234+
def check_plr_embedding_compatibility(self, feature_info:tuple):
235+
# List of incompatible preprocessing terms for PLR embedding
236+
incompatible_terms = ['ple', 'one-hot', 'polynomial', 'splines', 'sigmoid', 'rbf']
237+
238+
# Iterate through each dictionary in the tuple (data)
239+
for sub_dict in feature_info:
240+
# Iterate through each feature in the current dictionary
241+
for feature, properties in sub_dict.items():
242+
preprocessing = properties.get('preprocessing', '')
243+
244+
# Check for incompatible terms in the preprocessing string
245+
for term in incompatible_terms:
246+
if term in preprocessing:
247+
raise ValueError(f"PLR embedding type doesn't work with the '{term}' pre-processing method.\n")
229248

230249

231250
class OneHotEncoding(nn.Module):

mambular/base_models/utils/lightning_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ def validation_step(self, batch, batch_idx): # type: ignore
277277
data, labels = batch
278278
if hasattr(self.estimator, "validate_with_candidates") and self.train_features is not None:
279279
preds = self.estimator.validate_with_candidates(
280-
*data, candidate_x=self.train_features, candidate_y=self.train_targets
280+
*data,
281+
candidate_x=self.train_features, candidate_y=self.train_targets
281282
)
282283
else:
283284
preds = self(*data)

0 commit comments

Comments
 (0)