1717from devito .exceptions import CompilationError
1818from devito .ir .iet .nodes import (
1919 BlankLine , Call , Expression , ExpressionBundle , Iteration , Lambda , ListMajor , Node ,
20- Section
20+ MetaCall , Section
2121)
2222from devito .ir .support .space import Backward
2323from devito .symbolics import (
@@ -1508,7 +1508,8 @@ class Specializer(Uxreplace):
15081508 Note that the Operator is not re-optimized in response to this replacement - this
15091509 transformation could nominally result in expressions of the form `f + 0` in the
15101510 generated code. If one wants to construct an Operator where such expressions are
1511- considered, then use of `subs=...` is a better choice.
1511+ considered, then use of `subs=...` at construction time is a better choice. However,
1512+ it is likely that such expressions will be optimized away by the C-level compiler.
15121513 """
15131514
15141515 def __init__ (self , mapper , nested = False ):
@@ -1523,7 +1524,7 @@ def __init__(self, mapper, nested=False):
15231524 raise ValueError ("Only SymPy Numbers can used to replace values during "
15241525 f"specialization. Value { v } was supplied for symbol "
15251526 f"{ k } , but is of type { type (v )} ." )
1526-
1527+
15271528 def visit_KernelLaunch (self , o ):
15281529 # Remove kernel args if they are to be hardcoded
15291530 arguments = [i for i in o .arguments if i not in self .mapper ]
@@ -1553,6 +1554,23 @@ def visit_Operator(self, o, **kwargs):
15531554 state ['parameters' ] = parameters
15541555 state ['body' ] = body
15551556
1557+ # TODO: Also rebuild the _func_table for the Operator
1558+ # TODO: This is somewhat incongruent with the visitor and should be refactored
1559+
1560+ func_table = OrderedDict ()
1561+ for k , v in o ._func_table .items ():
1562+ root = v .root
1563+ local = v .local
1564+
1565+ body = self ._visit (root .body )
1566+ parameters = tuple (i for i in root .parameters if i not in self .mapper )
1567+
1568+ new_root = root ._rebuild (body = body , parameters = parameters )
1569+
1570+ func_table [k ] = MetaCall (root = new_root , local = local )
1571+
1572+ state ['_func_table' ] = func_table
1573+
15561574 try :
15571575 state .pop ('ccode' )
15581576 except KeyError :
0 commit comments