diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index 697a0f14..b612b5b4 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -17,6 +17,7 @@ The class is constructed with `LensCalc.from_mass_obj(mass)` or `LensCalc.from_tracer(tracer)`. """ from functools import wraps +import importlib import logging import numpy as np from typing import List, Tuple, Union @@ -29,6 +30,32 @@ logger = logging.getLogger(__name__) +_OPTIONAL_DEP_WARNED: set = set() + + +def _maybe_optional_dep_warn(import_name: str, feature_name: str) -> bool: + """ + Return True (and warn once per process for ``feature_name``) if the + optional dependency ``import_name`` is not installed; False otherwise. + + Callers that get True must early-return a soft-fail value (NaN, empty + list, etc.) — a search-killing raise here would discard the post-fit + metric write of an otherwise-converged search. Mirrors PyAutoLens's + ``_maybe_magzero_warn`` for the same reason. + """ + try: + importlib.import_module(import_name) + return False + except ModuleNotFoundError: + if feature_name not in _OPTIONAL_DEP_WARNED: + logger.warning( + "Optional dependency '%s' not installed; '%s' returning " + "NaN/empty. pip install %s to enable it.", + import_name, feature_name, import_name, + ) + _OPTIONAL_DEP_WARNED.add(feature_name) + return True + def grid_scaled_2d_for_marching_squares_from( grid_pixels_2d: aa.Grid2D, @@ -1152,13 +1179,11 @@ def _critical_curve_list_via_zero_contour( max_newton Maximum Newton iterations per step. """ - try: - from jax_zero_contour import ZeroSolver - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "jax_zero_contour is required for zero-contour critical curve tracing. " - "Install it with: pip install jax_zero_contour" - ) from exc + if _maybe_optional_dep_warn( + "jax_zero_contour", "critical_curve_list_via_zero_contour" + ): + return [] + from jax_zero_contour import ZeroSolver import jax.numpy as jnp if init_guess is None: @@ -1563,13 +1588,11 @@ def einstein_radius_jit_from( ``area`` is the largest enclosed area across all seeds — robust to multiple seeds landing on the same critical curve). """ - try: - from jax_zero_contour import ZeroSolver - except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "jax_zero_contour is required for einstein_radius_jit_from. " - "Install it with: pip install jax_zero_contour" - ) from exc + if _maybe_optional_dep_warn( + "jax_zero_contour", "einstein_radius_jit_from" + ): + return float("nan") + from jax_zero_contour import ZeroSolver import jax.numpy as jnp init_guess = jnp.atleast_2d(jnp.asarray(init_guess)) diff --git a/test_autogalaxy/operate/test_deflections.py b/test_autogalaxy/operate/test_deflections.py index 2b4dfa94..2f053c6a 100644 --- a/test_autogalaxy/operate/test_deflections.py +++ b/test_autogalaxy/operate/test_deflections.py @@ -1,3 +1,7 @@ +import importlib +import logging +import math + import numpy as np import pytest @@ -5,6 +9,7 @@ import autogalaxy as ag +from autogalaxy.operate import lens_calc as _lens_calc_module from autogalaxy.operate.lens_calc import ( grid_scaled_2d_for_marching_squares_from, LensCalc, @@ -624,3 +629,93 @@ def test__zero_contour_cache__starts_empty_and_is_per_instance(): "solver-stand-in", ) assert od_b._zero_contour_cache == {} + + +def _patch_missing_jax_zero_contour(monkeypatch): + """ + Make ``importlib.import_module("jax_zero_contour")`` raise as if the + package were not installed, while leaving all other imports intact. + Used by the soft-fail tests below. + """ + real_import = importlib.import_module + + def fake_import(name, *args, **kwargs): + if name == "jax_zero_contour": + raise ModuleNotFoundError(f"No module named '{name}'") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(_lens_calc_module.importlib, "import_module", fake_import) + + +def test__maybe_optional_dep_warn__logs_only_once_per_name(monkeypatch, caplog): + _lens_calc_module._OPTIONAL_DEP_WARNED.discard("test_feature_once") + _patch_missing_jax_zero_contour(monkeypatch) + + with caplog.at_level(logging.WARNING, logger=_lens_calc_module.__name__): + first = _lens_calc_module._maybe_optional_dep_warn( + "jax_zero_contour", "test_feature_once" + ) + second = _lens_calc_module._maybe_optional_dep_warn( + "jax_zero_contour", "test_feature_once" + ) + + assert first is True + assert second is True + matching = [r for r in caplog.records if "test_feature_once" in r.message] + assert len(matching) == 1 + + +def test__einstein_radius_jit_from__missing_jax_zero_contour__returns_nan_and_warns( + monkeypatch, caplog +): + """ + When ``jax_zero_contour`` isn't installed, ``einstein_radius_jit_from`` + must soft-fail to NaN with a single warning per process — not raise + ``ModuleNotFoundError``, which would kill the post-fit metric write of + an otherwise-converged search. + """ + _lens_calc_module._OPTIONAL_DEP_WARNED.discard("einstein_radius_jit_from") + _patch_missing_jax_zero_contour(monkeypatch) + + mp = ag.mp.IsothermalSph(centre=(0.0, 0.0), einstein_radius=2.0) + od = LensCalc.from_mass_obj(mp) + + with caplog.at_level(logging.WARNING, logger=_lens_calc_module.__name__): + result = od.einstein_radius_jit_from(init_guess=[[1.0, 0.0]]) + + assert math.isnan(result) + matching = [ + r for r in caplog.records if "einstein_radius_jit_from" in r.message + ] + assert len(matching) == 1 + + +def test__tangential_critical_curve_list_via_zero_contour__missing_dep__returns_empty( + monkeypatch, caplog +): + """ + Parallel soft-fail check for the critical-curve helper. With + ``jax_zero_contour`` missing, ``_critical_curve_list_via_zero_contour`` + returns ``[]`` (matching the existing ``ValueError → []`` early-out at + line 1167) with a single warning per process. + """ + _lens_calc_module._OPTIONAL_DEP_WARNED.discard( + "critical_curve_list_via_zero_contour" + ) + _patch_missing_jax_zero_contour(monkeypatch) + + mp = ag.mp.IsothermalSph(centre=(0.0, 0.0), einstein_radius=2.0) + od = LensCalc.from_mass_obj(mp) + + with caplog.at_level(logging.WARNING, logger=_lens_calc_module.__name__): + result = od.tangential_critical_curve_list_via_zero_contour_from( + init_guess=[[1.0, 0.0]] + ) + + assert result == [] + matching = [ + r + for r in caplog.records + if "critical_curve_list_via_zero_contour" in r.message + ] + assert len(matching) == 1