Skip to content

Commit 80a3d71

Browse files
committed
tests: Introduce further tests
1 parent 7b362eb commit 80a3d71

2 files changed

Lines changed: 34 additions & 4 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import ctypes
1212

1313
import cgen as c
14-
from sympy import IndexedBase
14+
from sympy import IndexedBase, Number
1515
from sympy.core.function import Application
1616

1717
from devito.exceptions import CompilationError
@@ -1515,10 +1515,17 @@ def __init__(self, mapper, nested=False):
15151515
super().__init__(mapper, nested=nested)
15161516

15171517
# Sanity check
1518-
for k in self.mapper.keys():
1518+
for k, v in self.mapper.items():
1519+
# FIXME: Erronously blocks f_vec->size[1]
1520+
# Apparently this is an IndexedPointer
15191521
if not isinstance(k, AbstractSymbol):
15201522
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")
15211523

1524+
if not isinstance(v, Number):
1525+
raise ValueError("Only SymPy Numbers can used to replace values during "
1526+
f"specialization. Value {v} was supplied for symbol "
1527+
f"{k}, but is of type {type(v)}.")
1528+
15221529
def visit_Operator(self, o, **kwargs):
15231530
# Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this
15241531
# is the intended use case

tests/test_specialization.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def define(self, dimensions):
120120
for m, o in zip(mappers, ops):
121121
check_op(m, o)
122122

123-
# FIXME: Currently throws an error
123+
# FIXME: Currently throws an error - probably a missing handler for GuardFactor
124+
# in Uxreplace
124125
# def test_factor(self):
125126
# """Test that ConditionalDimensions can have their symbolic factors specialized"""
126127
# size = 16
@@ -141,6 +142,28 @@ def define(self, dimensions):
141142
# assert ci.symbolic_factor.name not in str(op1.ccode)
142143
# assert "if ((i)%(4) == 0)" in str(op1.ccode)
143144

144-
# Spacings
145+
def test_spacing(self):
146+
"""Test that grid spacings can be specialized"""
147+
grid = Grid(shape=(11,))
148+
f = Function(name='f', grid=grid)
149+
150+
op0 = Operator(Eq(f, f.dx))
151+
152+
mapper = {grid.dimensions[0].spacing: sympy.Float(grid.spacing[0])}
153+
op1 = Specializer(mapper).visit(op0)
154+
155+
assert grid.dimensions[0].spacing not in op1.parameters
156+
assert grid.dimensions[0].spacing.name not in str(op1.ccode)
157+
assert "/1.0e-1F;" in str(op1.ccode)
145158

146159
# Strides/sizes
160+
def test_strides(self):
161+
"""Test that strides and sizes generated for linearization can be specialized"""
162+
grid = Grid(shape=(11, 11))
163+
164+
f = TimeFunction(name='f', grid=grid, space_order=2)
165+
166+
op0 = Operator(Eq(f.forward, f.dx2),
167+
opt=('advanced', {'expand': True, 'linearize': True}))
168+
169+
from IPython import embed; embed()

0 commit comments

Comments
 (0)