Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions searches/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,72 @@
# searches

Sampler / search profiling — Nautilus first, with the folder layout designed so that other samplers (Dynesty, Emcee, BlackJAX, NumPyro, NSS, LBFGS, PocoMC) can slot in under their own prompts.
Sampler / search profiling for the PyAutoLens HST MGE lens-modelling likelihood. Each subfolder drives a single sampler family directly against the real likelihood — bypassing `af.NonLinearSearch` — so the per-sampler convergence characteristics (wall time, likelihood evaluations, posterior ESS, evals/time to ML) can be compared on identical footing.

Populated by **Phase 3** of the `autolens_profiling` roadmap. Will mirror `autolens_workspace_developer/searches_minimal/`.
## Why bypass `af.NonLinearSearch`?

See the top-level [README](../README.md) for the full phase plan.
`af.NonLinearSearch` adds caching, multi-process forking, output formatting, and result hierarchies that are valuable for production fits but obscure the underlying sampler's cost. The scripts in this section call the sampler library directly and instrument every likelihood evaluation through a shared `MLTracker`. The result is a clean apples-to-apples comparison of:

- Wall time and likelihood-evaluation count to **Nautilus's default convergence** (`n_eff=10000`, `f_live=0.01`).
- Per-evaluation likelihood cost (NumPy baseline vs JAX-JIT'd path).
- Evals-to-ML and time-to-ML — the eval index and wall time at which the running max log L first came within 1 nat of the final maximum.

## Shared helpers

| File | Role |
|------|------|
| [`_setup.py`](./_setup.py) | Builds the HST imaging dataset, the MGE + Isothermal + ExternalShear lens model with an MGE source bulge, and the `AnalysisImaging` object. The dataset, mask, and model mirror the reference setup in [`likelihood/imaging/mge.py`](../likelihood/imaging/mge.py) so likelihood values are directly comparable across the two sections. |
| [`_metrics.py`](./_metrics.py) | `MLTracker` — records the log-likelihood and wall time of every evaluation, computes evals-to-ML and time-to-ML headline numbers. Also offers `MLTracker.from_log_l_history` for samplers that JIT their likelihood and only expose log-L per dead/live point post hoc. |

## Supported samplers

| Sampler | Folder | Status | Notes |
|---------|--------|--------|-------|
| Nautilus | [`nautilus/`](./nautilus/README.md) | ✓ profiled | Both NumPy (`simple.py`) and JAX-JIT (`jax.py`) variants. |
| Dynesty | _planned_ | not yet mirrored | Static nested sampling; reference scripts at `autolens_workspace_developer/searches_minimal/dynesty_simple.py`. |
| Emcee | _planned_ | not yet mirrored | Affine-invariant ensemble MCMC. |
| BlackJAX (NUTS, SMC) | _planned_ | not yet mirrored | Pure-JAX HMC family. Gradient pathology surfaced in upstream `sweep_findings.md`; HMC viability depends on first fixing NaN-gradient hot spots. |
| NumPyro (ESS) | _planned_ | not yet mirrored | Ensemble slice sampler under JAX. |
| PocoMC | _planned_ | not yet mirrored | Preconditioned Monte Carlo. |
| NSS (simple, jit, grad) | _planned_ | not yet mirrored | Nested slice sampler; `nss_jit.py` shows VRAM ceiling on consumer GPUs (see `sweep_findings.md`). |
| LBFGS | _planned_ | not yet mirrored | Not a sampler; serves as the maximum-likelihood reference point. |

Each row above corresponds to one or more scripts under `autolens_workspace_developer/searches_minimal/`; the mirror migration here under their own follow-up prompts.

## Versioned artifacts

Each script writes a JSON + PNG pair to:

```
results/searches/<sampler>/<script>_summary_v<al.__version__>.{json,png}
```

The JSON carries the structured timings + sampler config + best-fit summary. The PNG is a bar chart of the headline timings (wall time, time per eval, time to ML; plus JIT compile time on JAX scripts).

Old versions are retained alongside new ones; Phase 4's dashboard surfaces the latest per axis.

## Running a script

From the repo root (cwd matters because `_setup.build_dataset()` resolves `dataset/imaging/hst/` relative to the repo root via `Path(__file__).resolve().parent.parent`):

```bash
cd autolens_profiling
python searches/nautilus/simple.py
python searches/nautilus/jax.py
```

Or as modules:

```bash
python -m searches.nautilus.simple
python -m searches.nautilus.jax
```

Both invocation styles work — each script injects the repo root into `sys.path` before importing `searches._{setup,metrics}` for robustness.

**Requirements:** `nautilus-sampler` for the Nautilus scripts (`pip install nautilus-sampler`). The JAX variant additionally needs a working JAX install.

**Codex / sandboxed runs:**

```bash
NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python searches/nautilus/simple.py
```
81 changes: 81 additions & 0 deletions searches/_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
Shared per-evaluation tracker used by every script in this folder.

Wrap a log-likelihood callable with ``MLTracker.wrap`` (or call
``tracker.record(log_l)`` manually) and the tracker stores the log L and
wall-clock time of every evaluation. After the run, ``finalise`` returns
the eval index and wall time at which the running max log L first came
within ``tolerance`` nats of the final maximum -- the "evals to ML" /
"time to ML" headline numbers used in the comparison.

For JAX paths where the likelihood runs inside ``jax.jit`` (and a Python
callback is impossible without forcing a host round-trip), use
``MLTracker.from_log_l_history`` instead with the full per-eval log L
sequence reconstructed from the sampler's dead-point + live-point state.
"""

from __future__ import annotations

import time
from typing import Callable, Optional, Sequence


class MLTracker:
"""Record per-evaluation log L and wall time, compute evals/time to ML."""

def __init__(self):
self.t0 = time.time()
self.history_log_l: list[float] = []
self.history_wall: list[float] = []

def record(self, log_l: float) -> None:
self.history_log_l.append(float(log_l))
self.history_wall.append(time.time() - self.t0)

def wrap(self, fn: Callable) -> Callable:
"""Decorate a log-likelihood callable so every call is recorded."""

def wrapped(*args, **kwargs):
log_l = fn(*args, **kwargs)
self.record(log_l)
return log_l

return wrapped

def finalise(
self, max_log_l: Optional[float] = None, tolerance: float = 1.0
) -> tuple[Optional[int], Optional[float]]:
"""Return (evals_to_ml, time_to_ml) — the eval index and wall time at
which the running max first came within ``tolerance`` nats of the
final maximum. ``(None, None)`` if no evaluations were recorded."""
if not self.history_log_l:
return None, None
if max_log_l is None:
max_log_l = max(self.history_log_l)
target = max_log_l - tolerance
for i, log_l in enumerate(self.history_log_l):
if log_l >= target:
return i + 1, self.history_wall[i]
return None, None

@staticmethod
def from_log_l_history(
log_l_history: Sequence[float],
total_sampling_time: float,
tolerance: float = 1.0,
) -> tuple[Optional[int], Optional[float]]:
"""Variant for samplers that run their likelihood inside JIT and only
expose log L per dead/live point post hoc. ``time_to_ml`` is linearly
interpolated from the total sampling time -- evaluations are assumed
evenly distributed over the run, which is a reasonable approximation
for nested sampling (each step is roughly the same cost)."""
if not log_l_history:
return None, None
max_log_l = max(log_l_history)
target = max_log_l - tolerance
for i, log_l in enumerate(log_l_history):
if log_l >= target:
evals_to_ml = i + 1
time_to_ml = total_sampling_time * (evals_to_ml / len(log_l_history))
return evals_to_ml, time_to_ml
return None, None
108 changes: 108 additions & 0 deletions searches/_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
Shared setup for the ``searches/`` profiling scripts.

Builds the HST imaging dataset, the MGE + Isothermal + ExternalShear lens model
with an MGE source bulge, and the ``AnalysisImaging`` object used by every
sampler in this section. The dataset, mask, and model mirror the reference setup
in ``likelihood/imaging/mge.py`` so the likelihood value is directly comparable
across the two sections.

Usage
-----

from searches._setup import build_dataset, build_model, build_analysis

dataset = build_dataset()
model = build_model(mask_radius=3.5)
analysis = build_analysis(dataset, use_jax=False)
"""

from pathlib import Path

import numpy as np

import autofit as af
import autolens as al

_WORKSPACE_ROOT = Path(__file__).resolve().parent.parent
_DATASET_SUBPATH = Path("dataset") / "imaging" / "hst"

PIXEL_SCALE = 0.05
MASK_RADIUS = 3.5


def build_dataset(mask_radius: float = MASK_RADIUS) -> al.Imaging:
"""Load the HST imaging dataset with mask + radial-bin over-sampling applied."""
dataset_path = _DATASET_SUBPATH

if al.util.dataset.should_simulate(str(dataset_path)):
raise FileNotFoundError(
f"Input dataset missing at '{dataset_path}'. The autolens_profiling "
f"repo mirrors only the curated datasets needed for default smoke "
f"runs. To regenerate, use the source-of-truth script at "
f"autolens_workspace_developer/jax_profiling/dataset_setup/imaging.py "
f"and copy the result into autolens_profiling/dataset/."
)

dataset = al.Imaging.from_fits(
data_path=dataset_path / "data.fits",
psf_path=dataset_path / "psf.fits",
noise_map_path=dataset_path / "noise_map.fits",
pixel_scales=PIXEL_SCALE,
)

mask = al.Mask2D.circular(
shape_native=dataset.shape_native,
pixel_scales=dataset.pixel_scales,
radius=mask_radius,
)
dataset = dataset.apply_mask(mask=mask)
dataset = dataset.apply_over_sampling(over_sample_size_lp=4)

over_sample_size = al.util.over_sample.over_sample_size_via_radial_bins_from(
grid=dataset.grid,
sub_size_list=[4, 2, 1],
radial_list=[0.3, 0.6],
centre_list=[(0.0, 0.0)],
)
dataset = dataset.apply_over_sampling(over_sample_size_lp=over_sample_size)
return dataset


def build_model(
mask_radius: float = MASK_RADIUS, total_gaussians: int = 20
) -> af.Collection:
"""Build the lens + source model used in ``jax_profiling/imaging/mge.py``."""
lens_bulge = al.model_util.mge_model_from(
mask_radius=mask_radius,
total_gaussians=total_gaussians,
centre_prior_is_uniform=True,
)
mass = af.Model(al.mp.Isothermal)
shear = af.Model(al.mp.ExternalShear)
lens = af.Model(al.Galaxy, redshift=0.5, bulge=lens_bulge, mass=mass, shear=shear)

source_bulge = al.model_util.mge_model_from(
mask_radius=mask_radius,
total_gaussians=total_gaussians,
centre_prior_is_uniform=False,
)
source = af.Model(al.Galaxy, redshift=1.0, bulge=source_bulge)

return af.Collection(galaxies=af.Collection(lens=lens, source=source))


def build_analysis(dataset: al.Imaging, use_jax: bool = False) -> al.AnalysisImaging:
"""Build the analysis object. Set ``use_jax=True`` for the pure-JAX path."""
return al.AnalysisImaging(dataset=dataset, use_jax=use_jax)


def format_best_fit(instance) -> str:
"""Terse one-line summary of the lens mass + shear of a best-fit instance."""
mass = instance.galaxies.lens.mass
shear = instance.galaxies.lens.shear
return (
f"lens.mass.einstein_radius={mass.einstein_radius:.4f} "
f"lens.mass.centre=({mass.centre[0]:.3f}, {mass.centre[1]:.3f}) "
f"shear=({shear.gamma_1:.4f}, {shear.gamma_2:.4f})"
)
47 changes: 47 additions & 0 deletions searches/nautilus/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# searches/nautilus

[Nautilus](https://github.com/johannesulf/nautilus) is a production nested-importance sampler that combines neural-network-based importance sampling with traditional nested sampling. It's gradient-free, so it sidesteps the JAX-gradient pathologies that affect HMC-family samplers on this likelihood, and is a strong baseline for end-to-end "what does a real sampler do on this lens model" timing.

These scripts drive Nautilus directly against the HST MGE imaging likelihood, bypassing `af.NonLinearSearch`. They are wiring tests + run-time profiling, **not** converged science fits — `n_live=200` is below what you'd use in production, but enough to see per-evaluation cost and reach the default `n_eff=10000` / `f_live=0.01` termination on this MGE setup.

## Scripts

| Script | Likelihood backend | What it profiles |
|--------|--------------------|------------------|
| [`simple.py`](./simple.py) | NumPy (`use_jax=False`) | Baseline: pure-NumPy log-likelihood passed straight to Nautilus. Highest per-evaluation cost; useful as the reference point against which JAX speedup is measured. |
| [`jax.py`](./jax.py) | JAX JIT (`use_jax=True`, `jax.jit`-compiled likelihood) | Reports JIT compile time separately. Per-evaluation cost is JAX kernel + a single Python ↔ JAX boundary crossing per call (Nautilus itself is NumPy-only). Compare versus a pure-JAX nested sampler like NSS-JIT (not yet mirrored) for the no-boundary-crossing variant. |

Both share the same Nautilus configuration so timings are directly comparable: `n_live=200`, default `n_eff=10000`, default `f_live=0.01`. Both use the shared `_setup` / `_metrics` from one folder up.

## What each script reports

- **Best fit**: max-likelihood lens mass / shear parameters (one-line summary).
- **Max log L** and **log evidence**.
- **Wall time** for the sampling phase (excluding JIT compile for `jax.py`).
- **JIT compile time** (one-shot warmup; `jax.py` only).
- **Likelihood evaluations** and **time per eval** (ms).
- **ESS** (effective sample size) and **posterior samples**.
- **Convergence** indicator (Nautilus's `n_eff` / `f_live` defaults are reached).
- **Evals to ML** and **time to ML** via the shared `MLTracker`.

The headline JSON+PNG pair is written to `results/searches/nautilus/` per the [section README](../README.md#versioned-artifacts) convention.

## Headline run-times (populated by Phase 4)

| Script | Backend | Wall time | Time / eval | Evals → ML | Time → ML |
|--------|---------|-----------|-------------|-----------|-----------|
| `simple.py` | NumPy | _populated_ | _populated_ | _populated_ | _populated_ |
| `jax.py` | JAX JIT | _populated_ | _populated_ | _populated_ | _populated_ |

Numbers are filled in by Phase 4's `scripts/build_readme.py` from the latest `*_summary_v<version>.json` under `results/searches/nautilus/`.

## Expected behaviour

For reference: prior sweep runs on this exact MGE setup (recorded in `autolens_workspace_developer/searches_minimal/sweep_findings.md`) put converged log-evidence at around **logZ ≈ -169k**. A non-converged early-stop reading of `logZ ≈ -191k` is roughly what you'll see after a few minutes of sampling. The likelihood landscape anneals slowly — fully converged runs at `n_live=100` take ~30–60 minutes on GPU.

The JAX variant's wall time is dominated by the NumPy/JAX boundary crossings, not the JAX kernel. A future NSS-JIT mirror will surface the no-boundary-crossing alternative.

## Caveats

- **`use_jax=True` and JIT compile**: `_setup.build_analysis(dataset, use_jax=True)` returns an analysis object that the Nautilus wrapper feeds via `jax.jit`. If the underlying JAX-jitted likelihood path has an upstream regression (see [PyAutoLens#514](https://github.com/PyAutoLabs/PyAutoLens/issues/514)), the `jax.py` script may produce different log evidence than `simple.py` — that's an upstream issue, not a Nautilus issue.
- **GPU memory**: on a 6 GB consumer GPU (e.g. RTX 2060), `jax.py` with `n_live=200` fits comfortably for Nautilus (gradient-free, no curvature storage). The same `n_live` causes OOM in NSS-JIT (which stores curvature for HMC-style moves) per the upstream sweep findings.
Loading