Skip to content

Commit 9d2e0c4

Browse files
Merge pull request #2769 from devitocodes/JDBetteridge/remove_switchenv
misc: Un-cargo-cult switchenv
2 parents b4d8fe9 + ab31741 commit 9d2e0c4

3 files changed

Lines changed: 64 additions & 52 deletions

File tree

devito/parameters.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""The parameters dictionary contains global parameter settings."""
2-
2+
from abc import ABC, abstractmethod
33
from collections import OrderedDict
4-
from os import environ
4+
import os
55
from functools import wraps
66

77
from devito.logger import info, warning
@@ -170,24 +170,24 @@ def _signature_items(self):
170170
def init_configuration(configuration=configuration, env_vars_mapper=env_vars_mapper,
171171
env_vars_deprecated=env_vars_deprecated):
172172
# Populate `configuration` with user-provided options
173-
if environ.get('DEVITO_CONFIG') is None:
173+
if os.environ.get('DEVITO_CONFIG') is None:
174174
# It is important to configure `platform`, `compiler`, and the rest, in this order
175175
process_order = filter_ordered(['platform', 'compiler'] +
176176
list(env_vars_mapper.values()))
177177
queue = sorted(env_vars_mapper.items(), key=lambda i: process_order.index(i[1]))
178-
unprocessed = OrderedDict([(v, environ.get(k, configuration._defaults[v]))
178+
unprocessed = OrderedDict([(v, os.environ.get(k, configuration._defaults[v]))
179179
for k, v in queue])
180180

181181
# Handle deprecated env vars
182182
mapper = dict(queue)
183183
for k, (v, msg) in env_vars_deprecated.items():
184-
if environ.get(k):
184+
if os.environ.get(k):
185185
warning(f"`{k}` is deprecated. {msg}")
186-
if environ.get(v):
186+
if os.environ.get(v):
187187
warning(f"Both `{k}` and `{v}` set. Ignoring `{k}`")
188188
else:
189-
warning(f"Setting `{v}={environ[k]}`")
190-
unprocessed[mapper[v]] = environ[k]
189+
warning(f"Setting `{v}={os.environ[k]}`")
190+
unprocessed[mapper[v]] = os.environ[k]
191191
else:
192192
# Attempt reading from the specified configuration file
193193
raise NotImplementedError("Devito doesn't support configuration via file yet.")
@@ -224,13 +224,25 @@ def init_configuration(configuration=configuration, env_vars_mapper=env_vars_map
224224
configuration.initialize()
225225

226226

227-
class abstractswitch:
228-
227+
class SwitchDecorator(ABC):
229228
"""
230-
Abstract class for switch(whatever) decorators.
229+
Abstract base class that turns a context manager class into a decorator.
231230
"""
231+
@abstractmethod
232+
def __init__(self, *args, **kwargs):
233+
pass
234+
235+
@abstractmethod
236+
def __enter__(self):
237+
pass
238+
239+
@abstractmethod
240+
def __exit__(self, exc_type, exc_val, traceback):
241+
pass
232242

233243
def __call__(self, func, *args, **kwargs):
244+
""" The call method turns the context manager class into a decorator
245+
"""
234246
@wraps(func)
235247
def wrapper(*args, **kwargs):
236248
with self:
@@ -239,26 +251,24 @@ def wrapper(*args, **kwargs):
239251
return wrapper
240252

241253

242-
class switchconfig(abstractswitch):
243-
254+
class switchconfig(SwitchDecorator):
244255
"""
245256
Decorator or context manager to temporarily change `configuration` parameters.
246257
"""
247-
248258
def __init__(self, condition=True, **params):
249259
if condition:
250260
self.params = {k.replace('_', '-'): v for k, v in params.items()}
251261
else:
252262
self.params = {}
253263
self.previous = {}
254264

255-
def __enter__(self, condition=True, **params):
265+
def __enter__(self):
256266
self.previous = {}
257267
for k, v in self.params.items():
258268
self.previous[k] = configuration[k]
259269
configuration[k] = v
260270

261-
def __exit__(self, exc_type, exc_val, exc_tb):
271+
def __exit__(self, exc_type, exc_val, traceback):
262272
for k, v in self.params.items():
263273
try:
264274
configuration[k] = self.previous[k]
@@ -267,27 +277,23 @@ def __exit__(self, exc_type, exc_val, exc_tb):
267277
super(Parameters, configuration).__setitem__(k, self.previous[k])
268278

269279

270-
class switchenv(abstractswitch):
280+
class switchenv(SwitchDecorator):
271281
"""
272-
Decorator to temporarily change environment variables.
273-
Adapted from https://stackoverflow.com/questions/2059482/
274-
"""
275-
276-
def __init__(self, condition=True, **params):
277-
self.previous = dict(environ)
282+
Temporarily set environment variables from a dictionary
278283
279-
if condition:
280-
# Environment variables are essentially always uppercase
281-
self.params = {k.upper(): v for k, v in params.items()}
282-
else:
283-
self.params = params
284+
Note: This does not propagate any environment variables that change inside
285+
the context manager, so should be used cautiously.
286+
"""
287+
def __init__(self, params):
288+
self.previous = dict(os.environ)
289+
self.params = params
284290

285-
def __enter__(self, condition=True, **params):
286-
environ.update(self.params)
291+
def __enter__(self):
292+
os.environ.update(self.params)
287293

288-
def __exit__(self, exc_type, exc_val, exc_tb):
289-
environ.clear()
290-
environ.update(self.previous)
294+
def __exit__(self, exc_type, exc_val, traceback):
295+
os.environ.clear()
296+
os.environ.update(self.previous)
291297

292298

293299
def print_defaults():

tests/test_gpu_common.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import cloudpickle as pickle
22

33
import pytest
4-
import os
54
import numpy as np
65
import sympy
76
import scipy.sparse
@@ -81,11 +80,11 @@ class TestDeviceID:
8180
CUDA_VISIBLE_DEVICES are correctly handled.
8281
"""
8382

84-
@pytest.mark.parametrize('env_variables', [{"cuda_visible_devices": "1"},
85-
{"cuda_visible_devices": "1,2"},
86-
{"cuda_visible_devices": "1,0"},
87-
{"rocr_visible_devices": "1"},
88-
{"hip_visible_devices": " 1"}])
83+
@pytest.mark.parametrize('env_variables', [{"CUDA_VISIBLE_DEVICES": "1"},
84+
{"CUDA_VISIBLE_DEVICES": "1,2"},
85+
{"CUDA_VISIBLE_DEVICES": "1,0"},
86+
{"ROCR_VISIBLE_DEVICES": "1"},
87+
{"HIP_VISIBLE_DEVICES": " 1"}])
8988
def test_visible_devices(self, env_variables):
9089
"""
9190
Test that physical device IDs used for querying memory on a device via
@@ -96,19 +95,12 @@ def test_visible_devices(self, env_variables):
9695

9796
eq = Eq(u, u+1)
9897

99-
# Save previous environment to verify switchenv works as intended
100-
previous_environ = dict(os.environ)
101-
102-
with switchenv(**env_variables):
98+
with switchenv(env_variables):
10399
op1 = Operator(eq)
104-
105100
argmap1 = op1.arguments()
106101
# All variants in parameterisation should yield deviceid 1
107102
assert argmap1._physical_deviceid == 1
108103

109-
# Make sure the switchenv doesn't somehow persist
110-
assert dict(os.environ) == previous_environ
111-
112104
# Check that physical deviceid is 0 when no environment variables set
113105
op2 = Operator(eq)
114106

@@ -131,12 +123,10 @@ def test_visible_devices_mpi(self, visible_devices, mode):
131123

132124
eq = Eq(u, u+1)
133125

134-
with switchenv(cuda_visible_devices=visible_devices):
126+
with switchenv({'CUDA_VISIBLE_DEVICES': visible_devices}):
135127
op1 = Operator(eq)
136128
argmap1 = op1.arguments()
137-
138129
devices = [int(i) for i in visible_devices.split(',')]
139-
140130
assert argmap1._physical_deviceid == devices[rank]
141131

142132
# In default case, physical deviceid will equal rank
@@ -151,7 +141,7 @@ def test_visible_devices_with_devito_deviceid(self):
151141

152142
eq = Eq(u, u+1)
153143

154-
with switchenv(cuda_visible_devices="1,3"), switchconfig(deviceid=1):
144+
with switchenv({'CUDA_VISIBLE_DEVICES': "1,3"}), switchconfig(deviceid=1):
155145
op = Operator(eq)
156146

157147
argmap = op.arguments()

tests/test_tools.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import os
12
import numpy as np
23
import pytest
34
from sympy.abc import a, b, c, d, e
45

56
import time
67

7-
from devito import Operator, Eq
8+
from devito import Operator, Eq, switchenv
89
from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort,
910
filter_ordered, transitive_closure, UnboundTuple,
1011
CacheInstances)
@@ -209,3 +210,18 @@ def __init__(self, value: int):
209210
# Cache should be cleared after Operator construction
210211
cache_size = Object._instance_cache.cache_info()[-1]
211212
assert cache_size == 0
213+
214+
215+
def test_switchenv():
216+
# Save previous environment
217+
previous_environ = dict(os.environ)
218+
219+
# Check a temporary variable is set inside the context manager
220+
with switchenv({'TEST_VAR': 'foo'}):
221+
assert os.environ['TEST_VAR'] == 'foo'
222+
223+
# Check a temporary variable is unset inside the context manager
224+
assert os.environ.get('TEST_VAR') is None
225+
226+
# Make sure the switchenv does not persist to verify switchenv works as intended
227+
assert dict(os.environ) == previous_environ

0 commit comments

Comments
 (0)