Skip to content

Commit 4ff042f

Browse files
committed
Update classifier.py
1 parent d8f7981 commit 4ff042f

1 file changed

Lines changed: 39 additions & 3 deletions

File tree

src/adaptive_classifier/classifier.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,9 @@ def push_to_hub(
10441044
repo_id: str,
10451045
include_onnx: bool = True,
10461046
quantize_onnx: bool = True,
1047+
token: Optional[str] = None,
1048+
commit_message: Optional[str] = None,
1049+
private: bool = False,
10471050
**kwargs
10481051
):
10491052
"""Push model to HuggingFace Hub with ONNX export by default.
@@ -1052,14 +1055,27 @@ def push_to_hub(
10521055
repo_id: Repository ID on HuggingFace Hub (e.g., "username/model-name")
10531056
include_onnx: Whether to include ONNX version of the model (default: True)
10541057
quantize_onnx: Whether to quantize the ONNX model (requires include_onnx=True)
1055-
**kwargs: Additional arguments passed to ModelHub push_to_hub
1058+
token: HuggingFace Hub authentication token (or set HF_TOKEN env var)
1059+
commit_message: Commit message for the push
1060+
private: Whether to create a private repository
1061+
**kwargs: Additional arguments passed to HfApi.upload_folder
10561062
10571063
Examples:
10581064
>>> classifier.push_to_hub("my-org/my-classifier") # ONNX included by default
10591065
>>> classifier.push_to_hub("my-org/my-classifier", quantize_onnx=True)
10601066
>>> classifier.push_to_hub("my-org/my-classifier", include_onnx=False) # Opt-out
10611067
"""
10621068
import tempfile
1069+
import os
1070+
from huggingface_hub import HfApi
1071+
1072+
# Get token from parameter or environment
1073+
token = token or os.environ.get("HF_TOKEN")
1074+
if not token:
1075+
logger.warning(
1076+
"No HuggingFace token provided. Set HF_TOKEN environment variable or pass token parameter. "
1077+
"You may need to login with `huggingface-cli login`"
1078+
)
10631079

10641080
# Create temporary directory for saving
10651081
with tempfile.TemporaryDirectory() as tmpdir:
@@ -1072,12 +1088,32 @@ def push_to_hub(
10721088
quantize_onnx=quantize_onnx
10731089
)
10741090

1075-
# Use parent class push_to_hub to upload all files
1076-
super().push_to_hub(
1091+
# Use HfApi to upload the folder directly
1092+
api = HfApi()
1093+
1094+
# Create repo if it doesn't exist
1095+
try:
1096+
api.create_repo(
1097+
repo_id=repo_id,
1098+
token=token,
1099+
private=private,
1100+
exist_ok=True
1101+
)
1102+
except Exception as e:
1103+
logger.warning(f"Could not create repo (may already exist): {e}")
1104+
1105+
# Upload all files from the temp directory
1106+
commit_info = api.upload_folder(
1107+
folder_path=str(save_path),
10771108
repo_id=repo_id,
1109+
token=token,
1110+
commit_message=commit_message or "Upload model with adaptive-classifier",
10781111
**kwargs
10791112
)
10801113

1114+
logger.info(f"Successfully pushed model to https://huggingface.co/{repo_id}")
1115+
return f"https://huggingface.co/{repo_id}"
1116+
10811117
# Keep existing save/load methods for backwards compatibility
10821118
def save(self, save_dir: str, include_onnx: bool = True, quantize_onnx: bool = True):
10831119
"""Legacy save method for backwards compatibility.

0 commit comments

Comments
 (0)