Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion autolens/analysis/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
which would kill the post-fit metric write of an otherwise-converged
search).
"""
import importlib
import logging
from typing import Callable, Dict, List, Optional

Expand All @@ -35,6 +36,33 @@
# the many fit evaluations a single search performs.
_MAGZERO_WARNED: set = set()

# Set to True the first time ``effective_einstein_radius`` falls back from
# the JAX path to the NumPy path because ``jax_zero_contour`` is missing.
# Deduplicates the fallback warning across the many fit evaluations a
# single search performs.
_JAX_ZERO_CONTOUR_FALLBACK_WARNED: bool = False


def _jax_zero_contour_available() -> bool:
"""
Return True if ``jax_zero_contour`` can be imported; False otherwise.
The first False return emits one warning per process noting that
``effective_einstein_radius`` will use the slower NumPy path.
"""
global _JAX_ZERO_CONTOUR_FALLBACK_WARNED
try:
importlib.import_module("jax_zero_contour")
return True
except ModuleNotFoundError:
if not _JAX_ZERO_CONTOUR_FALLBACK_WARNED:
logger.warning(
"jax_zero_contour not installed; effective_einstein_radius "
"falling back to NumPy path (slower). "
"pip install jax_zero_contour to enable the JIT path."
)
_JAX_ZERO_CONTOUR_FALLBACK_WARNED = True
return False


def _maybe_magzero_warn(magzero, name) -> bool:
"""
Expand Down Expand Up @@ -217,7 +245,7 @@ def effective_einstein_radius(fit, magzero, xp=np):

try:
lens_calc = LensCalc.from_mass_obj(fit.tracer)
if xp is not np:
if xp is not np and _jax_zero_contour_available():
import jax.numpy as jnp
init_guess = jnp.array(
[[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]
Expand Down
56 changes: 56 additions & 0 deletions test_autolens/analysis/test_latent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import logging
from types import SimpleNamespace
from unittest.mock import MagicMock
Expand Down Expand Up @@ -283,6 +284,61 @@ def einstein_radius_jit_from(self, init_guess):
assert calls["grid"] == "sentinel_grid"


def test_effective_einstein_radius_jax_path_falls_back_to_numpy_when_dep_missing(
monkeypatch, caplog
):
"""
When ``xp is not np`` but ``jax_zero_contour`` isn't installed, the
function must fall through to ``einstein_radius_from`` (the NumPy path)
instead of crashing or returning NaN — caller-side fallback yields a
real Einstein radius value. One warning is emitted per process.
"""
_latent_module._JAX_ZERO_CONTOUR_FALLBACK_WARNED = False

real_import = importlib.import_module

def fake_import(name, *args, **kwargs):
if name == "jax_zero_contour":
raise ModuleNotFoundError(f"No module named '{name}'")
return real_import(name, *args, **kwargs)

monkeypatch.setattr(_latent_module.importlib, "import_module", fake_import)

calls = {}

class _SpyLensCalc:
def einstein_radius_from(self, grid):
calls["grid"] = grid
return 5.678

def einstein_radius_jit_from(self, init_guess):
raise AssertionError(
"jit path must not run when jax_zero_contour is missing"
)

monkeypatch.setattr(
"autogalaxy.operate.lens_calc.LensCalc.from_mass_obj",
classmethod(lambda cls, tracer: _SpyLensCalc()),
)
fit = SimpleNamespace(
tracer=object(),
dataset=SimpleNamespace(grids=SimpleNamespace(lp="sentinel_grid")),
)

sentinel_xp = MagicMock() # truthy `xp is not np`
with caplog.at_level(logging.WARNING, logger=_latent_module.__name__):
value = effective_einstein_radius(
fit=fit, magzero=None, xp=sentinel_xp
)

assert value == pytest.approx(5.678)
assert calls["grid"] == "sentinel_grid"
fallback_warnings = [
r for r in caplog.records if "falling back to NumPy" in r.message
]
assert len(fallback_warnings) == 1


def test_effective_einstein_radius_returns_nan_on_value_error(monkeypatch):
def _raise(cls, tracer):
raise ValueError("singular mass model")
Expand Down
Loading