From 0b895a6c4f3f673b00db7a14381b89795fd1a9f7 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 17 May 2026 19:09:50 +0100 Subject: [PATCH] revert: default use_jax_for_visualization to False (reverts #1278) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #1278 made `use_jax_for_visualization` default to follow `use_jax` (`Optional[bool] = None` resolving to `use_jax`). That caused every Nautilus quick-update under `use_jax=True` to evaluate `jax.jit(self.fit_from)(instance=instance)` where `instance` is a `ModelInstance` — a plain Python object that is not pytree-registered. JAX raises a `TypeError` trying to abstract it. On real pipeline runs (e.g. `z_projects/euclid/scripts/initial_lens_model.py`) the exception was swallowed deeper in the visualizer's outer guards; visible symptom was source-plane FITS files written all-zero and posteriors collapsing to the full prior on Einstein radius across every Euclid tile. Reverts the default to `bool = False`. Drops the sentinel-resolution block. Explicit opt-in (`use_jax_for_visualization=True`) and the existing `use_jax_for_visualization=True and not use_jax` warning both remain. No code in the PyAuto ecosystem relied on the implicit- on behaviour introduced by #1278. Co-Authored-By: Claude Opus 4.7 (1M context) --- autofit/non_linear/analysis/analysis.py | 45 ++++--------------- .../test_use_jax_for_visualization.py | 11 ----- 2 files changed, 8 insertions(+), 48 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index e5211a9b6..0bccc01b8 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -36,23 +36,9 @@ class Analysis(ABC): def __init__( self, use_jax: bool = False, - use_jax_for_visualization: Optional[bool] = None, + use_jax_for_visualization: bool = False, **kwargs, ): - """ - Parameters - ---------- - use_jax - Run the likelihood through ``jax.jit`` for the fast path. When JAX - is unavailable this silently falls back to numpy with a warning. - use_jax_for_visualization - Whether ``fit_for_visualization`` should dispatch through the - ``jax.jit``-cached path. ``None`` (default) follows ``use_jax`` — - users who set ``use_jax=True`` automatically get JIT visualization. - Pass ``False`` to force the eager NumPy plotter even when - ``use_jax=True``; pass ``True`` to opt in explicitly. Passing - ``True`` while ``use_jax=False`` logs a warning and disables it. - """ import os if os.environ.get("PYAUTO_DISABLE_JAX") == "1": use_jax = False @@ -82,9 +68,6 @@ def __init__( use_jax = False use_jax_for_visualization = False - if use_jax_for_visualization is None: - use_jax_for_visualization = use_jax - if use_jax_for_visualization and not use_jax: logger.warning( "use_jax_for_visualization=True requires use_jax=True; " @@ -100,21 +83,15 @@ def fit_for_visualization(self, instance): """ Build the fit used by the visualizer. - Dispatch over ``self.fit_from`` with a ``jax.jit`` fast path that - follows ``use_jax`` by default: + Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path: - * ``self._use_jax_for_visualization`` is ``False`` — plain - ``self.fit_from(instance)``. Untouched by JAX. This is the - resolved state when ``use_jax=False`` (the parameter default), - or when the user explicitly passed - ``use_jax_for_visualization=False`` to opt out. - * ``self._use_jax_for_visualization`` is ``True`` — lazily construct + * ``use_jax_for_visualization=False`` (default) — plain + ``self.fit_from(instance)``. Untouched by JAX. + * ``use_jax_for_visualization=True`` — lazily construct ``jax.jit(self.fit_from)`` on the first call and cache it on the instance as ``_jitted_fit_from``, then call that for every subsequent visualization. The first call pays the compile cost; - subsequent calls reuse the cached compiled function. This is the - resolved state when ``use_jax=True`` (the sentinel default - ``use_jax_for_visualization=None`` follows ``use_jax``). + subsequent calls reuse the cached compiled function. Caching is per-``Analysis`` instance so each analysis gets its own compiled function keyed off that instance's closed-over state @@ -122,18 +99,12 @@ def fit_for_visualization(self, instance): aux data via ``register_instance_pytree(FitImaging, no_flatten=...)`` in PyAutoLens). - ``fit_from`` is defined by Analysis subclasses (e.g. ``AnalysisImaging``), - not the base class — this method is only callable on subclasses that - provide it. Downstream visualizers should prefer this over calling - ``fit_from`` directly so the JIT seam stays in one place. - For the JIT path to succeed, the ``Fit*`` return type (and every nested autoarray / galaxy / lens type it carries) must be pytree- registered. That wiring lives in each analysis subclass (see ``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens). - Variants that have not yet been pytree-audited must pass - ``use_jax_for_visualization=False`` explicitly when constructing - the analysis (or simply leave ``use_jax=False``). + Variants that have not yet been pytree-audited must leave + ``use_jax_for_visualization`` at its default of ``False``. """ if not self._use_jax_for_visualization: return self.fit_from(instance=instance) diff --git a/test_autofit/analysis/test_use_jax_for_visualization.py b/test_autofit/analysis/test_use_jax_for_visualization.py index f2caca1f4..f4061e39a 100644 --- a/test_autofit/analysis/test_use_jax_for_visualization.py +++ b/test_autofit/analysis/test_use_jax_for_visualization.py @@ -67,17 +67,6 @@ def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch): assert analysis._use_jax_for_visualization is False -def test_pyauto_disable_jax_overrides_sentinel_default(monkeypatch): - """PYAUTO_DISABLE_JAX=1 must still force both off even when the user - constructs Analysis(use_jax=True) and lets the sentinel resolve. This is - a numpy-only check — JAX-conditional sentinel-resolution assertions live - in autofit_workspace_test/scripts/jax_assertions/fitness_dispatch.py.""" - monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1") - analysis = af.Analysis(use_jax=True) - assert analysis._use_jax is False - assert analysis._use_jax_for_visualization is False - - def test_fit_for_visualization_works_without_flag(): analysis = _FittableAnalysis() result = analysis.fit_for_visualization(instance="sentinel")