@@ -1499,6 +1499,41 @@ def visit_KernelLaunch(self, o):
14991499 arguments = arguments )
15001500
15011501
1502+ class Specializer (Uxreplace ):
1503+ """
1504+ A Transformer to "specialize" a pre-built Operator - that is to replace a given
1505+ set of (scalar) symbols with hard-coded values to free up registers. This will
1506+ yield a "specialized" version of the Operator, specific to a particular setup.
1507+ """
1508+
1509+ def __init__ (self , mapper , nested = False ):
1510+ super ().__init__ (mapper , nested = nested )
1511+
1512+ # Sanity check
1513+ for k in self .mapper .keys ():
1514+ if not isinstance (k , AbstractSymbol ):
1515+ raise ValueError (f"Attempted to specialize non-scalar symbol: { k } " )
1516+
1517+ def visit_Operator (self , o , ** kwargs ):
1518+ # Entirely fine to apply this to an Operator
1519+ body = self ._visit (o .body )
1520+ parameters = tuple (i for i in o .parameters if i not in self .mapper )
1521+
1522+ # Note: the following is not dissimilar to unpickling an Operator
1523+ state = o .__getstate__ ()
1524+ state ['parameters' ] = parameters
1525+ state ['body' ] = body
1526+ state .pop ('ccode' )
1527+
1528+ # FIXME: These names aren't great
1529+ newargs , newkwargs = o .__getnewargs_ex__ ()
1530+ newop = o .__class__ (* newargs , ** newkwargs )
1531+
1532+ newop .__setstate__ (state )
1533+
1534+ return newop
1535+
1536+
15021537# Utils
15031538
15041539blankline = c .Line ("" )
0 commit comments