|
24 | 24 | TRAINING_TIMEOUT = 600 # 10 minutes |
25 | 25 | MAX_MODEL_LEN = 40960 # total context window (input + output) |
26 | 26 | MAX_OUTPUT_TOKENS = 16384 # max tokens for LLM output (enough for full train.py) |
27 | | -TEMPERATURE = 0.5 |
| 27 | +TEMPERATURE = 0.7 |
28 | 28 | STAGNATION_THRESHOLD = 5 # consecutive non-improvements before nudge |
29 | 29 | MAX_HISTORY_IN_PROMPT = 20 # only show last N iterations in prompt |
| 30 | +MAX_CONSECUTIVE_CRASHES = 3 # after N crashes with same pattern, force new direction |
30 | 31 |
|
31 | 32 | RESULTS_PATH = "/app/results.json" |
32 | 33 | RESULTS_OUTPUT_PATH = "/data/outputs/results.json" |
@@ -189,6 +190,58 @@ def run_training(train_py_content): |
189 | 190 | } |
190 | 191 |
|
191 | 192 |
|
| 193 | +def summarize_tried_directions(iterations): |
| 194 | + """Summarize what has been tried to help the LLM avoid repeating itself.""" |
| 195 | + crashed_descs = [] |
| 196 | + successful_descs = [] |
| 197 | + for it in iterations: |
| 198 | + desc = it.get("description", "")[:100] |
| 199 | + if not desc or desc in ("baseline", "no description"): |
| 200 | + continue |
| 201 | + if it["status"] == "crash": |
| 202 | + crashed_descs.append(desc) |
| 203 | + elif it["status"] == "discard": |
| 204 | + successful_descs.append(f"bpb={it['val_bpb']:.4f}: {desc}") |
| 205 | + |
| 206 | + # Deduplicate similar crash descriptions (first 40 chars) |
| 207 | + seen = set() |
| 208 | + unique_crashes = [] |
| 209 | + for d in crashed_descs: |
| 210 | + key = d[:40].lower() |
| 211 | + if key not in seen: |
| 212 | + seen.add(key) |
| 213 | + unique_crashes.append(d) |
| 214 | + |
| 215 | + summary = [] |
| 216 | + if unique_crashes: |
| 217 | + summary.append("**Approaches that CRASHED** (do not repeat these):") |
| 218 | + for d in unique_crashes[-10:]: # last 10 unique crashes |
| 219 | + summary.append(f"- {d}") |
| 220 | + if successful_descs: |
| 221 | + summary.append("\n**Approaches that ran but did NOT improve** (try something different):") |
| 222 | + for d in successful_descs[-8:]: # last 8 |
| 223 | + summary.append(f"- {d}") |
| 224 | + return "\n".join(summary) |
| 225 | + |
| 226 | + |
| 227 | +def classify_error(error_text): |
| 228 | + """Classify a runtime error to give the LLM actionable feedback.""" |
| 229 | + if not error_text: |
| 230 | + return None |
| 231 | + e = error_text.lower() |
| 232 | + if "cuda out of memory" in e or "out of memory" in e or "oom" in e: |
| 233 | + return "OOM" |
| 234 | + if "timed out" in e or "timeout" in e: |
| 235 | + return "TIMEOUT" |
| 236 | + if "shape" in e or "size mismatch" in e or "dimension" in e: |
| 237 | + return "SHAPE_MISMATCH" |
| 238 | + if "import" in e or "no module" in e or "cannot import" in e: |
| 239 | + return "IMPORT_ERROR" |
| 240 | + if "nan" in e or "inf" in e: |
| 241 | + return "NUMERICAL" |
| 242 | + return "RUNTIME_ERROR" |
| 243 | + |
| 244 | + |
192 | 245 | def build_prompt(program_md, prepare_py_summary, best_train_py, results, |
193 | 246 | last_error=None, consecutive_non_improvements=0): |
194 | 247 | """Build the full prompt for the LLM. |
@@ -241,20 +294,48 @@ def build_prompt(program_md, prepare_py_summary, best_train_py, results, |
241 | 294 | best = results["best"] |
242 | 295 | parts.append(f"**Current best**: iteration {best['iteration']}, val_bpb={best['val_bpb']:.6f}\n\n") |
243 | 296 |
|
244 | | - # Error from last iteration |
| 297 | + # Summary of tried directions (helps avoid repetition) |
| 298 | + if len(iterations) > 5: |
| 299 | + tried_summary = summarize_tried_directions(iterations) |
| 300 | + if tried_summary: |
| 301 | + parts.append("## What Has Been Tried\n") |
| 302 | + parts.append(tried_summary) |
| 303 | + parts.append("\n\n") |
| 304 | + |
| 305 | + # Error from last iteration (with classification) |
245 | 306 | if last_error: |
246 | | - # Truncate error to avoid blowing up the prompt |
| 307 | + error_type = classify_error(last_error) |
247 | 308 | truncated_error = last_error[:500] |
248 | 309 | parts.append("## Previous Iteration Error\n") |
249 | | - parts.append(f"The previous iteration failed:\n```\n{truncated_error}\n```\n") |
250 | | - parts.append("Please fix the issue or try a different approach.\n\n") |
| 310 | + if error_type == "OOM": |
| 311 | + parts.append("**Error type: OUT OF MEMORY.** Your model was too large. " |
| 312 | + "Reduce model size, batch size, or sequence length.\n") |
| 313 | + elif error_type == "TIMEOUT": |
| 314 | + parts.append("**Error type: TIMEOUT.** Training took too long. " |
| 315 | + "Reduce model size or batch size to fit within 5 minutes.\n") |
| 316 | + elif error_type == "SHAPE_MISMATCH": |
| 317 | + parts.append("**Error type: TENSOR SHAPE MISMATCH.** Check that dimensions are consistent " |
| 318 | + "across model config, attention heads, and embeddings.\n") |
| 319 | + elif error_type == "IMPORT_ERROR": |
| 320 | + parts.append("**Error type: IMPORT ERROR.** You used a module that doesn't exist. " |
| 321 | + "Only use imports from the baseline.\n") |
| 322 | + elif error_type == "NUMERICAL": |
| 323 | + parts.append("**Error type: NUMERICAL INSTABILITY (NaN/Inf).** " |
| 324 | + "Learning rate may be too high, or initialization is unstable.\n") |
| 325 | + parts.append(f"```\n{truncated_error}\n```\n") |
| 326 | + parts.append("Fix the issue or try a completely different approach.\n\n") |
251 | 327 |
|
252 | 328 | # Stagnation nudge |
253 | 329 | if consecutive_non_improvements >= STAGNATION_THRESHOLD: |
254 | 330 | parts.append("## NOTE: Stagnation Detected\n") |
255 | 331 | parts.append(f"The last {consecutive_non_improvements} iterations did not improve val_bpb. ") |
256 | | - parts.append("Try a significantly different hyperparameter configuration: ") |
257 | | - parts.append("different learning rates, batch size, depth, width, or warmdown ratio. ") |
| 332 | + if consecutive_non_improvements >= STAGNATION_THRESHOLD * 2: |
| 333 | + parts.append("You have been stuck for a LONG time. Make a BOLD change you haven't tried yet. ") |
| 334 | + parts.append("Consider: very different model depth/width ratios, different batch sizes (2x or 0.5x), ") |
| 335 | + parts.append("different warmdown ratios, or significantly different learning rates. ") |
| 336 | + else: |
| 337 | + parts.append("Try a significantly different hyperparameter configuration: ") |
| 338 | + parts.append("different learning rates, batch size, depth, width, or warmdown ratio. ") |
258 | 339 | parts.append("Do NOT try to simplify or remove architecture components — that always causes crashes. ") |
259 | 340 | parts.append("Keep the same architecture but change the numbers.\n\n") |
260 | 341 |
|
|
0 commit comments