Skip to content

refactor: parameterise Ellipse + EllipseMultipole math on xp #407

@Jammy2211

Description

@Jammy2211

Overview

Add an xp=np keyword argument to every numeric method on Ellipse, EllipseMultipole, and EllipseMultipoleScaled, replacing bare np.* calls with xp.* so the geometry computations trace under jax.jit. Also replaces the two Python while loops in EllipseMultipole.get_shape_angle with a single xp.mod-based wrap so that method is JIT-safe too. Step 5 of 7 in the ellipse_fitting_jax feature (PyAutoPrompt/z_features/ellipse_fitting_jax.md). NumPy-path behaviour is preserved — xp=np is the default at every entry point and call sites in FitEllipse continue to pass nothing.

Plan

  • Thread xp=np through five Ellipse.*_from_major_axis_from methods that return arrays; replace np.* calls with xp.* in their bodies.
  • Thread xp=np through EllipseMultipole.points_perturbed_from and EllipseMultipoleScaled.points_perturbed_from.
  • Replace the two Python while loops in EllipseMultipole.get_shape_angle with a single arithmetic wrap using xp.mod. Document the (-180/m, 180/m] vs [-180/m, 180/m) boundary semantics change in a docstring note and pin it with a numpy-only boundary test.
  • Guard the if np.sum(idx) > 0: raise NotImplementedError() NaN check in points_from_major_axis_from behind if xp is np: — under JAX, NaNs propagate through nansum/nanmean downstream so a JIT-time raise would be incorrect.
  • Ellipse.total_points_from stays numpy (returns a Python int used as a static shape outside the JIT trace).
  • No JAX in unit tests per PyAutoGalaxy/CLAUDE.md rule. JAX parity is verified at the workspace_test level — implicitly through prompt 2's autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/ scripts once prompt 7 wires the JIT path. Until then, the xp=jnp path is exercised only manually during development.
  • Numpy-path numerics must not drift: prompt 3's rtol=1e-12 reference arrays and prompt 2's workspace_test reference numbers continue to pass.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary)

Work Classification

Library

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean

Suggested branch: feature/ellipse-xp
Worktree root: ~/Code/PyAutoLabs-wt/ellipse-xp/ (created later by /start_library)

Implementation Steps

  1. autogalaxy/ellipse/ellipse/ellipse.py: add xp=np to:

    • Ellipse.angles_from_x0_from(self, pixel_scale, n_i=0, xp=np) — replace np.linspace with xp.linspace.
    • Ellipse.ellipse_radii_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.divide, np.sqrt, np.add, np.sin, np.cos. Thread xp into the inner angles_from_x0_from call.
    • Ellipse.x_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.cos. Thread xp into inner calls.
    • Ellipse.y_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.sin. Thread xp into inner calls.
    • Ellipse.points_from_major_axis_from(self, pixel_scale, n_i=0, xp=np) — replace np.stack. Wrap the NaN-check (idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError()) in if xp is np:. Thread xp into x_from_major_axis_from / y_from_major_axis_from calls.
  2. Ellipse.total_points_from — unchanged. Stays numpy because it returns a Python int used outside JIT traces.

  3. autogalaxy/ellipse/ellipse/ellipse_multipole.py:

    • EllipseMultipole.get_shape_angle(self, ellipse, xp=np) — replace the two while angle < -180/self.m: angle += 360/self.m / while angle > 180/self.m: angle -= 360/self.m loops with:
      period = 360.0 / self.m
      angle = xp.mod(angle + period / 2.0, period) - period / 2.0
      The original returns angles in (-period/2, period/2]; the rewrite returns [-period/2, period/2). Document the boundary difference in a docstring note and pin the boundary case (angle = period/2.0 exactly) in a new unit test.
    • EllipseMultipole.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np) — replace np.arctan2, np.cos, np.sin, np.stack. multipole_comps_from / multipole_k_m_and_phi_m_from from convert.py operate on Python tuples and don't trip JIT; leave untouched.
    • EllipseMultipoleScaled.points_perturbed_from(self, pixel_scale, points, ellipse, n_i=0, xp=np) — same treatment as EllipseMultipole.points_perturbed_from.
  4. Call sites in FitEllipse (autogalaxy/ellipse/fit_ellipse.py:69-136) and elsewhere must NOT be updated in this PR — they pass nothing, get the numpy default, behaviour unchanged. Threading xp into those call sites happens in prompts 6 and 7.

  5. test_autogalaxy/ellipse/test_ellipse.py:

    • Add test__multipole__get_shape_angle__boundary — numpy-only. Input ellipse with angle() such that the un-normalised offset equals period/2.0 exactly; assert the rewrite returns -period/2.0 (the new convention). Documents the boundary-case change vs the original while-loop behaviour.
    • Do NOT add JAX-parity tests here. PyAutoGalaxy/CLAUDE.md: "Never use JAX in unit tests." JAX parity is checked at the workspace_test level via prompt 2's autogalaxy_workspace_test/scripts/jax_likelihood_functions/ellipse/ scripts once prompt 7 flips them to the JIT path.
  6. Run python -m pytest test_autogalaxy/ellipse/ -v from the worktree. Must report all tests pass — including prompt-3's rtol=1e-12 pins on _points_from_major_axis. If those drift, the numpy path semantics have changed and there's a bug.

  7. Run python scripts/jax_likelihood_functions/ellipse/{fit,multipoles}.py from autogalaxy_workspace_test/. Confirm all 8 reference numbers match the prompt-2 baseline.

Key Files

  • autogalaxy/ellipse/ellipse/ellipse.py — thread xp through five methods.
  • autogalaxy/ellipse/ellipse/ellipse_multipole.py — thread xp through three methods, rewrite get_shape_angle.
  • test_autogalaxy/ellipse/test_ellipse.py — one new test (boundary case for get_shape_angle).

Testing Approach

  • pytest: must remain green, including prompt-3 reference-array pins at rtol=1e-12.
  • Workspace parity: prompt-2 reference numbers byte-stable.
  • JAX-path verification: deferred to workspace_test level (prompt 7).

Original Prompt

Click to expand starting prompt

Step 5 of the ellipse-JAX series. With the 2D interpolator in place from prompt 4, the next blocker is the geometry math in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.py and @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Every routine on Ellipse uses bare np.*, and EllipseMultipole.get_shape_angle uses Python while loops to wrap an angle into [-180/m, 180/m] — both incompatible with jax.jit tracing. Convert these to the xp=np pattern documented in @PyAutoGalaxy/CLAUDE.md "JAX Support" section.

Please:

  1. Add xp=np as a keyword argument to every method in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse.py that returns a numerical array:

    • Ellipse.angles_from_x0_from
    • Ellipse.ellipse_radii_from_major_axis_from
    • Ellipse.x_from_major_axis_from
    • Ellipse.y_from_major_axis_from
    • Ellipse.points_from_major_axis_from

    Replace bare np.* with xp.* inside the function bodies (xp.linspace, xp.sin, xp.cos, xp.divide, xp.add, xp.sqrt, xp.stack). The total_points_from method stays numpy — its return type is a Python int and it's used to set static shapes outside the JIT trace.

    Special case in points_from_major_axis_from: the idx = np.logical_or(np.isnan(x), np.isnan(y)); if np.sum(idx) > 0: raise NotImplementedError() guard is JAX-incompatible (Python if on a traced value). Replace with if xp is np: around the guard — under JAX, NaNs propagate through downstream nansum/nanmean and we'd rather see them than crash inside a JIT trace.

  2. Same treatment for EllipseMultipole.points_perturbed_from and EllipseMultipoleScaled.points_perturbed_from in @PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py. Add xp=np, swap np.* for xp.*. The multipole_comps_from and multipole_k_m_and_phi_m_from helpers from @PyAutoGalaxy/autogalaxy/convert.py are called outside the math loop on Python tuples — leave those as-is unless they trip JIT (verify by tracing).

  3. Replace the while loops in EllipseMultipole.get_shape_angle (@PyAutoGalaxy/autogalaxy/ellipse/ellipse/ellipse_multipole.py:66-69) with arithmetic that JAX can trace. The intent is "wrap angle into the open interval (-180/m, 180/m]". A direct replacement using xp.mod works:

    period = 360.0 / self.m
    angle = xp.mod(angle + period / 2.0, period) - period / 2.0

    This produces values in [-period/2, period/2) rather than (-period/2, period/2], which is a tiny boundary-case difference. Verify against the existing tests in @PyAutoGalaxy/test_autogalaxy/ellipse/ and add a test pinning the new behaviour at the boundary (angle = period/2.0) so future changes don't drift unnoticed.

  4. The existing call sites in FitEllipse and elsewhere don't pass xp — they get the numpy default and behaviour is unchanged. Don't thread xp through the call sites in this prompt; that happens in prompt 6 and 7 where it actually matters.

  5. Add unit tests in @PyAutoGalaxy/test_autogalaxy/ellipse/test_ellipse.py that for one fixed Ellipse and one fixed EllipseMultipole, the xp=np and xp=jnp paths produce numerically identical points to rtol=1e-6. Gate the JAX side with pytest.importorskip("jax").

(NOTE: step 5 above conflicts with PyAutoGalaxy/CLAUDE.md's "Never use JAX in unit tests" rule — the issuing session caught this and dropped the JAX-parity tests during issue creation. Only the boundary test for get_shape_angle is added in test_autogalaxy/.)

  1. Test bar:
    • python -m pytest test_autogalaxy/ellipse/ -v passes.
    • The reference numbers from prompt 2's workspace_test scripts are unchanged on the numpy path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions