Skip to content

Commit 7b362eb

Browse files
committed
tests: Start adding tests for operator specialization
1 parent 99ccbe1 commit 7b362eb

3 files changed

Lines changed: 174 additions & 2 deletions

File tree

devito/ir/iet/visitors.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,11 @@ class Specializer(Uxreplace):
15041504
A Transformer to "specialize" a pre-built Operator - that is to replace a given
15051505
set of (scalar) symbols with hard-coded values to free up registers. This will
15061506
yield a "specialized" version of the Operator, specific to a particular setup.
1507+
1508+
Note that the Operator is not re-optimized in response to this replacement - this
1509+
transformation could nominally result in expressions of the form `f + 0` in the
1510+
generated code. If one wants to construct an Operator where such expressions are
1511+
considered, then use of `subs=...` is a better choice.
15071512
"""
15081513

15091514
def __init__(self, mapper, nested=False):
@@ -1515,15 +1520,31 @@ def __init__(self, mapper, nested=False):
15151520
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")
15161521

15171522
def visit_Operator(self, o, **kwargs):
1518-
# Entirely fine to apply this to an Operator
1523+
# Entirely fine to apply this to an Operator (unlike Uxreplace) - indeed this
1524+
# is the intended use case
15191525
body = self._visit(o.body)
1526+
1527+
not_params = tuple(i for i in self.mapper if i not in o.parameters)
1528+
if not_params:
1529+
raise ValueError(f"Attempted to specialize symbols {not_params} which are not"
1530+
" found in the Operator parameters")
1531+
1532+
# FIXME: Should also type-check the values supplied against the symbols they are
1533+
# replacing (and cast them if needed?) -> use a try-except on the cast in
1534+
# python-land
1535+
15201536
parameters = tuple(i for i in o.parameters if i not in self.mapper)
15211537

15221538
# Note: the following is not dissimilar to unpickling an Operator
15231539
state = o.__getstate__()
15241540
state['parameters'] = parameters
15251541
state['body'] = body
1526-
state.pop('ccode')
1542+
1543+
try:
1544+
state.pop('ccode')
1545+
except KeyError:
1546+
# C code has not previously been generated for this Operator
1547+
pass
15271548

15281549
# FIXME: These names aren't great
15291550
newargs, newkwargs = o.__getnewargs_ex__()

devito/types/dimension.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ def symbolic_max(self):
175175
"""Symbol defining the maximum point of the Dimension."""
176176
return Scalar(name=self.max_name, dtype=np.int32, is_const=True)
177177

178+
@property
179+
def symbolic_extrema(self):
180+
"""Symbols for the minimum and maximum points of the Dimension"""
181+
return (self.symbolic_min, self.symbolic_max)
182+
178183
@property
179184
def symbolic_incr(self):
180185
"""The increment value while iterating over the Dimension."""

tests/test_specialization.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import sympy
2+
import pytest
3+
4+
from devito import (Grid, Function, TimeFunction, Eq, Operator, SubDomain, Dimension,
5+
ConditionalDimension)
6+
from devito.ir.iet.visitors import Specializer
7+
8+
# Test that specializer replaces symbols as expected
9+
10+
# Create a couple of arbitrary operators
11+
# Reference bounds, subdomains, spacings, constants, conditionaldimensions with symbolic
12+
# factor
13+
# Create a couple of different substitution sets
14+
15+
# Check that all the instances in the kernel are replaced
16+
# Check that all the instances in the parameters are removed
17+
18+
# Check that sanity check catches attempts to specialize non-scalar types
19+
# Check that trying to specialize symbols not in the Operator parameters results
20+
# in an error being thrown
21+
22+
# Check that sizes and strides get specialized when using `linearize=True`
23+
24+
25+
class TestSpecializer:
26+
"""Tests for the Specializer transformer"""
27+
28+
@pytest.mark.parametrize('pre_gen', [True, False])
29+
@pytest.mark.parametrize('expand', [True, False])
30+
def test_bounds(self, pre_gen, expand):
31+
"""Test specialization of dimension bounds"""
32+
grid = Grid(shape=(11, 11))
33+
34+
((x_m, x_M), (y_m, y_M)) = [d.symbolic_extrema for d in grid.dimensions]
35+
time_m = grid.time_dim.symbolic_min
36+
minima = (x_m, y_m, time_m)
37+
maxima = (x_M, y_M)
38+
39+
def check_op(mapper, operator):
40+
for k, v in mapper.items():
41+
assert k not in operator.parameters
42+
assert k.name not in str(operator.ccode)
43+
# Check that the loop bounds are modified correctly
44+
if k in minima:
45+
assert f"{k.name.split('_')[0]} = {v}" in str(operator.ccode)
46+
elif k in maxima:
47+
assert f"{k.name.split('_')[0]} <= {v}" in str(operator.ccode)
48+
49+
f = Function(name='f', grid=grid)
50+
g = Function(name='g', grid=grid)
51+
h = TimeFunction(name='h', grid=grid)
52+
53+
eq0 = Eq(f, f + 1)
54+
eq1 = Eq(g, f.dx)
55+
eq2 = Eq(h.forward, (g + x_m).dy)
56+
eq3 = Eq(f, x_M)
57+
58+
# Check behaviour with expansion since we have a replaced symbol inside a
59+
# derivative
60+
if expand:
61+
kwargs = {'opt': ('advanced', {'expand': True})}
62+
else:
63+
kwargs = {'opt': ('advanced', {'expand': False})}
64+
65+
op = Operator([eq0, eq1, eq2, eq3], **kwargs)
66+
67+
if pre_gen:
68+
# Generate C code for the unspecialized Operator - the result should be
69+
# the same regardless, but it ensures that the old generated code is
70+
# purged and replaced in the specialized Operator
71+
_ = op.ccode
72+
73+
mapper0 = {x_m: sympy.S.Zero}
74+
mapper1 = {x_M: sympy.Integer(20), y_m: sympy.S.Zero}
75+
mapper2 = {**mapper0, **mapper1}
76+
mapper3 = {y_M: sympy.Integer(10), time_m: sympy.Integer(5)}
77+
78+
mappers = (mapper0, mapper1, mapper2, mapper3)
79+
ops = tuple(Specializer(m).visit(op) for m in mappers)
80+
81+
for m, o in zip(mappers, ops):
82+
check_op(m, o)
83+
84+
def test_subdomain(self):
85+
"""Test that SubDomain thicknesses can be specialized"""
86+
87+
def check_op(mapper, operator):
88+
for k in mapper.keys():
89+
assert k not in operator.parameters
90+
assert k.name not in str(operator.ccode)
91+
92+
class SD(SubDomain):
93+
name = 'sd'
94+
95+
def define(self, dimensions):
96+
x, y = dimensions
97+
return {x: ('middle', 1, 1), y: ('right', 2)}
98+
99+
grid = Grid(shape=(11, 11))
100+
sd = SD(grid=grid)
101+
102+
f = Function(name='f', grid=grid)
103+
g = Function(name='g', grid=sd)
104+
105+
eqs = [Eq(f, f+1, subdomain=sd),
106+
Eq(g, g+1, subdomain=sd)]
107+
108+
op = Operator(eqs)
109+
110+
subdims = [d for d in op.dimensions if d.is_Sub]
111+
((xltkn, xrtkn), (_, yrtkn)) = [d.thickness for d in subdims]
112+
113+
mapper0 = {xltkn: sympy.S.Zero}
114+
mapper1 = {xrtkn: sympy.Integer(2), yrtkn: sympy.S.Zero}
115+
mapper2 = {**mapper0, **mapper1}
116+
117+
mappers = (mapper0, mapper1, mapper2)
118+
ops = tuple(Specializer(m).visit(op) for m in mappers)
119+
120+
for m, o in zip(mappers, ops):
121+
check_op(m, o)
122+
123+
# FIXME: Currently throws an error
124+
# def test_factor(self):
125+
# """Test that ConditionalDimensions can have their symbolic factors specialized"""
126+
# size = 16
127+
# factor = 4
128+
# i = Dimension(name='i')
129+
# ci = ConditionalDimension(name='ci', parent=i, factor=factor)
130+
131+
# g = Function(name='g', shape=(size,), dimensions=(i,))
132+
# f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,))
133+
134+
# op0 = Operator([Eq(f, g)])
135+
136+
# mapper = {ci.symbolic_factor: sympy.Integer(factor)}
137+
138+
# op1 = Specializer(mapper).visit(op0)
139+
140+
# assert ci.symbolic_factor not in op1.parameters
141+
# assert ci.symbolic_factor.name not in str(op1.ccode)
142+
# assert "if ((i)%(4) == 0)" in str(op1.ccode)
143+
144+
# Spacings
145+
146+
# Strides/sizes

0 commit comments

Comments
 (0)