|
21 | 21 | from devito.ir.clusters import ClusterGroup, clusterize |
22 | 22 | from devito.ir.iet import (Callable, CInterface, EntryFunction, DeviceFunction, |
23 | 23 | FindSymbols, MetaCall, derive_parameters, iet_build) |
| 24 | +from devito.ir.iet.visitors import Specializer |
24 | 25 | from devito.ir.support import AccessMode, SymbolRegistry |
25 | 26 | from devito.ir.stree import stree_build |
26 | 27 | from devito.operator.profiling import create_profile |
@@ -990,16 +991,34 @@ def apply(self, **kwargs): |
990 | 991 | >>> op = Operator(Eq(u3.forward, u3 + 1)) |
991 | 992 | >>> summary = op.apply(time_M=10) |
992 | 993 | """ |
993 | | - # Compile the operator before building the arguments list |
994 | | - # to avoid out of memory with greedy compilers |
995 | | - cfunction = self.cfunction |
| 994 | + # Get items expected to be specialized |
| 995 | + specialize = as_tuple(kwargs.pop('specialize', [])) |
| 996 | + |
| 997 | + if not specialize: |
| 998 | + # Compile the operator before building the arguments list |
| 999 | + # to avoid out of memory with greedy compilers |
| 1000 | + cfunction = self.cfunction |
996 | 1001 |
|
997 | 1002 | # Build the arguments list to invoke the kernel function |
998 | 1003 | with self._profiler.timer_on('arguments-preprocess'): |
999 | 1004 | args = self.arguments(**kwargs) |
1000 | 1005 | with switch_log_level(comm=args.comm): |
1001 | 1006 | self._emit_args_profiling('arguments-preprocess') |
1002 | 1007 |
|
| 1008 | + # In the case of specialization, arguments must be processed before |
| 1009 | + # the operator can be compiled |
| 1010 | + if specialize: |
| 1011 | + specialized_args = {p: sympify(args.pop(p.name)) |
| 1012 | + for p in self.parameters if p.name in specialize} |
| 1013 | + |
| 1014 | + op = Specializer(specialized_args).visit(self) |
| 1015 | + else: |
| 1016 | + op = self |
| 1017 | + |
| 1018 | + from IPython import embed; embed() |
| 1019 | + |
| 1020 | + # TODO: Whose profiler should get used here? |
| 1021 | + |
1003 | 1022 | # Invoke kernel function with args |
1004 | 1023 | arg_values = [args[p.name] for p in self.parameters] |
1005 | 1024 | try: |
|
0 commit comments