diff --git a/xvec/accessor.py b/xvec/accessor.py index 917889b..8fa4dc0 100644 --- a/xvec/accessor.py +++ b/xvec/accessor.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Hashable, Mapping, Sequence -from typing import Any +from typing import Any, Callable import numpy as np import pandas as pd @@ -921,10 +921,10 @@ def to_geodataframe( def zonal_stats( self, - polygons: Sequence[shapely.Geometry], + geometry: Sequence[shapely.Geometry], x_coords: Hashable, y_coords: Hashable, - stats: str = "mean", + stats: str | Callable = "mean", name: Hashable = "geometry", index: bool = None, method: str = "rasterize", @@ -934,37 +934,43 @@ def zonal_stats( ): """Extract the values from a dataset indexed by a set of geometries - The CRS of the raster and that of polygons need to be equal. + The CRS of the raster and that of geometry need to be equal. Xvec does not verify their equality. Parameters ---------- - polygons : Sequence[shapely.Geometry] + geometry : Sequence[shapely.Geometry] An arrray-like (1-D) of shapely geometries, like a numpy array or - :class:`geopandas.GeoSeries`. + :class:`geopandas.GeoSeries`. Polygon and LineString geometry types are + supported. x_coords : Hashable name of the coordinates containing ``x`` coordinates (i.e. the first value in the coordinate pair encoding the vertex of the polygon) y_coords : Hashable name of the coordinates containing ``y`` coordinates (i.e. the second value in the coordinate pair encoding the vertex of the polygon) - stats : string - Spatial aggregation statistic method, by default "mean". It supports the - following statistcs: ['mean', 'median', 'min', 'max', 'sum'] + stats : string | Callable + Spatial aggregation statistic method, by default "mean". Any of the + aggregations available as :class:`xarray.DataArray` or + :class:`xarray.DataArrayGroupBy` methods like + :meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`, + :meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile` + are available. Alternatively, you can pass a ``Callable`` supported + by :meth:`~xarray.DataArray.reduce`. name : Hashable, optional - Name of the dimension that will hold the ``polygons``, by default "geometry" + Name of the dimension that will hold the ``geometry``, by default "geometry" index : bool, optional - If `polygons` is a GeoSeries, ``index=True`` will attach its index as another + If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another coordinate to the geometry dimension in the resulting object. If - ``index=None``, the index will be stored if the `polygons.index` is a named + ``index=None``, the index will be stored if the `geometry.index` is a named or non-default index. If ``index=False``, it will never be stored. This is useful as an attribute link between the resulting array and the GeoPandas - object from which the polygons are sourced. + object from which the geometry is sourced. method : str, optional The method of data extraction. The default is ``"rasterize"``, which uses :func:`rasterio.features.rasterize` and is faster, but can lead to loss - of information in case of small polygons. Other option is ``"iterate"``, which - iterates over polygons and uses :func:`rasterio.features.geometry_mask`. + of information in case of small polygons or lines. Other option is ``"iterate"``, which + iterates over geometries and uses :func:`rasterio.features.geometry_mask`. all_touched : bool, optional If True, all pixels touched by geometries will be considered. If False, only pixels whose center is within the polygon or that are selected by @@ -975,22 +981,21 @@ def zonal_stats( only if ``method="iterate"``. **kwargs : optional Keyword arguments to be passed to the aggregation function - (e.g., ``Dataset.mean(**kwargs)``). + (e.g., ``Dataset.quantile(**kwargs)``). Returns ------- - Dataset + Dataset or DataArray A subset of the original object with N-1 dimensions indexed by - the the GeometryIndex. + the :class:`GeometryIndex` of ``geometry``. """ # TODO: allow multiple stats at the same time (concat along a new axis), # TODO: possibly as a list of tuples to include names? - # TODO: allow callable in stat (via .reduce()) if method == "rasterize": result = _zonal_stats_rasterize( self, - polygons=polygons, + geometry=geometry, x_coords=x_coords, y_coords=y_coords, stats=stats, @@ -1001,7 +1006,7 @@ def zonal_stats( elif method == "iterate": result = _zonal_stats_iterative( self, - polygons=polygons, + geometry=geometry, x_coords=x_coords, y_coords=y_coords, stats=stats, @@ -1017,15 +1022,15 @@ def zonal_stats( ) # save the index as a data variable - if isinstance(polygons, pd.Series): + if isinstance(geometry, pd.Series): if index is None: - if polygons.index.name is not None or not polygons.index.equals( - pd.RangeIndex(0, len(polygons)) + if geometry.index.name is not None or not geometry.index.equals( + pd.RangeIndex(0, len(geometry)) ): index = True if index: - index_name = polygons.index.name if polygons.index.name else "index" - result = result.assign_coords({index_name: (name, polygons.index)}) + index_name = geometry.index.name if geometry.index.name else "index" + result = result.assign_coords({index_name: (name, geometry.index)}) # standardize the shape - each method comes with a different one return result.transpose( diff --git a/xvec/tests/test_zonal_stats.py b/xvec/tests/test_zonal_stats.py index 73116ed..c9c44a9 100644 --- a/xvec/tests/test_zonal_stats.py +++ b/xvec/tests/test_zonal_stats.py @@ -209,3 +209,29 @@ def test_crs(method): actual = da.xvec.zonal_stats(polygons, "x", "y", stats="sum", method=method) xr.testing.assert_identical(actual, expected) + + +@pytest.mark.parametrize("method", ["rasterize", "iterate"]) +def test_callable(method): + ds = xr.tutorial.open_dataset("eraint_uvz") + world = gpd.read_file(geodatasets.get_path("naturalearth land")) + ds_agg = ds.xvec.zonal_stats( + world.geometry, "longitude", "latitude", method=method, stats=np.nanstd + ) + ds_std = ds.xvec.zonal_stats( + world.geometry, "longitude", "latitude", method=method, stats="std" + ) + xr.testing.assert_identical(ds_agg, ds_std) + + da_agg = ds.z.xvec.zonal_stats( + world.geometry, + "longitude", + "latitude", + method=method, + stats=np.nanstd, + n_jobs=1, + ) + da_std = ds.z.xvec.zonal_stats( + world.geometry, "longitude", "latitude", method=method, stats="std" + ) + xr.testing.assert_identical(da_agg, da_std) diff --git a/xvec/zonal.py b/xvec/zonal.py index d860fb1..0681c5e 100644 --- a/xvec/zonal.py +++ b/xvec/zonal.py @@ -2,6 +2,7 @@ import gc from collections.abc import Hashable, Sequence +from typing import Callable import numpy as np import shapely @@ -10,16 +11,16 @@ def _zonal_stats_rasterize( acc, - polygons: Sequence[shapely.Geometry], + geometry: Sequence[shapely.Geometry], x_coords: Hashable, y_coords: Hashable, - stats: str = "mean", + stats: str | Callable = "mean", name: str = "geometry", all_touched: bool = False, **kwargs, ): try: - import rasterio # noqa: F401 + import rasterio import rioxarray # noqa: F401 except ImportError as err: raise ImportError( @@ -28,15 +29,15 @@ def _zonal_stats_rasterize( "'pip install rioxarray'." ) from err - if hasattr(polygons, "crs"): - crs = polygons.crs + if hasattr(geometry, "crs"): + crs = geometry.crs else: crs = None transform = acc._obj.rio.transform() labels = rasterio.features.rasterize( - zip(polygons, range(len(polygons))), + zip(geometry, range(len(geometry))), out_shape=( acc._obj[y_coords].shape[0], acc._obj[x_coords].shape[0], @@ -46,10 +47,13 @@ def _zonal_stats_rasterize( all_touched=all_touched, ) groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords))) - agg = getattr(groups, stats)(**kwargs) + if isinstance(stats, str): + agg = getattr(groups, stats)(**kwargs) + else: + agg = groups.reduce(stats, keep_attrs=True, **kwargs) vec_cube = ( - agg.reindex(group=range(len(polygons))) - .assign_coords(group=polygons) + agg.reindex(group=range(len(geometry))) + .assign_coords(group=geometry) .rename(group=name) ).xvec.set_geom_indexes(name, crs=crs) @@ -61,10 +65,10 @@ def _zonal_stats_rasterize( def _zonal_stats_iterative( acc, - polygons: Sequence[shapely.Geometry], + geometry: Sequence[shapely.Geometry], x_coords: Hashable, y_coords: Hashable, - stats: str = "mean", + stats: str | Callable = "mean", name: str = "geometry", all_touched: bool = False, n_jobs: int = -1, @@ -72,12 +76,12 @@ def _zonal_stats_iterative( ): """Extract the values from a dataset indexed by a set of geometries - The CRS of the raster and that of polygons need to be equal. + The CRS of the raster and that of geometry need to be equal. Xvec does not verify their equality. Parameters ---------- - polygons : Sequence[shapely.Geometry] + geometry : Sequence[shapely.Geometry] An arrray-like (1-D) of shapely geometries, like a numpy array or :class:`geopandas.GeoSeries`. x_coords : Hashable @@ -87,10 +91,14 @@ def _zonal_stats_iterative( name of the coordinates containing ``y`` coordinates (i.e. the second value in the coordinate pair encoding the vertex of the polygon) stats : Hashable - Spatial aggregation statistic method, by default "mean". It supports the - following statistcs: ['mean', 'median', 'min', 'max', 'sum'] + Spatial aggregation statistic method, by default "mean". Any of the + aggregations available as DataArray or DataArrayGroupBy like + :meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`, + :meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`, + methods are available. Alternatively, you can pass a ``Callable`` supported + by :meth:`~xarray.DataArray.reduce`. name : Hashable, optional - Name of the dimension that will hold the ``polygons``, by default "geometry" + Name of the dimension that will hold the ``geometry``, by default "geometry" all_touched : bool, optional If True, all pixels touched by geometries will be considered. If False, only pixels whose center is within the polygon or that are selected by @@ -140,14 +148,14 @@ def _zonal_stats_iterative( all_touched=all_touched, **kwargs, ) - for geom in polygons + for geom in geometry ) - if hasattr(polygons, "crs"): - crs = polygons.crs + if hasattr(geometry, "crs"): + crs = geometry.crs else: crs = None vec_cube = xr.concat( - zonal, dim=xr.DataArray(polygons, name=name, dims=name) + zonal, dim=xr.DataArray(geometry, name=name, dims=name) ).xvec.set_geom_indexes(name, crs=crs) gc.collect() @@ -160,7 +168,7 @@ def _agg_geom( trans, x_coords: str = None, y_coords: str = None, - stats: str = "mean", + stats: str | Callable = "mean", all_touched=False, **kwargs, ): @@ -207,9 +215,15 @@ def _agg_geom( invert=True, all_touched=all_touched, ) - result = getattr( - acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords))), stats - )(dim=(y_coords, x_coords), keep_attrs=True, **kwargs) + masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords))) + if isinstance(stats, str): + result = getattr(masked, stats)( + dim=(y_coords, x_coords), keep_attrs=True, **kwargs + ) + else: + result = masked.reduce( + stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs + ) del mask gc.collect()