Skip to content

Commit 5a56bf2

Browse files
Fix snapshot: HF cache discovery, e2e benchmark on volume
- Fix _find_hf_cache_dir to check huggingface_hub constants and HF_HUB_CACHE env var (not just HF_HOME) — fixes 0/339 tensor matching on RunPod where HF uses /gpu-cli-workspaces/.cache/ - Move e2e benchmark to /gpu-cli-workspaces/ for disk space - Clear pip cache + HF cache between scenarios for honest numbers - Uninstall system torchvision to avoid operator conflict - Add timing instrumentation to hydrate (python/tensors/reconstruct) - Fix torch_dtype deprecation → dtype - Include python/ in gpu.jsonc sync Results (Qwen2.5-7B, RTX 4090): Snapshot: 0.2s (339/339 tensors matched to safetensors) Hydrate (warm runtime): 0.4s Hydrate (cold runtime): 11.6s (cloudpickle triggers transformers import) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 928fa74 commit 5a56bf2

3 files changed

Lines changed: 122 additions & 44 deletions

File tree

gpu.jsonc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{ "type": "RTX A6000" }
77
],
88
"min_vram": 24,
9-
"include": ["bin/"],
9+
"include": ["bin/", "tests/", "python/"],
1010
"outputs": ["benches/results/"],
1111
"environment": {
1212
"system": {

python/zerostart/snapshot.py

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -204,24 +204,57 @@ def _find_safetensors_for_model(module: Any) -> list[Path]:
204204

205205
def _find_hf_cache_dir(model_id: str) -> Path | None:
206206
"""Find the HF hub cache directory for a model."""
207-
hf_home = Path(os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")))
208-
hub_dir = hf_home / "hub"
209-
210-
# HF cache structure: hub/models--org--name/snapshots/<hash>/
211207
safe_id = model_id.replace("/", "--")
212-
model_dir = hub_dir / f"models--{safe_id}"
208+
model_subdir = f"models--{safe_id}"
209+
210+
# Check multiple possible HF cache locations in order of priority:
211+
# 1. huggingface_hub constants (most reliable — reads all HF env vars)
212+
# 2. HF_HUB_CACHE env var
213+
# 3. HF_HOME env var + /hub
214+
# 4. Default ~/.cache/huggingface/hub
215+
candidates: list[Path] = []
216+
217+
try:
218+
from huggingface_hub import constants
219+
candidates.append(Path(constants.HF_HUB_CACHE))
220+
except ImportError:
221+
pass
222+
223+
if hf_hub_cache := os.environ.get("HF_HUB_CACHE"):
224+
candidates.append(Path(hf_hub_cache))
225+
226+
if hf_home := os.environ.get("HF_HOME"):
227+
candidates.append(Path(hf_home) / "hub")
213228

214-
if not model_dir.is_dir():
215-
return None
229+
candidates.append(Path(os.path.expanduser("~/.cache/huggingface/hub")))
216230

217-
# Find the latest snapshot
218-
snapshots = model_dir / "snapshots"
219-
if not snapshots.is_dir():
220-
return None
231+
# Dedupe preserving order
232+
seen: set[str] = set()
233+
unique: list[Path] = []
234+
for c in candidates:
235+
key = str(c)
236+
if key not in seen:
237+
seen.add(key)
238+
unique.append(c)
221239

222-
# Get the most recent snapshot directory
223-
snap_dirs = sorted(snapshots.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
224-
return snap_dirs[0] if snap_dirs else None
240+
for hub_dir in unique:
241+
model_dir = hub_dir / model_subdir
242+
if not model_dir.is_dir():
243+
continue
244+
245+
snapshots = model_dir / "snapshots"
246+
if not snapshots.is_dir():
247+
continue
248+
249+
snap_dirs = sorted(snapshots.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
250+
if snap_dirs:
251+
result = snap_dirs[0]
252+
sf_count = len(list(result.glob("*.safetensors")))
253+
log.info("Found HF cache for %s at %s (%d safetensors files)", model_id, result, sf_count)
254+
return result
255+
256+
log.warning("Could not find HF cache for %s in %s", model_id, [str(c) for c in unique])
257+
return None
225258

226259

227260
def _build_tensor_to_file_map(
@@ -552,6 +585,7 @@ def hydrate(
552585
log.info("Tensors loaded via mmap (%.3fs, %d tensors)", t_mmap, len(loaded_tensors))
553586

554587
# 4. Reconstruct state: wire tensors back into Python objects
588+
t_reconstruct_start = time.monotonic()
555589
restored_state: dict[str, Any] = {}
556590

557591
for key, value in cleaned_state.items():
@@ -570,10 +604,11 @@ def hydrate(
570604
else:
571605
restored_state[key] = value
572606

607+
t_reconstruct = time.monotonic() - t_reconstruct_start
573608
elapsed = time.monotonic() - t0
574609
log.info(
575-
"Hydration complete (%.3fs total: %.3fs python + %.3fs tensors)",
576-
elapsed, t_python, t_mmap,
610+
"Hydration complete (%.3fs total: %.3fs python + %.3fs tensors + %.3fs reconstruct)",
611+
elapsed, t_python, t_mmap, t_reconstruct,
577612
)
578613

579614
return restored_state
@@ -595,22 +630,29 @@ def _reconstruct_module(
595630
if config and config.get("_type") == "transformers":
596631
try:
597632
import importlib
633+
t_import = time.monotonic()
598634
model_module = importlib.import_module(config["_module"])
599635
model_class = getattr(model_module, config["_class"])
600636

601637
config_module = importlib.import_module(config["config_module"])
602638
config_class = getattr(config_module, config["config_class"])
639+
t_import_done = time.monotonic()
603640

604641
model_config = config_class.from_dict(config["config_dict"])
605642

606643
# Create model on meta device (zero memory allocation) then
607644
# replace meta tensors with real data via load_state_dict(assign=True).
608645
# This matches how from_pretrained works: 0.4s instead of 80s.
646+
t_meta = time.monotonic()
609647
with _no_init_weights():
610648
with torch.device("meta"):
611649
module = model_class(model_config)
650+
t_meta_done = time.monotonic()
612651

613-
log.info("Reconstructed %s from config (meta device)", config["_class"])
652+
log.info(
653+
"Reconstructed %s (import=%.2fs, meta_init=%.2fs)",
654+
config["_class"], t_import_done - t_import, t_meta_done - t_meta,
655+
)
614656
except Exception as e:
615657
log.warning("Failed to reconstruct from config: %s", e)
616658

@@ -634,21 +676,29 @@ def _reconstruct_module(
634676
state_dict[param_name] = tensor
635677

636678
if state_dict:
679+
t_load = time.monotonic()
637680
try:
638681
module.load_state_dict(state_dict, strict=False, assign=True)
639682
except TypeError:
640683
# older torch doesn't have assign=True
641684
module.load_state_dict(state_dict, strict=False)
685+
t_load_done = time.monotonic()
642686

643687
# Re-tie weights (e.g., lm_head.weight = wte.weight in GPT-2)
644688
if hasattr(module, "tie_weights"):
645689
module.tie_weights()
646690

691+
log.info("load_state_dict: %.2fs (%d tensors)", t_load_done - t_load, len(state_dict))
692+
647693
# Materialize any remaining meta tensors (computed buffers like
648694
# rotary_emb.inv_freq that aren't in state_dict/safetensors).
649695
# These need to be recreated by calling the module's init logic
650696
# just for the specific submodules that still have meta tensors.
697+
t_mat = time.monotonic()
651698
_materialize_meta_tensors(module)
699+
t_mat_done = time.monotonic()
700+
if t_mat_done - t_mat > 0.01:
701+
log.info("materialize_meta_tensors: %.2fs", t_mat_done - t_mat)
652702

653703
# Move to device if requested
654704
if device:

tests/test_e2e_cold_start.sh

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ set -uo pipefail
44
echo "=== End-to-End Cold Start Benchmark ==="
55
echo "Date: $(date -u)"
66
echo "GPU: $(nvidia-smi --query-gpu=name,memory.total --format=csv,noheader 2>/dev/null || echo 'none')"
7-
df -h /tmp | tail -1 | awk '{print "Disk: " $4 " free"}'
87
echo ""
98

109
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
@@ -15,15 +14,30 @@ export PYTHONPATH="$PROJECT_DIR/python:${PYTHONPATH:-}"
1514

1615
MODEL_ID="${SNAP_MODEL:-Qwen/Qwen2.5-7B}"
1716

17+
# All temp data on the volume (more space than /tmp)
18+
BENCH_DIR="/gpu-cli-workspaces/.bench-e2e"
19+
rm -rf "$BENCH_DIR"
20+
mkdir -p "$BENCH_DIR"
21+
22+
# Remove system torchvision that conflicts with fresh torch installs
23+
pip uninstall -y torchvision 2>/dev/null || true
24+
25+
# Clear pip cache so scenario 1 is a true cold install
26+
pip cache purge 2>/dev/null || true
27+
28+
# Clear HF cache so model download is truly cold
29+
export HF_HOME="$BENCH_DIR/hf-cache"
30+
31+
df -h /gpu-cli-workspaces | tail -1 | awk '{print "Disk: " $4 " free on /gpu-cli-workspaces"}'
32+
echo ""
33+
1834
# ============================================================
1935
# Scenario 1: pip install + from_pretrained (traditional)
2036
# ============================================================
2137
echo "--- Scenario 1: pip install + from_pretrained ---"
22-
# Clean slate
23-
rm -rf /tmp/.pip-bench-venv
2438
BENCH_START=$(date +%s%3N)
2539

26-
cat > /tmp/bench_pip.py << PYEOF
40+
cat > "$BENCH_DIR/bench_pip.py" << PYEOF
2741
import time
2842
t_script = time.monotonic()
2943
@@ -33,7 +47,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3347
t_import = time.monotonic()
3448
3549
tokenizer = AutoTokenizer.from_pretrained("$MODEL_ID")
36-
model = AutoModelForCausalLM.from_pretrained("$MODEL_ID", torch_dtype=torch.bfloat16, device_map="cpu")
50+
model = AutoModelForCausalLM.from_pretrained("$MODEL_ID", dtype=torch.bfloat16, device_map="cpu")
3751
model.eval()
3852
t_model = time.monotonic()
3953
@@ -47,27 +61,30 @@ print(f"RESULT: {result}")
4761
print(f"TIME import={t_import-t_script:.2f}s model={t_model-t_import:.2f}s inference={t_inf-t_model:.2f}s total={t_inf-t_script:.2f}s")
4862
PYEOF
4963

50-
# Install into a fresh venv (simulates cold container)
51-
python3 -m venv /tmp/.pip-bench-venv
52-
/tmp/.pip-bench-venv/bin/pip install -q torch transformers accelerate 2>&1 | tail -3
64+
# Fresh venv, no pip cache — true cold install
65+
python3 -m venv "$BENCH_DIR/pip-venv"
66+
"$BENCH_DIR/pip-venv/bin/pip" install --no-cache-dir -q torch transformers accelerate 2>&1 | tail -3
5367
PIP_DONE=$(date +%s%3N)
5468
echo " pip install: $(( PIP_DONE - BENCH_START ))ms"
5569

56-
/tmp/.pip-bench-venv/bin/python /tmp/bench_pip.py 2>&1 | grep -E "^(RESULT|TIME)"
70+
"$BENCH_DIR/pip-venv/bin/python" "$BENCH_DIR/bench_pip.py" 2>&1 | tail -30
5771
BENCH_END=$(date +%s%3N)
5872
echo " Total wall clock (install + load + inference): $(( BENCH_END - BENCH_START ))ms"
59-
rm -rf /tmp/.pip-bench-venv
73+
rm -rf "$BENCH_DIR/pip-venv"
6074
echo ""
6175

76+
# Clear HF cache so scenario 2 also downloads fresh
77+
rm -rf "$HF_HOME"
78+
6279
# ============================================================
6380
# Scenario 2: zerostart cold + from_pretrained
6481
# ============================================================
6582
echo "--- Scenario 2: zerostart cold + from_pretrained ---"
66-
export ZEROSTART_CACHE="/tmp/.zs-e2e-bench"
83+
export ZEROSTART_CACHE="$BENCH_DIR/zs-cache"
6784
export ZS_NO_SHARED_CACHE=1
6885
rm -rf "$ZEROSTART_CACHE"
6986

70-
cat > /tmp/bench_zs_cold.py << PYEOF
87+
cat > "$BENCH_DIR/bench_zs_cold.py" << PYEOF
7188
import time
7289
t_script = time.monotonic()
7390
@@ -77,7 +94,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
7794
t_import = time.monotonic()
7895
7996
tokenizer = AutoTokenizer.from_pretrained("$MODEL_ID")
80-
model = AutoModelForCausalLM.from_pretrained("$MODEL_ID", torch_dtype=torch.bfloat16, device_map="cpu")
97+
model = AutoModelForCausalLM.from_pretrained("$MODEL_ID", dtype=torch.bfloat16, device_map="cpu")
8198
model.eval()
8299
t_model = time.monotonic()
83100
@@ -92,7 +109,7 @@ print(f"TIME import={t_import-t_script:.2f}s model={t_model-t_import:.2f}s infer
92109
PYEOF
93110

94111
ZS_START=$(date +%s%3N)
95-
$ZS run -p torch -p transformers -p accelerate /tmp/bench_zs_cold.py 2>&1 | grep -E "^(RESULT|TIME|Resolved|Daemon|Environment|Cache)"
112+
$ZS run -p torch -p transformers -p accelerate "$BENCH_DIR/bench_zs_cold.py" 2>&1 | tail -30
96113
ZS_END=$(date +%s%3N)
97114
echo " Total wall clock (zerostart cold + load + inference): $(( ZS_END - ZS_START ))ms"
98115
echo ""
@@ -101,10 +118,11 @@ echo ""
101118
# Scenario 3: zerostart warm + from_pretrained
102119
# ============================================================
103120
echo "--- Scenario 3: zerostart warm + from_pretrained ---"
104-
# Cache is now populated from Scenario 2
121+
# zerostart package cache is warm from Scenario 2
122+
# HF model cache is warm from Scenario 2
105123

106124
ZS_WARM_START=$(date +%s%3N)
107-
$ZS run -p torch -p transformers -p accelerate /tmp/bench_zs_cold.py 2>&1 | grep -E "^(RESULT|TIME|Cache)"
125+
$ZS run -p torch -p transformers -p accelerate "$BENCH_DIR/bench_zs_cold.py" 2>&1 | tail -30
108126
ZS_WARM_END=$(date +%s%3N)
109127
echo " Total wall clock (zerostart warm + load + inference): $(( ZS_WARM_END - ZS_WARM_START ))ms"
110128
echo ""
@@ -114,39 +132,46 @@ echo ""
114132
# ============================================================
115133
echo "--- Scenario 4: Create snapshot for hydrate ---"
116134

117-
cat > /tmp/bench_create_snap.py << PYEOF
118-
import time
135+
cat > "$BENCH_DIR/bench_create_snap.py" << PYEOF
136+
import time, logging
137+
logging.basicConfig(level=logging.INFO, format="%(name)-20s %(message)s")
119138
t0 = time.monotonic()
120139
import torch
121140
from transformers import AutoModelForCausalLM, AutoTokenizer
122141
from zerostart.snapshot import snapshot
123142
124143
tokenizer = AutoTokenizer.from_pretrained("$MODEL_ID")
125-
model = AutoModelForCausalLM.from_pretrained("$MODEL_ID", torch_dtype=torch.bfloat16, device_map="cpu")
144+
model = AutoModelForCausalLM.from_pretrained("$MODEL_ID", dtype=torch.bfloat16, device_map="cpu")
126145
model.eval()
127146
128147
import shutil
129-
shutil.rmtree("/tmp/e2e-snapshot", ignore_errors=True)
130-
snapshot(state={"model": model, "tokenizer": tokenizer}, path="/tmp/e2e-snapshot")
148+
shutil.rmtree("$BENCH_DIR/e2e-snapshot", ignore_errors=True)
149+
snapshot(state={"model": model, "tokenizer": tokenizer}, path="$BENCH_DIR/e2e-snapshot")
131150
t1 = time.monotonic()
132151
print(f"Snapshot created in {t1-t0:.2f}s")
133152
PYEOF
134153

135-
$ZS run -p torch -p transformers -p accelerate -p cloudpickle /tmp/bench_create_snap.py 2>&1 | grep -E "^(Snapshot|Cache)"
154+
$ZS run -p torch -p transformers -p accelerate -p cloudpickle "$BENCH_DIR/bench_create_snap.py" 2>&1 | tail -30
136155

137156
echo ""
138157
echo "--- Scenario 4: zerostart warm + hydrate + inference ---"
139158

140-
cat > /tmp/bench_hydrate.py << PYEOF
141-
import time
159+
# Reuse the warm zerostart package cache from scenarios 2/3 —
160+
# the comparison is model loading (hydrate vs from_pretrained),
161+
# not package installation.
162+
export ZEROSTART_CACHE="$BENCH_DIR/zs-cache"
163+
164+
cat > "$BENCH_DIR/bench_hydrate.py" << PYEOF
165+
import time, logging
166+
logging.basicConfig(level=logging.INFO, format="%(name)-20s %(message)s")
142167
t_script = time.monotonic()
143168
144169
import torch
145170
from zerostart.snapshot import hydrate
146171
147172
t_import = time.monotonic()
148173
149-
restored = hydrate("/tmp/e2e-snapshot")
174+
restored = hydrate("$BENCH_DIR/e2e-snapshot")
150175
model = restored["model"]
151176
model.eval()
152177
tokenizer = restored["tokenizer"]
@@ -163,7 +188,7 @@ print(f"TIME import={t_import-t_script:.2f}s hydrate={t_hydrate-t_import:.2f}s i
163188
PYEOF
164189

165190
ZS_HYD_START=$(date +%s%3N)
166-
$ZS run -p torch -p transformers -p accelerate -p cloudpickle /tmp/bench_hydrate.py 2>&1 | grep -E "^(RESULT|TIME|Cache)"
191+
$ZS run -p torch -p transformers -p accelerate -p cloudpickle "$BENCH_DIR/bench_hydrate.py" 2>&1 | tail -30
167192
ZS_HYD_END=$(date +%s%3N)
168193
echo " Total wall clock (zerostart warm + hydrate + inference): $(( ZS_HYD_END - ZS_HYD_START ))ms"
169194
echo ""
@@ -180,3 +205,6 @@ echo " 2. zerostart cold + from_pretrained: $(( ZS_END - ZS_START ))ms"
180205
echo " 3. zerostart warm + from_pretrained: $(( ZS_WARM_END - ZS_WARM_START ))ms"
181206
echo " 4. zerostart warm + hydrate (snapshot): $(( ZS_HYD_END - ZS_HYD_START ))ms"
182207
echo "============================================================"
208+
209+
# Cleanup
210+
rm -rf "$BENCH_DIR"

0 commit comments

Comments
 (0)