Skip to content

Commit 3f7f6dc

Browse files
committed
compiler: Tweak CIRE max-par lowering
1 parent c65c28a commit 3f7f6dc

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

devito/passes/clusters/aliases.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
139139
# [Schedule]_m -> Schedule (s.t. best memory/flops trade-off)
140140
schedule, exprs = self._select(variants)
141141

142+
# Schedule -> Schedule (optimization)
143+
if self.opt_maxpar:
144+
schedule = optimize_schedule_maxpar(schedule)
145+
142146
# Schedule -> Schedule (optimization)
143147
if self.opt_rotate:
144148
schedule = optimize_schedule_rotations(schedule, self.sregistry)
@@ -664,7 +668,6 @@ def lower_aliases(aliases, meta, maxpar):
664668
"""
665669
Create a Schedule from an AliasList.
666670
"""
667-
stampcache = {}
668671
dmapper = {}
669672
processed = []
670673
for a in aliases:
@@ -704,12 +707,6 @@ def lower_aliases(aliases, meta, maxpar):
704707
# use `<1>` as stamp, which is what appears in `ispace`
705708
interval = interval.lift(i.stamp)
706709

707-
# We further bump the interval stamp if we were requested to trade
708-
# fusion for more collapse-parallelism
709-
if maxpar:
710-
stamp = stampcache.setdefault(interval.dim, Stamp())
711-
interval = interval.lift(stamp)
712-
713710
writeto.append(interval)
714711
intervals.append(interval)
715712

@@ -853,6 +850,30 @@ def optimize_schedule_rotations(schedule, sregistry):
853850
return schedule.rebuild(*processed, rmapper=rmapper)
854851

855852

853+
def optimize_schedule_maxpar(schedule):
854+
"""
855+
Bump the IterationSpace' stamp trading fusion for more collapse-parallelism.
856+
"""
857+
key = lambda i: (i.writeto, i.ispace)
858+
859+
processed = []
860+
for (writeto0, ispace0), group in groupby(schedule, key=key):
861+
g = list(group)
862+
863+
stamp = Stamp()
864+
dims = writeto0.itdims
865+
866+
writeto = writeto0.lift(dims, stamp)
867+
ispace = ispace0.lift(dims, stamp)
868+
869+
processed.extend([
870+
ScheduledAlias(pivot, writeto, ispace, aliaseds, indicess)
871+
for pivot, _, _, aliaseds, indicess in g
872+
])
873+
874+
return schedule.rebuild(*processed)
875+
876+
856877
def lower_schedule(schedule, meta, sregistry, opt_ftemps, opt_min_dtype,
857878
opt_minmem):
858879
"""

0 commit comments

Comments
 (0)