@@ -125,6 +125,8 @@ def __init__(self, num_feature_info, cat_feature_info, emb_feature_info, config)
125125 if self .layer_norm_after_embedding :
126126 self .embedding_norm = nn .LayerNorm (self .d_model )
127127
128+ self .feature_info = (num_feature_info , cat_feature_info , emb_feature_info )
129+
128130 def forward (self , num_features , cat_features , emb_features ):
129131 """Defines the forward pass of the model.
130132
@@ -171,6 +173,8 @@ def forward(self, num_features, cat_features, emb_features):
171173
172174 # Process numerical embeddings based on embedding_type
173175 if self .embedding_type == "plr" :
176+ # check pre-processing type compatibility with plr
177+ self .check_plr_embedding_compatibility (self .feature_info )
174178 # For PLR, pass all numerical features together
175179 if num_features is not None :
176180 num_features = torch .stack (num_features , dim = 1 ).squeeze (
@@ -226,6 +230,21 @@ def forward(self, num_features, cat_features, emb_features):
226230 x = self .embedding_dropout (x )
227231
228232 return x
233+
234+ def check_plr_embedding_compatibility (self , feature_info :tuple ):
235+ # List of incompatible preprocessing terms for PLR embedding
236+ incompatible_terms = ['ple' , 'one-hot' , 'polynomial' , 'splines' , 'sigmoid' , 'rbf' ]
237+
238+ # Iterate through each dictionary in the tuple (data)
239+ for sub_dict in feature_info :
240+ # Iterate through each feature in the current dictionary
241+ for feature , properties in sub_dict .items ():
242+ preprocessing = properties .get ('preprocessing' , '' )
243+
244+ # Check for incompatible terms in the preprocessing string
245+ for term in incompatible_terms :
246+ if term in preprocessing :
247+ raise ValueError (f"PLR embedding type doesn't work with the '{ term } ' pre-processing method.\n " )
229248
230249
231250class OneHotEncoding (nn .Module ):
0 commit comments