From 39a722228b7388d75e808c085d2a4f087355acf1 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 27 May 2026 11:36:51 +0100 Subject: [PATCH] feat: vmapped_deflections_from for batched subhalo deflections via jax.vmap MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a generic vmapped deflection path on MassProfile that computes summed deflections from N subhalos in a single GPU launch. Each spherical profile provides a small radial_deflection_from static method (the profile-specific physics); the generic wrapper handles centre subtraction, radial distance, cartesian reconstruction, masking, and summing — written once on the base class. Implemented for NFWTruncatedSph and cNFWSph. Unsupported profiles raise NotImplementedError with a clear message. Ref: PyAutoLens#542 prompt 1. Co-Authored-By: Claude Opus 4.7 (1M context) --- autogalaxy/profiles/mass/abstract/abstract.py | 30 ++++ autogalaxy/profiles/mass/dark/abstract.py | 43 ++---- autogalaxy/profiles/mass/dark/cnfw.py | 144 ++++++++++-------- .../profiles/mass/dark/nfw_truncated.py | 52 ++++--- 4 files changed, 163 insertions(+), 106 deletions(-) diff --git a/autogalaxy/profiles/mass/abstract/abstract.py b/autogalaxy/profiles/mass/abstract/abstract.py index ffb1fd9d..ac7da5d8 100644 --- a/autogalaxy/profiles/mass/abstract/abstract.py +++ b/autogalaxy/profiles/mass/abstract/abstract.py @@ -38,6 +38,36 @@ def __init__( """ super().__init__(centre=centre, ell_comps=ell_comps) + @staticmethod + def radial_deflection_from(r, params, xp): + raise NotImplementedError( + "vmapped_deflections_from is not yet supported for this profile. " + "Only spherical profiles with a radial_deflection_from " + "implementation are currently supported. If you need this for " + "your profile, please contact James " + "(jnightingale2211@gmail.com) or open a GitHub issue." + ) + + @classmethod + def vmapped_deflections_from(cls, grid, params_batch, mask, xp=None): + import jax + import jax.numpy as jnp + + if xp is None: + xp = jnp + + def single(params): + centred = grid - params[:2] + r = xp.sqrt(centred[:, 0] ** 2 + centred[:, 1] ** 2) + defl_r = cls.radial_deflection_from(r, params[2:], xp) + angle = xp.arctan2(centred[:, 0], centred[:, 1]) + return defl_r[:, None] * xp.stack( + [xp.sin(angle), xp.cos(angle)], axis=-1 + ) + + all_defl = jax.vmap(single)(params_batch) + return xp.sum(all_defl * mask[:, None, None], axis=0) + def deflections_yx_2d_from(self, grid): """ Returns the 2D deflection angles of the mass profile from a 2D grid of Cartesian (y,x) coordinates. diff --git a/autogalaxy/profiles/mass/dark/abstract.py b/autogalaxy/profiles/mass/dark/abstract.py index 07e5c858..6fe71614 100644 --- a/autogalaxy/profiles/mass/dark/abstract.py +++ b/autogalaxy/profiles/mass/dark/abstract.py @@ -14,6 +14,21 @@ class DarkProfile: pass +def coord_func_f_from(grid_radius, xp=np): + if isinstance(grid_radius, float) or isinstance(grid_radius, complex): + grid_radius = xp.array([grid_radius]) + + f = xp.ones(shape=grid_radius.shape[0], dtype="complex64") + + r = grid_radius + inv_r = 1.0 / r + + out_gt = (1.0 / xp.sqrt(r**2 - 1.0)) * xp.arccos(inv_r) + out_lt = (1.0 / xp.sqrt(1.0 - r**2)) * xp.arccosh(inv_r) + + return xp.where(r > 1.0, out_gt, xp.where(r < 1.0, out_lt, f)) + + class AbstractgNFW(MassProfile, DarkProfile): r""" Abstract base class for generalised NFW (gNFW) dark matter halo profiles. @@ -111,33 +126,7 @@ def density_3d_func(self, r, xp=np): ) def coord_func_f(self, grid_radius: np.ndarray, xp=np) -> np.ndarray: - """ - Given an array `grid_radius` and a work array `f`, fill f[i] with - - • (1 / sqrt(r_i^2 − 1)) * arccos(1 / r_i) if r_i > 1 - • (1 / sqrt(1 − r_i^2)) * arccosh(1 / r_i) if r_i < 1 - • leave f[i] unchanged if r_i == 1 (you can adjust this if you want a different convention) - - This version uses only pure JAX ops, so it can be jitted or grad’ed without - any Python control flow on tracer values. - """ - if isinstance(grid_radius, float) or isinstance(grid_radius, complex): - grid_radius = xp.array([grid_radius]) - - f = xp.ones(shape=grid_radius.shape[0], dtype="complex64") - - # compute both branches - r = grid_radius - inv_r = 1.0 / r - - # branch for r > 1 - out_gt = (1.0 / xp.sqrt(r**2 - 1.0)) * xp.arccos(inv_r) - - # branch for r < 1 - out_lt = (1.0 / xp.sqrt(1.0 - r**2)) * xp.arccosh(inv_r) - - # combine: if r>1 pick out_gt, elif r<1 pick out_lt, else keep original f - return xp.where(r > 1.0, out_gt, xp.where(r < 1.0, out_lt, f)) + return coord_func_f_from(grid_radius, xp=xp) def coord_func_g(self, grid_radius: np.ndarray, xp=np) -> np.ndarray: """ diff --git a/autogalaxy/profiles/mass/dark/cnfw.py b/autogalaxy/profiles/mass/dark/cnfw.py index 2996f787..18639b48 100644 --- a/autogalaxy/profiles/mass/dark/cnfw.py +++ b/autogalaxy/profiles/mass/dark/cnfw.py @@ -8,6 +8,73 @@ import autoarray as aa +def F_func_from(theta, radius, xp=np): + F = theta * 0.0 + + mask1 = (theta > 0) & (theta <= radius) + mask2 = theta > radius + + F = xp.where( + mask1, + ( + radius / 2 * xp.log(2 * radius / theta) + - xp.sqrt(radius**2 - theta**2) + * xp.arctanh(xp.sqrt((radius - theta) / (radius + theta))) + ), + F, + ) + + F = xp.where( + mask2, + ( + radius / 2 * xp.log(2 * radius / theta) + + xp.sqrt(theta**2 - radius**2) + * xp.arctan(xp.sqrt((theta - radius) / (theta + radius))) + ), + F, + ) + + return 2 * radius * F + + +def dev_F_func_from(theta, radius, xp=np): + dev_F = theta * 0.0 + + mask1 = (theta > 0) & (theta < radius) + mask2 = theta == radius + mask3 = theta > radius + + dev_F = xp.where( + mask1, + ( + radius * xp.log(2 * radius / theta) + - (2 * radius**2 - theta**2) + / xp.sqrt(radius**2 - theta**2) + * xp.arctanh(xp.sqrt((radius - theta) / (radius + theta))) + ), + dev_F, + ) + + dev_F = xp.where( + mask2, + radius * (xp.log(2) - 1 / 2), + dev_F, + ) + + dev_F = xp.where( + mask3, + ( + radius * xp.log(2 * radius / theta) + + (theta**2 - 2 * radius**2) + / xp.sqrt(theta**2 - radius**2) + * xp.arctan(xp.sqrt((theta - radius) / (theta + radius))) + ), + dev_F, + ) + + return 2 * dev_F + + class cNFW(AbstractgNFW): r""" Elliptical cored NFW (cNFW) dark matter halo profile. @@ -211,72 +278,27 @@ def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): ) def F_func(self, theta, radius, xp=np): - - F = theta * 0.0 - - mask1 = (theta > 0) & (theta <= radius) - mask2 = theta > radius - - F = xp.where( - mask1, - ( - radius / 2 * xp.log(2 * radius / theta) - - xp.sqrt(radius**2 - theta**2) - * xp.arctanh(xp.sqrt((radius - theta) / (radius + theta))) - ), - F, - ) - - F = xp.where( - mask2, - ( - radius / 2 * xp.log(2 * radius / theta) - + xp.sqrt(theta**2 - radius**2) - * xp.arctan(xp.sqrt((theta - radius) / (theta + radius))) - ), - F, - ) - - return 2 * radius * F + return F_func_from(theta, radius, xp=xp) def dev_F_func(self, theta, radius, xp=np): + return dev_F_func_from(theta, radius, xp=xp) - dev_F = theta * 0.0 - - mask1 = (theta > 0) & (theta < radius) - mask2 = theta == radius - mask3 = theta > radius - - dev_F = xp.where( - mask1, - ( - radius * xp.log(2 * radius / theta) - - (2 * radius**2 - theta**2) - / xp.sqrt(radius**2 - theta**2) - * xp.arctanh(xp.sqrt((radius - theta) / (radius + theta))) - ), - dev_F, - ) - - dev_F = xp.where( - mask2, - radius * (xp.log(2) - 1 / 2), - dev_F, - ) - - dev_F = xp.where( - mask3, - ( - radius * xp.log(2 * radius / theta) - + (theta**2 - 2 * radius**2) - / xp.sqrt(theta**2 - radius**2) - * xp.arctan(xp.sqrt((theta - radius) / (theta + radius))) - ), - dev_F, + @staticmethod + def radial_deflection_from(r, params, xp): + kappa_s, scale_radius, core_radius = params[0], params[1], params[2] + theta = xp.maximum(r, 1e-8) + factor = 4.0 * kappa_s * scale_radius**2 + return ( + factor + * ( + F_func_from(theta, scale_radius, xp=xp) + - F_func_from(theta, core_radius, xp=xp) + - (scale_radius - core_radius) + * dev_F_func_from(theta, scale_radius, xp=xp) + ) + / (theta * (scale_radius - core_radius) ** 2) ) - return 2 * dev_F - @aa.over_sample @aa.decorators.to_array @aa.decorators.transform diff --git a/autogalaxy/profiles/mass/dark/nfw_truncated.py b/autogalaxy/profiles/mass/dark/nfw_truncated.py index 5a9e8df3..d1c17c8f 100644 --- a/autogalaxy/profiles/mass/dark/nfw_truncated.py +++ b/autogalaxy/profiles/mass/dark/nfw_truncated.py @@ -4,7 +4,31 @@ import autoarray as aa from autogalaxy.cosmology.model import LensingCosmology -from autogalaxy.profiles.mass.dark.abstract import AbstractgNFW +from autogalaxy.profiles.mass.dark.abstract import AbstractgNFW, coord_func_f_from + + +def coord_func_k_from(grid_radius, tau, xp=np): + return xp.log( + xp.divide( + grid_radius, + xp.sqrt(xp.square(grid_radius) + xp.square(tau)) + tau, + ) + ) + + +def coord_func_m_from(grid_radius, tau, xp=np): + f_r = coord_func_f_from(grid_radius=grid_radius, xp=xp) + k_r = coord_func_k_from(grid_radius=grid_radius, tau=tau, xp=xp) + + return (tau**2.0 / (tau**2.0 + 1.0) ** 2.0) * ( + ((tau**2.0 + 2.0 * grid_radius**2.0 - 1.0) * f_r) + + (xp.pi * tau) + + ((tau**2.0 - 1.0) * xp.log(tau)) + + ( + xp.sqrt(grid_radius**2.0 + tau**2.0) + * (((tau**2.0 - 1.0) / tau) * k_r - xp.pi) + ) + ) class NFWTruncatedSph(AbstractgNFW): @@ -123,12 +147,7 @@ def potential_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs): ) def coord_func_k(self, grid_radius, xp=np): - return xp.log( - xp.divide( - grid_radius, - xp.sqrt(xp.square(grid_radius) + xp.square(self.tau)) + self.tau, - ) - ) + return coord_func_k_from(grid_radius, self.tau, xp=xp) def coord_func_l(self, grid_radius, xp=np): f_r = self.coord_func_f(grid_radius=grid_radius, xp=xp) @@ -149,18 +168,15 @@ def coord_func_l(self, grid_radius, xp=np): ) def coord_func_m(self, grid_radius, xp=np): - f_r = self.coord_func_f(grid_radius=grid_radius, xp=xp) - k_r = self.coord_func_k(grid_radius=grid_radius, xp=xp) + return coord_func_m_from(grid_radius, self.tau, xp=xp) - return (self.tau**2.0 / (self.tau**2.0 + 1.0) ** 2.0) * ( - ((self.tau**2.0 + 2.0 * grid_radius**2.0 - 1.0) * f_r) - + (xp.pi * self.tau) - + ((self.tau**2.0 - 1.0) * xp.log(self.tau)) - + ( - xp.sqrt(grid_radius**2.0 + self.tau**2.0) - * (((self.tau**2.0 - 1.0) / self.tau) * k_r - xp.pi) - ) - ) + @staticmethod + def radial_deflection_from(r, params, xp): + kappa_s, scale_radius, truncation_radius = params[0], params[1], params[2] + eta = (r / scale_radius) + 0j + tau = truncation_radius / scale_radius + m = xp.real(coord_func_m_from(eta, tau, xp=xp)) + return (4.0 * kappa_s * scale_radius / xp.real(eta)) * m @staticmethod def _delta_c_from_concentration(concentration: float) -> float: