Skip to content

Commit 6ab7434

Browse files
authored
Lazy loading colormaps (#343)
1 parent 03c655a commit 6ab7434

File tree

3 files changed

+203
-19
lines changed

3 files changed

+203
-19
lines changed

ultraplot/colors.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@
4141
)
4242
from .utils import set_alpha, to_hex, to_rgb, to_rgba, to_xyz, to_xyza
4343

44+
try:
45+
from typing import override
46+
except:
47+
from typing_extensions import override
48+
4449
__all__ = [
4550
"DiscreteColormap",
4651
"ContinuousColormap",
@@ -3090,13 +3095,10 @@ def __init__(self, kwargs):
30903095
kwargs : dict-like
30913096
The source dictionary.
30923097
"""
3093-
super().__init__(kwargs)
3094-
# The colormap is initialized with all the base colormaps
3095-
# We have to change the classes internally to Perceptual, Continuous or Discrete
3096-
# such that ultraplot knows what these objects are. We piggy back on the registering mechanism
3097-
# by overriding matplotlib's behavior
3098-
for name in tuple(self._cmaps.keys()):
3099-
self.register(self._cmaps[name], name=name)
3098+
super().__init__({k.lower(): v for k, v in kwargs.items()})
3099+
# The colormap is initialized with all the base colormaps.
3100+
# These are converted to ultraplot's own colormap objects
3101+
# on the fly when they are first accessed.
31003102

31013103
def _translate_deprecated(self, key):
31023104
"""
@@ -3171,6 +3173,34 @@ def _translate_key(self, original_key, mirror=True):
31713173
def _has_item(self, key):
31723174
return key in self._cmaps
31733175

3176+
def _load_and_register_cmap(self, key, value):
3177+
"""
3178+
Load a colormap from a file and register it.
3179+
"""
3180+
path = value["path"]
3181+
type = value["type"]
3182+
is_default = value.get("is_default", False)
3183+
if type == "continuous":
3184+
cmap = ContinuousColormap.from_file(path, warn_on_failure=True)
3185+
elif type == "discrete":
3186+
cmap = DiscreteColormap.from_file(path, warn_on_failure=True)
3187+
else:
3188+
raise ValueError(
3189+
f"Invalid colormap type {type!r} for key {key!r} in file {path!r}. "
3190+
"Expected 'continuous' or 'discrete'."
3191+
)
3192+
3193+
if cmap:
3194+
if is_default and cmap.name.lower() in CMAPS_CYCLIC:
3195+
cmap.set_cyclic(True)
3196+
self.register(cmap, name=key)
3197+
return self._cmaps[key]
3198+
else: # failed to load
3199+
# remove from registry to avoid trying again
3200+
del self._cmaps[key]
3201+
warnings._warn_ultraplot(f"Failed to load colormap {key!r} from {path!r}")
3202+
return None
3203+
31743204
def get_cmap(self, cmap):
31753205
return self.__getitem__(cmap)
31763206

@@ -3188,9 +3218,23 @@ def __getitem__(self, key):
31883218

31893219
if reverse:
31903220
key = key.removesuffix("_r")
3221+
31913222
# Retrieve colormap
31923223
if self._has_item(key):
3193-
value = self._cmaps[key].copy()
3224+
value = self._cmaps[key]
3225+
3226+
# Lazy loading from file
3227+
if isinstance(value, dict) and value.get("is_lazy"):
3228+
value = self._load_and_register_cmap(key, value)
3229+
if not value:
3230+
raise KeyError(f"Failed to load colormap {key!r} from file.")
3231+
3232+
# Lazy loading for builtin matplotlib cmaps
3233+
if not isinstance(value, (ContinuousColormap, DiscreteColormap)):
3234+
value = _translate_cmap(value)
3235+
self._cmaps[key] = value
3236+
3237+
value = value.copy()
31943238
else:
31953239
raise KeyError(
31963240
f"Invalid colormap or color cycle name {key!r}. Options are: "
@@ -3204,6 +3248,7 @@ def __getitem__(self, key):
32043248
value = value.shifted(180)
32053249
return value
32063250

3251+
@override
32073252
def register(self, cmap, *, name=None, force=False):
32083253
"""
32093254
Add the colormap after validating and converting.
@@ -3220,7 +3265,19 @@ def register(self, cmap, *, name=None, force=False):
32203265
# surpress warning if the colormap is not generate by ultraplot
32213266
if name not in self._builtin_cmaps:
32223267
print(f"Overwriting {name!r} that was already registered")
3223-
self._cmaps[name] = cmap.copy()
3268+
self._cmaps[name] = cmap.copy(name=name)
3269+
3270+
def register_lazy(self, name, path, type, is_default=False):
3271+
"""
3272+
Register a colormap to be loaded lazily from a file.
3273+
"""
3274+
name = self._translate_key(name, mirror=False)
3275+
self._cmaps[name] = {
3276+
"path": path,
3277+
"type": type,
3278+
"is_default": is_default,
3279+
"is_lazy": True,
3280+
}
32243281

32253282

32263283
# Initialize databases

ultraplot/config.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -497,12 +497,11 @@ def register_cmaps(*args, user=None, local=None, default=False):
497497
for i, path in _iter_data_objects(
498498
"cmaps", *paths, user=user, local=local, default=default
499499
):
500-
cmap = pcolors.ContinuousColormap.from_file(path, warn_on_failure=True)
501-
if not cmap:
502-
continue
503-
if i == 0 and cmap.name.lower() in pcolors.CMAPS_CYCLIC:
504-
cmap.set_cyclic(True)
505-
pcolors._cmap_database.register(cmap, name=cmap.name)
500+
name, ext = os.path.splitext(os.path.basename(path))
501+
if ext and ext[1:] in ("json", "txt", "rgb", "xml", "hex"):
502+
pcolors._cmap_database.register_lazy(
503+
name, path, "continuous", is_default=(i == 0)
504+
)
506505

507506

508507
@docstring._snippet_manager
@@ -541,10 +540,9 @@ def register_cycles(*args, user=None, local=None, default=False):
541540
for _, path in _iter_data_objects(
542541
"cycles", *paths, user=user, local=local, default=default
543542
):
544-
cmap = pcolors.DiscreteColormap.from_file(path, warn_on_failure=True)
545-
if not cmap:
546-
continue
547-
pcolors._cmap_database.register(cmap, name=cmap.name)
543+
name, ext = os.path.splitext(os.path.basename(path))
544+
if ext and ext[1:] in ("hex", "rgb", "txt"):
545+
pcolors._cmap_database.register_lazy(name, path, "discrete")
548546

549547

550548
@docstring._snippet_manager

ultraplot/tests/test_colors.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
import pytest
3+
import numpy as np
4+
import matplotlib.colors as mcolors
5+
6+
from ultraplot import colors as pcolors
7+
from ultraplot import config
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def setup_teardown():
12+
"""
13+
Reset the colormap database before and after each test.
14+
"""
15+
# This ensures a clean state for each test.
16+
# The singleton instance is replaced with a new one.
17+
pcolors._cmap_database = pcolors._init_cmap_database()
18+
config.register_cmaps(default=True)
19+
config.register_cycles(default=True)
20+
yield
21+
22+
23+
def test_lazy_loading_builtin():
24+
"""
25+
Test that built-in colormaps are lazy-loaded.
26+
"""
27+
# Before access, it should be a matplotlib colormap
28+
cmap_raw = pcolors._cmap_database._cmaps["viridis"]
29+
assert isinstance(
30+
cmap_raw,
31+
(
32+
pcolors.ContinuousColormap,
33+
pcolors.DiscreteColormap,
34+
mcolors.ListedColormap,
35+
),
36+
)
37+
38+
# After access, it should be an ultraplot colormap
39+
cmap_get = pcolors._cmap_database.get_cmap("viridis")
40+
assert isinstance(cmap_get, pcolors.ContinuousColormap)
41+
42+
# The internal representation should also be updated
43+
cmap_raw_after = pcolors._cmap_database._cmaps["viridis"]
44+
assert isinstance(cmap_raw_after, pcolors.ContinuousColormap)
45+
46+
47+
def test_case_insensitivity():
48+
"""
49+
Test that colormap lookup is case-insensitive.
50+
"""
51+
cmap1 = pcolors._cmap_database.get_cmap("ViRiDiS")
52+
cmap2 = pcolors._cmap_database.get_cmap("viridis")
53+
assert cmap1.name.lower().startswith("_viridis")
54+
assert cmap2.name.lower().startswith("_viridis")
55+
56+
57+
def test_reversed_shifted():
58+
"""
59+
Test reversed and shifted colormaps.
60+
"""
61+
# Create a simple colormap to test the reversal logic
62+
# This avoids dependency on the exact definition of 'viridis' in matplotlib
63+
colors_list = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] # Red, Green, Blue
64+
test_cmap = pcolors.ContinuousColormap.from_list("test_cmap", colors_list)
65+
pcolors._cmap_database.register(test_cmap)
66+
67+
cmap = pcolors._cmap_database.get_cmap("test_cmap")
68+
cmap_r = pcolors._cmap_database.get_cmap("test_cmap_r")
69+
70+
# Check name
71+
assert cmap_r.name == "_test_cmap_copy_r"
72+
# Check colors
73+
# Start of original should be end of reversed
74+
assert np.allclose(cmap(0.0), cmap_r(1.0))
75+
# End of original should be start of reversed
76+
assert np.allclose(cmap(1.0), cmap_r(0.0))
77+
# Middle should be the same
78+
assert np.allclose(cmap(0.5)[:3], cmap_r(0.5)[:3][::-1])
79+
80+
81+
def test_grays_translation():
82+
"""
83+
Test that 'Grays' is translated to 'greys'.
84+
"""
85+
cmap_grays = pcolors._cmap_database.get_cmap("Grays")
86+
assert cmap_grays.name.lower().startswith("_greys")
87+
88+
89+
def test_lazy_loading_file(tmp_path):
90+
"""
91+
Test that colormaps from files are lazy-loaded.
92+
"""
93+
# Create a dummy colormap file
94+
cmap_data = "1, 0, 0\n0, 1, 0\n0, 0, 1"
95+
cmap_file = tmp_path / "my_test_cmap.rgb"
96+
cmap_file.write_text(cmap_data)
97+
98+
# Register it lazily
99+
pcolors._cmap_database.register_lazy("my_test_cmap", str(cmap_file), "continuous")
100+
101+
# Before access, it should be a lazy-load dict
102+
cmap_raw = pcolors._cmap_database._cmaps["my_test_cmap"]
103+
assert isinstance(cmap_raw, dict)
104+
assert cmap_raw["is_lazy"]
105+
106+
# After access, it should be an ultraplot colormap
107+
cmap_get = pcolors._cmap_database.get_cmap("my_test_cmap")
108+
assert isinstance(cmap_get, pcolors.ContinuousColormap)
109+
assert cmap_get.name.lower().startswith("_my_test_cmap")
110+
111+
# The internal representation should also be updated
112+
cmap_raw_after = pcolors._cmap_database._cmaps["my_test_cmap"]
113+
assert isinstance(cmap_raw_after, pcolors.ContinuousColormap)
114+
115+
116+
def test_register_new():
117+
"""
118+
Test registering a new colormap.
119+
"""
120+
colors_list = [(0, 0, 0), (1, 1, 1)]
121+
new_cmap = pcolors.DiscreteColormap(colors_list, name="my_new_cmap")
122+
pcolors._cmap_database.register(new_cmap)
123+
124+
# Check it was registered
125+
cmap_get = pcolors._cmap_database.get_cmap("my_new_cmap")
126+
assert cmap_get.name.lower().startswith(
127+
"_my_new_cmap"
128+
), f"Received {cmap_get.name.lower()} expected _my_new_cmap"
129+
assert len(cmap_get.colors) == 2

0 commit comments

Comments
 (0)