@@ -139,6 +139,10 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
139139 # [Schedule]_m -> Schedule (s.t. best memory/flops trade-off)
140140 schedule , exprs = self ._select (variants )
141141
142+ # Schedule -> Schedule (optimization)
143+ if self .opt_maxpar :
144+ schedule = optimize_schedule_maxpar (schedule )
145+
142146 # Schedule -> Schedule (optimization)
143147 if self .opt_rotate :
144148 schedule = optimize_schedule_rotations (schedule , self .sregistry )
@@ -664,7 +668,6 @@ def lower_aliases(aliases, meta, maxpar):
664668 """
665669 Create a Schedule from an AliasList.
666670 """
667- stampcache = {}
668671 dmapper = {}
669672 processed = []
670673 for a in aliases :
@@ -704,12 +707,6 @@ def lower_aliases(aliases, meta, maxpar):
704707 # use `<1>` as stamp, which is what appears in `ispace`
705708 interval = interval .lift (i .stamp )
706709
707- # We further bump the interval stamp if we were requested to trade
708- # fusion for more collapse-parallelism
709- if maxpar :
710- stamp = stampcache .setdefault (interval .dim , Stamp ())
711- interval = interval .lift (stamp )
712-
713710 writeto .append (interval )
714711 intervals .append (interval )
715712
@@ -853,6 +850,30 @@ def optimize_schedule_rotations(schedule, sregistry):
853850 return schedule .rebuild (* processed , rmapper = rmapper )
854851
855852
853+ def optimize_schedule_maxpar (schedule ):
854+ """
855+ Bump the IterationSpace' stamp trading fusion for more collapse-parallelism.
856+ """
857+ key = lambda i : (i .writeto , i .ispace )
858+
859+ processed = []
860+ for (writeto0 , ispace0 ), group in groupby (schedule , key = key ):
861+ g = list (group )
862+
863+ stamp = Stamp ()
864+ dims = writeto0 .itdims
865+
866+ writeto = writeto0 .lift (dims , stamp )
867+ ispace = ispace0 .lift (dims , stamp )
868+
869+ processed .extend ([
870+ ScheduledAlias (pivot , writeto , ispace , aliaseds , indicess )
871+ for pivot , _ , _ , aliaseds , indicess in g
872+ ])
873+
874+ return schedule .rebuild (* processed )
875+
876+
856877def lower_schedule (schedule , meta , sregistry , opt_ftemps , opt_min_dtype ,
857878 opt_minmem ):
858879 """
0 commit comments