@@ -84,7 +84,7 @@ def _create_model_card(
8484
8585
8686def load_pretrained (
87- folder_or_repo_path : str | Path , token : str | None = None
87+ folder_or_repo_path : str | Path , token : str | None = None , from_sentence_transformers : bool = False
8888) -> tuple [np .ndarray , Tokenizer , dict [str , Any ], dict [str , Any ]]:
8989 """
9090 Loads a pretrained model from a folder.
@@ -93,26 +93,31 @@ def load_pretrained(
9393 - If this is a local path, we will load from the local path.
9494 - If the local path is not found, we will attempt to load from the huggingface hub.
9595 :param token: The huggingface token to use.
96+ :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
9697 :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
9798 :return: The embeddings, tokenizer, config, and metadata.
9899
99100 """
101+ if from_sentence_transformers :
102+ model_file = "0_StaticEmbedding/model.safetensors"
103+ tokenizer_file = "0_StaticEmbedding/tokenizer.json"
104+ config_name = "config_sentence_transformers.json"
105+ else :
106+ model_file = "model.safetensors"
107+ tokenizer_file = "tokenizer.json"
108+ config_name = "config.json"
109+
100110 folder_or_repo_path = Path (folder_or_repo_path )
101111 if folder_or_repo_path .exists ():
102- embeddings_path = folder_or_repo_path / "model.safetensors"
112+ embeddings_path = folder_or_repo_path / model_file
103113 if not embeddings_path .exists ():
104- old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
105- if old_embeddings_path .exists ():
106- logger .warning ("Old embeddings file found. Please rename to `model.safetensors` and re-save." )
107- embeddings_path = old_embeddings_path
108- else :
109- raise FileNotFoundError (f"Embeddings file does not exist in { folder_or_repo_path } " )
110-
111- config_path = folder_or_repo_path / "config.json"
114+ raise FileNotFoundError (f"Embeddings file does not exist in { folder_or_repo_path } " )
115+
116+ config_path = folder_or_repo_path / config_name
112117 if not config_path .exists ():
113118 raise FileNotFoundError (f"Config file does not exist in { folder_or_repo_path } " )
114119
115- tokenizer_path = folder_or_repo_path / "tokenizer.json"
120+ tokenizer_path = folder_or_repo_path / tokenizer_file
116121 if not tokenizer_path .exists ():
117122 raise FileNotFoundError (f"Tokenizer file does not exist in { folder_or_repo_path } " )
118123
@@ -122,18 +127,7 @@ def load_pretrained(
122127
123128 else :
124129 logger .info ("Folder does not exist locally, attempting to use huggingface hub." )
125- try :
126- embeddings_path = huggingface_hub .hf_hub_download (
127- folder_or_repo_path .as_posix (), "model.safetensors" , token = token
128- )
129- except huggingface_hub .utils .EntryNotFoundError as e :
130- try :
131- embeddings_path = huggingface_hub .hf_hub_download (
132- folder_or_repo_path .as_posix (), "embeddings.safetensors" , token = token
133- )
134- except huggingface_hub .utils .EntryNotFoundError :
135- # Raise original exception.
136- raise e
130+ embeddings_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), model_file , token = token )
137131
138132 try :
139133 readme_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), "README.md" , token = token )
@@ -142,11 +136,14 @@ def load_pretrained(
142136 logger .info ("No README found in the model folder. No model card loaded." )
143137 metadata = {}
144138
145- config_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), "config.json" , token = token )
146- tokenizer_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), "tokenizer.json" , token = token )
139+ config_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), config_name , token = token )
140+ tokenizer_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), tokenizer_file , token = token )
147141
148142 opened_tensor_file = cast (SafeOpenProtocol , safetensors .safe_open (embeddings_path , framework = "numpy" ))
149- embeddings = opened_tensor_file .get_tensor ("embeddings" )
143+ if from_sentence_transformers :
144+ embeddings = opened_tensor_file .get_tensor ("embedding.weight" )
145+ else :
146+ embeddings = opened_tensor_file .get_tensor ("embeddings" )
150147
151148 tokenizer : Tokenizer = Tokenizer .from_file (str (tokenizer_path ))
152149 config = json .load (open (config_path ))
0 commit comments