From ecf972081ff8b76b74500788c2f4781d5155b05f Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 20 Mar 2025 16:53:12 +0000 Subject: [PATCH 1/4] dsl: SparseFunction coordinates now always real --- devito/operations/interpolators.py | 3 ++- devito/types/sparse.py | 6 ++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index cc3f882a3b..1f843f89ce 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -472,7 +472,8 @@ def _weights(self, subdomain=None): @cached_property def _point_symbols(self): """Symbol for coordinate value in each Dimension of the point.""" - return DimensionTuple(*(Symbol(name='p%s' % d, dtype=self.sfunction.dtype) + dtype = self.sfunction.coordinates.dtype + return DimensionTuple(*(Symbol(name='p%s' % d, dtype=dtype) for d in self.grid.dimensions), getters=self.grid.dimensions) diff --git a/devito/types/sparse.py b/devito/types/sparse.py index 31c0adab4f..6af523d67d 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -202,6 +202,12 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs): else: dtype = dtype or self.dtype + # Complex coordinates are not valid, so fall back to corresponding + # real floating point type if dtype is complex. + if issubclass(dtype, np.complexfloating): + dtype = {np.complex64: np.float32, + np.complex128: np.float64}.get(dtype, np.float32) + sf = SparseSubFunction( name=name, dtype=dtype, dimensions=dimensions, shape=shape, space_order=0, initializer=key, alias=self.alias, From 305d364eebcf350c9d86e2bfd54d5bd878c8944b Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 20 Mar 2025 17:03:30 +0000 Subject: [PATCH 2/4] tests: Add tests for coordinate and point symbol types when using complex SparseFunctions --- tests/test_interpolation.py | 12 ++++++++++++ tests/test_sparse.py | 10 ++++++++++ 2 files changed, 22 insertions(+) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 4b85b57cf8..c885f2b768 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -841,6 +841,18 @@ def test_sinc_accuracy(r, tol): assert err_lin > 0.01 +@pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32), + (np.complex128, np.float64)]) +def test_point_symbol_types(dtype, expected): + """Test that positions are always real""" + grid = Grid(shape=(11,)) + s = SparseFunction(name='src', npoint=1, + grid=grid, dtype=dtype) + point_symbol = s.interpolator._point_symbols[0] + + assert point_symbol.dtype is expected + + class SD0(SubDomain): name = 'sd0' diff --git a/tests/test_sparse.py b/tests/test_sparse.py index c6c882c3f2..ad5776927e 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -497,6 +497,16 @@ def test_mpi_no_data(self, mode): ftest.data[:] = expected assert np.all(m.data[0, :, :] == ftest.data[:]) + @pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32), + (np.complex128, np.float64)]) + def test_coordinate_type(self, dtype, expected): + """Test that coordinates are always real""" + grid = Grid(shape=(11,)) + s = SparseFunction(name='src', npoint=1, + grid=grid, dtype=dtype) + + assert s.coordinates.dtype is expected + if __name__ == "__main__": TestMatrixSparseTimeFunction().test_mpi_no_data() From e68992d8bdf74ea39533f3c18e606a8db7d7eb57 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Thu, 20 Mar 2025 17:06:35 +0000 Subject: [PATCH 3/4] misc: Fstring --- devito/operations/interpolators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 1f843f89ce..4407f903dc 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -473,7 +473,7 @@ def _weights(self, subdomain=None): def _point_symbols(self): """Symbol for coordinate value in each Dimension of the point.""" dtype = self.sfunction.coordinates.dtype - return DimensionTuple(*(Symbol(name='p%s' % d, dtype=dtype) + return DimensionTuple(*(Symbol(name=f'p{d}', dtype=dtype) for d in self.grid.dimensions), getters=self.grid.dimensions) From e21189a228cfa0ff21f78755474660928453fc00 Mon Sep 17 00:00:00 2001 From: Edward Caunt Date: Fri, 21 Mar 2025 09:43:16 +0000 Subject: [PATCH 4/4] tests: Ensure float type SparseFunctions have matching coordinate dtype --- devito/types/sparse.py | 4 +--- tests/test_sparse.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/devito/types/sparse.py b/devito/types/sparse.py index 6af523d67d..676efe2b61 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -204,9 +204,7 @@ def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs): # Complex coordinates are not valid, so fall back to corresponding # real floating point type if dtype is complex. - if issubclass(dtype, np.complexfloating): - dtype = {np.complex64: np.float32, - np.complex128: np.float64}.get(dtype, np.float32) + dtype = dtype(0).real.__class__ sf = SparseSubFunction( name=name, dtype=dtype, dimensions=dimensions, diff --git a/tests/test_sparse.py b/tests/test_sparse.py index ad5776927e..14906d3c74 100644 --- a/tests/test_sparse.py +++ b/tests/test_sparse.py @@ -498,9 +498,13 @@ def test_mpi_no_data(self, mode): assert np.all(m.data[0, :, :] == ftest.data[:]) @pytest.mark.parametrize('dtype, expected', [(np.complex64, np.float32), - (np.complex128, np.float64)]) + (np.complex128, np.float64), + (np.float16, np.float16)]) def test_coordinate_type(self, dtype, expected): - """Test that coordinates are always real""" + """ + Test that coordinates are always real and SparseFunction dtype is + otherwise preserved. + """ grid = Grid(shape=(11,)) s = SparseFunction(name='src', npoint=1, grid=grid, dtype=dtype)