Skip to content

fix: optional-dep latents fall back to NumPy, soft-fail to NaN #464

@Jammy2211

Description

@Jammy2211

Overview

When jax_zero_contour isn't installed, the default-enabled
effective_einstein_radius latent crashes the post-fit metric write of
otherwise-converged searches (surfaced on autolens_profiling job
322552 against PyAutoLens #557). Fix with caller fallback + NaN
backstop
: PyAutoLens's effective_einstein_radius detects the
missing dep and routes to the existing NumPy einstein_radius_from
path (user keeps a real value); PyAutoGalaxy's lens_calc.py
soft-fails to NaN / [] as a defensive backstop for any other direct
caller.

Plan

  • PyAutoLens — caller fallback. In effective_einstein_radius,
    detect missing jax_zero_contour on the JAX path and route through
    the existing NumPy einstein_radius_from(grid=fit.dataset.grids.lp)
    branch with a one-time-per-process warning ("falling back to NumPy —
    slower; install jax_zero_contour to enable JIT path").
  • PyAutoGalaxy — backstop soft-fail. Add _maybe_optional_dep_warn
    helper in lens_calc.py; replace the two try/except → raise blocks
    with NaN return (einstein_radius_jit_from) and [] return
    (_critical_curve_list_via_zero_contour). Empty list matches the
    existing ValueError → return [] path at line 1167.
  • Tests in both repos. Mock the import failure and assert: PyAutoLens
    fallback returns a real (finite) Einstein radius; PyAutoGalaxy
    producer returns NaN / [] with exactly one warning per process per
    feature.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary — soft-fail backstop)
  • PyAutoLens (caller fallback)

Work Classification

Library (both repos)

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean
./PyAutoLens main clean

Suggested branch: feature/optional-dep-latents-soft-fail (both repos)
Worktree root: ~/Code/PyAutoLabs-wt/optional-dep-latents-soft-fail/ (created later by /start_library)

Implementation Steps

PyAutoGalaxy first (backstop, no caller dependency).

  1. autogalaxy/operate/lens_calc.py — add helper at module scope:

    • Add import importlib and _OPTIONAL_DEP_WARNED: set = set() near the top. logger already exists in this module.
    • Add _maybe_optional_dep_warn(import_name: str, feature_name: str) -> bool: tries importlib.import_module(import_name); returns False on success; on ModuleNotFoundError emits one logger.warning per feature_name and returns True. Message: "Optional dependency '%s' not installed; '%s' returning NaN/empty. pip install %s to enable it."
  2. Replace _critical_curve_list_via_zero_contour (lines 1155–1161):

    if _maybe_optional_dep_warn(
        "jax_zero_contour", "critical_curve_list_via_zero_contour"
    ):
        return []
    from jax_zero_contour import ZeroSolver
  3. Replace einstein_radius_jit_from (lines 1566–1572):

    if _maybe_optional_dep_warn(
        "jax_zero_contour", "einstein_radius_jit_from"
    ):
        import jax.numpy as jnp
        return jnp.nan
    from jax_zero_contour import ZeroSolver
  4. Add tests to test_autogalaxy/operate/test_deflections.py:

    • test__einstein_radius_jit_from__missing_jax_zero_contour__returns_nan_and_warns(monkeypatch, caplog) — patch importlib.import_module to raise ModuleNotFoundError for "jax_zero_contour", build a minimal LensCalc, assert return is NaN and one warning record matches feature_name.
    • test__tangential_critical_curve_list_via_zero_contour__missing_dep__returns_empty_list_and_warns(...) — same approach, assert [].
    • test__maybe_optional_dep_warn__logs_only_once_per_name(...) — call twice, assert one record.
    • Each test resets _OPTIONAL_DEP_WARNED.discard(name) at the start for determinism.

PyAutoLens second (caller fallback).

  1. autolens/analysis/latent.py — add fallback logic to effective_einstein_radius:

    • Add module-level _JAX_ZERO_CONTOUR_FALLBACK_WARNED: bool = False.
    • Add helper _jax_zero_contour_available() -> bool that does importlib.import_module("jax_zero_contour") inside try/except and warns once when missing: "jax_zero_contour not installed; effective_einstein_radius falling back to NumPy path (slower). pip install jax_zero_contour to enable the JIT path."
    • Modify the function (lines 218–228):
      try:
          lens_calc = LensCalc.from_mass_obj(fit.tracer)
          if xp is not np and _jax_zero_contour_available():
              import jax.numpy as jnp
              init_guess = jnp.array(
                  [[1.0, 0.0], [0.0, 1.0], [-1.0, 0.0], [0.0, -1.0]]
              )
              return lens_calc.einstein_radius_jit_from(init_guess=init_guess)
          return lens_calc.einstein_radius_from(grid=fit.dataset.grids.lp)
      except (ValueError, AttributeError):
          return xp.nan
    • Net effect: JAX path with dep present → unchanged JIT path. JAX path without dep → falls through to NumPy path with a one-time warning. NumPy path → unchanged.
  2. Add tests to test_autolens/analysis/test_latent.py:

    • test__effective_einstein_radius__jax_path__missing_jax_zero_contour__falls_back_to_numpy(monkeypatch, caplog) — patch importlib.import_module to raise for "jax_zero_contour", call effective_einstein_radius(fit=fit, magzero=None, xp=jnp), assert finite return (not NaN) and one warning emitted.
    • Reset _latent_module._JAX_ZERO_CONTOUR_FALLBACK_WARNED = False at start.

Key Files

  • PyAutoGalaxy/autogalaxy/operate/lens_calc.py — helper + replace 2 raise sites (~12 LOC net).
  • PyAutoGalaxy/test_autogalaxy/operate/test_deflections.py — 3 new unit tests.
  • PyAutoLens/autolens/analysis/latent.py — helper + reroute in effective_einstein_radius (~10 LOC net).
  • PyAutoLens/test_autolens/analysis/test_latent.py — 1 new fallback test.

Merge Order

PyAutoGalaxy PR first (the backstop). PyAutoLens PR second; does not strictly depend on the PyAutoGalaxy merge because the PyAutoLens fallback sidesteps the producer entirely. Library-first gate still applies.

Out of Scope

  • Promoting jax_zero_contour to a hard dep (alternative; not preferred per the prompt).
  • autogalaxy/util/plot_utils.py:337-360 — has its own ImportError handling for the plot helper, not reached by default latents.
  • Auditing every optional-dep import in PyAutoGalaxy; scope is the latent-computation surface.
  • Adding a fallback for _critical_curve_list_via_zero_contour to the marching-squares sibling — its callers are plot utilities; [] (no curves drawn) is acceptable.

Performance Note

The NumPy fallback (einstein_radius_from) uses a dense ~250k-evaluation grid vs. the JIT'd contour tracer. Cost is post-fit, runs once per converged search, so order-of-seconds is acceptable. The fallback warning makes the cost discoverable to the user.

Original Prompt

Click to expand starting prompt

Optional-dep latents should soft-fail to NaN, not re-raise ModuleNotFoundError

Context

Discovered while validating the PyAutoLens #557 magzero fix on the
first-class A100 search profile (autolens_profiling job 322552). The
magzero _mujy latents now soft-fail correctly, but a second post-fit
latent crashed the same metric-JSON write path with a parallel bug:

File "PyAutoGalaxy/autogalaxy/operate/lens_calc.py", line 1567,
   in einstein_radius_jit_from
    from jax_zero_contour import ZeroSolver
ModuleNotFoundError: No module named 'jax_zero_contour'

The above exception was the direct cause of the following exception:

  File "PyAutoLens/autolens/analysis/latent.py", line 225,
     in effective_einstein_radius
    return lens_calc.einstein_radius_jit_from(init_guess=init_guess)
  raise ModuleNotFoundError(
    "jax_zero_contour is required for einstein_radius_jit_from. "
    "Install it with: pip install jax_zero_contour"
  ) from exc

The raise lives at PyAutoGalaxy:autogalaxy/operate/lens_calc.py:1566-1572:

try:
    from jax_zero_contour import ZeroSolver
except ModuleNotFoundError as exc:
    raise ModuleNotFoundError(
        "jax_zero_contour is required for einstein_radius_jit_from. "
        "Install it with: pip install jax_zero_contour"
    ) from exc

This is the same default-config-crashes-by-default pattern as the
magzero family (fixed in PyAutoLens #557 with _maybe_magzero_warn),
just keyed on an optional dependency instead of an Analysis kwarg.

The bug

effective_einstein_radius is enabled by default in
PyAutoLens:autolens/config/latent.yaml, and is computed during every
SearchUpdater._compute_latent_samples call as part of the standard
post-fit pipeline. When jax_zero_contour isn't installed (it isn't
in PyAuto's default pyproject.toml deps — it's a JAX-only optional
extra), every search.fit() that reaches the post-fit step crashes
with the message above. The search itself converged; the metric-JSON
write never happens.

This is the same failure shape as
PyAutoPrompt/autolens/magzero_required_latents_crash_search.md
(landed as PyAutoLens #557). The PyAutoLens fix only addressed the
magzero family of latents — the optional-dependency family was not
touched, and surfaces with the same symptoms.

Desired fix

Mirror PyAutoLens #557's _maybe_magzero_warn: replace the re-raise
with a per-process warning + xp.nan return. The "module is missing"
case is structurally identical to the "user kwarg is missing" case —
both mean the latent value is unknown, and the right behaviour is
return-NaN-and-warn, not kill the search.

Sketch (in PyAutoGalaxy:autogalaxy/operate/lens_calc.py or wherever
the latent function lives — PyAutoLens:autolens/analysis/latent.py:225
calls into this):

_OPTIONAL_DEP_WARNED: set[str] = set()

def _maybe_optional_dep_warn(import_name: str, name: str) -> bool:
    """Return True (and warn once) if the optional dependency is missing."""
    try:
        importlib.import_module(import_name)
        return False
    except ModuleNotFoundError:
        if name not in _OPTIONAL_DEP_WARNED:
            logger.warning(
                "Optional dependency '%s' not installed; '%s' latent will "
                "be NaN. pip install %s to enable it, or disable in "
                "config/latent.yaml to silence this warning.",
                import_name, name, import_name,
            )
            _OPTIONAL_DEP_WARNED.add(name)
        return True

einstein_radius_jit_from (and any sibling that does the same
try/except ModuleNotFoundError → re-raise pattern) becomes:

def einstein_radius_jit_from(self, ...):
    if _maybe_optional_dep_warn("jax_zero_contour", "effective_einstein_radius"):
        return xp.nan
    from jax_zero_contour import ZeroSolver
    ...

Test plan

  • Unit test (mirror test_autolens/analysis/test_latent.py's magzero
    cases): mock jax_zero_contour import to raise, assert the latent
    returns NaN and emits one warning per process.
  • End-to-end on a node WITHOUT jax_zero_contour installed: construct
    AnalysisImaging(..., use_jax=True) with default
    config/latent.yaml, run af.Nautilus(...).fit(...) in
    PYAUTO_TEST_MODE=1. Confirm search completes, JSON write happens,
    latent.csv has NaN values for the affected columns.
  • Smoke check: re-run autolens_profiling/searches/nautilus/imaging/mge.py
    on a clean venv (no jax_zero_contour) and confirm no crash.

Affected callers / interaction surface

  • Every PyAutoLens search using default latents on a venv without
    jax_zero_contour — the HPC PyAutoNSS venv on Euclid SAAS in
    particular, where jax_zero_contour was not installed at venv
    creation time (added manually post-hoc to unblock job 322552).
  • Sibling functions in PyAutoGalaxy:autogalaxy/operate/lens_calc.py
    that do the same try / except ModuleNotFoundError → raise pattern:
    grep for raise ModuleNotFoundError in lens_calc.py and adjacent
    modules; any that gate optional-dep imports for default-on latents
    should switch to the soft-fail pattern.

Why not just add jax_zero_contour as a hard dep?

That's the pragmatic short-term fix (it's what unblocked job 322552
on the HPC). But it makes a JAX-only utility a required dep of every
PyAutoGalaxy install, including non-JAX numpy-only users — wider blast
radius than this fix needs. Soft-fail keeps the optional dep optional
while making default configs runnable.

Out of scope

  • Promoting jax_zero_contour to a hard dep (alternative; not the
    preferred fix).
  • Auditing every optional-dep code path in PyAutoGalaxy. Scope to
    the latent-computation surface where default configs hit them.
  • Re-reviewing PyAutoLens #557 — that fix is correct for what it
    covered; this is a parallel-but-independent surface.

Cross-references

  • PyAutoLens #557 — fixed the magzero family with the same soft-fail
    pattern this prompt proposes for optional deps.
  • PyAutoPrompt/autolens/magzero_required_latents_crash_search.md —
    the upstream prompt for #557.
  • autolens_profiling HPC job 322552 — the surfacing run; sampling
    completed (64,200 evals, log_Z=+31690.49) but post-fit latent
    crashed before metric JSON write.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions