Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 7 additions & 54 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,12 @@ class Analysis(ABC):
def __init__(
self,
use_jax: bool = False,
use_jax_for_visualization: bool = False,
**kwargs,
):
import os
if os.environ.get("PYAUTO_DISABLE_JAX") == "1":
use_jax = False
use_jax_for_visualization = False

# If the user requested JAX but it isn't installed (e.g. Python <3.11
# without the [jax] extra), fall back to numpy with a loud warning
# rather than crashing later when the analysis tries to jit-compile.
if use_jax:
import importlib.util
import warnings
Expand All @@ -88,55 +83,20 @@ def __init__(
stacklevel=2,
)
use_jax = False
use_jax_for_visualization = False

if use_jax_for_visualization and not use_jax:
logger.warning(
"use_jax_for_visualization=True requires use_jax=True; "
"disabling use_jax_for_visualization."
)
use_jax_for_visualization = False

self._use_jax = use_jax
self._use_jax_for_visualization = use_jax_for_visualization
self.kwargs = kwargs

def fit_for_visualization(self, instance):
"""
Build the fit used by the visualizer.

Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path:

* ``use_jax_for_visualization=False`` (default) — plain
``self.fit_from(instance)``. Untouched by JAX.
* ``use_jax_for_visualization=True`` — lazily construct
``jax.jit(self.fit_from)`` on the first call and cache it on the
instance as ``_jitted_fit_from``, then call that for every
subsequent visualization. The first call pays the compile cost;
subsequent calls reuse the cached compiled function.

Caching is per-``Analysis`` instance so each analysis gets its own
compiled function keyed off that instance's closed-over state
(``self.dataset``, ``self.settings``, etc. — these ride as pytree
aux data via ``register_instance_pytree(FitImaging, no_flatten=...)``
in PyAutoLens).

For the JIT path to succeed, the ``Fit*`` return type (and every
nested autoarray / galaxy / lens type it carries) must be pytree-
registered. That wiring lives in each analysis subclass (see
``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens).
Variants that have not yet been pytree-audited must leave
``use_jax_for_visualization`` at its default of ``False``.
Delegates to ``self.fit_from(instance)``. When ``use_jax=True``,
the profile evaluations inside ``fit_from`` dispatch to JAX via
the decorator chain. The per-function JIT caches warm up on the
first call and are reused on all subsequent quick updates.
"""
if not self._use_jax_for_visualization:
return self.fit_from(instance=instance)

if getattr(self, "_jitted_fit_from", None) is None:
import jax

self._jitted_fit_from = jax.jit(self.fit_from)

return self._jitted_fit_from(instance)
return self.fit_from(instance=instance)

def __getattr__(self, item: str):
"""
Expand Down Expand Up @@ -444,15 +404,8 @@ def supports_background_update(self) -> bool:

@property
def supports_jax_visualization(self) -> bool:
"""
Whether the visualizer can work directly with JAX arrays.

Derived from the ``use_jax_for_visualization`` flag passed at
construction time. Subclasses may override to force a specific
answer (e.g. an Analysis that has been audited to support JAX
visualization unconditionally).
"""
return self._use_jax_for_visualization
"""Whether the visualizer can work directly with JAX arrays."""
return self._use_jax

def perform_quick_update(self, paths, instance):
raise NotImplementedError
Expand Down
29 changes: 29 additions & 0 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,35 @@ def __init__(
if self.paths is not None:
self.check_log_likelihood(fitness=self)

if (
self.iterations_per_quick_update is not None
and self._xp.__name__.startswith("jax")
):
self._warmup_visualization()

def _warmup_visualization(self):
"""Pre-compile the JAX operations used by ``fit_for_visualization``.

The first call to ``fit_for_visualization`` triggers ~200 small
per-function JAX JIT compilations (one per profile method per
decorator). Running them here moves that cost to search setup
so every quick update during sampling is fast.
"""
logger.info(
"Warming up visualization (one-time JAX compilation)..."
)
try:
instance = self.model.instance_from_prior_medians()
fit = self.analysis.fit_for_visualization(instance=instance)
_ = fit.model_data
except Exception:
logger.warning(
"Visualization warm-up failed (non-fatal); "
"first quick update may be slow."
)
else:
logger.info("Visualization warm-up complete.")

@property
def _xp(self):
return self.analysis._xp
Expand Down
46 changes: 17 additions & 29 deletions test_autofit/analysis/test_use_jax_for_visualization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
"""Tests for the ``use_jax_for_visualization`` flag on ``Analysis``."""
"""Tests for the visualization path on ``Analysis``.

The ``use_jax_for_visualization`` flag has been removed — visualization
now always follows ``use_jax``. These tests verify the simplified
``fit_for_visualization`` dispatch and the ``supports_jax_visualization``
property.
"""

import importlib.util

Expand All @@ -8,12 +14,10 @@


def _jax_installed() -> bool:
"""Check jax availability without importing it (per numpy-only rule)."""
return importlib.util.find_spec("jax") is not None


class _FittableAnalysis(af.Analysis):
"""Minimal Analysis subclass with a trivial ``fit_from`` for dispatch tests."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
Expand All @@ -27,52 +31,37 @@ def fit_from(self, instance):
return ("fit", instance)


def test_default_flag_is_false():
def test_default_flags():
analysis = af.Analysis()
assert analysis._use_jax is False
assert analysis._use_jax_for_visualization is False
assert analysis.supports_jax_visualization is False


def test_flag_requires_use_jax(caplog):
with caplog.at_level("WARNING"):
analysis = af.Analysis(use_jax=False, use_jax_for_visualization=True)
assert analysis._use_jax_for_visualization is False
assert any("requires use_jax=True" in r.message for r in caplog.records)


@pytest.mark.skipif(not _jax_installed(), reason="jax not installed; fallback path tested below")
def test_flag_accepted_when_use_jax_true():
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
@pytest.mark.skipif(not _jax_installed(), reason="jax not installed")
def test_use_jax_enables_jax_visualization():
analysis = af.Analysis(use_jax=True)
assert analysis._use_jax is True
assert analysis._use_jax_for_visualization is True
assert analysis.supports_jax_visualization is True


@pytest.mark.skipif(_jax_installed(), reason="jax installed; happy path tested above")
def test_use_jax_true_falls_back_to_numpy_when_jax_missing(recwarn):
"""When jax isn't installed, use_jax=True should silently downgrade
to use_jax=False after emitting a UserWarning. Affects 3.9/3.10
where the [jax] extra is gated out."""
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
@pytest.mark.skipif(_jax_installed(), reason="jax installed")
def test_use_jax_true_falls_back_when_jax_missing(recwarn):
analysis = af.Analysis(use_jax=True)
assert analysis._use_jax is False
assert analysis._use_jax_for_visualization is False
assert any("JAX is not installed" in str(w.message) for w in recwarn)


def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
def test_pyauto_disable_jax_env_var(monkeypatch):
monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1")
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
analysis = af.Analysis(use_jax=True)
assert analysis._use_jax is False
assert analysis._use_jax_for_visualization is False


def test_fit_for_visualization_works_without_flag():
def test_fit_for_visualization_delegates_to_fit_from():
analysis = _FittableAnalysis()
result = analysis.fit_for_visualization(instance="sentinel")
assert result == ("fit", "sentinel")
assert analysis.fit_from_calls == 1
assert getattr(analysis, "_jitted_fit_from", None) is None


def test_subclass_can_override_supports_jax_visualization():
Expand All @@ -82,5 +71,4 @@ def supports_jax_visualization(self):
return True

analysis = ForcedAnalysis()
assert analysis._use_jax_for_visualization is False
assert analysis.supports_jax_visualization is True
Loading