@@ -457,8 +457,8 @@ def reduction_comms(clusters):
457457 # if `c`'s IterationSpace is such that the reduction can be carried out
458458 found , fifo = split (fifo , lambda dr : dr .ispace .is_subset (c .ispace ))
459459 for ispace , reds in groupby (found , key = lambda r : r .ispace ):
460- exprs = [ Eq ( dr .var , dr ) for dr in reds ]
461- processed .append (Cluster (exprs = exprs , ispace = ispace ))
460+ exprs = flatten ([ dr .exprs for dr in reds ])
461+ processed .append (c . rebuild (exprs = exprs , ispace = ispace ))
462462
463463 # Detect the global distributed reductions in `c`
464464 for e in c .exprs :
@@ -487,15 +487,16 @@ def reduction_comms(clusters):
487487 # The IterationSpace within which the global distributed reduction
488488 # must be carried out
489489 ispace = c .ispace .prefix (lambda d : d in var .free_symbols )
490-
491- fifo .append (DistReduce ( var , op = op , grid = grid , ispace = ispace ))
490+ expr = [ Eq ( var , DistReduce ( var , op = op , grid = grid , ispace = ispace ))]
491+ fifo .append (c . rebuild ( exprs = expr , ispace = ispace ))
492492
493493 processed .append (c )
494494
495495 # Leftover reductions are placed at the very end
496496 for ispace , reds in groupby (fifo , key = lambda r : r .ispace ):
497- exprs = [Eq (dr .var , dr ) for dr in reds ]
498- processed .append (Cluster (exprs = exprs , ispace = ispace ))
497+ reds = list (reds )
498+ exprs = flatten ([dr .exprs for dr in reds ])
499+ processed .append (reds [0 ].rebuild (exprs = exprs , ispace = ispace ))
499500
500501 return processed
501502
0 commit comments