Overview
Adds a latent/ runtime-profiling package to autolens_profiling, mirroring the existing likelihood_runtime/ structure. Provides per-latent timing across CPU/GPU × fp64/mp configurations so the cache behaviour of LensCalc.einstein_radius_jit_from() (the only non-trivial latent compute-wise) can be quantified empirically.
Closes the eighth sub-prompt of the latent_refactor epic.
Plan
- Mirror
likelihood_runtime/ layout: latent/README.md, latent/sweep.py, latent/aggregate.py, and latent/<category>/<latent_name>.py per-latent scripts.
- Per-latent scripts (one per default-enabled key in
autolens/config/latent.yaml): set up minimal context (tracer + fit + grid) and time the latent function in isolation. Report eager numpy baseline, single-JIT compile + first-call + steady-state, and vmap batched evaluation (where applicable).
- Methodology emphasis on the first-call vs cached-call split for
effective_einstein_radius — the closure cache at lens_calc.py:1580-1586 is the main thing worth profiling here. The three flux latents (total_lens_flux_mujy, total_lensed_source_flux_mujy, total_source_flux_mujy) are trivially cheap and serve as baselines.
sweep.py and aggregate.py are heavy copy-adapts from likelihood_runtime/, with the cell list replaced and the entry-point script convention pointed at the new package.
- Workspace
config/latent.yaml enables the 5 library latents so the profiling scripts can isolate them via latent_keys_enabled() filtering.
Detailed implementation plan
Affected Repositories
- autolens_profiling (primary, only)
Work Classification
Workspace (profiling repo, no library code)
Branch Survey
| Repository |
Current Branch |
Dirty? |
| ./autolens_profiling |
main |
(check on worktree create) |
Suggested branch: feature/latent-profiling
Worktree root: ~/Code/PyAutoLabs-wt/latent-profiling/
Implementation Steps
- Create
latent/ directory with subdirs latent/imaging/ (all 5 latents are imaging-derived in the day-1 catalogue).
- Write
latent/sweep.py — copy/modify likelihood_runtime/sweep.py. Cell list becomes the 5 latents:
imaging/total_lens_flux_mujy
imaging/total_lensed_source_flux_mujy
imaging/total_source_flux_mujy
imaging/magnification
imaging/effective_einstein_radius
- Write
latent/aggregate.py — copy/modify likelihood_runtime/aggregate.py. Schema stays per-config-per-cell.
- Write 5 per-latent scripts at
latent/imaging/<name>.py. Each:
- Uses
al.fixtures.make_masked_imaging_7x7() for the dataset (smoke-style minimal fixture).
- Builds a SIE + Sersic source model, fits to get a real
FitImaging.
- Enables ONLY its own latent via
conf.instance["latent"] override before constructing the Analysis.
- Times
analysis.compute_latent_variables(parameters, model) directly (single-call) and via jax.vmap (batched).
- For
effective_einstein_radius specifically: explicitly time first-call vs steady-state to surface the _zero_contour_cache behaviour.
- Writes a JSON output with the standard schema (matching likelihood_runtime per-cell JSON shape).
- Write
latent/README.md — sections from likelihood_runtime/README.md (Methodology, 6-config matrix, Mixed precision, Scripts, Driving the matrix, How to read output) plus a "When the cache helps / hurts" section specific to the einstein-radius latent.
- Add
latent/ outputs to .gitignore (mirror likelihood_runtime).
- Workspace
config/latent.yaml enabling all 5 library latents (overrides library default-false so the profiling scripts can isolate via toggle).
- Verify by running one config (
local_cpu_fp64) on one latent (effective_einstein_radius) end-to-end.
Key Files
latent/README.md — new methodology doc (Opus prose)
latent/sweep.py — new driver (copy-modify from likelihood_runtime)
latent/aggregate.py — new aggregator (copy-modify from likelihood_runtime)
latent/imaging/{total_lens_flux_mujy,total_lensed_source_flux_mujy,total_source_flux_mujy,magnification,effective_einstein_radius}.py — 5 per-latent scripts
config/latent.yaml — workspace overrides
.gitignore — exclude profiling output dirs
Constraints to preserve
- Per memory
feedback_jax_pure_callback_const_fold: profile via vmap, not single-jit on a concrete instance — single-jit can look 20-30× faster than vmap because pure_callback gets constant-folded.
- Per memory
feedback_jax_validation_vmap_not_jit: use fitness._vmap(jnp.array(params)) to force tracer propagation. The single-jit-on-concrete trap that hid latent regressions for #536 applies here too.
- Per memory
feedback_jax_closure_cache_busts: don't construct fresh closures per call — preserve the LensCalc's _zero_contour_cache mechanism.
- Per memory
feedback_jax_gpu_prealloc: document XLA_PYTHON_CLIENT_MEM_FRACTION=0.5 + renice for GPU runs in the README.
- Per memory
feedback_ship_workspace_binary_leak: profiling output dirs (PNG / JSON / log) must be .gitignored before shipping.
Original Prompt
Click to expand starting prompt
Parent epic: PyAutoPrompt/z_features/latent_refactor.md. Depends on PyAutoLens #534 (shipped).
Mirror autolens_profiling/likelihood_runtime/ structure with a latent/ package: README + sweep + aggregate + per-latent scripts. Emphasis on first-call vs cached-call timing for effective_einstein_radius (closure cache behaviour). Sonnet for per-latent script bodies, Opus for README. Honour all the JAX-profiling memories (vmap not jit on concrete, no constant-folding traps, GPU prealloc, closure cache busts).
Overview
Adds a
latent/runtime-profiling package to autolens_profiling, mirroring the existinglikelihood_runtime/structure. Provides per-latent timing across CPU/GPU × fp64/mp configurations so the cache behaviour ofLensCalc.einstein_radius_jit_from()(the only non-trivial latent compute-wise) can be quantified empirically.Closes the eighth sub-prompt of the latent_refactor epic.
Plan
likelihood_runtime/layout:latent/README.md,latent/sweep.py,latent/aggregate.py, andlatent/<category>/<latent_name>.pyper-latent scripts.autolens/config/latent.yaml): set up minimal context (tracer + fit + grid) and time the latent function in isolation. Report eager numpy baseline, single-JIT compile + first-call + steady-state, and vmap batched evaluation (where applicable).effective_einstein_radius— the closure cache atlens_calc.py:1580-1586is the main thing worth profiling here. The three flux latents (total_lens_flux_mujy,total_lensed_source_flux_mujy,total_source_flux_mujy) are trivially cheap and serve as baselines.sweep.pyandaggregate.pyare heavy copy-adapts fromlikelihood_runtime/, with the cell list replaced and the entry-point script convention pointed at the new package.config/latent.yamlenables the 5 library latents so the profiling scripts can isolate them vialatent_keys_enabled()filtering.Detailed implementation plan
Affected Repositories
Work Classification
Workspace (profiling repo, no library code)
Branch Survey
Suggested branch:
feature/latent-profilingWorktree root:
~/Code/PyAutoLabs-wt/latent-profiling/Implementation Steps
latent/directory with subdirslatent/imaging/(all 5 latents are imaging-derived in the day-1 catalogue).latent/sweep.py— copy/modifylikelihood_runtime/sweep.py. Cell list becomes the 5 latents:imaging/total_lens_flux_mujyimaging/total_lensed_source_flux_mujyimaging/total_source_flux_mujyimaging/magnificationimaging/effective_einstein_radiuslatent/aggregate.py— copy/modifylikelihood_runtime/aggregate.py. Schema stays per-config-per-cell.latent/imaging/<name>.py. Each:al.fixtures.make_masked_imaging_7x7()for the dataset (smoke-style minimal fixture).FitImaging.conf.instance["latent"]override before constructing the Analysis.analysis.compute_latent_variables(parameters, model)directly (single-call) and viajax.vmap(batched).effective_einstein_radiusspecifically: explicitly time first-call vs steady-state to surface the_zero_contour_cachebehaviour.latent/README.md— sections from likelihood_runtime/README.md (Methodology, 6-config matrix, Mixed precision, Scripts, Driving the matrix, How to read output) plus a "When the cache helps / hurts" section specific to the einstein-radius latent.latent/outputs to.gitignore(mirror likelihood_runtime).config/latent.yamlenabling all 5 library latents (overrides library default-false so the profiling scripts can isolate via toggle).local_cpu_fp64) on one latent (effective_einstein_radius) end-to-end.Key Files
latent/README.md— new methodology doc (Opus prose)latent/sweep.py— new driver (copy-modify from likelihood_runtime)latent/aggregate.py— new aggregator (copy-modify from likelihood_runtime)latent/imaging/{total_lens_flux_mujy,total_lensed_source_flux_mujy,total_source_flux_mujy,magnification,effective_einstein_radius}.py— 5 per-latent scriptsconfig/latent.yaml— workspace overrides.gitignore— exclude profiling output dirsConstraints to preserve
feedback_jax_pure_callback_const_fold: profile via vmap, not single-jit on a concrete instance — single-jit can look 20-30× faster than vmap because pure_callback gets constant-folded.feedback_jax_validation_vmap_not_jit: usefitness._vmap(jnp.array(params))to force tracer propagation. The single-jit-on-concrete trap that hid latent regressions for #536 applies here too.feedback_jax_closure_cache_busts: don't construct fresh closures per call — preserve the LensCalc's_zero_contour_cachemechanism.feedback_jax_gpu_prealloc: documentXLA_PYTHON_CLIENT_MEM_FRACTION=0.5+ renice for GPU runs in the README.feedback_ship_workspace_binary_leak: profiling output dirs (PNG / JSON / log) must be.gitignored before shipping.Original Prompt
Click to expand starting prompt
Parent epic:
PyAutoPrompt/z_features/latent_refactor.md. Depends on PyAutoLens #534 (shipped).Mirror
autolens_profiling/likelihood_runtime/structure with alatent/package: README + sweep + aggregate + per-latent scripts. Emphasis on first-call vs cached-call timing foreffective_einstein_radius(closure cache behaviour). Sonnet for per-latent script bodies, Opus for README. Honour all the JAX-profiling memories (vmap not jit on concrete, no constant-folding traps, GPU prealloc, closure cache busts).