Skip to content

Commit d8f7981

Browse files
committed
bump vesions
1 parent 70c4608 commit d8f7981

3 files changed

Lines changed: 157 additions & 3 deletions

File tree

scripts/benchmark_onnx_speedup.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/usr/bin/env python3
2+
"""Benchmark ONNX vs PyTorch performance for adaptive classifier."""
3+
4+
import time
5+
import logging
6+
import datasets
7+
from adaptive_classifier import AdaptiveClassifier
8+
9+
logging.basicConfig(level=logging.INFO)
10+
logger = logging.getLogger(__name__)
11+
12+
def benchmark_model(model_id: str, test_texts: list, use_onnx: bool, num_runs: int = 3):
13+
"""Benchmark a model configuration."""
14+
mode = "ONNX (Quantized)" if use_onnx else "PyTorch"
15+
logger.info(f"\n{'='*60}")
16+
logger.info(f"Benchmarking: {mode}")
17+
logger.info(f"{'='*60}")
18+
19+
# Load model
20+
logger.info(f"Loading model from {model_id}...")
21+
start = time.time()
22+
classifier = AdaptiveClassifier.load(model_id, use_onnx=use_onnx)
23+
load_time = time.time() - start
24+
logger.info(f"Model loaded in {load_time:.2f}s")
25+
26+
# Warm-up run (not timed)
27+
logger.info("Warming up...")
28+
_ = classifier.predict_batch(test_texts[:5])
29+
30+
# Benchmark runs
31+
times = []
32+
for run in range(num_runs):
33+
logger.info(f"Run {run + 1}/{num_runs}...")
34+
start = time.time()
35+
predictions = classifier.predict_batch(test_texts)
36+
elapsed = time.time() - start
37+
times.append(elapsed)
38+
logger.info(f" Completed in {elapsed:.3f}s ({len(test_texts)/elapsed:.1f} samples/sec)")
39+
40+
avg_time = sum(times) / len(times)
41+
throughput = len(test_texts) / avg_time
42+
43+
logger.info(f"\nResults for {mode}:")
44+
logger.info(f" Average time: {avg_time:.3f}s")
45+
logger.info(f" Throughput: {throughput:.1f} samples/sec")
46+
logger.info(f" Per-sample latency: {avg_time*1000/len(test_texts):.1f}ms")
47+
48+
return {
49+
'mode': mode,
50+
'load_time': load_time,
51+
'avg_time': avg_time,
52+
'throughput': throughput,
53+
'times': times
54+
}
55+
56+
def main():
57+
# Configuration
58+
model_id = "adaptive-classifier/llm-router"
59+
num_samples = 100
60+
num_runs = 3
61+
62+
logger.info(f"Benchmark Configuration:")
63+
logger.info(f" Model: {model_id}")
64+
logger.info(f" Samples: {num_samples}")
65+
logger.info(f" Runs per config: {num_runs}")
66+
67+
# Load test data
68+
logger.info(f"\nLoading test dataset...")
69+
dataset = datasets.load_dataset("routellm/gpt4_dataset", split="validation")
70+
test_data = dataset.select(range(min(num_samples, len(dataset))))
71+
test_texts = [item['prompt'] for item in test_data]
72+
logger.info(f"Loaded {len(test_texts)} test samples")
73+
74+
# Benchmark PyTorch version
75+
pytorch_results = benchmark_model(model_id, test_texts, use_onnx=False, num_runs=num_runs)
76+
77+
# Benchmark ONNX version
78+
onnx_results = benchmark_model(model_id, test_texts, use_onnx=True, num_runs=num_runs)
79+
80+
# Compare results
81+
logger.info(f"\n{'='*60}")
82+
logger.info(f"COMPARISON SUMMARY")
83+
logger.info(f"{'='*60}")
84+
85+
speedup = pytorch_results['avg_time'] / onnx_results['avg_time']
86+
throughput_increase = onnx_results['throughput'] / pytorch_results['throughput']
87+
latency_reduction = (1 - onnx_results['avg_time'] / pytorch_results['avg_time']) * 100
88+
89+
logger.info(f"\nPyTorch (Baseline):")
90+
logger.info(f" Average time: {pytorch_results['avg_time']:.3f}s")
91+
logger.info(f" Throughput: {pytorch_results['throughput']:.1f} samples/sec")
92+
93+
logger.info(f"\nONNX Quantized:")
94+
logger.info(f" Average time: {onnx_results['avg_time']:.3f}s")
95+
logger.info(f" Throughput: {onnx_results['throughput']:.1f} samples/sec")
96+
97+
logger.info(f"\nSpeedup:")
98+
logger.info(f" 🚀 {speedup:.2f}x faster")
99+
logger.info(f" 📈 {throughput_increase:.2f}x throughput increase")
100+
logger.info(f" ⏱️ {latency_reduction:.1f}% latency reduction")
101+
102+
logger.info(f"\nModel Size Comparison:")
103+
logger.info(f" PyTorch: Uses full precision weights")
104+
logger.info(f" ONNX Quantized: 65.6 MB (4x smaller than unquantized)")
105+
106+
logger.info(f"\n{'='*60}")
107+
logger.info(f"BENCHMARK COMPLETE")
108+
logger.info(f"{'='*60}")
109+
110+
return {
111+
'pytorch': pytorch_results,
112+
'onnx': onnx_results,
113+
'speedup': speedup,
114+
'throughput_increase': throughput_increase,
115+
'latency_reduction': latency_reduction
116+
}
117+
118+
if __name__ == "__main__":
119+
results = main()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setup(
1717
name="adaptive-classifier",
18-
version="0.0.19",
18+
version="0.1.0",
1919
author="codelion",
2020
author_email="codelion@okyasoft.com",
2121
description="A flexible, adaptive classification system for dynamic text classification",

src/adaptive_classifier/classifier.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,41 @@ def _from_pretrained(
701701
token=token,
702702
local_files_only=local_files_only,
703703
)
704+
705+
# Try to download ONNX files if they exist
706+
try:
707+
# Download quantized ONNX model (primary)
708+
hf_hub_download(
709+
repo_id=model_id,
710+
filename="onnx/model_quantized.onnx",
711+
revision=revision,
712+
cache_dir=cache_dir,
713+
force_download=force_download,
714+
proxies=proxies,
715+
resume_download=resume_download,
716+
token=token,
717+
local_files_only=local_files_only,
718+
)
719+
# Download ONNX config files
720+
for onnx_file in ["config.json", "ort_config.json", "tokenizer.json",
721+
"tokenizer_config.json", "special_tokens_map.json", "vocab.txt"]:
722+
try:
723+
hf_hub_download(
724+
repo_id=model_id,
725+
filename=f"onnx/{onnx_file}",
726+
revision=revision,
727+
cache_dir=cache_dir,
728+
force_download=force_download,
729+
proxies=proxies,
730+
resume_download=resume_download,
731+
token=token,
732+
local_files_only=local_files_only,
733+
)
734+
except:
735+
pass # Some files might not exist
736+
logger.info("Downloaded ONNX model files from Hub")
737+
except Exception as e:
738+
logger.debug(f"ONNX model not available on Hub: {e}")
704739
except Exception as e:
705740
raise ValueError(f"Error loading model from {model_id}: {e}")
706741

@@ -712,9 +747,9 @@ def _from_pretrained(
712747
with open(model_path / "examples.json", "r", encoding="utf-8") as f:
713748
saved_examples = json.load(f)
714749

715-
# Check if ONNX model exists
750+
# Check if ONNX model exists (quantized or unquantized)
716751
onnx_path = model_path / "onnx"
717-
has_onnx = onnx_path.exists() and (onnx_path / "model.onnx").exists()
752+
has_onnx = onnx_path.exists() and ((onnx_path / "model_quantized.onnx").exists() or (onnx_path / "model.onnx").exists())
718753

719754
# Determine if we should use ONNX
720755
final_use_onnx = use_onnx

0 commit comments

Comments
 (0)