Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 8 additions & 37 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,9 @@ class Analysis(ABC):
def __init__(
self,
use_jax: bool = False,
use_jax_for_visualization: Optional[bool] = None,
use_jax_for_visualization: bool = False,
**kwargs,
):
"""
Parameters
----------
use_jax
Run the likelihood through ``jax.jit`` for the fast path. When JAX
is unavailable this silently falls back to numpy with a warning.
use_jax_for_visualization
Whether ``fit_for_visualization`` should dispatch through the
``jax.jit``-cached path. ``None`` (default) follows ``use_jax`` —
users who set ``use_jax=True`` automatically get JIT visualization.
Pass ``False`` to force the eager NumPy plotter even when
``use_jax=True``; pass ``True`` to opt in explicitly. Passing
``True`` while ``use_jax=False`` logs a warning and disables it.
"""
import os
if os.environ.get("PYAUTO_DISABLE_JAX") == "1":
use_jax = False
Expand Down Expand Up @@ -82,9 +68,6 @@ def __init__(
use_jax = False
use_jax_for_visualization = False

if use_jax_for_visualization is None:
use_jax_for_visualization = use_jax

if use_jax_for_visualization and not use_jax:
logger.warning(
"use_jax_for_visualization=True requires use_jax=True; "
Expand All @@ -100,40 +83,28 @@ def fit_for_visualization(self, instance):
"""
Build the fit used by the visualizer.

Dispatch over ``self.fit_from`` with a ``jax.jit`` fast path that
follows ``use_jax`` by default:
Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path:

* ``self._use_jax_for_visualization`` is ``False`` — plain
``self.fit_from(instance)``. Untouched by JAX. This is the
resolved state when ``use_jax=False`` (the parameter default),
or when the user explicitly passed
``use_jax_for_visualization=False`` to opt out.
* ``self._use_jax_for_visualization`` is ``True`` — lazily construct
* ``use_jax_for_visualization=False`` (default) — plain
``self.fit_from(instance)``. Untouched by JAX.
* ``use_jax_for_visualization=True`` — lazily construct
``jax.jit(self.fit_from)`` on the first call and cache it on the
instance as ``_jitted_fit_from``, then call that for every
subsequent visualization. The first call pays the compile cost;
subsequent calls reuse the cached compiled function. This is the
resolved state when ``use_jax=True`` (the sentinel default
``use_jax_for_visualization=None`` follows ``use_jax``).
subsequent calls reuse the cached compiled function.

Caching is per-``Analysis`` instance so each analysis gets its own
compiled function keyed off that instance's closed-over state
(``self.dataset``, ``self.settings``, etc. — these ride as pytree
aux data via ``register_instance_pytree(FitImaging, no_flatten=...)``
in PyAutoLens).

``fit_from`` is defined by Analysis subclasses (e.g. ``AnalysisImaging``),
not the base class — this method is only callable on subclasses that
provide it. Downstream visualizers should prefer this over calling
``fit_from`` directly so the JIT seam stays in one place.

For the JIT path to succeed, the ``Fit*`` return type (and every
nested autoarray / galaxy / lens type it carries) must be pytree-
registered. That wiring lives in each analysis subclass (see
``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens).
Variants that have not yet been pytree-audited must pass
``use_jax_for_visualization=False`` explicitly when constructing
the analysis (or simply leave ``use_jax=False``).
Variants that have not yet been pytree-audited must leave
``use_jax_for_visualization`` at its default of ``False``.
"""
if not self._use_jax_for_visualization:
return self.fit_from(instance=instance)
Expand Down
11 changes: 0 additions & 11 deletions test_autofit/analysis/test_use_jax_for_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
assert analysis._use_jax_for_visualization is False


def test_pyauto_disable_jax_overrides_sentinel_default(monkeypatch):
"""PYAUTO_DISABLE_JAX=1 must still force both off even when the user
constructs Analysis(use_jax=True) and lets the sentinel resolve. This is
a numpy-only check — JAX-conditional sentinel-resolution assertions live
in autofit_workspace_test/scripts/jax_assertions/fitness_dispatch.py."""
monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1")
analysis = af.Analysis(use_jax=True)
assert analysis._use_jax is False
assert analysis._use_jax_for_visualization is False


def test_fit_for_visualization_works_without_flag():
analysis = _FittableAnalysis()
result = analysis.fit_for_visualization(instance="sentinel")
Expand Down
Loading