@@ -91,6 +91,9 @@ def _count_failed_invariants(
9191) -> list [dict [str , object ]]:
9292 counter : Counter [tuple [str , str ]] = Counter ()
9393 first_steps : dict [tuple [str , str ], int | None ] = {}
94+ last_steps : dict [tuple [str , str ], int | None ] = {}
95+ step_stage_maps : dict [tuple [str , str ], dict ] = defaultdict (dict )
96+ sample_traces : dict [tuple [str , str ], list ] = {}
9497 for res in results :
9598 if not res .check_passed :
9699 label = _format_invariant_label (res .invariant )
@@ -103,15 +106,36 @@ def _count_failed_invariants(
103106 first_steps [key ] = (
104107 min (steps ) if existing is None else min (existing , min (steps ))
105108 )
109+ existing_last = last_steps .get (key )
110+ last_steps [key ] = (
111+ max (steps )
112+ if existing_last is None
113+ else max (existing_last , max (steps ))
114+ )
106115 elif key not in first_steps :
107116 first_steps [key ] = None
117+ last_steps [key ] = None
118+ # Accumulate step → stage (first stage seen per step wins)
119+ for rec in res .trace or []:
120+ if not isinstance (rec , dict ):
121+ continue
122+ step = rec .get ("meta_vars.step" )
123+ stage = rec .get ("meta_vars.stage" )
124+ if step is not None and step not in step_stage_maps [key ]:
125+ step_stage_maps [key ][step ] = stage
126+ # One sample trace per invariant (first violation wins)
127+ if key not in sample_traces and res .trace :
128+ sample_traces [key ] = _summarize_trace_records (res .trace )
108129 top_pairs = counter .most_common (10 )
109130 return [
110131 {
111132 "label" : label ,
112133 "relation" : relation ,
113134 "count" : count ,
114135 "first_step" : first_steps .get ((label , relation )),
136+ "last_step" : last_steps .get ((label , relation )),
137+ "step_stages" : sorted (step_stage_maps [(label , relation )].items ()),
138+ "sample_trace" : sample_traces .get ((label , relation ), []),
115139 }
116140 for (label , relation ), count in top_pairs
117141 ]
@@ -434,11 +458,7 @@ def percent(part: int, total: int) -> float:
434458 traces = report_data .get ("traces" , [])
435459 top_violations = report_data .get ("top_violations" , [])
436460
437- # Build top violations HTML differently per mode.
438- # For online mode: step-sorted table with expandable trace rows.
439- # For offline mode: simple list (unchanged).
440- top_list = "" # used only in offline mode below
441- top_table_html = "" # used only in online mode below
461+ top_table_html = ""
442462
443463 if mode == "online" :
444464 sampling_interval = report_data .get ("sampling_interval" )
@@ -538,27 +558,77 @@ def _step_with_badge(step, stage) -> str:
538558 else "<p>No violations yet.</p>"
539559 )
540560 else :
541- top_items = []
561+ rows = []
542562 for entry in top_violations :
543563 label = esc (str (entry .get ("label" , "" )))
544564 relation = esc (str (entry .get ("relation" , "" )))
545- count = entry .get ("count" )
565+ count = entry .get ("count" , "" )
546566 first_step = entry .get ("first_step" )
547- count_html = f'<span class="inv-count">{ count } </span>' if count else ""
548- if first_step is not None :
549- step_note = f"first seen at step { first_step } "
550- if count and count > 1 :
551- step_note += f" · { count } occurrences"
552- detail = esc (f"{ entry .get ('relation' , '' )} — { step_note } " )
567+ last_step = entry .get ("last_step" )
568+ off_step_stages : list = entry .get ("step_stages" ) or []
569+ off_sample_trace = entry .get ("sample_trace" ) or []
570+
571+ def _step_cell (step , _ss = off_step_stages ) -> str :
572+ if step is None :
573+ return "—"
574+ stage = next ((s for st , s in _ss if st == step ), None )
575+ badge = _render_stage_badge (stage , esc )
576+ return f"{ badge } { step } "
577+
578+ first_step_html = _step_cell (first_step )
579+ last_step_html = _step_cell (last_step )
580+ steps_html = _render_step_stages_html (off_step_stages , esc )
581+
582+ if off_sample_trace :
583+ off_keys : list [str ] = []
584+ for rec in off_sample_trace :
585+ for k in rec :
586+ if k not in off_keys :
587+ off_keys .append (k )
588+ trace_head = "" .join (f"<th>{ esc (k )} </th>" for k in off_keys )
589+ trace_rows_html = []
590+ for rec in off_sample_trace :
591+ cells = []
592+ for k in off_keys :
593+ val = rec .get (k , "" )
594+ if k == "meta_vars.stage" and val :
595+ style = _stage_badge_style (val )
596+ cell = (
597+ f'<td><span class="stage-badge" style="{ style } ">'
598+ f"{ esc (val )} </span></td>"
599+ )
600+ else :
601+ cell = f"<td>{ esc (str (val ))} </td>"
602+ cells .append (cell )
603+ trace_rows_html .append (f"<tr>{ '' .join (cells )} </tr>" )
604+ trace_body = "\n " .join (trace_rows_html )
605+ expand_content = (
606+ f'<div class="trace-steps">Steps: { steps_html } </div>'
607+ f'<div class="trace-wrap"><table class="table trace-table">'
608+ f"<thead><tr>{ trace_head } </tr></thead>"
609+ f"<tbody>{ trace_body } </tbody></table></div>"
610+ )
553611 else :
554- detail = relation
555- if count and count > 1 :
556- detail = esc (f"{ entry .get ('relation' , '' )} — { count } occurrences" )
557- top_items .append (
558- f'<li><span class="inv-label">{ label } </span>'
559- f'<span class="inv-detail">{ detail } </span>{ count_html } </li>'
612+ expand_content = f'<div class="trace-steps">Steps: { steps_html } </div>'
613+
614+ rows .append (
615+ f"<tr>"
616+ f'<td><details><summary class="inv-label-summary">{ label } </summary>'
617+ f'<div class="expand-body">{ expand_content } </div></details>'
618+ f'<span class="inv-rel-tag">{ relation } </span></td>'
619+ f'<td class="step-cell">{ first_step_html } </td>'
620+ f'<td class="step-cell">{ last_step_html } </td>'
621+ f'<td class="freq-cell"><span class="freq-rate">{ count } </span></td>'
622+ f"</tr>"
560623 )
561- top_list = "" .join (top_items ) or "<li>None</li>"
624+ top_table_html = (
625+ f'<table class="table viol-table"><thead>'
626+ f"<tr><th>Invariant</th><th>First Step</th><th>Last Step</th>"
627+ f"<th>Count</th></tr>"
628+ f"</thead><tbody>{ '' .join (rows )} </tbody></table>"
629+ if rows
630+ else "<p>No violations.</p>"
631+ )
562632
563633 trace_sections = []
564634 for trace in traces :
@@ -573,27 +643,74 @@ def _step_with_badge(step, stage) -> str:
573643 + _render_bar_segment (percent (not_triggered , total ), "bar-not-triggered" )
574644 )
575645
576- failed_list_items = []
646+ failed_rows = []
577647 for failed_item in trace ["failed_invariants" ][:10 ]:
578648 label = esc (str (failed_item .get ("label" , "" )))
579649 relation = esc (str (failed_item .get ("relation" , "" )))
580- count = failed_item .get ("count" )
650+ count = failed_item .get ("count" , "" )
581651 first_step = failed_item .get ("first_step" )
582- count_html = f'<span class="inv-count">{ count } </span>' if count else ""
583- if first_step is not None :
584- step_note = f"first seen at step { first_step } "
585- if count and count > 1 :
586- step_note += f" · { count } occurrences"
587- detail = esc (f"{ relation } — { step_note } " )
652+ last_step = failed_item .get ("last_step" )
653+ item_step_stages : list = failed_item .get ("step_stages" ) or []
654+ item_sample_trace = failed_item .get ("sample_trace" ) or []
655+
656+ def _step_cell_trace (step ) -> str :
657+ if step is None :
658+ return "—"
659+ stage = next ((s for st , s in item_step_stages if st == step ), None )
660+ badge = _render_stage_badge (stage , esc )
661+ return f"{ badge } { step } "
662+
663+ steps_html = _render_step_stages_html (item_step_stages , esc )
664+ if item_sample_trace :
665+ item_keys : list [str ] = []
666+ for rec in item_sample_trace :
667+ for k in rec :
668+ if k not in item_keys :
669+ item_keys .append (k )
670+ trace_head = "" .join (f"<th>{ esc (k )} </th>" for k in item_keys )
671+ trace_rows_html = []
672+ for rec in item_sample_trace :
673+ cells = []
674+ for k in item_keys :
675+ val = rec .get (k , "" )
676+ if k == "meta_vars.stage" and val :
677+ style = _stage_badge_style (val )
678+ cell = (
679+ f'<td><span class="stage-badge" style="{ style } ">'
680+ f"{ esc (val )} </span></td>"
681+ )
682+ else :
683+ cell = f"<td>{ esc (str (val ))} </td>"
684+ cells .append (cell )
685+ trace_rows_html .append (f"<tr>{ '' .join (cells )} </tr>" )
686+ trace_body = "\n " .join (trace_rows_html )
687+ expand_content = (
688+ f'<div class="trace-steps">Steps: { steps_html } </div>'
689+ f'<div class="trace-wrap"><table class="table trace-table">'
690+ f"<thead><tr>{ trace_head } </tr></thead>"
691+ f"<tbody>{ trace_body } </tbody></table></div>"
692+ )
588693 else :
589- detail = relation
590- if count and count > 1 :
591- detail = esc (f"{ relation } — { count } occurrences" )
592- failed_list_items .append (
593- f'<li><span class="inv-label">{ label } </span>'
594- f'<span class="inv-detail">{ detail } </span>{ count_html } </li>'
694+ expand_content = f'<div class="trace-steps">Steps: { steps_html } </div>'
695+
696+ failed_rows .append (
697+ f"<tr>"
698+ f'<td><details><summary class="inv-label-summary">{ label } </summary>'
699+ f'<div class="expand-body">{ expand_content } </div></details>'
700+ f'<span class="inv-rel-tag">{ relation } </span></td>'
701+ f'<td class="step-cell">{ _step_cell_trace (first_step )} </td>'
702+ f'<td class="step-cell">{ _step_cell_trace (last_step )} </td>'
703+ f'<td class="freq-cell"><span class="freq-rate">{ count } </span></td>'
704+ f"</tr>"
595705 )
596- failed_list_html = "" .join (failed_list_items ) or "<li>None</li>"
706+ failed_list_html = (
707+ f'<table class="table viol-table"><thead>'
708+ f"<tr><th>Invariant</th><th>First Step</th><th>Last Step</th>"
709+ f"<th>Count</th></tr>"
710+ f"</thead><tbody>{ '' .join (failed_rows )} </tbody></table>"
711+ if failed_rows
712+ else "<p>None</p>"
713+ )
597714
598715 relation_rows = []
599716 for relation_name , rel_counts in sorted (trace ["relations" ].items ()):
@@ -630,7 +747,7 @@ def _step_with_badge(step, stage) -> str:
630747 <div class="grid-two">
631748 <div>
632749 <h3>Failed invariants (top 10)</h3>
633- <ul class="inv-list"> { failed_list_html } </ul>
750+ { failed_list_html }
634751 </div>
635752 <div>
636753 <h3>Relation breakdown</h3>
@@ -837,8 +954,8 @@ def _step_with_badge(step, stage) -> str:
837954 )
838955 panel_content = top_table_html
839956 else :
840- panel_subtitle = "Most frequent violations observed "
841- panel_content = f'<ul class="inv-list"> { top_list } </ul>'
957+ panel_subtitle = "Sorted by first violation step — click to expand trace "
958+ panel_content = top_table_html
842959
843960 top_panel = f"""
844961 <section class="panel">
@@ -1353,23 +1470,35 @@ def _log_wandb(
13531470 top_violations = report_data .get ("top_violations" , [])
13541471 if top_violations :
13551472 vtable = wandb .Table (
1356- columns = ["invariant" , "relation_type" , "occurrences" , "first_step" ]
1473+ columns = [
1474+ "invariant" ,
1475+ "relation_type" ,
1476+ "occurrences" ,
1477+ "first_step" ,
1478+ "last_step" ,
1479+ ]
13571480 )
13581481 for v in top_violations :
13591482 vtable .add_data (
13601483 v .get ("label" , "" ),
13611484 v .get ("relation" , "" ),
13621485 v .get ("count" , 0 ),
13631486 v .get ("first_step" ),
1487+ v .get ("last_step" ),
13641488 )
13651489 wandb .log ({"violations" : vtable })
13661490
13671491 # --- summary metrics (shown in run comparison table) ---
13681492 first_steps = [
13691493 v ["first_step" ] for v in top_violations if v .get ("first_step" ) is not None
13701494 ]
1495+ last_steps_wandb = [
1496+ v ["last_step" ] for v in top_violations if v .get ("last_step" ) is not None
1497+ ]
13711498 if first_steps :
13721499 run .summary ["violations/first_step" ] = min (first_steps )
1500+ if last_steps_wandb :
1501+ run .summary ["violations/last_step" ] = max (last_steps_wandb )
13731502 run .summary ["violations/distinct_invariants" ] = len (top_violations )
13741503
13751504 # --- violations_summary.json as versioned artifact ---
@@ -1463,8 +1592,13 @@ def _log_mlflow(
14631592 first_steps = [
14641593 v ["first_step" ] for v in top_violations if v .get ("first_step" ) is not None
14651594 ]
1595+ last_steps = [
1596+ v ["last_step" ] for v in top_violations if v .get ("last_step" ) is not None
1597+ ]
14661598 if first_steps :
14671599 mlflow .log_metric ("violations_first_step" , min (first_steps ))
1600+ if last_steps :
1601+ mlflow .log_metric ("violations_last_step" , max (last_steps ))
14681602 mlflow .log_metric ("violations_distinct_invariants" , len (top_violations ))
14691603
14701604 # --- violations table as JSON artifact ---
0 commit comments