Skip to content

Commit adbf3bf

Browse files
authored
Merge pull request #2557 from devitocodes/coordinates
dsl: ensure SparseFunction coordinates and point symbols are always real
2 parents 2247707 + e21189a commit adbf3bf

4 files changed

Lines changed: 32 additions & 1 deletion

File tree

devito/operations/interpolators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,8 @@ def _weights(self, subdomain=None):
472472
@cached_property
473473
def _point_symbols(self):
474474
"""Symbol for coordinate value in each Dimension of the point."""
475-
return DimensionTuple(*(Symbol(name='p%s' % d, dtype=self.sfunction.dtype)
475+
dtype = self.sfunction.coordinates.dtype
476+
return DimensionTuple(*(Symbol(name=f'p{d}', dtype=dtype)
476477
for d in self.grid.dimensions),
477478
getters=self.grid.dimensions)
478479

devito/types/sparse.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
202202
else:
203203
dtype = dtype or self.dtype
204204

205+
# Complex coordinates are not valid, so fall back to corresponding
206+
# real floating point type if dtype is complex.
207+
dtype = dtype(0).real.__class__
208+
205209
sf = SparseSubFunction(
206210
name=name, dtype=dtype, dimensions=dimensions,
207211
shape=shape, space_order=0, initializer=key, alias=self.alias,

tests/test_interpolation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,18 @@ def test_sinc_accuracy(r, tol):
841841
assert err_lin > 0.01
842842

843843

844+
@pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32),
845+
(np.complex128, np.float64)])
846+
def test_point_symbol_types(dtype, expected):
847+
"""Test that positions are always real"""
848+
grid = Grid(shape=(11,))
849+
s = SparseFunction(name='src', npoint=1,
850+
grid=grid, dtype=dtype)
851+
point_symbol = s.interpolator._point_symbols[0]
852+
853+
assert point_symbol.dtype is expected
854+
855+
844856
class SD0(SubDomain):
845857
name = 'sd0'
846858

tests/test_sparse.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,20 @@ def test_mpi_no_data(self, mode):
497497
ftest.data[:] = expected
498498
assert np.all(m.data[0, :, :] == ftest.data[:])
499499

500+
@pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32),
501+
(np.complex128, np.float64),
502+
(np.float16, np.float16)])
503+
def test_coordinate_type(self, dtype, expected):
504+
"""
505+
Test that coordinates are always real and SparseFunction dtype is
506+
otherwise preserved.
507+
"""
508+
grid = Grid(shape=(11,))
509+
s = SparseFunction(name='src', npoint=1,
510+
grid=grid, dtype=dtype)
511+
512+
assert s.coordinates.dtype is expected
513+
500514

501515
if __name__ == "__main__":
502516
TestMatrixSparseTimeFunction().test_mpi_no_data()

0 commit comments

Comments
 (0)