diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b2045ec9b72..237f45d3e05 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -33,6 +33,9 @@ Deprecations Bug fixes ~~~~~~~~~ + +- Support for recursively defined Arrays. Fixes repr and deepcopy. (:issue:`7111`, :pull:`7112`) + By `Michael Niklas `_. - Fixed :py:meth:`Dataset.transpose` to raise a more informative error. (:issue:`6502`, :pull:`7120`) By `Patrick Naylor `_ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 794984b7a1b..2deeba31eaa 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -516,7 +516,7 @@ def _overwrite_indexes( new_indexes.pop(name) if rename_dims: - new_variable.dims = [rename_dims.get(d, d) for d in new_variable.dims] + new_variable.dims = tuple(rename_dims.get(d, d) for d in new_variable.dims) return self._replace( variable=new_variable, coords=new_coords, indexes=new_indexes @@ -1169,7 +1169,15 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: -------- pandas.DataFrame.copy """ - variable = self.variable.copy(deep=deep, data=data) + return self._copy(deep=deep, data=data) + + def _copy( + self: T_DataArray, + deep: bool = True, + data: Any = None, + memo: dict[int, Any] | None = None, + ) -> T_DataArray: + variable = self.variable._copy(deep=deep, data=data, memo=memo) indexes, index_vars = self.xindexes.copy_indexes(deep=deep) coords = {} @@ -1177,17 +1185,17 @@ def copy(self: T_DataArray, deep: bool = True, data: Any = None) -> T_DataArray: if k in index_vars: coords[k] = index_vars[k] else: - coords[k] = v.copy(deep=deep) + coords[k] = v._copy(deep=deep, memo=memo) return self._replace(variable, coords, indexes=indexes) def __copy__(self: T_DataArray) -> T_DataArray: - return self.copy(deep=False) + return self._copy(deep=False) - def __deepcopy__(self: T_DataArray, memo=None) -> T_DataArray: - # memo does nothing but is required for compatibility with - # copy.deepcopy - return self.copy(deep=True) + def __deepcopy__( + self: T_DataArray, memo: dict[int, Any] | None = None + ) -> T_DataArray: + return self._copy(deep=True, memo=memo) # mutable objects should not be Hashable # https://github.com/python/mypy/issues/4266 diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ad5eeb6f97f..74fdcb94ce1 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1221,6 +1221,14 @@ def copy( -------- pandas.DataFrame.copy """ + return self._copy(deep=deep, data=data) + + def _copy( + self: T_Dataset, + deep: bool = False, + data: Mapping[Any, ArrayLike] | None = None, + memo: dict[int, Any] | None = None, + ) -> T_Dataset: if data is None: data = {} elif not utils.is_dict_like(data): @@ -1249,13 +1257,21 @@ def copy( if k in index_vars: variables[k] = index_vars[k] else: - variables[k] = v.copy(deep=deep, data=data.get(k)) + variables[k] = v._copy(deep=deep, data=data.get(k), memo=memo) - attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) - encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding) + attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) + encoding = ( + copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding) + ) return self._replace(variables, indexes=indexes, attrs=attrs, encoding=encoding) + def __copy__(self: T_Dataset) -> T_Dataset: + return self._copy(deep=False) + + def __deepcopy__(self: T_Dataset, memo: dict[int, Any] | None = None) -> T_Dataset: + return self._copy(deep=True, memo=memo) + def as_numpy(self: T_Dataset) -> T_Dataset: """ Coerces wrapped data and coordinates into numpy arrays, returning a Dataset. @@ -1332,14 +1348,6 @@ def _construct_dataarray(self, name: Hashable) -> DataArray: return DataArray(variable, coords, name=name, indexes=indexes, fastpath=True) - def __copy__(self: T_Dataset) -> T_Dataset: - return self.copy(deep=False) - - def __deepcopy__(self: T_Dataset, memo=None) -> T_Dataset: - # memo does nothing but is required for compatibility with - # copy.deepcopy - return self.copy(deep=True) - @property def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: """Places to look-up items for attribute-style access""" diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index be5e06becdf..37ad11b266d 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -8,6 +8,7 @@ from collections import defaultdict from datetime import datetime, timedelta from itertools import chain, zip_longest +from reprlib import recursive_repr from typing import Collection, Hashable import numpy as np @@ -385,7 +386,6 @@ def _mapping_repr( expand_option_name="display_expand_data_vars", ) - attrs_repr = functools.partial( _mapping_repr, title="Attributes", @@ -551,6 +551,7 @@ def short_data_repr(array): return f"[{array.size} values with dtype={array.dtype}]" +@recursive_repr("") def array_repr(arr): from .variable import Variable @@ -592,11 +593,12 @@ def array_repr(arr): summary.append(unindexed_dims_str) if arr.attrs: - summary.append(attrs_repr(arr.attrs)) + summary.append(attrs_repr(arr.attrs, max_rows=max_rows)) return "\n".join(summary) +@recursive_repr("") def dataset_repr(ds): summary = [f""] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 0320ea81052..eac705f8ba1 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -161,16 +161,13 @@ def equivalent(first: T, second: T) -> bool: # TODO: refactor to avoid circular import from . import duck_array_ops + if first is second: + return True if isinstance(first, np.ndarray) or isinstance(second, np.ndarray): return duck_array_ops.array_equiv(first, second) - elif isinstance(first, list) or isinstance(second, list): + if isinstance(first, list) or isinstance(second, list): return list_equiv(first, second) - else: - return ( - (first is second) - or (first == second) - or (pd.isnull(first) and pd.isnull(second)) - ) + return (first == second) or (pd.isnull(first) and pd.isnull(second)) def list_equiv(first, second): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 0733d0d5236..d99891aab98 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -918,7 +918,9 @@ def encoding(self, value): except ValueError: raise ValueError("encoding must be castable to a dictionary") - def copy(self, deep: bool = True, data: ArrayLike | None = None): + def copy( + self: T_Variable, deep: bool = True, data: ArrayLike | None = None + ) -> T_Variable: """Returns a copy of this object. If `deep=True`, the data array is loaded into memory and copied onto @@ -974,6 +976,14 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None): -------- pandas.DataFrame.copy """ + return self._copy(deep=deep, data=data) + + def _copy( + self: T_Variable, + deep: bool = True, + data: ArrayLike | None = None, + memo: dict[int, Any] | None = None, + ) -> T_Variable: if data is None: ndata = self._data @@ -982,7 +992,7 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None): ndata = indexing.MemoryCachedArray(ndata.array) if deep: - ndata = copy.deepcopy(ndata) + ndata = copy.deepcopy(ndata, memo) else: ndata = as_compatible_data(data) @@ -993,8 +1003,10 @@ def copy(self, deep: bool = True, data: ArrayLike | None = None): ) ) - attrs = copy.deepcopy(self._attrs) if deep else copy.copy(self._attrs) - encoding = copy.deepcopy(self._encoding) if deep else copy.copy(self._encoding) + attrs = copy.deepcopy(self._attrs, memo) if deep else copy.copy(self._attrs) + encoding = ( + copy.deepcopy(self._encoding, memo) if deep else copy.copy(self._encoding) + ) # note: dims is already an immutable tuple return self._replace(data=ndata, attrs=attrs, encoding=encoding) @@ -1016,13 +1028,13 @@ def _replace( encoding = copy.copy(self._encoding) return type(self)(dims, data, attrs, encoding, fastpath=True) - def __copy__(self): - return self.copy(deep=False) + def __copy__(self: T_Variable) -> T_Variable: + return self._copy(deep=False) - def __deepcopy__(self, memo=None): - # memo does nothing but is required for compatibility with - # copy.deepcopy - return self.copy(deep=True) + def __deepcopy__( + self: T_Variable, memo: dict[int, Any] | None = None + ) -> T_Variable: + return self._copy(deep=True, memo=memo) # mutable objects should not be hashable # https://github.com/python/mypy/issues/4266 diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index 3e32d0e366d..b199c697b21 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -219,7 +219,9 @@ def test_concat_errors(self): concat([data, data], "new_dim", coords=["not_found"]) with pytest.raises(ValueError, match=r"global attributes not"): - data0, data1 = deepcopy(split_data) + # call deepcopy seperately to get unique attrs + data0 = deepcopy(split_data[0]) + data1 = deepcopy(split_data[1]) data1.attrs["foo"] = "bar" concat([data0, data1], "dim1", compat="identical") assert_identical(data, concat([data0, data1], "dim1", compat="equals")) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 2536ada1155..ac6049872b8 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6488,6 +6488,28 @@ def test_deepcopy_obj_array() -> None: assert x0.values[0] is not x1.values[0] +def test_deepcopy_recursive() -> None: + # GH:issue:7111 + + # direct recursion + da = xr.DataArray([1, 2], dims=["x"]) + da.attrs["other"] = da + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + da.copy(deep=True) + + # indirect recursion + da2 = xr.DataArray([5, 6], dims=["y"]) + da.attrs["other"] = da2 + da2.attrs["other"] = da + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + da.copy(deep=True) + da2.copy(deep=True) + + def test_clip(da: DataArray) -> None: with raise_if_dask_computes(): result = da.clip(min=0.5) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 8c393009c1a..49090c2f6db 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6687,6 +6687,28 @@ def test_deepcopy_obj_array() -> None: assert x0["foo"].values[0] is not x1["foo"].values[0] +def test_deepcopy_recursive() -> None: + # GH:issue:7111 + + # direct recursion + ds = xr.Dataset({"a": (["x"], [1, 2])}) + ds.attrs["other"] = ds + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + ds.copy(deep=True) + + # indirect recursion + ds2 = xr.Dataset({"b": (["y"], [3, 4])}) + ds.attrs["other"] = ds2 + ds2.attrs["other"] = ds + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + ds.copy(deep=True) + ds2.copy(deep=True) + + def test_clip(ds) -> None: result = ds.clip(min=0.5) assert all((result.min(...) >= 0.5).values()) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 96800c5428a..7c6e6ae1489 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -431,6 +431,24 @@ def test_array_repr_variable(self) -> None: with xr.set_options(display_expand_data=False): formatting.array_repr(var) + def test_array_repr_recursive(self) -> None: + # GH:issue:7111 + + # direct recurion + var = xr.Variable("x", [0, 1]) + var.attrs["x"] = var + formatting.array_repr(var) + + da = xr.DataArray([0, 1], dims=["x"]) + da.attrs["x"] = da + formatting.array_repr(da) + + # indirect recursion + var.attrs["x"] = da + da.attrs["x"] = var + formatting.array_repr(var) + formatting.array_repr(da) + @requires_dask def test_array_scalar_format(self) -> None: # Test numpy scalars: @@ -615,6 +633,21 @@ def test__mapping_repr(display_max_rows, n_vars, n_attr) -> None: assert actual == expected +def test__mapping_repr_recursive() -> None: + # GH:issue:7111 + + # direct recursion + ds = xr.Dataset({"a": [["x"], [1, 2, 3]]}) + ds.attrs["ds"] = ds + formatting.dataset_repr(ds) + + # indirect recursion + ds2 = xr.Dataset({"b": [["y"], [1, 2, 3]]}) + ds.attrs["ds"] = ds2 + ds2.attrs["ds"] = ds + formatting.dataset_repr(ds2) + + def test__element_formatter(n_elements: int = 100) -> None: expected = """\ Dimensions without coordinates: dim_0: 3, dim_1: 3, dim_2: 3, dim_3: 3, diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index c12ae4b05a9..4e8fa4c8268 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -59,6 +59,8 @@ def var(): class VariableSubclassobjects: + cls: staticmethod[Variable] + def test_properties(self): data = 0.5 * np.arange(10) v = self.cls(["time"], data, {"foo": "bar"}) @@ -521,7 +523,7 @@ def test_concat_mixed_dtypes(self): @pytest.mark.parametrize("deep", [True, False]) @pytest.mark.parametrize("astype", [float, int, str]) - def test_copy(self, deep, astype): + def test_copy(self, deep: bool, astype: type[object]) -> None: v = self.cls("x", (0.5 * np.arange(10)).astype(astype), {"foo": "bar"}) w = v.copy(deep=deep) assert type(v) is type(w) @@ -534,6 +536,27 @@ def test_copy(self, deep, astype): assert source_ndarray(v.values) is source_ndarray(w.values) assert_identical(v, copy(v)) + def test_copy_deep_recursive(self) -> None: + # GH:issue:7111 + + # direct recursion + v = self.cls("x", [0, 1]) + v.attrs["other"] = v + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + v.copy(deep=True) + + # indirect recusrion + v2 = self.cls("y", [2, 3]) + v.attrs["other"] = v2 + v2.attrs["other"] = v + + # TODO: cannot use assert_identical on recursive Vars yet... + # lets just ensure that deep copy works without RecursionError + v.copy(deep=True) + v2.copy(deep=True) + def test_copy_index(self): midx = pd.MultiIndex.from_product( [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") @@ -545,7 +568,7 @@ def test_copy_index(self): assert isinstance(w.to_index(), pd.MultiIndex) assert_array_equal(v._data.array, w._data.array) - def test_copy_with_data(self): + def test_copy_with_data(self) -> None: orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = np.array([[2.5, 5.0], [7.1, 43]]) actual = orig.copy(data=new_data) @@ -553,20 +576,20 @@ def test_copy_with_data(self): expected.data = new_data assert_identical(expected, actual) - def test_copy_with_data_errors(self): + def test_copy_with_data_errors(self) -> None: orig = Variable(("x", "y"), [[1.5, 2.0], [3.1, 4.3]], {"foo": "bar"}) new_data = [2.5, 5.0] with pytest.raises(ValueError, match=r"must match shape of object"): orig.copy(data=new_data) - def test_copy_index_with_data(self): + def test_copy_index_with_data(self) -> None: orig = IndexVariable("x", np.arange(5)) new_data = np.arange(5, 10) actual = orig.copy(data=new_data) expected = IndexVariable("x", np.arange(5, 10)) assert_identical(expected, actual) - def test_copy_index_with_data_errors(self): + def test_copy_index_with_data_errors(self) -> None: orig = IndexVariable("x", np.arange(5)) new_data = np.arange(5, 20) with pytest.raises(ValueError, match=r"must match shape of object"):