Skip to content

Commit fe0835b

Browse files
committed
compiler: Emit arguments used to invoke kernels and add test for specialization with MPI
1 parent ca2e4af commit fe0835b

2 files changed

Lines changed: 36 additions & 18 deletions

File tree

devito/operator/operator.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -994,27 +994,22 @@ def apply(self, **kwargs):
994994
# Get items expected to be specialized
995995
specialize = as_tuple(kwargs.pop('specialize', []))
996996

997-
# In the case of specialization, arguments must be processed before
998-
# the operator can be compiled
999997
if specialize:
1000998
# FIXME: Cannot cope with things like sizes/strides yet since it only
1001999
# looks at the parameters
10021000

10031001
# Build the arguments list for specialization
1004-
with self._profiler.timer_on('specialized-arguments-preprocess'):
1002+
with self._profiler.timer_on('specialization'):
10051003
args = self.arguments(**kwargs)
1006-
with switch_log_level(comm=args.comm):
1007-
self._emit_args_profiling('specialized-arguments-preprocess')
1004+
# Uses parameters here since Specializer needs {symbol: sympy value}
1005+
specialized_values = {p: sympify(args[p.name])
1006+
for p in self.parameters if p.name in specialize}
10081007

1009-
# Uses parameters here since Specializer needs {symbol: sympy value} mapper
1010-
specialized_values = {p: sympify(args[p.name])
1011-
for p in self.parameters if p.name in specialize}
1008+
op = Specializer(specialized_values).visit(self)
10121009

1013-
op = Specializer(specialized_values).visit(self)
1010+
with switch_log_level(comm=args.comm):
1011+
self._emit_args_profiling('specialization')
10141012

1015-
# TODO: Does this cause problems for profilers?
1016-
# FIXME: Need some way to inspect this Operator for testing
1017-
# FIXME: Perhaps this should use some separate method
10181013
unspecialized_kwargs = {k: v for k, v in kwargs.items()
10191014
if k not in specialize}
10201015

@@ -1030,9 +1025,7 @@ def apply(self, **kwargs):
10301025
with switch_log_level(comm=args.comm):
10311026
self._emit_args_profiling('arguments-preprocess')
10321027

1033-
args_string = ", ".join([f"{p.name}={args[p.name]}"
1034-
for p in self.parameters if p.is_Symbol])
1035-
debug(f"Invoking `{self.name}` with scalar arguments: {args_string}")
1028+
self._emit_arguments(args)
10361029

10371030
# Invoke kernel function with args
10381031
arg_values = [args[p.name] for p in self.parameters]
@@ -1069,6 +1062,28 @@ def _emit_args_profiling(self, tag=''):
10691062
tagstr = ' '.join(tag.split('-'))
10701063
debug(f"Operator `{self.name}` {tagstr}: {elapsed:.2f} s")
10711064

1065+
def _emit_arguments(self, args):
1066+
comm = args.comm
1067+
scalar_args = ", ".join([f"{p.name}={args[p.name]}"
1068+
for p in self.parameters
1069+
if p.is_Symbol])
1070+
1071+
rank = f"[rank{args.comm.Get_rank()}] " if comm is not MPI.COMM_NULL else ""
1072+
1073+
msg = f"* {rank}{scalar_args}"
1074+
1075+
with switch_log_level(comm=comm):
1076+
debug(f"Scalar arguments used to invoke `{self.name}`")
1077+
1078+
if comm is not MPI.COMM_NULL:
1079+
# With MPI enabled, we add one entry per rank
1080+
allmsg = comm.allgather(msg)
1081+
if comm.Get_rank() == 0:
1082+
for m in allmsg:
1083+
debug(m)
1084+
else:
1085+
debug(msg)
1086+
10721087
def _emit_build_profiling(self):
10731088
if not is_log_enabled_for('PERF'):
10741089
return

tests/test_specialization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,15 @@ def test_basic(self, caplog, override):
228228

229229
# Ensure that the specialized operator was run
230230
assert all(s not in caplog.text for s in specialize)
231-
assert "specialized arguments preprocess" in caplog.text
231+
assert "specialization" in caplog.text
232232

233233
check = np.array(f.data[:])
234234
f.data[:] = 0
235235
op.apply(**kwargs)
236236

237-
assert np.all(check == f.data)
237+
assert np.all(check == f.data[:])
238238

239-
# Need to test specialization with MPI (both at)
239+
@pytest.mark.parallel(mode=[2, 4])
240+
@pytest.mark.parametrize('override', [False, True])
241+
def test_basic_mpi(self, caplog, mode, override):
242+
self.test_basic(caplog, override)

0 commit comments

Comments
 (0)