|
22 | 22 | from devito.ir.equations import LoweredEq, concretize_subdims, lower_exprs |
23 | 23 | from devito.ir.iet import ( |
24 | 24 | Callable, CInterface, DeviceFunction, EntryFunction, FindSymbols, MetaCall, |
25 | | - derive_parameters, iet_build |
| 25 | + Specializer, derive_parameters, iet_build |
26 | 26 | ) |
27 | 27 | from devito.ir.stree import stree_build |
28 | 28 | from devito.ir.support import AccessMode, SymbolRegistry |
@@ -986,16 +986,34 @@ def apply(self, **kwargs): |
986 | 986 | >>> op = Operator(Eq(u3.forward, u3 + 1)) |
987 | 987 | >>> summary = op.apply(time_M=10) |
988 | 988 | """ |
989 | | - # Compile the operator before building the arguments list |
990 | | - # to avoid out of memory with greedy compilers |
991 | | - cfunction = self.cfunction |
| 989 | + # Get items expected to be specialized |
| 990 | + specialize = as_tuple(kwargs.pop('specialize', [])) |
| 991 | + |
| 992 | + if not specialize: |
| 993 | + # Compile the operator before building the arguments list |
| 994 | + # to avoid out of memory with greedy compilers |
| 995 | + cfunction = self.cfunction |
992 | 996 |
|
993 | 997 | # Build the arguments list to invoke the kernel function |
994 | 998 | with self._profiler.timer_on('arguments-preprocess'): |
995 | 999 | args = self.arguments(**kwargs) |
996 | 1000 | with switch_log_level(comm=args.comm): |
997 | 1001 | self._emit_args_profiling('arguments-preprocess') |
998 | 1002 |
|
| 1003 | + # In the case of specialization, arguments must be processed before |
| 1004 | + # the operator can be compiled |
| 1005 | + if specialize: |
| 1006 | + specialized_args = {p: sympify(args.pop(p.name)) |
| 1007 | + for p in self.parameters if p.name in specialize} |
| 1008 | + |
| 1009 | + op = Specializer(specialized_args).visit(self) |
| 1010 | + else: |
| 1011 | + op = self |
| 1012 | + |
| 1013 | + from IPython import embed; embed() |
| 1014 | + |
| 1015 | + # TODO: Whose profiler should get used here? |
| 1016 | + |
999 | 1017 | # Invoke kernel function with args |
1000 | 1018 | arg_values = [args[p.name] for p in self.parameters] |
1001 | 1019 | try: |
|
0 commit comments