Skip to content

Commit e21189a

Browse files
committed
tests: Ensure float type SparseFunctions have matching coordinate dtype
1 parent e68992d commit e21189a

2 files changed

Lines changed: 7 additions & 5 deletions

File tree

devito/types/sparse.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,7 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
204204

205205
# Complex coordinates are not valid, so fall back to corresponding
206206
# real floating point type if dtype is complex.
207-
if issubclass(dtype, np.complexfloating):
208-
dtype = {np.complex64: np.float32,
209-
np.complex128: np.float64}.get(dtype, np.float32)
207+
dtype = dtype(0).real.__class__
210208

211209
sf = SparseSubFunction(
212210
name=name, dtype=dtype, dimensions=dimensions,

tests/test_sparse.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,13 @@ def test_mpi_no_data(self, mode):
498498
assert np.all(m.data[0, :, :] == ftest.data[:])
499499

500500
@pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32),
501-
(np.complex128, np.float64)])
501+
(np.complex128, np.float64),
502+
(np.float16, np.float16)])
502503
def test_coordinate_type(self, dtype, expected):
503-
"""Test that coordinates are always real"""
504+
"""
505+
Test that coordinates are always real and SparseFunction dtype is
506+
otherwise preserved.
507+
"""
504508
grid = Grid(shape=(11,))
505509
s = SparseFunction(name='src', npoint=1,
506510
grid=grid, dtype=dtype)

0 commit comments

Comments
 (0)