Skip to content

Commit b4f266d

Browse files
Essozclaude
andcommitted
fix: KeyError on missing varid attribute in APIContainRelation online_check
When iterating all varids of a given type, not every variable instance has every tracked attribute (e.g. _TRAINCHECK_grad_ID may be absent if grad was never observed). Skip varids that don't have the attribute in varid_map rather than crashing with KeyError. Also remove the remaining bare 'raise e' in the API-based invariant check block — the var-based block was fixed earlier but this one was missed, causing the checker to crash and stop on any API invariant exception. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e8a7520 commit b4f266d

3 files changed

Lines changed: 98 additions & 6 deletions

File tree

traincheck/checker_online.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
ALL_INVS: list[Invariant] = []
2929
CURRENT_STEP: int | None = None
3030
CURRENT_STAGE: str | None = None
31+
SAMPLING_INTERVAL: int | None = None
32+
WARM_UP_STEPS: int | None = None
3133
TOTAL_INVARIANTS = 0
3234
RELATION_TOTALS: dict[str, int] = {}
3335
REPORTER: ReportEmitter | None = None
@@ -167,6 +169,31 @@ def _record_violation_details(
167169
detail["sample_trace"] = trace[:8]
168170

169171

172+
def _read_sampling_config(
173+
trace_folders: list[str] | None,
174+
) -> tuple[int | None, int | None]:
175+
"""Parse sampling_interval and warm_up_steps from env_dump.txt in a trace folder."""
176+
import re
177+
178+
for folder in trace_folders or []:
179+
env_dump_path = os.path.join(folder, "env_dump.txt")
180+
if not os.path.exists(env_dump_path):
181+
continue
182+
sampling_interval: int | None = None
183+
warm_up_steps: int | None = None
184+
with open(env_dump_path) as fh:
185+
for line in fh:
186+
m = re.match(r"^sampling_interval:\s*(\d+)", line)
187+
if m:
188+
sampling_interval = int(m.group(1))
189+
m = re.match(r"^warm_up_steps:\s*(\d+)", line)
190+
if m:
191+
warm_up_steps = int(m.group(1))
192+
if sampling_interval is not None or warm_up_steps is not None:
193+
return sampling_interval, warm_up_steps
194+
return None, None
195+
196+
170197
def _emit_report(force: bool = False):
171198
if REPORTER is None:
172199
return
@@ -182,6 +209,8 @@ def _emit_report(force: bool = False):
182209
all_invs=ALL_INVS,
183210
current_step=CURRENT_STEP,
184211
current_stage=CURRENT_STAGE,
212+
sampling_interval=SAMPLING_INTERVAL,
213+
warm_up_steps=WARM_UP_STEPS,
185214
)
186215
report_state = (NUM_VIOLATIONS, len(FAILED_INV))
187216
REPORTER.emit(report_data, force=force, report_state=report_state)
@@ -198,6 +227,8 @@ def check(
198227
global ALL_INVS
199228
global CURRENT_STEP
200229
global CURRENT_STAGE
230+
global SAMPLING_INTERVAL
231+
global WARM_UP_STEPS
201232
global TOTAL_INVARIANTS
202233
global RELATION_TOTALS
203234

@@ -207,6 +238,8 @@ def check(
207238
logger.addHandler(logging.StreamHandler())
208239
logger.info("Starting online checker")
209240

241+
SAMPLING_INTERVAL, WARM_UP_STEPS = _read_sampling_config(trace_folders)
242+
210243
invs, param_to_invs, vartype_to_invs, needed_data = sort_inv_file(invariants)
211244
TOTAL_INVARIANTS = len(invs)
212245
ALL_INVS = list(invs)
@@ -331,7 +364,6 @@ def check(
331364
logger.error(
332365
f"Error when checking invariant {inv.text_description} with trace {trace_record}: {e}"
333366
)
334-
raise e
335367

336368
_emit_report()
337369

traincheck/invariant/contain_relation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,8 @@ def online_check(
13241324
attr_name = child_param.attr_name
13251325
elif isinstance(child_param, VarTypeParam):
13261326
attr_name = child_param.attr_name
1327+
if attr_name not in checker_data.varid_map[varid]:
1328+
continue
13271329
for i in reversed(
13281330
range(1, len(checker_data.varid_map[varid][attr_name]))
13291331
):

traincheck/reporting/checker_report.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ def build_online_report_data(
272272
all_invs: list | None = None,
273273
current_step: int | None = None,
274274
current_stage: str | None = None,
275+
sampling_interval: int | None = None,
276+
warm_up_steps: int | None = None,
275277
) -> dict:
276278
relation_violations: dict[str, int] = defaultdict(int)
277279
for inv in failed_inv:
@@ -280,6 +282,16 @@ def build_online_report_data(
280282
if violation_details is None:
281283
violation_details = {}
282284

285+
# Estimate how many steps have been checked so far.
286+
checked_steps: int | None = None
287+
if (
288+
current_step is not None
289+
and sampling_interval is not None
290+
and warm_up_steps is not None
291+
and sampling_interval > 0
292+
):
293+
checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + 1
294+
283295
def _make_entry(inv: Invariant, count: int) -> dict:
284296
detail = violation_details.get(inv, {})
285297
step_stages: list[tuple] = detail.get("step_stages") or []
@@ -300,6 +312,10 @@ def _make_entry(inv: Invariant, count: int) -> dict:
300312
)
301313
# deduplicated, sorted (step, stage) pairs for the expanded view
302314
unique_step_stages = sorted(set(step_stages), key=lambda x: x[0])
315+
unique_viol_steps = len(set(s for s, _ in step_stages))
316+
viol_rate: float | None = None
317+
if checked_steps is not None and checked_steps > 0:
318+
viol_rate = round(unique_viol_steps / checked_steps * 100, 1)
303319
return {
304320
"label": _format_invariant_label(inv),
305321
"relation": inv.relation.__name__,
@@ -310,6 +326,9 @@ def _make_entry(inv: Invariant, count: int) -> dict:
310326
"last_stage": last_stage,
311327
"step_stages": unique_step_stages[:100], # cap for HTML size
312328
"sample_trace": _summarize_trace_records(sample_trace),
329+
"violation_step_count": unique_viol_steps,
330+
"checked_steps": checked_steps,
331+
"violation_rate": viol_rate,
313332
}
314333

315334
# Sort by first violation step (earliest first), then by count descending.
@@ -372,6 +391,9 @@ def _sort_key(item):
372391
"traces": [],
373392
"top_violations": top_violations,
374393
"not_triggered_labels": not_triggered_labels,
394+
"sampling_interval": sampling_interval,
395+
"warm_up_steps": warm_up_steps,
396+
"checked_steps": checked_steps,
375397
}
376398

377399

@@ -433,6 +455,11 @@ def percent(part: int, total: int) -> float:
433455
top_table_html = "" # used only in online mode below
434456

435457
if mode == "online":
458+
sampling_interval = report_data.get("sampling_interval")
459+
warm_up_steps_val = report_data.get("warm_up_steps")
460+
checked_steps_total = report_data.get("checked_steps")
461+
has_sampling = sampling_interval is not None
462+
436463
rows = []
437464
for entry in top_violations:
438465
label = esc(str(entry.get("label", "")))
@@ -444,6 +471,9 @@ def percent(part: int, total: int) -> float:
444471
last_stage = entry.get("last_stage")
445472
step_stages: list = entry.get("step_stages") or []
446473
sample_trace = entry.get("sample_trace") or []
474+
violation_step_count = entry.get("violation_step_count", 0)
475+
entry_checked = entry.get("checked_steps")
476+
viol_rate = entry.get("violation_rate")
447477

448478
def _step_with_badge(step, stage) -> str:
449479
if step is None:
@@ -489,19 +519,34 @@ def _step_with_badge(step, stage) -> str:
489519
else:
490520
expand_content = f'<div class="trace-steps">Steps: {steps_html}</div>'
491521

522+
# Frequency cell: prefer rate when sampling info available
523+
if has_sampling and entry_checked is not None and entry_checked > 0:
524+
rate_str = f"{viol_rate}%" if viol_rate is not None else "?"
525+
freq_cell = (
526+
f'<span class="freq-rate">{rate_str}</span>'
527+
f'<span class="freq-detail">'
528+
f"{violation_step_count}/{entry_checked} steps"
529+
f"</span>"
530+
)
531+
else:
532+
freq_cell = f'<span class="freq-rate">{count}</span>'
533+
492534
rows.append(
493535
f"<tr>"
494536
f'<td><details><summary class="inv-label-summary">{label}</summary>'
495537
f'<div class="expand-body">{expand_content}</div></details>'
496538
f'<span class="inv-rel-tag">{relation}</span></td>'
497539
f'<td class="step-cell">{first_step_html}</td>'
498540
f'<td class="step-cell">{last_step_html}</td>'
499-
f'<td class="count-cell">{count}</td>'
541+
f'<td class="freq-cell">{freq_cell}</td>'
500542
f"</tr>"
501543
)
544+
545+
freq_col_header = "Frequency" if has_sampling else "Count"
502546
top_table_html = (
503-
'<table class="table viol-table"><thead>'
504-
"<tr><th>Invariant</th><th>First Step</th><th>Last Step</th><th>Count</th></tr>"
547+
f'<table class="table viol-table"><thead>'
548+
f"<tr><th>Invariant</th><th>First Step</th><th>Last Step</th>"
549+
f"<th>{freq_col_header}</th></tr>"
505550
f"</thead><tbody>{''.join(rows)}</tbody></table>"
506551
if rows
507552
else "<p>No violations yet.</p>"
@@ -791,8 +836,18 @@ def _step_with_badge(step, stage) -> str:
791836
first_step_note = ""
792837

793838
if mode == "online":
794-
panel_subtitle = (
795-
"Sorted by first violation step — click an invariant to expand trace"
839+
sampling_interval = report_data.get("sampling_interval")
840+
warm_up_steps_val = report_data.get("warm_up_steps")
841+
checked_steps_total = report_data.get("checked_steps")
842+
sampling_ctx = ""
843+
if sampling_interval is not None:
844+
sampling_ctx = f" · sampled every {sampling_interval} steps"
845+
if warm_up_steps_val is not None:
846+
sampling_ctx += f", warm-up {warm_up_steps_val}"
847+
if checked_steps_total is not None:
848+
sampling_ctx += f" ({checked_steps_total} steps checked)"
849+
panel_subtitle = esc(
850+
f"Sorted by first violation step — click to expand trace{sampling_ctx}"
796851
)
797852
panel_content = top_table_html
798853
else:
@@ -1016,6 +1071,9 @@ def _step_with_badge(step, stage) -> str:
10161071
.viol-table td {{ vertical-align: top; }}
10171072
.step-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 600; }}
10181073
.count-cell {{ white-space: nowrap; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }}
1074+
.freq-cell {{ white-space: nowrap; }}
1075+
.freq-rate {{ display: block; font-variant-numeric: tabular-nums; font-weight: 700; color: var(--failed); }}
1076+
.freq-detail {{ display: block; font-size: 11px; color: var(--muted); margin-top: 2px; }}
10191077
.stage-badge {{
10201078
display: inline-block;
10211079
font-size: 10px;

0 commit comments

Comments
 (0)