Skip to content

feat: latent variable runtime profiling package #23

@Jammy2211

Description

@Jammy2211

Overview

Adds a latent/ runtime-profiling package to autolens_profiling, mirroring the existing likelihood_runtime/ structure. Provides per-latent timing across CPU/GPU × fp64/mp configurations so the cache behaviour of LensCalc.einstein_radius_jit_from() (the only non-trivial latent compute-wise) can be quantified empirically.

Closes the eighth sub-prompt of the latent_refactor epic.

Plan

  • Mirror likelihood_runtime/ layout: latent/README.md, latent/sweep.py, latent/aggregate.py, and latent/<category>/<latent_name>.py per-latent scripts.
  • Per-latent scripts (one per default-enabled key in autolens/config/latent.yaml): set up minimal context (tracer + fit + grid) and time the latent function in isolation. Report eager numpy baseline, single-JIT compile + first-call + steady-state, and vmap batched evaluation (where applicable).
  • Methodology emphasis on the first-call vs cached-call split for effective_einstein_radius — the closure cache at lens_calc.py:1580-1586 is the main thing worth profiling here. The three flux latents (total_lens_flux_mujy, total_lensed_source_flux_mujy, total_source_flux_mujy) are trivially cheap and serve as baselines.
  • sweep.py and aggregate.py are heavy copy-adapts from likelihood_runtime/, with the cell list replaced and the entry-point script convention pointed at the new package.
  • Workspace config/latent.yaml enables the 5 library latents so the profiling scripts can isolate them via latent_keys_enabled() filtering.
Detailed implementation plan

Affected Repositories

  • autolens_profiling (primary, only)

Work Classification

Workspace (profiling repo, no library code)

Branch Survey

Repository Current Branch Dirty?
./autolens_profiling main (check on worktree create)

Suggested branch: feature/latent-profiling
Worktree root: ~/Code/PyAutoLabs-wt/latent-profiling/

Implementation Steps

  1. Create latent/ directory with subdirs latent/imaging/ (all 5 latents are imaging-derived in the day-1 catalogue).
  2. Write latent/sweep.py — copy/modify likelihood_runtime/sweep.py. Cell list becomes the 5 latents:
    • imaging/total_lens_flux_mujy
    • imaging/total_lensed_source_flux_mujy
    • imaging/total_source_flux_mujy
    • imaging/magnification
    • imaging/effective_einstein_radius
  3. Write latent/aggregate.py — copy/modify likelihood_runtime/aggregate.py. Schema stays per-config-per-cell.
  4. Write 5 per-latent scripts at latent/imaging/<name>.py. Each:
    • Uses al.fixtures.make_masked_imaging_7x7() for the dataset (smoke-style minimal fixture).
    • Builds a SIE + Sersic source model, fits to get a real FitImaging.
    • Enables ONLY its own latent via conf.instance["latent"] override before constructing the Analysis.
    • Times analysis.compute_latent_variables(parameters, model) directly (single-call) and via jax.vmap (batched).
    • For effective_einstein_radius specifically: explicitly time first-call vs steady-state to surface the _zero_contour_cache behaviour.
    • Writes a JSON output with the standard schema (matching likelihood_runtime per-cell JSON shape).
  5. Write latent/README.md — sections from likelihood_runtime/README.md (Methodology, 6-config matrix, Mixed precision, Scripts, Driving the matrix, How to read output) plus a "When the cache helps / hurts" section specific to the einstein-radius latent.
  6. Add latent/ outputs to .gitignore (mirror likelihood_runtime).
  7. Workspace config/latent.yaml enabling all 5 library latents (overrides library default-false so the profiling scripts can isolate via toggle).
  8. Verify by running one config (local_cpu_fp64) on one latent (effective_einstein_radius) end-to-end.

Key Files

  • latent/README.md — new methodology doc (Opus prose)
  • latent/sweep.py — new driver (copy-modify from likelihood_runtime)
  • latent/aggregate.py — new aggregator (copy-modify from likelihood_runtime)
  • latent/imaging/{total_lens_flux_mujy,total_lensed_source_flux_mujy,total_source_flux_mujy,magnification,effective_einstein_radius}.py — 5 per-latent scripts
  • config/latent.yaml — workspace overrides
  • .gitignore — exclude profiling output dirs

Constraints to preserve

  • Per memory feedback_jax_pure_callback_const_fold: profile via vmap, not single-jit on a concrete instance — single-jit can look 20-30× faster than vmap because pure_callback gets constant-folded.
  • Per memory feedback_jax_validation_vmap_not_jit: use fitness._vmap(jnp.array(params)) to force tracer propagation. The single-jit-on-concrete trap that hid latent regressions for #536 applies here too.
  • Per memory feedback_jax_closure_cache_busts: don't construct fresh closures per call — preserve the LensCalc's _zero_contour_cache mechanism.
  • Per memory feedback_jax_gpu_prealloc: document XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 + renice for GPU runs in the README.
  • Per memory feedback_ship_workspace_binary_leak: profiling output dirs (PNG / JSON / log) must be .gitignored before shipping.

Original Prompt

Click to expand starting prompt

Parent epic: PyAutoPrompt/z_features/latent_refactor.md. Depends on PyAutoLens #534 (shipped).

Mirror autolens_profiling/likelihood_runtime/ structure with a latent/ package: README + sweep + aggregate + per-latent scripts. Emphasis on first-call vs cached-call timing for effective_einstein_radius (closure cache behaviour). Sonnet for per-latent script bodies, Opus for README. Honour all the JAX-profiling memories (vmap not jit on concrete, no constant-folding traps, GPU prealloc, closure cache busts).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions