Skip to content

Commit 7afa536

Browse files
committed
tests: Expand specialization tests
1 parent 29a255e commit 7afa536

1 file changed

Lines changed: 21 additions & 15 deletions

File tree

tests/test_specialization.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import logging
22

3-
import sympy
4-
import pytest
5-
63
import numpy as np
4+
import pytest
5+
import sympy
76

8-
from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension,
9-
ConditionalDimension, switchconfig)
7+
from devito import (
8+
ConditionalDimension, Dimension, Eq, Function, Grid, Operator, SubDomain,
9+
TimeFunction, switchconfig
10+
)
1011
from devito.ir.iet.visitors import Specializer
1112

1213

@@ -66,14 +67,14 @@ def check_op(mapper, operator):
6667
mappers = (mapper0, mapper1, mapper2, mapper3)
6768
ops = tuple(Specializer(m).visit(op) for m in mappers)
6869

69-
for m, o in zip(mappers, ops):
70+
for m, o in zip(mappers, ops, strict=True):
7071
check_op(m, o)
7172

7273
def test_subdomain(self):
7374
"""Test that SubDomain thicknesses can be specialized"""
7475

7576
def check_op(mapper, operator):
76-
for k in mapper.keys():
77+
for k in mapper:
7778
assert k not in operator.parameters
7879
assert k.name not in str(operator.ccode)
7980

@@ -105,7 +106,7 @@ def define(self, dimensions):
105106
mappers = (mapper0, mapper1, mapper2)
106107
ops = tuple(Specializer(m).visit(op) for m in mappers)
107108

108-
for m, o in zip(mappers, ops):
109+
for m, o in zip(mappers, ops, strict=True):
109110
check_op(m, o)
110111

111112
def test_factor(self):
@@ -208,7 +209,7 @@ def test_basic(self, caplog, override):
208209
kwargs['x_m'] = 3
209210

210211
with switchconfig(log_level='DEBUG'), caplog.at_level(logging.DEBUG):
211-
op.apply(specialize=specialize, **kwargs)
212+
op.apply_specialize(specialize=specialize, **kwargs)
212213

213214
# Ensure that the specialized operator was run
214215
assert all(s not in caplog.text for s in specialize)
@@ -225,7 +226,14 @@ def test_basic(self, caplog, override):
225226
def test_basic_mpi(self, caplog, mode, override):
226227
self.test_basic(caplog, override)
227228

228-
def test_diffusion_like(self):
229+
@pytest.mark.parametrize('specialize',
230+
[('x_m',),
231+
('y_M',),
232+
('t_m',),
233+
('t_m', 't_M'),
234+
('x_m', 'y_M'),
235+
('x_m', 'x_M', 'y_m', 'y_M')])
236+
def test_diffusion_like(self, specialize):
229237
grid = Grid(shape=(11, 11))
230238

231239
dt = 2.5e-5
@@ -237,15 +245,13 @@ def test_diffusion_like(self):
237245

238246
op.apply(t_M=100, dt=dt)
239247

240-
check = np.array(f.data[0])
248+
check = np.array(f.data)
241249
f.data[:] = 0
242250
f.data[:, 4:-4, 4:-4] = 1
243251

244-
op.apply(t_M=100, dt=dt, specialize=tuple())
252+
op.apply_specialize(t_M=100, dt=dt, specialize=specialize)
245253

246-
print(f.data[0])
247-
print(check)
248-
assert False
254+
assert np.all(np.isclose(check, f.data))
249255

250256
# Diffusion-like test
251257
# Acoustic-like test (with and without source injection)

0 commit comments

Comments
 (0)