Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions xvec/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ def zonal_stats(
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: Hashable = "geometry",
index: bool = None,
method: str = "rasterize",
Expand All @@ -949,36 +949,39 @@ def zonal_stats(
y_coords : Hashable
name of the coordinates containing ``y`` coordinates (i.e. the second value
in the coordinate pair encoding the vertex of the polygon)
stats : string | Callable
stats : string | Callable | Sequence[str | Callable | tuple]
Spatial aggregation statistic method, by default "mean". Any of the
aggregations available as :class:`xarray.DataArray` or
:class:`xarray.DataArrayGroupBy` methods like
:meth:`~xarray.DataArray.mean`, :meth:`~xarray.DataArray.min`,
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile`
are available. Alternatively, you can pass a ``Callable`` supported
by :meth:`~xarray.DataArray.reduce`.
:meth:`~xarray.DataArray.max`, or :meth:`~xarray.DataArray.quantile` are
available. Alternatively, you can pass a ``Callable`` supported by
:meth:`~xarray.DataArray.reduce` or a list with ``strings``, ``callables``
or ``tuples`` in a ``(name, func, {kwargs})`` format, where ``func`` can be
a string or a callable.
name : Hashable, optional
Name of the dimension that will hold the ``geometry``, by default "geometry"
index : bool, optional
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will attach its index as another
coordinate to the geometry dimension in the resulting object. If
``index=None``, the index will be stored if the `geometry.index` is a named
or non-default index. If ``index=False``, it will never be stored. This is
useful as an attribute link between the resulting array and the GeoPandas
object from which the geometry is sourced.
If ``geometry`` is a :class:`~geopandas.GeoSeries`, ``index=True`` will
attach its index as another coordinate to the geometry dimension in the
resulting object. If ``index=None``, the index will be stored if the
`geometry.index` is a named or non-default index. If ``index=False``, it
will never be stored. This is useful as an attribute link between the
resulting array and the GeoPandas object from which the geometry is sourced.
method : str, optional
The method of data extraction. The default is ``"rasterize"``, which uses
:func:`rasterio.features.rasterize` and is faster, but can lead to loss
of information in case of small polygons or lines. Other option is ``"iterate"``, which
iterates over geometries and uses :func:`rasterio.features.geometry_mask`.
:func:`rasterio.features.rasterize` and is faster, but can lead to loss of
information in case of small polygons or lines. Other option is
``"iterate"``, which iterates over geometries and uses
:func:`rasterio.features.geometry_mask`.
all_touched : bool, optional
If True, all pixels touched by geometries will be considered. If False, only
pixels whose center is within the polygon or that are selected by
Bresenham’s line algorithm will be considered.
n_jobs : int, optional
Number of parallel threads to use. It is recommended to set this to the
number of physical cores of the CPU. ``-1`` uses all available cores. Applies
only if ``method="iterate"``.
number of physical cores of the CPU. ``-1`` uses all available cores.
Applies only if ``method="iterate"``.
**kwargs : optional
Keyword arguments to be passed to the aggregation function
(e.g., ``Dataset.quantile(**kwargs)``).
Expand All @@ -990,8 +993,6 @@ def zonal_stats(
the :class:`GeometryIndex` of ``geometry``.

"""
# TODO: allow multiple stats at the same time (concat along a new axis),
# TODO: possibly as a list of tuples to include names?
if method == "rasterize":
result = _zonal_stats_rasterize(
self,
Expand Down Expand Up @@ -1033,9 +1034,7 @@ def zonal_stats(
result = result.assign_coords({index_name: (name, geometry.index)})

# standardize the shape - each method comes with a different one
return result.transpose(
name, *tuple(d for d in self._obj.dims if d not in [x_coords, y_coords])
)
return result.transpose(name, ...)

def extract_points(
self,
Expand Down
61 changes: 61 additions & 0 deletions xvec/tests/test_zonal_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,64 @@ def test_callable(method):
world.geometry, "longitude", "latitude", method=method, stats="std"
)
xr.testing.assert_identical(da_agg, da_std)


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_multiple(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
result = ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=[
"mean",
"sum",
("quantile", "quantile", {"q": [0.1, 0.2, 0.3]}),
("numpymean", np.nanmean),
np.nanmean,
],
method=method,
n_jobs=1,
)
assert sorted(result.dims) == sorted(
[
"level",
"zonal_statistics",
"geometry",
"month",
"quantile",
]
)

assert (
result.zonal_statistics == ["mean", "sum", "quantile", "numpymean", "nanmean"]
).all()


@pytest.mark.parametrize("method", ["rasterize", "iterate"])
def test_invalid(method):
ds = xr.tutorial.open_dataset("eraint_uvz")
world = gpd.read_file(geodatasets.get_path("naturalearth land"))
with pytest.raises(ValueError, match=r"\['gorilla'\] is not a valid aggregation."):
ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=[
"mean",
["gorilla"],
],
method=method,
n_jobs=1,
)

with pytest.raises(ValueError, match="3 is not a valid aggregation."):
ds.xvec.zonal_stats(
world.geometry[:10].boundary,
"longitude",
"latitude",
stats=3,
method=method,
n_jobs=1,
)
77 changes: 65 additions & 12 deletions xvec/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,31 @@
from typing import Callable

import numpy as np
import pandas as pd
import shapely
import xarray as xr


def _agg_rasterize(groups, stats, **kwargs):
if isinstance(stats, str):
return getattr(groups, stats)(**kwargs)
return groups.reduce(stats, keep_attrs=True, **kwargs)


def _agg_iterate(masked, stats, x_coords, y_coords, **kwargs):
if isinstance(stats, str):
return getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
return masked.reduce(stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs)


def _zonal_stats_rasterize(
acc,
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
all_touched: bool = False,
**kwargs,
Expand Down Expand Up @@ -47,10 +62,31 @@ def _zonal_stats_rasterize(
all_touched=all_touched,
)
groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords)))
if isinstance(stats, str):
agg = getattr(groups, stats)(**kwargs)

if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats:
if isinstance(stat, str):
agg[stat] = _agg_rasterize(groups, stat, **kwargs)
elif callable(stat):
agg[stat.__name__] = _agg_rasterize(groups, stat, **kwargs)
elif isinstance(stat, tuple):
kws = stat[2] if len(stat) == 3 else {}
agg[stat[0]] = _agg_rasterize(groups, stat[1], **kws)
else:
raise ValueError(f"{stat} is not a valid aggregation.")

agg = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str) or callable(stats):
agg = _agg_rasterize(groups, stats, **kwargs)
else:
agg = groups.reduce(stats, keep_attrs=True, **kwargs)
raise ValueError(f"{stats} is not a valid aggregation.")

vec_cube = (
agg.reindex(group=range(len(geometry)))
.assign_coords(group=geometry)
Expand All @@ -68,7 +104,7 @@ def _zonal_stats_iterative(
geometry: Sequence[shapely.Geometry],
x_coords: Hashable,
y_coords: Hashable,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
name: str = "geometry",
all_touched: bool = False,
n_jobs: int = -1,
Expand Down Expand Up @@ -168,7 +204,7 @@ def _agg_geom(
trans,
x_coords: str = None,
y_coords: str = None,
stats: str | Callable = "mean",
stats: str | Callable | Sequence[str | Callable | tuple] = "mean",
all_touched=False,
**kwargs,
):
Expand Down Expand Up @@ -216,14 +252,31 @@ def _agg_geom(
all_touched=all_touched,
)
masked = acc._obj.where(xr.DataArray(mask, dims=(y_coords, x_coords)))
if isinstance(stats, str):
result = getattr(masked, stats)(
dim=(y_coords, x_coords), keep_attrs=True, **kwargs
if pd.api.types.is_list_like(stats):
agg = {}
for stat in stats:
if isinstance(stat, str):
agg[stat] = _agg_iterate(masked, stat, x_coords, y_coords, **kwargs)
elif callable(stat):
agg[stat.__name__] = _agg_iterate(
masked, stat, x_coords, y_coords, **kwargs
)
elif isinstance(stat, tuple):
kws = stat[2] if len(stat) == 3 else {}
agg[stat[0]] = _agg_iterate(masked, stat[1], x_coords, y_coords, **kws)
else:
raise ValueError(f"{stat} is not a valid aggregation.")

result = xr.concat(
agg.values(),
dim=xr.DataArray(
list(agg.keys()), name="zonal_statistics", dims="zonal_statistics"
),
)
elif isinstance(stats, str) or callable(stats):
result = _agg_iterate(masked, stats, x_coords, y_coords, **kwargs)
else:
result = masked.reduce(
stats, dim=(y_coords, x_coords), keep_attrs=True, **kwargs
)
raise ValueError(f"{stats} is not a valid aggregation.")

del mask
gc.collect()
Expand Down