Skip to content

Commit a44ec51

Browse files
committed
include data check in preprocessor
1 parent 0d7be4a commit a44ec51

4 files changed

Lines changed: 129 additions & 2 deletions

File tree

mambular/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
# The following line *must* be the last in the module, exactly as formatted:
1919

20-
__version__ = "1.3.2"
20+
__version__ = "1.4.0"

mambular/preprocessing/preprocessor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
OneHotFromOrdinal,
2828
ToFloatTransformer,
2929
)
30+
from .utils import check_inputs
3031
from sklearn.base import TransformerMixin
3132

3233

@@ -118,6 +119,7 @@ def __init__(
118119
use_decision_tree_knots=True,
119120
knots_strategy="uniform",
120121
spline_implementation="sklearn",
122+
min_unique_vals=5,
121123
):
122124
self.n_bins = n_bins
123125
self.numerical_preprocessing = (
@@ -176,6 +178,7 @@ def __init__(
176178
self.use_decision_tree_knots = use_decision_tree_knots
177179
self.knots_strategy = knots_strategy
178180
self.spline_implementation = spline_implementation
181+
self.min_unique_vals = min_unique_vals
179182

180183
def get_params(self, deep=True):
181184
"""Get parameters for the preprocessor.
@@ -307,6 +310,15 @@ def fit(self, X, y=None, embeddings=None):
307310
self._fit_embeddings(embeddings)
308311

309312
numerical_features, categorical_features = self._detect_column_types(X)
313+
314+
check_inputs(
315+
X,
316+
y,
317+
numerical_features,
318+
categorical_features,
319+
task_type=self.task,
320+
min_samples=self.min_unique_vals,
321+
)
310322
transformers = []
311323

312324
if numerical_features:

mambular/preprocessing/utils.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import pandas as pd
2+
import numpy as np
3+
import warnings
4+
5+
6+
def check_inputs(
7+
X,
8+
y=None,
9+
numerical_columns=None,
10+
categorical_columns=None,
11+
task_type=None,
12+
min_samples=5,
13+
):
14+
"""
15+
Perform thorough validation on input features and target.
16+
17+
Parameters
18+
----------
19+
X : pd.DataFrame or dict
20+
Input features.
21+
y : array-like, optional
22+
Target values.
23+
numerical_columns : list of str
24+
Columns expected to be numerical.
25+
categorical_columns : list of str
26+
Columns expected to be categorical.
27+
task_type : str, optional
28+
One of {"regression", "binary", "multiclass"}. If specified, target checks will apply accordingly.
29+
min_samples : int, optional
30+
Minimum number of distinct values required in any feature or target.
31+
32+
Raises
33+
------
34+
ValueError
35+
If any feature or target fails validation checks.
36+
"""
37+
if isinstance(X, dict):
38+
X = pd.DataFrame(X)
39+
40+
if not isinstance(X, pd.DataFrame):
41+
raise TypeError("X must be a DataFrame or a dict convertible to DataFrame.")
42+
43+
if X.empty:
44+
raise ValueError("X must not be empty.")
45+
46+
if numerical_columns is None:
47+
numerical_columns = []
48+
if categorical_columns is None:
49+
categorical_columns = []
50+
51+
all_cols = set(numerical_columns) | set(categorical_columns)
52+
missing_cols = all_cols - set(X.columns)
53+
if missing_cols:
54+
raise ValueError(
55+
f"The following specified columns are missing in X: {missing_cols}"
56+
)
57+
58+
# Check numerical features
59+
for col in numerical_columns:
60+
series = X[col]
61+
if series.nunique(dropna=False) < min_samples:
62+
raise ValueError(
63+
f"Numerical feature '{col}' has less than {min_samples} unique values."
64+
)
65+
if not np.issubdtype(series.dtype, np.number):
66+
raise TypeError(f"Numerical feature '{col}' must be numeric.")
67+
if not np.all(np.isfinite(series.dropna())):
68+
raise ValueError(
69+
f"Numerical feature '{col}' contains non-finite values (inf or NaN)."
70+
)
71+
72+
# Check categorical features
73+
for col in categorical_columns:
74+
series = X[col]
75+
if series.nunique(dropna=False) < 2:
76+
raise ValueError(
77+
f"Categorical feature '{col}' has less only a single value ."
78+
)
79+
if pd.api.types.is_numeric_dtype(
80+
series
81+
) and not pd.api.types.is_categorical_dtype(series):
82+
# allow numerical dtypes only if user intends to encode them
83+
pass # optionally warn or convert
84+
if series.isnull().all():
85+
raise ValueError(f"Categorical feature '{col}' contains only NaNs.")
86+
87+
# Check y
88+
if y is not None:
89+
y = np.array(y)
90+
91+
if y.ndim != 1:
92+
raise ValueError("y must be a 1D array or Series.")
93+
94+
if len(y) != len(X):
95+
raise ValueError("X and y must have the same number of samples.")
96+
97+
unique_targets = np.unique(y[~pd.isnull(y)])
98+
n_classes = len(unique_targets)
99+
100+
if task_type == "regression":
101+
if not np.issubdtype(y.dtype, np.number):
102+
raise TypeError("For regression, target y must be numeric.")
103+
if not np.all(np.isfinite(y)):
104+
raise ValueError("Target y contains non-finite values.")
105+
106+
if n_classes <= 10:
107+
warnings.warn(
108+
f"Target y has only {n_classes} unique values. "
109+
"Consider if this should be a classification problem instead of regression.",
110+
UserWarning,
111+
)
112+
113+
elif task_type == "classification":
114+
if n_classes < 2:
115+
raise ValueError("Classification tasks requires more than 1 class.")

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[tool.poetry]
22
name = "mambular"
33

4-
version = "1.3.2"
4+
version = "1.4.0"
55

66
description = "A python package for tabular deep learning with mamba blocks."
77
authors = ["Anton Thielmann", "Manish Kumar", "Christoph Weisser"]

0 commit comments

Comments
 (0)