Skip to content

Commit 99ccbe1

Browse files
committed
compiler: Start adding machinery to specialise operators with hardcoded values
1 parent 6061b76 commit 99ccbe1

1 file changed

Lines changed: 35 additions & 0 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,6 +1499,41 @@ def visit_KernelLaunch(self, o):
14991499
arguments=arguments)
15001500

15011501

1502+
class Specializer(Uxreplace):
1503+
"""
1504+
A Transformer to "specialize" a pre-built Operator - that is to replace a given
1505+
set of (scalar) symbols with hard-coded values to free up registers. This will
1506+
yield a "specialized" version of the Operator, specific to a particular setup.
1507+
"""
1508+
1509+
def __init__(self, mapper, nested=False):
1510+
super().__init__(mapper, nested=nested)
1511+
1512+
# Sanity check
1513+
for k in self.mapper.keys():
1514+
if not isinstance(k, AbstractSymbol):
1515+
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")
1516+
1517+
def visit_Operator(self, o, **kwargs):
1518+
# Entirely fine to apply this to an Operator
1519+
body = self._visit(o.body)
1520+
parameters = tuple(i for i in o.parameters if i not in self.mapper)
1521+
1522+
# Note: the following is not dissimilar to unpickling an Operator
1523+
state = o.__getstate__()
1524+
state['parameters'] = parameters
1525+
state['body'] = body
1526+
state.pop('ccode')
1527+
1528+
# FIXME: These names aren't great
1529+
newargs, newkwargs = o.__getnewargs_ex__()
1530+
newop = o.__class__(*newargs, **newkwargs)
1531+
1532+
newop.__setstate__(state)
1533+
1534+
return newop
1535+
1536+
15021537
# Utils
15031538

15041539
blankline = c.Line("")

0 commit comments

Comments
 (0)