@@ -1525,6 +1525,18 @@ def __init__(self, mapper, nested=False):
15251525 f"specialization. Value { v } was supplied for symbol "
15261526 f"{ k } , but is of type { type (v )} ." )
15271527
1528+ def visit_OrderedDict (self , o ):
1529+ return OrderedDict ((k , self ._visit (v )) for k , v in o .items ())
1530+
1531+ def visit_MetaCall (self , o ):
1532+ root = self ._visit (o .root )
1533+ return MetaCall (root = root , local = o .local )
1534+
1535+ def visit_Callable (self , o ):
1536+ body = self ._visit (o .body )
1537+ parameters = [i for i in o .parameters if i not in self .mapper ]
1538+ return o ._rebuild (body = body , parameters = parameters )
1539+
15281540 def visit_KernelLaunch (self , o ):
15291541 # Remove kernel args if they are to be hardcoded
15301542 arguments = [i for i in o .arguments if i not in self .mapper ]
@@ -1553,23 +1565,8 @@ def visit_Operator(self, o, **kwargs):
15531565 state = o .__getstate__ ()
15541566 state ['parameters' ] = parameters
15551567 state ['body' ] = body
1556-
1557- # TODO: Also rebuild the _func_table for the Operator
1558- # TODO: This is somewhat incongruent with the visitor and should be refactored
1559-
1560- func_table = OrderedDict ()
1561- for k , v in o ._func_table .items ():
1562- root = v .root
1563- local = v .local
1564-
1565- body = self ._visit (root .body )
1566- parameters = tuple (i for i in root .parameters if i not in self .mapper )
1567-
1568- new_root = root ._rebuild (body = body , parameters = parameters )
1569-
1570- func_table [k ] = MetaCall (root = new_root , local = local )
1571-
1572- state ['_func_table' ] = func_table
1568+ # Modify the _func_table to ensure callbacks are specialized
1569+ state ['_func_table' ] = self ._visit (o ._func_table )
15731570
15741571 try :
15751572 state .pop ('ccode' )
0 commit comments