Skip to content

Commit ade5198

Browse files
committed
api: fix arg processing for conditon+factor
1 parent 21f1921 commit ade5198

3 files changed

Lines changed: 83 additions & 3 deletions

File tree

devito/types/dimension.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from devito.types.args import ArgProvider
1515
from devito.types.basic import Symbol, DataSymbol, Scalar
1616
from devito.types.constant import Constant
17+
from devito.types.relational import relational_min, relational_max
1718

1819

1920
__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
@@ -1015,7 +1016,20 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
10151016
dim = alias or self
10161017
if dim.uses_symbolic_factor:
10171018
factor = defaults[dim.symbolic_factor.name] = self.factor_data
1018-
defaults[dim.parent.max_name] = range(0, factor*size - 1)
1019+
if dim.condition is None:
1020+
d0 = 0
1021+
d1 = sympy.S.Infinity
1022+
else:
1023+
d0 = relational_min(dim.condition, dim.parent)
1024+
d1 = relational_max(dim.condition, dim.parent)
1025+
if d1 < sympy.S.Infinity:
1026+
# We make sure the condition size matches the input size
1027+
size0 = (d1 - d0 + factor) // factor
1028+
if size < size0:
1029+
raise ValueError(f"Incompatible size for ConditionalDimension "
1030+
f"{self.name}: {size} < {size0}")
1031+
else:
1032+
defaults[dim.parent.max_name] = range(d0, d0 + factor*size - 1)
10191033

10201034
return defaults
10211035

@@ -1482,7 +1496,7 @@ def _arg_defaults(self, **kwargs):
14821496
def _arg_values(self, *args, **kwargs):
14831497
return {}
14841498

1485-
def _arg_check(self, *args):
1499+
def _arg_check(self, *args, **kwargs):
14861500
"""A CustomDimension performs no runtime checks."""
14871501
return
14881502

devito/types/relational.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import sympy
55

6-
__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne', 'relational_min']
6+
__all__ = ['Le', 'Lt', 'Ge', 'Gt', 'Ne', 'relational_min', 'relational_max']
77

88

99
class AbstractRel:
@@ -250,3 +250,44 @@ def _(expr, s):
250250
return expr.lts
251251
else:
252252
return 0
253+
254+
255+
def relational_max(expr, s):
256+
"""
257+
Infer the maximum valid value for symbol `s` in the expression `expr`.
258+
For example
259+
- if `expr` is `s < 10`, then the maximum valid value for `s` is 9
260+
- if `expr` is `s >= 10`, then the maximum valid value for `s` is Inf
261+
"""
262+
if not expr.has(s):
263+
return sympy.S.Infinity
264+
265+
return _relational_max(expr, s)
266+
267+
268+
@singledispatch
269+
def _relational_max(s, expr):
270+
return sympy.S.Infinity
271+
272+
273+
@_relational_max.register(sympy.And)
274+
def _(expr, s):
275+
return min([_relational_max(e, s) for e in expr.args])
276+
277+
278+
@_relational_max.register(Gt)
279+
@_relational_max.register(Lt)
280+
def _(expr, s):
281+
if s == expr.lts:
282+
return expr.gts - 1
283+
else:
284+
return sympy.S.Infinity
285+
286+
287+
@_relational_max.register(Ge)
288+
@_relational_max.register(Le)
289+
def _(expr, s):
290+
if s == expr.lts:
291+
return expr.gts
292+
else:
293+
return sympy.S.Infinity

tests/test_dimension.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,31 @@ def test_symbolic_factor_override(self):
19821982

19831983
assert all(np.all(usave.data[i] == 16 + i*8) for i in range(4))
19841984

1985+
def test_factor_and_condition(self):
1986+
grid = Grid(shape=(10, 10))
1987+
time = grid.time_dim
1988+
1989+
nt = 200
1990+
bounds = (10, 100)
1991+
factor = 5
1992+
1993+
condition = And(Ge(time, bounds[0]), Le(time, bounds[1]))
1994+
time_under = ConditionalDimension(name='timeu', parent=time,
1995+
factor=factor, condition=condition)
1996+
buffer_size = (bounds[1] - bounds[0] + factor) // factor
1997+
1998+
rec = SparseTimeFunction(name='rec', grid=grid, npoint=1, nt=nt)
1999+
u = TimeFunction(name='u', grid=grid, space_order=2)
2000+
usaved = TimeFunction(name='usaved', grid=grid, space_order=2,
2001+
time_dim=time_under, save=buffer_size)
2002+
2003+
eq = [Eq(u.forward, u+1), Eq(usaved, u)] + rec.interpolate(u)
2004+
2005+
op = Operator(eq)
2006+
op(time_m=1, time_M=nt-1, usaved=usaved, rec=rec)
2007+
for t in range(buffer_size):
2008+
assert np.all(usaved.data[t] == t*factor + bounds[0] - 1)
2009+
19852010

19862011
class TestCustomDimension:
19872012

0 commit comments

Comments
 (0)