@@ -125,10 +125,10 @@ def _get_loop_context(self, node):
125125
126126 if iter_name :
127127 if "train" in iter_name :
128- print (f"Found training loop based on iterator: { iter_name } " )
128+ logger . debug (f"Found training loop based on iterator: { iter_name } " )
129129 return "training"
130130 elif any (x in iter_name for x in ["val" , "eval" , "test" ]):
131- print (f"Found eval loop based on iterator: { iter_name } " )
131+ logger . debug (f"Found eval loop based on iterator: { iter_name } " )
132132 return "eval"
133133
134134 # Heuristic 2: Check for calls to .step() or .backward() or .eval()
@@ -187,25 +187,25 @@ def _get_loop_context(self, node):
187187 if isinstance (expr .func , ast .Attribute ):
188188 if expr .func .attr == "no_grad" :
189189 has_eval_signal = True
190- print (f"Found no_grad context in loop { node } ." )
190+ logger . debug (f"Found no_grad context in loop { node } ." )
191191
192192 if has_training_signal :
193- print (f"Found training signal in loop { node } ." )
193+ logger . debug (f"Found training signal in loop { node } ." )
194194 return "training"
195195
196196 if has_eval_signal :
197- print (f"Found eval signal in loop { node } ." )
197+ logger . debug (f"Found eval signal in loop { node } ." )
198198 return "eval"
199199
200200 # if the number of lines are too few and the function calls do not involve "eval", "train", we omit the loop context
201201 # We use statement_count calculated recursively
202202 if statement_count < 3 :
203- print (
203+ logger . debug (
204204 f"Skipping loop { node } as it is too short ({ statement_count } statements) and does not contain eval/train/step/backward signal."
205205 )
206206 return None
207207
208- print (f"Found eval signal in loop { node } (fallback)." )
208+ logger . debug (f"Found eval signal in loop { node } (fallback)." )
209209 return "eval"
210210
211211 def _inject_call (self , node , func_name ):
@@ -465,7 +465,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]:
465465 for node in ast .walk (root ):
466466 for child in ast .iter_child_nodes (node ):
467467 if child in parent_map and not ast .unparse (child ).strip () == "" :
468- print (
468+ logger . debug (
469469 f"Node { ast .unparse (child )} already has a parent, { ast .unparse (parent_map [child ])} "
470470 )
471471 parent_map [child ] = node
@@ -480,7 +480,7 @@ def instrument_all_model_assignments(
480480 Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement
481481 after each assignment, depending on the mode.
482482 """
483- print (
483+ logger . debug (
484484 f"Instrumenting model: { model_name } , mode: { mode } , scanning for assignments to { model_name } "
485485 )
486486
@@ -529,10 +529,10 @@ def instrument_all_model_assignments(
529529 if node in parent_map :
530530 parent = parent_map [node ]
531531 # print(f"Parent node: {ast.unparse(parent)}")
532- print ( " \t Instrumenting: " , ast .unparse (node ))
532+ logger . debug ( "Instrumenting: %s " , ast .unparse (node ))
533533 if isinstance (parent , ast .For ):
534- print (
535- "\t \t ⬆️ Parent is a for loop, cowardly skipping instrumentation in fear of multiple models with the same 'var_name'"
534+ logger . debug (
535+ "Parent is a for loop, skipping instrumentation to avoid multiple models with the same 'var_name'"
536536 )
537537 continue
538538 if node in parent .body : # type: ignore
@@ -601,25 +601,18 @@ def instrument_model_tracker_proxy(
601601spec = importlib.util.find_spec('traincheck')
602602if spec and spec.origin:
603603 traincheck_folder = os.path.dirname(spec.origin)
604- print("traincheck folder: ", traincheck_folder)
605604else:
606605 raise Exception("traincheck is not installed properly")
607- print("auto observer enabled with observing depth: ", auto_observer_config["enable_auto_observer_depth"])
608606enable_auto_observer_depth = auto_observer_config["enable_auto_observer_depth"]
609607neglect_hidden_func = auto_observer_config["neglect_hidden_func"]
610608neglect_hidden_module = auto_observer_config["neglect_hidden_module"]
611609observe_then_unproxy = auto_observer_config["observe_then_unproxy"]
612610observe_up_to_depth = auto_observer_config["observe_up_to_depth"]
613- if observe_up_to_depth:
614- print("observe up to the depth of the function call")
615- else:
616- print("observe only the function call at the depth")
617611from traincheck.static_analyzer.graph_generator.call_graph_parser import add_observer_given_call_graph
618612
619613log_files = glob.glob(
620614 os.path.join(traincheck_folder, "static_analyzer", "func_level", "*.log")
621615)
622- print("log_files: ", log_files)
623616for log_file in log_files:
624617 add_observer_given_call_graph(
625618 log_file,
@@ -1072,7 +1065,7 @@ def instrument_file(
10721065 if model_tracker_style == "proxy" or model_tracker_style == "subclass" :
10731066 if model_tracker_style == "subclass" :
10741067 # adjust the proxy config to disable the proxy-specific configs
1075- print (
1068+ logger . debug (
10761069 "Using subclass model tracker, overriding observe_then_unproxy to False"
10771070 )
10781071 adjusted_proxy_config [0 ]["observe_then_unproxy" ] = False
0 commit comments