Skip to content

Commit 9ab2e88

Browse files
committed
feat: allow model to make more aggresive changes in the train.py
1 parent dc4eb76 commit 9ab2e88

1 file changed

Lines changed: 20 additions & 7 deletions

File tree

autoresearch/algo.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
MAX_ITERATIONS = 200
3232
TRAINING_TIMEOUT = 600 # 10 minutes
3333
MAX_MODEL_LEN = 65536 # larger context — dedicated GPU has plenty of room
34-
MAX_OUTPUT_TOKENS = 16384 # max tokens for LLM output (enough for full train.py)
34+
MAX_OUTPUT_TOKENS = 4096 # train.py is ~2K tokens; 4K is plenty and keeps generation fast
3535
TEMPERATURE = 0.7
36-
STAGNATION_THRESHOLD = 5 # consecutive non-improvements before nudge
36+
STAGNATION_THRESHOLD = 3 # consecutive non-improvements before nudge
3737
MAX_HISTORY_IN_PROMPT = 20 # only show last N iterations in prompt
3838
MAX_CONSECUTIVE_CRASHES = 3 # after N crashes with same pattern, force new direction
3939

@@ -227,11 +227,11 @@ def summarize_tried_directions(iterations):
227227

228228
summary = []
229229
if unique_crashes:
230-
summary.append("**Approaches that CRASHED** (do not repeat these):")
230+
summary.append("**Approaches that CRASHED** — understand WHY and fix the root cause if you want to try something similar:")
231231
for d in unique_crashes[-10:]: # last 10 unique crashes
232232
summary.append(f"- {d}")
233233
if successful_descs:
234-
summary.append("\n**Approaches that ran but did NOT improve** (try something different):")
234+
summary.append("\n**Approaches that ran but did NOT improve** try something meaningfully different:")
235235
for d in successful_descs[-8:]: # last 8
236236
summary.append(f"- {d}")
237237
return "\n".join(summary)
@@ -346,11 +346,11 @@ def build_prompt(program_md, prepare_py_summary, best_train_py, results,
346346
parts.append("You have been stuck for a LONG time. Make a BOLD change you haven't tried yet. ")
347347
parts.append("Consider: very different model depth/width ratios, different batch sizes (2x or 0.5x), ")
348348
parts.append("different warmdown ratios, or significantly different learning rates. ")
349+
parts.append("If a previous direction crashed, diagnose the root cause and try a corrected version of it. ")
349350
else:
350351
parts.append("Try a significantly different hyperparameter configuration: ")
351352
parts.append("different learning rates, batch size, depth, width, or warmdown ratio. ")
352-
parts.append("Do NOT try to simplify or remove architecture components — that always causes crashes. ")
353-
parts.append("Keep the same architecture but change the numbers.\n\n")
353+
parts.append("\n\n")
354354

355355
parts.append("Now propose your next experiment. Start with 'Changes:' followed by a 1-2 sentence description, "
356356
"then the COMPLETE train.py in a ```python code block. "
@@ -489,6 +489,7 @@ def main():
489489
best_val_bpb = baseline_entry["val_bpb"]
490490
last_error = None
491491
consecutive_non_improvements = 0
492+
consecutive_crashes = 0
492493

493494
for iteration in range(1, MAX_ITERATIONS + 1):
494495
log(f"--- Iteration {iteration}/{MAX_ITERATIONS} ---")
@@ -541,6 +542,7 @@ def main():
541542
"You MUST include the COMPLETE train.py inside a ```python code block. "
542543
"Do not use any other format.")
543544
consecutive_non_improvements += 1
545+
consecutive_crashes += 1
544546
continue
545547

546548
# Syntax check
@@ -560,6 +562,7 @@ def main():
560562
save_results(results)
561563
last_error = f"Your code had a syntax error:\n{syntax_err}\nPlease output the COMPLETE, valid train.py file."
562564
consecutive_non_improvements += 1
565+
consecutive_crashes += 1
563566
continue
564567

565568
# Run training
@@ -581,12 +584,22 @@ def main():
581584
save_results(results)
582585
# Revert to best
583586
write_file(TRAIN_PY_PATH, best_train_py)
584-
last_error = run_result["error"]
587+
consecutive_crashes += 1
585588
consecutive_non_improvements += 1
589+
if consecutive_crashes >= MAX_CONSECUTIVE_CRASHES:
590+
log(f"WARNING: {consecutive_crashes} consecutive crashes — forcing direction change in prompt")
591+
last_error = (run_result["error"] +
592+
f"\n\nYou have crashed {consecutive_crashes} times in a row. "
593+
"STOP repeating the same approach. Try something completely different: "
594+
"change the model architecture style, batch size, or learning rate schedule entirely.")
595+
consecutive_crashes = 0 # reset so we get a fresh window
596+
else:
597+
last_error = run_result["error"]
586598
continue
587599

588600
val_bpb = run_result["metrics"]["val_bpb"]
589601
peak_vram_mb = run_result["metrics"].get("peak_vram_mb", 0)
602+
consecutive_crashes = 0 # successful run resets crash streak
590603

591604
# Decision: keep or discard
592605
if val_bpb < best_val_bpb:

0 commit comments

Comments
 (0)