|
| 1 | +#!/usr/bin/env bash |
| 2 | +# Anchor-strengthening mitigation test for the working memory cliff. |
| 3 | +# |
| 4 | +# Phase 2B identified the cliff failure mode as "primacy-biased document |
| 5 | +# continuation overflow" — the chat-template anchor at the start of the |
| 6 | +# prompt gets overpowered by the haystack continuation prior. This script |
| 7 | +# tests two cheap interventions that strengthen the anchor without |
| 8 | +# touching the model: |
| 9 | +# |
| 10 | +# ARM 1: BASELINE |
| 11 | +# Just haystack + final question. The current niah_test.sh format. |
| 12 | +# |
| 13 | +# ARM 2: PQRI (Periodic Question Re-Injection) |
| 14 | +# Insert "[REMINDER: <question>]" markers inside the haystack every |
| 15 | +# ~256 tokens. The chat template anchor is conceptually refreshed |
| 16 | +# because the question reappears throughout, not just at the end. |
| 17 | +# |
| 18 | +# ARM 3: CONVCHUNK (Conversational Chunking) |
| 19 | +# Split the haystack into N=4 chunks, wrap each as a separate |
| 20 | +# <|user|> turn with the SAME question, terminated by <|eot_id|>. |
| 21 | +# The chat template anchor is *literally* refreshed at every turn |
| 22 | +# boundary because each chunk is a fresh user message. |
| 23 | +# |
| 24 | +# Grid: 3B Q4 × 4 contexts (1024, 1280, 1536, 2048) × 3 needles × 3 arms |
| 25 | +# = 36 trials. ~30 min on Metal. |
| 26 | + |
| 27 | +set -e |
| 28 | +export LC_ALL=C |
| 29 | +export LANG=C |
| 30 | + |
| 31 | +TQ=${TQ:-./build_metal/quant} |
| 32 | +MODEL=${MODEL:-models/Llama-3.2-3B-Instruct-Q8_0.gguf} |
| 33 | +THREADS=${THREADS:-8} |
| 34 | +OUT_DIR=bench/results/niah |
| 35 | +RUN_ID=$(date -u +%Y%m%dT%H%M%S) |
| 36 | +RAW_LOG="$OUT_DIR/raw_anchor_${RUN_ID}.log" |
| 37 | +RESULT_CSV="$OUT_DIR/results_anchor_${RUN_ID}.csv" |
| 38 | + |
| 39 | +mkdir -p "$OUT_DIR" |
| 40 | +echo "arm,context,needle_idx,pass,response" > "$RESULT_CSV" |
| 41 | + |
| 42 | +NEEDLE_0="The chief financial officer of Northwind Logistics is Sarah Chen, hired in 2023." |
| 43 | +QUESTION_0="Who is the chief financial officer of Northwind Logistics? Answer with the full name." |
| 44 | +KEYWORD_0="Sarah\|Chen" |
| 45 | + |
| 46 | +NEEDLE_1="The launch date for Project Aurora is November 14th in San Francisco." |
| 47 | +QUESTION_1="When and where will Project Aurora launch? Answer in one sentence." |
| 48 | +KEYWORD_1="November\|San Francisco" |
| 49 | + |
| 50 | +NEEDLE_2="The reactor cooling tank at the Helios facility holds exactly eight thousand liters of distilled water." |
| 51 | +QUESTION_2="How much distilled water does the reactor cooling tank at Helios hold?" |
| 52 | +KEYWORD_2="eight thousand\|8000\|8,000" |
| 53 | + |
| 54 | +NEEDLES=("$NEEDLE_0" "$NEEDLE_1" "$NEEDLE_2") |
| 55 | +QUESTIONS=("$QUESTION_0" "$QUESTION_1" "$QUESTION_2") |
| 56 | +KEYWORDS=("$KEYWORD_0" "$KEYWORD_1" "$KEYWORD_2") |
| 57 | + |
| 58 | +CONTEXTS=(1024 1280 1536 2048) |
| 59 | +ARMS=("baseline" "pqri" "convchunk") |
| 60 | +DEPTH=0.5 |
| 61 | + |
| 62 | +# ---------------------------------------------------------------------------- |
| 63 | +# Prompt builders — one per arm |
| 64 | +# ---------------------------------------------------------------------------- |
| 65 | + |
| 66 | +build_prompt_baseline() { |
| 67 | + python3 - "$1" "$2" "$3" <<'PYEOF' |
| 68 | +import sys |
| 69 | +ctx_tokens=int(sys.argv[1]); needle=sys.argv[2]; question=sys.argv[3] |
| 70 | +with open("bench/data/wikitext2_test.txt") as f: |
| 71 | + raw=f.read() |
| 72 | +target=int(ctx_tokens*3.6) |
| 73 | +hay=raw[:target] |
| 74 | +end=hay.rfind(". ") |
| 75 | +if end>0: hay=hay[:end+1] |
| 76 | +sb=hay.rfind(". ", 0, max(len(hay)//2,2)) |
| 77 | +sb = 0 if sb<0 else sb+2 |
| 78 | +h=hay[:sb]+needle+" "+hay[sb:] |
| 79 | +sys.stdout.write(h+"\n\nQuestion: "+question) |
| 80 | +PYEOF |
| 81 | +} |
| 82 | + |
| 83 | +build_prompt_pqri() { |
| 84 | + # Same haystack + needle layout as baseline, but with periodic reminder |
| 85 | + # markers inserted every ~256 tokens (~920 chars). |
| 86 | + python3 - "$1" "$2" "$3" <<'PYEOF' |
| 87 | +import sys |
| 88 | +ctx_tokens=int(sys.argv[1]); needle=sys.argv[2]; question=sys.argv[3] |
| 89 | +with open("bench/data/wikitext2_test.txt") as f: |
| 90 | + raw=f.read() |
| 91 | +target=int(ctx_tokens*3.6) |
| 92 | +hay=raw[:target] |
| 93 | +end=hay.rfind(". ") |
| 94 | +if end>0: hay=hay[:end+1] |
| 95 | +sb=hay.rfind(". ", 0, max(len(hay)//2,2)) |
| 96 | +sb = 0 if sb<0 else sb+2 |
| 97 | +h=hay[:sb]+needle+" "+hay[sb:] |
| 98 | +
|
| 99 | +# Reminder format — minimal, recognisable, not chat-template-token-disruptive |
| 100 | +reminder = f" [REMINDER: {question}] " |
| 101 | +
|
| 102 | +# Insert at every ~920 chars (~256 tokens) at sentence boundaries |
| 103 | +INTERVAL = 920 |
| 104 | +parts = [] |
| 105 | +pos = 0 |
| 106 | +while pos < len(h): |
| 107 | + end_pos = min(pos + INTERVAL, len(h)) |
| 108 | + # Snap to next sentence boundary |
| 109 | + if end_pos < len(h): |
| 110 | + sb_next = h.find(". ", end_pos) |
| 111 | + if sb_next > 0 and sb_next - end_pos < 200: |
| 112 | + end_pos = sb_next + 2 |
| 113 | + parts.append(h[pos:end_pos]) |
| 114 | + pos = end_pos |
| 115 | +
|
| 116 | +augmented = reminder.join(parts) |
| 117 | +sys.stdout.write(augmented + "\n\nQuestion: " + question) |
| 118 | +PYEOF |
| 119 | +} |
| 120 | + |
| 121 | +build_prompt_convchunk() { |
| 122 | + # Split the haystack into 4 user turns, each terminated by <|eot_id|>, |
| 123 | + # with the question repeated at every turn. The chat template anchor |
| 124 | + # is *literally* refreshed because each chunk is a fresh user message |
| 125 | + # with its own header. |
| 126 | + python3 - "$1" "$2" "$3" <<'PYEOF' |
| 127 | +import sys |
| 128 | +ctx_tokens=int(sys.argv[1]); needle=sys.argv[2]; question=sys.argv[3] |
| 129 | +with open("bench/data/wikitext2_test.txt") as f: |
| 130 | + raw=f.read() |
| 131 | +target=int(ctx_tokens*3.6) |
| 132 | +hay=raw[:target] |
| 133 | +end=hay.rfind(". ") |
| 134 | +if end>0: hay=hay[:end+1] |
| 135 | +sb=hay.rfind(". ", 0, max(len(hay)//2,2)) |
| 136 | +sb = 0 if sb<0 else sb+2 |
| 137 | +h=hay[:sb]+needle+" "+hay[sb:] |
| 138 | +
|
| 139 | +# Split into 4 roughly equal chunks at sentence boundaries |
| 140 | +N_CHUNKS = 4 |
| 141 | +chunks = [] |
| 142 | +chunk_target = len(h) // N_CHUNKS |
| 143 | +pos = 0 |
| 144 | +for i in range(N_CHUNKS - 1): |
| 145 | + end_pos = pos + chunk_target |
| 146 | + sb_next = h.find(". ", end_pos) |
| 147 | + if sb_next > 0 and sb_next - end_pos < 200: |
| 148 | + end_pos = sb_next + 2 |
| 149 | + chunks.append(h[pos:end_pos]) |
| 150 | + pos = end_pos |
| 151 | +chunks.append(h[pos:]) |
| 152 | +
|
| 153 | +# Build the chat-format prompt manually (we will NOT use --chat for this arm) |
| 154 | +# because we need raw control over the multi-turn structure. |
| 155 | +prompt = "" |
| 156 | +for i, ch in enumerate(chunks): |
| 157 | + prompt += "<|start_header_id|>user<|end_header_id|>\n\n" |
| 158 | + prompt += f"Document part {i+1}/{N_CHUNKS}:\n{ch}\n\n" |
| 159 | + if i < N_CHUNKS - 1: |
| 160 | + prompt += "Acknowledged, continue with the next part." |
| 161 | + else: |
| 162 | + prompt += f"Question: {question}" |
| 163 | + prompt += "<|eot_id|>" |
| 164 | +prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| 165 | +sys.stdout.write(prompt) |
| 166 | +PYEOF |
| 167 | +} |
| 168 | + |
| 169 | +# ---------------------------------------------------------------------------- |
| 170 | +# Run loop |
| 171 | +# ---------------------------------------------------------------------------- |
| 172 | +total=$(( ${#ARMS[@]} * ${#CONTEXTS[@]} * ${#NEEDLES[@]} )) |
| 173 | +run_idx=0 |
| 174 | + |
| 175 | +echo "==> Anchor mitigation test" |
| 176 | +echo " binary: $TQ" |
| 177 | +echo " model: $MODEL" |
| 178 | +echo " contexts: ${CONTEXTS[*]}" |
| 179 | +echo " arms: ${ARMS[*]}" |
| 180 | +echo " needles: ${#NEEDLES[@]} total: $total" |
| 181 | +echo " raw: $RAW_LOG" |
| 182 | +echo " csv: $RESULT_CSV" |
| 183 | +echo "" |
| 184 | + |
| 185 | +for arm in "${ARMS[@]}"; do |
| 186 | + for ctx in "${CONTEXTS[@]}"; do |
| 187 | + # Generous cli_ctx: wikitext (with @-@ and == markup) tokenises closer |
| 188 | + # to 3 chars/token than 4, and PQRI/convchunk add 100-300 tokens of |
| 189 | + # reminder/wrap overhead. Use 2× ctx + 256 to be safe everywhere. |
| 190 | + cli_ctx=$(( ctx * 2 + 256 )) |
| 191 | + for ni in "${!NEEDLES[@]}"; do |
| 192 | + run_idx=$(( run_idx + 1 )) |
| 193 | + needle="${NEEDLES[$ni]}" |
| 194 | + question="${QUESTIONS[$ni]}" |
| 195 | + keyword="${KEYWORDS[$ni]}" |
| 196 | + |
| 197 | + case "$arm" in |
| 198 | + baseline) |
| 199 | + prompt=$(build_prompt_baseline "$ctx" "$needle" "$question") |
| 200 | + chat_flag="--chat" |
| 201 | + ;; |
| 202 | + pqri) |
| 203 | + prompt=$(build_prompt_pqri "$ctx" "$needle" "$question") |
| 204 | + chat_flag="--chat" |
| 205 | + ;; |
| 206 | + convchunk) |
| 207 | + prompt=$(build_prompt_convchunk "$ctx" "$needle" "$question") |
| 208 | + chat_flag="" # prompt is already chat-formatted |
| 209 | + ;; |
| 210 | + esac |
| 211 | + |
| 212 | + printf "[%2d/%d] %-10s ctx=%-5d needle=%d " \ |
| 213 | + "$run_idx" "$total" "$arm" "$ctx" "$ni" |
| 214 | + |
| 215 | + out=$( "$TQ" "$MODEL" -p "$prompt" -n 32 -T 0.0 -j "$THREADS" \ |
| 216 | + $chat_flag --ctx "$cli_ctx" -k fp32 2>&1 || true ) |
| 217 | + |
| 218 | + resp=$(echo "$out" | awk ' |
| 219 | + /^---$/ { n++; next } |
| 220 | + n==1 && /^\[tokenizer\]/ { next } |
| 221 | + n==1 { print } |
| 222 | + ' || true) |
| 223 | + if [ -z "$resp" ]; then resp=$(echo "$out" | tail -3 | head -1); fi |
| 224 | + resp_csv=$(echo "$resp" | tr '\n' ' ' | sed 's/"/""/g') |
| 225 | + |
| 226 | + if echo "$resp" | grep -qiE "$(echo "$keyword" | sed 's/\\|/|/g')"; then |
| 227 | + pass=1; echo "PASS" |
| 228 | + else |
| 229 | + pass=0; echo "FAIL: ${resp:0:55}" |
| 230 | + fi |
| 231 | + |
| 232 | + echo "$arm,$ctx,$ni,$pass,\"$resp_csv\"" >> "$RESULT_CSV" |
| 233 | + echo "===== $arm ctx=$ctx needle=$ni =====" >> "$RAW_LOG" |
| 234 | + echo "$out" >> "$RAW_LOG" |
| 235 | + echo "" >> "$RAW_LOG" |
| 236 | + done |
| 237 | + done |
| 238 | +done |
| 239 | + |
| 240 | +echo "" |
| 241 | +echo "==> Summary by (arm × ctx):" |
| 242 | +printf " %-10s" "arm" |
| 243 | +for ctx in "${CONTEXTS[@]}"; do printf " %7d" "$ctx"; done |
| 244 | +echo "" |
| 245 | +for arm in "${ARMS[@]}"; do |
| 246 | + printf " %-10s" "$arm" |
| 247 | + for ctx in "${CONTEXTS[@]}"; do |
| 248 | + pass=$(awk -F, -v a="$arm" -v c="$ctx" 'NR>1 && $1==a && $2==c {p+=$4; t++} END{if(t>0)printf "%d/%d", p, t; else print "n/a"}' "$RESULT_CSV") |
| 249 | + printf " %7s" "$pass" |
| 250 | + done |
| 251 | + echo "" |
| 252 | +done |
0 commit comments