Skip to content

feat: vmapped_deflections_from for batched subhalo deflections#455

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/vmap-subhalo-deflections
May 27, 2026
Merged

feat: vmapped_deflections_from for batched subhalo deflections#455
Jammy2211 merged 1 commit into
mainfrom
feature/vmap-subhalo-deflections

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Adds MassProfile.vmapped_deflections_from(grid, params_batch, mask) — a generic classmethod that uses jax.vmap to compute summed deflections from N subhalos in a single GPU launch
  • Each spherical profile provides a small radial_deflection_from static method (just the physics formula); the generic wrapper on the base class handles centre subtraction, radial distance, cartesian reconstruction, masking, and summing
  • Implemented for NFWTruncatedSph and cNFWSph; unsupported profiles raise NotImplementedError with contact info
  • Extracted helper math functions (coord_func_f_from, coord_func_k_from, coord_func_m_from, F_func_from, dev_F_func_from) to module level so radial_deflection_from can call them without an instance; existing instance methods delegate to them

API Changes

New methods on MassProfile:

  • MassProfile.vmapped_deflections_from(grid, params_batch, mask, xp=None) — classmethod
  • MassProfile.radial_deflection_from(r, params, xp) — static method (default raises NotImplementedError)

New static methods:

  • NFWTruncatedSph.radial_deflection_from(r, params, xp) — params = [kappa_s, scale_radius, truncation_radius]
  • cNFWSph.radial_deflection_from(r, params, xp) — params = [kappa_s, scale_radius, core_radius]

New module-level functions:

  • abstract.coord_func_f_from(grid_radius, xp)
  • nfw_truncated.coord_func_k_from(grid_radius, tau, xp)
  • nfw_truncated.coord_func_m_from(grid_radius, tau, xp)
  • cnfw.F_func_from(theta, radius, xp)
  • cnfw.dev_F_func_from(theta, radius, xp)

Test plan

Ref: PyAutoLens#542 (prompt 1 of 4)

🤖 Generated with Claude Code

…x.vmap

Adds a generic vmapped deflection path on MassProfile that computes
summed deflections from N subhalos in a single GPU launch. Each
spherical profile provides a small radial_deflection_from static method
(the profile-specific physics); the generic wrapper handles centre
subtraction, radial distance, cartesian reconstruction, masking, and
summing — written once on the base class.

Implemented for NFWTruncatedSph and cNFWSph. Unsupported profiles raise
NotImplementedError with a clear message. Ref: PyAutoLens#542 prompt 1.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 854a26d into main May 27, 2026
6 checks passed
@Jammy2211 Jammy2211 deleted the feature/vmap-subhalo-deflections branch May 27, 2026 10:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant