Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions autogalaxy/operate/lens_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
The class is constructed with `LensCalc.from_mass_obj(mass)` or `LensCalc.from_tracer(tracer)`.
"""
from functools import wraps
import importlib
import logging
import numpy as np
from typing import List, Tuple, Union
Expand All @@ -29,6 +30,32 @@

logger = logging.getLogger(__name__)

_OPTIONAL_DEP_WARNED: set = set()


def _maybe_optional_dep_warn(import_name: str, feature_name: str) -> bool:
"""
Return True (and warn once per process for ``feature_name``) if the
optional dependency ``import_name`` is not installed; False otherwise.

Callers that get True must early-return a soft-fail value (NaN, empty
list, etc.) — a search-killing raise here would discard the post-fit
metric write of an otherwise-converged search. Mirrors PyAutoLens's
``_maybe_magzero_warn`` for the same reason.
"""
try:
importlib.import_module(import_name)
return False
except ModuleNotFoundError:
if feature_name not in _OPTIONAL_DEP_WARNED:
logger.warning(
"Optional dependency '%s' not installed; '%s' returning "
"NaN/empty. pip install %s to enable it.",
import_name, feature_name, import_name,
)
_OPTIONAL_DEP_WARNED.add(feature_name)
return True


def grid_scaled_2d_for_marching_squares_from(
grid_pixels_2d: aa.Grid2D,
Expand Down Expand Up @@ -1152,13 +1179,11 @@ def _critical_curve_list_via_zero_contour(
max_newton
Maximum Newton iterations per step.
"""
try:
from jax_zero_contour import ZeroSolver
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"jax_zero_contour is required for zero-contour critical curve tracing. "
"Install it with: pip install jax_zero_contour"
) from exc
if _maybe_optional_dep_warn(
"jax_zero_contour", "critical_curve_list_via_zero_contour"
):
return []
from jax_zero_contour import ZeroSolver
import jax.numpy as jnp

if init_guess is None:
Expand Down Expand Up @@ -1563,13 +1588,11 @@ def einstein_radius_jit_from(
``area`` is the largest enclosed area across all seeds — robust
to multiple seeds landing on the same critical curve).
"""
try:
from jax_zero_contour import ZeroSolver
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"jax_zero_contour is required for einstein_radius_jit_from. "
"Install it with: pip install jax_zero_contour"
) from exc
if _maybe_optional_dep_warn(
"jax_zero_contour", "einstein_radius_jit_from"
):
return float("nan")
from jax_zero_contour import ZeroSolver
import jax.numpy as jnp

init_guess = jnp.atleast_2d(jnp.asarray(init_guess))
Expand Down
95 changes: 95 additions & 0 deletions test_autogalaxy/operate/test_deflections.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import importlib
import logging
import math

import numpy as np
import pytest

from skimage import measure

import autogalaxy as ag

from autogalaxy.operate import lens_calc as _lens_calc_module
from autogalaxy.operate.lens_calc import (
grid_scaled_2d_for_marching_squares_from,
LensCalc,
Expand Down Expand Up @@ -624,3 +629,93 @@ def test__zero_contour_cache__starts_empty_and_is_per_instance():
"solver-stand-in",
)
assert od_b._zero_contour_cache == {}


def _patch_missing_jax_zero_contour(monkeypatch):
"""
Make ``importlib.import_module("jax_zero_contour")`` raise as if the
package were not installed, while leaving all other imports intact.
Used by the soft-fail tests below.
"""
real_import = importlib.import_module

def fake_import(name, *args, **kwargs):
if name == "jax_zero_contour":
raise ModuleNotFoundError(f"No module named '{name}'")
return real_import(name, *args, **kwargs)

monkeypatch.setattr(_lens_calc_module.importlib, "import_module", fake_import)


def test__maybe_optional_dep_warn__logs_only_once_per_name(monkeypatch, caplog):
_lens_calc_module._OPTIONAL_DEP_WARNED.discard("test_feature_once")
_patch_missing_jax_zero_contour(monkeypatch)

with caplog.at_level(logging.WARNING, logger=_lens_calc_module.__name__):
first = _lens_calc_module._maybe_optional_dep_warn(
"jax_zero_contour", "test_feature_once"
)
second = _lens_calc_module._maybe_optional_dep_warn(
"jax_zero_contour", "test_feature_once"
)

assert first is True
assert second is True
matching = [r for r in caplog.records if "test_feature_once" in r.message]
assert len(matching) == 1


def test__einstein_radius_jit_from__missing_jax_zero_contour__returns_nan_and_warns(
monkeypatch, caplog
):
"""
When ``jax_zero_contour`` isn't installed, ``einstein_radius_jit_from``
must soft-fail to NaN with a single warning per process — not raise
``ModuleNotFoundError``, which would kill the post-fit metric write of
an otherwise-converged search.
"""
_lens_calc_module._OPTIONAL_DEP_WARNED.discard("einstein_radius_jit_from")
_patch_missing_jax_zero_contour(monkeypatch)

mp = ag.mp.IsothermalSph(centre=(0.0, 0.0), einstein_radius=2.0)
od = LensCalc.from_mass_obj(mp)

with caplog.at_level(logging.WARNING, logger=_lens_calc_module.__name__):
result = od.einstein_radius_jit_from(init_guess=[[1.0, 0.0]])

assert math.isnan(result)
matching = [
r for r in caplog.records if "einstein_radius_jit_from" in r.message
]
assert len(matching) == 1


def test__tangential_critical_curve_list_via_zero_contour__missing_dep__returns_empty(
monkeypatch, caplog
):
"""
Parallel soft-fail check for the critical-curve helper. With
``jax_zero_contour`` missing, ``_critical_curve_list_via_zero_contour``
returns ``[]`` (matching the existing ``ValueError → []`` early-out at
line 1167) with a single warning per process.
"""
_lens_calc_module._OPTIONAL_DEP_WARNED.discard(
"critical_curve_list_via_zero_contour"
)
_patch_missing_jax_zero_contour(monkeypatch)

mp = ag.mp.IsothermalSph(centre=(0.0, 0.0), einstein_radius=2.0)
od = LensCalc.from_mass_obj(mp)

with caplog.at_level(logging.WARNING, logger=_lens_calc_module.__name__):
result = od.tangential_critical_curve_list_via_zero_contour_from(
init_guess=[[1.0, 0.0]]
)

assert result == []
matching = [
r
for r in caplog.records
if "critical_curve_list_via_zero_contour" in r.message
]
assert len(matching) == 1
Loading