@@ -47,17 +47,18 @@ def __hash__(self):
4747 def loc_values (self ):
4848 return frozenset (self .loc_indices .values ())
4949
50- def union (self , other ):
50+ def merge (self , other ):
5151 """
52- Return a new HaloSchemeEntry that is the union of this and `other`.
53- The `loc_indices` and `loc_dirs` must be the same, otherwise an
54- exception is raised.
52+ Return a new HaloSchemeEntry that is the result of merging `other`
53+ into `self`, irrespective of whether the `loc_indices` and `loc_dirs`
54+ are the same or not. Thus, the returned HaloSchemeEntry will have
55+ in general more halo exchanges than `self`, or exactly the same halo
56+ exchanges in the worst case).
5557 """
56- if self .loc_indices != other .loc_indices or \
57- self .loc_dirs != other .loc_dirs or \
58+ if self .loc_dirs != other .loc_dirs or \
5859 self .bundle is not other .bundle :
5960 raise HaloSchemeException (
60- "Inconsistency found while building a HaloScheme "
61+ "Inconsistency found while merging HaloSchemeEntries "
6162 )
6263
6364 halos = self .halos | other .halos
@@ -66,6 +67,22 @@ def union(self, other):
6667 return HaloSchemeEntry (self .loc_indices , self .loc_dirs , halos , dims ,
6768 bundle = self .bundle , getters = self .getters )
6869
70+ def union (self , other ):
71+ """
72+ Return a new HaloSchemeEntry that is the union of this and `other`.
73+ The `loc_indices` and `loc_dirs` must be the same, otherwise an
74+ exception is raised. This is a more restrictive version of `merge`,
75+ which is used when we want to ensure that the halo exchanges are
76+ performed at the same time index, i.e., the same `loc_indices` are
77+ are expected.
78+ """
79+ if self .loc_indices != other .loc_indices :
80+ raise HaloSchemeException (
81+ "Inconsistency found while taking the union of HaloSchemeEntries"
82+ )
83+
84+ return self .merge (other )
85+
6986
7087Halo = namedtuple ('Halo' , 'dim side' )
7188
@@ -481,6 +498,15 @@ def add(self, f, hse):
481498 fmapper [f ] = hse
482499 return HaloScheme .build (fmapper , self .honored )
483500
501+ def merge (self , hs ):
502+ """
503+ Create a new HaloScheme that is the result of merging `hs` into `self`.
504+ """
505+ fmapper = dict (self .fmapper )
506+ for f , hse in hs .fmapper .items ():
507+ fmapper [f ] = fmapper .get (f , hse ).merge (hse )
508+ return HaloScheme .build (fmapper , self .honored )
509+
484510
485511def classify (exprs , ispace ):
486512 """
0 commit comments