Skip to content

Commit b4d8fe9

Browse files
authored
Merge pull request #2768 from devitocodes/stagg-param
Stagg param
2 parents 3169bb5 + 5607c73 commit b4d8fe9

3 files changed

Lines changed: 25 additions & 4 deletions

File tree

devito/finite_differences/differentiable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,9 @@ def _(expr, x0, **kwargs):
11331133
def _(expr, x0, **kwargs):
11341134
from devito.finite_differences.derivative import Derivative
11351135
x0_expr = {d: v for d, v in x0.items() if v is not expr.indices_ref[d]}
1136-
if x0_expr:
1136+
if expr.is_parameter:
1137+
return expr
1138+
elif x0_expr:
11371139
dims = tuple((d, 0) for d in x0_expr)
11381140
fd_o = tuple([expr.interp_order]*len(dims))
11391141
return Derivative(expr, *dims, fd_order=fd_o, x0=x0_expr)

devito/types/dense.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ def shape_global(self):
301301

302302
@property
303303
def symbolic_shape(self):
304-
return tuple(self._C_get_field(FULL, d).size for d in self.dimensions)
304+
return DimensionTuple(*[self._C_get_field(FULL, d).size for d in self.dimensions],
305+
getters=self.dimensions)
305306

306307
@property
307308
def size_global(self):
@@ -1018,7 +1019,7 @@ class Function(DiscreteFunction):
10181019
is_autopaddable = True
10191020

10201021
__rkwargs__ = (DiscreteFunction.__rkwargs__ +
1021-
('space_order', 'interp_order', 'dimensions'))
1022+
('space_order', 'interp_order', 'dimensions', 'is_parameter'))
10221023

10231024
def _cache_meta(self):
10241025
# Attach additional metadata to self's cache entry
@@ -1063,7 +1064,7 @@ def __init_finalize__(self, *args, **kwargs):
10631064
# Used at operator evaluation to evaluate the Function at the
10641065
# variable location (i.e. if the variable is staggered in x the
10651066
# parameter has to be computed at x + hx/2)
1066-
self._is_parameter = kwargs.get('parameter', False)
1067+
self._is_parameter = kwargs.get('parameter', kwargs.get('is_parameter', False))
10671068

10681069
def __fd_setup__(self):
10691070
"""

tests/test_derivatives.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,24 @@ def test_nested_call(self):
762762
# Should be commutative
763763
assert simplify(deriv.evaluate - derivc.evaluate) == 0
764764

765+
def test_param_stagg_inner(self):
766+
space_order = 2
767+
nx, ny = 5, 5
768+
769+
grid = Grid((nx, ny))
770+
771+
x, y = grid.dimensions
772+
yp = y + y.spacing / 2
773+
xp = x + x.spacing / 2
774+
775+
f = TimeFunction(name="f", grid=grid, space_order=space_order, staggered=y)
776+
p = Function(name="p", grid=grid, space_order=space_order, parameter=True)
777+
g = TimeFunction(name="g", grid=grid, space_order=space_order, staggered=(x, y))
778+
779+
eqn = Eq(g, (p * f).dx)
780+
eqne = eqn.evaluate.rhs
781+
assert simplify(eqne - (p._subs(y, yp).evaluate * f).dx(x0=xp).evaluate) == 0
782+
765783

766784
class TestTwoStageEvaluation:
767785

0 commit comments

Comments
 (0)