Skip to content

refactor: parameterise Ellipse + EllipseMultipole math on xp#408

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/ellipse-xp
May 14, 2026
Merged

refactor: parameterise Ellipse + EllipseMultipole math on xp#408
Jammy2211 merged 1 commit into
mainfrom
feature/ellipse-xp

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Add an xp=np keyword parameter to every numeric method on Ellipse, EllipseMultipole, and EllipseMultipoleScaled, replacing bare np.* calls with xp.* so the geometry math traces under jax.jit. Rewrite EllipseMultipole.get_shape_angle's two Python while loops as a single xp.mod arithmetic wrap. Gate the NaN-raise in points_from_major_axis_from behind if xp is np: so the JIT path can let NaNs propagate through downstream nansum / nanmean reductions.

Step 5 of 7 in z_features/ellipse_fitting_jax.md. NumPy-path behaviour preserved — xp=np is the default at every entry point and call sites in FitEllipse continue to pass nothing. The prompt-2 workspace reference numbers stay byte-stable to all 8 printed decimal places; the prompt-3 rtol=1e-12 reference arrays on _points_from_major_axis continue to pass.

JAX trace check confirmed: jax.jit(lambda ps: ellipse.points_from_major_axis_from(pixel_scale=ps, xp=jnp)) compiles cleanly to shape (61, 2). Note: pixel_scale must be a concrete Python value (not a traced argument) because total_points_from uses it to compute a static array size — by design, since this is a shape parameter.

API Changes

Eight methods on Ellipse / EllipseMultipole / EllipseMultipoleScaled gain an xp=np keyword (purely additive — existing callers unaffected). One behaviour change: EllipseMultipole.get_shape_angle previously wrapped angles into the open-closed interval (-period/2, period/2] via two while loops; the xp.mod rewrite wraps into the closed-open interval [-period/2, period/2), so the single boundary value angle == period/2.0 exactly now returns -period/2.0 instead of +period/2.0. A new test pins the new convention. No removals, no other behaviour changes. See full details below.

Test Plan

  • python -m pytest test_autogalaxy/ellipse/ -v — 31/31 pass (30 existing + 1 new boundary test)
  • python -m pytest test_autogalaxy/ -x — 869/869 pass, no regressions
  • python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py from autogalaxy_workspace_test/ — all 8 prompt-2 reference numbers match byte-for-byte
  • Per PyAutoGalaxy/CLAUDE.md "Never use JAX in unit tests" rule, no JAX imports were added to test_autogalaxy/. JAX parity is verified at the workspace_test level by prompt 7 once AnalysisEllipse flips to the JIT path.
Full API Changes (for automation & release notes)

Added (keyword-only, all default xp=np)

  • Ellipse.angles_from_x0_from(..., xp=np)
  • Ellipse.ellipse_radii_from_major_axis_from(..., xp=np)
  • Ellipse.x_from_major_axis_from(..., xp=np)
  • Ellipse.y_from_major_axis_from(..., xp=np)
  • Ellipse.points_from_major_axis_from(..., xp=np)
  • EllipseMultipole.get_shape_angle(ellipse, xp=np)
  • EllipseMultipole.points_perturbed_from(..., xp=np)
  • EllipseMultipoleScaled.points_perturbed_from(..., xp=np)

Removed

None.

Changed Behaviour

  • EllipseMultipole.get_shape_angle: boundary case angle == period/2.0 (where period = 360/m) now returns -period/2.0 (closed-open interval [-period/2, period/2)) where previously it returned +period/2.0 (open-closed (-period/2, period/2]). The new xp.mod-based wrap is JIT-traceable. Pinned by test__multipole__get_shape_angle__boundary_returns_neg_half_period.
  • Ellipse.points_from_major_axis_from: the NaN-raise (raise NotImplementedError() if any output coordinate is NaN) is now skipped when xp is not np. Under JAX, NaNs propagate through downstream nansum/nanmean reductions instead of raising at JIT-trace time.

Migration

  • Existing callers: no change. xp is keyword-only with a default of np; the eight methods continue to behave identically for xp=np.
  • New JAX path: pass xp=jnp to opt in. Note that pixel_scale is a shape parameter — pass a concrete Python value, not a JAX tracer, when JIT-wrapping.
  • Code that relied on get_shape_angle returning exactly +period/2.0 at the boundary needs to flip to -period/2.0. Search-and-replace target: equality against +45.0 (or +90.0, +30.0, etc.) following a get_shape_angle() call. We don't believe any callers depend on this; the test pins the new behaviour going forward.

🤖 Generated with Claude Code

Adds xp=np keyword to every numeric method on Ellipse, EllipseMultipole,
and EllipseMultipoleScaled, replacing bare np.* with xp.* in their bodies
so the geometry computations trace under jax.jit. NumPy-path behaviour
preserved — xp=np is the default and existing call sites pass nothing.

Replaces two Python while-loops in EllipseMultipole.get_shape_angle with
xp.mod(angle + period/2, period) - period/2. Boundary semantics flip from
(-period/2, period/2] to [-period/2, period/2); a new boundary test pins
the new convention (angle == period/2.0 -> -period/2.0).

The NaN-raise in points_from_major_axis_from is now gated by `if xp is
np:` — under JAX, NaNs propagate naturally through downstream
nansum/nanmean.

Step 5 of 7 in z_features/ellipse_fitting_jax.md. Issue PyAutoGalaxy#407.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 14, 2026
@Jammy2211 Jammy2211 merged commit e2ad662 into main May 14, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/ellipse-xp branch May 14, 2026 18:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant