refactor: parameterise Ellipse + EllipseMultipole math on xp#408
Merged
Conversation
Adds xp=np keyword to every numeric method on Ellipse, EllipseMultipole, and EllipseMultipoleScaled, replacing bare np.* with xp.* in their bodies so the geometry computations trace under jax.jit. NumPy-path behaviour preserved — xp=np is the default and existing call sites pass nothing. Replaces two Python while-loops in EllipseMultipole.get_shape_angle with xp.mod(angle + period/2, period) - period/2. Boundary semantics flip from (-period/2, period/2] to [-period/2, period/2); a new boundary test pins the new convention (angle == period/2.0 -> -period/2.0). The NaN-raise in points_from_major_axis_from is now gated by `if xp is np:` — under JAX, NaNs propagate naturally through downstream nansum/nanmean. Step 5 of 7 in z_features/ellipse_fitting_jax.md. Issue PyAutoGalaxy#407. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Add an
xp=npkeyword parameter to every numeric method onEllipse,EllipseMultipole, andEllipseMultipoleScaled, replacing barenp.*calls withxp.*so the geometry math traces underjax.jit. RewriteEllipseMultipole.get_shape_angle's two Pythonwhileloops as a singlexp.modarithmetic wrap. Gate the NaN-raise inpoints_from_major_axis_frombehindif xp is np:so the JIT path can let NaNs propagate through downstreamnansum/nanmeanreductions.Step 5 of 7 in
z_features/ellipse_fitting_jax.md. NumPy-path behaviour preserved —xp=npis the default at every entry point and call sites inFitEllipsecontinue to pass nothing. The prompt-2 workspace reference numbers stay byte-stable to all 8 printed decimal places; the prompt-3rtol=1e-12reference arrays on_points_from_major_axiscontinue to pass.JAX trace check confirmed:
jax.jit(lambda ps: ellipse.points_from_major_axis_from(pixel_scale=ps, xp=jnp))compiles cleanly to shape(61, 2). Note:pixel_scalemust be a concrete Python value (not a traced argument) becausetotal_points_fromuses it to compute a static array size — by design, since this is a shape parameter.API Changes
Eight methods on
Ellipse/EllipseMultipole/EllipseMultipoleScaledgain anxp=npkeyword (purely additive — existing callers unaffected). One behaviour change:EllipseMultipole.get_shape_anglepreviously wrapped angles into the open-closed interval(-period/2, period/2]via twowhileloops; thexp.modrewrite wraps into the closed-open interval[-period/2, period/2), so the single boundary valueangle == period/2.0exactly now returns-period/2.0instead of+period/2.0. A new test pins the new convention. No removals, no other behaviour changes. See full details below.Test Plan
python -m pytest test_autogalaxy/ellipse/ -v— 31/31 pass (30 existing + 1 new boundary test)python -m pytest test_autogalaxy/ -x— 869/869 pass, no regressionspython scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.pyfromautogalaxy_workspace_test/— all 8 prompt-2 reference numbers match byte-for-bytePyAutoGalaxy/CLAUDE.md"Never use JAX in unit tests" rule, no JAX imports were added totest_autogalaxy/. JAX parity is verified at the workspace_test level by prompt 7 onceAnalysisEllipseflips to the JIT path.Full API Changes (for automation & release notes)
Added (keyword-only, all default
xp=np)Ellipse.angles_from_x0_from(..., xp=np)Ellipse.ellipse_radii_from_major_axis_from(..., xp=np)Ellipse.x_from_major_axis_from(..., xp=np)Ellipse.y_from_major_axis_from(..., xp=np)Ellipse.points_from_major_axis_from(..., xp=np)EllipseMultipole.get_shape_angle(ellipse, xp=np)EllipseMultipole.points_perturbed_from(..., xp=np)EllipseMultipoleScaled.points_perturbed_from(..., xp=np)Removed
None.
Changed Behaviour
EllipseMultipole.get_shape_angle: boundary caseangle == period/2.0(whereperiod = 360/m) now returns-period/2.0(closed-open interval[-period/2, period/2)) where previously it returned+period/2.0(open-closed(-period/2, period/2]). The newxp.mod-based wrap is JIT-traceable. Pinned bytest__multipole__get_shape_angle__boundary_returns_neg_half_period.Ellipse.points_from_major_axis_from: the NaN-raise (raise NotImplementedError()if any output coordinate is NaN) is now skipped whenxp is not np. Under JAX, NaNs propagate through downstreamnansum/nanmeanreductions instead of raising at JIT-trace time.Migration
xpis keyword-only with a default ofnp; the eight methods continue to behave identically forxp=np.xp=jnpto opt in. Note thatpixel_scaleis a shape parameter — pass a concrete Python value, not a JAX tracer, when JIT-wrapping.get_shape_anglereturning exactly+period/2.0at the boundary needs to flip to-period/2.0. Search-and-replace target: equality against+45.0(or+90.0,+30.0, etc.) following aget_shape_angle()call. We don't believe any callers depend on this; the test pins the new behaviour going forward.🤖 Generated with Claude Code