Skip to content

Commit cf801f2

Browse files
committed
Add quantized ONNX model loading option
Enhanced the AdaptiveClassifier to support loading both quantized and unquantized ONNX models, with quantized as the default for improved performance. Updated the README with usage instructions and clarified the behavior of ONNX model selection for saving and loading.
1 parent 7ce0156 commit cf801f2

2 files changed

Lines changed: 54 additions & 9 deletions

File tree

README.md

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,19 +221,31 @@ predictions = classifier.predict("Fast inference!")
221221
#### Save & Deploy with ONNX
222222

223223
```python
224-
# Save with ONNX export (included by default)
224+
# Save with ONNX export (both quantized & unquantized versions)
225225
classifier.save("./model")
226226

227-
# Push to Hub with ONNX (included by default)
227+
# Push to Hub with ONNX (both versions included by default)
228228
classifier.push_to_hub("username/model")
229229

230-
# Load automatically uses ONNX on CPU
230+
# Load automatically uses quantized ONNX on CPU (fastest, 4x smaller)
231231
fast_classifier = AdaptiveClassifier.load("./model")
232232

233-
# Opt-out if you don't want ONNX export
233+
# Choose unquantized ONNX for maximum accuracy
234+
accurate_classifier = AdaptiveClassifier.load("./model", prefer_quantized=False)
235+
236+
# Force PyTorch (no ONNX)
237+
pytorch_classifier = AdaptiveClassifier.load("./model", use_onnx=False)
238+
239+
# Opt-out of ONNX export when saving
234240
classifier.save("./model", include_onnx=False)
235241
```
236242

243+
**ONNX Model Versions:**
244+
- **Quantized (default)**: INT8 quantized, 4x smaller, ~1.14x faster on ARM, 2-4x faster on x86
245+
- **Unquantized**: Full precision, maximum accuracy, larger file size
246+
247+
By default, models are saved with both versions, and the quantized version is automatically loaded for best performance. Use `prefer_quantized=False` if you need maximum accuracy.
248+
237249
#### Benchmark Your Model
238250

239251
```bash

src/adaptive_classifier/classifier.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -650,10 +650,22 @@ def _from_pretrained(
650650
local_files_only: Use local files only, don't download
651651
token: Authentication token for Hub
652652
use_onnx: Whether to use ONNX Runtime ("auto", True, False)
653+
prefer_quantized: Use quantized ONNX model if available (default: True)
654+
Set to False to use unquantized model for maximum accuracy
653655
**kwargs: Additional arguments passed to from_pretrained
654656
655657
Returns:
656658
Loaded AdaptiveClassifier instance
659+
660+
Examples:
661+
>>> # Load with quantized ONNX (default - faster, smaller)
662+
>>> classifier = AdaptiveClassifier.load("adaptive-classifier/llm-router")
663+
>>>
664+
>>> # Load with unquantized ONNX (maximum accuracy)
665+
>>> classifier = AdaptiveClassifier.load("adaptive-classifier/llm-router", prefer_quantized=False)
666+
>>>
667+
>>> # Force PyTorch (no ONNX)
668+
>>> classifier = AdaptiveClassifier.load("adaptive-classifier/llm-router", use_onnx=False)
657669
"""
658670

659671
# Check if model_id is a local directory
@@ -781,8 +793,28 @@ def _from_pretrained(
781793
classifier.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
782794
classifier.use_onnx = True
783795

784-
# Load ONNX model
785-
classifier.model = ORTModelForFeatureExtraction.from_pretrained(onnx_path)
796+
# Load ONNX model (prefer quantized by default)
797+
# Check which ONNX files exist
798+
has_quantized = (onnx_path / "model_quantized.onnx").exists()
799+
has_unquantized = (onnx_path / "model.onnx").exists()
800+
801+
# Determine which file to load
802+
if prefer_quantized and has_quantized:
803+
onnx_file = "model_quantized.onnx"
804+
logger.info("Loading quantized ONNX model for optimal performance")
805+
elif has_unquantized:
806+
onnx_file = "model.onnx"
807+
logger.info("Loading unquantized ONNX model")
808+
elif has_quantized:
809+
onnx_file = "model_quantized.onnx"
810+
logger.info("Loading quantized ONNX model (only version available)")
811+
else:
812+
raise ValueError(f"No ONNX model files found in {onnx_path}")
813+
814+
classifier.model = ORTModelForFeatureExtraction.from_pretrained(
815+
onnx_path,
816+
file_name=onnx_file
817+
)
786818
classifier.tokenizer = AutoTokenizer.from_pretrained(config_dict['model_name'])
787819

788820
# Initialize memory and other components
@@ -1133,18 +1165,19 @@ def save(self, save_dir: str, include_onnx: bool = True, quantize_onnx: bool = T
11331165
)
11341166

11351167
@classmethod
1136-
def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Union[bool, str]] = "auto") -> 'AdaptiveClassifier':
1168+
def load(cls, save_dir: str, device: Optional[str] = None, use_onnx: Optional[Union[bool, str]] = "auto", prefer_quantized: bool = True) -> 'AdaptiveClassifier':
11371169
"""Legacy load method for backwards compatibility.
11381170
11391171
Args:
11401172
save_dir: Directory to load from
11411173
device: Device to load model on
1142-
use_onnx: Whether to use ONNX ("auto", True, False)
1174+
use_onnx: Whether to use ONNX Runtime ("auto", True, False)
1175+
prefer_quantized: Use quantized ONNX model if available (default: True)
11431176
"""
11441177
kwargs = {}
11451178
if device is not None:
11461179
kwargs['device'] = device
1147-
return cls._from_pretrained(save_dir, use_onnx=use_onnx, **kwargs)
1180+
return cls._from_pretrained(save_dir, use_onnx=use_onnx, prefer_quantized=prefer_quantized, **kwargs)
11481181

11491182
def to(self, device: str) -> 'AdaptiveClassifier':
11501183
"""Move the model to specified device.

0 commit comments

Comments
 (0)