Skip to content

Commit 893cb09

Browse files
committed
compiler: prevent loosing cluster propertieswit DistReduce
1 parent 7783a27 commit 893cb09

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,8 @@ def reduction_comms(clusters):
457457
# if `c`'s IterationSpace is such that the reduction can be carried out
458458
found, fifo = split(fifo, lambda dr: dr.ispace.is_subset(c.ispace))
459459
for ispace, reds in groupby(found, key=lambda r: r.ispace):
460-
exprs = [Eq(dr.var, dr) for dr in reds]
461-
processed.append(Cluster(exprs=exprs, ispace=ispace))
460+
exprs = flatten([dr.exprs for dr in reds])
461+
processed.append(c.rebuild(exprs=exprs, ispace=ispace))
462462

463463
# Detect the global distributed reductions in `c`
464464
for e in c.exprs:
@@ -487,15 +487,16 @@ def reduction_comms(clusters):
487487
# The IterationSpace within which the global distributed reduction
488488
# must be carried out
489489
ispace = c.ispace.prefix(lambda d: d in var.free_symbols)
490-
491-
fifo.append(DistReduce(var, op=op, grid=grid, ispace=ispace))
490+
expr = [Eq(var, DistReduce(var, op=op, grid=grid, ispace=ispace))]
491+
fifo.append(c.rebuild(exprs=expr, ispace=ispace))
492492

493493
processed.append(c)
494494

495495
# Leftover reductions are placed at the very end
496496
for ispace, reds in groupby(fifo, key=lambda r: r.ispace):
497-
exprs = [Eq(dr.var, dr) for dr in reds]
498-
processed.append(Cluster(exprs=exprs, ispace=ispace))
497+
reds = list(reds)
498+
exprs = flatten([dr.exprs for dr in reds])
499+
processed.append(reds[0].rebuild(exprs=exprs, ispace=ispace))
499500

500501
return processed
501502

devito/mpi/reduction_scheme.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __new__(cls, var, op=None, grid=None, ispace=None, **kwargs):
1919
obj.op = op
2020
obj.grid = grid
2121
obj.ispace = ispace
22+
obj.guards = kwargs.get('guards', None)
2223
return obj
2324

2425
def __repr__(self):

0 commit comments

Comments
 (0)