Skip to content

Commit 7b48777

Browse files
committed
api: fix arg processing for subsampling factor
1 parent 38c673b commit 7b48777

3 files changed

Lines changed: 22 additions & 13 deletions

File tree

devito/types/dimension.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
915915

916916
# Process subsampling factor
917917
fname = f"{name}f"
918-
if factor is None:
918+
if factor is None or factor == 1:
919919
self._factor = None
920920
elif is_number(factor):
921921
self._factor = int(factor)
@@ -966,9 +966,15 @@ def free_symbols(self):
966966
pass
967967
return retval
968968

969-
def _arg_values(self, interval, grid=None, **kwargs):
970-
# Parent dimension define the interval
971-
fact = self.factor
969+
def _arg_values(self, interval, grid=None, args=None, **kwargs):
970+
if self.symbolic_factor is not None:
971+
fname = self.symbolic_factor.name
972+
args = args or {}
973+
fact = kwargs.get(fname, args.get(fname, self.factor))
974+
else:
975+
# No factor
976+
return {}
977+
972978
toint = lambda x: math.ceil(x / fact)
973979
vals = {}
974980
try:
@@ -981,6 +987,9 @@ def _arg_values(self, interval, grid=None, **kwargs):
981987
except (KeyError, TypeError):
982988
pass
983989

990+
if self.symbolic_factor is not None:
991+
vals[self.symbolic_factor.name] = fact
992+
984993
return vals
985994

986995
def _arg_defaults(self, _min=None, size=None, alias=None):
@@ -990,11 +999,9 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
990999
# `factor` endpoints are legal, so we return them all. It's then
9911000
# up to the caller to decide which one to pick upon reduction
9921001
dim = alias or self
993-
if dim.condition is not None or size is None or dim._factor is None:
994-
return defaults
995-
996-
factor = defaults[dim.symbolic_factor.name] = self.factor
997-
defaults[dim.parent.max_name] = range(0, factor*size - 1)
1002+
if dim.symbolic_factor is not None:
1003+
factor = defaults[dim.symbolic_factor.name] = self.factor
1004+
defaults[dim.parent.max_name] = range(0, factor*size - 1)
9981005

9991006
return defaults
10001007

tests/test_dimension.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.ir import SymbolRegistry
1818
from devito.symbolics import indexify, retrieve_functions, IntDiv, INT
1919
from devito.types import Array, StencilDimension, Symbol
20+
from devito.types.basic import Scalar
2021
from devito.types.dimension import AffineIndexAccessFunction, Thickness
2122

2223

@@ -1012,9 +1013,9 @@ def test_issue_1592(self):
10121013
op = Operator(Eq(v.forward, v.dx))
10131014
op.apply(time=6)
10141015
exprs = FindNodes(Expression).visit(op)
1015-
assert exprs[-1].expr.lhs.indices[0] == IntDiv(time, time_sub.factor) + 1
1016-
assert time_sub.factor.data == 2
1017-
assert time_sub.factor.is_Constant
1016+
assert exprs[-1].expr.lhs.indices[0] == IntDiv(time, time_sub.symbolic_factor) + 1
1017+
assert time_sub.factor == 2
1018+
assert isinstance(time_sub.symbolic_factor, Scalar)
10181019

10191020
def test_issue_1753(self):
10201021
grid = Grid(shape=(3, 3, 3))

tests/test_pickle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def test_conditional_dimension(self, pickle):
320320
assert cd.name == new_cd.name
321321
assert cd.parent.name == new_cd.parent.name
322322
assert cd.factor == new_cd.factor
323+
assert cd.symbolic_factor == new_cd.symbolic_factor
323324
assert cd.condition == new_cd.condition
324325

325326
def test_incr_dimension(self, pickle):
@@ -603,7 +604,7 @@ def test_equation(self, pickle):
603604
assert new_eq.lhs.name == f.name
604605
assert str(new_eq.rhs) == 'f(x) + 1'
605606
assert new_eq.implicit_dims[0].name == 'xs'
606-
assert new_eq.implicit_dims[0].factor.data == 4
607+
assert new_eq.implicit_dims[0].factor == 4
607608

608609
@pytest.mark.parametrize('typ', [ctypes.c_float, 'struct truct'])
609610
def test_Cast(self, pickle, typ):

0 commit comments

Comments
 (0)