Skip to content

Commit 29a255e

Browse files
committed
API: Refactor operator specialization API
1 parent 289676d commit 29a255e

1 file changed

Lines changed: 37 additions & 23 deletions

File tree

devito/operator/operator.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -924,6 +924,43 @@ def _enrich_memreport(self, args):
924924
# Hook for enriching memory report with additional metadata
925925
return {}
926926

927+
def specialize(self, **kwargs):
928+
"""
929+
"""
930+
931+
specialize = as_tuple(kwargs.pop('specialize', []))
932+
933+
if not specialize:
934+
return self, kwargs
935+
936+
# FIXME: Cannot cope with things like sizes/strides yet since it only
937+
# looks at the parameters
938+
939+
# Build the arguments list for specialization
940+
with self._profiler.timer_on('specialization'):
941+
args = self.arguments(**kwargs)
942+
# Uses parameters here since Specializer needs {symbol: sympy value}
943+
specialized_values = {p: sympify(args[p.name])
944+
for p in self.parameters
945+
if p.name in specialize}
946+
947+
op = Specializer(specialized_values).visit(self)
948+
949+
with switch_log_level(comm=args.comm):
950+
self._emit_args_profiling('specialization')
951+
952+
unspecialized_kwargs = {k: v for k, v in kwargs.items()
953+
if k not in specialize}
954+
955+
return op, unspecialized_kwargs
956+
957+
def apply_specialize(self, **kwargs):
958+
"""
959+
"""
960+
961+
op, unspecialized_kwargs = self.specialize(**kwargs)
962+
return op.apply(**unspecialized_kwargs)
963+
927964
def apply(self, **kwargs):
928965
"""
929966
Execute the Operator.
@@ -986,29 +1023,6 @@ def apply(self, **kwargs):
9861023
>>> op = Operator(Eq(u3.forward, u3 + 1))
9871024
>>> summary = op.apply(time_M=10)
9881025
"""
989-
# Get items expected to be specialized
990-
specialize = as_tuple(kwargs.pop('specialize', []))
991-
992-
if specialize:
993-
# FIXME: Cannot cope with things like sizes/strides yet since it only
994-
# looks at the parameters
995-
996-
# Build the arguments list for specialization
997-
with self._profiler.timer_on('specialization'):
998-
args = self.arguments(**kwargs)
999-
# Uses parameters here since Specializer needs {symbol: sympy value}
1000-
specialized_values = {p: sympify(args[p.name])
1001-
for p in self.parameters if p.name in specialize}
1002-
1003-
op = Specializer(specialized_values).visit(self)
1004-
1005-
with switch_log_level(comm=args.comm):
1006-
self._emit_args_profiling('specialization')
1007-
1008-
unspecialized_kwargs = {k: v for k, v in kwargs.items()
1009-
if k not in specialize}
1010-
1011-
return op.apply(**unspecialized_kwargs)
10121026

10131027
# Compile the operator before building the arguments list
10141028
# to avoid out of memory with greedy compilers

0 commit comments

Comments
 (0)