Skip to content

Commit 6727a74

Browse files
unamedkrclaude
andcommitted
refparity: --dtype option for memory-constrained HF reference runs
Adds FP32/BF16/FP16 selection to hf_reference.py; plumbs per-entry dtype through matrix.json → run_matrix.sh. Unblocks 4B-class models on 16 GB machines (BF16 halves memory so a ~4B model can sit next to a 4B GGUF engine run in the same 16 GB). Also adds `_disabled: true` entry filter for matrix.json and documents the intended (but currently oversized) DeltaNet-hybrid entry. Real 4B DeltaNet comparison target is TBD — none of Qwen's DeltaNet HF releases are <8B. Smoke-tested BF16 load: Qwen3-0.6B top1 matches FP32 exactly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent f612c57 commit 6727a74

3 files changed

Lines changed: 32 additions & 12 deletions

File tree

tools/refparity/hf_reference.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,27 @@
2020
from transformers import AutoModelForCausalLM, AutoTokenizer
2121

2222

23-
def run(model_name: str, prompt: str, out_path: str) -> int:
24-
print(f"[refparity/hf] model={model_name}", file=sys.stderr)
23+
DTYPE_MAP = {
24+
"float32": torch.float32, "fp32": torch.float32,
25+
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
26+
"float16": torch.float16, "fp16": torch.float16,
27+
}
28+
29+
30+
def run(model_name: str, prompt: str, out_path: str,
31+
dtype: str = "float32") -> int:
32+
torch_dtype = DTYPE_MAP.get(dtype)
33+
if torch_dtype is None:
34+
print(f"error: unknown dtype {dtype!r}; valid: {list(DTYPE_MAP)}",
35+
file=sys.stderr)
36+
return 2
37+
print(f"[refparity/hf] model={model_name} dtype={dtype}", file=sys.stderr)
2538
print(f"[refparity/hf] prompt: {prompt[:80]!r}{'...' if len(prompt) > 80 else ''}",
2639
file=sys.stderr)
2740

2841
tok = AutoTokenizer.from_pretrained(model_name)
2942
model = AutoModelForCausalLM.from_pretrained(
30-
model_name, dtype=torch.float32, device_map="cpu")
43+
model_name, dtype=torch_dtype, device_map="cpu")
3144
model.eval()
3245

3346
ids = tok.encode(prompt, return_tensors="pt")
@@ -77,6 +90,10 @@ def main():
7790
group.add_argument("--prompt", help="literal prompt text")
7891
group.add_argument("--prompt-file", help="read prompt from file")
7992
ap.add_argument("--out", required=True, help="output .npz path")
93+
ap.add_argument("--dtype", default="float32",
94+
choices=list(DTYPE_MAP.keys()),
95+
help="HF model dtype (default: float32). Use bfloat16 "
96+
"for 4B+ models on 16 GB machines.")
8097
args = ap.parse_args()
8198

8299
if args.prompt:
@@ -85,7 +102,7 @@ def main():
85102
with open(args.prompt_file) as f:
86103
prompt = f.read().rstrip("\n")
87104

88-
return run(args.model, prompt, args.out)
105+
return run(args.model, prompt, args.out, dtype=args.dtype)
89106

90107

91108
if __name__ == "__main__":

tools/refparity/matrix.json

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515
},
1616
{
1717
"name": "qwen3.5_4b_hybrid",
18-
"hf_model": "Qwen/Qwen3.5-4B-Thinking",
18+
"hf_model": "Qwen/Qwen3-Next-80B-A3B-Thinking",
1919
"engine_gguf": "Qwen3.5-4B-Q4_K_M.gguf",
20-
"_note": "Hybrid DeltaNet+self-attn reference. HF model may need download.",
20+
"_note": "Hybrid DeltaNet+self-attn. HF ref too big for 16 GB (80B). Use BF16 fallback or pick a smaller DeltaNet-class model. Currently DISABLED — kept for documentation.",
21+
"_disabled": true,
2122
"prompts": [
22-
"Hello",
23-
"def fibonacci(n):"
23+
"Hello"
2424
],
2525
"threshold_l2_rel": 0.05,
26-
"threshold_cosine": 0.90
26+
"threshold_cosine": 0.90,
27+
"dtype": "bfloat16"
2728
},
2829
{
2930
"name": "llama3.2_1b",

tools/refparity/run_matrix.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ with open("$MATRIX") as f:
3939
m = json.load(f)
4040
filt = "$FILTER"
4141
for t in m.get("tests", []):
42+
if t.get("_disabled"):
43+
continue
4244
if filt and filt not in t["name"]:
4345
continue
4446
gguf = t["engine_gguf"]
4547
for prompt in t["prompts"]:
4648
# bash-escaped prompt (base64 roundtrip to avoid quoting hell)
4749
import base64
4850
p64 = base64.b64encode(prompt.encode()).decode()
49-
print(f"{t['name']}|{t['hf_model']}|{gguf}|{p64}|{t.get('threshold_l2_rel', 0.05)}|{t.get('threshold_cosine', 0.90)}")
51+
print(f"{t['name']}|{t['hf_model']}|{gguf}|{p64}|{t.get('threshold_l2_rel', 0.05)}|{t.get('threshold_cosine', 0.90)}|{t.get('dtype', 'float32')}")
5052
PY
5153
)
5254

@@ -66,7 +68,7 @@ echo ""
6668

6769
PREV_NAME=""
6870
IDX=0
69-
while IFS='|' read -r NAME HF_MODEL GGUF P64 TH_L2 TH_COS; do
71+
while IFS='|' read -r NAME HF_MODEL GGUF P64 TH_L2 TH_COS DTYPE; do
7072
[[ -z "$NAME" ]] && continue
7173
PROMPT=$(echo "$P64" | base64 -d)
7274
TOTAL=$((TOTAL + 1))
@@ -93,7 +95,7 @@ while IFS='|' read -r NAME HF_MODEL GGUF P64 TH_L2 TH_COS; do
9395

9496
# HF reference dump (one per slot — different prompts produce different tokens)
9597
REF_NPZ="$WORK_DIR/$SLOT.npz"
96-
if ! python hf_reference.py --model "$HF_MODEL" --prompt "$PROMPT" --out "$REF_NPZ" 2>"$WORK_DIR/hf.err"; then
98+
if ! python hf_reference.py --model "$HF_MODEL" --prompt "$PROMPT" --out "$REF_NPZ" --dtype "${DTYPE:-float32}" 2>"$WORK_DIR/hf.err"; then
9799
echo " [ERROR] HF reference failed:"
98100
sed 's/^/ /' "$WORK_DIR/hf.err"
99101
FAILED_ENTRIES+=("$SLOT: hf_reference failed")

0 commit comments

Comments
 (0)