Skip to content

Commit 515aa2e

Browse files
Improve accelerate() for large MoE models (Qwen3.5-35B-A3B)
- Skip cache for device_map='auto' — HF's shard-by-shard loading is faster than our load-to-CPU-then-dispatch path (10s vs 25s) - Add suffix-based tensor matching for MoE models where state_dict and safetensors use different key prefixes (684/693 matched vs 1/693 before) - Add match ratio check — skip cache save when <50% tensors match - Fix meta tensor crash with to_empty() fallback - Add parallel shard loading via ThreadPoolExecutor - Update benchmark script with proper HF cache cleanup - Update README with Qwen3.5-35B-A3B benchmark results Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 339f8ae commit 515aa2e

5 files changed

Lines changed: 423 additions & 77 deletions

File tree

README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ import torch
124124

125125
## Model Loading Acceleration
126126

127-
`zerostart.accelerate()` patches `from_pretrained` to speed up model loading by skipping unnecessary work (random weight init, repeated downloads). Add one line:
127+
`zerostart.accelerate()` patches `from_pretrained` to speed up model loading. Sets `low_cpu_mem_usage=True` by default (skips random weight initialization), and auto-caches models for faster repeat loads on models that fit in GPU memory.
128128

129129
```python
130130
import zerostart
@@ -134,6 +134,14 @@ from transformers import AutoModelForCausalLM
134134
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
135135
```
136136

137+
| Model | Cold (download+load) | Baseline (HF cached) | accelerate() | Notes |
138+
|-------|---------------------|---------------------|--------------|-------|
139+
| Qwen3.5-35B-A3B | 299s | 10.1s | 10.1s | MoE, 34.7B params, device_map='auto' |
140+
| Qwen2.5-7B || 5.5s | 3.3s | Fits in GPU, cache provides speedup |
141+
| Qwen2.5-1.5B || 3.5s | 3.2s | Small model, minimal difference |
142+
143+
All measured on RTX A6000 (48GB). For models requiring `device_map='auto'` (model > VRAM), accelerate() matches baseline by eliminating random weight initialization. For models that fit entirely in GPU memory, the mmap cache provides additional speedup.
144+
137145
Or via CLI:
138146

139147
```bash
@@ -144,13 +152,17 @@ zerostart run --accelerate -p torch -p transformers serve.py
144152

145153
| Hook | What it does |
146154
|------|-------------|
147-
| Meta device init | Skips random weight initialization during `from_pretrained` |
148-
| Auto-cache | Snapshots model on first load, mmap hydrate on repeat |
155+
| `low_cpu_mem_usage` | Sets `low_cpu_mem_usage=True` by default — skips random weight initialization |
156+
| Auto-cache | Snapshots model on first load, mmap hydrate via `safe_open` on repeat (models that fit in GPU) |
157+
| Parallel shard loading | Loads multiple safetensors shards concurrently during cache hydration |
158+
| Suffix tensor matching | Handles MoE models where state_dict and safetensors use different key prefixes |
149159
| Network volume fix | Eager read instead of mmap on NFS/JuiceFS (cold reads only*) |
150160
| .bin conversion | Converts legacy checkpoints to safetensors, mmaps on repeat |
151161

152162
*Network volume fix only helps on cold reads from network-backed filesystems where mmap page faults trigger network round-trips. On FUSE with warm page cache (most container providers), mmap is already fast.
153163

164+
For `device_map='auto'` (model larger than VRAM), caching is skipped — HF's shard-by-shard loading directly to the right device is faster than our load-to-CPU-then-dispatch path.
165+
154166
### Model Cache
155167

156168
Models are automatically cached after first load:

gpu.jsonc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
"$schema": "https://gpu-cli.sh/schema/v1/gpu.json",
33
"project_id": "zerostart",
44
"gpu_types": [
5-
{ "type": "RTX 4090" },
65
{ "type": "RTX A6000" }
76
],
87
"min_vram": 24,
9-
"include": ["bin/", "tests/", "python/"],
8+
"storage_mode": "built-in",
9+
"workspace_size_gb": 100,
10+
"include": ["bin/", "tests/", "python/", "README.md"],
1011
"outputs": ["benches/results/"],
1112
"environment": {
1213
"system": {

python/zerostart/accelerate.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
"""Transparent model loading acceleration.
22
3-
One line enables 4x faster model loading — no user code changes needed:
3+
One line enables faster model loading — no user code changes needed:
44
55
import zerostart
66
zerostart.accelerate()
77
88
from transformers import AutoModelForCausalLM
99
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-7B", device_map="cuda")
10-
# 4x faster, identical result
1110
1211
Hooks:
1312
1. transformers PreTrainedModel.from_pretrained — meta device init + cache
@@ -19,6 +18,7 @@
1918
from __future__ import annotations
2019

2120
import logging
21+
import threading
2222
import time
2323
from pathlib import Path
2424
from typing import Any, Callable
@@ -29,6 +29,7 @@
2929

3030
_hooks: list[tuple[str, Callable[[], None]]] = []
3131
_cache: ModelCache | None = None
32+
_bg_save_thread: threading.Thread | None = None
3233

3334

3435
def accelerate(
@@ -80,7 +81,11 @@ def accelerate(
8081

8182
def decelerate() -> None:
8283
"""Remove all hooks, restore original functions."""
83-
global _cache
84+
global _cache, _bg_save_thread
85+
# Wait for any background save to complete
86+
if _bg_save_thread is not None and _bg_save_thread.is_alive():
87+
_bg_save_thread.join(timeout=30)
88+
_bg_save_thread = None
8489
for name, unpatch in reversed(_hooks):
8590
unpatch()
8691
log.debug("Removed hook: %s", name)
@@ -119,7 +124,13 @@ def patched(cls: type, pretrained_model_name_or_path: str, *args: Any, **kwargs:
119124
# Try cache
120125
key = cache_key(str(pretrained_model_name_or_path), kwargs)
121126
device = kwargs.get("device_map", kwargs.get("device"))
122-
if mc.has(key):
127+
128+
# Skip cache for device_map='auto' — HF's shard-by-shard loading to
129+
# the right device is faster than our load-to-CPU-then-dispatch path.
130+
# We still benefit from low_cpu_mem_usage=True below.
131+
use_cache = device != "auto"
132+
133+
if use_cache and mc.has(key):
123134
t0 = time.monotonic()
124135
state = mc.load(key, device=_resolve_device(device))
125136
model = state.get("model")
@@ -138,25 +149,53 @@ def patched(cls: type, pretrained_model_name_or_path: str, *args: Any, **kwargs:
138149
elapsed = time.monotonic() - t0
139150
log.info("from_pretrained(%s): %.2fs (accelerated)", pretrained_model_name_or_path, elapsed)
140151

141-
# Auto-cache for next load
142-
if auto:
143-
try:
144-
mc.save(
145-
key,
146-
{"model": model},
147-
model_id=str(pretrained_model_name_or_path),
148-
dtype=str(kwargs.get("torch_dtype", kwargs.get("dtype", "auto"))),
149-
revision=kwargs.get("revision", "main"),
150-
)
151-
except Exception as e:
152-
log.warning("Auto-cache failed for %s: %s", pretrained_model_name_or_path, e)
152+
# Auto-cache in background thread (non-blocking)
153+
# Skip for device_map='auto' — cache won't be used anyway
154+
if auto and use_cache:
155+
_bg_cache_save(mc, key, model, str(pretrained_model_name_or_path), kwargs)
153156

154157
return model
155158

156159
PreTrainedModel.from_pretrained = patched
157160
_hooks.append(("transformers", lambda: setattr(PreTrainedModel, "from_pretrained", original)))
158161

159162

163+
def _bg_cache_save(
164+
mc: ModelCache,
165+
key: str,
166+
model: Any,
167+
model_id: str,
168+
kwargs: dict[str, Any],
169+
) -> None:
170+
"""Save model to cache in a background thread."""
171+
global _bg_save_thread
172+
173+
# Capture state_dict eagerly on main thread (safe reference to tensor memory)
174+
try:
175+
import torch
176+
# Verify model has parameters before attempting save
177+
param_count = sum(1 for _ in model.parameters())
178+
if param_count == 0:
179+
return
180+
except Exception:
181+
return
182+
183+
def _save() -> None:
184+
try:
185+
mc.save(
186+
key,
187+
{"model": model},
188+
model_id=model_id,
189+
dtype=str(kwargs.get("torch_dtype", kwargs.get("dtype", "auto"))),
190+
revision=kwargs.get("revision", "main"),
191+
)
192+
except Exception as e:
193+
log.warning("Background cache save failed for %s: %s", model_id, e)
194+
195+
_bg_save_thread = threading.Thread(target=_save, daemon=True)
196+
_bg_save_thread.start()
197+
198+
160199
# ---------------------------------------------------------------------------
161200
# Hook 2: diffusers ModelMixin.from_pretrained + DiffusionPipeline
162201
# ---------------------------------------------------------------------------
@@ -189,10 +228,7 @@ def patched_model(cls: type, pretrained_model_name_or_path: str, *args: Any, **k
189228
log.info("diffusers.from_pretrained(%s): %.2fs", pretrained_model_name_or_path, time.monotonic() - t0)
190229

191230
if auto:
192-
try:
193-
mc.save(key, {"model": model}, model_id=f"diffusers:{pretrained_model_name_or_path}")
194-
except Exception as e:
195-
log.warning("Auto-cache failed: %s", e)
231+
_bg_cache_save(mc, key, model, f"diffusers:{pretrained_model_name_or_path}", kwargs)
196232

197233
return model
198234

@@ -254,7 +290,11 @@ def patched(filename: str, device: str = "cpu") -> dict[str, Any]:
254290
t0 = time.monotonic()
255291
with open(path, "rb") as f:
256292
data = f.read()
257-
result = safetensors.torch.load(data, device=device)
293+
# safetensors.torch.load() (bytes) doesn't accept device kwarg
294+
result = safetensors.torch.load(data)
295+
if device and device != "cpu":
296+
import torch
297+
result = {k: v.to(device) for k, v in result.items()}
258298
log.debug("Eager read %s (%.2fs, network volume)", Path(path).name, time.monotonic() - t0)
259299
return result
260300
return original(filename, device=device)
@@ -304,7 +344,7 @@ def patched(f: Any, *args: Any, **kwargs: Any) -> Any:
304344
# Load normally
305345
result = original(f, *args, **kwargs)
306346

307-
# Cache as safetensors for next time
347+
# Cache as safetensors for next time (background)
308348
if isinstance(result, dict):
309349
# Only cache if all values are tensors
310350
all_tensors = True
@@ -343,10 +383,14 @@ def _is_network_volume(path: str) -> bool:
343383

344384

345385
def _check_network_volume(path: str) -> bool:
346-
"""Detect FUSE/NFS/overlay mounts via /proc/mounts."""
386+
"""Detect FUSE/NFS mounts via /proc/mounts.
387+
388+
overlay is intentionally excluded — most container providers (RunPod, etc.)
389+
use overlay backed by local SSD where mmap is fast.
390+
"""
347391
slow_fs_types = frozenset({
348392
"fuse", "fuse.juicefs", "fuse.gcsfuse", "fuse.sshfs",
349-
"nfs", "nfs4", "cifs", "smbfs", "9p", "overlay",
393+
"nfs", "nfs4", "cifs", "smbfs", "9p",
350394
})
351395

352396
try:

0 commit comments

Comments
 (0)