@@ -41,6 +41,17 @@ def _build_violation_entry(result: CheckerResult) -> dict:
4141 }
4242
4343
44+ def _build_violation_steps_map (results : list [CheckerResult ]) -> dict [int , int ]:
45+ """Map step → count of distinct invariants violated at that step."""
46+ step_to_invs : dict [int , set [str ]] = defaultdict (set )
47+ for res in results :
48+ if not res .check_passed :
49+ label = _format_invariant_label (res .invariant )
50+ for step in _extract_violation_steps (res .trace ):
51+ step_to_invs [step ].add (label )
52+ return {step : len (invs ) for step , invs in step_to_invs .items ()}
53+
54+
4455def build_violations_summary (results : list [CheckerResult ]) -> dict :
4556 """Build a pre-digested summary of all violations for machine and human consumption."""
4657 failed = [r for r in results if not r .check_passed ]
@@ -191,6 +202,7 @@ def build_offline_report_data(
191202 all_failed_invariants .extend ([res for res in results if not res .check_passed ])
192203
193204 top_violations = _count_failed_invariants (all_failed_invariants )
205+ violation_steps_map = _build_violation_steps_map (all_failed_invariants )
194206
195207 return {
196208 "mode" : "offline" ,
@@ -200,6 +212,7 @@ def build_offline_report_data(
200212 "relations" : dict (overall_relation_counts ),
201213 "traces" : trace_sections ,
202214 "top_violations" : top_violations ,
215+ "violation_steps_map" : violation_steps_map ,
203216 }
204217
205218
@@ -341,6 +354,14 @@ def _make_entry(inv: Invariant, count: int) -> dict:
341354 "violation_rate" : viol_rate ,
342355 }
343356
357+ # Build step → distinct-invariant count map from all violation_details
358+ step_to_invs : dict [int , set [str ]] = defaultdict (set )
359+ for inv , detail in violation_details .items ():
360+ lbl = _format_invariant_label (inv )
361+ for step , _ in detail .get ("step_stages" ) or []:
362+ step_to_invs [step ].add (lbl )
363+ violation_steps_map = {step : len (invs ) for step , invs in step_to_invs .items ()}
364+
344365 # Sort by first violation step (earliest first), then by count descending.
345366 def _sort_key (item ):
346367 inv , count = item
@@ -404,6 +425,7 @@ def _sort_key(item):
404425 "sampling_interval" : sampling_interval ,
405426 "warm_up_steps" : warm_up_steps ,
406427 "checked_steps" : checked_steps ,
428+ "violation_steps_map" : violation_steps_map ,
407429 }
408430
409431
@@ -1397,14 +1419,19 @@ def _log_wandb(
13971419 return
13981420
13991421 if self ._wandb_run is None :
1400- self . _wandb_run = wandb . init (
1422+ init_kwargs : dict = dict (
14011423 project = args .wandb_project ,
14021424 entity = args .wandb_entity ,
14031425 name = args .wandb_run_name ,
14041426 group = args .wandb_group ,
14051427 tags = args .wandb_tags ,
14061428 job_type = "checker" ,
14071429 )
1430+ run_id = getattr (args , "wandb_run_id" , None )
1431+ if run_id :
1432+ init_kwargs ["id" ] = run_id
1433+ init_kwargs ["resume" ] = "allow"
1434+ self ._wandb_run = wandb .init (** init_kwargs ) # type: ignore[assignment]
14081435 run = self ._wandb_run
14091436 if run is None :
14101437 logging .getLogger (__name__ ).warning ("wandb.init() returned None; skipping." )
@@ -1530,6 +1557,11 @@ def _log_wandb(
15301557 "Failed to attach HTML report to wandb run."
15311558 )
15321559
1560+ # --- per-step violation time-series (overlays with training loss curve) ---
1561+ violation_steps_map : dict [int , int ] = report_data .get ("violation_steps_map" , {})
1562+ for step , count in sorted (violation_steps_map .items ()):
1563+ wandb .log ({"traincheck/violations" : count }, step = step )
1564+
15331565 def _log_mlflow (
15341566 self ,
15351567 report_data : dict ,
@@ -1601,11 +1633,24 @@ def _log_mlflow(
16011633 mlflow .log_metric ("violations_last_step" , max (last_steps ))
16021634 mlflow .log_metric ("violations_distinct_invariants" , len (top_violations ))
16031635
1604- # --- violations table as JSON artifact ---
1636+ # --- per-step violation time-series (overlays with training loss curve) ---
1637+ violation_steps_map : dict [int , int ] = report_data .get ("violation_steps_map" , {})
1638+ for step , count in sorted (violation_steps_map .items ()):
1639+ mlflow .log_metric ("traincheck_violations" , count , step = step )
1640+
1641+ # --- violations table (mlflow.log_table for proper UI rendering) ---
16051642 if top_violations :
16061643 try :
1607- mlflow .log_dict (
1608- {"violations" : top_violations },
1644+ mlflow .log_table (
1645+ data = {
1646+ "invariant" : [v .get ("label" , "" ) for v in top_violations ],
1647+ "relation_type" : [
1648+ v .get ("relation" , "" ) for v in top_violations
1649+ ],
1650+ "occurrences" : [v .get ("count" , 0 ) for v in top_violations ],
1651+ "first_step" : [v .get ("first_step" ) for v in top_violations ],
1652+ "last_step" : [v .get ("last_step" ) for v in top_violations ],
1653+ },
16091654 artifact_file = "violations.json" ,
16101655 )
16111656 except Exception :
0 commit comments