Skip to content

Commit e44eafd

Browse files
committed
feat: make agent more creative | reduce stagnation through better prompt
1 parent 6cfc3eb commit e44eafd

1 file changed

Lines changed: 88 additions & 7 deletions

File tree

autoresearch/algo.py

Lines changed: 88 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
TRAINING_TIMEOUT = 600 # 10 minutes
2525
MAX_MODEL_LEN = 40960 # total context window (input + output)
2626
MAX_OUTPUT_TOKENS = 16384 # max tokens for LLM output (enough for full train.py)
27-
TEMPERATURE = 0.5
27+
TEMPERATURE = 0.7
2828
STAGNATION_THRESHOLD = 5 # consecutive non-improvements before nudge
2929
MAX_HISTORY_IN_PROMPT = 20 # only show last N iterations in prompt
30+
MAX_CONSECUTIVE_CRASHES = 3 # after N crashes with same pattern, force new direction
3031

3132
RESULTS_PATH = "/app/results.json"
3233
RESULTS_OUTPUT_PATH = "/data/outputs/results.json"
@@ -189,6 +190,58 @@ def run_training(train_py_content):
189190
}
190191

191192

193+
def summarize_tried_directions(iterations):
194+
"""Summarize what has been tried to help the LLM avoid repeating itself."""
195+
crashed_descs = []
196+
successful_descs = []
197+
for it in iterations:
198+
desc = it.get("description", "")[:100]
199+
if not desc or desc in ("baseline", "no description"):
200+
continue
201+
if it["status"] == "crash":
202+
crashed_descs.append(desc)
203+
elif it["status"] == "discard":
204+
successful_descs.append(f"bpb={it['val_bpb']:.4f}: {desc}")
205+
206+
# Deduplicate similar crash descriptions (first 40 chars)
207+
seen = set()
208+
unique_crashes = []
209+
for d in crashed_descs:
210+
key = d[:40].lower()
211+
if key not in seen:
212+
seen.add(key)
213+
unique_crashes.append(d)
214+
215+
summary = []
216+
if unique_crashes:
217+
summary.append("**Approaches that CRASHED** (do not repeat these):")
218+
for d in unique_crashes[-10:]: # last 10 unique crashes
219+
summary.append(f"- {d}")
220+
if successful_descs:
221+
summary.append("\n**Approaches that ran but did NOT improve** (try something different):")
222+
for d in successful_descs[-8:]: # last 8
223+
summary.append(f"- {d}")
224+
return "\n".join(summary)
225+
226+
227+
def classify_error(error_text):
228+
"""Classify a runtime error to give the LLM actionable feedback."""
229+
if not error_text:
230+
return None
231+
e = error_text.lower()
232+
if "cuda out of memory" in e or "out of memory" in e or "oom" in e:
233+
return "OOM"
234+
if "timed out" in e or "timeout" in e:
235+
return "TIMEOUT"
236+
if "shape" in e or "size mismatch" in e or "dimension" in e:
237+
return "SHAPE_MISMATCH"
238+
if "import" in e or "no module" in e or "cannot import" in e:
239+
return "IMPORT_ERROR"
240+
if "nan" in e or "inf" in e:
241+
return "NUMERICAL"
242+
return "RUNTIME_ERROR"
243+
244+
192245
def build_prompt(program_md, prepare_py_summary, best_train_py, results,
193246
last_error=None, consecutive_non_improvements=0):
194247
"""Build the full prompt for the LLM.
@@ -241,20 +294,48 @@ def build_prompt(program_md, prepare_py_summary, best_train_py, results,
241294
best = results["best"]
242295
parts.append(f"**Current best**: iteration {best['iteration']}, val_bpb={best['val_bpb']:.6f}\n\n")
243296

244-
# Error from last iteration
297+
# Summary of tried directions (helps avoid repetition)
298+
if len(iterations) > 5:
299+
tried_summary = summarize_tried_directions(iterations)
300+
if tried_summary:
301+
parts.append("## What Has Been Tried\n")
302+
parts.append(tried_summary)
303+
parts.append("\n\n")
304+
305+
# Error from last iteration (with classification)
245306
if last_error:
246-
# Truncate error to avoid blowing up the prompt
307+
error_type = classify_error(last_error)
247308
truncated_error = last_error[:500]
248309
parts.append("## Previous Iteration Error\n")
249-
parts.append(f"The previous iteration failed:\n```\n{truncated_error}\n```\n")
250-
parts.append("Please fix the issue or try a different approach.\n\n")
310+
if error_type == "OOM":
311+
parts.append("**Error type: OUT OF MEMORY.** Your model was too large. "
312+
"Reduce model size, batch size, or sequence length.\n")
313+
elif error_type == "TIMEOUT":
314+
parts.append("**Error type: TIMEOUT.** Training took too long. "
315+
"Reduce model size or batch size to fit within 5 minutes.\n")
316+
elif error_type == "SHAPE_MISMATCH":
317+
parts.append("**Error type: TENSOR SHAPE MISMATCH.** Check that dimensions are consistent "
318+
"across model config, attention heads, and embeddings.\n")
319+
elif error_type == "IMPORT_ERROR":
320+
parts.append("**Error type: IMPORT ERROR.** You used a module that doesn't exist. "
321+
"Only use imports from the baseline.\n")
322+
elif error_type == "NUMERICAL":
323+
parts.append("**Error type: NUMERICAL INSTABILITY (NaN/Inf).** "
324+
"Learning rate may be too high, or initialization is unstable.\n")
325+
parts.append(f"```\n{truncated_error}\n```\n")
326+
parts.append("Fix the issue or try a completely different approach.\n\n")
251327

252328
# Stagnation nudge
253329
if consecutive_non_improvements >= STAGNATION_THRESHOLD:
254330
parts.append("## NOTE: Stagnation Detected\n")
255331
parts.append(f"The last {consecutive_non_improvements} iterations did not improve val_bpb. ")
256-
parts.append("Try a significantly different hyperparameter configuration: ")
257-
parts.append("different learning rates, batch size, depth, width, or warmdown ratio. ")
332+
if consecutive_non_improvements >= STAGNATION_THRESHOLD * 2:
333+
parts.append("You have been stuck for a LONG time. Make a BOLD change you haven't tried yet. ")
334+
parts.append("Consider: very different model depth/width ratios, different batch sizes (2x or 0.5x), ")
335+
parts.append("different warmdown ratios, or significantly different learning rates. ")
336+
else:
337+
parts.append("Try a significantly different hyperparameter configuration: ")
338+
parts.append("different learning rates, batch size, depth, width, or warmdown ratio. ")
258339
parts.append("Do NOT try to simplify or remove architecture components — that always causes crashes. ")
259340
parts.append("Keep the same architecture but change the numbers.\n\n")
260341

0 commit comments

Comments
 (0)