Skip to content

Commit 4e54bfc

Browse files
committed
fix input data for transform
1 parent c59367d commit 4e54bfc

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

mambular/preprocessing/preprocessor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,10 @@ def transform(self, X, embeddings=None):
582582
raise NotFittedError(
583583
"The preprocessor must be fitted before transforming new data. Use .fit or .fit_transform"
584584
)
585+
if isinstance(X, np.ndarray):
586+
X = pd.DataFrame(X)
587+
else:
588+
X = X.copy()
585589
transformed_X = self.column_transformer.transform(X) # type: ignore
586590

587591
# Now let's convert this into a dictionary of arrays, one per column

0 commit comments

Comments
 (0)