@@ -30,12 +30,14 @@ class HaloLabel(Tag):
3030class HaloSchemeEntry (EnrichedTuple ):
3131
3232 __rargs__ = ('loc_indices' , 'loc_dirs' , 'halos' , 'dims' )
33+ __rkwargs__ = ('bundle' ,)
3334
34- def __new__ (cls , loc_indices , loc_dirs , halos , dims , getters = None ):
35+ def __new__ (cls , loc_indices , loc_dirs , halos , dims , bundle = None , getters = None ):
36+ getters = cls .__rargs__ + cls .__rkwargs__
3537 items = [frozendict (loc_indices ), frozendict (loc_dirs ),
36- frozenset (halos ), frozenset (dims )]
37- kwargs = dict (zip (cls . __rargs__ , items ))
38- return super ().__new__ (cls , * items , getters = cls . __rargs__ , ** kwargs )
38+ frozenset (halos ), frozenset (dims ), bundle ]
39+ kwargs = dict (zip (getters , items ))
40+ return super ().__new__ (cls , * items , getters = getters , ** kwargs )
3941
4042 def __hash__ (self ):
4143 return hash ((self .loc_indices , self .loc_dirs , self .halos , self .dims ))
@@ -47,7 +49,8 @@ def union(self, other):
4749 exception is raised.
4850 """
4951 if self .loc_indices != other .loc_indices or \
50- self .loc_dirs != other .loc_dirs :
52+ self .loc_dirs != other .loc_dirs or \
53+ self .bundle is not other .bundle :
5154 raise HaloSchemeException (
5255 "Inconsistency found while building a HaloScheme"
5356 )
@@ -56,7 +59,7 @@ def union(self, other):
5659 dims = self .dims | other .dims
5760
5861 return HaloSchemeEntry (self .loc_indices , self .loc_dirs , halos , dims ,
59- getters = self .getters )
62+ bundle = self . bundle , getters = self .getters )
6063
6164
6265Halo = namedtuple ('Halo' , 'dim side' )
@@ -168,7 +171,7 @@ def union(self, halo_schemes):
168171 elif not v .loc_indices or hse .loc_indices == v .loc_indices :
169172 loc_indices , loc_dirs = hse .loc_indices , hse .loc_dirs
170173 else :
171- # The `loc_dirs` must match otherwise it'd be a symptom there's
174+ # These must match otherwise it'd be a symptom there's
172175 # something horribly broken elsewhere!
173176 assert hse .loc_dirs == v .loc_dirs
174177 assert list (hse .loc_indices ) == list (v .loc_indices )
@@ -185,7 +188,11 @@ def union(self, halo_schemes):
185188 halos = hse .halos | v .halos
186189 dims = hse .dims | v .dims
187190
188- fmapper [k ] = HaloSchemeEntry (loc_indices , loc_dirs , halos , dims )
191+ assert hse .bundle is v .bundle
192+
193+ fmapper [k ] = HaloSchemeEntry (
194+ loc_indices , loc_dirs , halos , dims , bundle = hse .bundle
195+ )
189196
190197 # Compute the `honored` union
191198 for d , v in i .honored .items ():
@@ -641,8 +648,12 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
641648 for i , v in rule .items ():
642649 if i is f :
643650 # Yes!
644- g = v
645- hse = hse0
651+ if v .is_Bundle :
652+ g = f
653+ hse = hse0 ._rebuild (bundle = v )
654+ else :
655+ g = v
656+ hse = hse0
646657
647658 elif i .is_Indexed and i .function is f and v .is_Indexed :
648659 # Yes, but through an Indexed, hence the `loc_indices` may now
0 commit comments