|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import pickle |
| 4 | +from unittest.mock import patch |
| 5 | + |
3 | 6 | import numpy as np |
4 | 7 | import pytest |
5 | 8 |
|
6 | 9 | import xarray as xr |
7 | | -from xarray.tests import assert_allclose, assert_array_equal, mock |
| 10 | +import xarray.ufuncs as xu |
| 11 | +from xarray.tests import assert_allclose, assert_array_equal, mock, requires_dask |
8 | 12 | from xarray.tests import assert_identical as assert_identical_ |
9 | 13 |
|
10 | 14 |
|
@@ -155,3 +159,108 @@ def test_gufuncs(): |
155 | 159 | fake_gufunc = mock.Mock(signature="(n)->()", autospec=np.sin) |
156 | 160 | with pytest.raises(NotImplementedError, match=r"generalized ufuncs"): |
157 | 161 | xarray_obj.__array_ufunc__(fake_gufunc, "__call__", xarray_obj) |
| 162 | + |
| 163 | + |
| 164 | +class DuckArray(np.ndarray): |
| 165 | + # Minimal subclassed duck array with its own self-contained namespace, |
| 166 | + # which implements a few ufuncs |
| 167 | + def __new__(cls, array): |
| 168 | + obj = np.asarray(array).view(cls) |
| 169 | + return obj |
| 170 | + |
| 171 | + def __array_namespace__(self): |
| 172 | + return DuckArray |
| 173 | + |
| 174 | + @staticmethod |
| 175 | + def sin(x): |
| 176 | + return np.sin(x) |
| 177 | + |
| 178 | + @staticmethod |
| 179 | + def add(x, y): |
| 180 | + return x + y |
| 181 | + |
| 182 | + |
| 183 | +class DuckArray2(DuckArray): |
| 184 | + def __array_namespace__(self): |
| 185 | + return DuckArray2 |
| 186 | + |
| 187 | + |
| 188 | +class TestXarrayUfuncs: |
| 189 | + @pytest.fixture(autouse=True) |
| 190 | + def setUp(self): |
| 191 | + self.x = xr.DataArray([1, 2, 3]) |
| 192 | + self.xd = xr.DataArray(DuckArray([1, 2, 3])) |
| 193 | + self.xd2 = xr.DataArray(DuckArray2([1, 2, 3])) |
| 194 | + self.xt = xr.DataArray(np.datetime64("2021-01-01", "ns")) |
| 195 | + |
| 196 | + @pytest.mark.filterwarnings("ignore::RuntimeWarning") |
| 197 | + @pytest.mark.parametrize("name", xu.__all__) |
| 198 | + def test_ufuncs(self, name, request): |
| 199 | + xu_func = getattr(xu, name) |
| 200 | + np_func = getattr(np, name, None) |
| 201 | + if np_func is None and np.lib.NumpyVersion(np.__version__) < "2.0.0": |
| 202 | + pytest.skip(f"Ufunc {name} is not available in numpy {np.__version__}.") |
| 203 | + |
| 204 | + if name == "isnat": |
| 205 | + args = (self.xt,) |
| 206 | + elif hasattr(np_func, "nin") and np_func.nin == 2: |
| 207 | + args = (self.x, self.x) |
| 208 | + else: |
| 209 | + args = (self.x,) |
| 210 | + |
| 211 | + expected = np_func(*args) |
| 212 | + actual = xu_func(*args) |
| 213 | + |
| 214 | + if name in ["angle", "iscomplex"]: |
| 215 | + np.testing.assert_equal(expected, actual.values) |
| 216 | + else: |
| 217 | + assert_identical(actual, expected) |
| 218 | + |
| 219 | + def test_ufunc_pickle(self): |
| 220 | + a = 1.0 |
| 221 | + cos_pickled = pickle.loads(pickle.dumps(xu.cos)) |
| 222 | + assert_identical(cos_pickled(a), xu.cos(a)) |
| 223 | + |
| 224 | + def test_ufunc_scalar(self): |
| 225 | + actual = xu.sin(1) |
| 226 | + assert isinstance(actual, float) |
| 227 | + |
| 228 | + def test_ufunc_duck_array_dataarray(self): |
| 229 | + actual = xu.sin(self.xd) |
| 230 | + assert isinstance(actual.data, DuckArray) |
| 231 | + |
| 232 | + def test_ufunc_duck_array_variable(self): |
| 233 | + actual = xu.sin(self.xd.variable) |
| 234 | + assert isinstance(actual.data, DuckArray) |
| 235 | + |
| 236 | + def test_ufunc_duck_array_dataset(self): |
| 237 | + ds = xr.Dataset({"a": self.xd}) |
| 238 | + actual = xu.sin(ds) |
| 239 | + assert isinstance(actual.a.data, DuckArray) |
| 240 | + |
| 241 | + @requires_dask |
| 242 | + def test_ufunc_duck_dask(self): |
| 243 | + import dask.array as da |
| 244 | + |
| 245 | + x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) |
| 246 | + actual = xu.sin(x) |
| 247 | + assert isinstance(actual.data._meta, DuckArray) |
| 248 | + |
| 249 | + @requires_dask |
| 250 | + @pytest.mark.xfail(reason="dask ufuncs currently dispatch to numpy") |
| 251 | + def test_ufunc_duck_dask_no_array_ufunc(self): |
| 252 | + import dask.array as da |
| 253 | + |
| 254 | + # dask ufuncs currently only preserve duck arrays that implement __array_ufunc__ |
| 255 | + with patch.object(DuckArray, "__array_ufunc__", new=None, create=True): |
| 256 | + x = xr.DataArray(da.from_array(DuckArray(np.array([1, 2, 3])))) |
| 257 | + actual = xu.sin(x) |
| 258 | + assert isinstance(actual.data._meta, DuckArray) |
| 259 | + |
| 260 | + def test_ufunc_mixed_arrays_compatible(self): |
| 261 | + actual = xu.add(self.xd, self.x) |
| 262 | + assert isinstance(actual.data, DuckArray) |
| 263 | + |
| 264 | + def test_ufunc_mixed_arrays_incompatible(self): |
| 265 | + with pytest.raises(ValueError, match=r"Mixed array types"): |
| 266 | + xu.add(self.xd, self.xd2) |
0 commit comments