@@ -399,6 +399,81 @@ class HaloComms(Queue):
399399 def process (self , clusters ):
400400 return self ._process_fatd (clusters , 1 , seen = set ())
401401
402+ def _derive_halo_schemes (self , c ):
403+ hs = HaloScheme (c .exprs , c .ispace )
404+
405+ # 95% of the times we will just return `hs` as is as there are no guards
406+ if not c .guards :
407+ yield hs , c
408+ return
409+
410+ # This is a more contrived situation in which we might need halo exchanges
411+ # from multiple so called loc-indices -- let's check this out
412+ candidates = []
413+ for f , hse in hs .fmapper .items ():
414+ reads = c .scope .reads [f ]
415+
416+ for d in hse .loc_indices :
417+ if not d ._defines & set (c .guards ):
418+ continue
419+
420+ candidates .append (as_mapper (reads , key = lambda i : i [d ]).values ())
421+
422+ # 4% of the times we will just return `hs` as is
423+ # E.g., we end up here when taking space derivatives of one or more saved
424+ # TimeFunctions in equations evaluating gradients that are controlled by
425+ # a ConditionalDimension (otherwise we would have exited earlier)
426+ if any (len (g ) <= 1 for g in candidates ):
427+ yield hs , c
428+ return
429+
430+ # 1% of the times, finally, we end up here...
431+ # At this point we have to create a mock Cluster for each loc-index,
432+ # containing all and only the accesses to `f` at a given loc-index
433+ # E.g., a mock Cluster at `loc_index=t0` containing the accesses
434+ # `[u(t0, x + 8, ...), u(t0, x + 7, ...)], another mock Cluster at
435+ # `loc_index=t1` containing the accesses `[u(t1, x + 5, ...),
436+ # u(t1, x + 6, ...)]`, and so on
437+ for groups in candidates :
438+ for group in groups :
439+ pointset = sympy .Function ('pointset' )
440+ v = pointset (* [i .access for i in group ])
441+ exprs = [e .func (rhs = v ) for e in c .exprs ]
442+
443+ c1 = c .rebuild (exprs = exprs )
444+
445+ hs = HaloScheme (c1 .exprs , c .ispace )
446+
447+ yield hs , c1
448+
449+ def _make_halo_touch (self , hs , c , prefix ):
450+ points = set ()
451+ for f in hs .fmapper :
452+ for a in c .scope .getreads (f ):
453+ points .add (a .access )
454+
455+ # We also add all written symbols to ultimately create mock WARs
456+ # with `c`, which will prevent the newly created HaloTouch from
457+ # ever being rescheduled
458+ points .update (a .access for a in c .scope .accesses if a .is_write )
459+
460+ # Sort for determinism
461+ # NOTE: not sorting might impact code generation. The order of
462+ # the args is important because that's what search functions honor!
463+ points = sorted (points , key = str )
464+
465+ # Construct the HaloTouch Cluster
466+ expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
467+
468+ key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
469+ ispace = c .ispace .project (key )
470+ # HaloTouches are not parallel
471+ properties = c .properties .sequentialize ()
472+
473+ halo_touch = c .rebuild (exprs = expr , ispace = ispace , properties = properties )
474+
475+ return halo_touch
476+
402477 def callback (self , clusters , prefix , seen = None ):
403478 if not prefix :
404479 return clusters
@@ -412,38 +487,18 @@ def callback(self, clusters, prefix, seen=None):
412487 c in seen :
413488 continue
414489
415- hs = HaloScheme (c .exprs , c .ispace )
416- if hs .is_void or \
417- not d ._defines & hs .distributed_aindices :
418- continue
419-
420- points = set ()
421- for f in hs .fmapper :
422- for a in c .scope .getreads (f ):
423- points .add (a .access )
424-
425- # We also add all written symbols to ultimately create mock WARs
426- # with `c`, which will prevent the newly created HaloTouch to ever
427- # be rescheduled after `c` upon topological sorting
428- points .update (a .access for a in c .scope .accesses if a .is_write )
490+ seen .add (c )
429491
430- # Sort for determinism
431- # NOTE: not sorting might impact code generation. The order of
432- # the args is important because that's what search functions honor!
433- points = sorted (points , key = str )
434-
435- # Construct the HaloTouch Cluster
436- expr = Eq (self .B , HaloTouch (* points , halo_scheme = hs ))
492+ for hs , c1 in self ._derive_halo_schemes (c ):
493+ if hs .is_void or \
494+ not d ._defines & hs .distributed_aindices :
495+ continue
437496
438- key = lambda i : i in prefix [:- 1 ] or i in hs .loc_indices
439- ispace = c .ispace .project (key )
440- # HaloTouches are not parallel
441- properties = c .properties .sequentialize ()
497+ halo_touch = self ._make_halo_touch (hs , c1 , prefix )
442498
443- halo_touch = c . rebuild ( exprs = expr , ispace = ispace , properties = properties )
499+ processed . append ( halo_touch )
444500
445- processed .append (halo_touch )
446- seen .update ({halo_touch , c })
501+ seen .add (halo_touch )
447502
448503 processed .extend (clusters )
449504
0 commit comments