Skip to content

Commit 8b1c74d

Browse files
committed
tests: Add tests for specialising ConditionalDimension factors
1 parent a60a0de commit 8b1c74d

7 files changed

Lines changed: 52 additions & 47 deletions

File tree

devito/ir/equations/equation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from devito.finite_differences.differentiable import diff2sympy
77
from devito.ir.equations.algorithms import dimension_sort, lower_exprs
88
from devito.ir.support import (
9-
GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses,
9+
Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses,
1010
detect_io
1111
)
12+
from devito.ir.support.guards import GuardFactorEq
1213
from devito.symbolics import IntDiv, limits_mapper, uxreplace
1314
from devito.tools import Pickable, Tag, frozendict
1415
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min
@@ -221,11 +222,11 @@ def __new__(cls, *args, **kwargs):
221222
if not d.is_Conditional:
222223
continue
223224
if d.condition is None:
224-
conditionals[d] = GuardFactor(d)
225+
conditionals[d] = GuardFactorEq.new_from_dim(d)
225226
else:
226227
cond = diff2sympy(lower_exprs(d.condition))
227228
if d._factor is not None:
228-
cond = d.relation(cond, GuardFactor(d))
229+
cond = d.relation(cond, GuardFactorEq.new_from_dim(d))
229230
conditionals[d] = cond
230231
# Replace dimension with index
231232
index = d.index

devito/ir/iet/visitors.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
)
2222
from devito.ir.support.space import Backward
2323
from devito.symbolics import (
24-
FieldFromComposite, FieldFromPointer, ListInitializer, uxreplace
24+
FieldFromComposite, FieldFromPointer, IndexedPointer, ListInitializer, uxreplace
2525
)
2626
from devito.symbolics.extended_dtypes import NoDeclStruct
2727
from devito.tools import (
@@ -1515,9 +1515,7 @@ def __init__(self, mapper, nested=False):
15151515

15161516
# Sanity check
15171517
for k, v in self.mapper.items():
1518-
# FIXME: Erronously blocks f_vec->size[1]
1519-
# Apparently this is an IndexedPointer
1520-
if not isinstance(k, AbstractSymbol):
1518+
if not isinstance(k, (AbstractSymbol, IndexedPointer)):
15211519
raise ValueError(f"Attempted to specialize non-scalar symbol: {k}")
15221520

15231521
if not isinstance(v, Number):
@@ -1530,7 +1528,10 @@ def visit_Operator(self, o, **kwargs):
15301528
# is the intended use case
15311529
body = self._visit(o.body)
15321530

1533-
not_params = tuple(i for i in self.mapper if i not in o.parameters)
1531+
# NOTE: IndexedPointers that want replacing with a hardcoded value won't appear in
1532+
# the Operator parameters. Perhaps this check wants relaxing.
1533+
not_params = tuple(i for i in self.mapper
1534+
if i not in o.parameters and isinstance(i, AbstractSymbol))
15341535
if not_params:
15351536
raise ValueError(f"Attempted to specialize symbols {not_params} which are not"
15361537
" found in the Operator parameters")

devito/ir/support/guards.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,35 +47,34 @@ def canonical(self):
4747

4848
@property
4949
def negated(self):
50-
return negations[self.__class__](*self._args_rebuild, evaluate=False)
50+
try:
51+
return negations[self.__class__](*self._args_rebuild, evaluate=False)
52+
except KeyError:
53+
raise ValueError(f"Class {self.__class__.__name__} does not have a negation")
5154

5255

5356
# *** GuardFactor
5457

5558

56-
class GuardFactor(Guard, CondEq, Pickable):
59+
class GuardFactor(Guard, Pickable):
5760

5861
"""
5962
A guard for factor-based ConditionalDimensions.
6063
61-
Given the ConditionalDimension `d` with factor `k`, create the
62-
symbolic relational `d.parent % k == 0`.
64+
Introduces a constructor where, given the ConditionalDimension `d` with factor `k`,
65+
the symbolic relational `d.parent % k == 0` is created.
6366
"""
6467

65-
__rargs__ = ('d',)
68+
__rargs__ = ('lhs', 'rhs')
6669

67-
def __new__(cls, d, **kwargs):
70+
@classmethod
71+
def new_from_dim(cls, d, **kwargs):
6872
assert d.is_Conditional
6973

7074
obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
71-
obj.d = d
7275

7376
return obj
7477

75-
@property
76-
def _args_rebuild(self):
77-
return (self.d,)
78-
7978

8079
class GuardFactorEq(GuardFactor, CondEq):
8180
pass
@@ -85,9 +84,6 @@ class GuardFactorNe(GuardFactor, CondNe):
8584
pass
8685

8786

88-
GuardFactor = GuardFactorEq
89-
90-
9187
# *** GuardBound
9288

9389

devito/symbolics/extended_sympy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class CondEq(sympy.Eq):
3636
"""
3737

3838
def __new__(cls, *args, **kwargs):
39-
return sympy.Eq.__new__(cls, *args, evaluate=False)
39+
kwargs['evaluate'] = False
40+
return sympy.Eq.__new__(cls, *args, **kwargs)
4041

4142
@property
4243
def canonical(self):

devito/symbolics/manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,9 @@ def _(mapper, rule):
152152
@singledispatch
153153
def _uxreplace_handle(expr, args, kwargs):
154154
try:
155-
return expr.func(*args, evaluate=False)
155+
return expr.func(*args, evaluate=False, **kwargs)
156156
except TypeError:
157-
return expr.func(*args)
157+
return expr.func(*args, **kwargs)
158158

159159

160160
@_uxreplace_handle.register(Min)

tests/test_pickle.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
)
1414
from devito.data import LEFT, OWNED
1515
from devito.finite_differences.tools import centered, direct, left, right, transpose
16-
from devito.ir import Backward, Forward, GuardBound, GuardBoundNext, GuardFactor
16+
from devito.ir import Backward, Forward, GuardBound, GuardBoundNext
17+
from devito.ir.support.guards import GuardFactorEq
1718
from devito.mpi.halo_scheme import Halo
1819
from devito.mpi.routines import (
1920
MPIMsgEnriched, MPIRegion, MPIRequestObject, MPIStatusObject
@@ -503,7 +504,7 @@ def test_guard_factor(self, pickle):
503504
d = Dimension(name='d')
504505
cd = ConditionalDimension(name='cd', parent=d, factor=4)
505506

506-
gf = GuardFactor(cd)
507+
gf = GuardFactorEq.new_from_dim(cd)
507508

508509
pkl_gf = pickle.dumps(gf)
509510
new_gf = pickle.loads(pkl_gf)

tests/test_specialization.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,27 +120,25 @@ def define(self, dimensions):
120120
for m, o in zip(mappers, ops):
121121
check_op(m, o)
122122

123-
# FIXME: Currently throws an error - probably a missing handler for GuardFactor
124-
# in Uxreplace
125-
# def test_factor(self):
126-
# """Test that ConditionalDimensions can have their symbolic factors specialized"""
127-
# size = 16
128-
# factor = 4
129-
# i = Dimension(name='i')
130-
# ci = ConditionalDimension(name='ci', parent=i, factor=factor)
123+
def test_factor(self):
124+
"""Test that ConditionalDimensions can have their symbolic factors specialized"""
125+
size = 16
126+
factor = 4
127+
i = Dimension(name='i')
128+
ci = ConditionalDimension(name='ci', parent=i, factor=factor)
131129

132-
# g = Function(name='g', shape=(size,), dimensions=(i,))
133-
# f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,))
130+
g = Function(name='g', shape=(size,), dimensions=(i,))
131+
f = Function(name='f', shape=(int(size/factor),), dimensions=(ci,))
134132

135-
# op0 = Operator([Eq(f, g)])
133+
op0 = Operator([Eq(f, g)])
136134

137-
# mapper = {ci.symbolic_factor: sympy.Integer(factor)}
135+
mapper = {ci.symbolic_factor: sympy.Integer(factor)}
138136

139-
# op1 = Specializer(mapper).visit(op0)
137+
op1 = Specializer(mapper).visit(op0)
140138

141-
# assert ci.symbolic_factor not in op1.parameters
142-
# assert ci.symbolic_factor.name not in str(op1.ccode)
143-
# assert "if ((i)%(4) == 0)" in str(op1.ccode)
139+
assert ci.symbolic_factor not in op1.parameters
140+
assert ci.symbolic_factor.name not in str(op1.ccode)
141+
assert "if ((i)%(4) == 0)" in str(op1.ccode)
144142

145143
def test_spacing(self):
146144
"""Test that grid spacings can be specialized"""
@@ -156,14 +154,21 @@ def test_spacing(self):
156154
assert grid.dimensions[0].spacing.name not in str(op1.ccode)
157155
assert "/1.0e-1F;" in str(op1.ccode)
158156

159-
# Strides/sizes
160-
def test_strides(self):
161-
"""Test that strides and sizes generated for linearization can be specialized"""
157+
def test_sizes(self):
158+
"""Test that strides generated for linearization can be specialized"""
162159
grid = Grid(shape=(11, 11))
163160

164161
f = TimeFunction(name='f', grid=grid, space_order=2)
165162

166163
op0 = Operator(Eq(f.forward, f.dx2),
167164
opt=('advanced', {'expand': True, 'linearize': True}))
168165

169-
from IPython import embed; embed()
166+
mapper = {f.symbolic_shape[1]: sympy.Integer(11),
167+
f.symbolic_shape[2]: sympy.Integer(11)}
168+
169+
op1 = Specializer(mapper).visit(op0)
170+
171+
assert "const int x_fsz0 = 11;" in str(op1.ccode)
172+
assert "const int y_fsz0 = 11;" in str(op1.ccode)
173+
174+
# TODO: Should strides get linearized? If so, how?

0 commit comments

Comments
 (0)