diff --git a/xvec/zonal.py b/xvec/zonal.py index 3e7fcf3..715ded1 100644 --- a/xvec/zonal.py +++ b/xvec/zonal.py @@ -8,6 +8,7 @@ import pandas as pd import shapely import xarray as xr +from xarray.groupers import UniqueGrouper def _agg_rasterize(groups, stats, **kwargs): @@ -50,19 +51,32 @@ def _zonal_stats_rasterize( crs = None transform = acc._obj.rio.transform() + length = len(geometry) + dtype = np.int16 if length < np.iinfo(np.int16).max else np.int32 labels = features.rasterize( - zip(geometry, range(len(geometry)), strict=False), + zip(geometry, range(length), strict=False), out_shape=( acc._obj[y_coords].shape[0], acc._obj[x_coords].shape[0], ), transform=transform, - fill=np.nan, # type: ignore + fill=length, # type: ignore all_touched=all_touched, - dtype=np.float32, + dtype=dtype, ) - groups = acc._obj.groupby(xr.DataArray(labels, dims=(y_coords, x_coords))) + + unique = np.unique(labels).tolist() + unique.remove(length) + + obj = acc._obj.copy() + if isinstance(obj, xr.Dataset): + obj = obj.assign_coords( + __labels__=xr.DataArray(labels, dims=(y_coords, x_coords)) + ) + else: + obj["__labels__"] = xr.DataArray(labels, dims=(y_coords, x_coords)) + groups = obj.groupby({"__labels__": UniqueGrouper(labels=unique)}) if pd.api.types.is_list_like(stats): agg = {} @@ -89,10 +103,11 @@ def _zonal_stats_rasterize( raise ValueError(f"{stats} is not a valid aggregation.") vec_cube = ( - agg_array.reindex(group=range(len(geometry))) - .assign_coords(group=geometry) - .rename(group=name) - ).xvec.set_geom_indexes(name, crs=crs) + agg_array.reindex(__labels__=range(length)) + .assign_coords(__labels__=geometry) + .rename(__labels__=name) + .xvec.set_geom_indexes(name, crs=crs) + ) del groups gc.collect()