Skip to content

Commit 6c4caf4

Browse files
Essozclaude
andcommitted
feat: log per-step violation counts to W&B and MLflow
- Add _build_violation_steps_map() helper: step → count of distinct invariants violated at that step (across all failed CheckerResults) - Propagate violation_steps_map through build_offline_report_data and build_online_report_data so downstream loggers can consume it - W&B: log traincheck/violations as a metric at each step via wandb.log({...}, step=N) so violations appear on the same x-axis as training loss; add --wandb-run-id CLI arg to attach to an existing run - MLflow: log traincheck_violations per step via mlflow.log_metric(step=N); switch violations table from log_dict to log_table() for proper UI Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 9cf8064 commit 6c4caf4

3 files changed

Lines changed: 61 additions & 4 deletions

File tree

traincheck/checker.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ def main():
170170
default=None,
171171
help="Weights & Biases tags.",
172172
)
173+
parser.add_argument(
174+
"--wandb-run-id",
175+
type=str,
176+
default=None,
177+
help="Attach to an existing Weights & Biases run ID (e.g. to overlay violation metrics on a training run).",
178+
)
173179
parser.add_argument(
174180
"--report-mlflow",
175181
action="store_true",

traincheck/checker_online.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,12 @@ def main():
477477
default=None,
478478
help="Weights & Biases tags.",
479479
)
480+
parser.add_argument(
481+
"--wandb-run-id",
482+
type=str,
483+
default=None,
484+
help="Attach to an existing Weights & Biases run ID (e.g. to overlay violation metrics on a training run).",
485+
)
480486
parser.add_argument(
481487
"--report-mlflow",
482488
action="store_true",

traincheck/reporting/checker_report.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ def _build_violation_entry(result: CheckerResult) -> dict:
4141
}
4242

4343

44+
def _build_violation_steps_map(results: list[CheckerResult]) -> dict[int, int]:
45+
"""Map step → count of distinct invariants violated at that step."""
46+
step_to_invs: dict[int, set[str]] = defaultdict(set)
47+
for res in results:
48+
if not res.check_passed:
49+
label = _format_invariant_label(res.invariant)
50+
for step in _extract_violation_steps(res.trace):
51+
step_to_invs[step].add(label)
52+
return {step: len(invs) for step, invs in step_to_invs.items()}
53+
54+
4455
def build_violations_summary(results: list[CheckerResult]) -> dict:
4556
"""Build a pre-digested summary of all violations for machine and human consumption."""
4657
failed = [r for r in results if not r.check_passed]
@@ -191,6 +202,7 @@ def build_offline_report_data(
191202
all_failed_invariants.extend([res for res in results if not res.check_passed])
192203

193204
top_violations = _count_failed_invariants(all_failed_invariants)
205+
violation_steps_map = _build_violation_steps_map(all_failed_invariants)
194206

195207
return {
196208
"mode": "offline",
@@ -200,6 +212,7 @@ def build_offline_report_data(
200212
"relations": dict(overall_relation_counts),
201213
"traces": trace_sections,
202214
"top_violations": top_violations,
215+
"violation_steps_map": violation_steps_map,
203216
}
204217

205218

@@ -341,6 +354,14 @@ def _make_entry(inv: Invariant, count: int) -> dict:
341354
"violation_rate": viol_rate,
342355
}
343356

357+
# Build step → distinct-invariant count map from all violation_details
358+
step_to_invs: dict[int, set[str]] = defaultdict(set)
359+
for inv, detail in violation_details.items():
360+
lbl = _format_invariant_label(inv)
361+
for step, _ in detail.get("step_stages") or []:
362+
step_to_invs[step].add(lbl)
363+
violation_steps_map = {step: len(invs) for step, invs in step_to_invs.items()}
364+
344365
# Sort by first violation step (earliest first), then by count descending.
345366
def _sort_key(item):
346367
inv, count = item
@@ -404,6 +425,7 @@ def _sort_key(item):
404425
"sampling_interval": sampling_interval,
405426
"warm_up_steps": warm_up_steps,
406427
"checked_steps": checked_steps,
428+
"violation_steps_map": violation_steps_map,
407429
}
408430

409431

@@ -1397,14 +1419,19 @@ def _log_wandb(
13971419
return
13981420

13991421
if self._wandb_run is None:
1400-
self._wandb_run = wandb.init(
1422+
init_kwargs: dict = dict(
14011423
project=args.wandb_project,
14021424
entity=args.wandb_entity,
14031425
name=args.wandb_run_name,
14041426
group=args.wandb_group,
14051427
tags=args.wandb_tags,
14061428
job_type="checker",
14071429
)
1430+
run_id = getattr(args, "wandb_run_id", None)
1431+
if run_id:
1432+
init_kwargs["id"] = run_id
1433+
init_kwargs["resume"] = "allow"
1434+
self._wandb_run = wandb.init(**init_kwargs) # type: ignore[assignment]
14081435
run = self._wandb_run
14091436
if run is None:
14101437
logging.getLogger(__name__).warning("wandb.init() returned None; skipping.")
@@ -1530,6 +1557,11 @@ def _log_wandb(
15301557
"Failed to attach HTML report to wandb run."
15311558
)
15321559

1560+
# --- per-step violation time-series (overlays with training loss curve) ---
1561+
violation_steps_map: dict[int, int] = report_data.get("violation_steps_map", {})
1562+
for step, count in sorted(violation_steps_map.items()):
1563+
wandb.log({"traincheck/violations": count}, step=step)
1564+
15331565
def _log_mlflow(
15341566
self,
15351567
report_data: dict,
@@ -1601,11 +1633,24 @@ def _log_mlflow(
16011633
mlflow.log_metric("violations_last_step", max(last_steps))
16021634
mlflow.log_metric("violations_distinct_invariants", len(top_violations))
16031635

1604-
# --- violations table as JSON artifact ---
1636+
# --- per-step violation time-series (overlays with training loss curve) ---
1637+
violation_steps_map: dict[int, int] = report_data.get("violation_steps_map", {})
1638+
for step, count in sorted(violation_steps_map.items()):
1639+
mlflow.log_metric("traincheck_violations", count, step=step)
1640+
1641+
# --- violations table (mlflow.log_table for proper UI rendering) ---
16051642
if top_violations:
16061643
try:
1607-
mlflow.log_dict(
1608-
{"violations": top_violations},
1644+
mlflow.log_table(
1645+
data={
1646+
"invariant": [v.get("label", "") for v in top_violations],
1647+
"relation_type": [
1648+
v.get("relation", "") for v in top_violations
1649+
],
1650+
"occurrences": [v.get("count", 0) for v in top_violations],
1651+
"first_step": [v.get("first_step") for v in top_violations],
1652+
"last_step": [v.get("last_step") for v in top_violations],
1653+
},
16091654
artifact_file="violations.json",
16101655
)
16111656
except Exception:

0 commit comments

Comments
 (0)