diff --git a/.gitignore b/.gitignore index bbd6bf100..0bf4ea4e1 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ sources # Python extras .ipynb_checkpoints *.log +*.ipnyb *.pyc .*.pyc __pycache__ diff --git a/pyproject.toml b/pyproject.toml index 2e0aee22b..1ad8b9c37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,3 +50,6 @@ include-package-data = true [tool.setuptools_scm] write_to = "ultraplot/_version.py" write_to_template = "__version__ = '{version}'\n" + +[tool.pytest.ini_options] +addopts = "--ignore=dbt_packages" diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 22e489d18..c140b0161 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -3261,6 +3261,26 @@ def _is_panel_group_member(self, other: "Axes") -> bool: # Not in the same panel group return False + def _label_key(self, side: str) -> str: + """ + Map requested side name to the correct tick_params key across mpl versions. + + This accounts for the API change around Matplotlib 3.10 where labeltop/labelbottom + became first-class tick parameter keys. For older versions, these map to + labelright/labelleft respectively. + """ + from packaging import version + from ..internals import _version_mpl + #TODO: internal deprecation warning when we drop 3.9, we need to remove this + + use_new = version.parse(str(_version_mpl)) >= version.parse("3.10") + if side == "labeltop": + return "labeltop" if use_new else "labelright" + if side == "labelbottom": + return "labelbottom" if use_new else "labelleft" + # "labelleft" and "labelright" are stable across versions + return side + def _is_ticklabel_on(self, side: str) -> bool: """ Check if tick labels are on for the specified sides. @@ -3274,10 +3294,8 @@ def _is_ticklabel_on(self, side: str) -> bool: label = "label1" if side in ["labelright", "labeltop"]: label = "label2" - for tick in axis.get_major_ticks(): - if getattr(tick, label).get_visible(): - return True - return False + + return axis.get_tick_params().get(self._label_key(side), False) @docstring._snippet_manager def inset(self, *args, **kwargs): diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 46685b5df..0678a2a8e 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -386,9 +386,7 @@ def _apply_axis_sharing(self): # bottommost or to the *right* of the leftmost panel. But the sharing level # used for the leftmost and bottommost is the *figure* sharing level. - # Get border axes once for efficiency border_axes = self.figure._get_border_axes() - # Apply X axis sharing self._apply_axis_sharing_for_axis("x", border_axes) @@ -412,128 +410,31 @@ def _apply_axis_sharing_for_axis( """ if axis_name == "x": axis = self.xaxis - shared_axis = self._sharex - panel_group = self._panel_sharex_group + shared_axis = self._sharex # do we share the xaxis? + panel_group = self._panel_sharex_group # do we have a panel? sharing_level = self.figure._sharex - label_params = ["labeltop", "labelbottom"] - border_sides = ["top", "bottom"] else: # axis_name == 'y' axis = self.yaxis shared_axis = self._sharey panel_group = self._panel_sharey_group sharing_level = self.figure._sharey - label_params = ["labelleft", "labelright"] - border_sides = ["left", "right"] - if shared_axis is None or not axis.get_visible(): + if not axis.get_visible(): return level = 3 if panel_group else sharing_level # Handle axis label sharing (level > 0) - if level > 0: + # If we are a border axis, @shared_axis may be None + # We propagate this through the _determine_tick_label_visiblity() logic + if level > 0 and shared_axis: shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") labels._transfer_label(axis.label, shared_axis_obj.label) axis.label.set_visible(False) - # Handle tick label sharing (level > 2) - if level > 2: - label_visibility = self._determine_tick_label_visibility( - axis, - shared_axis, - axis_name, - label_params, - border_sides, - border_axes, - ) - axis.set_tick_params(which="both", **label_visibility) # Turn minor ticks off axis.set_minor_formatter(mticker.NullFormatter()) - def _determine_tick_label_visibility( - self, - axis: maxis.Axis, - shared_axis: maxis.Axis, - axis_name: str, - label_params: list[str], - border_sides: list[str], - border_axes: dict[str, list[plot.PlotAxes]], - ) -> dict[str, bool]: - """ - Determine which tick labels should be visible based on sharing rules and borders. - - Parameters - ---------- - axis : matplotlib axis - The current axis object - shared_axis : Axes - The axes this one shares with - axis_name : str - Either 'x' or 'y' - label_params : list - List of label parameter names (e.g., ['labeltop', 'labelbottom']) - border_sides : list - List of border side names (e.g., ['top', 'bottom']) - border_axes : dict - Dictionary from _get_border_axes() - - Returns - ------- - dict - Dictionary of label visibility parameters - """ - ticks = axis.get_tick_params() - shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") - sharing_ticks = shared_axis_obj.get_tick_params() - - label_visibility = {} - - def _convert_label_param(label_param: str) -> str: - # Deal with logic not being consistent - # in prior mpl versions - if version.parse(str(_version_mpl)) <= version.parse("3.9"): - if label_param == "labeltop" and axis_name == "x": - label_param = "labelright" - elif label_param == "labelbottom" and axis_name == "x": - label_param = "labelleft" - return label_param - - for label_param, border_side in zip(label_params, border_sides): - # Check if user has explicitly set label location via format() - label_visibility[label_param] = False - has_panel = False - for panel in self._panel_dict[border_side]: - # Check if the panel is a colorbar - colorbars = [ - values - for key, values in self._colorbar_dict.items() - if border_side in key # key is tuple (side, top | center | lower) - ] - if not panel in colorbars: - # Skip colorbar as their - # yaxis is not shared - has_panel = True - break - # When we have a panel, let the panel have - # the labels and turn-off for this axis + side. - if has_panel: - continue - is_border = self in border_axes.get(border_side, []) - is_panel = ( - self in shared_axis._panel_dict[border_side] - and self == shared_axis._panel_dict[border_side][-1] - ) - # Use automatic border detection logic - # if we are a panel we "push" the labels outwards - label_param_trans = _convert_label_param(label_param) - is_this_tick_on = ticks[label_param_trans] - is_parent_tick_on = sharing_ticks[label_param_trans] - if is_panel: - label_visibility[label_param] = is_parent_tick_on - elif is_border: - label_visibility[label_param] = is_this_tick_on - return label_visibility - def _add_alt(self, sx, **kwargs): """ Add an alternate axes. diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 896bc0a6d..15c5f9a43 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -652,27 +652,16 @@ def _apply_axis_sharing(self): or to the *right* of the leftmost panel. But the sharing level used for the leftmost and bottommost is the *figure* sharing level. """ - # Handle X axis sharing - if self._sharex: - self._handle_axis_sharing( - source_axis=self._sharex._lonaxis, - target_axis=self._lonaxis, - ) - # Handle Y axis sharing - if self._sharey: - self._handle_axis_sharing( - source_axis=self._sharey._lataxis, - target_axis=self._lataxis, - ) + # Share interval x + if self._sharex and self.figure._sharex >= 2: + self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval()) + self._lonaxis.set_minor_locator(self._sharex._lonaxis.get_minor_locator()) - # This block is apart of the draw sequence as the - # gridliner object is created late in the - # build chain. - if not self.stale: - return - if self.figure._get_sharing_level() == 0: - return + # Share interval y + if self._sharey and self.figure._sharey >= 2: + self._lataxis.set_view_interval(*self._sharey._lataxis.get_view_interval()) + self._lataxis.set_minor_locator(self._sharey._lataxis.get_minor_locator()) def _get_gridliner_labels( self, @@ -691,38 +680,36 @@ def _toggle_gridliner_labels( labelright=None, geo=None, ): - # For BasemapAxes the gridlines are dicts with key as the coordinate and keys the line and label - # We override the dict here assuming the labels are mut excl due to the N S E W extra chars + """ + Toggle visibility of gridliner labels for each direction. + + Parameters + ---------- + labeltop, labelbottom, labelleft, labelright : bool or None + Whether to show labels on each side. If None, do not change. + geo : optional + Not used in this method. + """ + # Ensure gridlines_major is fully initialized if any(i is None for i in self.gridlines_major): return + gridlabels = self._get_gridliner_labels( bottom=labelbottom, top=labeltop, left=labelleft, right=labelright ) - bools = [labelbottom, labeltop, labelleft, labelright] - directions = "bottom top left right".split() - for direction, toggle in zip(directions, bools): + + toggles = { + "bottom": labelbottom, + "top": labeltop, + "left": labelleft, + "right": labelright, + } + + for direction, toggle in toggles.items(): if toggle is None: continue for label in gridlabels.get(direction, []): - label.set_visible(toggle) - - def _handle_axis_sharing( - self, - source_axis: "GeoAxes", - target_axis: "GeoAxes", - ): - """ - Helper method to handle axis sharing for both X and Y axes. - - Args: - source_axis: The source axis to share from - target_axis: The target axis to apply sharing to - """ - # Copy view interval and minor locator from source to target - - if self.figure._get_sharing_level() >= 2: - target_axis.set_view_interval(*source_axis.get_view_interval()) - target_axis.set_minor_locator(source_axis.get_minor_locator()) + label.set_visible(bool(toggle) or toggle in ("x", "y")) @override def draw(self, renderer=None, *args, **kwargs): @@ -1441,6 +1428,7 @@ def _is_ticklabel_on(self, side: str) -> bool: """ # Deal with different cartopy versions left_labels, right_labels, bottom_labels, top_labels = self._get_side_labels() + if self.gridlines_major is None: return False elif side == "labelleft": diff --git a/ultraplot/axes/polar.py b/ultraplot/axes/polar.py index d66e3e2ea..94950179d 100644 --- a/ultraplot/axes/polar.py +++ b/ultraplot/axes/polar.py @@ -4,6 +4,11 @@ """ import inspect +try: + from typing import override +except: + from typing_extensions import override + import matplotlib.projections.polar as mpolar import numpy as np @@ -138,6 +143,11 @@ def __init__(self, *args, **kwargs): for axis in (self.xaxis, self.yaxis): axis.set_tick_params(which="both", size=0) + @override + def _apply_axis_sharing(self): + # Not implemented. Silently pass + return + def _update_formatter(self, x, *, formatter=None, formatter_kw=None): """ Update the gridline label formatter. diff --git a/ultraplot/figure.py b/ultraplot/figure.py index d44f31e61..7a9410c73 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -6,6 +6,7 @@ import inspect import os from numbers import Integral +from packaging import version try: from typing import List @@ -20,6 +21,11 @@ import matplotlib.transforms as mtransforms import numpy as np +try: + from typing import override +except: + from typing_extensions import override + from . import axes as paxes from . import constructor from . import gridspec as pgridspec @@ -477,6 +483,21 @@ def _canvas_preprocess(self, *args, **kwargs): return canvas +def _clear_border_cache(func): + """ + Decorator that clears the border cache after function execution. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if hasattr(self, "_cache_border_axes"): + delattr(self, "_cache_border_axes") + return result + + return wrapper + + class Figure(mfigure.Figure): """ The `~matplotlib.figure.Figure` subclass used by ultraplot. @@ -801,6 +822,172 @@ def __init__( # NOTE: This ignores user-input rc_mode. self.format(rc_kw=rc_kw, rc_mode=1, skip_axes=True, **kw_format) + @override + def draw(self, renderer): + # implement the tick sharing here + # should be shareable --> either all cartesian or all geographic + # but no mixing (panels can be mixed) + # check which ticks are on for x or y and push the labels to the + # outer most on a given column or row. + # we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars + self._share_ticklabels(axis="x") + self._share_ticklabels(axis="y") + super().draw(renderer) + + def _share_ticklabels(self, *, axis: str) -> None: + """ + Tick label sharing is determined at the figure level. While + each subplot controls the limits, we are dealing with the ticklabels + here as the complexity is easier to deal with. + axis: str 'x' or 'y', row or columns to update + """ + if not self.stale: + return + + outer_axes = self._get_border_axes() + + sides = ("top", "bottom") if axis == "x" else ("left", "right") + + # Version-dependent label name mapping for reading back params + first_axi = next(self._iter_axes(panels=True), None) + if first_axi is None: + labelleft = "labelleft" + labelright = "labelright" + labeltop = "labeltop" + labelbottom = "labelbottom" + else: + labelleft = first_axi._label_key("labelleft") + labelright = first_axi._label_key("labelright") + labeltop = first_axi._label_key("labeltop") + labelbottom = first_axi._label_key("labelbottom") + + # Group axes by row (for x) or column (for y) + def _group_key(ax): + ss = ax.get_subplotspec() + return ss.rowspan.start if axis == "x" else ss.colspan.start + + axes = list(self._iter_axes(panels=True, hidden=False)) + groups = {} + for axi in axes: + try: + key = _group_key(axi) + except Exception: + # If we can't get a subplotspec, skip grouping for this axes + continue + groups.setdefault(key, []).append(axi) + + # Process each group independently + for key, group_axes in groups.items(): + # Build baseline from MAIN axes only (exclude panels) + tick_params_group = {} + subplot_types_group = set() + unsupported_found = False + + for axi in group_axes: + # Only main axes "vote" for baseline + if getattr(axi, "_panel_side", None): + continue + # Supported axes types + if not isinstance( + axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) + ): + warnings._warn_ultraplot( + f"Tick label sharing not implemented for {type(axi)} subplots." + ) + unsupported_found = True + break + subplot_types_group.add(type(axi)) + match axis: + # Handle x + case "x" if isinstance(axi, paxes.CartesianAxes): + tmp = axi.xaxis.get_tick_params() + if tmp.get(labeltop): + tick_params_group[labeltop] = tmp[labeltop] + if tmp.get(labelbottom): + tick_params_group[labelbottom] = tmp[labelbottom] + case "x" if isinstance(axi, paxes.GeoAxes): + if axi._is_ticklabel_on("labeltop"): + tick_params_group["labeltop"] = axi._is_ticklabel_on( + "labeltop" + ) + if axi._is_ticklabel_on("labelbottom"): + tick_params_group["labelbottom"] = axi._is_ticklabel_on( + "labelbottom" + ) + + # Handle y + case "y" if isinstance(axi, paxes.CartesianAxes): + tmp = axi.yaxis.get_tick_params() + if tmp.get(labelleft): + tick_params_group[labelleft] = tmp[labelleft] + if tmp.get(labelright): + tick_params_group[labelright] = tmp[labelright] + case "y" if isinstance(axi, paxes.GeoAxes): + if axi._is_ticklabel_on("labelleft"): + tick_params_group["labelleft"] = axi._is_ticklabel_on( + "labelleft" + ) + if axi._is_ticklabel_on("labelright"): + tick_params_group["labelright"] = axi._is_ticklabel_on( + "labelright" + ) + + # Skip group if unsupported axes were found + if unsupported_found: + continue + + # We cannot mix types (yet) within a group + if len(subplot_types_group) > 1: + warnings._warn_ultraplot( + "Tick label sharing not implemented for mixed subplot types." + ) + continue + + # Apply baseline to all axes in the group (including panels) + for axi in group_axes: + tmp = tick_params_group.copy() + + # Respect figure border sides: only keep labels on true borders + for side in sides: + label = f"label{side}" + if isinstance(axi, paxes.CartesianAxes): + # For cartesian, use version-mapped key when reading/writing + label = axi._label_key(label) + if axi not in outer_axes[side]: + tmp[label] = False + from .axes.cartesian import OPPOSITE_SIDE + + if axi._panel_side and OPPOSITE_SIDE[axi._panel_side] == side: + tmp[label] = False + + # Determine sharing level for this axes + level = getattr(self, f"_share{axis}") + if axis == "y": + if hasattr(axi, "_panel_sharey_group") and axi._panel_sharey_group: + level = 3 + elif getattr(axi, "_panel_side", None) and getattr( + axi, "_sharey", None + ): + level = 3 + else: # x-axis + if hasattr(axi, "_panel_sharex_group") and axi._panel_sharex_group: + level = 3 + elif getattr(axi, "_panel_side", None) and getattr( + axi, "_sharex", None + ): + level = 3 + + if level < 3: + continue + + # Apply to geo/cartesian appropriately + if isinstance(axi, paxes.GeoAxes): + axi._toggle_gridliner_labels(**tmp) + elif tmp: + getattr(axi, f"{axis}axis").set_tick_params(**tmp) + + self.stale = True + def _context_adjusting(self, cache=True): """ Prevent re-running auto layout steps due to draws triggered by figure @@ -928,8 +1115,9 @@ def _get_border_axes( if gs is None: return border_axes - # Skip colorbars or panels etc - all_axes = [axi for axi in self.axes if axi.number is not None] + all_axes = [] + for axi in self._iter_axes(panels=True): + all_axes.append(axi) # Handle empty cases nrows, ncols = gs.nrows, gs.ncols @@ -941,26 +1129,52 @@ def _get_border_axes( # Reconstruct the grid based on axis locations. Note that # spanning axes will fit into one of the boxes. Check # this with unittest to see how empty axes are handles - grid, grid_axis_type, seen_axis_type = _get_subplot_layout( - gs, - all_axes, - same_type=same_type, - ) + + gs = self.axes[0].get_gridspec() + shape = (gs.nrows_total, gs.ncols_total) + grid = np.zeros(shape, dtype=object) + grid.fill(None) + grid_axis_type = np.zeros(shape, dtype=int) + seen_axis_type = dict() + ax_type_mapping = dict() + for axi in self._iter_axes(panels=True, hidden=True): + gs = axi.get_subplotspec() + x, y = np.unravel_index(gs.num1, shape) + span = gs._get_rows_columns() + + xleft, xright, yleft, yright = span + xspan = xright - xleft + 1 + yspan = yright - yleft + 1 + number = axi.number + axis_type = type(axi) + if isinstance(axi, (paxes.GeoAxes)): + axis_type = axi.projection + if axis_type not in seen_axis_type: + seen_axis_type[axis_type] = len(seen_axis_type) + type_number = seen_axis_type[axis_type] + ax_type_mapping[axi] = type_number + if axi.get_visible(): + grid[x : x + xspan, y : y + yspan] = axi + grid_axis_type[x : x + xspan, y : y + yspan] = type_number # We check for all axes is they are a border or not # Note we could also write the crawler in a way where # it find the borders by moving around in the grid, without spawning on each axis point. We may change # this in the future for axi in all_axes: - axis_type = seen_axis_type.get(type(axi), 1) + axis_type = ax_type_mapping[axi] + number = axi.number + if axi.number is None: + number = -axi._panel_parent.number crawler = _Crawler( ax=axi, grid=grid, - target=axi.number, + target=number, axis_type=axis_type, grid_axis_type=grid_axis_type, ) for direction, is_border in crawler.find_edges(): - if is_border: + # print(">>", is_border, direction, axi.number) + if is_border and axi not in border_axes[direction]: border_axes[direction].append(axi) self._cached_border_axes = border_axes return border_axes @@ -1054,12 +1268,7 @@ def _get_renderer(self): renderer = canvas.get_renderer() return renderer - def _get_sharing_level(self): - """ - We take the average here as the sharex and sharey should be the same value. In case this changes in the future we can track down the error easily - """ - return 0.5 * (self.figure._sharex + self.figure._sharey) - + @_clear_border_cache def _add_axes_panel(self, ax, side=None, **kwargs): """ Add an axes panel. @@ -1102,8 +1311,66 @@ def _add_axes_panel(self, ax, side=None, **kwargs): axis = pax.yaxis if side in ("left", "right") else pax.xaxis getattr(axis, "tick_" + side)() # set tick and tick label position axis.set_label_position(side) # set label position + # Sync limits and formatters with parent when sharing to ensure consistent ticks + if share: + # Copy limits for the shared axis + if side in ("left", "right"): + try: + pax.set_ylim(ax.get_ylim()) + except Exception: + pass + else: + try: + pax.set_xlim(ax.get_xlim()) + except Exception: + pass + # Align with backend: for GeoAxes, use lon/lat degree formatters on panels. + # Otherwise, copy the parent's axis formatters. + if isinstance(ax, paxes.GeoAxes): + fmt_key = "deglat" if side in ("left", "right") else "deglon" + axis.set_major_formatter(constructor.Formatter(fmt_key)) + else: + paxis = ax.yaxis if side in ("left", "right") else ax.xaxis + axis.set_major_formatter(paxis.get_major_formatter()) + axis.set_minor_formatter(paxis.get_minor_formatter()) + # Push main axes tick labels to the outside relative to the added panel + # Skip this for filled panels (colorbars/legends) + if not kw.get("filled", False): + if isinstance(ax, paxes.GeoAxes): + if side == "top": + ax._toggle_gridliner_labels(labeltop=False) + elif side == "bottom": + ax._toggle_gridliner_labels(labelbottom=False) + elif side == "left": + ax._toggle_gridliner_labels(labelleft=False) + elif side == "right": + ax._toggle_gridliner_labels(labelright=False) + else: + if side == "top": + ax.xaxis.set_tick_params(labeltop=False) + elif side == "bottom": + ax.xaxis.set_tick_params(labelbottom=False) + elif side == "left": + ax.yaxis.set_tick_params(labelleft=False) + elif side == "right": + ax.yaxis.set_tick_params(labelright=False) + + # Panel labels: prefer outside only for non-sharing top/right; otherwise keep off + if side == "top": + if not share: + pax.xaxis.set_tick_params(labeltop=True, labelbottom=False) + else: + pax.xaxis.set_tick_params(labeltop=False) + elif side == "right": + if not share: + pax.yaxis.set_tick_params(labelright=True, labelleft=False) + else: + pax.yaxis.set_tick_params(labelright=False) + ax.yaxis.set_tick_params(labelright=False) + return pax + @_clear_border_cache def _add_figure_panel( self, side=None, span=None, row=None, col=None, rows=None, cols=None, **kwargs ): @@ -1138,6 +1405,7 @@ def _add_figure_panel( pax._panel_parent = None return pax + @_clear_border_cache def _add_subplot(self, *args, **kwargs): """ The driver function for adding single subplots. @@ -1246,9 +1514,6 @@ def _add_subplot(self, *args, **kwargs): if ax.number: self._subplot_dict[ax.number] = ax - # Invalidate border axes cache - if hasattr(self, "_cached_border_axes"): - delattr(self, "_cached_border_axes") return ax def _unshare_axes(self): @@ -1263,56 +1528,6 @@ def _unshare_axes(self): if isinstance(ax, paxes.GeoAxes) and hasattr(ax, "set_global"): ax.set_global() - def _share_labels_with_others(self, *, which="both"): - """ - Helpers function to ensure the labels - are shared for rectilinear GeoAxes. - """ - # Only apply sharing of labels when we are - # actually sharing labels. - if self._get_sharing_level() == 0: - return - # Turn all labels off - # Note: this action performs it for all the axes in - # the figure. We use the stale here to only perform - # it once as it is an expensive action. - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind - border_axes = self._get_border_axes() - # Recode: - recoded = {} - for direction, axes in border_axes.items(): - for axi in axes: - recoded[axi] = recoded.get(axi, []) + [direction] - - are_ticks_on = False - default = dict( - labelleft=are_ticks_on, - labelright=are_ticks_on, - labeltop=are_ticks_on, - labelbottom=are_ticks_on, - ) - for axi in self._iter_axes(hidden=False, panels=False, children=False): - # Turn the ticks on or off depending on the position - sides = recoded.get(axi, []) - turn_on_or_off = default.copy() - - for side in sides: - sidelabel = f"label{side}" - is_label_on = axi._is_ticklabel_on(sidelabel) - if is_label_on: - # When we are a border an the labels are on - # we keep them on - assert sidelabel in turn_on_or_off - turn_on_or_off[sidelabel] = True - - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**turn_on_or_off) - else: - axi._apply_axis_sharing() - def _toggle_axis_sharing( self, *, @@ -1728,6 +1943,7 @@ def _update_super_title(self, title, **kwargs): if title is not None: self._suptitle.set_text(title) + @_clear_border_cache @docstring._concatenate_inherited @docstring._snippet_manager def add_axes(self, rect, **kwargs): @@ -1822,7 +2038,6 @@ def _align_content(): # noqa: E306 # subsequent tight layout really weird. Have to resize twice. _draw_content() if not gs: - print("hello") return if aspect: gs._auto_layout_aspect() @@ -1968,12 +2183,6 @@ def format( } ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs) ax.number = store_old_number - # When we apply formatting to all axes, we need - # to potentially adjust the labels. - - if len(axs) == len(self.axes) and self._get_sharing_level() > 0: - self._share_labels_with_others() - # Warn unused keyword argument(s) kw = { key: value @@ -1985,53 +2194,6 @@ def format( f"Ignoring unused projection-specific format() keyword argument(s): {kw}" # noqa: E501 ) - def _share_labels_with_others(self, *, which="both"): - """ - Helpers function to ensure the labels - are shared for rectilinear GeoAxes. - """ - # Turn all labels off - # Note: this action performs it for all the axes in - # the figure. We use the stale here to only perform - # it once as it is an expensive action. - border_axes = self._get_border_axes(same_type=False) - # Recode: - recoded = {} - for direction, axes in border_axes.items(): - for axi in axes: - recoded[axi] = recoded.get(axi, []) + [direction] - - # We turn off the tick labels when the scale and - # ticks are shared (level > 0) - are_ticks_on = False - default = dict( - labelleft=are_ticks_on, - labelright=are_ticks_on, - labeltop=are_ticks_on, - labelbottom=are_ticks_on, - ) - for axi in self._iter_axes(hidden=False, panels=False, children=False): - # Turn the ticks on or off depending on the position - sides = recoded.get(axi, []) - turn_on_or_off = default.copy() - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind - for side in sides: - sidelabel = f"label{side}" - is_label_on = axi._is_ticklabel_on(sidelabel) - if is_label_on: - # When we are a border an the labels are on - # we keep them on - assert sidelabel in turn_on_or_off - turn_on_or_off[sidelabel] = True - - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**turn_on_or_off) - else: - axi.tick_params(which=which, **turn_on_or_off) - @docstring._concatenate_inherited @docstring._snippet_manager def colorbar( diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 159cac2c5..029b61c1e 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -195,7 +195,7 @@ def _get_rows_columns(self, ncols=None): row2, col2 = divmod(self.num2, ncols) return row1, row2, col1, col2 - def _get_grid_span(self, hidden=False) -> (int, int, int, int): + def _get_grid_span(self, hidden=True) -> (int, int, int, int): """ Retrieve the location of the subplot within the gridspec. When hidden is False we only consider @@ -203,11 +203,12 @@ def _get_grid_span(self, hidden=False) -> (int, int, int, int): """ gs = self.get_gridspec() nrows, ncols = gs.nrows_total, gs.ncols_total - if not hidden: + if hidden: + x, y = np.unravel_index(self.num1, (nrows, ncols)) + else: nrows, ncols = gs.nrows, gs.ncols - # Use num1 or num2 - decoded = gs._decode_indices(self.num1) - x, y = np.unravel_index(decoded, (nrows, ncols)) + decoded = gs._decode_indices(self.num1) + x, y = np.unravel_index(decoded, (nrows, ncols)) span = self._get_rows_columns() xspan = span[1] - span[0] + 1 # inclusive diff --git a/ultraplot/tests/conftest.py b/ultraplot/tests/conftest.py index e6848abaa..db2482d90 100644 --- a/ultraplot/tests/conftest.py +++ b/ultraplot/tests/conftest.py @@ -3,7 +3,6 @@ import warnings, logging logging.getLogger("matplotlib").setLevel(logging.ERROR) - SEED = 51423 diff --git a/ultraplot/tests/test_2dplots.py b/ultraplot/tests/test_2dplots.py index 13f084c64..a2b75319d 100644 --- a/ultraplot/tests/test_2dplots.py +++ b/ultraplot/tests/test_2dplots.py @@ -30,12 +30,12 @@ def test_auto_diverging1(rng): """ # Test with basic data fig = uplt.figure() - # fig.format(collabels=('Auto sequential', 'Auto diverging'), suptitle='Default') ax = fig.subplot(121) ax.pcolor(rng.random((10, 10)) * 5, colorbar="b") ax = fig.subplot(122) ax.pcolor(rng.random((10, 10)) * 5 - 3.5, colorbar="b") fig.format(toplabels=("Sequential", "Diverging")) + fig.canvas.draw() return fig diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index a04c2233a..75ccb3aa3 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -352,7 +352,7 @@ def test_sharing_labels_top_right(): [3, 4, 5], [3, 4, 0], ], - 3, # default sharing level + True, # default sharing level {"xticklabelloc": "t", "yticklabelloc": "r"}, [1, 3, 4], # y-axis labels visible indices [0, 1, 4], # x-axis labels visible indices @@ -405,6 +405,7 @@ def check_state(ax, numbers, state, which): # Format axes with the specified tick label locations ax.format(**tick_loc) + fig.canvas.draw() # needed for sharing labels # Calculate the indices where labels should be hidden all_indices = list(range(len(ax))) diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index 0e92f8f2f..cffa3c7f6 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -58,7 +58,17 @@ def test_unsharing_different_rectilinear(): """ with pytest.warns(uplt.internals.warnings.UltraPlotWarning): fig, ax = uplt.subplots(ncols=2, proj=("cyl", "merc"), share="all") - uplt.close(fig) + + +def test_get_renderer_basic(): + """ + Test that _get_renderer returns a renderer object. + """ + fig, ax = uplt.subplots() + renderer = fig._get_renderer() + # Renderer should not be None and should have draw_path method + assert renderer is not None + assert hasattr(renderer, "draw_path") def test_figure_sharing_toggle(): diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 35789a54d..c94b0adf9 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -296,6 +296,7 @@ def are_labels_on(ax, which=["top", "bottom", "right", "left"]) -> tuple[bool]: settings = dict(land=True, ocean=True, labels="both") fig, ax = uplt.subplots(layout, share="all", proj="cyl") ax.format(**settings) + fig.canvas.draw() # needed for sharing labels for axi in ax: state = are_labels_on(axi) expectation = expectations[axi.number - 1] @@ -491,7 +492,8 @@ def test_get_gridliner_labels_cartopy(): uplt.close(fig) -def test_sharing_levels(): +@pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) +def test_sharing_levels(level): """ We can share limits or labels. We check if we can do both for the GeoAxes. @@ -515,7 +517,6 @@ def test_sharing_levels(): x = np.array([0, 10]) y = np.array([0, 10]) - sharing_levels = [0, 1, 2, 3, 4] lonlim = latlim = np.array((-10, 10)) def assert_views_are_sharing(ax): @@ -551,46 +552,42 @@ def assert_views_are_sharing(ax): l2 = np.linalg.norm( np.asarray(latview) - np.asarray(target_lat), ) - level = ax.figure._get_sharing_level() + level = ax.figure._sharex if level <= 1: share_x = share_y = False assert np.allclose(l1, 0) == share_x assert np.allclose(l2, 0) == share_y - for level in sharing_levels: - fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share=level) - ax.format(labels="both") - for axi in ax: - axi.format( - lonlim=lonlim * axi.number, - latlim=latlim * axi.number, - ) + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share=level) + ax.format(labels="both") + for axi in ax: + axi.format( + lonlim=lonlim * axi.number, + latlim=latlim * axi.number, + ) - fig.canvas.draw() - for idx, axi in enumerate(ax): - axi.plot(x * (idx + 1), y * (idx + 1)) - - fig.canvas.draw() # need this to update the labels - # All the labels should be on - for axi in ax: - side_labels = axi._get_gridliner_labels( - left=True, - right=True, - top=True, - bottom=True, - ) - s = 0 - for dir, labels in side_labels.items(): - s += any([label.get_visible() for label in labels]) - - assert_views_are_sharing(axi) - # When we share the labels but not the limits, - # we expect all ticks to be on - if level == 0: - assert s == 4 - else: - assert s == 2 - uplt.close(fig) + fig.canvas.draw() + for idx, axi in enumerate(ax): + axi.plot(x * (idx + 1), y * (idx + 1)) + + # All the labels should be on + for axi in ax: + + s = sum( + [ + 1 if axi._is_ticklabel_on(side) else 0 + for side in "labeltop labelbottom labelleft labelright".split() + ] + ) + + assert_views_are_sharing(axi) + # When we share the labels but not the limits, + # we expect all ticks to be on + if level > 2: + assert s == 2 + else: + assert s == 4 + uplt.close(fig) @pytest.mark.mpl_image_compare @@ -616,8 +613,10 @@ def test_cartesian_and_geo(rng): ax.format(land=True, lonlim=(-10, 10), latlim=(-10, 10)) ax[0].pcolormesh(rng.random((10, 10))) ax[1].scatter(*rng.random((2, 100))) - ax[0]._apply_axis_sharing() - assert mocked.call_count == 2 + fig.canvas.draw() + assert ( + mocked.call_count > 2 + ) # needs to be called at least twice; one for each axis return fig @@ -676,21 +675,38 @@ def test_check_tricontourf(): def test_panels_geo(): fig, ax = uplt.subplots(proj="cyl") ax.format(labels=True) - for dir in "top bottom right left".split(): + dirs = "top bottom right left".split() + for dir in dirs: pax = ax.panel_axes(dir) - match dir: - case "top": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "bottom": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "left": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "right": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 + fig.canvas.draw() + pax = ax[0]._panel_dict["left"][-1] + assert pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + + pax = ax[0]._panel_dict["top"][-1] + assert pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert not pax._is_ticklabel_on("labelbottom") + + pax = ax[0]._panel_dict["bottom"][-1] + assert pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + + pax = ax[0]._panel_dict["right"][-1] + assert not pax._is_ticklabel_on("labelleft") # should not error + assert not pax._is_ticklabel_on("labelright") + assert not pax._is_ticklabel_on("labeltop") + assert pax._is_ticklabel_on("labelbottom") + + for dir in dirs: + not ax[0]._is_ticklabel_on(f"label{dir}") + + return fig @pytest.mark.mpl_image_compare @@ -807,6 +823,7 @@ def are_labels_on(ax, which=("top", "bottom", "right", "left")) -> tuple[bool]: h = ax.imshow(data)[0] ax.format(land=True, labels="both") # need this otherwise no labels are printed fig.colorbar(h, loc="r") + fig.canvas.draw() # needed to invoke axis sharing expectations = ( [True, False, False, True], diff --git a/ultraplot/tests/test_inset.py b/ultraplot/tests/test_inset.py index ea1bf76af..9a1dfc611 100644 --- a/ultraplot/tests/test_inset.py +++ b/ultraplot/tests/test_inset.py @@ -7,6 +7,7 @@ def test_inset_basic(): # spacing, aspect ratios, and axis sharing gs = uplt.GridSpec(nrows=2, ncols=2) fig = uplt.figure(refwidth=1.5, share=False) + fig.canvas.draw() for ss, side in zip(gs, "tlbr"): ax = fig.add_subplot(ss) px = ax.panel_axes(side, width="3em") diff --git a/ultraplot/tests/test_sharing.py b/ultraplot/tests/test_sharing.py new file mode 100644 index 000000000..620e879f8 --- /dev/null +++ b/ultraplot/tests/test_sharing.py @@ -0,0 +1,98 @@ +import pytest, ultraplot as uplt + +""" +Sharing levels for subplots determine the visibility of the axis labels and tick labels. + +Axis labels are pushed to the border subplots when the sharing level is greater than 1. + +Ticks are visible only on the border plots when the sharing level is greater than 2. + +Or more verbosely: + sharey = 0: no sharing, all labels and ticks visible + sharey = 1: share axis labels, tick labels are still independent + sharey = 2: share data limits + sharey = 3 or True, share both ticks and labels +A similar story holds for sharex. +""" + + +@pytest.mark.parametrize("share_level", [0, "labels", "labs", 1, True]) +@pytest.mark.mpl_image_compare +def test_sharing_levels_y(share_level): + """ + Test sharing levels for y-axis: left and right ticks/labels. + """ + fig, axs = uplt.subplots(None, 2, 3, sharey=share_level) + axs.format(ylabel="Y") + axs.format(title=f"sharey = {share_level}") + fig.canvas.draw() # needed for checks + + if fig._sharey < 3: + border_axes = set(axs) + else: + # Reduce border_axes to a set of axes for left and right + border_axes = set() + for direction in ["left", "right"]: + axes = fig._get_border_axes().get(direction, []) + if isinstance(axes, (list, tuple, set)): + border_axes.update(axes) + else: + border_axes.add(axes) + for axi in axs: + tick_params = axi.yaxis.get_tick_params() + for direction in ["left", "right"]: + label_key = f"label{direction}" + visible = tick_params.get(label_key, False) + is_border = axi in fig._get_border_axes().get(direction, []) + if direction == "left" and (fig._sharey < 3 or is_border): + assert visible + else: + assert not visible + return fig + + +@pytest.mark.parametrize("share_level", [0, "labels", "labs", 1, True]) +@pytest.mark.mpl_image_compare +def test_sharing_levels_x(share_level): + """ + Test sharing levels for x-axis: top and bottom ticks/labels. + """ + fig, axs = uplt.subplots(None, 2, 3, sharex=share_level) + axs.format(xlabel="X") + axs.format(title=f"sharex = {share_level}") + fig.canvas.draw() # needed for checks + + # Get the border axes + if fig._sharex < 3: + border_axes = set(axs) + else: + # Reduce border_axes to a set of axes for top and bottom + border_axes = set() + for direction in ["top", "bottom"]: + axes = fig._get_border_axes().get(direction, []) + if isinstance(axes, (list, tuple, set)): + border_axes.update(axes) + else: + border_axes.add(axes) + + # Run tests + for axi in axs: + tick_params = axi.xaxis.get_tick_params() + # Get correct directions depending on mpl version + from ultraplot.internals.versions import _version_mpl + from packaging import version + + if version.parse(str(_version_mpl)) >= version.parse("3.10"): + direction_label_map = {"top": "labeltop", "bottom": "labelbottom"} + else: + direction_label_map = {"top": "labelright", "bottom": "labelleft"} + + for direction in ["top", "bottom"]: + label_key = direction_label_map[direction] + visible = tick_params.get(label_key, False) + is_border = axi in fig._get_border_axes().get(direction, []) + if direction == "bottom" and (fig._sharex < 3 or is_border): + assert visible + else: + assert not visible + return fig diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index e215a90ee..d2379ad73 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -290,29 +290,53 @@ def test_panel_sharing_top_right(layout): for dir in "left right top bottom".split(): pax = ax[0].panel(dir) fig.canvas.draw() # force redraw tick labels - for dir, paxs in ax[0]._panel_dict.items(): - # Since we are sharing some of the ticks - # should be hidden depending on where the panel is - # in the grid - for pax in paxs: - match dir: - case "left": - assert pax._is_ticklabel_on("labelleft") - assert pax._is_ticklabel_on("labelbottom") - case "top": - assert pax._is_ticklabel_on("labeltop") == False - assert pax._is_ticklabel_on("labelbottom") == False - assert pax._is_ticklabel_on("labelleft") - case "right": - print(pax._is_ticklabel_on("labelright")) - assert pax._is_ticklabel_on("labelright") == False - assert pax._is_ticklabel_on("labelbottom") - case "bottom": - assert pax._is_ticklabel_on("labelleft") - assert pax._is_ticklabel_on("labelbottom") == False - - # The sharing axis is not showing any ticks - assert ax[0]._is_ticklabel_on(dir) == False + + # Main panel: ticks are off + assert not ax[0]._is_ticklabel_on("labelleft") + assert not ax[0]._is_ticklabel_on("labelright") + assert not ax[0]._is_ticklabel_on("labeltop") + assert not ax[0]._is_ticklabel_on("labelbottom") + + # For panels the inside ticks are off + panel = ax[0]._panel_dict["left"][-1] + assert panel._is_ticklabel_on("labelleft") + assert panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + panel = ax[0]._panel_dict["top"][-1] + assert panel._is_ticklabel_on("labelleft") + assert not panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + panel = ax[0]._panel_dict["right"][-1] + assert not panel._is_ticklabel_on("labelleft") + assert panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + panel = ax[0]._panel_dict["bottom"][-1] + assert panel._is_ticklabel_on("labelleft") + assert not panel._is_ticklabel_on("labelbottom") + assert not panel._is_ticklabel_on("labelright") + assert not panel._is_ticklabel_on("labeltop") + + assert not ax[1]._is_ticklabel_on("labelleft") + assert not ax[1]._is_ticklabel_on("labelright") + assert not ax[1]._is_ticklabel_on("labeltop") + assert not ax[1]._is_ticklabel_on("labelbottom") + + assert ax[2]._is_ticklabel_on("labelleft") + assert not ax[2]._is_ticklabel_on("labelright") + assert not ax[2]._is_ticklabel_on("labeltop") + assert ax[2]._is_ticklabel_on("labelbottom") + + assert not ax[3]._is_ticklabel_on("labelleft") + assert not ax[3]._is_ticklabel_on("labelright") + assert not ax[3]._is_ticklabel_on("labeltop") + assert ax[3]._is_ticklabel_on("labelbottom") + return fig @@ -327,3 +351,67 @@ def test_uneven_span_subplots(rng): axs[-1, -1].format(fc="gray4", grid=False) axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2) return fig + + +@pytest.mark.mpl_image_compare +def test_uneven_span_subplots(rng): + fig = uplt.figure(refwidth=1, refnum=5, span=False) + axs = fig.subplots([[1, 1, 2], [3, 4, 2], [3, 4, 5]], hratios=[2.2, 1, 1]) + axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Complex SubplotGrid") + axs[0].format(ec="black", fc="gray1", lw=1.4) + axs[1, 1:].format(fc="blush") + axs[1, :1].format(fc="sky blue") + axs[-1, -1].format(fc="gray4", grid=False) + axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2) + return fig + + +@pytest.mark.parametrize("share_panels", [True, False]) +def test_panel_ticklabels_all_sides_share_and_no_share(share_panels): + # 2x2 grid; add panels on all sides of the first axes + fig, ax = uplt.subplots(nrows=2, ncols=2) + axi = ax[0] + + # Create panels on all sides with configurable sharing + pax_left = axi.panel("left", share=share_panels) + pax_right = axi.panel("right", share=share_panels) + pax_top = axi.panel("top", share=share_panels) + pax_bottom = axi.panel("bottom", share=share_panels) + + # Force draw so ticklabel state is resolved + fig.canvas.draw() + + def assert_panel(axi_panel, side, share_flag): + on_left = axi_panel._is_ticklabel_on("labelleft") + on_right = axi_panel._is_ticklabel_on("labelright") + on_top = axi_panel._is_ticklabel_on("labeltop") + on_bottom = axi_panel._is_ticklabel_on("labelbottom") + + # Inside (toward the main) must be off in all cases + if side == "left": + # Inside is right + assert not on_right + elif side == "right": + # Inside is left + assert not on_left + elif side == "top": + # Inside is bottom + assert not on_bottom + elif side == "bottom": + # Inside is top + assert not on_top + + if not share_flag: + # For non-sharing panels, prefer outside labels on for top/right + if side == "right": + assert on_right + if side == "top": + assert on_top + # For left/bottom non-sharing, we don't enforce outside on here + # (baseline may keep left/bottom on the main) + + # Check each panel side + assert_panel(pax_left, "left", share_panels) + assert_panel(pax_right, "right", share_panels) + assert_panel(pax_top, "top", share_panels) + assert_panel(pax_bottom, "bottom", share_panels) diff --git a/ultraplot/utils.py b/ultraplot/utils.py index 1b1b97a95..621127982 100644 --- a/ultraplot/utils.py +++ b/ultraplot/utils.py @@ -918,7 +918,8 @@ def _get_subplot_layout( axis types. This function is used internally to determine the layout of axes in a GridSpec. """ - grid = np.zeros((gs.nrows, gs.ncols)) + grid = np.zeros((gs.nrows_total, gs.ncols_total), dtype=object) + grid.fill(None) grid_axis_type = np.zeros((gs.nrows, gs.ncols)) # Collect grouper based on kinds of axes. This # would allow us to share labels across types @@ -936,7 +937,7 @@ def _get_subplot_layout( grid[ slice(*rowspan), slice(*colspan), - ] = axi.number + ] = axi # Allow grouping of mixed types axis_type = 1 @@ -996,22 +997,28 @@ def find_edge_for( direction: str, d: tuple[int, int], ) -> tuple[str, bool]: - from itertools import product - """ Setup search for a specific direction. """ + from itertools import product + # Retrieve where the axis is in the grid spec = self.ax.get_subplotspec() - spans = spec._get_grid_span() + shape = (spec.get_gridspec().nrows_total, spec.get_gridspec().ncols_total) + x, y = np.unravel_index(spec.num1, shape) + spans = spec._get_rows_columns() rowspan = spans[:2] colspan = spans[-2:] - xs = range(*rowspan) - ys = range(*colspan) + + a = rowspan[1] - rowspan[0] + b = colspan[1] - colspan[0] + xs = range(x, x + a + 1) + ys = range(y, y + b + 1) + is_border = False - for x, y in product(xs, ys): - pos = (x, y) + for xl, yl in product(xs, ys): + pos = (xl, yl) if self.is_border(pos, d): is_border = True break @@ -1026,27 +1033,31 @@ def is_border( Recursively move over the grid by following the direction. """ x, y = pos - # Check if we are at an edge of the grid (out-of-bounds). - if x < 0: - return True - elif x > self.grid.shape[0] - 1: + # Edge of grid (out-of-bounds) + if not (0 <= x < self.grid.shape[0] and 0 <= y < self.grid.shape[1]): return True - if y < 0: - return True - elif y > self.grid.shape[1] - 1: - return True + cell = self.grid[x, y] + dx, dy = direction + if cell is None: + return self.is_border((x + dx, y + dy), direction) - if self.grid[x, y] == 0 or self.grid_axis_type[x, y] != self.axis_type: - return True + if hasattr(cell, "_panel_hidden") and cell._panel_hidden: + return self.is_border((x + dx, y + dy), direction) - # Check if we reached a plot or an internal edge - if self.grid[x, y] != self.target and self.grid[x, y] > 0: - return self._check_ranges(direction, other=self.grid[x, y]) + if self.grid_axis_type[x, y] != self.axis_type: + # Allow traversing across the parent<->panel interface even when types differ + # e.g., GeoAxes main with cartesian panel or vice versa + if getattr(self.ax, "_panel_parent", None) is cell: + return self.is_border((x + dx, y + dy), direction) + if getattr(cell, "_panel_parent", None) is self.ax: + return self.is_border((x + dx, y + dy), direction) - dx, dy = direction - pos = (x + dx, y + dy) - return self.is_border(pos, direction) + # Internal edge or plot reached + if cell != self.ax: + return self._check_ranges(direction, other=cell) + + return self.is_border((x + dx, y + dy), direction) def _check_ranges( self, @@ -1065,14 +1076,15 @@ def _check_ranges( can share x. """ this_spec = self.ax.get_subplotspec() - other_spec = self.ax.figure._subplot_dict[other].get_subplotspec() + other_spec = other.get_subplotspec() # Get the row and column spans of both axes - this_span = this_spec._get_grid_span() + this_span = this_spec._get_rows_columns() this_rowspan = this_span[:2] this_colspan = this_span[-2:] other_span = other_spec._get_grid_span() + other_span = other_spec._get_rows_columns() other_rowspan = other_span[:2] other_colspan = other_span[-2:] @@ -1089,7 +1101,28 @@ def _check_ranges( other_start, other_stop = other_rowspan if this_start == other_start and this_stop == other_stop: - return False # not a border + # We may hit an internal border if we are at + # the interface with a panel that is not sharing + dmap = { + (-1, 0): "bottom", + (1, 0): "top", + (0, -1): "left", + (0, 1): "right", + } + side = dmap[direction] + if self.ax.number is None: # panel + panel_side = getattr(self.ax, "_panel_side", None) + # Non-sharing panels: border only on their outward side + if not getattr(self.ax, "_panel_share", False): + return side == panel_side + # Sharing panels: border only if this is the outward side and this + # panel is the outer-most panel for that side relative to its parent. + parent = self.ax._panel_parent + panels = parent._panel_dict.get(panel_side, []) + if side == panel_side and panels and panels[-1] is self.ax: + return True + return False + return False return True