Skip to content

Commit 019c837

Browse files
Merge pull request #2575 from devitocodes/fix-subsampling-rebuild
api: revamp subsampling factors to avoid duplicates
2 parents 4196aa4 + 23f9eaa commit 019c837

8 files changed

Lines changed: 106 additions & 32 deletions

File tree

devito/deprecations.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,12 @@ def subdomain_warn(self):
2727
DeprecationWarning, stacklevel=2)
2828
return
2929

30+
@cached_property
31+
def constant_factor_warn(self):
32+
warn("Using a `Constant` as a factor when creating a ConditionalDimension"
33+
" is deprecated. Use an integer instead.",
34+
DeprecationWarning, stacklevel=2)
35+
return
36+
3037

3138
deprecations = DevitoDeprecation()

devito/ir/equations/equation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __new__(cls, *args, **kwargs):
220220
index = d.index
221221
if d.condition is not None and d in expr.free_symbols:
222222
index = index - relational_min(d.condition, d.parent)
223-
expr = uxreplace(expr, {d: IntDiv(index, d.factor)})
223+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})
224224

225225
conditionals = frozendict(conditionals)
226226

devito/ir/support/guards.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class GuardFactor(Guard, CondEq, Pickable):
4747
def __new__(cls, d, **kwargs):
4848
assert d.is_Conditional
4949

50-
obj = super().__new__(cls, d.parent % d.factor, 0)
50+
obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
5151
obj.d = d
5252

5353
return obj
@@ -129,7 +129,7 @@ def __new__(cls, d, direction, **kwargs):
129129
p1 = d.root.symbolic_max
130130

131131
if d.is_Conditional:
132-
v = d.factor
132+
v = d.symbolic_factor
133133
# Round `p0 + 1` up to the nearest multiple of `v`
134134
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
135135
else:
@@ -140,7 +140,7 @@ def __new__(cls, d, direction, **kwargs):
140140
p1 = d.root
141141

142142
if d.is_Conditional:
143-
v = d.factor
143+
v = d.symbolic_factor
144144
# Round `p1 - 1` down to the nearest sub-multiple of `v`
145145
# NOTE: we use ABS to make sure we handle negative values properly.
146146
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),

devito/symbolics/extended_sympy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class IntDiv(sympy.Expr):
8484
def __new__(cls, lhs, rhs, params=None):
8585
if rhs == 0:
8686
raise ValueError("Cannot divide by 0")
87-
elif rhs == 1:
87+
elif rhs == 1 or rhs is None:
8888
return lhs
8989

9090
if not is_integer(rhs):

devito/types/dimension.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
import numpy as np
88

99
from devito.data import LEFT, RIGHT
10+
from devito.deprecations import deprecations
1011
from devito.exceptions import InvalidArgument
1112
from devito.logger import debug
12-
from devito.tools import Pickable, is_integer, memoized_meth
13+
from devito.tools import Pickable, is_integer, is_number, memoized_meth
1314
from devito.types.args import ArgProvider
1415
from devito.types.basic import Symbol, DataSymbol, Scalar
1516
from devito.types.constant import Constant
@@ -822,6 +823,10 @@ def bound_symbols(self):
822823
return self.parent.bound_symbols
823824

824825

826+
class SubsamplingFactor(Scalar):
827+
pass
828+
829+
825830
class ConditionalDimension(DerivedDimension):
826831

827832
"""
@@ -898,7 +903,8 @@ class ConditionalDimension(DerivedDimension):
898903
is_NonlinearDerived = True
899904
is_Conditional = True
900905

901-
__rkwargs__ = DerivedDimension.__rkwargs__ + ('factor', 'condition', 'indirect')
906+
__rkwargs__ = DerivedDimension.__rkwargs__ + \
907+
('factor', 'condition', 'indirect')
902908

903909
def __init_finalize__(self, name, parent=None, factor=None, condition=None,
904910
indirect=False, **kwargs):
@@ -909,27 +915,49 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
909915

910916
super().__init_finalize__(name, parent)
911917

912-
# Always make the factor symbolic to allow overrides with different factor.
918+
# Process subsampling factor
913919
if factor is None or factor == 1:
914920
self._factor = None
915-
elif is_integer(factor):
916-
self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32)
917-
elif factor.is_Constant and is_integer(factor.data):
921+
elif is_number(factor):
922+
self._factor = int(factor)
923+
elif factor.is_Constant:
924+
deprecations.constant_factor_warn
918925
self._factor = factor
919926
else:
920-
raise ValueError("factor must be an integer or integer Constant")
927+
raise ValueError("factor must be an integer")
921928

922929
self._condition = condition
923930
self._indirect = indirect
924931

932+
@property
933+
def uses_symbolic_factor(self):
934+
return self._factor is not None
935+
936+
@property
937+
def factor_data(self):
938+
if isinstance(self.factor, Constant):
939+
return self.factor.data
940+
else:
941+
return self.factor
942+
925943
@property
926944
def spacing(self):
927-
s = self._factor.data if self._factor is not None else 1
928-
return s * self.parent.spacing
945+
return self.factor_data * self.parent.spacing
929946

930947
@property
931948
def factor(self):
932-
return self._factor if self._factor is not None else 1
949+
return self._factor if self.uses_symbolic_factor else 1
950+
951+
@cached_property
952+
def symbolic_factor(self):
953+
if not self.uses_symbolic_factor:
954+
return None
955+
elif isinstance(self.factor, Constant):
956+
return self.factor
957+
else:
958+
return SubsamplingFactor(
959+
name=f'{self.name}f', dtype=np.int32, is_const=True
960+
)
933961

934962
@property
935963
def condition(self):
@@ -950,9 +978,14 @@ def free_symbols(self):
950978
pass
951979
return retval
952980

953-
def _arg_values(self, interval, grid=None, **kwargs):
954-
# Parent dimension define the interval
955-
fact = self._factor.data if self._factor is not None else 1
981+
def _arg_values(self, interval, grid=None, args=None, **kwargs):
982+
if not self.uses_symbolic_factor:
983+
return {}
984+
985+
args = args or {}
986+
fname = self.symbolic_factor.name
987+
fact = kwargs.get(fname, args.get(fname, self.factor_data))
988+
956989
toint = lambda x: math.ceil(x / fact)
957990
vals = {}
958991
try:
@@ -965,6 +998,8 @@ def _arg_values(self, interval, grid=None, **kwargs):
965998
except (KeyError, TypeError):
966999
pass
9671000

1001+
vals[self.symbolic_factor.name] = fact
1002+
9681003
return vals
9691004

9701005
def _arg_defaults(self, _min=None, size=None, alias=None):
@@ -974,15 +1009,9 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
9741009
# `factor` endpoints are legal, so we return them all. It's then
9751010
# up to the caller to decide which one to pick upon reduction
9761011
dim = alias or self
977-
if dim.condition is not None or size is None or dim._factor is None:
978-
return defaults
979-
try:
980-
# Is it a symbolic factor?
981-
factor = defaults[dim._factor.name] = self._factor.data
982-
except AttributeError:
983-
factor = dim._factor
984-
985-
defaults[dim.parent.max_name] = range(0, factor*size - 1)
1012+
if dim.uses_symbolic_factor:
1013+
factor = defaults[dim.symbolic_factor.name] = self.factor_data
1014+
defaults[dim.parent.max_name] = range(0, factor*size - 1)
9861015

9871016
return defaults
9881017

devito/types/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def spacing_map(self):
305305
# Special case subsampling: `Grid.dimensions` -> (xb, yb, zb)`
306306
# where `xb, yb, zb` are ConditionalDimensions whose parents
307307
# are SpaceDimensions
308-
mapper[d.root.spacing] = s/self.dtype(d.factor.data)
308+
mapper[d.root.spacing] = s/self.dtype(d.factor)
309309
elif d.is_Space:
310310
# Typical case: `Grid.dimensions` -> (x, y, z)` where `x, y, z` are
311311
# the SpaceDimensions

tests/test_dimension.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import product
2+
from copy import deepcopy
23

34
import numpy as np
45
from sympy import And, Or
@@ -16,6 +17,7 @@
1617
from devito.ir import SymbolRegistry
1718
from devito.symbolics import indexify, retrieve_functions, IntDiv, INT
1819
from devito.types import Array, StencilDimension, Symbol
20+
from devito.types.basic import Scalar
1921
from devito.types.dimension import AffineIndexAccessFunction, Thickness
2022

2123

@@ -1011,9 +1013,9 @@ def test_issue_1592(self):
10111013
op = Operator(Eq(v.forward, v.dx))
10121014
op.apply(time=6)
10131015
exprs = FindNodes(Expression).visit(op)
1014-
assert exprs[-1].expr.lhs.indices[0] == IntDiv(time, time_sub.factor) + 1
1015-
assert time_sub.factor.data == 2
1016-
assert time_sub.factor.is_Constant
1016+
assert exprs[-1].expr.lhs.indices[0] == IntDiv(time, time_sub.symbolic_factor) + 1
1017+
assert time_sub.factor == 2
1018+
assert isinstance(time_sub.symbolic_factor, Scalar)
10171019

10181020
def test_issue_1753(self):
10191021
grid = Grid(shape=(3, 3, 3))
@@ -1896,6 +1898,41 @@ def test_cond_notime(self):
18961898
op(time_m=1, time_M=nt-1, dt=1)
18971899
assert norm(g, order=1) == norm(sum(usaved, dims=time_under), order=1)
18981900

1901+
def test_cond_copy(self):
1902+
grid = Grid((11, 11, 11))
1903+
time = grid.time_dim
1904+
1905+
cd = ConditionalDimension(name='tsub', parent=time, factor=5)
1906+
u = TimeFunction(name='u', grid=grid, space_order=4, time_order=2, save=Buffer(2))
1907+
u1 = TimeFunction(name='u1', grid=grid, space_order=0,
1908+
time_order=0, save=5, time_dim=cd)
1909+
u2 = TimeFunction(name='u2', grid=grid, space_order=0,
1910+
time_order=0, save=5, time_dim=cd)
1911+
1912+
# Mimic what happens when an operator is copied
1913+
u12 = deepcopy(u1)
1914+
u22 = deepcopy(u2)
1915+
1916+
op = Operator([Eq(u.forward, u.laplace), Eq(u12, u), Eq(u22, u)])
1917+
assert len([p for p in op.parameters if p.name == 'tsubf']) == 1
1918+
1919+
def test_const_factor(self):
1920+
grid = Grid(shape=(4, 4))
1921+
time = grid.time_dim
1922+
1923+
f1 = 4
1924+
f2 = Constant(name='f2', dtype=np.int32, value=4)
1925+
t1 = ConditionalDimension('t_sub', parent=time, factor=f1)
1926+
t2 = ConditionalDimension('t_sub2', parent=time, factor=f2)
1927+
1928+
assert isinstance(t1.symbolic_factor, Scalar)
1929+
assert t1.factor == f1
1930+
1931+
assert t2.symbolic_factor.is_Constant
1932+
assert t2.factor == f2
1933+
assert t2.factor.data == f1
1934+
assert t2.spacing == t1.spacing
1935+
18991936

19001937
class TestCustomDimension:
19011938

tests/test_pickle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def test_conditional_dimension(self, pickle):
320320
assert cd.name == new_cd.name
321321
assert cd.parent.name == new_cd.parent.name
322322
assert cd.factor == new_cd.factor
323+
assert cd.symbolic_factor == new_cd.symbolic_factor
323324
assert cd.condition == new_cd.condition
324325

325326
def test_incr_dimension(self, pickle):
@@ -603,7 +604,7 @@ def test_equation(self, pickle):
603604
assert new_eq.lhs.name == f.name
604605
assert str(new_eq.rhs) == 'f(x) + 1'
605606
assert new_eq.implicit_dims[0].name == 'xs'
606-
assert new_eq.implicit_dims[0].factor.data == 4
607+
assert new_eq.implicit_dims[0].factor == 4
607608

608609
@pytest.mark.parametrize('typ', [ctypes.c_float, 'struct truct'])
609610
def test_Cast(self, pickle, typ):

0 commit comments

Comments
 (0)