Skip to content

Commit 18906bd

Browse files
committed
compiler: Start adding machinery to specialise operators with hardcoded values
1 parent e147aaa commit 18906bd

1 file changed

Lines changed: 35 additions & 0 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,41 @@ def visit_KernelLaunch(self, o):
14981498
arguments=arguments)
14991499

15001500

1501+
class Specializer(Uxreplace):
1502+
"""
1503+
A Transformer to "specialize" a pre-built Operator - that is to replace a given
1504+
set of (scalar) symbols with hard-coded values to free up registers. This will
1505+
yield a "specialized" version of the Operator, specific to a particular setup.
1506+
"""
1507+
1508+
def __init__(self, mapper, nested=False):
1509+
super().__init__(mapper, nested=nested)
1510+
1511+
# Sanity check
1512+
for k in self.mapper.keys():
1513+
if not isinstance(k, AbstractSymbol):
1514+
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")
1515+
1516+
def visit_Operator(self, o, **kwargs):
1517+
# Entirely fine to apply this to an Operator
1518+
body = self._visit(o.body)
1519+
parameters = tuple(i for i in o.parameters if i not in self.mapper)
1520+
1521+
# Note: the following is not dissimilar to unpickling an Operator
1522+
state = o.__getstate__()
1523+
state['parameters'] = parameters
1524+
state['body'] = body
1525+
state.pop('ccode')
1526+
1527+
# FIXME: These names aren't great
1528+
newargs, newkwargs = o.__getnewargs_ex__()
1529+
newop = o.__class__(*newargs, **newkwargs)
1530+
1531+
newop.__setstate__(state)
1532+
1533+
return newop
1534+
1535+
15011536
# Utils
15021537

15031538
blankline = c.Line("")

0 commit comments

Comments
 (0)