@@ -994,27 +994,22 @@ def apply(self, **kwargs):
994994 # Get items expected to be specialized
995995 specialize = as_tuple (kwargs .pop ('specialize' , []))
996996
997- # In the case of specialization, arguments must be processed before
998- # the operator can be compiled
999997 if specialize :
1000998 # FIXME: Cannot cope with things like sizes/strides yet since it only
1001999 # looks at the parameters
10021000
10031001 # Build the arguments list for specialization
1004- with self ._profiler .timer_on ('specialized-arguments-preprocess ' ):
1002+ with self ._profiler .timer_on ('specialization ' ):
10051003 args = self .arguments (** kwargs )
1006- with switch_log_level (comm = args .comm ):
1007- self ._emit_args_profiling ('specialized-arguments-preprocess' )
1004+ # Uses parameters here since Specializer needs {symbol: sympy value}
1005+ specialized_values = {p : sympify (args [p .name ])
1006+ for p in self .parameters if p .name in specialize }
10081007
1009- # Uses parameters here since Specializer needs {symbol: sympy value} mapper
1010- specialized_values = {p : sympify (args [p .name ])
1011- for p in self .parameters if p .name in specialize }
1008+ op = Specializer (specialized_values ).visit (self )
10121009
1013- op = Specializer (specialized_values ).visit (self )
1010+ with switch_log_level (comm = args .comm ):
1011+ self ._emit_args_profiling ('specialization' )
10141012
1015- # TODO: Does this cause problems for profilers?
1016- # FIXME: Need some way to inspect this Operator for testing
1017- # FIXME: Perhaps this should use some separate method
10181013 unspecialized_kwargs = {k : v for k , v in kwargs .items ()
10191014 if k not in specialize }
10201015
@@ -1030,9 +1025,7 @@ def apply(self, **kwargs):
10301025 with switch_log_level (comm = args .comm ):
10311026 self ._emit_args_profiling ('arguments-preprocess' )
10321027
1033- args_string = ", " .join ([f"{ p .name } ={ args [p .name ]} "
1034- for p in self .parameters if p .is_Symbol ])
1035- debug (f"Invoking `{ self .name } ` with scalar arguments: { args_string } " )
1028+ self ._emit_arguments (args )
10361029
10371030 # Invoke kernel function with args
10381031 arg_values = [args [p .name ] for p in self .parameters ]
@@ -1069,6 +1062,28 @@ def _emit_args_profiling(self, tag=''):
10691062 tagstr = ' ' .join (tag .split ('-' ))
10701063 debug (f"Operator `{ self .name } ` { tagstr } : { elapsed :.2f} s" )
10711064
1065+ def _emit_arguments (self , args ):
1066+ comm = args .comm
1067+ scalar_args = ", " .join ([f"{ p .name } ={ args [p .name ]} "
1068+ for p in self .parameters
1069+ if p .is_Symbol ])
1070+
1071+ rank = f"[rank{ args .comm .Get_rank ()} ] " if comm is not MPI .COMM_NULL else ""
1072+
1073+ msg = f"* { rank } { scalar_args } "
1074+
1075+ with switch_log_level (comm = comm ):
1076+ debug (f"Scalar arguments used to invoke `{ self .name } `" )
1077+
1078+ if comm is not MPI .COMM_NULL :
1079+ # With MPI enabled, we add one entry per rank
1080+ allmsg = comm .allgather (msg )
1081+ if comm .Get_rank () == 0 :
1082+ for m in allmsg :
1083+ debug (m )
1084+ else :
1085+ debug (msg )
1086+
10721087 def _emit_build_profiling (self ):
10731088 if not is_log_enabled_for ('PERF' ):
10741089 return
0 commit comments