Skip to content

Commit 10fd848

Browse files
committed
scaling strategy included for ple, splines etc.
1 parent 7169e96 commit 10fd848

4 files changed

Lines changed: 25 additions & 5 deletions

File tree

mambular/models/sklearn_base_classifier.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, model, config, **kwargs):
2828
"cat_cutoff",
2929
"treat_all_integers_as_numerical",
3030
"degree",
31+
"scaling_strategy",
3132
"n_knots",
3233
"use_decision_tree_knots",
3334
"knots_strategy",

mambular/models/sklearn_base_lss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, model, config, **kwargs):
4949
"cat_cutoff",
5050
"treat_all_integers_as_numerical",
5151
"degree",
52+
"scaling_strategy",
5253
"n_knots",
5354
"use_decision_tree_knots",
5455
"knots_strategy",

mambular/models/sklearn_base_regressor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, model, config, **kwargs):
2626
"cat_cutoff",
2727
"treat_all_integers_as_numerical",
2828
"degree",
29+
"scaling_strategy",
2930
"n_knots",
3031
"use_decision_tree_knots",
3132
"knots_strategy",

mambular/preprocessing/preprocessor.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class Preprocessor:
6060
treat_all_integers_as_numerical : bool, default=False
6161
If True, all integer columns will be treated as numerical, regardless
6262
of their unique value count or proportion.
63+
scaling_strategy : str, default="minmax"
64+
The scaling strategy to use for numerical features before applying PLE, Splines, RBF or Sigmoid.
65+
Options include 'standardization', 'minmax', 'none'.
6366
degree : int, default=3
6467
The degree of the polynomial features to be used in preprocessing. It also affects the degree of
6568
splines if splines are used.
@@ -93,6 +96,7 @@ def __init__(
9396
cat_cutoff=0.03,
9497
treat_all_integers_as_numerical=False,
9598
degree=3,
99+
scaling_strategy="minmax",
96100
n_knots=64,
97101
use_decision_tree_knots=True,
98102
knots_strategy="uniform",
@@ -138,6 +142,7 @@ def __init__(
138142
self.cat_cutoff = cat_cutoff
139143
self.treat_all_integers_as_numerical = treat_all_integers_as_numerical
140144
self.degree = degree
145+
self.scaling_strategy = scaling_strategy
141146
self.n_knots = n_knots
142147
self.use_decision_tree_knots = use_decision_tree_knots
143148
self.knots_strategy = knots_strategy
@@ -166,6 +171,7 @@ def get_params(self, deep=True):
166171
"cat_cutoff": self.cat_cutoff,
167172
"treat_all_integers_as_numerical": self.treat_all_integers_as_numerical,
168173
"degree": self.degree,
174+
"scaling_strategy": self.scaling_strategy,
169175
"n_knots": self.n_knots,
170176
"use_decision_tree_knots": self.use_decision_tree_knots,
171177
"knots_strategy": self.knots_strategy,
@@ -330,7 +336,10 @@ def fit(self, X, y=None):
330336
)
331337

332338
elif self.numerical_preprocessing == "polynomial":
333-
numeric_transformer_steps.append(("scaler", StandardScaler()))
339+
if self.scaling_strategy == "standardization":
340+
numeric_transformer_steps.append(("scaler", StandardScaler()))
341+
elif self.scaling_strategy == "minmax":
342+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
334343
numeric_transformer_steps.append(
335344
(
336345
"polynomial",
@@ -342,8 +351,10 @@ def fit(self, X, y=None):
342351
numeric_transformer_steps.append(("robust", RobustScaler()))
343352

344353
elif self.numerical_preprocessing == "splines":
345-
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
346-
# numeric_transformer_steps.append(("scaler", StandardScaler()))
354+
if self.scaling_strategy == "standardization":
355+
numeric_transformer_steps.append(("scaler", StandardScaler()))
356+
elif self.scaling_strategy == "minmax":
357+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
347358
numeric_transformer_steps.append(
348359
(
349360
"splines",
@@ -359,7 +370,10 @@ def fit(self, X, y=None):
359370
)
360371

361372
elif self.numerical_preprocessing == "rbf":
362-
numeric_transformer_steps.append(("scaler", StandardScaler()))
373+
if self.scaling_strategy == "standardization":
374+
numeric_transformer_steps.append(("scaler", StandardScaler()))
375+
elif self.scaling_strategy == "minmax":
376+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
363377
numeric_transformer_steps.append(
364378
(
365379
"rbf",
@@ -373,7 +387,10 @@ def fit(self, X, y=None):
373387
)
374388

375389
elif self.numerical_preprocessing == "sigmoid":
376-
numeric_transformer_steps.append(("scaler", StandardScaler()))
390+
if self.scaling_strategy == "standardization":
391+
numeric_transformer_steps.append(("scaler", StandardScaler()))
392+
elif self.scaling_strategy == "minmax":
393+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
377394
numeric_transformer_steps.append(
378395
(
379396
"sigmoid",

0 commit comments

Comments
 (0)