Skip to content

Commit 3c820a3

Browse files
Essozclaude
andcommitted
fix: demote instrumentor internal prints to logger.debug
Eliminates per-step stdout noise from control.py (Warmup/Interval/ Skipping step printed every training step), shutdown messages from dumper.py, AST loop/model detection messages from source_file.py, and proxy parameter setup messages from proxy.py. All demoted to logger.debug() so they remain accessible with -d flag but don't clutter normal traincheck-collect output. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 242959c commit 3c820a3

4 files changed

Lines changed: 28 additions & 34 deletions

File tree

traincheck/instrumentor/control.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,17 @@ def start_step():
3232
config.DISABLE_WRAPPER = False
3333

3434
if current_step < warm_up:
35-
print(f"Warmup step {current_step}")
35+
logger.debug(f"Warmup step {current_step}")
3636
config.DISABLE_WRAPPER = False
3737
elif (current_step - warm_up) % interval == 0:
38-
print(f"Interval step {current_step}")
38+
logger.debug(f"Interval step {current_step}")
3939
config.DISABLE_WRAPPER = False
4040
else:
41-
print(f"Skipping step {current_step}")
41+
logger.debug(f"Skipping step {current_step}")
4242
config.DISABLE_WRAPPER = True
4343
else:
4444
# No policy, always enable
45-
print("No policy, always enable")
45+
logger.debug("No policy, always enable")
4646
config.DISABLE_WRAPPER = False
4747

4848

@@ -65,13 +65,13 @@ def start_eval_step():
6565
config.DISABLE_WRAPPER = False
6666

6767
if current_step < warm_up:
68-
print(f"Eval: Warmup step {current_step}")
68+
logger.debug(f"Eval: Warmup step {current_step}")
6969
config.DISABLE_WRAPPER = False
7070
elif (current_step - warm_up) % interval == 0:
71-
print(f"Eval: Interval step {current_step}")
71+
logger.debug(f"Eval: Interval step {current_step}")
7272
config.DISABLE_WRAPPER = False
7373
else:
74-
print(f"Eval: Skipping step {current_step}")
74+
logger.debug(f"Eval: Skipping step {current_step}")
7575
config.DISABLE_WRAPPER = True
7676
else:
7777
config.DISABLE_WRAPPER = False

traincheck/instrumentor/dumper.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def serialize(obj_dict: dict[str, object | str]) -> str:
9292

9393
def monitor_main_thread(main_thread, stop_event):
9494
main_thread.join() # Wait for the main thread to finish
95-
print("Main thread has finished or encountered an exception")
96-
print("Flushing all buffers to the trace log file")
95+
logger.debug("Main thread has finished or encountered an exception")
96+
logger.debug("Flushing all buffers to the trace log file")
9797
stop_event.set() # Signal the logging threads to stop
9898

9999

@@ -106,12 +106,12 @@ def trace_dumper(task_queue: Queue, trace_file_name: str, stop_event: threading.
106106
) # wait for 2x the flush interval, this is an arbitrary number, as long as it is larger than the flush interval, it should be fine.
107107
except Empty:
108108
if stop_event.is_set():
109-
print("Trace dumper thread has stopped.")
109+
logger.debug("Trace dumper thread has stopped.")
110110
break
111111
continue
112112
f.write(f"{trace}\n")
113113
task_queue.task_done()
114-
print("Trace dumper thread has finished normally...")
114+
logger.debug("Trace dumper thread has finished normally.")
115115

116116

117117
def get_trace_API_dumper_queue():

traincheck/instrumentor/proxy_wrapper/proxy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from .proxy_registry import get_global_registry
2222
from .utils import print_debug
2323

24+
logger = logging.getLogger(__name__)
25+
2426

2527
class ProxyObjInfo:
2628
def __init__(self, var_name: str, last_update_timestamp: int, version: int | None):
@@ -118,9 +120,8 @@ def proxy_parameters(module: torch.nn.Module, parent_name="", from_iter=False):
118120

119121
time_end = time.perf_counter()
120122
if num_params != 0:
121-
print(
122-
"logger_proxy: "
123-
+ f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time} seconds"
123+
logger.debug(
124+
f"Proxied {num_params} parameters of '{parent_name + module.__class__.__name__}', duration: {time_end - start_time:.3f}s"
124125
)
125126

126127
def update_timestamp(self):

traincheck/instrumentor/source_file.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ def _get_loop_context(self, node):
125125

126126
if iter_name:
127127
if "train" in iter_name:
128-
print(f"Found training loop based on iterator: {iter_name}")
128+
logger.debug(f"Found training loop based on iterator: {iter_name}")
129129
return "training"
130130
elif any(x in iter_name for x in ["val", "eval", "test"]):
131-
print(f"Found eval loop based on iterator: {iter_name}")
131+
logger.debug(f"Found eval loop based on iterator: {iter_name}")
132132
return "eval"
133133

134134
# Heuristic 2: Check for calls to .step() or .backward() or .eval()
@@ -187,25 +187,25 @@ def _get_loop_context(self, node):
187187
if isinstance(expr.func, ast.Attribute):
188188
if expr.func.attr == "no_grad":
189189
has_eval_signal = True
190-
print(f"Found no_grad context in loop {node}.")
190+
logger.debug(f"Found no_grad context in loop {node}.")
191191

192192
if has_training_signal:
193-
print(f"Found training signal in loop {node}.")
193+
logger.debug(f"Found training signal in loop {node}.")
194194
return "training"
195195

196196
if has_eval_signal:
197-
print(f"Found eval signal in loop {node}.")
197+
logger.debug(f"Found eval signal in loop {node}.")
198198
return "eval"
199199

200200
# if the number of lines are too few and the function calls do not involve "eval", "train", we omit the loop context
201201
# We use statement_count calculated recursively
202202
if statement_count < 3:
203-
print(
203+
logger.debug(
204204
f"Skipping loop {node} as it is too short ({statement_count} statements) and does not contain eval/train/step/backward signal."
205205
)
206206
return None
207207

208-
print(f"Found eval signal in loop {node} (fallback).")
208+
logger.debug(f"Found eval signal in loop {node} (fallback).")
209209
return "eval"
210210

211211
def _inject_call(self, node, func_name):
@@ -465,7 +465,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]:
465465
for node in ast.walk(root):
466466
for child in ast.iter_child_nodes(node):
467467
if child in parent_map and not ast.unparse(child).strip() == "":
468-
print(
468+
logger.debug(
469469
f"Node {ast.unparse(child)} already has a parent, {ast.unparse(parent_map[child])}"
470470
)
471471
parent_map[child] = node
@@ -480,7 +480,7 @@ def instrument_all_model_assignments(
480480
Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement
481481
after each assignment, depending on the mode.
482482
"""
483-
print(
483+
logger.debug(
484484
f"Instrumenting model: {model_name}, mode: {mode}, scanning for assignments to {model_name}"
485485
)
486486

@@ -529,10 +529,10 @@ def instrument_all_model_assignments(
529529
if node in parent_map:
530530
parent = parent_map[node]
531531
# print(f"Parent node: {ast.unparse(parent)}")
532-
print("\tInstrumenting: ", ast.unparse(node))
532+
logger.debug("Instrumenting: %s", ast.unparse(node))
533533
if isinstance(parent, ast.For):
534-
print(
535-
"\t\t⬆️ Parent is a for loop, cowardly skipping instrumentation in fear of multiple models with the same 'var_name'"
534+
logger.debug(
535+
"Parent is a for loop, skipping instrumentation to avoid multiple models with the same 'var_name'"
536536
)
537537
continue
538538
if node in parent.body: # type: ignore
@@ -601,25 +601,18 @@ def instrument_model_tracker_proxy(
601601
spec = importlib.util.find_spec('traincheck')
602602
if spec and spec.origin:
603603
traincheck_folder = os.path.dirname(spec.origin)
604-
print("traincheck folder: ", traincheck_folder)
605604
else:
606605
raise Exception("traincheck is not installed properly")
607-
print("auto observer enabled with observing depth: ", auto_observer_config["enable_auto_observer_depth"])
608606
enable_auto_observer_depth = auto_observer_config["enable_auto_observer_depth"]
609607
neglect_hidden_func = auto_observer_config["neglect_hidden_func"]
610608
neglect_hidden_module = auto_observer_config["neglect_hidden_module"]
611609
observe_then_unproxy = auto_observer_config["observe_then_unproxy"]
612610
observe_up_to_depth = auto_observer_config["observe_up_to_depth"]
613-
if observe_up_to_depth:
614-
print("observe up to the depth of the function call")
615-
else:
616-
print("observe only the function call at the depth")
617611
from traincheck.static_analyzer.graph_generator.call_graph_parser import add_observer_given_call_graph
618612
619613
log_files = glob.glob(
620614
os.path.join(traincheck_folder, "static_analyzer", "func_level", "*.log")
621615
)
622-
print("log_files: ", log_files)
623616
for log_file in log_files:
624617
add_observer_given_call_graph(
625618
log_file,
@@ -1072,7 +1065,7 @@ def instrument_file(
10721065
if model_tracker_style == "proxy" or model_tracker_style == "subclass":
10731066
if model_tracker_style == "subclass":
10741067
# adjust the proxy config to disable the proxy-specific configs
1075-
print(
1068+
logger.debug(
10761069
"Using subclass model tracker, overriding observe_then_unproxy to False"
10771070
)
10781071
adjusted_proxy_config[0]["observe_then_unproxy"] = False

0 commit comments

Comments
 (0)