Skip to content

Commit b92ff5f

Browse files
committed
compiler: Make Specializer visit _func_table of an Operator
1 parent 8697a96 commit b92ff5f

1 file changed

Lines changed: 21 additions & 3 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.exceptions import CompilationError
1818
from devito.ir.iet.nodes import (
1919
BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node,
20-
Section
20+
MetaCall, Section
2121
)
2222
from devito.ir.support.space import Backward
2323
from devito.symbolics import (
@@ -1508,7 +1508,8 @@ class Specializer(Uxreplace):
15081508
Note that the Operator is not re-optimized in response to this replacement - this
15091509
transformation could nominally result in expressions of the form `f + 0` in the
15101510
generated code. If one wants to construct an Operator where such expressions are
1511-
considered, then use of `subs=...` is a better choice.
1511+
considered, then use of `subs=...` at construction time is a better choice. However,
1512+
it is likely that such expressions will be optimized away by the C-level compiler.
15121513
"""
15131514

15141515
def __init__(self, mapper, nested=False):
@@ -1523,7 +1524,7 @@ def __init__(self, mapper, nested=False):
15231524
raise ValueError("Only SymPy Numbers can used to replace values during "
15241525
f"specialization. Value {v} was supplied for symbol "
15251526
f"{k}, but is of type {type(v)}.")
1526-
1527+
15271528
def visit_KernelLaunch(self, o):
15281529
# Remove kernel args if they are to be hardcoded
15291530
arguments = [i for i in o.arguments if i not in self.mapper]
@@ -1553,6 +1554,23 @@ def visit_Operator(self, o, **kwargs):
15531554
state['parameters'] = parameters
15541555
state['body'] = body
15551556

1557+
# TODO: Also rebuild the _func_table for the Operator
1558+
# TODO: This is somewhat incongruent with the visitor and should be refactored
1559+
1560+
func_table = OrderedDict()
1561+
for k, v in o._func_table.items():
1562+
root = v.root
1563+
local = v.local
1564+
1565+
body = self._visit(root.body)
1566+
parameters = tuple(i for i in root.parameters if i not in self.mapper)
1567+
1568+
new_root = root._rebuild(body=body, parameters=parameters)
1569+
1570+
func_table[k] = MetaCall(root=new_root, local=local)
1571+
1572+
state['_func_table'] = func_table
1573+
15561574
try:
15571575
state.pop('ccode')
15581576
except KeyError:

0 commit comments

Comments
 (0)