Skip to content

Commit e8a7520

Browse files
Essozclaude
andcommitted
feat: rich online HTML report with step/stage annotations and checking progress
checker_online.py: - Track VIOLATION_DETAILS (step/stage pairs and sample trace per invariant) - Track TRIGGERED_INV (invariants checked at least once), ALL_INVS, CURRENT_STEP and CURRENT_STAGE from each processed trace record - Remove bare 'raise e' from API invariant exception handler so a single bad invariant check no longer crashes the entire checker loop - Pass new tracking state to build_online_report_data on every report emit checker_report.py: - Violations sorted by first violation step (earliest first) instead of count - Per-violation: first/last step with stage badge, full step list grouped by stage (e.g. [train] 1,2,3 · [eval] 100,101), expandable sample trace table - Stage badges with distinct colors for train/eval/val/test/inference; unknown stages get a hash-derived color from a fallback palette - New Checking Progress panel: stacked bar (passing/failing/not-triggered), collapsible list of not-yet-triggered invariants, pass rate card, and Current Step card showing latest step with stage badge Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 93a3e9c commit e8a7520

2 files changed

Lines changed: 680 additions & 53 deletions

File tree

traincheck/checker_online.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
)
2424
NUM_VIOLATIONS = 0
2525
FAILED_INV: dict[Invariant, int] = {}
26+
VIOLATION_DETAILS: dict[Invariant, dict] = {}
27+
TRIGGERED_INV: set[Invariant] = set()
28+
ALL_INVS: list[Invariant] = []
29+
CURRENT_STEP: int | None = None
30+
CURRENT_STAGE: str | None = None
2631
TOTAL_INVARIANTS = 0
2732
RELATION_TOTALS: dict[str, int] = {}
2833
REPORTER: ReportEmitter | None = None
@@ -100,7 +105,7 @@ def sort_inv_file(invariants):
100105
vartype_to_invs: dict[str, dict[str, list[Invariant]]] = {}
101106
needed_vars = set()
102107
needed_apis = set()
103-
_get_api_args_map_to_check = set()
108+
all_needed_args_api = set()
104109
for inv in invs:
105110
assert (
106111
inv.precondition is not None
@@ -114,7 +119,7 @@ def sort_inv_file(invariants):
114119
if needed_api is not None:
115120
needed_apis.update(needed_api)
116121
if needed_args_api is not None:
117-
_get_api_args_map_to_check.update(needed_args_api)
122+
all_needed_args_api.update(needed_args_api)
118123
for param in params:
119124
if isinstance(param, VarTypeParam):
120125
if param.var_type not in vartype_to_invs:
@@ -127,7 +132,7 @@ def sort_inv_file(invariants):
127132
param_to_invs[param] = []
128133
param_to_invs[param].append(inv)
129134
logger.info("Sorting done.")
130-
needed_data = (needed_vars, needed_apis, _get_api_args_map_to_check)
135+
needed_data = (needed_vars, needed_apis, all_needed_args_api)
131136
return invs, param_to_invs, vartype_to_invs, needed_data
132137

133138

@@ -139,6 +144,29 @@ def get_violated_pair_hash(trace_pair):
139144
return tuple(sorted((h1, h2), reverse=True))
140145

141146

147+
_MAX_TRACKED_STEPS = 500 # cap on steps stored per invariant
148+
149+
150+
def _record_violation_details(
151+
inv: Invariant, result, violation_details: dict[Invariant, dict]
152+
):
153+
"""Update per-invariant (step, stage) list and sample trace for the HTML report."""
154+
trace = result.trace or []
155+
step_stages = [
156+
(r["meta_vars.step"], r.get("meta_vars.stage"))
157+
for r in trace
158+
if isinstance(r, dict) and r.get("meta_vars.step") is not None
159+
]
160+
if inv not in violation_details:
161+
violation_details[inv] = {"step_stages": [], "sample_trace": None}
162+
detail = violation_details[inv]
163+
remaining = _MAX_TRACKED_STEPS - len(detail["step_stages"])
164+
if remaining > 0:
165+
detail["step_stages"].extend(step_stages[:remaining])
166+
if detail["sample_trace"] is None and trace:
167+
detail["sample_trace"] = trace[:8]
168+
169+
142170
def _emit_report(force: bool = False):
143171
if REPORTER is None:
144172
return
@@ -149,6 +177,11 @@ def _emit_report(force: bool = False):
149177
total_violations=NUM_VIOLATIONS,
150178
failed_inv=FAILED_INV,
151179
relation_totals=RELATION_TOTALS,
180+
violation_details=VIOLATION_DETAILS,
181+
triggered_inv=TRIGGERED_INV,
182+
all_invs=ALL_INVS,
183+
current_step=CURRENT_STEP,
184+
current_stage=CURRENT_STAGE,
152185
)
153186
report_state = (NUM_VIOLATIONS, len(FAILED_INV))
154187
REPORTER.emit(report_data, force=force, report_state=report_state)
@@ -160,6 +193,11 @@ def check(
160193
global OBSERVER
161194
global NUM_VIOLATIONS
162195
global FAILED_INV
196+
global VIOLATION_DETAILS
197+
global TRIGGERED_INV
198+
global ALL_INVS
199+
global CURRENT_STEP
200+
global CURRENT_STAGE
163201
global TOTAL_INVARIANTS
164202
global RELATION_TOTALS
165203

@@ -171,6 +209,7 @@ def check(
171209

172210
invs, param_to_invs, vartype_to_invs, needed_data = sort_inv_file(invariants)
173211
TOTAL_INVARIANTS = len(invs)
212+
ALL_INVS = list(invs)
174213
RELATION_TOTALS = defaultdict(int)
175214
for inv in invs:
176215
RELATION_TOTALS[inv.relation.__name__] += 1
@@ -198,6 +237,13 @@ def check(
198237
else:
199238
break
200239

240+
step = trace_record.get("meta_vars.step")
241+
stage = trace_record.get("meta_vars.stage")
242+
if step is not None:
243+
CURRENT_STEP = step
244+
if stage is not None:
245+
CURRENT_STAGE = stage
246+
201247
if "var_name" in trace_record and trace_record["var_name"] is not None:
202248
varid = VarInstId(
203249
trace_record["process_id"],
@@ -216,6 +262,7 @@ def check(
216262
result = inv.online_check(
217263
trace_record, checker_data, check_relation_first
218264
)
265+
TRIGGERED_INV.add(inv)
219266
if not result.check_passed:
220267
violated_pair = get_violated_pair_hash(result.trace)
221268
if inv not in violated_pairs:
@@ -227,6 +274,9 @@ def check(
227274
if inv not in FAILED_INV:
228275
FAILED_INV[inv] = 0
229276
FAILED_INV[inv] += 1
277+
_record_violation_details(
278+
inv, result, VIOLATION_DETAILS
279+
)
230280
NUM_VIOLATIONS += 1
231281
result.set_id_and_detection_time(
232282
NUM_VIOLATIONS, time.monotonic_ns()
@@ -258,16 +308,18 @@ def check(
258308
result = inv.online_check(
259309
trace_record, checker_data, check_relation_first
260310
)
311+
TRIGGERED_INV.add(inv)
261312
if not result.check_passed:
262313
if inv not in FAILED_INV:
263314
FAILED_INV[inv] = 0
264315
FAILED_INV[inv] += 1
316+
_record_violation_details(inv, result, VIOLATION_DETAILS)
265317
NUM_VIOLATIONS += 1
266318
result.set_id_and_detection_time(
267319
NUM_VIOLATIONS, time.monotonic_ns()
268320
)
269321
logger.error(
270-
f"Violated id {NUM_VIOLATIONS}:\nInvariant {inv} violated near time {trace_record['time']}"
322+
f"Violated id {NUM_VIOLATIONS}:\nInvariant {inv.text_description} violated near time {trace_record['time'], trace_record['meta_vars.step']}"
271323
)
272324
with open(output_file, "a") as f:
273325
json.dump(
@@ -279,6 +331,7 @@ def check(
279331
logger.error(
280332
f"Error when checking invariant {inv.text_description} with trace {trace_record}: {e}"
281333
)
334+
raise e
282335

283336
_emit_report()
284337

0 commit comments

Comments
 (0)