|
1 | 1 | import numpy as np |
2 | 2 | import pandas as pd |
3 | 3 | from sklearn.base import BaseEstimator, TransformerMixin |
4 | | -from sentence_transformers import SentenceTransformer |
5 | 4 |
|
6 | 5 |
|
7 | 6 | class CustomBinner(TransformerMixin): |
@@ -228,20 +227,29 @@ def transform(self, X): |
228 | 227 | class LanguageEmbeddingTransformer(TransformerMixin, BaseEstimator): |
229 | 228 | """A transformer that encodes categorical text features into embeddings using a pre-trained language model.""" |
230 | 229 |
|
231 | | - def __init__(self, model_name="paraphrase-MiniLM-L3-v2"): |
| 230 | + def __init__(self, model_name="paraphrase-MiniLM-L3-v2", model=None): |
232 | 231 | """ |
233 | 232 | Initializes the transformer with a language embedding model. |
234 | 233 |
|
235 | 234 | Parameters: |
236 | | - - model_name (str): The name of the SentenceTransformer model to use. |
| 235 | + - model_name (str): The name of the SentenceTransformer model to use (if model is None). |
| 236 | + - model (object, optional): A preloaded SentenceTransformer model instance. |
237 | 237 | """ |
238 | 238 | self.model_name = model_name |
239 | | - self.model = SentenceTransformer(model_name) |
| 239 | + self.model = model # Allow user to pass a preloaded model |
| 240 | + |
| 241 | + if self.model is None: |
| 242 | + try: |
| 243 | + from sentence_transformers import SentenceTransformer |
| 244 | + |
| 245 | + self.model = SentenceTransformer(model_name) |
| 246 | + except ImportError: |
| 247 | + raise ImportError( |
| 248 | + "sentence-transformers is not installed. Install it via `pip install sentence-transformers` or provide a preloaded model." |
| 249 | + ) |
240 | 250 |
|
241 | 251 | def fit(self, X, y=None): |
242 | | - """ |
243 | | - Fit method (not required for a transformer but included for compatibility). |
244 | | - """ |
| 252 | + """Fit method (not required for a transformer but included for compatibility).""" |
245 | 253 | self.n_features_in_ = X.shape[1] if len(X.shape) > 1 else 1 |
246 | 254 | return self |
247 | 255 |
|
|
0 commit comments