Skip to content

Commit 715c3bb

Browse files
committed
api: update throughout for staggered setup
1 parent 2598976 commit 715c3bb

5 files changed

Lines changed: 47 additions & 26 deletions

File tree

devito/finite_differences/derivative.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ def _eval_at(self, func):
405405
return self
406406
# For basic equation of the form f = Derivative(g, ...) we can just
407407
# compare staggering
408-
if self.expr.staggered == func.staggered:
408+
if self.expr.staggered == func.staggered and \
409+
not (self.expr.is_Add or self.expr.is_Mul):
409410
return self
410411

411412
x0 = func.indices_ref.getters

devito/finite_differences/rsfd.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import wraps
22

3-
from devito.types import NODE
43
from devito.types.dimension import StencilDimension
54
from .differentiable import Weights, DiffDerivative
65
from .tools import generate_indices, fd_weights_registry
@@ -101,12 +100,7 @@ def check_staggering(func):
101100
def wrapper(expr, dim, x0=None, expand=True):
102101
grid = expr.grid
103102
x0 = {k: v for k, v in x0.items() if k.is_Space}
104-
if expr.staggered is NODE or expr.staggered is None:
105-
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
106-
elif expr.staggered == grid.dimensions:
107-
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
108-
else:
109-
cond = False
103+
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
110104
if cond:
111105
return func(expr, dim, x0=x0, expand=expand)
112106
else:
@@ -117,7 +111,8 @@ def wrapper(expr, dim, x0=None, expand=True):
117111
@check_staggering
118112
def d45(expr, dim, x0=None, expand=True):
119113
"""
120-
RSFD approximation of the derivative of `expr` along `dim` at point `x0`.
114+
Rotated staggered grid finite-differences (RSFD) discretization
115+
of the derivative of `expr` along `dim` at point `x0`.
121116
122117
Parameters
123118
----------
@@ -132,7 +127,8 @@ def d45(expr, dim, x0=None, expand=True):
132127
"""
133128
# Make sure the grid supports RSFD
134129
if expr.grid.dim not in [2, 3]:
135-
raise ValueError('RSFD only supported in 2D and 3D')
130+
raise ValueError('Rotated staggered grid finite-differences (RSFD)'
131+
' only supported in 2D and 3D')
136132

137133
# Diagonals weights
138134
w = dir_weights[(dim.name, expr.grid.dim)]

devito/types/basic.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,12 @@ def __new__(cls, *args, **kwargs):
699699
args, kwargs = cls.__args_setup__(*args, **kwargs)
700700

701701
# Extract the `indices`, as perhaps they're explicitly provided
702-
dimensions, indices, staggered = cls.__indices_setup__(*args, **kwargs)
702+
dim_args = cls.__indices_setup__(*args, **kwargs)
703+
try:
704+
dimensions, indices, staggered = dim_args
705+
except ValueError:
706+
dimensions, indices = dim_args
707+
staggered = (sympy.S.Zero,)*len(dimensions)
703708

704709
# If it's an alias or simply has a different name, ignore `function`.
705710
# These cases imply the construction of a new AbstractFunction off
@@ -929,7 +934,10 @@ def indices(self):
929934
@property
930935
def staggered(self):
931936
"""The staggered indices of the object."""
932-
return DimensionTuple(*self._staggered, getters=self.dimensions)
937+
if self._staggered:
938+
return DimensionTuple(*self._staggered, getters=self.dimensions)
939+
else:
940+
return None
933941

934942
@property
935943
def indices_ref(self):
@@ -1496,6 +1504,10 @@ def name(self):
14961504
return self.__class__.__name__
14971505

14981506
def _rebuild(self, *args, **kwargs):
1507+
# Plain `func` call (row, col, comps)
1508+
if not kwargs.keys() & self.__rkwargs__:
1509+
assert len(args) == 3
1510+
return self._new(*args, **kwargs)
14991511
# We need to rebuild the components with the new name then
15001512
# rebuild the matrix
15011513
newname = kwargs.pop('name', self.name)

devito/types/dense.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ def __fd_setup__(self):
10421042

10431043
@cached_property
10441044
def _fd_priority(self):
1045-
return 1 if self.staggered in [NODE, None] else 2
1045+
return 1 if not self.staggered or all(s == 0 for s in self.staggered) else 2
10461046

10471047
@property
10481048
def is_parameter(self):
@@ -1066,18 +1066,26 @@ def __staggered_setup__(cls, dimensions, **kwargs):
10661066
* 0 to non-staggered dimensions;
10671067
* 1 to staggered dimensions.
10681068
"""
1069-
stagg = kwargs.get('staggered', None)
1070-
if stagg is CELL:
1069+
stagg = kwargs.get('staggered')
1070+
if stagg is None:
1071+
return tuple()
1072+
elif stagg is CELL:
10711073
staggered = (sympy.S.One for d in dimensions)
1072-
elif stagg in [None, NODE]:
1074+
elif stagg in [NODE]:
10731075
staggered = (sympy.S.Zero for d in dimensions)
10741076
elif all(is_integer(s) for s in as_tuple(stagg)):
10751077
# Staggering is already a tuple likely from rebuild
10761078
assert len(stagg) == len(dimensions)
1077-
return tuple(stagg)
1079+
staggered = stagg
10781080
else:
1079-
staggered = (sympy.S.One if d in as_tuple(stagg) else sympy.S.Zero
1080-
for d in dimensions)
1081+
staggered = []
1082+
for d in dimensions:
1083+
if d in as_tuple(stagg):
1084+
staggered.append(sympy.S.One)
1085+
elif -d in as_tuple(stagg):
1086+
staggered.append(sympy.S.NegativeOne)
1087+
else:
1088+
staggered.append(sympy.S.Zero)
10811089
return tuple(staggered)
10821090

10831091
@classmethod
@@ -1097,9 +1105,12 @@ def __indices_setup__(cls, *args, **kwargs):
10971105
assert len(args) == len(dimensions)
10981106
staggered_indices = tuple(args)
10991107
else:
1100-
# Staggered indices
1101-
staggered_indices = (d + i * d.spacing / 2
1102-
for d, i in zip(dimensions, staggered))
1108+
if not staggered:
1109+
staggered_indices = (d for d in dimensions)
1110+
else:
1111+
# Staggered indices
1112+
staggered_indices = (d + i * d.spacing / 2
1113+
for d, i in zip(dimensions, staggered))
11031114
return tuple(dimensions), tuple(staggered_indices), staggered
11041115

11051116
@property
@@ -1392,7 +1403,6 @@ def __fd_setup__(self):
13921403
@classmethod
13931404
def __indices_setup__(cls, *args, **kwargs):
13941405
dimensions = kwargs.get('dimensions')
1395-
staggered = kwargs.get('staggered')
13961406

13971407
if dimensions is None:
13981408
save = kwargs.get('save')
@@ -1407,7 +1417,7 @@ def __indices_setup__(cls, *args, **kwargs):
14071417
dimensions.insert(cls._time_position, time_dim)
14081418

14091419
return Function.__indices_setup__(
1410-
*args, dimensions=dimensions, staggered=staggered
1420+
*args, dimensions=dimensions, staggered=kwargs.get('staggered')
14111421
)
14121422

14131423
@classmethod
@@ -1446,7 +1456,7 @@ def __shape_setup__(cls, **kwargs):
14461456

14471457
@cached_property
14481458
def _fd_priority(self):
1449-
return 2.1 if self.staggered in [NODE, None] else 2.2
1459+
return 2.1 if not self.staggered or all(s == 0 for s in self.staggered) else 2.2
14501460

14511461
@property
14521462
def time_order(self):
@@ -1600,7 +1610,7 @@ def __indices_setup__(cls, **kwargs):
16001610
# Sanity check
16011611
assert not any(d.is_NonlinearDerived for d in dimensions)
16021612

1603-
return dimensions, dimensions, (sympy.S.Zero for _ in dimensions)
1613+
return dimensions, dimensions
16041614

16051615
def __halo_setup__(self, **kwargs):
16061616
pointer_dim = kwargs.get('pointer_dim')

devito/types/tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ class TensorFunction(AbstractTensor):
6969
_class_priority = 10
7070
_op_priority = Differentiable._op_priority + 1.
7171

72+
__rkwargs__ = AbstractTensor.__rkwargs__ + ('dimensions', 'space_order')
73+
7274
def __init_finalize__(self, *args, **kwargs):
7375
super().__init_finalize__(*args, **kwargs)
7476
grid = kwargs.get('grid')

0 commit comments

Comments
 (0)