diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index d19c33d0f..b7e6631be 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -8,7 +8,7 @@ import re import types from numbers import Integral, Number -from typing import Union, Iterable, MutableMapping +from typing import Union, Iterable, MutableMapping, Optional, Tuple from collections.abc import Iterable as IterableType try: @@ -276,6 +276,19 @@ pad : unit-spec, default: :rc:`subplots.panelpad` The :ref:`tight layout padding ` between the panel and the subplot. %(units.em)s +row, rows + Aliases for `span` for panels on the left or right side (vertical panels). +col, cols + Aliases for `span` for panels on the top or bottom side (horizontal panels). +span : int or 2-tuple of int, default: None + Integer(s) indicating the span of the panel across rows and columns of + subplots. For panels on the left or right side, use `rows` or `row` to + specify which rows the panel should span. For panels on the top or bottom + side, use `cols` or `col` to specify which columns the panel should span. + For example, ``ax.panel('b', col=1)`` draws a panel beneath only the + leftmost column, and ``ax.panel('b', cols=(1, 2))`` draws a panel beneath + the left two columns. By default the panel will span all rows or columns + aligned with the parent axes. share : bool, default: True Whether to enable axis sharing between the *x* and *y* axes of the main subplot and the panel long axes for each panel in the "stack". @@ -963,7 +976,18 @@ def _add_guide_frame( self.add_artist(patch) return patch - def _add_guide_panel(self, loc="fill", align="center", length=0, **kwargs): + def _add_guide_panel( + self, + loc: str = "fill", + align: str = "center", + length: Union[float, str] = 0, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ) -> "Axes": """ Add a panel to be filled by an "outer" colorbar or legend. """ @@ -984,7 +1008,16 @@ def _add_guide_panel(self, loc="fill", align="center", length=0, **kwargs): ax = pax break if ax is None: - ax = self.panel_axes(loc, filled=True, **kwargs) + ax = self.panel_axes( + loc, + filled=True, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, + ) else: raise ValueError(f"Invalid filled panel location {loc!r}.") for s in ax.spines.values(): @@ -1002,13 +1035,18 @@ def _add_colorbar( mappable, values=None, *, - loc=None, - align=None, - space=None, - pad=None, - width=None, - length=None, - shrink=None, + loc: Optional[str] = None, + align: Optional[str] = None, + space: Optional[Union[float, str]] = None, + pad: Optional[Union[float, str]] = None, + width: Optional[Union[float, str]] = None, + length: Optional[Union[float, str]] = None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + shrink: Optional[Union[float, str]] = None, label=None, title=None, reverse=False, @@ -1123,7 +1161,17 @@ def _add_colorbar( kwargs.update({"align": align, "length": length}) extendsize = _not_none(extendsize, rc["colorbar.extend"]) ax = self._add_guide_panel( - loc, align, length=length, width=width, space=space, pad=pad + loc, + align, + length=length, + width=width, + space=space, + pad=pad, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, ) # noqa: E501 cax, kwargs = ax._parse_colorbar_filled(**kwargs) else: diff --git a/ultraplot/figure.py b/ultraplot/figure.py index d28e929c8..7c2cd454b 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -9,9 +9,9 @@ from packaging import version try: - from typing import List + from typing import List, Optional, Union, Tuple except ImportError: - from typing_extensions import List + from typing_extensions import List, Optional, Union, Tuple import matplotlib.axes as maxes import matplotlib.figure as mfigure @@ -1338,7 +1338,17 @@ def _get_renderer(self): return renderer @_clear_border_cache - def _add_axes_panel(self, ax, side=None, **kwargs): + def _add_axes_panel( + self, + ax: "paxes.Axes", + side: Optional[str] = None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ) -> "paxes.Axes": """ Add an axes panel. """ @@ -1368,6 +1378,35 @@ def _add_axes_panel(self, ax, side=None, **kwargs): if not gs: raise RuntimeError("The gridspec must be active.") kw = _pop_params(kwargs, gs._insert_panel_slot) + + # Validate and determine span override from span/row/col/rows/cols parameters + span_override = None + if side in ("left", "right"): + # Vertical panels: should use rows parameter, not cols + if _not_none(cols, col) is not None and _not_none(rows, row) is None: + raise ValueError( + f"For {side!r} colorbars (vertical), use 'rows=' or 'row=' " + "to specify span, not 'cols=' or 'col='." + ) + if span is not None and _not_none(rows, row) is None: + warnings._warn_ultraplot( + f"For {side!r} colorbars (vertical), prefer 'rows=' over 'span=' " + "for clarity. Using 'span' as rows." + ) + span_override = _not_none(rows, row, span) + else: + # Horizontal panels: should use cols parameter, not rows + if _not_none(rows, row) is not None and _not_none(cols, col, span) is None: + raise ValueError( + f"For {side!r} colorbars (horizontal), use 'cols=' or 'span=' " + "to specify span, not 'rows=' or 'row='." + ) + span_override = _not_none(cols, col, span) + + # Pass span_override to gridspec if provided + if span_override is not None: + kw["span_override"] = span_override + ss, share = gs._insert_panel_slot(side, ax, **kw) # Guard: GeoAxes with non-rectilinear projections cannot share with panels if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear(): @@ -1452,8 +1491,15 @@ def _add_axes_panel(self, ax, side=None, **kwargs): @_clear_border_cache def _add_figure_panel( - self, side=None, span=None, row=None, col=None, rows=None, cols=None, **kwargs - ): + self, + side: Optional[str] = None, + span: Optional[Union[int, Tuple[int, int]]] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + **kwargs, + ) -> "paxes.Axes": """ Add a figure panel. """ @@ -2280,16 +2326,16 @@ def colorbar( self, mappable, values=None, - loc=None, - location=None, - row=None, - col=None, - rows=None, - cols=None, - span=None, - space=None, - pad=None, - width=None, + loc: Optional[str] = None, + location: Optional[str] = None, + row: Optional[int] = None, + col: Optional[int] = None, + rows: Optional[Union[int, Tuple[int, int]]] = None, + cols: Optional[Union[int, Tuple[int, int]]] = None, + span: Optional[Union[int, Tuple[int, int]]] = None, + space: Optional[Union[float, str]] = None, + pad: Optional[Union[float, str]] = None, + width: Optional[Union[float, str]] = None, **kwargs, ): """ @@ -2341,8 +2387,33 @@ def colorbar( cb = super().colorbar(mappable, cax=cax, **kwargs) # Axes panel colorbar elif ax is not None: - cb = ax.colorbar( - mappable, values, space=space, pad=pad, width=width, **kwargs + # Check if span parameters are provided + has_span = _not_none(span, row, col, rows, cols) is not None + + # Extract a single axes from array if span is provided + # Otherwise, pass the array as-is for normal colorbar behavior + if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)): + try: + ax_single = next(iter(ax)) + except (TypeError, StopIteration): + ax_single = ax + else: + ax_single = ax + + # Pass span parameters through to axes colorbar + cb = ax_single.colorbar( + mappable, + values, + space=space, + pad=pad, + width=width, + loc=loc, + span=span, + row=row, + col=col, + rows=rows, + cols=cols, + **kwargs, ) # Figure panel colorbar else: diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index df9a539f7..59de0f04c 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -12,7 +12,7 @@ import matplotlib.gridspec as mgridspec import matplotlib.transforms as mtransforms import numpy as np -from typing import List +from typing import List, Optional, Union, Tuple from functools import wraps from . import axes as paxes @@ -587,16 +587,79 @@ def _parse_panel_arg(self, side, arg): # NOTE: Convert using the lengthwise indices return slot, iratio, slice(start, stop + 1) + def _parse_panel_arg_with_span( + self, + side: str, + ax: "paxes.Axes", + span_override: Optional[Union[int, Tuple[int, int]]], + ) -> Tuple[str, int, slice]: + """ + Parse panel arg with span override. Uses ax for position, span for extent. + + Parameters + ---------- + side : str + Panel side ('left', 'right', 'top', 'bottom') + ax : Axes + The axes to position the panel relative to + span_override : int or tuple + The span extent (1-indexed like subplot numbers) + + Returns + ------- + slot : str + Panel slot identifier + iratio : int + Panel position index + span : slice + Encoded span slice for the panel extent + """ + # Get the axes position + ss = ax.get_subplotspec().get_topmost_subplotspec() + row1, row2, col1, col2 = ss._get_rows_columns() + + # Determine slot and index based on side + slot = side[0] + offset = len(ax._panel_dict[side]) + 1 + + if side in ("left", "right"): + # Panel is vertical, span controls rows + iratio = col1 - offset if side == "left" else col2 + offset + # Parse span as row specification (1-indexed input, convert to 0-indexed) + if isinstance(span_override, Integral): + span_start, span_stop = span_override - 1, span_override - 1 + else: + span_override = np.atleast_1d(span_override) + span_start, span_stop = span_override[0] - 1, span_override[-1] - 1 + else: + # Panel is horizontal, span controls columns + iratio = row1 - offset if side == "top" else row2 + offset + # Parse span as column specification (1-indexed input, convert to 0-indexed) + if isinstance(span_override, Integral): + span_start, span_stop = span_override - 1, span_override - 1 + else: + span_override = np.atleast_1d(span_override) + span_start, span_stop = span_override[0] - 1, span_override[-1] - 1 + + # Encode indices for gridspec + which = "h" if side in ("left", "right") else "w" + span_start_encoded, span_stop_encoded = self._encode_indices( + span_start, span_stop, which=which + ) + + return slot, iratio, slice(span_start_encoded, span_stop_encoded + 1) + def _insert_panel_slot( self, - side, + side: str, arg, *, - share=None, - width=None, - space=None, - pad=None, - filled=False, + share: Optional[bool] = None, + width: Optional[Union[float, str]] = None, + space: Optional[Union[float, str]] = None, + pad: Optional[Union[float, str]] = None, + filled: bool = False, + span_override: Optional[Union[int, Tuple[int, int]]] = None, ): """ Insert a panel slot into the existing gridspec. The `side` is the panel side @@ -608,7 +671,11 @@ def _insert_panel_slot( raise RuntimeError("Figure must be assigned to gridspec.") if side not in ("left", "right", "bottom", "top"): raise ValueError(f"Invalid side {side}.") - slot, idx, span = self._parse_panel_arg(side, arg) + # Use span override if provided + if span_override is not None: + slot, idx, span = self._parse_panel_arg_with_span(side, arg, span_override) + else: + slot, idx, span = self._parse_panel_arg(side, arg) pad = units(pad, "em", "in") space = units(space, "em", "in") width = units(width, "in") diff --git a/ultraplot/tests/test_colorbar.py b/ultraplot/tests/test_colorbar.py index 6781e3b81..f16a6f13a 100644 --- a/ultraplot/tests/test_colorbar.py +++ b/ultraplot/tests/test_colorbar.py @@ -571,3 +571,199 @@ def test_inset_colorbar_orientation(loc, orientation, labelloc): found = True break assert found, f"Colorbar not found for loc='{loc}' with orientation='{orientation}'" + + +def test_colorbar_span_bottom(): + """Test bottom colorbar with span parameter.""" + + fig, axs = uplt.subplots(nrows=2, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Colorbar below row 1, spanning columns 1-2 + cb = fig.colorbar(cm, ax=axs[0, :], span=(1, 2), loc="bottom") + + # Verify colorbar was created + assert cb is not None + + # Verify position (should span only columns 1-2) + pos = cb.ax.get_position() + col0_left = axs[0, 0].get_position().x0 + col1_right = axs[0, 1].get_position().x1 + assert abs(pos.x0 - col0_left) < 0.1 + assert abs(pos.x1 - col1_right) < 0.1 + + +def test_colorbar_span_top(): + """Test top colorbar with span parameter.""" + import numpy as np + + fig, axs = uplt.subplots(nrows=2, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Colorbar above row 2, spanning columns 2-3 + cb = fig.colorbar(cm, ax=axs[1, :], cols=(2, 3), loc="top") + + assert cb is not None + + +def test_colorbar_span_right(): + """Test right colorbar with rows parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Colorbar right of column 1, spanning rows 1-2 + cb = fig.colorbar(cm, ax=axs[:, 0], rows=(1, 2), loc="right") + + assert cb is not None + + +def test_colorbar_span_left(): + """Test left colorbar with rows parameter.""" + import numpy as np + + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Colorbar left of column 2, spanning rows 2-3 + cb = fig.colorbar(cm, ax=axs[:, 1], rows=(2, 3), loc="left") + + assert cb is not None + + +def test_colorbar_span_validation_left_with_cols_error(): + """Test that LEFT colorbar raises error with cols parameter.""" + + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + with pytest.raises(ValueError, match="left.*vertical.*use 'rows='.*not 'cols='"): + fig.colorbar(cm, ax=axs[0, 0], cols=(1, 2), loc="left") + + +def test_colorbar_span_validation_right_with_cols_error(): + """Test that RIGHT colorbar raises error with cols parameter.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + with pytest.raises(ValueError, match="right.*vertical.*use 'rows='.*not 'cols='"): + fig.colorbar(cm, ax=axs[0, 0], cols=(1, 2), loc="right") + + +def test_colorbar_span_validation_top_with_rows_error(): + """Test that TOP colorbar raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + with pytest.raises(ValueError, match="top.*horizontal.*use 'cols='.*not 'rows='"): + fig.colorbar(cm, ax=axs[0, 0], rows=(1, 2), loc="top") + + +def test_colorbar_span_validation_bottom_with_rows_error(): + """Test that BOTTOM colorbar raises error with rows parameter.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + with pytest.raises( + ValueError, match="bottom.*horizontal.*use 'cols='.*not 'rows='" + ): + fig.colorbar(cm, ax=axs[0, 0], rows=(1, 2), loc="bottom") + + +def test_colorbar_span_validation_left_with_span_warns(): + """Test that LEFT colorbar with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + with pytest.warns(match="left.*vertical.*prefer 'rows='"): + cb = fig.colorbar(cm, ax=axs[0, 0], span=(1, 2), loc="left") + assert cb is not None + + +def test_colorbar_span_validation_right_with_span_warns(): + """Test that RIGHT colorbar with span parameter issues warning.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + with pytest.warns(match="right.*vertical.*prefer 'rows='"): + cb = fig.colorbar(cm, ax=axs[0, 0], span=(1, 2), loc="right") + assert cb is not None + + +def test_colorbar_array_without_span(): + """Test that colorbar on array without span preserves original behavior.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Should create colorbar for all axes in the array + cb = fig.colorbar(cm, ax=axs[:], loc="right") + assert cb is not None + + +def test_colorbar_array_with_span(): + """Test that colorbar on array with span uses first axis + span extent.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Should use first axis position with span extent + cb = fig.colorbar(cm, ax=axs[0, :], span=(1, 2), loc="bottom") + assert cb is not None + + # Verify it spans only columns 1-2 + pos = cb.ax.get_position() + col0_left = axs[0, 0].get_position().x0 + col1_right = axs[0, 1].get_position().x1 + assert abs(pos.x0 - col0_left) < 0.1 + assert abs(pos.x1 - col1_right) < 0.1 + + +def test_colorbar_row_without_span(): + """Test that colorbar on row without span spans entire row.""" + fig, axs = uplt.subplots(nrows=2, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Should span all 3 columns + cb = fig.colorbar(cm, ax=axs[0, :], loc="bottom") + assert cb is not None + + +def test_colorbar_column_without_span(): + """Test that colorbar on column without span spans entire column.""" + fig, axs = uplt.subplots(nrows=3, ncols=2) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Should span all 3 rows + cb = fig.colorbar(cm, ax=axs[:, 0], loc="right") + assert cb is not None + + +def test_colorbar_multiple_sides_with_span(): + """Test multiple colorbars on different sides with span control.""" + fig, axs = uplt.subplots(nrows=3, ncols=3) + data = np.random.random((10, 10)) + cm = axs[0, 0].pcolormesh(data) + + # Create colorbars on all 4 sides with different spans + cb_bottom = fig.colorbar(cm, ax=axs[0, 0], span=(1, 2), loc="bottom") + cb_top = fig.colorbar(cm, ax=axs[1, 0], span=(2, 3), loc="top") + cb_right = fig.colorbar(cm, ax=axs[0, 0], rows=(1, 2), loc="right") + cb_left = fig.colorbar(cm, ax=axs[0, 1], rows=(2, 3), loc="left") + + assert cb_bottom is not None + assert cb_top is not None + assert cb_right is not None + assert cb_left is not None