diff --git a/autolens/analysis/latent.py b/autolens/analysis/latent.py index e1c314d29..8dbad3b6e 100644 --- a/autolens/analysis/latent.py +++ b/autolens/analysis/latent.py @@ -17,6 +17,7 @@ which would kill the post-fit metric write of an otherwise-converged search). """ +import importlib import logging from typing import Callable, Dict, List, Optional @@ -35,6 +36,33 @@ # the many fit evaluations a single search performs. _MAGZERO_WARNED: set = set() +# Set to True the first time ``effective_einstein_radius`` falls back from +# the JAX path to the NumPy path because ``jax_zero_contour`` is missing. +# Deduplicates the fallback warning across the many fit evaluations a +# single search performs. +_JAX_ZERO_CONTOUR_FALLBACK_WARNED: bool = False + + +def _jax_zero_contour_available() -> bool: + """ + Return True if ``jax_zero_contour`` can be imported; False otherwise. + The first False return emits one warning per process noting that + ``effective_einstein_radius`` will use the slower NumPy path. + """ + global _JAX_ZERO_CONTOUR_FALLBACK_WARNED + try: + importlib.import_module("jax_zero_contour") + return True + except ModuleNotFoundError: + if not _JAX_ZERO_CONTOUR_FALLBACK_WARNED: + logger.warning( + "jax_zero_contour not installed; effective_einstein_radius " + "falling back to NumPy path (slower). " + "pip install jax_zero_contour to enable the JIT path." + ) + _JAX_ZERO_CONTOUR_FALLBACK_WARNED = True + return False + def _maybe_magzero_warn(magzero, name) -> bool: """ @@ -217,7 +245,7 @@ def effective_einstein_radius(fit, magzero, xp=np): try: lens_calc = LensCalc.from_mass_obj(fit.tracer) - if xp is not np: + if xp is not np and _jax_zero_contour_available(): import jax.numpy as jnp init_guess = jnp.array( [[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]] diff --git a/test_autolens/analysis/test_latent.py b/test_autolens/analysis/test_latent.py index 73f6d2951..df4d254a5 100644 --- a/test_autolens/analysis/test_latent.py +++ b/test_autolens/analysis/test_latent.py @@ -1,3 +1,4 @@ +import importlib import logging from types import SimpleNamespace from unittest.mock import MagicMock @@ -283,6 +284,61 @@ def einstein_radius_jit_from(self, init_guess): assert calls["grid"] == "sentinel_grid" +def test_effective_einstein_radius_jax_path_falls_back_to_numpy_when_dep_missing( + monkeypatch, caplog +): + """ + When ``xp is not np`` but ``jax_zero_contour`` isn't installed, the + function must fall through to ``einstein_radius_from`` (the NumPy path) + instead of crashing or returning NaN — caller-side fallback yields a + real Einstein radius value. One warning is emitted per process. + """ + _latent_module._JAX_ZERO_CONTOUR_FALLBACK_WARNED = False + + real_import = importlib.import_module + + def fake_import(name, *args, **kwargs): + if name == "jax_zero_contour": + raise ModuleNotFoundError(f"No module named '{name}'") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(_latent_module.importlib, "import_module", fake_import) + + calls = {} + + class _SpyLensCalc: + def einstein_radius_from(self, grid): + calls["grid"] = grid + return 5.678 + + def einstein_radius_jit_from(self, init_guess): + raise AssertionError( + "jit path must not run when jax_zero_contour is missing" + ) + + monkeypatch.setattr( + "autogalaxy.operate.lens_calc.LensCalc.from_mass_obj", + classmethod(lambda cls, tracer: _SpyLensCalc()), + ) + fit = SimpleNamespace( + tracer=object(), + dataset=SimpleNamespace(grids=SimpleNamespace(lp="sentinel_grid")), + ) + + sentinel_xp = MagicMock() # truthy `xp is not np` + with caplog.at_level(logging.WARNING, logger=_latent_module.__name__): + value = effective_einstein_radius( + fit=fit, magzero=None, xp=sentinel_xp + ) + + assert value == pytest.approx(5.678) + assert calls["grid"] == "sentinel_grid" + fallback_warnings = [ + r for r in caplog.records if "falling back to NumPy" in r.message + ] + assert len(fallback_warnings) == 1 + + def test_effective_einstein_radius_returns_nan_on_value_error(monkeypatch): def _raise(cls, tracer): raise ValueError("singular mass model")