@@ -407,67 +407,74 @@ def initialize(self, iet, options=None):
407407 is sufficient reuse to implement the logic as a single method.
408408 """
409409
410- @singledispatch
411- def _initialize (iet ):
412- return iet , {}
413-
414- @_initialize .register (EntryFunction )
415- def _ (iet ):
416- assert iet .body .is_CallableBody
417-
418- # TODO: we need to pick the rank from `comm_shm`, not `comm`,
419- # so that we have nranks == ngpus (as long as the user has launched
420- # the right number of MPI processes per node given the available
421- # number of GPUs per node)
422-
423- objcomm = None
410+ def _extract_objcomm (iet ):
424411 for i in iet .parameters :
425412 if isinstance (i , MPICommObject ):
426- objcomm = i
427- break
428- if objcomm is None and options [ 'mpi' ]:
429- # Time to inject `objcomm`. If it's not here, it simply means
430- # there's no halo exchanges in the Operator, but we now need it
431- # nonetheless to perform the rank-GPU assignment
413+ return i
414+
415+ # Fallback -- might end up here because the Operator has no
416+ # halo exchanges, but we now need it nonetheless to perform
417+ # the rank-GPU assignment
418+ if options [ 'mpi' ]:
432419 for i in iet .parameters :
433420 try :
434- objcomm = i .grid .distributor ._obj_comm
435- break
421+ return i .grid .distributor ._obj_comm
436422 except AttributeError :
437423 pass
438424
425+ def _make_setdevice_seq (iet , nodes = ()):
439426 devicetype = as_list (self .langbb [self .platform ])
440427 deviceid = self .deviceid
441428
429+ return list (nodes ) + [Conditional (
430+ CondNe (deviceid , - 1 ),
431+ self .langbb ['set-device' ]([deviceid ] + devicetype )
432+ )]
433+
434+ def _make_setdevice_mpi (iet , objcomm , nodes = ()):
435+ devicetype = as_list (self .langbb [self .platform ])
436+ deviceid = self .deviceid
437+
438+ rank = Symbol (name = 'rank' )
439+ rank_decl = DummyExpr (rank , 0 )
440+ rank_init = Call ('MPI_Comm_rank' , [objcomm , Byref (rank )])
441+
442+ ngpus , call_ngpus = self .langbb ._get_num_devices (self .platform )
443+
444+ osdd_then = self .langbb ['set-device' ]([deviceid ] + devicetype )
445+ osdd_else = self .langbb ['set-device' ]([rank % ngpus ] + devicetype )
446+
447+ return list (nodes ) + [Conditional (
448+ CondNe (deviceid , - 1 ),
449+ osdd_then ,
450+ List (body = [rank_decl , rank_init , call_ngpus , osdd_else ]),
451+ )]
452+
453+ @singledispatch
454+ def _initialize (iet ):
455+ return iet , {}
456+
457+ @_initialize .register (EntryFunction )
458+ def _ (iet ):
459+ assert iet .body .is_CallableBody
460+
461+ devicetype = as_list (self .langbb [self .platform ])
462+
442463 try :
443464 lang_init = [self .langbb ['init' ](devicetype )]
444465 except TypeError :
445466 # Not all target languages need to be explicitly initialized
446467 lang_init = []
447468
448- if objcomm is not None :
449- rank = Symbol (name = 'rank' )
450- rank_decl = DummyExpr (rank , 0 )
451- rank_init = Call ('MPI_Comm_rank' , [objcomm , Byref (rank )])
452-
453- ngpus , call_ngpus = self .langbb ._get_num_devices (self .platform )
454-
455- osdd_then = self .langbb ['set-device' ]([deviceid ] + devicetype )
456- osdd_else = self .langbb ['set-device' ]([rank % ngpus ] + devicetype )
469+ objcomm = _extract_objcomm (iet )
457470
458- body = lang_init + [Conditional (
459- CondNe (deviceid , - 1 ),
460- osdd_then ,
461- List (body = [rank_decl , rank_init , call_ngpus , osdd_else ]),
462- )]
471+ if objcomm is not None :
472+ body = _make_setdevice_mpi (iet , objcomm , nodes = lang_init )
463473
464474 header = c .Comment ('Beginning of %s+MPI setup' % self .langbb ['name' ])
465475 footer = c .Comment ('End of %s+MPI setup' % self .langbb ['name' ])
466476 else :
467- body = lang_init + [Conditional (
468- CondNe (deviceid , - 1 ),
469- self .langbb ['set-device' ]([deviceid ] + devicetype )
470- )]
477+ body = _make_setdevice_seq (iet , nodes = lang_init )
471478
472479 header = c .Comment ('Beginning of %s setup' % self .langbb ['name' ])
473480 footer = c .Comment ('End of %s setup' % self .langbb ['name' ])
@@ -479,13 +486,12 @@ def _(iet):
479486
480487 @_initialize .register (AsyncCallable )
481488 def _ (iet ):
482- devicetype = as_list (self .langbb [self .platform ])
483- deviceid = self .deviceid
489+ objcomm = _extract_objcomm (iet )
490+ if objcomm is not None :
491+ init = _make_setdevice_mpi (iet , objcomm )
492+ else :
493+ init = _make_setdevice_seq (iet )
484494
485- init = Conditional (
486- CondNe (deviceid , - 1 ),
487- self .langbb ['set-device' ]([deviceid ] + devicetype )
488- )
489495 iet = iet ._rebuild (body = iet .body ._rebuild (init = init ))
490496
491497 return iet , {}
0 commit comments