Skip to content
60 changes: 51 additions & 9 deletions ultraplot/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3090,13 +3090,10 @@ def __init__(self, kwargs):
kwargs : dict-like
The source dictionary.
"""
super().__init__(kwargs)
# The colormap is initialized with all the base colormaps
# We have to change the classes internally to Perceptual, Continuous or Discrete
# such that ultraplot knows what these objects are. We piggy back on the registering mechanism
# by overriding matplotlib's behavior
for name in tuple(self._cmaps.keys()):
self.register(self._cmaps[name], name=name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we are no longer using the register method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch. We do need it as we override the ColormapRegistery. Moved the lazy loading from file to a private function to make this distinction clearer and still rely on register

super().__init__({k.lower(): v for k, v in kwargs.items()})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch to k.lower() here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We perform a lower command internally later when the colormap is registered.

# The colormap is initialized with all the base colormaps.
# These are converted to ultraplot's own colormap objects
# on the fly when they are first accessed.

def _translate_deprecated(self, key):
"""
Expand Down Expand Up @@ -3188,9 +3185,42 @@ def __getitem__(self, key):

if reverse:
key = key.removesuffix("_r")

# Retrieve colormap
if self._has_item(key):
value = self._cmaps[key].copy()
value = self._cmaps[key]

# Lazy loading from file
if isinstance(value, dict) and value.get("is_lazy"):
path = value["path"]
type = value["type"]
is_default = value.get("is_default", False)
if type == "continuous":
cmap = ContinuousColormap.from_file(path, warn_on_failure=True)
elif type == "discrete":
cmap = DiscreteColormap.from_file(path, warn_on_failure=True)
else:
raise ValueError(
f"Invalid colormap type {type!r} for key {key!r} in file {path!r}. "
"Expected 'continuous' or 'discrete'."
)

if cmap:
if is_default and cmap.name.lower() in CMAPS_CYCLIC:
cmap.set_cyclic(True)
self._cmaps[key] = cmap
value = cmap
else: # failed to load
# remove from registry to avoid trying again
del self._cmaps[key]
raise KeyError(f"Failed to load colormap {key!r} from {path!r}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code warned on failures at least according to the keyword. This code appears to raise an error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change it back to a warn.


# Lazy loading for builtin matplotlib cmaps
if not isinstance(value, (ContinuousColormap, DiscreteColormap)):
value = _translate_cmap(value)
self._cmaps[key] = value

value = value.copy()
else:
raise KeyError(
f"Invalid colormap or color cycle name {key!r}. Options are: "
Expand Down Expand Up @@ -3220,7 +3250,19 @@ def register(self, cmap, *, name=None, force=False):
# surpress warning if the colormap is not generate by ultraplot
if name not in self._builtin_cmaps:
print(f"Overwriting {name!r} that was already registered")
self._cmaps[name] = cmap.copy()
self._cmaps[name] = cmap.copy(name=name)

def register_lazy(self, name, path, type, is_default=False):
"""
Register a colormap to be loaded lazily from a file.
"""
name = self._translate_key(name, mirror=False)
self._cmaps[name] = {
"path": path,
"type": type,
"is_default": is_default,
"is_lazy": True,
}


# Initialize databases
Expand Down
18 changes: 8 additions & 10 deletions ultraplot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,11 @@ def register_cmaps(*args, user=None, local=None, default=False):
for i, path in _iter_data_objects(
"cmaps", *paths, user=user, local=local, default=default
):
cmap = pcolors.ContinuousColormap.from_file(path, warn_on_failure=True)
if not cmap:
continue
if i == 0 and cmap.name.lower() in pcolors.CMAPS_CYCLIC:
cmap.set_cyclic(True)
pcolors._cmap_database.register(cmap, name=cmap.name)
name, ext = os.path.splitext(os.path.basename(path))
if ext and ext[1:] in ("json", "txt", "rgb", "xml", "hex"):
pcolors._cmap_database.register_lazy(
name, path, "continuous", is_default=(i == 0)
)


@docstring._snippet_manager
Expand Down Expand Up @@ -541,10 +540,9 @@ def register_cycles(*args, user=None, local=None, default=False):
for _, path in _iter_data_objects(
"cycles", *paths, user=user, local=local, default=default
):
cmap = pcolors.DiscreteColormap.from_file(path, warn_on_failure=True)
if not cmap:
continue
pcolors._cmap_database.register(cmap, name=cmap.name)
name, ext = os.path.splitext(os.path.basename(path))
if ext and ext[1:] in ("hex", "rgb", "txt"):
pcolors._cmap_database.register_lazy(name, path, "discrete")


@docstring._snippet_manager
Expand Down
129 changes: 129 additions & 0 deletions ultraplot/tests/test_colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import pytest
import numpy as np
import matplotlib.colors as mcolors

from ultraplot import colors as pcolors
from ultraplot import config


@pytest.fixture(autouse=True)
def setup_teardown():
"""
Reset the colormap database before and after each test.
"""
# This ensures a clean state for each test.
# The singleton instance is replaced with a new one.
pcolors._cmap_database = pcolors._init_cmap_database()
config.register_cmaps(default=True)
config.register_cycles(default=True)
yield


def test_lazy_loading_builtin():
"""
Test that built-in colormaps are lazy-loaded.
"""
# Before access, it should be a matplotlib colormap
cmap_raw = pcolors._cmap_database._cmaps["viridis"]
assert isinstance(
cmap_raw,
(
pcolors.ContinuousColormap,
pcolors.DiscreteColormap,
mcolors.ListedColormap,
),
)

# After access, it should be an ultraplot colormap
cmap_get = pcolors._cmap_database.get_cmap("viridis")
assert isinstance(cmap_get, pcolors.ContinuousColormap)

# The internal representation should also be updated
cmap_raw_after = pcolors._cmap_database._cmaps["viridis"]
assert isinstance(cmap_raw_after, pcolors.ContinuousColormap)


def test_case_insensitivity():
"""
Test that colormap lookup is case-insensitive.
"""
cmap1 = pcolors._cmap_database.get_cmap("ViRiDiS")
cmap2 = pcolors._cmap_database.get_cmap("viridis")
assert cmap1.name.lower().startswith("_viridis")
assert cmap2.name.lower().startswith("_viridis")


def test_reversed_shifted():
"""
Test reversed and shifted colormaps.
"""
# Create a simple colormap to test the reversal logic
# This avoids dependency on the exact definition of 'viridis' in matplotlib
colors_list = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] # Red, Green, Blue
test_cmap = pcolors.ContinuousColormap.from_list("test_cmap", colors_list)
pcolors._cmap_database.register(test_cmap)

cmap = pcolors._cmap_database.get_cmap("test_cmap")
cmap_r = pcolors._cmap_database.get_cmap("test_cmap_r")

# Check name
assert cmap_r.name == "_test_cmap_copy_r"
# Check colors
# Start of original should be end of reversed
assert np.allclose(cmap(0.0), cmap_r(1.0))
# End of original should be start of reversed
assert np.allclose(cmap(1.0), cmap_r(0.0))
# Middle should be the same
assert np.allclose(cmap(0.5)[:3], cmap_r(0.5)[:3][::-1])


def test_grays_translation():
"""
Test that 'Grays' is translated to 'greys'.
"""
cmap_grays = pcolors._cmap_database.get_cmap("Grays")
assert cmap_grays.name.lower().startswith("_greys")


def test_lazy_loading_file(tmp_path):
"""
Test that colormaps from files are lazy-loaded.
"""
# Create a dummy colormap file
cmap_data = "1, 0, 0\n0, 1, 0\n0, 0, 1"
cmap_file = tmp_path / "my_test_cmap.rgb"
cmap_file.write_text(cmap_data)

# Register it lazily
pcolors._cmap_database.register_lazy("my_test_cmap", str(cmap_file), "continuous")

# Before access, it should be a lazy-load dict
cmap_raw = pcolors._cmap_database._cmaps["my_test_cmap"]
assert isinstance(cmap_raw, dict)
assert cmap_raw["is_lazy"]

# After access, it should be an ultraplot colormap
cmap_get = pcolors._cmap_database.get_cmap("my_test_cmap")
assert isinstance(cmap_get, pcolors.ContinuousColormap)
assert cmap_get.name.lower().startswith("_my_test_cmap")

# The internal representation should also be updated
cmap_raw_after = pcolors._cmap_database._cmaps["my_test_cmap"]
assert isinstance(cmap_raw_after, pcolors.ContinuousColormap)


def test_register_new():
"""
Test registering a new colormap.
"""
colors_list = [(0, 0, 0), (1, 1, 1)]
new_cmap = pcolors.DiscreteColormap(colors_list, name="my_new_cmap")
pcolors._cmap_database.register(new_cmap)

# Check it was registered
cmap_get = pcolors._cmap_database.get_cmap("my_new_cmap")
assert cmap_get.name.lower().startswith(
"_my_new_cmap"
), f"Received {cmap_get.name.lower()} expected _my_new_cmap"
assert len(cmap_get.colors) == 2