Skip to content

Commit 3a769c1

Browse files
committed
formatting, refactor (used exception instead of assert)
1 parent 4bbf174 commit 3a769c1

1 file changed

Lines changed: 42 additions & 96 deletions

File tree

mambular/preprocessing/preprocessor.py

Lines changed: 42 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,10 @@ def __init__(
120120
):
121121
self.n_bins = n_bins
122122
self.numerical_preprocessing = (
123-
numerical_preprocessing.lower()
124-
if numerical_preprocessing is not None
125-
else "none"
123+
numerical_preprocessing.lower() if numerical_preprocessing is not None else "none"
126124
)
127125
self.categorical_preprocessing = (
128-
categorical_preprocessing.lower()
129-
if categorical_preprocessing is not None
130-
else "none"
126+
categorical_preprocessing.lower() if categorical_preprocessing is not None else "none"
131127
)
132128
if self.numerical_preprocessing not in [
133129
"ple",
@@ -251,19 +247,13 @@ def _detect_column_types(self, X):
251247
numerical_features.append(col)
252248
else:
253249
if isinstance(self.cat_cutoff, float):
254-
cutoff_condition = (
255-
num_unique_values / total_samples
256-
) < self.cat_cutoff
250+
cutoff_condition = (num_unique_values / total_samples) < self.cat_cutoff
257251
elif isinstance(self.cat_cutoff, int):
258252
cutoff_condition = num_unique_values < self.cat_cutoff
259253
else:
260-
raise ValueError(
261-
"cat_cutoff should be either a float or an integer."
262-
)
254+
raise ValueError("cat_cutoff should be either a float or an integer.")
263255

264-
if X[col].dtype.kind not in "iufc" or (
265-
X[col].dtype.kind == "i" and cutoff_condition
266-
):
256+
if X[col].dtype.kind not in "iufc" or (X[col].dtype.kind == "i" and cutoff_condition):
267257
categorical_features.append(col)
268258
else:
269259
numerical_features.append(col)
@@ -276,11 +266,9 @@ def _fit_embeddings(self, embeddings):
276266
self.embedding_dimensions = {}
277267
if isinstance(embeddings, np.ndarray):
278268
self.embedding_dimensions["embeddings_1"] = embeddings.shape[1]
279-
elif isinstance(embeddings, list) and all(
280-
isinstance(e, np.ndarray) for e in embeddings
281-
):
269+
elif isinstance(embeddings, list) and all(isinstance(e, np.ndarray) for e in embeddings):
282270
for idx, e in enumerate(embeddings):
283-
self.embedding_dimensions[f"embedding_{idx+1}"] = e.shape[1]
271+
self.embedding_dimensions[f"embedding_{idx + 1}"] = e.shape[1]
284272
else:
285273
self.embeddings = False
286274

@@ -310,9 +298,7 @@ def fit(self, X, y=None, embeddings=None):
310298

311299
if numerical_features:
312300
for feature in numerical_features:
313-
feature_preprocessing = self.feature_preprocessing.get(
314-
feature, self.numerical_preprocessing
315-
)
301+
feature_preprocessing = self.feature_preprocessing.get(feature, self.numerical_preprocessing)
316302

317303
# extended the annotation list if new transformer is added, either from sklearn or custom
318304
numeric_transformer_steps: list[
@@ -348,11 +334,7 @@ def fit(self, X, y=None, embeddings=None):
348334
(
349335
"discretizer",
350336
KBinsDiscretizer(
351-
n_bins=(
352-
bins
353-
if isinstance(bins, int)
354-
else len(bins) - 1
355-
),
337+
n_bins=(bins if isinstance(bins, int) else len(bins) - 1),
356338
encode="ordinal",
357339
strategy=self.binning_strategy, # type: ignore
358340
subsample=200_000 if len(X) > 200_000 else None,
@@ -381,27 +363,21 @@ def fit(self, X, y=None, embeddings=None):
381363
numeric_transformer_steps.append(("scaler", StandardScaler()))
382364

383365
elif feature_preprocessing == "minmax":
384-
numeric_transformer_steps.append(
385-
("minmax", MinMaxScaler(feature_range=(-1, 1)))
386-
)
366+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
387367

388368
elif feature_preprocessing == "quantile":
389369
numeric_transformer_steps.append(
390370
(
391371
"quantile",
392-
QuantileTransformer(
393-
n_quantiles=self.n_bins, random_state=101
394-
),
372+
QuantileTransformer(n_quantiles=self.n_bins, random_state=101),
395373
)
396374
)
397375

398376
elif feature_preprocessing == "polynomial":
399377
if self.scaling_strategy == "standardization":
400378
numeric_transformer_steps.append(("scaler", StandardScaler()))
401379
elif self.scaling_strategy == "minmax":
402-
numeric_transformer_steps.append(
403-
("minmax", MinMaxScaler(feature_range=(-1, 1)))
404-
)
380+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
405381
numeric_transformer_steps.append(
406382
(
407383
"polynomial",
@@ -416,9 +392,7 @@ def fit(self, X, y=None, embeddings=None):
416392
if self.scaling_strategy == "standardization":
417393
numeric_transformer_steps.append(("scaler", StandardScaler()))
418394
elif self.scaling_strategy == "minmax":
419-
numeric_transformer_steps.append(
420-
("minmax", MinMaxScaler(feature_range=(-1, 1)))
421-
)
395+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
422396
numeric_transformer_steps.append(
423397
(
424398
"splines",
@@ -437,9 +411,7 @@ def fit(self, X, y=None, embeddings=None):
437411
if self.scaling_strategy == "standardization":
438412
numeric_transformer_steps.append(("scaler", StandardScaler()))
439413
elif self.scaling_strategy == "minmax":
440-
numeric_transformer_steps.append(
441-
("minmax", MinMaxScaler(feature_range=(-1, 1)))
442-
)
414+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
443415
numeric_transformer_steps.append(
444416
(
445417
"rbf",
@@ -456,9 +428,7 @@ def fit(self, X, y=None, embeddings=None):
456428
if self.scaling_strategy == "standardization":
457429
numeric_transformer_steps.append(("scaler", StandardScaler()))
458430
elif self.scaling_strategy == "minmax":
459-
numeric_transformer_steps.append(
460-
("minmax", MinMaxScaler(feature_range=(-1, 1)))
461-
)
431+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
462432
numeric_transformer_steps.append(
463433
(
464434
"sigmoid",
@@ -471,21 +441,16 @@ def fit(self, X, y=None, embeddings=None):
471441
)
472442
)
473443

474-
475444
elif feature_preprocessing == "ple":
476-
numeric_transformer_steps.append(
477-
("minmax", MinMaxScaler(feature_range=(-1, 1)))
478-
)
479-
numeric_transformer_steps.append(
480-
("ple", PLE(n_bins=self.n_bins, task=self.task))
481-
)
445+
numeric_transformer_steps.append(("minmax", MinMaxScaler(feature_range=(-1, 1))))
446+
numeric_transformer_steps.append(("ple", PLE(n_bins=self.n_bins, task=self.task)))
482447

483448
elif feature_preprocessing == "box-cox":
484449
numeric_transformer_steps.append(
485-
("minmax", MinMaxScaler(feature_range=(1e-03, 1)))
450+
("minmax", MinMaxScaler(feature_range=(1e-03, 1))) # type: ignore
486451
)
487452
numeric_transformer_steps.append(
488-
("check_positive", MinMaxScaler(feature_range=(1e-3, 1)))
453+
("check_positive", MinMaxScaler(feature_range=(1e-3, 1))) # type: ignore
489454
)
490455
numeric_transformer_steps.append(
491456
(
@@ -516,9 +481,7 @@ def fit(self, X, y=None, embeddings=None):
516481

517482
if categorical_features:
518483
for feature in categorical_features:
519-
feature_preprocessing = self.feature_preprocessing.get(
520-
feature, self.categorical_preprocessing
521-
)
484+
feature_preprocessing = self.feature_preprocessing.get(feature, self.categorical_preprocessing)
522485
if feature_preprocessing == "int":
523486
# Use ContinuousOrdinalEncoder for "int"
524487
categorical_transformer = Pipeline(
@@ -553,18 +516,12 @@ def fit(self, X, y=None, embeddings=None):
553516
]
554517
)
555518
else:
556-
raise ValueError(
557-
f"Unknown categorical_preprocessing type: {feature_preprocessing}"
558-
)
519+
raise ValueError(f"Unknown categorical_preprocessing type: {feature_preprocessing}")
559520

560521
# Append the transformer for the current categorical feature
561-
transformers.append(
562-
(f"cat_{feature}", categorical_transformer, [feature])
563-
)
522+
transformers.append((f"cat_{feature}", categorical_transformer, [feature]))
564523

565-
self.column_transformer = ColumnTransformer(
566-
transformers=transformers, remainder="passthrough"
567-
)
524+
self.column_transformer = ColumnTransformer(transformers=transformers, remainder="passthrough")
568525
self.column_transformer.fit(X, y)
569526

570527
self.fitted = True
@@ -590,17 +547,13 @@ def _get_decision_tree_bins(self, X, y, numerical_features):
590547
bins = []
591548
for feature in numerical_features:
592549
tree_model = (
593-
DecisionTreeClassifier(max_depth=3)
594-
if y.dtype.kind in "bi"
595-
else DecisionTreeRegressor(max_depth=3)
550+
DecisionTreeClassifier(max_depth=3) if y.dtype.kind in "bi" else DecisionTreeRegressor(max_depth=3)
596551
)
597552
tree_model.fit(X[[feature]], y)
598553
thresholds = tree_model.tree_.threshold[tree_model.tree_.feature != -2] # type: ignore
599554
bin_edges = np.sort(np.unique(thresholds))
600555

601-
bins.append(
602-
np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()]))
603-
)
556+
bins.append(np.concatenate(([X[feature].min()], bin_edges, [X[feature].max()])))
604557
return bins
605558

606559
def transform(self, X, embeddings=None):
@@ -634,30 +587,27 @@ def transform(self, X, embeddings=None):
634587
# Now let's convert this into a dictionary of arrays, one per column
635588
transformed_dict = self._split_transformed_output(X, transformed_X)
636589
if embeddings is not None:
637-
assert self.embeddings is True, "self.embeddings should be True but is not."
590+
if not self.embeddings:
591+
raise ValueError("self.embeddings should be True but is not.")
638592

639593
if isinstance(embeddings, np.ndarray):
640-
assert (
641-
self.embedding_dimensions["embedding_1"] == embeddings.shape[1]
642-
), (
643-
f"Expected embedding dimension {self.embedding_dimensions['embeddings']}, "
644-
f"but got {embeddings.shape[1]}"
645-
)
594+
if self.embedding_dimensions["embedding_1"] != embeddings.shape[1]:
595+
raise ValueError(
596+
f"Expected embedding dimension {self.embedding_dimensions['embedding_1']}, "
597+
f"but got {embeddings.shape[1]}"
598+
)
646599
transformed_dict["embedding_1"] = embeddings.astype(np.float32)
647-
elif isinstance(embeddings, list) and all(
648-
isinstance(e, np.ndarray) for e in embeddings
649-
):
600+
elif isinstance(embeddings, list) and all(isinstance(e, np.ndarray) for e in embeddings):
650601
for idx, e in enumerate(embeddings):
651-
key = f"embedding_{idx+1}"
652-
assert self.embedding_dimensions[key] == e.shape[1], (
653-
f"Expected embedding dimension {self.embedding_dimensions[key]} for {key}, "
654-
f"but got {e.shape[1]}"
655-
)
602+
key = f"embedding_{idx + 1}"
603+
if self.embedding_dimensions[key] != e.shape[1]:
604+
raise ValueError(
605+
f"Expected embedding dimension {self.embedding_dimensions[key]} for {key}, but got {e.shape[1]}"
606+
)
656607
transformed_dict[key] = e.astype(np.float32)
657608
else:
658-
assert (
659-
self.embeddings is False
660-
), "self.embeddings should be False when embeddings are None."
609+
if self.embeddings is not False:
610+
raise ValueError("self.embeddings should be False when embeddings are None.")
661611
self.embeddings = False
662612

663613
return transformed_dict
@@ -790,9 +740,7 @@ def get_feature_info(self, verbose=True):
790740
"categories": None,
791741
}
792742
if verbose:
793-
print(
794-
f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}"
795-
)
743+
print(f"Numerical Feature: {feature_name}, Info: {numerical_feature_info[feature_name]}")
796744

797745
elif "continuous_ordinal" in steps:
798746
step = transformer_pipeline.named_steps["continuous_ordinal"]
@@ -842,9 +790,7 @@ def get_feature_info(self, verbose=True):
842790
"categories": None,
843791
}
844792
if verbose:
845-
print(
846-
f"Feature: {feature_name}, Info: {preprocessing_type}, Dimension: {dimension}"
847-
)
793+
print(f"Feature: {feature_name}, Info: {preprocessing_type}, Dimension: {dimension}")
848794

849795
if verbose:
850796
print("-" * 50)

0 commit comments

Comments
 (0)