diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index a9b649403..097ceb911 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -58,17 +58,12 @@ class Analysis(ABC): def __init__( self, use_jax: bool = False, - use_jax_for_visualization: bool = False, **kwargs, ): import os if os.environ.get("PYAUTO_DISABLE_JAX") == "1": use_jax = False - use_jax_for_visualization = False - # If the user requested JAX but it isn't installed (e.g. Python <3.11 - # without the [jax] extra), fall back to numpy with a loud warning - # rather than crashing later when the analysis tries to jit-compile. if use_jax: import importlib.util import warnings @@ -88,55 +83,20 @@ def __init__( stacklevel=2, ) use_jax = False - use_jax_for_visualization = False - - if use_jax_for_visualization and not use_jax: - logger.warning( - "use_jax_for_visualization=True requires use_jax=True; " - "disabling use_jax_for_visualization." - ) - use_jax_for_visualization = False self._use_jax = use_jax - self._use_jax_for_visualization = use_jax_for_visualization self.kwargs = kwargs def fit_for_visualization(self, instance): """ Build the fit used by the visualizer. - Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path: - - * ``use_jax_for_visualization=False`` (default) — plain - ``self.fit_from(instance)``. Untouched by JAX. - * ``use_jax_for_visualization=True`` — lazily construct - ``jax.jit(self.fit_from)`` on the first call and cache it on the - instance as ``_jitted_fit_from``, then call that for every - subsequent visualization. The first call pays the compile cost; - subsequent calls reuse the cached compiled function. - - Caching is per-``Analysis`` instance so each analysis gets its own - compiled function keyed off that instance's closed-over state - (``self.dataset``, ``self.settings``, etc. — these ride as pytree - aux data via ``register_instance_pytree(FitImaging, no_flatten=...)`` - in PyAutoLens). - - For the JIT path to succeed, the ``Fit*`` return type (and every - nested autoarray / galaxy / lens type it carries) must be pytree- - registered. That wiring lives in each analysis subclass (see - ``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens). - Variants that have not yet been pytree-audited must leave - ``use_jax_for_visualization`` at its default of ``False``. + Delegates to ``self.fit_from(instance)``. When ``use_jax=True``, + the profile evaluations inside ``fit_from`` dispatch to JAX via + the decorator chain. The per-function JIT caches warm up on the + first call and are reused on all subsequent quick updates. """ - if not self._use_jax_for_visualization: - return self.fit_from(instance=instance) - - if getattr(self, "_jitted_fit_from", None) is None: - import jax - - self._jitted_fit_from = jax.jit(self.fit_from) - - return self._jitted_fit_from(instance) + return self.fit_from(instance=instance) def __getattr__(self, item: str): """ @@ -444,15 +404,8 @@ def supports_background_update(self) -> bool: @property def supports_jax_visualization(self) -> bool: - """ - Whether the visualizer can work directly with JAX arrays. - - Derived from the ``use_jax_for_visualization`` flag passed at - construction time. Subclasses may override to force a specific - answer (e.g. an Analysis that has been audited to support JAX - visualization unconditionally). - """ - return self._use_jax_for_visualization + """Whether the visualizer can work directly with JAX arrays.""" + return self._use_jax def perform_quick_update(self, paths, instance): raise NotImplementedError diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index 8736230cd..9b95d48f5 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -163,6 +163,35 @@ def __init__( if self.paths is not None: self.check_log_likelihood(fitness=self) + if ( + self.iterations_per_quick_update is not None + and self._xp.__name__.startswith("jax") + ): + self._warmup_visualization() + + def _warmup_visualization(self): + """Pre-compile the JAX operations used by ``fit_for_visualization``. + + The first call to ``fit_for_visualization`` triggers ~200 small + per-function JAX JIT compilations (one per profile method per + decorator). Running them here moves that cost to search setup + so every quick update during sampling is fast. + """ + logger.info( + "Warming up visualization (one-time JAX compilation)..." + ) + try: + instance = self.model.instance_from_prior_medians() + fit = self.analysis.fit_for_visualization(instance=instance) + _ = fit.model_data + except Exception: + logger.warning( + "Visualization warm-up failed (non-fatal); " + "first quick update may be slow." + ) + else: + logger.info("Visualization warm-up complete.") + @property def _xp(self): return self.analysis._xp diff --git a/test_autofit/analysis/test_use_jax_for_visualization.py b/test_autofit/analysis/test_use_jax_for_visualization.py index f4061e39a..e24b6ffa4 100644 --- a/test_autofit/analysis/test_use_jax_for_visualization.py +++ b/test_autofit/analysis/test_use_jax_for_visualization.py @@ -1,4 +1,10 @@ -"""Tests for the ``use_jax_for_visualization`` flag on ``Analysis``.""" +"""Tests for the visualization path on ``Analysis``. + +The ``use_jax_for_visualization`` flag has been removed — visualization +now always follows ``use_jax``. These tests verify the simplified +``fit_for_visualization`` dispatch and the ``supports_jax_visualization`` +property. +""" import importlib.util @@ -8,12 +14,10 @@ def _jax_installed() -> bool: - """Check jax availability without importing it (per numpy-only rule).""" return importlib.util.find_spec("jax") is not None class _FittableAnalysis(af.Analysis): - """Minimal Analysis subclass with a trivial ``fit_from`` for dispatch tests.""" def __init__(self, **kwargs): super().__init__(**kwargs) @@ -27,52 +31,37 @@ def fit_from(self, instance): return ("fit", instance) -def test_default_flag_is_false(): +def test_default_flags(): analysis = af.Analysis() assert analysis._use_jax is False - assert analysis._use_jax_for_visualization is False assert analysis.supports_jax_visualization is False -def test_flag_requires_use_jax(caplog): - with caplog.at_level("WARNING"): - analysis = af.Analysis(use_jax=False, use_jax_for_visualization=True) - assert analysis._use_jax_for_visualization is False - assert any("requires use_jax=True" in r.message for r in caplog.records) - - -@pytest.mark.skipif(not _jax_installed(), reason="jax not installed; fallback path tested below") -def test_flag_accepted_when_use_jax_true(): - analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True) +@pytest.mark.skipif(not _jax_installed(), reason="jax not installed") +def test_use_jax_enables_jax_visualization(): + analysis = af.Analysis(use_jax=True) assert analysis._use_jax is True - assert analysis._use_jax_for_visualization is True assert analysis.supports_jax_visualization is True -@pytest.mark.skipif(_jax_installed(), reason="jax installed; happy path tested above") -def test_use_jax_true_falls_back_to_numpy_when_jax_missing(recwarn): - """When jax isn't installed, use_jax=True should silently downgrade - to use_jax=False after emitting a UserWarning. Affects 3.9/3.10 - where the [jax] extra is gated out.""" - analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True) +@pytest.mark.skipif(_jax_installed(), reason="jax installed") +def test_use_jax_true_falls_back_when_jax_missing(recwarn): + analysis = af.Analysis(use_jax=True) assert analysis._use_jax is False - assert analysis._use_jax_for_visualization is False assert any("JAX is not installed" in str(w.message) for w in recwarn) -def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch): +def test_pyauto_disable_jax_env_var(monkeypatch): monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1") - analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True) + analysis = af.Analysis(use_jax=True) assert analysis._use_jax is False - assert analysis._use_jax_for_visualization is False -def test_fit_for_visualization_works_without_flag(): +def test_fit_for_visualization_delegates_to_fit_from(): analysis = _FittableAnalysis() result = analysis.fit_for_visualization(instance="sentinel") assert result == ("fit", "sentinel") assert analysis.fit_from_calls == 1 - assert getattr(analysis, "_jitted_fit_from", None) is None def test_subclass_can_override_supports_jax_visualization(): @@ -82,5 +71,4 @@ def supports_jax_visualization(self): return True analysis = ForcedAnalysis() - assert analysis._use_jax_for_visualization is False assert analysis.supports_jax_visualization is True