@@ -20,8 +20,13 @@ def test_init_omp_env(self):
2020
2121 op = Operator (Eq (u .forward , u .dx + 1 ), language = 'openmp' )
2222
23- assert str (op .body .init [0 ].body [0 ]) == \
24- 'if (deviceid != -1)\n {\n omp_set_default_device(deviceid);\n }'
23+ # With device validation, the generated code now includes validation logic
24+ init_code = str (op .body .init [0 ].body [0 ])
25+ assert 'if (deviceid != -1)' in init_code
26+ assert 'int ngpus = omp_get_num_devices()' in init_code
27+ assert 'if (deviceid >= ngpus)' in init_code
28+ assert 'does not exist' in init_code
29+ assert 'omp_set_default_device(deviceid)' in init_code
2530
2631 @pytest .mark .parallel (mode = 1 )
2732 def test_init_omp_env_w_mpi (self , mode ):
@@ -31,14 +36,42 @@ def test_init_omp_env_w_mpi(self, mode):
3136
3237 op = Operator (Eq (u .forward , u .dx + 1 ), language = 'openmp' )
3338
34- assert str (op .body .init [0 ].body [0 ]) == \
35- ('if (deviceid != -1)\n '
36- '{\n omp_set_default_device(deviceid);\n }\n '
37- 'else\n '
38- '{\n int rank = 0;\n '
39- ' MPI_Comm_rank(comm,&rank);\n '
40- ' int ngpus = omp_get_num_devices();\n '
41- ' omp_set_default_device((rank)%(ngpus));\n }' )
39+ # With device validation, the MPI case also includes validation for explicit deviceid
40+ init_code = str (op .body .init [0 ].body [0 ])
41+ assert 'if (deviceid != -1)' in init_code
42+ assert 'int ngpus = omp_get_num_devices()' in init_code
43+ # For MPI case with explicit deviceid, should have validation
44+ assert 'if (deviceid >= ngpus)' in init_code
45+ assert 'does not exist' in init_code
46+ # Should still have MPI rank-based assignment in else clause
47+ assert 'int rank = 0' in init_code
48+ assert 'MPI_Comm_rank(comm,&rank)' in init_code
49+ assert '(rank)%(ngpus)' in init_code
50+
51+ def test_device_validation_error_message (self ):
52+ """Test that device validation includes helpful error messages."""
53+ grid = Grid (shape = (3 , 3 , 3 ))
54+
55+ u = TimeFunction (name = 'u' , grid = grid )
56+
57+ op = Operator (Eq (u .forward , u .dx + 1 ), language = 'openmp' )
58+
59+ # Check that the generated code contains device validation with informative error
60+ code = str (op )
61+
62+ # Should contain device count check
63+ assert 'omp_get_num_devices()' in code , "Missing device count check"
64+
65+ # Should contain validation condition
66+ assert 'deviceid >= ngpus' in code , "Missing device ID validation condition"
67+
68+ # Should contain helpful error message components
69+ assert 'does not exist' in code , "Missing 'does not exist' error message"
70+ assert 'CUDA_VISIBLE_DEVICES' in code , "Missing CUDA_VISIBLE_DEVICES guidance"
71+ assert 'container GPU configuration' in code , "Missing container guidance"
72+
73+ # Should contain exit call to prevent undefined behavior
74+ assert 'exit(1)' in code , "Missing exit call on validation failure"
4275
4376 def test_basic (self ):
4477 grid = Grid (shape = (3 , 3 , 3 ))
0 commit comments