Skip to content
75 changes: 66 additions & 9 deletions ultraplot/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@
)
from .utils import set_alpha, to_hex, to_rgb, to_rgba, to_xyz, to_xyza

try:
from typing import override
except:
from typing_extensions import override

__all__ = [
"DiscreteColormap",
"ContinuousColormap",
Expand Down Expand Up @@ -3090,13 +3095,10 @@ def __init__(self, kwargs):
kwargs : dict-like
The source dictionary.
"""
super().__init__(kwargs)
# The colormap is initialized with all the base colormaps
# We have to change the classes internally to Perceptual, Continuous or Discrete
# such that ultraplot knows what these objects are. We piggy back on the registering mechanism
# by overriding matplotlib's behavior
for name in tuple(self._cmaps.keys()):
self.register(self._cmaps[name], name=name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we are no longer using the register method?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch. We do need it as we override the ColormapRegistery. Moved the lazy loading from file to a private function to make this distinction clearer and still rely on register

super().__init__({k.lower(): v for k, v in kwargs.items()})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why switch to k.lower() here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We perform a lower command internally later when the colormap is registered.

# The colormap is initialized with all the base colormaps.
# These are converted to ultraplot's own colormap objects
# on the fly when they are first accessed.

def _translate_deprecated(self, key):
"""
Expand Down Expand Up @@ -3171,6 +3173,34 @@ def _translate_key(self, original_key, mirror=True):
def _has_item(self, key):
return key in self._cmaps

def _load_and_register_cmap(self, key, value):
"""
Load a colormap from a file and register it.
"""
path = value["path"]
type = value["type"]
is_default = value.get("is_default", False)
if type == "continuous":
cmap = ContinuousColormap.from_file(path, warn_on_failure=True)
elif type == "discrete":
cmap = DiscreteColormap.from_file(path, warn_on_failure=True)
else:
raise ValueError(
f"Invalid colormap type {type!r} for key {key!r} in file {path!r}. "
"Expected 'continuous' or 'discrete'."
)

if cmap:
if is_default and cmap.name.lower() in CMAPS_CYCLIC:
cmap.set_cyclic(True)
self.register(cmap, name=key)
return self._cmaps[key]
else: # failed to load
# remove from registry to avoid trying again
del self._cmaps[key]
warnings._warn_ultraplot(f"Failed to load colormap {key!r} from {path!r}")
return None

def get_cmap(self, cmap):
return self.__getitem__(cmap)

Expand All @@ -3188,9 +3218,23 @@ def __getitem__(self, key):

if reverse:
key = key.removesuffix("_r")

# Retrieve colormap
if self._has_item(key):
value = self._cmaps[key].copy()
value = self._cmaps[key]

# Lazy loading from file
if isinstance(value, dict) and value.get("is_lazy"):
value = self._load_and_register_cmap(key, value)
if not value:
raise KeyError(f"Failed to load colormap {key!r} from file.")

# Lazy loading for builtin matplotlib cmaps
if not isinstance(value, (ContinuousColormap, DiscreteColormap)):
value = _translate_cmap(value)
self._cmaps[key] = value

value = value.copy()
else:
raise KeyError(
f"Invalid colormap or color cycle name {key!r}. Options are: "
Expand All @@ -3204,6 +3248,7 @@ def __getitem__(self, key):
value = value.shifted(180)
return value

@override
def register(self, cmap, *, name=None, force=False):
"""
Add the colormap after validating and converting.
Expand All @@ -3220,7 +3265,19 @@ def register(self, cmap, *, name=None, force=False):
# surpress warning if the colormap is not generate by ultraplot
if name not in self._builtin_cmaps:
print(f"Overwriting {name!r} that was already registered")
self._cmaps[name] = cmap.copy()
self._cmaps[name] = cmap.copy(name=name)

def register_lazy(self, name, path, type, is_default=False):
"""
Register a colormap to be loaded lazily from a file.
"""
name = self._translate_key(name, mirror=False)
self._cmaps[name] = {
"path": path,
"type": type,
"is_default": is_default,
"is_lazy": True,
}


# Initialize databases
Expand Down
18 changes: 8 additions & 10 deletions ultraplot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,12 +497,11 @@ def register_cmaps(*args, user=None, local=None, default=False):
for i, path in _iter_data_objects(
"cmaps", *paths, user=user, local=local, default=default
):
cmap = pcolors.ContinuousColormap.from_file(path, warn_on_failure=True)
if not cmap:
continue
if i == 0 and cmap.name.lower() in pcolors.CMAPS_CYCLIC:
cmap.set_cyclic(True)
pcolors._cmap_database.register(cmap, name=cmap.name)
name, ext = os.path.splitext(os.path.basename(path))
if ext and ext[1:] in ("json", "txt", "rgb", "xml", "hex"):
pcolors._cmap_database.register_lazy(
name, path, "continuous", is_default=(i == 0)
)


@docstring._snippet_manager
Expand Down Expand Up @@ -541,10 +540,9 @@ def register_cycles(*args, user=None, local=None, default=False):
for _, path in _iter_data_objects(
"cycles", *paths, user=user, local=local, default=default
):
cmap = pcolors.DiscreteColormap.from_file(path, warn_on_failure=True)
if not cmap:
continue
pcolors._cmap_database.register(cmap, name=cmap.name)
name, ext = os.path.splitext(os.path.basename(path))
if ext and ext[1:] in ("hex", "rgb", "txt"):
pcolors._cmap_database.register_lazy(name, path, "discrete")


@docstring._snippet_manager
Expand Down
129 changes: 129 additions & 0 deletions ultraplot/tests/test_colors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import os
import pytest
import numpy as np
import matplotlib.colors as mcolors

from ultraplot import colors as pcolors
from ultraplot import config


@pytest.fixture(autouse=True)
def setup_teardown():
"""
Reset the colormap database before and after each test.
"""
# This ensures a clean state for each test.
# The singleton instance is replaced with a new one.
pcolors._cmap_database = pcolors._init_cmap_database()
config.register_cmaps(default=True)
config.register_cycles(default=True)
yield


def test_lazy_loading_builtin():
"""
Test that built-in colormaps are lazy-loaded.
"""
# Before access, it should be a matplotlib colormap
cmap_raw = pcolors._cmap_database._cmaps["viridis"]
assert isinstance(
cmap_raw,
(
pcolors.ContinuousColormap,
pcolors.DiscreteColormap,
mcolors.ListedColormap,
),
)

# After access, it should be an ultraplot colormap
cmap_get = pcolors._cmap_database.get_cmap("viridis")
assert isinstance(cmap_get, pcolors.ContinuousColormap)

# The internal representation should also be updated
cmap_raw_after = pcolors._cmap_database._cmaps["viridis"]
assert isinstance(cmap_raw_after, pcolors.ContinuousColormap)


def test_case_insensitivity():
"""
Test that colormap lookup is case-insensitive.
"""
cmap1 = pcolors._cmap_database.get_cmap("ViRiDiS")
cmap2 = pcolors._cmap_database.get_cmap("viridis")
assert cmap1.name.lower().startswith("_viridis")
assert cmap2.name.lower().startswith("_viridis")


def test_reversed_shifted():
"""
Test reversed and shifted colormaps.
"""
# Create a simple colormap to test the reversal logic
# This avoids dependency on the exact definition of 'viridis' in matplotlib
colors_list = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] # Red, Green, Blue
test_cmap = pcolors.ContinuousColormap.from_list("test_cmap", colors_list)
pcolors._cmap_database.register(test_cmap)

cmap = pcolors._cmap_database.get_cmap("test_cmap")
cmap_r = pcolors._cmap_database.get_cmap("test_cmap_r")

# Check name
assert cmap_r.name == "_test_cmap_copy_r"
# Check colors
# Start of original should be end of reversed
assert np.allclose(cmap(0.0), cmap_r(1.0))
# End of original should be start of reversed
assert np.allclose(cmap(1.0), cmap_r(0.0))
# Middle should be the same
assert np.allclose(cmap(0.5)[:3], cmap_r(0.5)[:3][::-1])


def test_grays_translation():
"""
Test that 'Grays' is translated to 'greys'.
"""
cmap_grays = pcolors._cmap_database.get_cmap("Grays")
assert cmap_grays.name.lower().startswith("_greys")


def test_lazy_loading_file(tmp_path):
"""
Test that colormaps from files are lazy-loaded.
"""
# Create a dummy colormap file
cmap_data = "1, 0, 0\n0, 1, 0\n0, 0, 1"
cmap_file = tmp_path / "my_test_cmap.rgb"
cmap_file.write_text(cmap_data)

# Register it lazily
pcolors._cmap_database.register_lazy("my_test_cmap", str(cmap_file), "continuous")

# Before access, it should be a lazy-load dict
cmap_raw = pcolors._cmap_database._cmaps["my_test_cmap"]
assert isinstance(cmap_raw, dict)
assert cmap_raw["is_lazy"]

# After access, it should be an ultraplot colormap
cmap_get = pcolors._cmap_database.get_cmap("my_test_cmap")
assert isinstance(cmap_get, pcolors.ContinuousColormap)
assert cmap_get.name.lower().startswith("_my_test_cmap")

# The internal representation should also be updated
cmap_raw_after = pcolors._cmap_database._cmaps["my_test_cmap"]
assert isinstance(cmap_raw_after, pcolors.ContinuousColormap)


def test_register_new():
"""
Test registering a new colormap.
"""
colors_list = [(0, 0, 0), (1, 1, 1)]
new_cmap = pcolors.DiscreteColormap(colors_list, name="my_new_cmap")
pcolors._cmap_database.register(new_cmap)

# Check it was registered
cmap_get = pcolors._cmap_database.get_cmap("my_new_cmap")
assert cmap_get.name.lower().startswith(
"_my_new_cmap"
), f"Received {cmap_get.name.lower()} expected _my_new_cmap"
assert len(cmap_get.colors) == 2