Skip to content

Commit dd983b2

Browse files
committed
api: Start enabling specialization at operator apply
1 parent c6885c1 commit dd983b2

2 files changed

Lines changed: 23 additions & 4 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
'MapExprStmts',
4646
'MapHaloSpots',
4747
'MapNodes',
48+
'Specializer',
4849
'Transformer',
4950
'Uxreplace',
5051
'printAST',

devito/operator/operator.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from devito.ir.equations import LoweredEq, concretize_subdims, lower_exprs
2323
from devito.ir.iet import (
2424
Callable, CInterface, DeviceFunction, EntryFunction, FindSymbols, MetaCall,
25-
derive_parameters, iet_build
25+
Specializer, derive_parameters, iet_build
2626
)
2727
from devito.ir.stree import stree_build
2828
from devito.ir.support import AccessMode, SymbolRegistry
@@ -986,16 +986,34 @@ def apply(self, **kwargs):
986986
>>> op = Operator(Eq(u3.forward, u3 + 1))
987987
>>> summary = op.apply(time_M=10)
988988
"""
989-
# Compile the operator before building the arguments list
990-
# to avoid out of memory with greedy compilers
991-
cfunction = self.cfunction
989+
# Get items expected to be specialized
990+
specialize = as_tuple(kwargs.pop('specialize', []))
991+
992+
if not specialize:
993+
# Compile the operator before building the arguments list
994+
# to avoid out of memory with greedy compilers
995+
cfunction = self.cfunction
992996

993997
# Build the arguments list to invoke the kernel function
994998
with self._profiler.timer_on('arguments-preprocess'):
995999
args = self.arguments(**kwargs)
9961000
with switch_log_level(comm=args.comm):
9971001
self._emit_args_profiling('arguments-preprocess')
9981002

1003+
# In the case of specialization, arguments must be processed before
1004+
# the operator can be compiled
1005+
if specialize:
1006+
specialized_args = {p: sympify(args.pop(p.name))
1007+
for p in self.parameters if p.name in specialize}
1008+
1009+
op = Specializer(specialized_args).visit(self)
1010+
else:
1011+
op = self
1012+
1013+
from IPython import embed; embed()
1014+
1015+
# TODO: Whose profiler should get used here?
1016+
9991017
# Invoke kernel function with args
10001018
arg_values = [args[p.name] for p in self.parameters]
10011019
try:

0 commit comments

Comments
 (0)