diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 924d92d83..89b1c4662 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3150,6 +3150,43 @@ def set_prop_cycle(self, *args, **kwargs): cycle = self._active_cycle = constructor.Cycle(*args, **kwargs) return super().set_prop_cycle(cycle) # set the property cycler after validation + def _is_panel_group_member(self, other: "Axes") -> bool: + """ + Determine if the current axes and another axes belong to the same panel group. + + Two axes belong to the same panel group if any of the following is true: + 1. One axis is the parent of the other + 2. Both axes are panels sharing the same parent + + Parameters + ---------- + other : Axes + The other axes to compare with + + Returns + ------- + bool + True if both axes belong to the same panel group, False otherwise + """ + # Case 1: self is a panel of other (other is the parent) + if self._panel_parent is other: + return True + + # Case 2: other is a panel of self (self is the parent) + if other._panel_parent is self: + return True + + # Case 3: both are panels of the same parent + if ( + self._panel_parent + and other._panel_parent + and self._panel_parent is other._panel_parent + ): + return True + + # Not in the same panel group + return False + @docstring._snippet_manager def inset(self, *args, **kwargs): """ diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index df6c1a6c9..b64ea84b3 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -551,18 +551,6 @@ def _get_spine_side(self, s, loc): ) return side - def _is_panel_group_member(self, other): - """ - Return whether the axes belong in a panel sharing stack.. - """ - return ( - self._panel_parent is other # other is child panel - or other._panel_parent is self # other is main subplot - or other._panel_parent - and self._panel_parent # ... - and other._panel_parent is self._panel_parent # other is sibling panel - ) - def _sharex_limits(self, sharex): """ Safely share limits and tickers without resetting things. diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 500565c9f..57d3bc892 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -566,7 +566,7 @@ def __share_axis_setup( limits: bool, ): level = getattr(self.figure, f"_share{which}") - if getattr(self, f"_panel_share{which}_group") and self.is_panel_group_member( + if getattr(self, f"_panel_share{which}_group") and self._is_panel_group_member( other ): level = 3 diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 7403b3785..4a737a013 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -665,3 +665,77 @@ def test_check_tricontourf(): assert "transform" in mocked.call_args.kwargs assert isinstance(mocked.call_args.kwargs["transform"], ccrs.PlateCarree) uplt.close(fig) + + +def test_panels_geo(): + fig, ax = uplt.subplots(proj="cyl") + ax.format(labels=True) + for dir in "top bottom right left".split(): + pax = ax.panel_axes(dir) + match dir: + case "top": + assert len(pax.get_xticklabels()) > 0 + assert len(pax.get_yticklabels()) > 0 + case "bottom": + assert len(pax.get_xticklabels()) > 0 + assert len(pax.get_yticklabels()) > 0 + case "left": + assert len(pax.get_xticklabels()) > 0 + assert len(pax.get_yticklabels()) > 0 + case "right": + assert len(pax.get_xticklabels()) > 0 + assert len(pax.get_yticklabels()) > 0 + + +@pytest.mark.mpl_image_compare +def test_geo_with_panels(): + """ + We are allowed to add panels in GeoPlots + """ + # Define coordinates + lat = np.linspace(-90, 90, 180) + lon = np.linspace(-180, 180, 360) + time = np.arange(2000, 2005) + lon_grid, lat_grid = np.meshgrid(lon, lat) + + # Zoomed region elevation (Asia region) + lat_zoom = np.linspace(0, 60, 60) + lon_zoom = np.linspace(60, 180, 120) + lz, lz_grid = np.meshgrid(lon_zoom, lat_zoom) + + elevation = ( + 2000 * np.exp(-((lz - 90) ** 2 + (lz_grid - 30) ** 2) / 400) + + 1000 * np.exp(-((lz - 120) ** 2 + (lz_grid - 45) ** 2) / 225) + + np.random.normal(0, 100, lz.shape) + ) + elevation = np.clip(elevation, 0, 4000) + + fig, ax = uplt.subplots(nrows=2, proj="cyl") + pax = ax[0].panel("r") + pax.barh(lat_zoom, elevation.sum(axis=1)) + pax = ax[1].panel("r") + pax.barh(lat_zoom - 30, elevation.sum(axis=1)) + ax[0].pcolormesh( + lon_zoom, + lat_zoom, + elevation, + cmap="bilbao", + colorbar="t", + colorbar_kw=dict( + align="l", + length=0.5, + ), + ) + ax[1].pcolormesh( + lon_zoom, + lat_zoom - 30, + elevation, + cmap="glacial", + colorbar="t", + colorbar_kw=dict( + align="r", + length=0.5, + ), + ) + ax.format(oceancolor="blue", coast=True) + return fig