Skip to content

Commit 51d4d91

Browse files
committed
compiler: Call set_device in pthreads
1 parent b54ca3f commit 51d4d91

2 files changed

Lines changed: 53 additions & 47 deletions

File tree

devito/passes/iet/asynchrony.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def _(iet, key=None, tracker=None, sregistry=None, **kwargs):
170170
DummyExpr(i._C_symbol, FieldFromPointer(i._C_symbol, sbase))
171171
)
172172
else:
173-
unpacks.append(DummyExpr(i, FieldFromPointer(i.base, sbase)))
173+
unpacks.append(DummyExpr(i, FieldFromPointer(i.base, sbase), init=True))
174174

175175
body = iet.body._rebuild(body=[wrap, Return(Null)], unpacks=unpacks)
176176
iet = ThreadCallable(iet.name, body, tparameter)

devito/passes/iet/langbase.py

Lines changed: 52 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -407,67 +407,74 @@ def initialize(self, iet, options=None):
407407
is sufficient reuse to implement the logic as a single method.
408408
"""
409409

410-
@singledispatch
411-
def _initialize(iet):
412-
return iet, {}
413-
414-
@_initialize.register(EntryFunction)
415-
def _(iet):
416-
assert iet.body.is_CallableBody
417-
418-
# TODO: we need to pick the rank from `comm_shm`, not `comm`,
419-
# so that we have nranks == ngpus (as long as the user has launched
420-
# the right number of MPI processes per node given the available
421-
# number of GPUs per node)
422-
423-
objcomm = None
410+
def _extract_objcomm(iet):
424411
for i in iet.parameters:
425412
if isinstance(i, MPICommObject):
426-
objcomm = i
427-
break
428-
if objcomm is None and options['mpi']:
429-
# Time to inject `objcomm`. If it's not here, it simply means
430-
# there's no halo exchanges in the Operator, but we now need it
431-
# nonetheless to perform the rank-GPU assignment
413+
return i
414+
415+
# Fallback -- might end up here because the Operator has no
416+
# halo exchanges, but we now need it nonetheless to perform
417+
# the rank-GPU assignment
418+
if options['mpi']:
432419
for i in iet.parameters:
433420
try:
434-
objcomm = i.grid.distributor._obj_comm
435-
break
421+
return i.grid.distributor._obj_comm
436422
except AttributeError:
437423
pass
438424

425+
def _make_setdevice_seq(iet, nodes=()):
439426
devicetype = as_list(self.langbb[self.platform])
440427
deviceid = self.deviceid
441428

429+
return list(nodes) + [Conditional(
430+
CondNe(deviceid, -1),
431+
self.langbb['set-device']([deviceid] + devicetype)
432+
)]
433+
434+
def _make_setdevice_mpi(iet, objcomm, nodes=()):
435+
devicetype = as_list(self.langbb[self.platform])
436+
deviceid = self.deviceid
437+
438+
rank = Symbol(name='rank')
439+
rank_decl = DummyExpr(rank, 0)
440+
rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])
441+
442+
ngpus, call_ngpus = self.langbb._get_num_devices(self.platform)
443+
444+
osdd_then = self.langbb['set-device']([deviceid] + devicetype)
445+
osdd_else = self.langbb['set-device']([rank % ngpus] + devicetype)
446+
447+
return list(nodes) + [Conditional(
448+
CondNe(deviceid, -1),
449+
osdd_then,
450+
List(body=[rank_decl, rank_init, call_ngpus, osdd_else]),
451+
)]
452+
453+
@singledispatch
454+
def _initialize(iet):
455+
return iet, {}
456+
457+
@_initialize.register(EntryFunction)
458+
def _(iet):
459+
assert iet.body.is_CallableBody
460+
461+
devicetype = as_list(self.langbb[self.platform])
462+
442463
try:
443464
lang_init = [self.langbb['init'](devicetype)]
444465
except TypeError:
445466
# Not all target languages need to be explicitly initialized
446467
lang_init = []
447468

448-
if objcomm is not None:
449-
rank = Symbol(name='rank')
450-
rank_decl = DummyExpr(rank, 0)
451-
rank_init = Call('MPI_Comm_rank', [objcomm, Byref(rank)])
452-
453-
ngpus, call_ngpus = self.langbb._get_num_devices(self.platform)
454-
455-
osdd_then = self.langbb['set-device']([deviceid] + devicetype)
456-
osdd_else = self.langbb['set-device']([rank % ngpus] + devicetype)
469+
objcomm = _extract_objcomm(iet)
457470

458-
body = lang_init + [Conditional(
459-
CondNe(deviceid, -1),
460-
osdd_then,
461-
List(body=[rank_decl, rank_init, call_ngpus, osdd_else]),
462-
)]
471+
if objcomm is not None:
472+
body = _make_setdevice_mpi(iet, objcomm, nodes=lang_init)
463473

464474
header = c.Comment('Beginning of %s+MPI setup' % self.langbb['name'])
465475
footer = c.Comment('End of %s+MPI setup' % self.langbb['name'])
466476
else:
467-
body = lang_init + [Conditional(
468-
CondNe(deviceid, -1),
469-
self.langbb['set-device']([deviceid] + devicetype)
470-
)]
477+
body = _make_setdevice_seq(iet, nodes=lang_init)
471478

472479
header = c.Comment('Beginning of %s setup' % self.langbb['name'])
473480
footer = c.Comment('End of %s setup' % self.langbb['name'])
@@ -479,13 +486,12 @@ def _(iet):
479486

480487
@_initialize.register(AsyncCallable)
481488
def _(iet):
482-
devicetype = as_list(self.langbb[self.platform])
483-
deviceid = self.deviceid
489+
objcomm = _extract_objcomm(iet)
490+
if objcomm is not None:
491+
init = _make_setdevice_mpi(iet, objcomm)
492+
else:
493+
init = _make_setdevice_seq(iet)
484494

485-
init = Conditional(
486-
CondNe(deviceid, -1),
487-
self.langbb['set-device']([deviceid] + devicetype)
488-
)
489495
iet = iet._rebuild(body=iet.body._rebuild(init=init))
490496

491497
return iet, {}

0 commit comments

Comments
 (0)