Skip to content

Commit 041eb4d

Browse files
committed
tests: Add tests for specialising ConditionalDimension factors
1 parent 80a3d71 commit 041eb4d

7 files changed

Lines changed: 52 additions & 46 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
from devito.ir.equations.algorithms import dimension_sort, lower_exprs
77
from devito.finite_differences.differentiable import diff2sympy
8-
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
8+
from devito.ir.support import (Interval, IntervalGroup, IterationSpace,
99
Stencil, detect_io, detect_accesses)
10+
from devito.ir.support.guards import GuardFactorEq
1011
from devito.symbolics import IntDiv, limits_mapper, uxreplace
1112
from devito.tools import Pickable, Tag, frozendict
1213
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
@@ -210,11 +211,11 @@ def __new__(cls, *args, **kwargs):
210211
if not d.is_Conditional:
211212
continue
212213
if d.condition is None:
213-
conditionals[d] = GuardFactor(d)
214+
conditionals[d] = GuardFactorEq.new_from_dim(d)
214215
else:
215216
cond = diff2sympy(lower_exprs(d.condition))
216217
if d._factor is not None:
217-
cond = sympy.And(cond, GuardFactor(d))
218+
cond = sympy.And(cond, GuardFactorEq.new_from_dim(d))
218219
conditionals[d] = cond
219220
# Replace dimension with index
220221
index = d.index

devito/ir/iet/visitors.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sympy.core.function import Application
1616

1717
from devito.exceptions import CompilationError
18+
from devito.symbolics import IndexedPointer
1819
from devito.ir.iet.nodes import (Node, Iteration, Expression, ExpressionBundle,
1920
Call, Lambda, BlankLine, Section, ListMajor)
2021
from devito.ir.support.space import Backward
@@ -1516,9 +1517,7 @@ def __init__(self, mapper, nested=False):
15161517

15171518
# Sanity check
15181519
for k, v in self.mapper.items():
1519-
# FIXME: Erronously blocks f_vec->size[1]
1520-
# Apparently this is an IndexedPointer
1521-
if not isinstance(k, AbstractSymbol):
1520+
if not isinstance(k, (AbstractSymbol, IndexedPointer)):
15221521
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")
15231522

15241523
if not isinstance(v, Number):
@@ -1531,7 +1530,10 @@ def visit_Operator(self, o, **kwargs):
15311530
# is the intended use case
15321531
body = self._visit(o.body)
15331532

1534-
not_params = tuple(i for i in self.mapper if i not in o.parameters)
1533+
# NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in
1534+
# the Operator parameters. Perhaps this check wants relaxing.
1535+
not_params = tuple(i for i in self.mapper
1536+
if i not in o.parameters and isinstance(i, AbstractSymbol))
15351537
if not_params:
15361538
raise ValueError(f"Attempted to specialize symbols {not_params} which are not"
15371539
" found in the Operator parameters")

devito/ir/support/guards.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,34 @@ def canonical(self):
3838

3939
@property
4040
def negated(self):
41-
return negations[self.__class__](*self._args_rebuild, evaluate=False)
41+
try:
42+
return negations[self.__class__](*self._args_rebuild, evaluate=False)
43+
except KeyError:
44+
raise ValueError(f"Class {self.__class__.__name__} does not have a negation")
4245

4346

4447
# *** GuardFactor
4548

4649

47-
class GuardFactor(Guard, CondEq, Pickable):
50+
class GuardFactor(Guard, Pickable):
4851

4952
"""
5053
A guard for factor-based ConditionalDimensions.
5154
52-
Given the ConditionalDimension `d` with factor `k`, create the
53-
symbolic relational `d.parent % k == 0`.
55+
Introduces a constructor where, given the ConditionalDimension `d` with factor `k`,
56+
the symbolic relational `d.parent % k == 0` is created.
5457
"""
5558

56-
__rargs__ = ('d',)
59+
__rargs__ = ('lhs', 'rhs')
5760

58-
def __new__(cls, d, **kwargs):
61+
@classmethod
62+
def new_from_dim(cls, d, **kwargs):
5963
assert d.is_Conditional
6064

6165
obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
62-
obj.d = d
6366

6467
return obj
6568

66-
@property
67-
def _args_rebuild(self):
68-
return (self.d,)
69-
7069

7170
class GuardFactorEq(GuardFactor, CondEq):
7271
pass
@@ -76,9 +75,6 @@ class GuardFactorNe(GuardFactor, CondNe):
7675
pass
7776

7877

79-
GuardFactor = GuardFactorEq
80-
81-
8278
# *** GuardBound
8379

8480

devito/symbolics/extended_sympy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ class CondEq(sympy.Eq):
3535
"""
3636

3737
def __new__(cls, *args, **kwargs):
38-
return sympy.Eq.__new__(cls, *args, evaluate=False)
38+
kwargs['evaluate'] = False
39+
return sympy.Eq.__new__(cls, *args, **kwargs)
3940

4041
@property
4142
def canonical(self):

devito/symbolics/manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def _(mapper, rule):
136136
@singledispatch
137137
def _uxreplace_handle(expr, args, kwargs):
138138
try:
139-
return expr.func(*args, evaluate=False)
139+
return expr.func(*args, evaluate=False, **kwargs)
140140
except TypeError:
141-
return expr.func(*args)
141+
return expr.func(*args, **kwargs)
142142

143143

144144
@_uxreplace_handle.register(Min)

tests/test_pickle.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
Dimension, SubDimension, ConditionalDimension, IncrDimension,
1111
TimeDimension, SteppingDimension, Operator, MPI, Min, solve,
1212
PrecomputedSparseTimeFunction, SubDomain)
13-
from devito.ir import Backward, Forward, GuardFactor, GuardBound, GuardBoundNext
13+
from devito.ir import Backward, Forward, GuardBound, GuardBoundNext
14+
from devito.ir.support.guards import GuardFactorEq
1415
from devito.data import LEFT, OWNED
1516
from devito.finite_differences.tools import direct, transpose, left, right, centered
1617
from devito.mpi.halo_scheme import Halo
@@ -500,7 +501,7 @@ def test_guard_factor(self, pickle):
500501
d = Dimension(name='d')
501502
cd = ConditionalDimension(name='cd', parent=d, factor=4)
502503

503-
gf = GuardFactor(cd)
504+
gf = GuardFactorEq.new_from_dim(cd)
504505

505506
pkl_gf = pickle.dumps(gf)
506507
new_gf = pickle.loads(pkl_gf)

tests/test_specialization.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,25 @@ def define(self, dimensions):
120120
for m, o in zip(mappers, ops):
121121
check_op(m, o)
122122

123-
# FIXME: Currently throws an error - probably a missing handler for GuardFactor
124-
# in Uxreplace
125-
# def test_factor(self):
126-
# """Test that ConditionalDimensions can have their symbolic factors specialized"""
127-
# size = 16
128-
# factor = 4
129-
# i = Dimension(name='i')
130-
# ci = ConditionalDimension(name='ci', parent=i, factor=factor)
123+
def test_factor(self):
124+
"""Test that ConditionalDimensions can have their symbolic factors specialized"""
125+
size = 16
126+
factor = 4
127+
i = Dimension(name='i')
128+
ci = ConditionalDimension(name='ci', parent=i, factor=factor)
131129

132-
# g = Function(name='g', shape=(size,), dimensions=(i,))
133-
# f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,))
130+
g = Function(name='g', shape=(size,), dimensions=(i,))
131+
f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,))
134132

135-
# op0 = Operator([Eq(f, g)])
133+
op0 = Operator([Eq(f, g)])
136134

137-
# mapper = {ci.symbolic_factor: sympy.Integer(factor)}
135+
mapper = {ci.symbolic_factor: sympy.Integer(factor)}
138136

139-
# op1 = Specializer(mapper).visit(op0)
137+
op1 = Specializer(mapper).visit(op0)
140138

141-
# assert ci.symbolic_factor not in op1.parameters
142-
# assert ci.symbolic_factor.name not in str(op1.ccode)
143-
# assert "if ((i)%(4) == 0)" in str(op1.ccode)
139+
assert ci.symbolic_factor not in op1.parameters
140+
assert ci.symbolic_factor.name not in str(op1.ccode)
141+
assert "if ((i)%(4) == 0)" in str(op1.ccode)
144142

145143
def test_spacing(self):
146144
"""Test that grid spacings can be specialized"""
@@ -156,14 +154,21 @@ def test_spacing(self):
156154
assert grid.dimensions[0].spacing.name not in str(op1.ccode)
157155
assert "/1.0e-1F;" in str(op1.ccode)
158156

159-
# Strides/sizes
160-
def test_strides(self):
161-
"""Test that strides and sizes generated for linearization can be specialized"""
157+
def test_sizes(self):
158+
"""Test that strides generated for linearization can be specialized"""
162159
grid = Grid(shape=(11, 11))
163160

164161
f = TimeFunction(name='f', grid=grid, space_order=2)
165162

166163
op0 = Operator(Eq(f.forward, f.dx2),
167164
opt=('advanced', {'expand': True, 'linearize': True}))
168165

169-
from IPython import embed; embed()
166+
mapper = {f.symbolic_shape[1]: sympy.Integer(11),
167+
f.symbolic_shape[2]: sympy.Integer(11)}
168+
169+
op1 = Specializer(mapper).visit(op0)
170+
171+
assert "const int x_fsz0 = 11;" in str(op1.ccode)
172+
assert "const int y_fsz0 = 11;" in str(op1.ccode)
173+
174+
# TODO: Should strides get linearized? If so, how?

0 commit comments

Comments
 (0)