Skip to content

Commit 35304b9

Browse files
EdCauntmloubout
authored andcommitted
compiler: Tweak behaviour of deviceid to respect per-rank DeviceID specified by user
1 parent dc4d345 commit 35304b9

2 files changed

Lines changed: 30 additions & 6 deletions

File tree

devito/operator/operator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,13 +1393,16 @@ def _get_nbytes(self, i):
13931393
def _physical_deviceid(self):
13941394
if isinstance(self.platform, Device):
13951395
# Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set)
1396-
rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0
1396+
logical_deviceid = self.get('deviceid', -1)
1397+
if logical_deviceid < 0:
1398+
rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0
1399+
logical_deviceid = rank
13971400

1398-
logical_deviceid = max(self.get('deviceid', 0), 0) + rank
1399-
if self._visible_devices is None:
1401+
visible_devices = get_visible_devices()
1402+
if visible_devices is None:
14001403
return logical_deviceid
14011404
else:
1402-
return get_visible_devices()[logical_deviceid]
1405+
return visible_devices[logical_deviceid]
14031406
else:
14041407
return None
14051408

tests/test_gpu_common.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ def test_autopad_with_platform_switch(self):
7575
assert f.shape_allocated[1] == 64
7676

7777

78-
class TestEnvironmentVariables:
78+
class TestDeviceID:
7979
"""
80-
Test that environment variables are correctly handled.
80+
Test that device IDs and associated environment variables such as
81+
CUDA_VISIBLE_DEVICES are correctly handled.
8182
"""
8283

8384
@pytest.mark.parametrize('env_variables', [{"cuda_visible_devices": "1"},
@@ -158,6 +159,26 @@ def test_visible_devices_with_devito_deviceid(self):
158159
# So should be the second of the two visible devices specified (3)
159160
assert argmap._physical_deviceid == 3
160161

162+
@pytest.mark.parallel(mode=2)
163+
def test_deviceid_per_rank(self, mode):
164+
"""
165+
Test that Device IDs set by the user on a per-rank basis do not
166+
get modifed.
167+
"""
168+
# Reversed order to ensure it is different to default
169+
user_set_deviceids = (1, 0)
170+
171+
grid = Grid(shape=(10, 10))
172+
u = Function(name='u', grid=grid)
173+
174+
rank = grid.distributor.myrank
175+
deviceid = user_set_deviceids[rank]
176+
177+
op = Operator(Eq(u, u+1))
178+
179+
argmap = op.arguments(deviceid=deviceid)
180+
assert argmap._physical_deviceid == deviceid
181+
161182

162183
class TestCodeGeneration:
163184

0 commit comments

Comments
 (0)