Skip to content

Commit 50a3883

Browse files
committed
make sentence_transformer input optional dependency
1 parent 850a5cc commit 50a3883

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

mambular/preprocessing/prepro_utils.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pandas as pd
33
from sklearn.base import BaseEstimator, TransformerMixin
4-
from sentence_transformers import SentenceTransformer
54

65

76
class CustomBinner(TransformerMixin):
@@ -228,20 +227,29 @@ def transform(self, X):
228227
class LanguageEmbeddingTransformer(TransformerMixin, BaseEstimator):
229228
"""A transformer that encodes categorical text features into embeddings using a pre-trained language model."""
230229

231-
def __init__(self, model_name="paraphrase-MiniLM-L3-v2"):
230+
def __init__(self, model_name="paraphrase-MiniLM-L3-v2", model=None):
232231
"""
233232
Initializes the transformer with a language embedding model.
234233
235234
Parameters:
236-
- model_name (str): The name of the SentenceTransformer model to use.
235+
- model_name (str): The name of the SentenceTransformer model to use (if model is None).
236+
- model (object, optional): A preloaded SentenceTransformer model instance.
237237
"""
238238
self.model_name = model_name
239-
self.model = SentenceTransformer(model_name)
239+
self.model = model # Allow user to pass a preloaded model
240+
241+
if self.model is None:
242+
try:
243+
from sentence_transformers import SentenceTransformer
244+
245+
self.model = SentenceTransformer(model_name)
246+
except ImportError:
247+
raise ImportError(
248+
"sentence-transformers is not installed. Install it via `pip install sentence-transformers` or provide a preloaded model."
249+
)
240250

241251
def fit(self, X, y=None):
242-
"""
243-
Fit method (not required for a transformer but included for compatibility).
244-
"""
252+
"""Fit method (not required for a transformer but included for compatibility)."""
245253
self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1
246254
return self
247255

0 commit comments

Comments
 (0)