Skip to content

Commit 7169e96

Browse files
committed
rbf and sigmoid expansion
1 parent 8e8579c commit 7169e96

1 file changed

Lines changed: 42 additions & 6 deletions

File tree

mambular/preprocessing/preprocessor.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
1919

20-
from .basis_expansion import SplineExpansion
20+
from .basis_expansion import RBFExpansion, SigmoidExpansion, SplineExpansion
2121
from .ple_encoding import PLE
2222
from .prepro_utils import ContinuousOrdinalEncoder, CustomBinner, NoTransformer, OneHotFromOrdinal, ToFloatTransformer
2323

@@ -38,7 +38,7 @@ class Preprocessor:
3838
only if `numerical_preprocessing` is set to 'binning', 'ple' or 'one-hot'.
3939
numerical_preprocessing : str, default="ple"
4040
The preprocessing strategy for numerical features. Valid options are
41-
'ple', 'binning', 'one-hot', 'standardization', 'min-max', 'quantile', 'polynomial', 'robust',
41+
'ple', 'binning', 'one-hot', 'standardization', 'min-max', 'quantile', 'polynomial', 'robust', 'rbf', 'sigmoid'.
4242
'splines', 'box-cox', 'yeo-johnson' and None
4343
categorical_preprocessing : str, default="int"
4444
The preprocessing strategy for categorical features. Valid options are
@@ -67,7 +67,7 @@ class Preprocessor:
6767
The number of knots to be used in spline transformations.
6868
use_decision_tree_knots : bool, default=True
6969
If True, uses decision tree regression to determine optimal knot positions for splines.
70-
knots_strategy : str, default="uniform"
70+
knots_strategy : str, default="quantile"
7171
Defines the strategy for determining knot positions in spline transformations
7272
if `use_decision_tree_knots` is False. Options include 'uniform', 'quantile'.
7373
spline_implementation : str, default="sklearn"
@@ -117,11 +117,14 @@ def __init__(
117117
"splines",
118118
"box-cox",
119119
"yeo-johnson",
120+
"rbf",
121+
"sigmoid",
120122
"none",
121123
]:
122124
raise ValueError(
123125
"Invalid numerical_preprocessing value. Supported values are 'ple', 'binning', 'box-cox', \
124-
'one-hot', 'standardization', 'quantile', 'polynomial', 'splines', 'minmax' , 'robust' or 'None'."
126+
'one-hot', 'standardization', 'quantile', 'polynomial', 'splines', 'minmax' , 'robust',\
127+
'rbf', 'sigmoid', or 'None'."
125128
)
126129

127130
if self.categorical_preprocessing not in ["int", "one-hot", "none"]:
@@ -247,6 +250,8 @@ def fit(self, X, y=None):
247250
X = pd.DataFrame(X)
248251

249252
numerical_features, categorical_features = self._detect_column_types(X)
253+
print("Numerical features:", numerical_features)
254+
print("Categorical features:", categorical_features)
250255
transformers = []
251256

252257
if numerical_features:
@@ -268,7 +273,9 @@ def fit(self, X, y=None):
268273
| PLE
269274
| PowerTransformer
270275
| NoTransformer
271-
| SplineExpansion,
276+
| SplineExpansion
277+
| RBFExpansion
278+
| SigmoidExpansion,
272279
]
273280
] = [("imputer", SimpleImputer(strategy="mean"))]
274281
if self.numerical_preprocessing in ["binning", "one-hot"]:
@@ -335,7 +342,8 @@ def fit(self, X, y=None):
335342
numeric_transformer_steps.append(("robust", RobustScaler()))
336343

337344
elif self.numerical_preprocessing == "splines":
338-
numeric_transformer_steps.append(("scaler", StandardScaler()))
345+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
346+
# numeric_transformer_steps.append(("scaler", StandardScaler()))
339347
numeric_transformer_steps.append(
340348
(
341349
"splines",
@@ -350,6 +358,34 @@ def fit(self, X, y=None):
350358
),
351359
)
352360

361+
elif self.numerical_preprocessing == "rbf":
362+
numeric_transformer_steps.append(("scaler", StandardScaler()))
363+
numeric_transformer_steps.append(
364+
(
365+
"rbf",
366+
RBFExpansion(
367+
n_centers=self.n_knots,
368+
use_decision_tree=self.use_decision_tree_knots,
369+
strategy=self.knots_strategy,
370+
task=self.task,
371+
),
372+
)
373+
)
374+
375+
elif self.numerical_preprocessing == "sigmoid":
376+
numeric_transformer_steps.append(("scaler", StandardScaler()))
377+
numeric_transformer_steps.append(
378+
(
379+
"sigmoid",
380+
SigmoidExpansion(
381+
n_centers=self.n_knots,
382+
use_decision_tree=self.use_decision_tree_knots,
383+
strategy=self.knots_strategy,
384+
task=self.task,
385+
),
386+
)
387+
)
388+
353389
elif self.numerical_preprocessing == "ple":
354390
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
355391
numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task)))

0 commit comments

Comments
 (0)