@@ -1498,6 +1498,41 @@ def visit_KernelLaunch(self, o):
14981498 arguments = arguments )
14991499
15001500
1501+ class Specializer (Uxreplace ):
1502+ """
1503+ A Transformer to "specialize" a pre-built Operator - that is to replace a given
1504+ set of (scalar) symbols with hard-coded values to free up registers. This will
1505+ yield a "specialized" version of the Operator, specific to a particular setup.
1506+ """
1507+
1508+ def __init__ (self , mapper , nested = False ):
1509+ super ().__init__ (mapper , nested = nested )
1510+
1511+ # Sanity check
1512+ for k in self .mapper .keys ():
1513+ if not isinstance (k , AbstractSymbol ):
1514+ raise ValueError (f"Attempted to specialize non-scalar symbol: { k } " )
1515+
1516+ def visit_Operator (self , o , ** kwargs ):
1517+ # Entirely fine to apply this to an Operator
1518+ body = self ._visit (o .body )
1519+ parameters = tuple (i for i in o .parameters if i not in self .mapper )
1520+
1521+ # Note: the following is not dissimilar to unpickling an Operator
1522+ state = o .__getstate__ ()
1523+ state ['parameters' ] = parameters
1524+ state ['body' ] = body
1525+ state .pop ('ccode' )
1526+
1527+ # FIXME: These names aren't great
1528+ newargs , newkwargs = o .__getnewargs_ex__ ()
1529+ newop = o .__class__ (* newargs , ** newkwargs )
1530+
1531+ newop .__setstate__ (state )
1532+
1533+ return newop
1534+
1535+
15011536# Utils
15021537
15031538blankline = c .Line ("" )
0 commit comments