Skip to content

Commit 82704f3

Browse files
EdCauntmloubout
authored andcommitted
tests: Add tests for using CUDA_VISIBLE_DEVICES and similar
1 parent 212a045 commit 82704f3

3 files changed

Lines changed: 147 additions & 27 deletions

File tree

devito/operator/operator.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,21 +1390,36 @@ def _get_nbytes(self, i):
13901390
return nbytes
13911391

13921392
@cached_property
1393-
def visible_devices(self):
1393+
def _visible_devices(self):
13941394
device_vars = (
13951395
'CUDA_VISIBLE_DEVICES',
13961396
'ROCR_VISIBLE_DEVICES',
13971397
'HIP_VISIBLE_DEVICES'
13981398
)
13991399
for v in device_vars:
1400-
if v in os.environ:
1401-
try:
1402-
return tuple(int(i) for i in os.environ[v].split(','))
1403-
except ValueError:
1404-
# Visible devices set via UUIDs or other non-integer identifiers
1405-
continue
1400+
try:
1401+
return tuple(int(i) for i in os.environ[v].split(','))
1402+
except (ValueError, KeyError):
1403+
# Environment variable not set or visible devices set via UUIDs
1404+
# or other non-integer identifiers.
1405+
continue
14061406

14071407
return None
1408+
1409+
@cached_property
1410+
def _physical_deviceid(self):
1411+
if isinstance(self.platform, Device):
1412+
# Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set)
1413+
rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0
1414+
1415+
logical_deviceid = max(self.get('deviceid', 0), 0) + rank
1416+
if self._visible_devices is not None:
1417+
return self._visible_devices[logical_deviceid]
1418+
else:
1419+
return logical_deviceid
1420+
1421+
else:
1422+
return None
14081423

14091424
@cached_property
14101425
def nbytes_avail_mapper(self):
@@ -1420,16 +1435,7 @@ def nbytes_avail_mapper(self):
14201435

14211436
# The amount of space available on the device
14221437
if isinstance(self.platform, Device):
1423-
# Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set)
1424-
rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0
1425-
1426-
logical_deviceid = max(self.get('deviceid', 0), 0) + rank
1427-
if self.visible_devices is not None:
1428-
physical_deviceid = self.visible_devices[logical_deviceid]
1429-
else:
1430-
physical_deviceid = logical_deviceid
1431-
1432-
mapper[device_layer] = self.platform.memavail(deviceid=physical_deviceid)
1438+
mapper[device_layer] = self.platform.memavail(deviceid=self._physical_deviceid)
14331439

14341440
# The amount of space available on the host
14351441
try:

devito/parameters.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from devito.tools import Signer, filter_ordered
99

1010
__all__ = ['configuration', 'init_configuration', 'print_defaults', 'print_state',
11-
'switchconfig']
11+
'switchconfig', 'switchenv']
1212

1313
# Be EXTREMELY careful when writing to a Parameters dictionary
1414
# Read here for reference: http://wiki.c2.com/?GlobalVariablesAreBad
@@ -224,7 +224,22 @@ def init_configuration(configuration=configuration, env_vars_mapper=env_vars_map
224224
configuration.initialize()
225225

226226

227-
class switchconfig:
227+
class abstractswitch:
228+
229+
"""
230+
Abstract class for switch(whatever) decorators.
231+
"""
232+
233+
def __call__(self, func, *args, **kwargs):
234+
@wraps(func)
235+
def wrapper(*args, **kwargs):
236+
with self:
237+
result = func(*args, **kwargs)
238+
return result
239+
return wrapper
240+
241+
242+
class switchconfig(abstractswitch):
228243

229244
"""
230245
Decorator or context manager to temporarily change `configuration` parameters.
@@ -250,14 +265,29 @@ def __exit__(self, exc_type, exc_val, exc_tb):
250265
except ValueError:
251266
# E.g., `platform` and `compiler` will end up here
252267
super(Parameters, configuration).__setitem__(k, self.previous[k])
268+
253269

254-
def __call__(self, func, *args, **kwargs):
255-
@wraps(func)
256-
def wrapper(*args, **kwargs):
257-
with self:
258-
result = func(*args, **kwargs)
259-
return result
260-
return wrapper
270+
class switchenv(abstractswitch):
271+
"""
272+
Decorator to temporarily change environment variables.
273+
Adapted from https://stackoverflow.com/questions/2059482/
274+
"""
275+
276+
def __init__(self, condition=True, **params):
277+
self.previous = dict(environ)
278+
279+
if condition:
280+
# Environment variables are essentially always uppercase
281+
self.params = {k.upper(): v for k, v in params.items()}
282+
else:
283+
self.params = params
284+
285+
def __enter__(self, condition=True, **params):
286+
environ.update(self.params)
287+
288+
def __exit__(self, exc_type, exc_val, exc_tb):
289+
environ.clear()
290+
environ.update(self.previous)
261291

262292

263293
def print_defaults():

tests/test_gpu_common.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import cloudpickle as pickle
22

33
import pytest
4+
import os
45
import numpy as np
56
import sympy
67
import scipy.sparse
@@ -10,7 +11,7 @@
1011
Dimension, MatrixSparseTimeFunction, SparseTimeFunction,
1112
SubDimension, SubDomain, SubDomainSet, TimeFunction, exp,
1213
Operator, configuration, switchconfig, TensorTimeFunction,
13-
Buffer, assign)
14+
Buffer, assign, switchenv)
1415
from devito.arch import get_gpu_info, get_cpu_info, Device, Cpu64
1516
from devito.exceptions import InvalidArgument
1617
from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols,
@@ -74,6 +75,89 @@ def test_autopad_with_platform_switch(self):
7475
assert f.shape_allocated[1] == 64
7576

7677

78+
class TestEnvironmentVariables:
79+
"""
80+
Test that environment variables are correctly handled.
81+
"""
82+
83+
@pytest.mark.parametrize('env_variables', [{"cuda_visible_devices": "1"},
84+
{"cuda_visible_devices": "1,2"},
85+
{"cuda_visible_devices": "1,0"},
86+
{"rocr_visible_devices": "1"},
87+
{"hip_visible_devices": " 1"}])
88+
def test_visible_devices(self, env_variables):
89+
"""
90+
Test that physical device IDs used for querying memory on a device via
91+
nvidia-smi correctly account for visible-device environment variables.
92+
"""
93+
grid = Grid(shape=(10, 10))
94+
u = Function(name='u', grid=grid)
95+
96+
eq = Eq(u, u+1)
97+
98+
with switchenv(**env_variables):
99+
op1 = Operator(eq)
100+
101+
argmap1 = op1.arguments()
102+
# All variants in parameterisation should yield deviceid 1
103+
assert argmap1._physical_deviceid == 1
104+
105+
# Make sure the switchenv doesn't somehow persist
106+
for i in ("CUDA", "ROCR", "HIP"):
107+
assert f"{i}_VISIBLE_DEVICES" not in os.environ
108+
109+
# Check that physical deviceid is 0 when no environment variables set
110+
op2 = Operator(eq)
111+
112+
argmap2 = op2.arguments()
113+
# Default physical deviceid expected to be 0
114+
assert argmap2._physical_deviceid == 0
115+
116+
117+
@pytest.mark.parallel(mode=2)
118+
@pytest.mark.parametrize('visible_devices', ["1,2", "1,0", "0,2,3"])
119+
def test_visible_devices_mpi(self, visible_devices, mode):
120+
"""
121+
Test that physical device IDs used for querying memory on a device via
122+
nvidia-smi correctly account for visible-device environment variables
123+
when using MPI.
124+
"""
125+
126+
grid = Grid(shape=(10, 10))
127+
rank = grid.distributor.myrank
128+
u = Function(name='u', grid=grid)
129+
130+
eq = Eq(u, u+1)
131+
132+
with switchenv(cuda_visible_devices=visible_devices):
133+
op1 = Operator(eq)
134+
argmap1 = op1.arguments()
135+
136+
devices = [int(i) for i in visible_devices.split(',')]
137+
138+
assert argmap1._physical_deviceid == devices[rank]
139+
140+
# In default case, physical deviceid will equal rank
141+
op2 = Operator(eq)
142+
argmap2 = op2.arguments()
143+
assert argmap2._physical_deviceid == rank
144+
145+
def test_visible_devices_with_devito_deviceid(self):
146+
"""Test interaction between CUDA_VISIBLE_DEVICES and DEVITO_DEVICEID"""
147+
grid = Grid(shape=(10, 10))
148+
u = Function(name='u', grid=grid)
149+
150+
eq = Eq(u, u+1)
151+
152+
with switchenv(cuda_visible_devices="1,3"), switchconfig(deviceid=1):
153+
op = Operator(eq)
154+
155+
argmap = op.arguments()
156+
# deviceid should see the world from within CUDA_VISIBLE_DEVICES
157+
# So should be the second of the two visible devices specified (3)
158+
assert argmap._physical_deviceid == 3
159+
160+
77161
class TestCodeGeneration:
78162

79163
def test_maxpar_option(self):

0 commit comments

Comments
 (0)