From 3ad10d9e4319f91e169d8578307085f19a7cae46 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 18:57:36 +0100 Subject: [PATCH 01/31] Use broadcast_like for 2d plot coordinates Use broadcast_like if either `x` or `y` inputs are 2d to ensure that both have dimensions in the same order as the DataArray being plotted. Convert to numpy arrays after possibly using broadcast_like. Simplifies code, and fixes #5097 (bug when dimensions have the same size). --- xarray/plot/plot.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9c7323fc73e..92e4f2fb797 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -671,28 +671,21 @@ def newplotfunc( darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb ) - # better to pass the ndarrays directly to plotting functions - xval = darray[xlab].values - yval = darray[ylab].values - - # check if we need to broadcast one dimension - if xval.ndim < yval.ndim: - dims = darray[ylab].dims - if xval.shape[0] == yval.shape[0]: - xval = np.broadcast_to(xval[:, np.newaxis], yval.shape) - else: - xval = np.broadcast_to(xval[np.newaxis, :], yval.shape) - - elif yval.ndim < xval.ndim: - dims = darray[xlab].dims - if yval.shape[0] == xval.shape[0]: - yval = np.broadcast_to(yval[:, np.newaxis], xval.shape) - else: - yval = np.broadcast_to(yval[np.newaxis, :], xval.shape) - elif xval.ndim == 2: - dims = darray[xlab].dims + xval = darray[xlab] + yval = darray[ylab] + + if xval.ndim > 1 or yval.ndim > 1: + # Passing 2d coordinate values, need to ensure they are transposed the same + # way as darray + xval = xval.broadcast_like(darray) + yval = yval.broadcast_like(darray) + dims = darray.dims else: - dims = (darray[ylab].dims[0], darray[xlab].dims[0]) + dims = (yval.dims[0], xval.dims[0]) + + # better to pass the ndarrays directly to plotting functions + xval = xval.values + yval = yval.values # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names From 17151d1cfb4681f01119f24a604c0de15d0bef6d Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 20:37:31 +0100 Subject: [PATCH 02/31] Update whats-new --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 408b59d3c6a..0d7e69451c3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -78,6 +78,9 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is + 2d (:issue:`5079`, :pull:`5099`). + By `John Omotani `_ - Ensure standard calendar times encoded with large values (i.e. greater than approximately 292 years), can be decoded correctly without silently overflowing (:pull:`5050`). This was a regression in xarray 0.17.0. By `Zeb Nicholls `_. - Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`). By `Victor Negîrneac `_. From 38220a6e73122edd8ed9d2ddfa81cec97256cc3a Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 18:47:11 +0100 Subject: [PATCH 03/31] Implement 'surface()' plot function Wraps mpl_toolkits.mplot3d.axes3d.plot_surface --- xarray/plot/__init__.py | 3 ++- xarray/plot/plot.py | 27 ++++++++++++++++++++++++--- xarray/plot/utils.py | 8 ++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 86a09506824..28ae0cf32e7 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,6 +1,6 @@ from .dataset_plot import scatter from .facetgrid import FacetGrid -from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step +from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step, surface __all__ = [ "plot", @@ -13,4 +13,5 @@ "pcolormesh", "FacetGrid", "scatter", + "surface", ] diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 92e4f2fb797..f72ab5fb0e9 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -633,7 +633,9 @@ def newplotfunc( # Decide on a default for the colorbar before facetgrids if add_colorbar is None: - add_colorbar = plotfunc.__name__ != "contour" + add_colorbar = plotfunc.__name__ != "contour" and not ( + plotfunc.__name__ == "surface" and cmap is None + ) imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( 3 + (row is not None) + (col is not None) ) @@ -674,9 +676,10 @@ def newplotfunc( xval = darray[xlab] yval = darray[ylab] - if xval.ndim > 1 or yval.ndim > 1: + if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": # Passing 2d coordinate values, need to ensure they are transposed the same - # way as darray + # way as darray. + # Also surface plots always need 2d coordinates xval = xval.broadcast_like(darray) yval = yval.broadcast_like(darray) dims = darray.dims @@ -736,6 +739,11 @@ def newplotfunc( if subplot_kws is None: subplot_kws = dict() + + if "surface" == plotfunc.__name__: + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) primitive = plotfunc( @@ -755,6 +763,8 @@ def newplotfunc( ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) ax.set_title(darray._title_for_slice()) + if plotfunc.__name__ == "surface": + ax.set_zlabel(label_from_attrs(darray)) if add_colorbar: if add_labels and "label" not in cbar_kwargs: @@ -987,3 +997,14 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): ax.set_ylim(y[0], y[-1]) return primitive + + +@_plot2d +def surface(x, y, z, ax, **kwargs): + """ + Surface plot of 2d DataArray + + Wraps :func:`matplotlib:mpl_toolkits.mplot3d.axes3d.plot_surface` + """ + primitive = ax.plot_surface(x, y, z, **kwargs) + return primitive diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a83bc28e273..325ea799f28 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -804,6 +804,14 @@ def _process_cmap_cbar_kwargs( cmap_params cbar_kwargs """ + if func.__name__ == "surface": + # Leave user to specify cmap settings for surface plots + kwargs["cmap"] = cmap + return { + k: kwargs.get(k, None) + for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"] + }, {} + cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs) if "contour" in func.__name__ and levels is None: From 0ce6941ea29c6c2fc9db16e73a2834fe5017d0f6 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 22:21:23 +0100 Subject: [PATCH 04/31] Make surface plots work with facet grids --- xarray/plot/facetgrid.py | 4 +++- xarray/plot/plot.py | 18 +++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 2d3c0595026..781cdcd6ea7 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -263,7 +263,9 @@ def map_dataarray(self, func, x, y, **kwargs): if k not in {"cmap", "colors", "cbar_kwargs", "levels"} } func_kwargs.update(cmap_params) - func_kwargs.update({"add_colorbar": False, "add_labels": False}) + func_kwargs["add_colorbar"] = False + if func.__name__ != "surface": + func_kwargs["add_labels"] = False # Get x, y labels for the first subplot x, y = _infer_xy_labels( diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f72ab5fb0e9..457039aee31 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -648,6 +648,17 @@ def newplotfunc( darray = _rescale_imshow_rgb(darray, vmin, vmax, robust) vmin, vmax, robust = None, None, False + if subplot_kws is None: + subplot_kws = dict() + + if "surface" == plotfunc.__name__ and not kwargs.get("_is_facetgrid", False): + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" + + # In facet grids, shared axis labels don't make sense for surface plots + sharex = False + sharey = False + # Handle facetgrids first if row or col: allargs = locals().copy() @@ -737,13 +748,6 @@ def newplotfunc( # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") - if subplot_kws is None: - subplot_kws = dict() - - if "surface" == plotfunc.__name__: - # Need to create a "3d" Axes instance for surface plots - subplot_kws["projection"] = "3d" - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) primitive = plotfunc( From c7dbdf181baec474fe45d1f3b0192a08e371e928 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 22:50:42 +0100 Subject: [PATCH 05/31] Unit tests for surface plot --- xarray/tests/test_plot.py | 47 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c0d3712dfa0..0f052bb64d7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -36,6 +36,7 @@ try: import matplotlib as mpl import matplotlib.pyplot as plt + import mpl_toolkits except ImportError: pass @@ -132,8 +133,8 @@ def setup(self): # Remove all matplotlib figures plt.close("all") - def pass_in_axis(self, plotmethod): - fig, axes = plt.subplots(ncols=2) + def pass_in_axis(self, plotmethod, subplot_kw={}): + fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw) plotmethod(ax=axes[0]) assert axes[0].has_data() @@ -1794,6 +1795,43 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 +class TestSurface(Common2dMixin, PlotTestCase): + + plotfunc = staticmethod(xplt.surface) + + def test_primitive_artist_returned(self): + artist = self.plotmethod() + assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) + + def test_everything_plotted(self): + artist = self.plotmethod() + assert artist.get_array().size == self.darray.size + + @pytest.mark.slow + def test_2d_coord_names(self): + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + assert self.name == ax.get_zlabel() + + def test_xyincrease_false_changes_axes(self): + # Does not make sense for surface plots + pass + + def test_xyincrease_true_changes_axes(self): + # Does not make sense for surface plots + pass + + def test_can_pass_in_axis(self): + self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) + + def test_default_cmap(self): + # Does not make sense for surface plots with default arguments + pass + + class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): @@ -2485,6 +2523,9 @@ def test_cfdatetime_pcolormesh_plot(self): def test_cfdatetime_contour_plot(self): self.darray.plot.contour() + def test_cfdatetime_surface_plot(self): + self.darray.plot.surface() + @requires_cftime @pytest.mark.skipif(has_nc_time_axis, reason="nc_time_axis is installed") @@ -2578,7 +2619,7 @@ def test_yticks_kwarg(self, da): @requires_matplotlib -@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"]) +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour", "surface"]) def test_plot_transposed_nondim_coord(plotfunc): x = np.linspace(0, 10, 101) h = np.linspace(3, 7, 101) From bc0c85a87701a1a6880e359e1c87dfba97406c53 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 22:51:48 +0100 Subject: [PATCH 06/31] Minor fixes for surface plots --- xarray/plot/plot.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 457039aee31..695a39cbab6 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -652,8 +652,9 @@ def newplotfunc( subplot_kws = dict() if "surface" == plotfunc.__name__ and not kwargs.get("_is_facetgrid", False): - # Need to create a "3d" Axes instance for surface plots - subplot_kws["projection"] = "3d" + if ax is None: + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" # In facet grids, shared axis labels don't make sense for surface plots sharex = False @@ -671,6 +672,19 @@ def newplotfunc( plt = import_matplotlib_pyplot() + if ( + "surface" == plotfunc.__name__ + and not kwargs.get("_is_facetgrid", False) + and ax is not None + ): + import mpl_toolkits + + if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): + raise ValueError( + "If ax is passed to surface(), it must be created with " + 'projection="3d"' + ) + rgb = kwargs.pop("rgb", None) if rgb is not None and plotfunc.__name__ != "imshow": raise ValueError('The "rgb" keyword is only valid for imshow()') From d31e1938dc3a2f7a97703856220564f80da0a84a Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 31 Mar 2021 23:39:40 +0100 Subject: [PATCH 07/31] Add surface plots to api.rst and api-hidden.rst --- doc/api-hidden.rst | 1 + doc/api.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index f5e9348d4eb..7209e9b0fe7 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -597,6 +597,7 @@ plot.imshow plot.pcolormesh plot.scatter + plot.surface plot.FacetGrid.map_dataarray plot.FacetGrid.set_titles diff --git a/doc/api.rst b/doc/api.rst index a140d9e2b81..271edd9db11 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -587,6 +587,7 @@ Plotting DataArray.plot.line DataArray.plot.pcolormesh DataArray.plot.step + DataArray.plot.surface .. _api.ufuncs: From 7acce7efe064dbfe203bc02d93d0b28ec5e65b21 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 00:03:02 +0100 Subject: [PATCH 08/31] Update whats-new --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0d7e69451c3..9c05fab3aae 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ v0.17.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make + surface plots (:issue:`#2235` :issue:`#5084` :pull:`5101`). + By `John Omotani `_. - Add :py:meth:`Dataset.query` and :py:meth:`DataArray.query` which enable indexing of datasets and data arrays by evaluating query expressions against the values of the data variables (:pull:`4984`). By `Alistair Miles `_. From 1e4ff18a75fcdd7d520e05d0d9ef57c61fcf253e Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 10:40:38 +0100 Subject: [PATCH 09/31] Fix tests --- xarray/tests/test_plot.py | 46 ++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 0f052bb64d7..58c56f17615 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1080,6 +1080,9 @@ class Common2dMixin: Should have the same name as the method. """ + # Needs to be overridden in TestSurface for facet grid plots + subplot_kws = {} + @pytest.fixture(autouse=True) def setUp(self): da = DataArray( @@ -1395,7 +1398,7 @@ def test_colorbar_kwargs(self): def test_verbose_facetgrid(self): a = easy_array((10, 15, 3)) d = DataArray(a, dims=["y", "x", "z"]) - g = xplt.FacetGrid(d, col="z") + g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws) g.map_dataarray(self.plotfunc, "x", "y") for ax in g.axes.flat: assert ax.has_data() @@ -1798,15 +1801,12 @@ def test_origin_overrides_xyincrease(self): class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) + subplot_kws = {"projection": "3d"} def test_primitive_artist_returned(self): artist = self.plotmethod() assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) - def test_everything_plotted(self): - artist = self.plotmethod() - assert artist.get_array().size == self.darray.size - @pytest.mark.slow def test_2d_coord_names(self): self.plotmethod(x="x2d", y="y2d") @@ -1814,7 +1814,7 @@ def test_2d_coord_names(self): ax = plt.gca() assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() - assert self.name == ax.get_zlabel() + assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() def test_xyincrease_false_changes_axes(self): # Does not make sense for surface plots @@ -1831,6 +1831,40 @@ def test_default_cmap(self): # Does not make sense for surface plots with default arguments pass + def test_diverging_color_limits(self): + # Does not make sense for surface plots with default arguments + pass + + def test_colorbar_kwargs(self): + # Does not make sense for surface plots with default arguments + pass + + def test_cmap_and_color_both(self): + # Does not make sense for surface plots with default arguments + pass + + # Need to modify this test for surface(), because all subplots should have labels, + # not just left and bottom + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self): + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + + # Infering labels + g = self.plotfunc(d, col="z", col_wrap=2) + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() + class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) From e12b7ce64811c936aa8115f8625b7438eb17d4f6 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 11:08:35 +0100 Subject: [PATCH 10/31] mypy fix --- xarray/tests/test_plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 58c56f17615..e497231136a 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2,6 +2,7 @@ import inspect from copy import copy from datetime import datetime +from typing import Any, Dict import numpy as np import pandas as pd @@ -1081,7 +1082,7 @@ class Common2dMixin: """ # Needs to be overridden in TestSurface for facet grid plots - subplot_kws = {} + subplot_kws: Dict[Any, Any] = {} @pytest.fixture(autouse=True) def setUp(self): From 266bd4abbe1b69cad03ada68fa8b1005f6f1dc72 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 14:08:29 +0100 Subject: [PATCH 11/31] seaborn doesn't work with matplotlib 3d toolkit --- xarray/tests/test_plot.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index e497231136a..52128be6c5a 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1844,6 +1844,11 @@ def test_cmap_and_color_both(self): # Does not make sense for surface plots with default arguments pass + def test_seaborn_palette_as_cmap(self): + # seaborn does not work with mpl_toolkits.mplot3d + with pytest.raises(ValueError): + super().test_seaborn_palette_as_cmap() + # Need to modify this test for surface(), because all subplots should have labels, # not just left and bottom @pytest.mark.filterwarnings("ignore:tight_layout cannot") From e3de64f05e275e083d0e54698abb993956ebf3a1 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 14:33:15 +0100 Subject: [PATCH 12/31] Remove cfdatetime surface plot test Does not work because the datetime.timedelta does not work with surface's 'shading'. --- xarray/tests/test_plot.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 52128be6c5a..05ec40e9467 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2563,9 +2563,6 @@ def test_cfdatetime_pcolormesh_plot(self): def test_cfdatetime_contour_plot(self): self.darray.plot.contour() - def test_cfdatetime_surface_plot(self): - self.darray.plot.surface() - @requires_cftime @pytest.mark.skipif(has_nc_time_axis, reason="nc_time_axis is installed") From 82c708ec7a77240c5c906e797cecb0fed414b876 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 14:47:59 +0100 Subject: [PATCH 13/31] Ignore type checks for mpl_toolkits module --- xarray/plot/plot.py | 2 +- xarray/tests/test_plot.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 695a39cbab6..7ddcdf1e220 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -677,7 +677,7 @@ def newplotfunc( and not kwargs.get("_is_facetgrid", False) and ax is not None ): - import mpl_toolkits + import mpl_toolkits # type: ignore if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): raise ValueError( diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 05ec40e9467..3593d02beef 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -37,7 +37,7 @@ try: import matplotlib as mpl import matplotlib.pyplot as plt - import mpl_toolkits + import mpl_toolkits # type: ignore except ImportError: pass From f27aa45bf04278b08f30ecd04666ea56c2ba38fc Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 15:26:15 +0100 Subject: [PATCH 14/31] Check matplotlib version is new enough for surface plots --- xarray/plot/plot.py | 10 ++++++++++ xarray/tests/test_plot.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 7ddcdf1e220..c8699323597 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -652,6 +652,16 @@ def newplotfunc( subplot_kws = dict() if "surface" == plotfunc.__name__ and not kwargs.get("_is_facetgrid", False): + # Check we have new enough version of matplotlib + from distutils.version import LooseVersion + + import matplotlib as mpl + + if LooseVersion(mpl.__version__) < "3.2.0": + raise ValueError("surface plot requires at least matplotlib-3.2.0") + del LooseVersion + del mpl + if ax is None: # Need to create a "3d" Axes instance for surface plots subplot_kws["projection"] = "3d" diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 3593d02beef..5824a96409f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2,6 +2,7 @@ import inspect from copy import copy from datetime import datetime +from distutils.version import LooseVersion from typing import Any, Dict import numpy as np @@ -1799,6 +1800,10 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 +@pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.2.0", + reason="surface plot requires newer matplotlib", +) class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) @@ -1871,6 +1876,34 @@ def test_convenient_facetgrid(self): assert "y" == ax.get_ylabel() assert "x" == ax.get_xlabel() + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_viridis_cmap(self): + return super().test_viridis_cmap() + + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_can_change_default_cmap(self): + return super().test_can_change_default_cmap() + + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_colorbar_default_label(self): + return super().test_colorbar_default_label() + + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_facetgrid_map_only_appends_mappables(self): + return super().test_facetgrid_map_only_appends_mappables() + class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) @@ -2657,6 +2690,10 @@ def test_yticks_kwarg(self, da): @requires_matplotlib @pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour", "surface"]) +@pytest.mark.skipif( + "plotfunc" == "surface" and LooseVersion(mpl.__version__) <= "3.2.0", + reason="surface plot requires newer matplotlib", +) def test_plot_transposed_nondim_coord(plotfunc): x = np.linspace(0, 10, 101) h = np.linspace(3, 7, 101) From b0a1f406e7bc5df90f92a842fede17d389dcd020 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 15:38:00 +0100 Subject: [PATCH 15/31] version check requires matplotlib --- xarray/tests/test_plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 5824a96409f..c2592471b43 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1800,6 +1800,7 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 +@requires_matplotlib @pytest.mark.skipif( LooseVersion(mpl.__version__) < "3.2.0", reason="surface plot requires newer matplotlib", From e592e5e0528c596a62bf6a5b20c464c8e568cb49 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 15:45:56 +0100 Subject: [PATCH 16/31] Handle matplotlib not installed for TestSurface version check --- xarray/tests/test_plot.py | 183 ++++++++++++++++++++------------------ 1 file changed, 95 insertions(+), 88 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index c2592471b43..4f3fb02d23c 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1800,110 +1800,117 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 -@requires_matplotlib -@pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.2.0", - reason="surface plot requires newer matplotlib", -) -class TestSurface(Common2dMixin, PlotTestCase): +# The try/except/else is needed for the matplotlib version check, to handle the case +# when matplotlib is not installed. It should be possible to remove it once we require +# matplotlib>=3.2.0 +try: + import matplotlib as mpl +except ImportError: + pass +else: + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.2.0", + reason="surface plot requires newer matplotlib", + ) + class TestSurface(Common2dMixin, PlotTestCase): - plotfunc = staticmethod(xplt.surface) - subplot_kws = {"projection": "3d"} + plotfunc = staticmethod(xplt.surface) + subplot_kws = {"projection": "3d"} - def test_primitive_artist_returned(self): - artist = self.plotmethod() - assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) + def test_primitive_artist_returned(self): + artist = self.plotmethod() + assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) - @pytest.mark.slow - def test_2d_coord_names(self): - self.plotmethod(x="x2d", y="y2d") - # make sure labels came out ok - ax = plt.gca() - assert "x2d" == ax.get_xlabel() - assert "y2d" == ax.get_ylabel() - assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() - - def test_xyincrease_false_changes_axes(self): - # Does not make sense for surface plots - pass - - def test_xyincrease_true_changes_axes(self): - # Does not make sense for surface plots - pass + @pytest.mark.slow + def test_2d_coord_names(self): + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() - def test_can_pass_in_axis(self): - self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) + def test_xyincrease_false_changes_axes(self): + # Does not make sense for surface plots + pass - def test_default_cmap(self): - # Does not make sense for surface plots with default arguments - pass + def test_xyincrease_true_changes_axes(self): + # Does not make sense for surface plots + pass - def test_diverging_color_limits(self): - # Does not make sense for surface plots with default arguments - pass + def test_can_pass_in_axis(self): + self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) - def test_colorbar_kwargs(self): - # Does not make sense for surface plots with default arguments - pass + def test_default_cmap(self): + # Does not make sense for surface plots with default arguments + pass - def test_cmap_and_color_both(self): - # Does not make sense for surface plots with default arguments - pass + def test_diverging_color_limits(self): + # Does not make sense for surface plots with default arguments + pass - def test_seaborn_palette_as_cmap(self): - # seaborn does not work with mpl_toolkits.mplot3d - with pytest.raises(ValueError): - super().test_seaborn_palette_as_cmap() + def test_colorbar_kwargs(self): + # Does not make sense for surface plots with default arguments + pass - # Need to modify this test for surface(), because all subplots should have labels, - # not just left and bottom - @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid(self): - a = easy_array((10, 15, 4)) - d = DataArray(a, dims=["y", "x", "z"]) - g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + def test_cmap_and_color_both(self): + # Does not make sense for surface plots with default arguments + pass - assert_array_equal(g.axes.shape, [2, 2]) - for (y, x), ax in np.ndenumerate(g.axes): - assert ax.has_data() - assert "y" == ax.get_ylabel() - assert "x" == ax.get_xlabel() + def test_seaborn_palette_as_cmap(self): + # seaborn does not work with mpl_toolkits.mplot3d + with pytest.raises(ValueError): + super().test_seaborn_palette_as_cmap() + + # Need to modify this test for surface(), because all subplots should have labels, + # not just left and bottom + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self): + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) + + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() - # Infering labels - g = self.plotfunc(d, col="z", col_wrap=2) - assert_array_equal(g.axes.shape, [2, 2]) - for (y, x), ax in np.ndenumerate(g.axes): - assert ax.has_data() - assert "y" == ax.get_ylabel() - assert "x" == ax.get_xlabel() + # Infering labels + g = self.plotfunc(d, col="z", col_wrap=2) + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) - def test_viridis_cmap(self): - return super().test_viridis_cmap() + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_viridis_cmap(self): + return super().test_viridis_cmap() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) - def test_can_change_default_cmap(self): - return super().test_can_change_default_cmap() + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_can_change_default_cmap(self): + return super().test_can_change_default_cmap() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) - def test_colorbar_default_label(self): - return super().test_colorbar_default_label() + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_colorbar_default_label(self): + return super().test_colorbar_default_label() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) - def test_facetgrid_map_only_appends_mappables(self): - return super().test_facetgrid_map_only_appends_mappables() + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.3.0", + reason="this feature of surface plot requires newer matplotlib", + ) + def test_facetgrid_map_only_appends_mappables(self): + return super().test_facetgrid_map_only_appends_mappables() class TestFacetGrid(PlotTestCase): From 43a51e9bd5d5dbee2069cfac1c0d6150f2c35dea Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 15:51:59 +0100 Subject: [PATCH 17/31] fix flake8 error --- xarray/tests/test_plot.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 4f3fb02d23c..0b56797668f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1802,14 +1802,16 @@ def test_origin_overrides_xyincrease(self): # The try/except/else is needed for the matplotlib version check, to handle the case # when matplotlib is not installed. It should be possible to remove it once we require -# matplotlib>=3.2.0 +# matplotlib>=3.2.0. +# Note, importing as mpl2 to avoid redefining mpl, which is a flake8 error. try: - import matplotlib as mpl + import matplotlib as mpl2 except ImportError: pass else: + @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.2.0", + LooseVersion(mpl2.__version__) < "3.2.0", reason="surface plot requires newer matplotlib", ) class TestSurface(Common2dMixin, PlotTestCase): From ea431773697f93dddb59b47be06859773fd91cea Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 1 Apr 2021 16:12:16 +0100 Subject: [PATCH 18/31] Don't run test_plot_transposed_nondim_coord for surface plots Too complicated to check matplotlib version is high enough just for surface plots. --- xarray/tests/test_plot.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 0b56797668f..96fc7a279dd 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2699,11 +2699,7 @@ def test_yticks_kwarg(self, da): @requires_matplotlib -@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour", "surface"]) -@pytest.mark.skipif( - "plotfunc" == "surface" and LooseVersion(mpl.__version__) <= "3.2.0", - reason="surface plot requires newer matplotlib", -) +@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"]) def test_plot_transposed_nondim_coord(plotfunc): x = np.linspace(0, 10, 101) h = np.linspace(3, 7, 101) From 648e13b0852fc89eb4d312a2b15d303f3f7a9ba7 Mon Sep 17 00:00:00 2001 From: johnomotani Date: Tue, 20 Apr 2021 20:40:06 +0100 Subject: [PATCH 19/31] Apply suggestions from code review Co-authored-by: Mathias Hauser --- xarray/plot/plot.py | 6 ++---- xarray/tests/test_plot.py | 3 ++- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index c8699323597..e75cefc2450 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -651,7 +651,7 @@ def newplotfunc( if subplot_kws is None: subplot_kws = dict() - if "surface" == plotfunc.__name__ and not kwargs.get("_is_facetgrid", False): + if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): # Check we have new enough version of matplotlib from distutils.version import LooseVersion @@ -659,8 +659,6 @@ def newplotfunc( if LooseVersion(mpl.__version__) < "3.2.0": raise ValueError("surface plot requires at least matplotlib-3.2.0") - del LooseVersion - del mpl if ax is None: # Need to create a "3d" Axes instance for surface plots @@ -683,7 +681,7 @@ def newplotfunc( plt = import_matplotlib_pyplot() if ( - "surface" == plotfunc.__name__ + plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False) and ax is not None ): diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 96fc7a279dd..b86e89fc94d 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -135,7 +135,8 @@ def setup(self): # Remove all matplotlib figures plt.close("all") - def pass_in_axis(self, plotmethod, subplot_kw={}): + def pass_in_axis(self, plotmethod, subplot_kw=None): + subplot_kw = {} if subplot_kw is None else subplot_kw fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw) plotmethod(ax=axes[0]) assert axes[0].has_data() From 313daf09350bd9b469063d9e53f0a1c8eac16c3f Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 20 Apr 2021 21:02:00 +0100 Subject: [PATCH 20/31] More suggestions from code review --- xarray/plot/plot.py | 4 +++- xarray/tests/__init__.py | 4 ++++ xarray/tests/test_plot.py | 40 +++++++++++++-------------------------- 3 files changed, 20 insertions(+), 28 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index e75cefc2450..78f509320ce 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -8,6 +8,7 @@ """ import functools +from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -653,13 +654,14 @@ def newplotfunc( if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): # Check we have new enough version of matplotlib - from distutils.version import LooseVersion import matplotlib as mpl if LooseVersion(mpl.__version__) < "3.2.0": raise ValueError("surface plot requires at least matplotlib-3.2.0") + del mpl + if ax is None: # Need to create a "3d" Axes instance for surface plots subplot_kws["projection"] = "3d" diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index aebcb0f2b8d..04062f21662 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -60,6 +60,10 @@ def LooseVersion(vstring): has_matplotlib, requires_matplotlib = _importorskip("matplotlib") +has_matplotlib_3_2_0, requires_matplotlib_3_2_0 = _importorskip("matplotlib", + minversion="3.2.0") +has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip("matplotlib", + minversion="3.3.0") has_scipy, requires_scipy = _importorskip("scipy") has_pydap, requires_pydap = _importorskip("pydap.client") has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b86e89fc94d..9c1154b0e5b 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2,7 +2,6 @@ import inspect from copy import copy from datetime import datetime -from distutils.version import LooseVersion from typing import Any, Dict import numpy as np @@ -30,6 +29,8 @@ requires_cartopy, requires_cftime, requires_matplotlib, + requires_matplotlib_3_2_0, + requires_matplotlib_3_3_0, requires_nc_time_axis, requires_seaborn, ) @@ -1811,10 +1812,7 @@ def test_origin_overrides_xyincrease(self): pass else: - @pytest.mark.skipif( - LooseVersion(mpl2.__version__) < "3.2.0", - reason="surface plot requires newer matplotlib", - ) + @requires_matplotlib_3_2_0 class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) @@ -1835,30 +1833,30 @@ def test_2d_coord_names(self): def test_xyincrease_false_changes_axes(self): # Does not make sense for surface plots - pass + pytest.skip("does not make sense for surface plots") def test_xyincrease_true_changes_axes(self): # Does not make sense for surface plots - pass + pytest.skip("does not make sense for surface plots") def test_can_pass_in_axis(self): self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) def test_default_cmap(self): # Does not make sense for surface plots with default arguments - pass + pytest.skip("does not make sense for surface plots") def test_diverging_color_limits(self): # Does not make sense for surface plots with default arguments - pass + pytest.skip("does not make sense for surface plots") def test_colorbar_kwargs(self): # Does not make sense for surface plots with default arguments - pass + pytest.skip("does not make sense for surface plots") def test_cmap_and_color_both(self): # Does not make sense for surface plots with default arguments - pass + pytest.skip("does not make sense for surface plots") def test_seaborn_palette_as_cmap(self): # seaborn does not work with mpl_toolkits.mplot3d @@ -1887,31 +1885,19 @@ def test_convenient_facetgrid(self): assert "y" == ax.get_ylabel() assert "x" == ax.get_xlabel() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) + @requires_matplotlib_3_3_0 def test_viridis_cmap(self): return super().test_viridis_cmap() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) + @requires_matplotlib_3_3_0 def test_can_change_default_cmap(self): return super().test_can_change_default_cmap() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) + @requires_matplotlib_3_3_0 def test_colorbar_default_label(self): return super().test_colorbar_default_label() - @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.3.0", - reason="this feature of surface plot requires newer matplotlib", - ) + @requires_matplotlib_3_3_0 def test_facetgrid_map_only_appends_mappables(self): return super().test_facetgrid_map_only_appends_mappables() From a566744cf071384b2786870390bc21967d714ec1 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 20 Apr 2021 21:02:37 +0100 Subject: [PATCH 21/31] black --- xarray/tests/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 04062f21662..94c4604834a 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -60,10 +60,12 @@ def LooseVersion(vstring): has_matplotlib, requires_matplotlib = _importorskip("matplotlib") -has_matplotlib_3_2_0, requires_matplotlib_3_2_0 = _importorskip("matplotlib", - minversion="3.2.0") -has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip("matplotlib", - minversion="3.3.0") +has_matplotlib_3_2_0, requires_matplotlib_3_2_0 = _importorskip( + "matplotlib", minversion="3.2.0" +) +has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip( + "matplotlib", minversion="3.3.0" +) has_scipy, requires_scipy = _importorskip("scipy") has_pydap, requires_pydap = _importorskip("pydap.client") has_netCDF4, requires_netCDF4 = _importorskip("netCDF4") From 817d30547ba5c7ce880fb9f6857e57a2b98f5ccd Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 20 Apr 2021 21:06:47 +0100 Subject: [PATCH 22/31] isort and flake8 --- xarray/plot/plot.py | 2 +- xarray/tests/test_plot.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 78f509320ce..66ffa28453e 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -7,8 +7,8 @@ Dataset.plot._____ """ import functools - from distutils.version import LooseVersion + import numpy as np import pandas as pd diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 9c1154b0e5b..4f1d546c7c7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1807,7 +1807,7 @@ def test_origin_overrides_xyincrease(self): # matplotlib>=3.2.0. # Note, importing as mpl2 to avoid redefining mpl, which is a flake8 error. try: - import matplotlib as mpl2 + import matplotlib as mpl2 # noqa: F401 except ImportError: pass else: From 99459cc8c15bde8495756b94f59e0ba744181788 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Tue, 20 Apr 2021 23:46:47 +0100 Subject: [PATCH 23/31] Make surface plots more backward compatible Following suggestion from Illviljan --- xarray/plot/plot.py | 14 +++++--------- xarray/tests/__init__.py | 3 --- xarray/tests/test_plot.py | 2 -- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 66ffa28453e..e132ee0f357 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -7,7 +7,6 @@ Dataset.plot._____ """ import functools -from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -653,16 +652,13 @@ def newplotfunc( subplot_kws = dict() if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): - # Check we have new enough version of matplotlib - - import matplotlib as mpl - - if LooseVersion(mpl.__version__) < "3.2.0": - raise ValueError("surface plot requires at least matplotlib-3.2.0") + if ax is None: + # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 - del mpl + del Axes3D - if ax is None: # Need to create a "3d" Axes instance for surface plots subplot_kws["projection"] = "3d" diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 94c4604834a..5061c7dc216 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -60,9 +60,6 @@ def LooseVersion(vstring): has_matplotlib, requires_matplotlib = _importorskip("matplotlib") -has_matplotlib_3_2_0, requires_matplotlib_3_2_0 = _importorskip( - "matplotlib", minversion="3.2.0" -) has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip( "matplotlib", minversion="3.3.0" ) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 4f1d546c7c7..0f8046f69a7 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -29,7 +29,6 @@ requires_cartopy, requires_cftime, requires_matplotlib, - requires_matplotlib_3_2_0, requires_matplotlib_3_3_0, requires_nc_time_axis, requires_seaborn, @@ -1812,7 +1811,6 @@ def test_origin_overrides_xyincrease(self): pass else: - @requires_matplotlib_3_2_0 class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) From f86f76def777de718a9456f07cf8b850236f5b2d Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 21 Apr 2021 09:38:20 +0100 Subject: [PATCH 24/31] Clean up matplotlib requirement --- xarray/tests/test_plot.py | 150 ++++++++++++++++++-------------------- 1 file changed, 70 insertions(+), 80 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 0f8046f69a7..f9ed07bb048 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1801,103 +1801,93 @@ def test_origin_overrides_xyincrease(self): assert plt.ylim()[0] < 0 -# The try/except/else is needed for the matplotlib version check, to handle the case -# when matplotlib is not installed. It should be possible to remove it once we require -# matplotlib>=3.2.0. -# Note, importing as mpl2 to avoid redefining mpl, which is a flake8 error. -try: - import matplotlib as mpl2 # noqa: F401 -except ImportError: - pass -else: +class TestSurface(Common2dMixin, PlotTestCase): - class TestSurface(Common2dMixin, PlotTestCase): + plotfunc = staticmethod(xplt.surface) + subplot_kws = {"projection": "3d"} - plotfunc = staticmethod(xplt.surface) - subplot_kws = {"projection": "3d"} + def test_primitive_artist_returned(self): + artist = self.plotmethod() + assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) - def test_primitive_artist_returned(self): - artist = self.plotmethod() - assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) + @pytest.mark.slow + def test_2d_coord_names(self): + self.plotmethod(x="x2d", y="y2d") + # make sure labels came out ok + ax = plt.gca() + assert "x2d" == ax.get_xlabel() + assert "y2d" == ax.get_ylabel() + assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() - @pytest.mark.slow - def test_2d_coord_names(self): - self.plotmethod(x="x2d", y="y2d") - # make sure labels came out ok - ax = plt.gca() - assert "x2d" == ax.get_xlabel() - assert "y2d" == ax.get_ylabel() - assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() + def test_xyincrease_false_changes_axes(self): + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") - def test_xyincrease_false_changes_axes(self): - # Does not make sense for surface plots - pytest.skip("does not make sense for surface plots") + def test_xyincrease_true_changes_axes(self): + # Does not make sense for surface plots + pytest.skip("does not make sense for surface plots") - def test_xyincrease_true_changes_axes(self): - # Does not make sense for surface plots - pytest.skip("does not make sense for surface plots") + def test_can_pass_in_axis(self): + self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) - def test_can_pass_in_axis(self): - self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) + def test_default_cmap(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") - def test_default_cmap(self): - # Does not make sense for surface plots with default arguments - pytest.skip("does not make sense for surface plots") + def test_diverging_color_limits(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") - def test_diverging_color_limits(self): - # Does not make sense for surface plots with default arguments - pytest.skip("does not make sense for surface plots") + def test_colorbar_kwargs(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") - def test_colorbar_kwargs(self): - # Does not make sense for surface plots with default arguments - pytest.skip("does not make sense for surface plots") + def test_cmap_and_color_both(self): + # Does not make sense for surface plots with default arguments + pytest.skip("does not make sense for surface plots") + + def test_seaborn_palette_as_cmap(self): + # seaborn does not work with mpl_toolkits.mplot3d + with pytest.raises(ValueError): + super().test_seaborn_palette_as_cmap() - def test_cmap_and_color_both(self): - # Does not make sense for surface plots with default arguments - pytest.skip("does not make sense for surface plots") + # Need to modify this test for surface(), because all subplots should have labels, + # not just left and bottom + @pytest.mark.filterwarnings("ignore:tight_layout cannot") + def test_convenient_facetgrid(self): + a = easy_array((10, 15, 4)) + d = DataArray(a, dims=["y", "x", "z"]) + g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) - def test_seaborn_palette_as_cmap(self): - # seaborn does not work with mpl_toolkits.mplot3d - with pytest.raises(ValueError): - super().test_seaborn_palette_as_cmap() - - # Need to modify this test for surface(), because all subplots should have labels, - # not just left and bottom - @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid(self): - a = easy_array((10, 15, 4)) - d = DataArray(a, dims=["y", "x", "z"]) - g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) - - assert_array_equal(g.axes.shape, [2, 2]) - for (y, x), ax in np.ndenumerate(g.axes): - assert ax.has_data() - assert "y" == ax.get_ylabel() - assert "x" == ax.get_xlabel() + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() - # Infering labels - g = self.plotfunc(d, col="z", col_wrap=2) - assert_array_equal(g.axes.shape, [2, 2]) - for (y, x), ax in np.ndenumerate(g.axes): - assert ax.has_data() - assert "y" == ax.get_ylabel() - assert "x" == ax.get_xlabel() + # Infering labels + g = self.plotfunc(d, col="z", col_wrap=2) + assert_array_equal(g.axes.shape, [2, 2]) + for (y, x), ax in np.ndenumerate(g.axes): + assert ax.has_data() + assert "y" == ax.get_ylabel() + assert "x" == ax.get_xlabel() - @requires_matplotlib_3_3_0 - def test_viridis_cmap(self): - return super().test_viridis_cmap() + @requires_matplotlib_3_3_0 + def test_viridis_cmap(self): + return super().test_viridis_cmap() - @requires_matplotlib_3_3_0 - def test_can_change_default_cmap(self): - return super().test_can_change_default_cmap() + @requires_matplotlib_3_3_0 + def test_can_change_default_cmap(self): + return super().test_can_change_default_cmap() - @requires_matplotlib_3_3_0 - def test_colorbar_default_label(self): - return super().test_colorbar_default_label() + @requires_matplotlib_3_3_0 + def test_colorbar_default_label(self): + return super().test_colorbar_default_label() - @requires_matplotlib_3_3_0 - def test_facetgrid_map_only_appends_mappables(self): - return super().test_facetgrid_map_only_appends_mappables() + @requires_matplotlib_3_3_0 + def test_facetgrid_map_only_appends_mappables(self): + return super().test_facetgrid_map_only_appends_mappables() class TestFacetGrid(PlotTestCase): From 7b6f470899403f977975bd047d8a3423592f18fe Mon Sep 17 00:00:00 2001 From: johnomotani Date: Wed, 28 Apr 2021 00:47:12 +0100 Subject: [PATCH 25/31] Update xarray/plot/plot.py Co-authored-by: Mathias Hauser --- xarray/plot/plot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index e132ee0f357..12660d31220 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -657,6 +657,7 @@ def newplotfunc( # Remove when minimum requirement of matplotlib is 3.2: from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 + # delete so it does not end up in locals() del Axes3D # Need to create a "3d" Axes instance for surface plots From 518110c30a26809d2f351243d39dacaec377fae7 Mon Sep 17 00:00:00 2001 From: johnomotani Date: Wed, 28 Apr 2021 17:10:23 +0100 Subject: [PATCH 26/31] Apply suggestions from code review Co-authored-by: Mathias Hauser --- doc/whats-new.rst | 3 --- xarray/plot/plot.py | 8 +++++--- xarray/tests/test_plot.py | 1 - 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cb3c8f88250..e25a7f6d99d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -113,9 +113,6 @@ Deprecations Bug fixes ~~~~~~~~~ -- Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is - 2d (:issue:`5079`, :pull:`5099`). - By `John Omotani `_ - Properly support :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill`, :py:meth:`Dataset.bfill` along chunked dimensions. (:issue:`2699`).By `Deepak Cherian `_. - Fix 2d plot failure for certain combinations of dimensions when `x` is 1d and `y` is diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 5a11327eceb..e14f1f24583 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -633,9 +633,11 @@ def newplotfunc( # Decide on a default for the colorbar before facetgrids if add_colorbar is None: - add_colorbar = plotfunc.__name__ != "contour" and not ( + add_colorbar = True + if plotfunc.__name__ == "contour" or ( plotfunc.__name__ == "surface" and cmap is None - ) + ): + add_colorbar = False imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( 3 + (row is not None) + (col is not None) ) @@ -653,7 +655,7 @@ def newplotfunc( if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): if ax is None: - # TODO: Importing Axes3D is not necessary in matplotlib >= 3.2. + # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. # Remove when minimum requirement of matplotlib is 3.2: from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 7815d2f8342..b4eaff40e05 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -135,7 +135,6 @@ def setup(self): plt.close("all") def pass_in_axis(self, plotmethod, subplot_kw=None): - subplot_kw = {} if subplot_kw is None else subplot_kw fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw) plotmethod(ax=axes[0]) assert axes[0].has_data() From 84b3e6dd1cfae4cc9f07490573abfc6cf90418ce Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 28 Apr 2021 17:20:06 +0100 Subject: [PATCH 27/31] Use None as default value --- xarray/tests/test_plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index b4eaff40e05..e71bcaa359c 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2,7 +2,7 @@ import inspect from copy import copy from datetime import datetime -from typing import Any, Dict +from typing import Any, Dict, Union import numpy as np import pandas as pd @@ -1110,7 +1110,7 @@ class Common2dMixin: """ # Needs to be overridden in TestSurface for facet grid plots - subplot_kws: Dict[Any, Any] = {} + subplot_kws: Union[Dict[Any, Any], None] = None @pytest.fixture(autouse=True) def setUp(self): From 08b9117e24a330cdea30c2988f140ec40214243b Mon Sep 17 00:00:00 2001 From: John Omotani Date: Wed, 28 Apr 2021 17:25:12 +0100 Subject: [PATCH 28/31] black --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index e14f1f24583..e6eb7ecbe0b 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -637,7 +637,7 @@ def newplotfunc( if plotfunc.__name__ == "contour" or ( plotfunc.__name__ == "surface" and cmap is None ): - add_colorbar = False + add_colorbar = False imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( 3 + (row is not None) + (col is not None) ) From c96484830ac2c5058c6b028ed269fda35cee223a Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 29 Apr 2021 13:55:43 +0100 Subject: [PATCH 29/31] More 2D plotting method examples in docs --- doc/user-guide/plotting.rst | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index 098c63d0e40..3abcce85651 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -411,6 +411,36 @@ produce plots with nonuniform coordinates. @savefig plotting_nonuniform_coords.png width=4in b.plot() +==================== + Other types of plot +==================== + +There are several other options for plotting 2D data. + +Contour plot using `DataArray.plot.contour()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contour.png width=4in + air2d.plot.contour() + +Filled contour plot using `DataArray.plot.contourf()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contourf.png width=4in + air2d.plot.contourf() + +Surface plot using `DataArray.plot.surface()` + +.. ipython:: python + :okwarning: + + @savefig plotting_contourf.png width=4in + air2d.plot.surface() + ==================== Calling Matplotlib ==================== From 50152b3c8d0eedaedd6edb7f20ba6771dda3f7f2 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 29 Apr 2021 14:37:11 +0100 Subject: [PATCH 30/31] Fix docs --- doc/user-guide/plotting.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index 3abcce85651..8e326c17da5 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -417,7 +417,7 @@ produce plots with nonuniform coordinates. There are several other options for plotting 2D data. -Contour plot using `DataArray.plot.contour()` +Contour plot using :py:meth:`DataArray.plot.contour()` .. ipython:: python :okwarning: @@ -425,7 +425,7 @@ Contour plot using `DataArray.plot.contour()` @savefig plotting_contour.png width=4in air2d.plot.contour() -Filled contour plot using `DataArray.plot.contourf()` +Filled contour plot using :py:meth:`DataArray.plot.contourf()` .. ipython:: python :okwarning: @@ -433,12 +433,12 @@ Filled contour plot using `DataArray.plot.contourf()` @savefig plotting_contourf.png width=4in air2d.plot.contourf() -Surface plot using `DataArray.plot.surface()` +Surface plot using :py:meth:`DataArray.plot.surface()` .. ipython:: python :okwarning: - @savefig plotting_contourf.png width=4in + @savefig plotting_surface.png width=4in air2d.plot.surface() ==================== From 4831b8bc9bac7b25b62730c9391e19459cf5205e Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 29 Apr 2021 15:04:08 +0100 Subject: [PATCH 31/31] [skip-ci] Make example surface plot look a bit nicer --- doc/user-guide/plotting.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index 8e326c17da5..f1c76b21488 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -439,7 +439,8 @@ Surface plot using :py:meth:`DataArray.plot.surface()` :okwarning: @savefig plotting_surface.png width=4in - air2d.plot.surface() + # transpose just to make the example look a bit nicer + air2d.T.plot.surface() ==================== Calling Matplotlib