diff --git a/xvec/plotting.py b/xvec/plotting.py index e769986..1a13412 100644 --- a/xvec/plotting.py +++ b/xvec/plotting.py @@ -281,8 +281,18 @@ def _plot( cmap_params = {} # Handle simple case - single geometry with no faceting - if not col and isinstance(arr, xr.DataArray) and arr.ndim == 1: - arr.xvec.to_geodataframe(geometry=geometry).plot(arr.values, ax=axs, **kwargs) + if not col and isinstance(arr, xr.DataArray) and n_cols == 1 and n_rows == 1: + if arr.ndim == 2: + arr = arr.squeeze() + arr.xvec.to_geodataframe(geometry=geometry, name="plotting").plot( + arr.values, + ax=axs, + vmin=cmap_params.get("vmin", None), + vmax=cmap_params.get("vmax", None), + cmap=cmap_params.get("cmap", None), + categories=cmap_params.get("categories", None), + **kwargs, + ) axs.set_xlabel(x_label, fontsize="small") axs.set_ylabel(y_label, fontsize="small") @@ -298,7 +308,7 @@ def _plot( return fig, axs if not col and geometry in arr.xvec._geom_coords_all: - arr[geometry].drop_vars([geometry]).xvec.to_geodataframe().plot( + arr[geometry].drop_vars([geometry]).xvec.to_geodataframe(name="plotting").plot( ax=axs, **kwargs ) axs.set_xlabel(x_label, fontsize="small") diff --git a/xvec/tests/baseline_images/test_plotting/unnamed.png b/xvec/tests/baseline_images/test_plotting/unnamed.png new file mode 100644 index 0000000..f1fe50a Binary files /dev/null and b/xvec/tests/baseline_images/test_plotting/unnamed.png differ diff --git a/xvec/tests/baseline_images/test_plotting/void_dimension.png b/xvec/tests/baseline_images/test_plotting/void_dimension.png new file mode 100644 index 0000000..f1fe50a Binary files /dev/null and b/xvec/tests/baseline_images/test_plotting/void_dimension.png differ diff --git a/xvec/tests/test_plotting.py b/xvec/tests/test_plotting.py index f10fc68..aeb39f6 100644 --- a/xvec/tests/test_plotting.py +++ b/xvec/tests/test_plotting.py @@ -78,6 +78,37 @@ def test_1d(aggregated): assert ax.get_ylabel() == "Geodetic latitude\n[degree]" +@image_comparison( + baseline_images=["void_dimension"], extensions=["png"], style=[], tol=0.01 +) +def test_void_dimension(): + ds = xr.tutorial.open_dataset("eraint_uvz").load() + counties = gpd.read_file(geodatasets.get_path("geoda natregimes")).to_crs(4326) + + ds.sel(month=1, level=[200]).z.xvec.zonal_stats( + counties.geometry, + x_coords="longitude", + y_coords="latitude", + all_touched=True, + ).xvec.plot() + + +@image_comparison(baseline_images=["unnamed"], extensions=["png"], style=[], tol=0.01) +def test_unnamed(): + ds = xr.tutorial.open_dataset("eraint_uvz").load() + counties = gpd.read_file(geodatasets.get_path("geoda natregimes")).to_crs(4326) + + arr = ds.sel(month=1, level=[200]).z + arr.name = None + + arr.xvec.zonal_stats( + counties.geometry, + x_coords="longitude", + y_coords="latitude", + all_touched=True, + ).sel(level=200).xvec.plot() + + @image_comparison(baseline_images=["var_geom"], extensions=["png"], style=[], tol=0.01) def test_var_geom(glaciers): f, ax = glaciers.geometry.xvec.plot(col="year")