Skip to content

Commit c3e9c90

Browse files
committed
add JohnsonSU and individual preprocessing
1 parent 4a76db9 commit c3e9c90

2 files changed

Lines changed: 200 additions & 47 deletions

File tree

mambular/preprocessing/preprocessor.py

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ class Preprocessor:
4040
4141
Parameters
4242
----------
43+
feature_preprocessing: dict or None
44+
Dictionary mapping column names to preprocessing techniques. Example:
45+
{
46+
"num_feature1": "minmax",
47+
"num_feature2": "ple",
48+
"cat_feature1": "one-hot",
49+
"cat_feature2": "int"
50+
}
4351
n_bins : int, default=50
4452
The number of bins to use for numerical feature binning. This parameter is relevant
4553
only if `numerical_preprocessing` is set to 'binning', 'ple' or 'one-hot'.
@@ -94,6 +102,7 @@ class Preprocessor:
94102

95103
def __init__(
96104
self,
105+
feature_preprocessing=None,
97106
n_bins=64,
98107
numerical_preprocessing="ple",
99108
categorical_preprocessing="int",
@@ -111,10 +120,14 @@ def __init__(
111120
):
112121
self.n_bins = n_bins
113122
self.numerical_preprocessing = (
114-
numerical_preprocessing.lower() if numerical_preprocessing is not None else "none"
123+
numerical_preprocessing.lower()
124+
if numerical_preprocessing is not None
125+
else "none"
115126
)
116127
self.categorical_preprocessing = (
117-
categorical_preprocessing.lower() if categorical_preprocessing is not None else "none"
128+
categorical_preprocessing.lower()
129+
if categorical_preprocessing is not None
130+
else "none"
118131
)
119132
if self.numerical_preprocessing not in [
120133
"ple",
@@ -149,6 +162,7 @@ def __init__(
149162
)
150163

151164
self.use_decision_tree_bins = use_decision_tree_bins
165+
self.feature_preprocessing = feature_preprocessing or {}
152166
self.column_transformer = None
153167
self.fitted = False
154168
self.binning_strategy = binning_strategy
@@ -237,13 +251,19 @@ def _detect_column_types(self, X):
237251
numerical_features.append(col)
238252
else:
239253
if isinstance(self.cat_cutoff, float):
240-
cutoff_condition = (num_unique_values / total_samples) < self.cat_cutoff
254+
cutoff_condition = (
255+
num_unique_values / total_samples
256+
) < self.cat_cutoff
241257
elif isinstance(self.cat_cutoff, int):
242258
cutoff_condition = num_unique_values < self.cat_cutoff
243259
else:
244-
raise ValueError("cat_cutoff should be either a float or an integer.")
260+
raise ValueError(
261+
"cat_cutoff should be either a float or an integer."
262+
)
245263

246-
if X[col].dtype.kind not in "iufc" or (X[col].dtype.kind == "i" and cutoff_condition):
264+
if X[col].dtype.kind not in "iufc" or (
265+
X[col].dtype.kind == "i" and cutoff_condition
266+
):
247267
categorical_features.append(col)
248268
else:
249269
numerical_features.append(col)
@@ -274,6 +294,10 @@ def fit(self, X, y=None):
274294

275295
if numerical_features:
276296
for feature in numerical_features:
297+
feature_preprocessing = self.feature_preprocessing.get(
298+
feature, self.numerical_preprocessing
299+
)
300+
277301
# extended the annotation list if new transformer is added, either from sklearn or custom
278302
numeric_transformer_steps: list[
279303
tuple[
@@ -296,7 +320,7 @@ def fit(self, X, y=None):
296320
| SigmoidExpansion,
297321
]
298322
] = [("imputer", SimpleImputer(strategy="mean"))]
299-
if self.numerical_preprocessing in ["binning", "one-hot"]:
323+
if feature_preprocessing in ["binning", "one-hot"]:
300324
bins = (
301325
self._get_decision_tree_bins(X[[feature]], y, [feature])
302326
if self.use_decision_tree_bins
@@ -308,7 +332,11 @@ def fit(self, X, y=None):
308332
(
309333
"discretizer",
310334
KBinsDiscretizer(
311-
n_bins=(bins if isinstance(bins, int) else len(bins) - 1),
335+
n_bins=(
336+
bins
337+
if isinstance(bins, int)
338+
else len(bins) - 1
339+
),
312340
encode="ordinal",
313341
strategy=self.binning_strategy, # type: ignore
314342
subsample=200_000 if len(X) > 200_000 else None,
@@ -326,47 +354,55 @@ def fit(self, X, y=None):
326354
]
327355
)
328356

329-
if self.numerical_preprocessing == "one-hot":
357+
if feature_preprocessing == "one-hot":
330358
numeric_transformer_steps.extend(
331359
[
332360
("onehot_from_ordinal", OneHotFromOrdinal()),
333361
]
334362
)
335363

336-
elif self.numerical_preprocessing == "standardization":
364+
elif feature_preprocessing == "standardization":
337365
numeric_transformer_steps.append(("scaler", StandardScaler()))
338366

339-
elif self.numerical_preprocessing == "minmax":
340-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
367+
elif feature_preprocessing == "minmax":
368+
numeric_transformer_steps.append(
369+
("minmax", MinMaxScaler(feature_range=(-1, 1)))
370+
)
341371

342-
elif self.numerical_preprocessing == "quantile":
372+
elif feature_preprocessing == "quantile":
343373
numeric_transformer_steps.append(
344374
(
345375
"quantile",
346-
QuantileTransformer(n_quantiles=self.n_bins, random_state=101),
376+
QuantileTransformer(
377+
n_quantiles=self.n_bins, random_state=101
378+
),
347379
)
348380
)
349381

350-
elif self.numerical_preprocessing == "polynomial":
382+
elif feature_preprocessing == "polynomial":
351383
if self.scaling_strategy == "standardization":
352384
numeric_transformer_steps.append(("scaler", StandardScaler()))
353385
elif self.scaling_strategy == "minmax":
354-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
386+
numeric_transformer_steps.append(
387+
("minmax", MinMaxScaler(feature_range=(-1, 1)))
388+
)
355389
numeric_transformer_steps.append(
356390
(
357391
"polynomial",
358392
PolynomialFeatures(self.degree, include_bias=False),
359393
)
360394
)
361395

362-
elif self.numerical_preprocessing == "robust":
396+
elif feature_preprocessing == "robust":
363397
numeric_transformer_steps.append(("robust", RobustScaler()))
364398

365-
elif self.numerical_preprocessing == "splines":
399+
elif feature_preprocessing == "splines":
366400
if self.scaling_strategy == "standardization":
367401
numeric_transformer_steps.append(("scaler", StandardScaler()))
368402
elif self.scaling_strategy == "minmax":
369-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
403+
numeric_transformer_steps.append(
404+
("minmax", MinMaxScaler(feature_range=(-1, 1)))
405+
)
370406
numeric_transformer_steps.append(
371407
(
372408
"splines",
@@ -381,11 +417,13 @@ def fit(self, X, y=None):
381417
),
382418
)
383419

384-
elif self.numerical_preprocessing == "rbf":
420+
elif feature_preprocessing == "rbf":
385421
if self.scaling_strategy == "standardization":
386422
numeric_transformer_steps.append(("scaler", StandardScaler()))
387423
elif self.scaling_strategy == "minmax":
388-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
424+
numeric_transformer_steps.append(
425+
("minmax", MinMaxScaler(feature_range=(-1, 1)))
426+
)
389427
numeric_transformer_steps.append(
390428
(
391429
"rbf",
@@ -398,11 +436,13 @@ def fit(self, X, y=None):
398436
)
399437
)
400438

401-
elif self.numerical_preprocessing == "sigmoid":
439+
elif feature_preprocessing == "sigmoid":
402440
if self.scaling_strategy == "standardization":
403441
numeric_transformer_steps.append(("scaler", StandardScaler()))
404442
elif self.scaling_strategy == "minmax":
405-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
443+
numeric_transformer_steps.append(
444+
("minmax", MinMaxScaler(feature_range=(-1, 1)))
445+
)
406446
numeric_transformer_steps.append(
407447
(
408448
"sigmoid",
@@ -415,27 +455,31 @@ def fit(self, X, y=None):
415455
)
416456
)
417457

418-
elif self.numerical_preprocessing == "ple":
419-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
420-
numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task)))
458+
elif feature_preprocessing == "ple":
459+
numeric_transformer_steps.append(
460+
("minmax", MinMaxScaler(feature_range=(-1, 1)))
461+
)
462+
numeric_transformer_steps.append(
463+
("ple", PLE(n_bins=self.n_bins, task=self.task))
464+
)
421465

422-
elif self.numerical_preprocessing == "box-cox":
466+
elif feature_preprocessing == "box-cox":
423467
numeric_transformer_steps.append(
424468
(
425469
"box-cox",
426470
PowerTransformer(method="box-cox", standardize=True),
427471
)
428472
)
429473

430-
elif self.numerical_preprocessing == "yeo-johnson":
474+
elif feature_preprocessing == "yeo-johnson":
431475
numeric_transformer_steps.append(
432476
(
433477
"yeo-johnson",
434478
PowerTransformer(method="yeo-johnson", standardize=True),
435479
)
436480
)
437481

438-
elif self.numerical_preprocessing == "none":
482+
elif feature_preprocessing == "none":
439483
numeric_transformer_steps.append(
440484
(
441485
"none",
@@ -449,15 +493,18 @@ def fit(self, X, y=None):
449493

450494
if categorical_features:
451495
for feature in categorical_features:
452-
if self.categorical_preprocessing == "int":
496+
feature_preprocessing = self.feature_preprocessing.get(
497+
feature, self.categorical_preprocessing
498+
)
499+
if feature_preprocessing == "int":
453500
# Use ContinuousOrdinalEncoder for "int"
454501
categorical_transformer = Pipeline(
455502
[
456503
("imputer", SimpleImputer(strategy="most_frequent")),
457504
("continuous_ordinal", ContinuousOrdinalEncoder()),
458505
]
459506
)
460-
elif self.categorical_preprocessing == "one-hot":
507+
elif feature_preprocessing == "one-hot":
461508
# Use OneHotEncoder for "one-hot"
462509
categorical_transformer = Pipeline(
463510
[
@@ -467,28 +514,34 @@ def fit(self, X, y=None):
467514
]
468515
)
469516

470-
elif self.categorical_preprocessing == "none":
517+
elif feature_preprocessing == "none":
471518
# Use OneHotEncoder for "one-hot"
472519
categorical_transformer = Pipeline(
473520
[
474521
("imputer", SimpleImputer(strategy="most_frequent")),
475522
("none", NoTransformer()),
476523
]
477524
)
478-
elif self.categorical_preprocessing == "pretrained":
525+
elif feature_preprocessing == "pretrained":
479526
categorical_transformer = Pipeline(
480527
[
481528
("imputer", SimpleImputer(strategy="most_frequent")),
482529
("pretrained", LanguageEmbeddingTransformer()),
483530
]
484531
)
485532
else:
486-
raise ValueError(f"Unknown categorical_preprocessing type: {self.categorical_preprocessing}")
533+
raise ValueError(
534+
f"Unknown categorical_preprocessing type: {feature_preprocessing}"
535+
)
487536

488537
# Append the transformer for the current categorical feature
489-
transformers.append((f"cat_{feature}", categorical_transformer, [feature]))
538+
transformers.append(
539+
(f"cat_{feature}", categorical_transformer, [feature])
540+
)
490541

491-
self.column_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough")
542+
self.column_transformer = ColumnTransformer(
543+
transformers=transformers, remainder="passthrough"
544+
)
492545
self.column_transformer.fit(X, y)
493546

494547
self.fitted = True
@@ -514,13 +567,17 @@ def _get_decision_tree_bins(self, X, y, numerical_features):
514567
bins = []
515568
for feature in numerical_features:
516569
tree_model = (
517-
DecisionTreeClassifier(max_depth=3) if y.dtype.kind in "bi" else DecisionTreeRegressor(max_depth=3)
570+
DecisionTreeClassifier(max_depth=3)
571+
if y.dtype.kind in "bi"
572+
else DecisionTreeRegressor(max_depth=3)
518573
)
519574
tree_model.fit(X[[feature]], y)
520575
thresholds = tree_model.tree_.threshold[tree_model.tree_.feature != -2] # type: ignore
521576
bin_edges = np.sort(np.unique(thresholds))
522577

523-
bins.append(np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()])))
578+
bins.append(
579+
np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()]))
580+
)
524581
return bins
525582

526583
def transform(self, X):
@@ -676,7 +733,9 @@ def get_feature_info(self, verbose=True):
676733
"categories": None, # Numerical features don't have categories
677734
}
678735
if verbose:
679-
print(f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}")
736+
print(
737+
f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}"
738+
)
680739

681740
# Categorical features
682741
elif "continuous_ordinal" in steps:

0 commit comments

Comments
 (0)