Skip to content

Commit 305d364

Browse files
committed
tests: Add tests for coordinate and point symbol types when using complex SparseFunctions
1 parent ecf9720 commit 305d364

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

tests/test_interpolation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,18 @@ def test_sinc_accuracy(r, tol):
841841
assert err_lin > 0.01
842842

843843

844+
@pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32),
845+
(np.complex128, np.float64)])
846+
def test_point_symbol_types(dtype, expected):
847+
"""Test that positions are always real"""
848+
grid = Grid(shape=(11,))
849+
s = SparseFunction(name='src', npoint=1,
850+
grid=grid, dtype=dtype)
851+
point_symbol = s.interpolator._point_symbols[0]
852+
853+
assert point_symbol.dtype is expected
854+
855+
844856
class SD0(SubDomain):
845857
name = 'sd0'
846858

tests/test_sparse.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,16 @@ def test_mpi_no_data(self, mode):
497497
ftest.data[:] = expected
498498
assert np.all(m.data[0, :, :] == ftest.data[:])
499499

500+
@pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32),
501+
(np.complex128, np.float64)])
502+
def test_coordinate_type(self, dtype, expected):
503+
"""Test that coordinates are always real"""
504+
grid = Grid(shape=(11,))
505+
s = SparseFunction(name='src', npoint=1,
506+
grid=grid, dtype=dtype)
507+
508+
assert s.coordinates.dtype is expected
509+
500510

501511
if __name__ == "__main__":
502512
TestMatrixSparseTimeFunction().test_mpi_no_data()

0 commit comments

Comments
 (0)