Skip to content

Commit c91cee0

Browse files
authored
Merge pull request #2604 from devitocodes/tens-kwargs
api: fix kwargs processing for tensor functions
2 parents 3047085 + 90aeff2 commit c91cee0

6 files changed

Lines changed: 57 additions & 18 deletions

File tree

devito/builtins/initializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
7979
eqs = [eq.xreplace(subs) for eq in eqs]
8080

8181
op = dv.Operator(eqs, name=name, **kwargs)
82+
8283
try:
8384
op()
8485
except ValueError:

devito/types/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,15 @@ def __subfunc_setup__(cls, *args, **kwargs):
14891489
"""Setup each component of the tensor as a Devito type."""
14901490
return []
14911491

1492+
@classmethod
1493+
def _sympify(self, arg):
1494+
# This is used internally by sympy to process arguments at rebuilt. And since
1495+
# some of our properties are non-sympyfiable we need to have a fallback
1496+
try:
1497+
return super()._sympify(arg)
1498+
except sympy.SympifyError:
1499+
return arg
1500+
14921501
@property
14931502
def grid(self):
14941503
"""

devito/types/dense.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def __init_finalize__(self, *args, **kwargs):
10191019

10201020
# Space order
10211021
space_order = kwargs.get('space_order', 1)
1022-
if isinstance(space_order, int):
1022+
if is_integer(space_order):
10231023
self._space_order = space_order
10241024
elif isinstance(space_order, tuple) and len(space_order) >= 2:
10251025
self._space_order = space_order[0]
@@ -1175,7 +1175,7 @@ def __halo_setup__(self, **kwargs):
11751175
halo = tuple(halo[d] for d in self.dimensions)
11761176
else:
11771177
space_order = kwargs.get('space_order', 1)
1178-
if isinstance(space_order, int):
1178+
if is_integer(space_order):
11791179
v = (space_order, space_order)
11801180
halo = [v if i.is_Space else (0, 0) for i in self.dimensions]
11811181

@@ -1208,12 +1208,12 @@ def __padding_setup__(self, **kwargs):
12081208
elif isinstance(padding, DimensionTuple):
12091209
padding = tuple(padding[d] for d in self.dimensions)
12101210

1211-
elif isinstance(padding, int):
1211+
elif is_integer(padding):
12121212
padding = tuple((0, padding) if d.is_Space else (0, 0)
12131213
for d in self.dimensions)
12141214

12151215
elif isinstance(padding, tuple) and len(padding) == self.ndim:
1216-
padding = tuple((0, i) if isinstance(i, int) else i for i in padding)
1216+
padding = tuple((0, i) if is_integer(i) else i for i in padding)
12171217

12181218
else:
12191219
raise TypeError("`padding` must be int or %d-tuple of ints" % self.ndim)
@@ -1398,7 +1398,7 @@ def __init_finalize__(self, *args, **kwargs):
13981398
self._time_order = kwargs.get('time_order', 1)
13991399
super().__init_finalize__(*args, **kwargs)
14001400

1401-
if not isinstance(self.time_order, int):
1401+
if not is_integer(self.time_order):
14021402
raise TypeError("`time_order` must be int")
14031403

14041404
self.save = kwargs.get('save')
@@ -1420,7 +1420,7 @@ def __indices_setup__(cls, *args, **kwargs):
14201420
time_dim = kwargs.get('time_dim')
14211421

14221422
if time_dim is None:
1423-
time_dim = grid.time_dim if isinstance(save, int) else grid.stepping_dim
1423+
time_dim = grid.time_dim if is_integer(save) else grid.stepping_dim
14241424
elif not (isinstance(time_dim, Dimension) and time_dim.is_Time):
14251425
raise TypeError("`time_dim` must be a time dimension")
14261426
dimensions = list(Function.__indices_setup__(**kwargs)[0])
@@ -1450,7 +1450,7 @@ def __shape_setup__(cls, **kwargs):
14501450
shape.insert(cls._time_position, time_order + 1)
14511451
elif isinstance(save, Buffer):
14521452
shape.insert(cls._time_position, save.val)
1453-
elif isinstance(save, int):
1453+
elif is_integer(save):
14541454
shape.insert(cls._time_position, save)
14551455
else:
14561456
raise TypeError("`save` can be None, int or Buffer, not %s" % type(save))

devito/types/tensor.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from functools import cached_property
33

44
import numpy as np
5+
from sympy.matrices.matrixbase import MatrixBase
56
from sympy.core.sympify import converter as sympify_converter
67

78
from devito.finite_differences import Differentiable
@@ -78,6 +79,20 @@ def __init_finalize__(self, *args, **kwargs):
7879
inds, _ = Function.__indices_setup__(grid=grid, dimensions=dimensions)
7980
self._space_dimensions = inds
8081

82+
@classmethod
83+
def _component_kwargs(cls, inds, **kwargs):
84+
"""
85+
Get the kwargs for a single component
86+
from the kwargs of the TensorFunction.
87+
"""
88+
kw = {}
89+
for k, v in kwargs.items():
90+
if isinstance(v, MatrixBase):
91+
kw[k] = v[inds]
92+
else:
93+
kw[k] = v
94+
return kw
95+
8196
@classmethod
8297
def __subfunc_setup__(cls, *args, **kwargs):
8398
"""
@@ -107,10 +122,12 @@ def __subfunc_setup__(cls, *args, **kwargs):
107122
start = i if (symm or diag) else 0
108123
stop = i + 1 if diag else len(dims)
109124
for j in range(start, stop):
110-
kwargs["name"] = "%s_%s%s" % (name, d.name, dims[j].name)
111-
kwargs["staggered"] = (stagg[i][j] if stagg is not None
112-
else (NODE if i == j else (d, dims[j])))
113-
funcs2[j] = cls._sub_type(**kwargs)
125+
staggj = (stagg[i][j] if stagg is not None
126+
else (NODE if i == j else (d, dims[j])))
127+
sub_kwargs = cls._component_kwargs((i, j), **kwargs)
128+
sub_kwargs.update({'name': f"{name}_{d.name}{dims[j].name}",
129+
'staggered': staggj})
130+
funcs2[j] = cls._sub_type(**sub_kwargs)
114131
funcs.append(funcs2)
115132

116133
# Symmetrize and fill diagonal if symmetric
@@ -169,7 +186,11 @@ def root_dimensions(self):
169186
@cached_property
170187
def space_order(self):
171188
"""The space order for all components."""
172-
return ({a.space_order for a in self} - {None}).pop()
189+
orders = self.applyfunc(lambda x: x.space_order)
190+
if len(set(orders)) > 1:
191+
return orders
192+
else:
193+
return orders[0]
173194

174195
@property
175196
def is_diagonal(self):
@@ -319,9 +340,10 @@ def __subfunc_setup__(cls, *args, **kwargs):
319340
stagg = kwargs.get("staggered", None)
320341
name = kwargs.get("name")
321342
for i, d in enumerate(dims):
322-
kwargs["name"] = "%s_%s" % (name, d.name)
323-
kwargs["staggered"] = stagg[i] if stagg is not None else d
324-
funcs.append(cls._sub_type(**kwargs))
343+
sub_kwargs = cls._component_kwargs(i, **kwargs)
344+
sub_kwargs.update({'name': f"{name}_{d.name}",
345+
'staggered': stagg[i] if stagg is not None else d})
346+
funcs.append(cls._sub_type(**sub_kwargs))
325347

326348
return funcs
327349

docker/Dockerfile.devito

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ ARG GROUP_ID=1000
1313

1414
################## Install devito ############################################
1515

16+
ENV PIP_USE_PEP517=1
1617
# Install pip dependencies
1718
RUN python3 -m venv /venv && \
18-
/venv/bin/pip install --no-cache-dir --upgrade pip && \
19+
/venv/bin/pip install --no-cache-dir --upgrade pip wheel setuptools && \
1920
/venv/bin/pip install --no-cache-dir jupyter && \
20-
/venv/bin/pip install --no-cache-dir --upgrade wheel setuptools && \
2121
ln -fs /app/nvtop/build/src/nvtop /venv/bin/nvtop
2222

2323
# Copy Devito

tests/test_tensors.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
import sympy
3-
from sympy import Rational
3+
from sympy import Rational, Matrix
44

55
import pytest
66

@@ -494,3 +494,10 @@ def test_diag(func1):
494494
else:
495495
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
496496
assert all(f2[i, i] == f1 for i in range(3))
497+
498+
499+
@pytest.mark.parametrize('func1', [TensorFunction, VectorFunction])
500+
def test_kwargs(func1):
501+
orders = Matrix([[1, 2], [3, 4]]) if func1 is TensorFunction else Matrix([1, 2])
502+
f = func1(name="f", grid=Grid((5, 5)), space_order=orders, symmetric=False)
503+
assert f.space_order == orders

0 commit comments

Comments
 (0)