Skip to content

Commit 6e52342

Browse files
authored
feat: Add loading from st (#151)
* Add ability to load from sentence-transformers * Remove legacy loading
1 parent e25988e commit 6e52342

3 files changed

Lines changed: 60 additions & 42 deletions

File tree

model2vec/hf_utils.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _create_model_card(
8484

8585

8686
def load_pretrained(
87-
folder_or_repo_path: str | Path, token: str | None = None
87+
folder_or_repo_path: str | Path, token: str | None = None, from_sentence_transformers: bool = False
8888
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
8989
"""
9090
Loads a pretrained model from a folder.
@@ -93,26 +93,31 @@ def load_pretrained(
9393
- If this is a local path, we will load from the local path.
9494
- If the local path is not found, we will attempt to load from the huggingface hub.
9595
:param token: The huggingface token to use.
96+
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
9697
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
9798
:return: The embeddings, tokenizer, config, and metadata.
9899
99100
"""
101+
if from_sentence_transformers:
102+
model_file = "0_StaticEmbedding/model.safetensors"
103+
tokenizer_file = "0_StaticEmbedding/tokenizer.json"
104+
config_name = "config_sentence_transformers.json"
105+
else:
106+
model_file = "model.safetensors"
107+
tokenizer_file = "tokenizer.json"
108+
config_name = "config.json"
109+
100110
folder_or_repo_path = Path(folder_or_repo_path)
101111
if folder_or_repo_path.exists():
102-
embeddings_path = folder_or_repo_path / "model.safetensors"
112+
embeddings_path = folder_or_repo_path / model_file
103113
if not embeddings_path.exists():
104-
old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
105-
if old_embeddings_path.exists():
106-
logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.")
107-
embeddings_path = old_embeddings_path
108-
else:
109-
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")
110-
111-
config_path = folder_or_repo_path / "config.json"
114+
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")
115+
116+
config_path = folder_or_repo_path / config_name
112117
if not config_path.exists():
113118
raise FileNotFoundError(f"Config file does not exist in {folder_or_repo_path}")
114119

115-
tokenizer_path = folder_or_repo_path / "tokenizer.json"
120+
tokenizer_path = folder_or_repo_path / tokenizer_file
116121
if not tokenizer_path.exists():
117122
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")
118123

@@ -122,18 +127,7 @@ def load_pretrained(
122127

123128
else:
124129
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
125-
try:
126-
embeddings_path = huggingface_hub.hf_hub_download(
127-
folder_or_repo_path.as_posix(), "model.safetensors", token=token
128-
)
129-
except huggingface_hub.utils.EntryNotFoundError as e:
130-
try:
131-
embeddings_path = huggingface_hub.hf_hub_download(
132-
folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token
133-
)
134-
except huggingface_hub.utils.EntryNotFoundError:
135-
# Raise original exception.
136-
raise e
130+
embeddings_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), model_file, token=token)
137131

138132
try:
139133
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
@@ -142,11 +136,14 @@ def load_pretrained(
142136
logger.info("No README found in the model folder. No model card loaded.")
143137
metadata = {}
144138

145-
config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
146-
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)
139+
config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), config_name, token=token)
140+
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), tokenizer_file, token=token)
147141

148142
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
149-
embeddings = opened_tensor_file.get_tensor("embeddings")
143+
if from_sentence_transformers:
144+
embeddings = opened_tensor_file.get_tensor("embedding.weight")
145+
else:
146+
embeddings = opened_tensor_file.get_tensor("embeddings")
150147

151148
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
152149
config = json.load(open(config_path))

model2vec/model.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,33 @@ def from_pretrained(
160160
"""
161161
from model2vec.hf_utils import load_pretrained
162162

163-
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)
163+
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)
164164

165165
return cls(
166166
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
167167
)
168168

169+
@classmethod
170+
def from_sentence_transformers(
171+
cls: type[StaticModel],
172+
path: PathLike,
173+
token: str | None = None,
174+
) -> StaticModel:
175+
"""
176+
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
177+
178+
NOTE: if you load a private model from the huggingface hub, you need to pass a token.
179+
180+
:param path: The path to load your static model from.
181+
:param token: The huggingface token to use.
182+
:return: A StaticModel
183+
"""
184+
from model2vec.hf_utils import load_pretrained
185+
186+
embeddings, tokenizer, config, _ = load_pretrained(path, token=token, from_sentence_transformers=True)
187+
188+
return cls(embeddings, tokenizer, config, base_model_name=None, language=None)
189+
169190
def encode_as_sequence(
170191
self,
171192
sentences: list[str] | str,

uv.lock

Lines changed: 15 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)