Skip to content

Commit 2d3a7e5

Browse files
Copilotggorman
andcommitted
Implement device ID validation for GPU operators
Co-authored-by: ggorman <5394691+ggorman@users.noreply.github.com>
1 parent 4b5e53e commit 2d3a7e5

3 files changed

Lines changed: 108 additions & 13 deletions

File tree

devito/passes/iet/langbase.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from devito.mpi.distributed import MPICommObject
1212
from devito.passes import is_on_device
1313
from devito.passes.iet.engine import iet_pass
14-
from devito.symbolics import Byref, CondNe, SizeOf
14+
from devito.symbolics import Byref, CondNe, CondGe, SizeOf
1515
from devito.tools import as_list, is_integer, prod
1616
from devito.types import Symbol, QueueID, Wildcard
1717

@@ -426,9 +426,30 @@ def _make_setdevice_seq(iet, nodes=()):
426426
devicetype = as_list(self.langbb[self.platform])
427427
deviceid = self.deviceid
428428

429+
# Add device validation check
430+
ngpus, call_ngpus = self.langbb._get_num_devices(self.platform)
431+
432+
# Create validation: if deviceid >= num_devices, print error and exit
433+
validation_check = Conditional(
434+
CondGe(deviceid, ngpus),
435+
List(body=[
436+
Call('printf', ['"%s: Error - Requested device ID %d does not exist. '
437+
'Only %d device(s) available. Check CUDA_VISIBLE_DEVICES '
438+
'and container GPU configuration.\\n"',
439+
self.langbb['name'], deviceid, ngpus]),
440+
Call('exit', [1])
441+
])
442+
)
443+
444+
device_setup = List(body=[
445+
call_ngpus,
446+
validation_check,
447+
self.langbb['set-device']([deviceid] + devicetype)
448+
])
449+
429450
return list(nodes) + [Conditional(
430451
CondNe(deviceid, -1),
431-
self.langbb['set-device']([deviceid] + devicetype)
452+
device_setup
432453
)]
433454

434455
def _make_setdevice_mpi(iet, objcomm, nodes=()):
@@ -441,7 +462,23 @@ def _make_setdevice_mpi(iet, objcomm, nodes=()):
441462

442463
ngpus, call_ngpus = self.langbb._get_num_devices(self.platform)
443464

444-
osdd_then = self.langbb['set-device']([deviceid] + devicetype)
465+
# Add device validation check for explicit device ID
466+
validation_check = Conditional(
467+
CondGe(deviceid, ngpus),
468+
List(body=[
469+
Call('printf', ['"%s: Error - Requested device ID %d does not exist. '
470+
'Only %d device(s) available. Check CUDA_VISIBLE_DEVICES '
471+
'and container GPU configuration.\\n"',
472+
self.langbb['name'], deviceid, ngpus]),
473+
Call('exit', [1])
474+
])
475+
)
476+
477+
osdd_then = List(body=[
478+
call_ngpus,
479+
validation_check,
480+
self.langbb['set-device']([deviceid] + devicetype)
481+
])
445482
osdd_else = self.langbb['set-device']([rank % ngpus] + devicetype)
446483

447484
return list(nodes) + [Conditional(

tests/test_gpu_openacc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,31 @@ def test_op_apply(self):
200200

201201
assert np.all(np.array(u.data[0, :, :, :]) == time_steps)
202202

203+
def test_device_validation_error_message(self):
204+
"""Test that OpenACC device validation includes helpful error messages."""
205+
grid = Grid(shape=(3, 3, 3))
206+
207+
u = TimeFunction(name='u', grid=grid, dtype=np.int32)
208+
209+
op = Operator(Eq(u.forward, u + 1), platform='nvidiaX', language='openacc')
210+
211+
# Check that the generated code contains device validation with informative error
212+
code = str(op)
213+
214+
# Should contain device count check
215+
assert 'acc_get_num_devices' in code, "Missing OpenACC device count check"
216+
217+
# Should contain validation condition
218+
assert 'deviceid >= ngpus' in code, "Missing OpenACC device ID validation condition"
219+
220+
# Should contain helpful error message components
221+
assert 'does not exist' in code, "Missing 'does not exist' error message"
222+
assert 'CUDA_VISIBLE_DEVICES' in code, "Missing CUDA_VISIBLE_DEVICES guidance"
223+
assert 'container GPU configuration' in code, "Missing container guidance"
224+
225+
# Should contain exit call to prevent undefined behavior
226+
assert 'exit(1)' in code, "Missing exit call on validation failure"
227+
203228
def iso_acoustic(self, opt):
204229
shape = (101, 101)
205230
extent = (1000, 1000)

tests/test_gpu_openmp.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,13 @@ def test_init_omp_env(self):
2020

2121
op = Operator(Eq(u.forward, u.dx+1), language='openmp')
2222

23-
assert str(op.body.init[0].body[0]) ==\
24-
'if (deviceid != -1)\n{\n omp_set_default_device(deviceid);\n}'
23+
# With device validation, the generated code now includes validation logic
24+
init_code = str(op.body.init[0].body[0])
25+
assert 'if (deviceid != -1)' in init_code
26+
assert 'int ngpus = omp_get_num_devices()' in init_code
27+
assert 'if (deviceid >= ngpus)' in init_code
28+
assert 'does not exist' in init_code
29+
assert 'omp_set_default_device(deviceid)' in init_code
2530

2631
@pytest.mark.parallel(mode=1)
2732
def test_init_omp_env_w_mpi(self, mode):
@@ -31,14 +36,42 @@ def test_init_omp_env_w_mpi(self, mode):
3136

3237
op = Operator(Eq(u.forward, u.dx+1), language='openmp')
3338

34-
assert str(op.body.init[0].body[0]) ==\
35-
('if (deviceid != -1)\n'
36-
'{\n omp_set_default_device(deviceid);\n}\n'
37-
'else\n'
38-
'{\n int rank = 0;\n'
39-
' MPI_Comm_rank(comm,&rank);\n'
40-
' int ngpus = omp_get_num_devices();\n'
41-
' omp_set_default_device((rank)%(ngpus));\n}')
39+
# With device validation, the MPI case also includes validation for explicit deviceid
40+
init_code = str(op.body.init[0].body[0])
41+
assert 'if (deviceid != -1)' in init_code
42+
assert 'int ngpus = omp_get_num_devices()' in init_code
43+
# For MPI case with explicit deviceid, should have validation
44+
assert 'if (deviceid >= ngpus)' in init_code
45+
assert 'does not exist' in init_code
46+
# Should still have MPI rank-based assignment in else clause
47+
assert 'int rank = 0' in init_code
48+
assert 'MPI_Comm_rank(comm,&rank)' in init_code
49+
assert '(rank)%(ngpus)' in init_code
50+
51+
def test_device_validation_error_message(self):
52+
"""Test that device validation includes helpful error messages."""
53+
grid = Grid(shape=(3, 3, 3))
54+
55+
u = TimeFunction(name='u', grid=grid)
56+
57+
op = Operator(Eq(u.forward, u.dx+1), language='openmp')
58+
59+
# Check that the generated code contains device validation with informative error
60+
code = str(op)
61+
62+
# Should contain device count check
63+
assert 'omp_get_num_devices()' in code, "Missing device count check"
64+
65+
# Should contain validation condition
66+
assert 'deviceid >= ngpus' in code, "Missing device ID validation condition"
67+
68+
# Should contain helpful error message components
69+
assert 'does not exist' in code, "Missing 'does not exist' error message"
70+
assert 'CUDA_VISIBLE_DEVICES' in code, "Missing CUDA_VISIBLE_DEVICES guidance"
71+
assert 'container GPU configuration' in code, "Missing container guidance"
72+
73+
# Should contain exit call to prevent undefined behavior
74+
assert 'exit(1)' in code, "Missing exit call on validation failure"
4275

4376
def test_basic(self):
4477
grid = Grid(shape=(3, 3, 3))

0 commit comments

Comments
 (0)