Skip to content

Commit 64d1e73

Browse files
EdCauntmloubout
authored andcommitted
compiler: Refine handling of device environment variables and associated tests
1 parent 61eb537 commit 64d1e73

3 files changed

Lines changed: 8 additions & 6 deletions

File tree

devito/operator/operator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1435,7 +1435,8 @@ def nbytes_avail_mapper(self):
14351435

14361436
# The amount of space available on the device
14371437
if isinstance(self.platform, Device):
1438-
mapper[device_layer] = self.platform.memavail(deviceid=self._physical_deviceid)
1438+
mapper[device_layer] = \
1439+
self.platform.memavail(deviceid=self._physical_deviceid)
14391440

14401441
# The amount of space available on the host
14411442
try:

devito/parameters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
265265
except ValueError:
266266
# E.g., `platform` and `compiler` will end up here
267267
super(Parameters, configuration).__setitem__(k, self.previous[k])
268-
268+
269269

270270
class switchenv(abstractswitch):
271271
"""
@@ -279,7 +279,7 @@ def __init__(self, condition=True, **params):
279279
if condition:
280280
# Environment variables are essentially always uppercase
281281
self.params = {k.upper(): v for k, v in params.items()}
282-
else:
282+
else:
283283
self.params = params
284284

285285
def __enter__(self, condition=True, **params):

tests/test_gpu_common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ def test_visible_devices(self, env_variables):
9595

9696
eq = Eq(u, u+1)
9797

98+
# Save previous environment to verify switchenv works as intended
99+
previous_environ = dict(os.environ)
100+
98101
with switchenv(**env_variables):
99102
op1 = Operator(eq)
100103

@@ -103,8 +106,7 @@ def test_visible_devices(self, env_variables):
103106
assert argmap1._physical_deviceid == 1
104107

105108
# Make sure the switchenv doesn't somehow persist
106-
for i in ("CUDA", "ROCR", "HIP"):
107-
assert f"{i}_VISIBLE_DEVICES" not in os.environ
109+
assert dict(os.environ) == previous_environ
108110

109111
# Check that physical deviceid is 0 when no environment variables set
110112
op2 = Operator(eq)
@@ -113,7 +115,6 @@ def test_visible_devices(self, env_variables):
113115
# Default physical deviceid expected to be 0
114116
assert argmap2._physical_deviceid == 0
115117

116-
117118
@pytest.mark.parallel(mode=2)
118119
@pytest.mark.parametrize('visible_devices', ["1,2", "1,0", "0,2,3"])
119120
def test_visible_devices_mpi(self, visible_devices, mode):

0 commit comments

Comments
 (0)