Skip to content

triage: drift in jit/ regression constants (autolens_profiling F1) #67

@Jammy2211

Description

@Jammy2211

Overview

Follow-up F1 of the autolens_profiling z_feature (tracker: PyAutoPrompt/z_features/autolens_profiling.md). Phase 1 mirror smoke runs on PyAutoLens 2026.5.14.2 (2026-05-16) found two of the four hardcoded EXPECTED_LOG_LIKELIHOOD_* regression constants in jax_profiling/jit/ have drifted off the current eager log-likelihood. One drift is small (likely numerical), the other is huge with a sign change (almost certainly a real behaviour change in PointSolver / FitPositionsImagePairAll). The mirrored copies in PyAutoLabs/autolens_profiling/likelihood/ carry the same constants verbatim, so both repos need to be refreshed together (or — if a real bug is identified — held until the upstream PyAuto* bug is fixed).

Plan

  • Reproduce the drift cleanly in this repo (autolens_workspace_developer) on the current PyAutoLens release, recording the actual eager log-likelihood for each of the 4 constants.
  • Spot-smoke the 6 other jit scripts that weren't hit during the Phase 1 mirror smoke to find any other drifts (imaging/pixelization.py, imaging/delaunay.py, interferometer/pixelization.py, interferometer/delaunay.py, datacube/delaunay.py, point_source/source_plane.py).
  • For point_source/image_plane.py specifically: light bisect on PyAutoLens / PyAutoArray / PyAutoGalaxy since the constant was last set, focusing on PointSolver, FitPositionsImagePairAll, and positions-noise-map handling. Decide whether the new -362 value is correct or evidence of a real regression.
  • For imaging/mge.py: investigate whether the 0.6% drift is pure floating-point noise from a dependency bump vs. an algorithmic change (compare per-step JIT timings; if shape is similar and every number moves slightly, it's numerical).
  • Outcome: either refresh the constants (in both _developer and autolens_profiling) via a coordinated PR pair, OR file an upstream bug against the responsible PyAuto* repo and leave the assertion failing as load-bearing while the bug is open.
Detailed implementation plan

Affected Repositories

  • PyAutoLabs/autolens_workspace_developer (primary — where the source-of-truth constants live)
  • PyAutoLabs/autolens_profiling (mirror — constants carried verbatim in Phase 1; same refresh applies)
  • Potentially PyAutoLabs/PyAutoLens / PyAutoArray / PyAutoGalaxy (if a real bug is found — file separate issue)

Work Classification

Workspace work (scripts + regression assertion constants; no library code changes in this repo).

Branch Survey

Repository Current Branch Dirty?
./autolens_workspace_developer main dirty (pre-existing in-progress work unrelated to this task — worktree gives a clean copy)
./autolens_profiling main clean (just shipped Phase 1)

Suggested branch: feature/jit-regression-drift (on both repos when the time comes)
Worktree root: ~/Code/PyAutoLabs-wt/jit-regression-drift/

All four EXPECTED_LOG_LIKELIHOOD_* constants

Script Line Constant Phase 1 smoke result
jax_profiling/jit/imaging/mge.py 853 EXPECTED_LOG_LIKELIHOOD_HST = 27379.38890685539 drifted to 27542.08 (+0.6%)
jax_profiling/jit/interferometer/mge.py 489 EXPECTED_LOG_LIKELIHOOD_SMA = -3153.8939746810656 PASSED
jax_profiling/jit/point_source/image_plane.py 444 EXPECTED_LOG_LIKELIHOOD_IMAGE_PLANE = 0.07475703623045682 drifted to -362.21 (sign change, ~5000×)
jax_profiling/jit/point_source/source_plane.py 492 EXPECTED_LOG_LIKELIHOOD_SOURCE_PLANE = -294.1401881258811 untested in Phase 1 — verify first

Plus 6 other scripts in jax_profiling/jit/ that don't carry an EXPECTED_LOG_LIKELIHOOD_* constant but do have other internal assertions; worth a re-smoke pass to catch any other drift.

Implementation Steps

  1. Set up the worktree (gives a clean main copy unaffected by the dirty canonical checkout):

    source admin_jammy/software/worktree.sh
    worktree_create jit-regression-drift autolens_workspace_developer autolens_profiling
  2. Smoke all 10 jit scripts in the worktree from a clean baseline, recording actual log-likelihoods. Use:

    cd ~/Code/PyAutoLabs-wt/jit-regression-drift/autolens_workspace_developer
    source ../activate.sh
    for f in jax_profiling/jit/{imaging,interferometer,point_source}/*.py jax_profiling/jit/datacube/delaunay.py; do
      echo "=== $f ==="
      python "$f" 2>&1 | tail -30
    done
  3. Investigate point_source/image_plane.py first (biggest drift). Likely suspects:

    • FitPositionsImagePairAll chi-squared formula change.
    • PointSolver triangle-refinement loop output drift.
    • positions_noise_map handling change (squaring, sign).

    Use git log -L 444,444:jax_profiling/jit/point_source/image_plane.py to find when the constant was last set, then look at downstream PyAutoLens / PyAutoArray commits since then.

  4. Investigate imaging/mge.py (smaller drift). Compare per-step JIT timings to detect algorithmic vs numerical change.

  5. Decide per script, in this order:

    • If the new value is correct: refresh both copies (this repo + autolens_profiling mirror).
    • If the new value is wrong: file an upstream bug, do NOT refresh. Cross-link from this issue.
  6. PR pair (only after step 5 decisions):

    • autolens_workspace_developer: feature/jit-regression-drift updating the constants and adding a CHANGELOG-style note in each script's docstring.
    • autolens_profiling: feature/jit-regression-drift refreshing the same constants in the Phase 1 mirror. Cross-reference both PRs.

Key Files

  • jax_profiling/jit/imaging/mge.py — refresh EXPECTED_LOG_LIKELIHOOD_HST (L853).
  • jax_profiling/jit/point_source/image_plane.py — refresh or hold EXPECTED_LOG_LIKELIHOOD_IMAGE_PLANE (L444).
  • jax_profiling/jit/point_source/source_plane.py — verify, refresh if drifted.
  • jax_profiling/jit/interferometer/mge.py — passed in Phase 1; re-confirm.
  • autolens_profiling/likelihood/{imaging/mge.py,point_source/image_plane.py,point_source/source_plane.py,interferometer/mge.py} — mirror constants (identical to source).

Out of scope

  • Phase 2+ of the autolens_profiling z_feature (those are independent migration phases).
  • JAX gradient profiling (gated on the gradient story stabilising).
  • The pre-existing dirty state of the _developer canonical checkout (unrelated work).

Original Prompt

Click to expand starting prompt

The full prompt content is preserved at PyAutoPrompt/issued/jit_regression_constant_drift.md after this issue is created. See also the z_features tracker at PyAutoPrompt/z_features/autolens_profiling.md (Follow-ups section).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions