Skip to content

Commit 02a3a33

Browse files
Essozclaude
andcommitted
fix: strip _TRAINCHECK_ prefix when displaying internal tensor-tracking attrs
Add _display_attr_name() helper that maps '_TRAINCHECK_grad_ID' -> 'grad' etc. Use it in APIContainRelation.to_display_name (removing the return-None guard) and ConsistencyRelation.to_display_name. Remove the now-unnecessary [internal tracking] fallback from _format_invariant_label. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 95fb115 commit 02a3a33

4 files changed

Lines changed: 18 additions & 22 deletions

File tree

traincheck/invariant/base_cls.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,6 +1944,17 @@ def _short_api_name(full_name: str) -> str:
19441944
return ".".join(parts[-2:]) if len(parts) >= 2 else full_name
19451945

19461946

1947+
def _display_attr_name(attr_name: str) -> str:
1948+
"""Strip TrainCheck-internal proxy bookkeeping prefix/suffix for display.
1949+
1950+
'_TRAINCHECK_grad_ID' → 'grad'
1951+
'dtype' → 'dtype' (unchanged)
1952+
"""
1953+
if attr_name.startswith("_TRAINCHECK_") and attr_name.endswith("_ID"):
1954+
return attr_name[len("_TRAINCHECK_") : -len("_ID")]
1955+
return attr_name
1956+
1957+
19471958
def read_inv_file(file_path: str | list[str]) -> list[Invariant]:
19481959
if isinstance(file_path, str):
19491960
file_path = [file_path]

traincheck/invariant/consistency_relation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Param,
1515
Relation,
1616
VarTypeParam,
17+
_display_attr_name,
1718
)
1819
from traincheck.invariant.precondition import find_precondition
1920
from traincheck.onlinechecker.utils import Checker_data, set_meta_vars_online
@@ -118,7 +119,7 @@ def to_display_name(params: list[Param]) -> str | None:
118119
if not isinstance(p, VarTypeParam):
119120
return None
120121
var_short = p.var_type.split(".")[-1]
121-
attr = p.attr_name
122+
attr = _display_attr_name(p.attr_name)
122123
return f"{var_short}.{attr} stays consistent across training steps"
123124

124125
@staticmethod

traincheck/invariant/contain_relation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Relation,
2424
VarNameParam,
2525
VarTypeParam,
26+
_display_attr_name,
2627
_short_api_name,
2728
calc_likelihood,
2829
construct_api_param,
@@ -358,10 +359,7 @@ def to_display_name(params: list[Param]) -> str | None:
358359
child_short = _short_api_name(child.api_full_name)
359360
return f"{parent_short}() always calls {child_short}()"
360361
if isinstance(child, (VarTypeParam, VarNameParam)):
361-
attr = child.attr_name
362-
# Skip internal TrainCheck proxy bookkeeping attributes
363-
if attr.startswith("_TRAINCHECK_"):
364-
return None
362+
attr = _display_attr_name(child.attr_name)
365363
var_short = child.var_type.split(".")[-1]
366364
pre = child.pre_value
367365
post = child.post_value

traincheck/reporting/checker_report.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,12 @@
77
from typing import Iterable
88

99
from traincheck.invariant import CheckerResult, Invariant
10-
from traincheck.invariant.base_cls import APIParam, VarNameParam, VarTypeParam
1110

1211

1312
def _format_invariant_label(invariant: Invariant) -> str:
1413
display = invariant.relation.to_display_name(invariant.params)
1514
if display:
1615
return display
17-
# When to_display_name returns None, fall back — but sanitize params that
18-
# contain internal TrainCheck proxy bookkeeping names (_TRAINCHECK_*) so
19-
# they never surface raw in the UI.
20-
for p in invariant.params:
21-
if isinstance(p, (VarTypeParam, VarNameParam)) and p.attr_name.startswith(
22-
"_TRAINCHECK_"
23-
):
24-
# Build a minimal label using only the API param, hiding internals
25-
api_parts = [q for q in invariant.params if isinstance(q, APIParam)]
26-
from traincheck.invariant.base_cls import _short_api_name
27-
28-
if api_parts:
29-
func = _short_api_name(api_parts[0].api_full_name)
30-
return f"{func}() [internal tracking]"
31-
return f"{invariant.relation.__name__} [internal tracking]"
3216
if invariant.text_description:
3317
return invariant.text_description
3418
params = ", ".join(str(param) for param in invariant.params)
@@ -290,7 +274,9 @@ def build_online_report_data(
290274
and warm_up_steps is not None
291275
and sampling_interval > 0
292276
):
293-
checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + 1
277+
checked_steps = max(0, current_step - warm_up_steps) // sampling_interval + min(
278+
current_step, warm_up_steps
279+
)
294280

295281
def _make_entry(inv: Invariant, count: int) -> dict:
296282
detail = violation_details.get(inv, {})

0 commit comments

Comments
 (0)