diff --git a/tests/unit/mathutils/test_function_grid.py b/tests/unit/mathutils/test_function_grid.py index 994689544..b9ad94d04 100644 --- a/tests/unit/mathutils/test_function_grid.py +++ b/tests/unit/mathutils/test_function_grid.py @@ -1,5 +1,7 @@ """Unit tests for Function.from_grid() method and grid interpolation.""" +import warnings + import numpy as np import pytest @@ -137,3 +139,141 @@ def test_from_grid_backward_compatibility(): # Test callable function func3 = Function(lambda x: x**2) assert func3(2) == 4 + + +def test_shepard_fallback_warning(): + """Test that shepard_fallback is triggered and emits a warning. + + When linear_grid interpolation is set but no grid interpolator is available, + the Function class should fall back to shepard interpolation and emit a warning. + """ + # Create a 2D function with scattered points (not structured grid) + source = [(0, 0, 0), (1, 0, 1), (0, 1, 2), (1, 1, 3)] + func = Function( + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" + ) + + # Now manually change interpolation to linear_grid without setting up the grid + # This simulates the fallback scenario + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + func.set_interpolation("linear_grid") + + # Check that a warning was issued + assert len(w) == 1 + assert "falling back to shepard interpolation" in str(w[0].message) + + +def test_shepard_fallback_2d_interpolation(): + """Test that shepard_fallback produces correct interpolation for 2D data. + + This test verifies the fallback interpolation works correctly when + linear_grid is set without a grid interpolator. + """ + # Create a 2D function: z = x + y + source = [ + (0, 0, 0), # f(0, 0) = 0 + (1, 0, 1), # f(1, 0) = 1 + (0, 1, 1), # f(0, 1) = 1 + (1, 1, 2), # f(1, 1) = 2 + ] + + # First, create with shepard to get baseline results + func_shepard = Function( + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" + ) + + # Create another function and trigger the fallback + func_fallback = Function( + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" + ) + + # Trigger fallback + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # Suppress warnings for this test + func_fallback.set_interpolation("linear_grid") + + # Test that both produce the same results at exact points + assert func_fallback(0, 0) == func_shepard(0, 0) + assert func_fallback(1, 1) == func_shepard(1, 1) + + # Test interpolation at an intermediate point + result_fallback = func_fallback(0.5, 0.5) + result_shepard = func_shepard(0.5, 0.5) + assert np.isclose(result_fallback, result_shepard, atol=1e-6) + + +def test_shepard_fallback_3d_interpolation(): + """Test that shepard_fallback produces correct interpolation for 3D data. + + This test verifies the fallback interpolation works correctly for + 3-dimensional input data. + """ + # Create a 3D function: w = x + y + z + source = [ + (0, 0, 0, 0), # f(0, 0, 0) = 0 + (1, 0, 0, 1), # f(1, 0, 0) = 1 + (0, 1, 0, 1), # f(0, 1, 0) = 1 + (0, 0, 1, 1), # f(0, 0, 1) = 1 + (1, 1, 1, 3), # f(1, 1, 1) = 3 + ] + + # Create with shepard to get baseline results + func_shepard = Function( + source=source, + inputs=["x", "y", "z"], + outputs="w", + interpolation="shepard", + ) + + # Create another function and trigger the fallback + func_fallback = Function( + source=source, + inputs=["x", "y", "z"], + outputs="w", + interpolation="shepard", + ) + + # Trigger fallback + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + func_fallback.set_interpolation("linear_grid") + + # Test that both produce the same results at exact points + assert func_fallback(0, 0, 0) == func_shepard(0, 0, 0) + assert func_fallback(1, 1, 1) == func_shepard(1, 1, 1) + + # Test interpolation at an intermediate point + result_fallback = func_fallback(0.5, 0.5, 0.5) + result_shepard = func_shepard(0.5, 0.5, 0.5) + assert np.isclose(result_fallback, result_shepard, atol=1e-6) + + +def test_shepard_fallback_at_exact_data_points(): + """Test that shepard_fallback returns exact values at data points. + + When querying at exact data points, the fallback should return the + exact value stored at that point. + """ + # Create a 2D function + source = [ + (0, 0, 10), + (1, 0, 20), + (0, 1, 30), + (1, 1, 40), + ] + + func = Function( + source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" + ) + + # Trigger fallback + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + func.set_interpolation("linear_grid") + + # Test exact data points - should return exact values + assert func(0, 0) == 10 + assert func(1, 0) == 20 + assert func(0, 1) == 30 + assert func(1, 1) == 40