From 93546d8a3d9d3fdaa51994302afb92b192b80e0d Mon Sep 17 00:00:00 2001 From: mloubout Date: Tue, 14 Apr 2026 14:42:19 -0400 Subject: [PATCH] api: add support for no interp (interp_order=0) --- devito/types/basic.py | 5 +++++ devito/types/dense.py | 4 +--- tests/test_differentiable.py | 19 ++++++++++++++++--- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/devito/types/basic.py b/devito/types/basic.py index 3239638a98..e6dea0e8e0 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1074,6 +1074,11 @@ def _evaluate(self, **kwargs): io = self.interp_order retval = self.subs({i.subs(subs): self.indices_ref[d] for d, i in mapper.items()}) + + if io == 0: + # No interpolation, just substitution (e.g nearest grid point) + return retval + if self.is_harmonic: retval = retval._inv(retval, safe=self.is_harmonic_safe) diff --git a/devito/types/dense.py b/devito/types/dense.py index 5fcbe8d1d7..30205a9712 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -1095,8 +1095,6 @@ def __init_finalize__(self, *args, **kwargs): interp_order = kwargs.get('interp_order', 2) if not is_integer(interp_order): raise TypeError("`interp_order` must be an integer") - elif interp_order < 1: - raise ValueError("`interp_order` must be >= 2") elif interp_order > self._space_order and self._space_order > 1: raise ValueError("`interp_order` must be <= `space_order`") self._interp_order = interp_order @@ -1121,7 +1119,7 @@ def _fd_priority(self): return 1 if self.staggered.on_node else 2 def _eval_at(self, func): - if self.staggered == func.staggered: + if self.staggered == func.staggered or self.interp_order == 0: return self mapper = {} diff --git a/tests/test_differentiable.py b/tests/test_differentiable.py index 65554d4487..ab8d724b60 100644 --- a/tests/test_differentiable.py +++ b/tests/test_differentiable.py @@ -111,9 +111,6 @@ def test_avg_mode(ndim, io): with pytest.raises(ValueError): # interp_order > space_order Function(name="a", grid=grid, interp_order=8, space_order=4) - with pytest.raises(ValueError): - # interp_order < 1 - Function(name="a", grid=grid, interp_order=0, space_order=4) with pytest.raises(TypeError): # interp_order not int Function(name="a", grid=grid, interp_order=2.5, space_order=4) @@ -152,3 +149,19 @@ def test_avg_mode(ndim, io): assert sympy.simplify(b_avg.args[0] - expected) == 0 assert isinstance(b_avg, SafeInv) assert b_avg.base == b + + +def test_no_interp(): + grid = Grid((10, 10)) + x = grid.dimensions[0] + a = Function(name="a", grid=grid, staggered=NODE, interp_order=0) + sa = Function(name="as", grid=grid, staggered=x) + + assert a._eval_at(sa) == a + assert sa._eval_at(a) == sa._subs(x, x - x.spacing/2) + assert (a*sa)._eval_at(sa) == a*sa + assert (a + sa)._eval_at(sa) == a + sa + + a_shift = a._subs(x, x + x.spacing / 2) + # Should just do nearest grid point, so shift back to original + assert a_shift.evaluate == a