Skip to content
Merged
20 changes: 12 additions & 8 deletions ultraplot/axes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import matplotlib.colors as mcolors
import matplotlib.container as mcontainer
import matplotlib.contour as mcontour
import matplotlib.legend as mlegend
import matplotlib.offsetbox as moffsetbox
import matplotlib.patches as mpatches
import matplotlib.projections as mproj
Expand All @@ -35,6 +34,7 @@
from matplotlib import cbook
from packaging import version

from .. import legend as plegend
from .. import colors as pcolors
from .. import constructor
from .. import ticker as pticker
Expand Down Expand Up @@ -925,7 +925,7 @@ def _add_queued_guides(self):
# _parse_legend_args to search for everything. Ensure None if empty.
for (loc, align), legend in tuple(self._legend_dict.items()):
if not isinstance(legend, tuple) or any(
isinstance(_, mlegend.Legend) for _ in legend
isinstance(_, plegend.Legend) for _ in legend
): # noqa: E501
continue
handles, labels, kwargs = legend
Expand Down Expand Up @@ -1131,7 +1131,11 @@ def _add_colorbar(
kwargs.update({"label": label, "length": length, "width": width})
extendsize = _not_none(extendsize, rc["colorbar.insetextend"])
cax, kwargs = self._parse_colorbar_inset(
loc=loc, labelloc=labelloc, labelrotation = labelrotation, pad=pad, **kwargs
loc=loc,
labelloc=labelloc,
labelrotation=labelrotation,
pad=pad,
**kwargs,
) # noqa: E501

# Parse the colorbar mappable
Expand Down Expand Up @@ -1681,14 +1685,14 @@ def _get_legend_handles(self, handler_map=None):
else: # this is a figure-wide legend
axs = list(self.figure._iter_axes(hidden=False, children=True))
handles = []
handler_map_full = mlegend.Legend.get_default_handler_map()
handler_map_full = plegend.Legend.get_default_handler_map()
handler_map_full = handler_map_full.copy()
handler_map_full.update(handler_map or {})
for ax in axs:
for attr in ("lines", "patches", "collections", "containers"):
for handle in getattr(ax, attr, []): # guard against API changes
label = handle.get_label()
handler = mlegend.Legend.get_legend_handler(
handler = plegend.Legend.get_legend_handler(
handler_map_full, handle
) # noqa: E501
if handler and label and label[0] != "_":
Expand Down Expand Up @@ -1802,7 +1806,7 @@ def _register_guide(self, guide, obj, key, **kwargs):
# Replace with instance or update the queue
# NOTE: This is valid for both mappable-values pairs and handles-labels pairs
if not isinstance(obj, tuple) or any(
isinstance(_, mlegend.Legend) for _ in obj
isinstance(_, plegend.Legend) for _ in obj
): # noqa: E501
dict_[key] = obj
else:
Expand Down Expand Up @@ -2273,7 +2277,7 @@ def _parse_legend_aligned(self, pairs, ncol=None, order=None, **kwargs):
# NOTE: Permit drawing empty legend to catch edge cases
pairs = [pair for pair in array.flat if isinstance(pair, tuple)]
args = tuple(zip(*pairs)) or ([], [])
return mlegend.Legend(self, *args, ncol=ncol, **kwargs)
return plegend.Legend(self, *args, ncol=ncol, **kwargs)

def _parse_legend_centered(
self,
Expand Down Expand Up @@ -2328,7 +2332,7 @@ def _parse_legend_centered(
base, offset = 0.5, 0.5 * (len(pairs) - extra)
y0, y1 = base + (offset - np.array([i + 1, i])) * height
bb = mtransforms.Bbox([[0, y0], [1, y1]])
leg = mlegend.Legend(
leg = plegend.Legend(
self,
*zip(*ipairs),
bbox_to_anchor=bb,
Expand Down
33 changes: 33 additions & 0 deletions ultraplot/legend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from matplotlib import legend as mlegend

try:
from typing import override
except ImportError:
from typing_extensions import override


class Legend(mlegend.Legend):
# Soft wrapper of matplotlib legend's class.
# Currently we only override the syncing of the location.
# The user may change the location and the legend_dict should
# be updated accordingly. This caused an issue where
# a legend format was not behaving according to the docs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@override
def set_loc(self, loc=None):
# Sync location setting with the move
old_loc = None
if self.axes is not None:
# Get old location which is a tuple of location and
# legend type
for k, v in self.axes._legend_dict.items():
if v is self:
old_loc = k
break
super().set_loc(loc)
if old_loc is not None:
value = self.axes._legend_dict.pop(old_loc, None)
where, type = old_loc
self.axes._legend_dict[(loc, type)] = value
22 changes: 22 additions & 0 deletions ultraplot/tests/test_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,25 @@ def test_legend_col_spacing(rng):
with pytest.raises(ValueError):
ax.legend(loc="bottom", ncol=3, columnspacing="15x")
return fig


def test_sync_label_dict(rng):
"""
Legends are held within _legend_dict for which the key is a tuple of location and alignment.

We need to ensure that the legend is updated in the dictionary when its location is changed.
"""
data = rng.random((2, 100))
fig, ax = uplt.subplots()
ax.plot(*data, label="test")
leg = ax.legend(loc="lower right")
assert ("lower right", "center") in ax[0]._legend_dict, "Legend not found in dict"
leg.set_loc("upper left")
assert ("upper left", "center") in ax[
0
]._legend_dict, "Legend not found in dict after update"
assert leg is ax[0]._legend_dict[("upper left", "center")]
assert ("lower right", "center") not in ax[
0
]._legend_dict, "Old legend not removed from dict"
uplt.close(fig)