Skip to content

Commit 2e87e87

Browse files
authored
Merge pull request #216 from basf/johnson_su
Johnson su
2 parents c379a7a + d6380fd commit 2e87e87

3 files changed

Lines changed: 148 additions & 29 deletions

File tree

mambular/arch_utils/layer_utils/embedding_layer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,10 @@ def forward(self, num_features, cat_features, emb_features):
156156
# Process categorical embeddings
157157
if self.cat_embeddings and cat_features is not None:
158158
cat_embeddings = [
159-
emb(cat_features[i]) for i, emb in enumerate(self.cat_embeddings)
159+
emb(cat_features[i]) if emb(cat_features[i]).ndim == 3 else emb(cat_features[i]).unsqueeze(1)
160+
for i, emb in enumerate(self.cat_embeddings)
160161
]
162+
161163
cat_embeddings = torch.stack(cat_embeddings, dim=1)
162164
cat_embeddings = torch.squeeze(cat_embeddings, dim=2)
163165
if self.layer_norm_after_embedding:
@@ -189,6 +191,7 @@ def forward(self, num_features, cat_features, emb_features):
189191
]
190192
emb_embeddings = torch.stack(emb_embeddings, dim=1)
191193
else:
194+
192195
emb_embeddings = torch.stack(emb_features, dim=1)
193196
if self.layer_norm_after_embedding:
194197
emb_embeddings = self.embedding_norm(emb_embeddings)
@@ -199,6 +202,7 @@ def forward(self, num_features, cat_features, emb_features):
199202

200203
if embeddings:
201204
x = torch.cat(embeddings, dim=1) if len(embeddings) > 1 else embeddings[0]
205+
202206
else:
203207
raise ValueError("No features provided to the model.")
204208

mambular/preprocessing/preprocessor.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ class Preprocessor:
4040
4141
Parameters
4242
----------
43+
feature_preprocessing: dict or None
44+
Dictionary mapping column names to preprocessing techniques. Example:
45+
{
46+
"num_feature1": "minmax",
47+
"num_feature2": "ple",
48+
"cat_feature1": "one-hot",
49+
"cat_feature2": "int"
50+
}
4351
n_bins : int, default=50
4452
The number of bins to use for numerical feature binning. This parameter is relevant
4553
only if `numerical_preprocessing` is set to 'binning', 'ple' or 'one-hot'.
@@ -94,6 +102,7 @@ class Preprocessor:
94102

95103
def __init__(
96104
self,
105+
feature_preprocessing=None,
97106
n_bins=64,
98107
numerical_preprocessing="ple",
99108
categorical_preprocessing="int",
@@ -153,6 +162,7 @@ def __init__(
153162
)
154163

155164
self.use_decision_tree_bins = use_decision_tree_bins
165+
self.feature_preprocessing = feature_preprocessing or {}
156166
self.column_transformer = None
157167
self.fitted = False
158168
self.binning_strategy = binning_strategy
@@ -300,6 +310,10 @@ def fit(self, X, y=None, embeddings=None):
300310

301311
if numerical_features:
302312
for feature in numerical_features:
313+
feature_preprocessing = self.feature_preprocessing.get(
314+
feature, self.numerical_preprocessing
315+
)
316+
303317
# extended the annotation list if new transformer is added, either from sklearn or custom
304318
numeric_transformer_steps: list[
305319
tuple[
@@ -322,7 +336,7 @@ def fit(self, X, y=None, embeddings=None):
322336
| SigmoidExpansion,
323337
]
324338
] = [("imputer", SimpleImputer(strategy="mean"))]
325-
if self.numerical_preprocessing in ["binning", "one-hot"]:
339+
if feature_preprocessing in ["binning", "one-hot"]:
326340
bins = (
327341
self._get_decision_tree_bins(X[[feature]], y, [feature])
328342
if self.use_decision_tree_bins
@@ -356,22 +370,22 @@ def fit(self, X, y=None, embeddings=None):
356370
]
357371
)
358372

359-
if self.numerical_preprocessing == "one-hot":
373+
if feature_preprocessing == "one-hot":
360374
numeric_transformer_steps.extend(
361375
[
362376
("onehot_from_ordinal", OneHotFromOrdinal()),
363377
]
364378
)
365379

366-
elif self.numerical_preprocessing == "standardization":
380+
elif feature_preprocessing == "standardization":
367381
numeric_transformer_steps.append(("scaler", StandardScaler()))
368382

369-
elif self.numerical_preprocessing == "minmax":
383+
elif feature_preprocessing == "minmax":
370384
numeric_transformer_steps.append(
371385
("minmax", MinMaxScaler(feature_range=(-1, 1)))
372386
)
373387

374-
elif self.numerical_preprocessing == "quantile":
388+
elif feature_preprocessing == "quantile":
375389
numeric_transformer_steps.append(
376390
(
377391
"quantile",
@@ -381,7 +395,7 @@ def fit(self, X, y=None, embeddings=None):
381395
)
382396
)
383397

384-
elif self.numerical_preprocessing == "polynomial":
398+
elif feature_preprocessing == "polynomial":
385399
if self.scaling_strategy == "standardization":
386400
numeric_transformer_steps.append(("scaler", StandardScaler()))
387401
elif self.scaling_strategy == "minmax":
@@ -395,10 +409,10 @@ def fit(self, X, y=None, embeddings=None):
395409
)
396410
)
397411

398-
elif self.numerical_preprocessing == "robust":
412+
elif feature_preprocessing == "robust":
399413
numeric_transformer_steps.append(("robust", RobustScaler()))
400414

401-
elif self.numerical_preprocessing == "splines":
415+
elif feature_preprocessing == "splines":
402416
if self.scaling_strategy == "standardization":
403417
numeric_transformer_steps.append(("scaler", StandardScaler()))
404418
elif self.scaling_strategy == "minmax":
@@ -419,7 +433,7 @@ def fit(self, X, y=None, embeddings=None):
419433
),
420434
)
421435

422-
elif self.numerical_preprocessing == "rbf":
436+
elif feature_preprocessing == "rbf":
423437
if self.scaling_strategy == "standardization":
424438
numeric_transformer_steps.append(("scaler", StandardScaler()))
425439
elif self.scaling_strategy == "minmax":
@@ -438,7 +452,7 @@ def fit(self, X, y=None, embeddings=None):
438452
)
439453
)
440454

441-
elif self.numerical_preprocessing == "sigmoid":
455+
elif feature_preprocessing == "sigmoid":
442456
if self.scaling_strategy == "standardization":
443457
numeric_transformer_steps.append(("scaler", StandardScaler()))
444458
elif self.scaling_strategy == "minmax":
@@ -457,15 +471,19 @@ def fit(self, X, y=None, embeddings=None):
457471
)
458472
)
459473

460-
elif self.numerical_preprocessing == "ple":
474+
475+
elif feature_preprocessing == "ple":
461476
numeric_transformer_steps.append(
462477
("minmax", MinMaxScaler(feature_range=(-1, 1)))
463478
)
464479
numeric_transformer_steps.append(
465480
("ple", PLE(n_bins=self.n_bins, task=self.task))
466481
)
467482

468-
elif self.numerical_preprocessing == "box-cox":
483+
elif feature_preprocessing == "box-cox":
484+
numeric_transformer_steps.append(
485+
("minmax", MinMaxScaler(feature_range=(1e-03, 1)))
486+
)
469487
numeric_transformer_steps.append(
470488
("check_positive", MinMaxScaler(feature_range=(1e-3, 1)))
471489
)
@@ -476,15 +494,15 @@ def fit(self, X, y=None, embeddings=None):
476494
)
477495
)
478496

479-
elif self.numerical_preprocessing == "yeo-johnson":
497+
elif feature_preprocessing == "yeo-johnson":
480498
numeric_transformer_steps.append(
481499
(
482500
"yeo-johnson",
483501
PowerTransformer(method="yeo-johnson", standardize=True),
484502
)
485503
)
486504

487-
elif self.numerical_preprocessing == "none":
505+
elif feature_preprocessing == "none":
488506
numeric_transformer_steps.append(
489507
(
490508
"none",
@@ -498,15 +516,18 @@ def fit(self, X, y=None, embeddings=None):
498516

499517
if categorical_features:
500518
for feature in categorical_features:
501-
if self.categorical_preprocessing == "int":
519+
feature_preprocessing = self.feature_preprocessing.get(
520+
feature, self.categorical_preprocessing
521+
)
522+
if feature_preprocessing == "int":
502523
# Use ContinuousOrdinalEncoder for "int"
503524
categorical_transformer = Pipeline(
504525
[
505526
("imputer", SimpleImputer(strategy="most_frequent")),
506527
("continuous_ordinal", ContinuousOrdinalEncoder()),
507528
]
508529
)
509-
elif self.categorical_preprocessing == "one-hot":
530+
elif feature_preprocessing == "one-hot":
510531
# Use OneHotEncoder for "one-hot"
511532
categorical_transformer = Pipeline(
512533
[
@@ -516,15 +537,15 @@ def fit(self, X, y=None, embeddings=None):
516537
]
517538
)
518539

519-
elif self.categorical_preprocessing == "none":
540+
elif feature_preprocessing == "none":
520541
# Use OneHotEncoder for "one-hot"
521542
categorical_transformer = Pipeline(
522543
[
523544
("imputer", SimpleImputer(strategy="most_frequent")),
524545
("none", NoTransformer()),
525546
]
526547
)
527-
elif self.categorical_preprocessing == "pretrained":
548+
elif feature_preprocessing == "pretrained":
528549
categorical_transformer = Pipeline(
529550
[
530551
("imputer", SimpleImputer(strategy="most_frequent")),
@@ -533,7 +554,7 @@ def fit(self, X, y=None, embeddings=None):
533554
)
534555
else:
535556
raise ValueError(
536-
f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}"
557+
f"Unknown categorical_preprocessing type: {feature_preprocessing}"
537558
)
538559

539560
# Append the transformer for the current categorical feature

0 commit comments

Comments
 (0)