Skip to content

Commit 9cf8064

Browse files
Essozclaude
andcommitted
feat: add step/stage/trace detail to offline report, W&B, and MLflow
- _count_failed_invariants now tracks last_step, step_stages (step→stage map from all violation traces), and sample_trace (first violation) - Offline HTML violations panel and per-trace failed-invariants lists now use the same expandable table format as the online report: First Step / Last Step / Count columns, stage badges, collapsible step timeline and sample trace rows - W&B violations table gains a last_step column; summary gains violations/last_step - MLflow gains violations_last_step metric Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9d8d04c commit 9cf8064

1 file changed

Lines changed: 173 additions & 39 deletions

File tree

traincheck/reporting/checker_report.py

Lines changed: 173 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def _count_failed_invariants(
9191
) -> list[dict[str, object]]:
9292
counter: Counter[tuple[str, str]] = Counter()
9393
first_steps: dict[tuple[str, str], int | None] = {}
94+
last_steps: dict[tuple[str, str], int | None] = {}
95+
step_stage_maps: dict[tuple[str, str], dict] = defaultdict(dict)
96+
sample_traces: dict[tuple[str, str], list] = {}
9497
for res in results:
9598
if not res.check_passed:
9699
label = _format_invariant_label(res.invariant)
@@ -103,15 +106,36 @@ def _count_failed_invariants(
103106
first_steps[key] = (
104107
min(steps) if existing is None else min(existing, min(steps))
105108
)
109+
existing_last = last_steps.get(key)
110+
last_steps[key] = (
111+
max(steps)
112+
if existing_last is None
113+
else max(existing_last, max(steps))
114+
)
106115
elif key not in first_steps:
107116
first_steps[key] = None
117+
last_steps[key] = None
118+
# Accumulate step → stage (first stage seen per step wins)
119+
for rec in res.trace or []:
120+
if not isinstance(rec, dict):
121+
continue
122+
step = rec.get("meta_vars.step")
123+
stage = rec.get("meta_vars.stage")
124+
if step is not None and step not in step_stage_maps[key]:
125+
step_stage_maps[key][step] = stage
126+
# One sample trace per invariant (first violation wins)
127+
if key not in sample_traces and res.trace:
128+
sample_traces[key] = _summarize_trace_records(res.trace)
108129
top_pairs = counter.most_common(10)
109130
return [
110131
{
111132
"label": label,
112133
"relation": relation,
113134
"count": count,
114135
"first_step": first_steps.get((label, relation)),
136+
"last_step": last_steps.get((label, relation)),
137+
"step_stages": sorted(step_stage_maps[(label, relation)].items()),
138+
"sample_trace": sample_traces.get((label, relation), []),
115139
}
116140
for (label, relation), count in top_pairs
117141
]
@@ -434,11 +458,7 @@ def percent(part: int, total: int) -> float:
434458
traces = report_data.get("traces", [])
435459
top_violations = report_data.get("top_violations", [])
436460

437-
# Build top violations HTML differently per mode.
438-
# For online mode: step-sorted table with expandable trace rows.
439-
# For offline mode: simple list (unchanged).
440-
top_list = "" # used only in offline mode below
441-
top_table_html = "" # used only in online mode below
461+
top_table_html = ""
442462

443463
if mode == "online":
444464
sampling_interval = report_data.get("sampling_interval")
@@ -538,27 +558,77 @@ def _step_with_badge(step, stage) -> str:
538558
else "<p>No violations yet.</p>"
539559
)
540560
else:
541-
top_items = []
561+
rows = []
542562
for entry in top_violations:
543563
label = esc(str(entry.get("label", "")))
544564
relation = esc(str(entry.get("relation", "")))
545-
count = entry.get("count")
565+
count = entry.get("count", "")
546566
first_step = entry.get("first_step")
547-
count_html = f'<span class="inv-count">{count}</span>' if count else ""
548-
if first_step is not None:
549-
step_note = f"first seen at step {first_step}"
550-
if count and count > 1:
551-
step_note += f" · {count} occurrences"
552-
detail = esc(f"{entry.get('relation', '')}{step_note}")
567+
last_step = entry.get("last_step")
568+
off_step_stages: list = entry.get("step_stages") or []
569+
off_sample_trace = entry.get("sample_trace") or []
570+
571+
def _step_cell(step, _ss=off_step_stages) -> str:
572+
if step is None:
573+
return "—"
574+
stage = next((s for st, s in _ss if st == step), None)
575+
badge = _render_stage_badge(stage, esc)
576+
return f"{badge}{step}"
577+
578+
first_step_html = _step_cell(first_step)
579+
last_step_html = _step_cell(last_step)
580+
steps_html = _render_step_stages_html(off_step_stages, esc)
581+
582+
if off_sample_trace:
583+
off_keys: list[str] = []
584+
for rec in off_sample_trace:
585+
for k in rec:
586+
if k not in off_keys:
587+
off_keys.append(k)
588+
trace_head = "".join(f"<th>{esc(k)}</th>" for k in off_keys)
589+
trace_rows_html = []
590+
for rec in off_sample_trace:
591+
cells = []
592+
for k in off_keys:
593+
val = rec.get(k, "")
594+
if k == "meta_vars.stage" and val:
595+
style = _stage_badge_style(val)
596+
cell = (
597+
f'<td><span class="stage-badge" style="{style}">'
598+
f"{esc(val)}</span></td>"
599+
)
600+
else:
601+
cell = f"<td>{esc(str(val))}</td>"
602+
cells.append(cell)
603+
trace_rows_html.append(f"<tr>{''.join(cells)}</tr>")
604+
trace_body = "\n".join(trace_rows_html)
605+
expand_content = (
606+
f'<div class="trace-steps">Steps: {steps_html}</div>'
607+
f'<div class="trace-wrap"><table class="table trace-table">'
608+
f"<thead><tr>{trace_head}</tr></thead>"
609+
f"<tbody>{trace_body}</tbody></table></div>"
610+
)
553611
else:
554-
detail = relation
555-
if count and count > 1:
556-
detail = esc(f"{entry.get('relation', '')}{count} occurrences")
557-
top_items.append(
558-
f'<li><span class="inv-label">{label}</span>'
559-
f'<span class="inv-detail">{detail}</span>{count_html}</li>'
612+
expand_content = f'<div class="trace-steps">Steps: {steps_html}</div>'
613+
614+
rows.append(
615+
f"<tr>"
616+
f'<td><details><summary class="inv-label-summary">{label}</summary>'
617+
f'<div class="expand-body">{expand_content}</div></details>'
618+
f'<span class="inv-rel-tag">{relation}</span></td>'
619+
f'<td class="step-cell">{first_step_html}</td>'
620+
f'<td class="step-cell">{last_step_html}</td>'
621+
f'<td class="freq-cell"><span class="freq-rate">{count}</span></td>'
622+
f"</tr>"
560623
)
561-
top_list = "".join(top_items) or "<li>None</li>"
624+
top_table_html = (
625+
f'<table class="table viol-table"><thead>'
626+
f"<tr><th>Invariant</th><th>First Step</th><th>Last Step</th>"
627+
f"<th>Count</th></tr>"
628+
f"</thead><tbody>{''.join(rows)}</tbody></table>"
629+
if rows
630+
else "<p>No violations.</p>"
631+
)
562632

563633
trace_sections = []
564634
for trace in traces:
@@ -573,27 +643,74 @@ def _step_with_badge(step, stage) -> str:
573643
+ _render_bar_segment(percent(not_triggered, total), "bar-not-triggered")
574644
)
575645

576-
failed_list_items = []
646+
failed_rows = []
577647
for failed_item in trace["failed_invariants"][:10]:
578648
label = esc(str(failed_item.get("label", "")))
579649
relation = esc(str(failed_item.get("relation", "")))
580-
count = failed_item.get("count")
650+
count = failed_item.get("count", "")
581651
first_step = failed_item.get("first_step")
582-
count_html = f'<span class="inv-count">{count}</span>' if count else ""
583-
if first_step is not None:
584-
step_note = f"first seen at step {first_step}"
585-
if count and count > 1:
586-
step_note += f" · {count} occurrences"
587-
detail = esc(f"{relation}{step_note}")
652+
last_step = failed_item.get("last_step")
653+
item_step_stages: list = failed_item.get("step_stages") or []
654+
item_sample_trace = failed_item.get("sample_trace") or []
655+
656+
def _step_cell_trace(step) -> str:
657+
if step is None:
658+
return "—"
659+
stage = next((s for st, s in item_step_stages if st == step), None)
660+
badge = _render_stage_badge(stage, esc)
661+
return f"{badge}{step}"
662+
663+
steps_html = _render_step_stages_html(item_step_stages, esc)
664+
if item_sample_trace:
665+
item_keys: list[str] = []
666+
for rec in item_sample_trace:
667+
for k in rec:
668+
if k not in item_keys:
669+
item_keys.append(k)
670+
trace_head = "".join(f"<th>{esc(k)}</th>" for k in item_keys)
671+
trace_rows_html = []
672+
for rec in item_sample_trace:
673+
cells = []
674+
for k in item_keys:
675+
val = rec.get(k, "")
676+
if k == "meta_vars.stage" and val:
677+
style = _stage_badge_style(val)
678+
cell = (
679+
f'<td><span class="stage-badge" style="{style}">'
680+
f"{esc(val)}</span></td>"
681+
)
682+
else:
683+
cell = f"<td>{esc(str(val))}</td>"
684+
cells.append(cell)
685+
trace_rows_html.append(f"<tr>{''.join(cells)}</tr>")
686+
trace_body = "\n".join(trace_rows_html)
687+
expand_content = (
688+
f'<div class="trace-steps">Steps: {steps_html}</div>'
689+
f'<div class="trace-wrap"><table class="table trace-table">'
690+
f"<thead><tr>{trace_head}</tr></thead>"
691+
f"<tbody>{trace_body}</tbody></table></div>"
692+
)
588693
else:
589-
detail = relation
590-
if count and count > 1:
591-
detail = esc(f"{relation}{count} occurrences")
592-
failed_list_items.append(
593-
f'<li><span class="inv-label">{label}</span>'
594-
f'<span class="inv-detail">{detail}</span>{count_html}</li>'
694+
expand_content = f'<div class="trace-steps">Steps: {steps_html}</div>'
695+
696+
failed_rows.append(
697+
f"<tr>"
698+
f'<td><details><summary class="inv-label-summary">{label}</summary>'
699+
f'<div class="expand-body">{expand_content}</div></details>'
700+
f'<span class="inv-rel-tag">{relation}</span></td>'
701+
f'<td class="step-cell">{_step_cell_trace(first_step)}</td>'
702+
f'<td class="step-cell">{_step_cell_trace(last_step)}</td>'
703+
f'<td class="freq-cell"><span class="freq-rate">{count}</span></td>'
704+
f"</tr>"
595705
)
596-
failed_list_html = "".join(failed_list_items) or "<li>None</li>"
706+
failed_list_html = (
707+
f'<table class="table viol-table"><thead>'
708+
f"<tr><th>Invariant</th><th>First Step</th><th>Last Step</th>"
709+
f"<th>Count</th></tr>"
710+
f"</thead><tbody>{''.join(failed_rows)}</tbody></table>"
711+
if failed_rows
712+
else "<p>None</p>"
713+
)
597714

598715
relation_rows = []
599716
for relation_name, rel_counts in sorted(trace["relations"].items()):
@@ -630,7 +747,7 @@ def _step_with_badge(step, stage) -> str:
630747
<div class="grid-two">
631748
<div>
632749
<h3>Failed invariants (top 10)</h3>
633-
<ul class="inv-list">{failed_list_html}</ul>
750+
{failed_list_html}
634751
</div>
635752
<div>
636753
<h3>Relation breakdown</h3>
@@ -837,8 +954,8 @@ def _step_with_badge(step, stage) -> str:
837954
)
838955
panel_content = top_table_html
839956
else:
840-
panel_subtitle = "Most frequent violations observed"
841-
panel_content = f'<ul class="inv-list">{top_list}</ul>'
957+
panel_subtitle = "Sorted by first violation step — click to expand trace"
958+
panel_content = top_table_html
842959

843960
top_panel = f"""
844961
<section class="panel">
@@ -1353,23 +1470,35 @@ def _log_wandb(
13531470
top_violations = report_data.get("top_violations", [])
13541471
if top_violations:
13551472
vtable = wandb.Table(
1356-
columns=["invariant", "relation_type", "occurrences", "first_step"]
1473+
columns=[
1474+
"invariant",
1475+
"relation_type",
1476+
"occurrences",
1477+
"first_step",
1478+
"last_step",
1479+
]
13571480
)
13581481
for v in top_violations:
13591482
vtable.add_data(
13601483
v.get("label", ""),
13611484
v.get("relation", ""),
13621485
v.get("count", 0),
13631486
v.get("first_step"),
1487+
v.get("last_step"),
13641488
)
13651489
wandb.log({"violations": vtable})
13661490

13671491
# --- summary metrics (shown in run comparison table) ---
13681492
first_steps = [
13691493
v["first_step"] for v in top_violations if v.get("first_step") is not None
13701494
]
1495+
last_steps_wandb = [
1496+
v["last_step"] for v in top_violations if v.get("last_step") is not None
1497+
]
13711498
if first_steps:
13721499
run.summary["violations/first_step"] = min(first_steps)
1500+
if last_steps_wandb:
1501+
run.summary["violations/last_step"] = max(last_steps_wandb)
13731502
run.summary["violations/distinct_invariants"] = len(top_violations)
13741503

13751504
# --- violations_summary.json as versioned artifact ---
@@ -1463,8 +1592,13 @@ def _log_mlflow(
14631592
first_steps = [
14641593
v["first_step"] for v in top_violations if v.get("first_step") is not None
14651594
]
1595+
last_steps = [
1596+
v["last_step"] for v in top_violations if v.get("last_step") is not None
1597+
]
14661598
if first_steps:
14671599
mlflow.log_metric("violations_first_step", min(first_steps))
1600+
if last_steps:
1601+
mlflow.log_metric("violations_last_step", max(last_steps))
14681602
mlflow.log_metric("violations_distinct_invariants", len(top_violations))
14691603

14701604
# --- violations table as JSON artifact ---

0 commit comments

Comments
 (0)