|
| 1 | +# 01: Transparent Hooks (`accelerate.py`) |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +Single entry point that monkey-patches framework loading functions for transparent acceleration. No user code changes required. |
| 6 | + |
| 7 | +## API |
| 8 | + |
| 9 | +```python |
| 10 | +import zerostart |
| 11 | + |
| 12 | +zerostart.accelerate( |
| 13 | + cache_dir=None, # Auto-detect: /volume first, then ~/.cache/zerostart/models |
| 14 | + auto_cache=True, # Snapshot models after first load for faster second load |
| 15 | + network_volume_fix=True, # Detect FUSE/NFS and use eager read instead of mmap |
| 16 | +) |
| 17 | +``` |
| 18 | + |
| 19 | +## Hook Registry |
| 20 | + |
| 21 | +Each hook is a pair: `(patch_fn, unpatch_fn)`. Hooks are only installed if the target module is importable. |
| 22 | + |
| 23 | +```python |
| 24 | +_hooks: list[tuple[str, Callable, Callable]] = [] |
| 25 | + |
| 26 | +def accelerate(**kwargs): |
| 27 | + _try_hook_transformers(kwargs) |
| 28 | + _try_hook_diffusers(kwargs) |
| 29 | + _try_hook_safetensors(kwargs) |
| 30 | + _try_hook_torch_load(kwargs) |
| 31 | + |
| 32 | +def decelerate(): |
| 33 | + for name, _, unpatch in reversed(_hooks): |
| 34 | + unpatch() |
| 35 | + _hooks.clear() |
| 36 | +``` |
| 37 | + |
| 38 | +## Hook 1: transformers `from_pretrained` |
| 39 | + |
| 40 | +**Target:** `transformers.PreTrainedModel.from_pretrained` |
| 41 | + |
| 42 | +**What it does:** |
| 43 | +1. Checks model cache — if cached, hydrate and return (skip everything) |
| 44 | +2. Forces `low_cpu_mem_usage=True` to use meta device init |
| 45 | +3. Calls original `from_pretrained` |
| 46 | +4. Auto-snapshots result for next load |
| 47 | + |
| 48 | +**Edge cases:** |
| 49 | +- `quantization_config` (GPTQ, AWQ, BitsAndBytes) — don't cache quantized models, quantization is hardware-specific |
| 50 | +- `device_map="auto"` — already uses meta device, but we still cache |
| 51 | +- `revision` parameter — include in cache key |
| 52 | +- `from_pretrained(local_path)` — still cache, key by resolved path |
| 53 | + |
| 54 | +**Implementation:** |
| 55 | + |
| 56 | +```python |
| 57 | +def _try_hook_transformers(config): |
| 58 | + try: |
| 59 | + from transformers import PreTrainedModel |
| 60 | + except ImportError: |
| 61 | + return |
| 62 | + |
| 63 | + original = PreTrainedModel.from_pretrained.__func__ |
| 64 | + cache = config.get("_cache") |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def patched(cls, pretrained_model_name_or_path, *args, **kwargs): |
| 68 | + # Skip cache for quantized models |
| 69 | + if kwargs.get("quantization_config"): |
| 70 | + kwargs.setdefault("low_cpu_mem_usage", True) |
| 71 | + return original(cls, pretrained_model_name_or_path, *args, **kwargs) |
| 72 | + |
| 73 | + # Try cache |
| 74 | + cache_key = _cache_key(pretrained_model_name_or_path, kwargs) |
| 75 | + device = kwargs.get("device_map", kwargs.get("device")) |
| 76 | + if cache and cache.has(cache_key): |
| 77 | + log.info("Cache hit for %s", pretrained_model_name_or_path) |
| 78 | + state = cache.load(cache_key, device=device) |
| 79 | + return state["model"] |
| 80 | + |
| 81 | + # Accelerated load |
| 82 | + kwargs.setdefault("low_cpu_mem_usage", True) |
| 83 | + t0 = time.monotonic() |
| 84 | + model = original(cls, pretrained_model_name_or_path, *args, **kwargs) |
| 85 | + elapsed = time.monotonic() - t0 |
| 86 | + log.info("from_pretrained(%s): %.2fs", pretrained_model_name_or_path, elapsed) |
| 87 | + |
| 88 | + # Auto-cache |
| 89 | + if cache and config.get("auto_cache", True): |
| 90 | + cache.save(cache_key, {"model": model}) |
| 91 | + |
| 92 | + return model |
| 93 | + |
| 94 | + PreTrainedModel.from_pretrained = patched |
| 95 | + _hooks.append(("transformers", None, lambda: setattr(PreTrainedModel, "from_pretrained", classmethod(original)))) |
| 96 | +``` |
| 97 | + |
| 98 | +## Hook 2: diffusers `from_pretrained` |
| 99 | + |
| 100 | +**Target:** `diffusers.ModelMixin.from_pretrained` and `diffusers.DiffusionPipeline.from_pretrained` |
| 101 | + |
| 102 | +Same pattern as transformers. DiffusionPipeline is special because it loads multiple sub-models — we cache the entire pipeline state as one snapshot. |
| 103 | + |
| 104 | +## Hook 3: safetensors network volume fix |
| 105 | + |
| 106 | +**Target:** `safetensors.torch.load_file` |
| 107 | + |
| 108 | +**Problem:** mmap is 30-50x slower on FUSE/NFS volumes (RunPod persistent volumes, vast.ai volumes). This affects ALL model loading, not just from_pretrained. |
| 109 | + |
| 110 | +**Detection:** |
| 111 | + |
| 112 | +```python |
| 113 | +import os |
| 114 | +import subprocess |
| 115 | + |
| 116 | +def _is_network_volume(path: str) -> bool: |
| 117 | + """Check if path is on a FUSE/NFS filesystem where mmap is slow.""" |
| 118 | + try: |
| 119 | + # Linux: check /proc/mounts |
| 120 | + with open("/proc/mounts") as f: |
| 121 | + for line in f: |
| 122 | + parts = line.split() |
| 123 | + mount_point = parts[1] |
| 124 | + fs_type = parts[2] |
| 125 | + if path.startswith(mount_point) and fs_type in ( |
| 126 | + "fuse", "fuse.juicefs", "fuse.gcsfuse", "nfs", "nfs4", |
| 127 | + "cifs", "smbfs", "9p", "overlay", |
| 128 | + ): |
| 129 | + return True |
| 130 | + except FileNotFoundError: |
| 131 | + pass |
| 132 | + return False |
| 133 | +``` |
| 134 | + |
| 135 | +**Patch:** |
| 136 | + |
| 137 | +```python |
| 138 | +def _try_hook_safetensors(config): |
| 139 | + try: |
| 140 | + import safetensors.torch |
| 141 | + except ImportError: |
| 142 | + return |
| 143 | + |
| 144 | + original = safetensors.torch.load_file |
| 145 | + |
| 146 | + def patched(filename, device="cpu"): |
| 147 | + if config.get("network_volume_fix", True) and _is_network_volume(str(filename)): |
| 148 | + # Eager read: read entire file into memory, then deserialize |
| 149 | + # Avoids mmap page fault penalty on network volumes |
| 150 | + with open(filename, "rb") as f: |
| 151 | + data = f.read() |
| 152 | + return safetensors.torch.load(data, device=device) |
| 153 | + return original(filename, device=device) |
| 154 | + |
| 155 | + safetensors.torch.load_file = patched |
| 156 | + _hooks.append(("safetensors", None, lambda: setattr(safetensors.torch, "load_file", original))) |
| 157 | +``` |
| 158 | + |
| 159 | +## Hook 4: torch.load → safetensors conversion |
| 160 | + |
| 161 | +**Target:** `torch.load` |
| 162 | + |
| 163 | +For legacy `.bin` checkpoints, convert to safetensors on first load, mmap on subsequent loads. |
| 164 | + |
| 165 | +```python |
| 166 | +def _try_hook_torch_load(config): |
| 167 | + original = torch.load |
| 168 | + cache = config.get("_cache") |
| 169 | + |
| 170 | + def patched(f, *args, **kwargs): |
| 171 | + if not isinstance(f, (str, Path)) or not str(f).endswith(".bin"): |
| 172 | + return original(f, *args, **kwargs) |
| 173 | + |
| 174 | + # Check if we have a safetensors conversion cached |
| 175 | + sf_path = cache.safetensors_path_for(str(f)) |
| 176 | + if sf_path and sf_path.exists(): |
| 177 | + from safetensors.torch import load_file |
| 178 | + device = kwargs.get("map_location", "cpu") |
| 179 | + if isinstance(device, torch.device): |
| 180 | + device = str(device) |
| 181 | + return load_file(str(sf_path), device=device or "cpu") |
| 182 | + |
| 183 | + # Load normally, convert to safetensors for next time |
| 184 | + result = original(f, *args, **kwargs) |
| 185 | + if cache and isinstance(result, dict): |
| 186 | + cache.save_as_safetensors(str(f), result) |
| 187 | + return result |
| 188 | + |
| 189 | + torch.load = patched |
| 190 | + _hooks.append(("torch.load", None, lambda: setattr(torch, "load", original))) |
| 191 | +``` |
| 192 | + |
| 193 | +## Testing |
| 194 | + |
| 195 | +```python |
| 196 | +def test_accelerate_transformers(): |
| 197 | + """from_pretrained is faster with acceleration.""" |
| 198 | + import zerostart |
| 199 | + zerostart.accelerate(cache_dir="/tmp/zs-test") |
| 200 | + |
| 201 | + from transformers import AutoModelForCausalLM |
| 202 | + |
| 203 | + # First load: normal speed + auto-cache |
| 204 | + t0 = time.monotonic() |
| 205 | + model1 = AutoModelForCausalLM.from_pretrained("gpt2") |
| 206 | + first_load = time.monotonic() - t0 |
| 207 | + |
| 208 | + del model1 |
| 209 | + |
| 210 | + # Second load: should be much faster (cache hit) |
| 211 | + t0 = time.monotonic() |
| 212 | + model2 = AutoModelForCausalLM.from_pretrained("gpt2") |
| 213 | + second_load = time.monotonic() - t0 |
| 214 | + |
| 215 | + assert second_load < first_load * 0.5 # At least 2x faster |
| 216 | + |
| 217 | + zerostart.decelerate() |
| 218 | +``` |
0 commit comments