Skip to content

fix: subhalo redshift as free parameter raises TracerBoolConversionError under JAX #498

@Jammy2211

Description

@Jammy2211

Overview

Reported on Slack by @qiuhan96 (working with an undergraduate at the University of Groningen) on the public pip release of autolens. Setting a subhalo's redshift as a free parameter (af.UniformPrior) raises jax.errors.TracerBoolConversionError during fitting. The only workaround today is use_jax=False, which is much slower. The cause is that autolens/lens/tracer_util.py performs Python-level <, <=, ==, float() on the redshift values, which are traced arrays under jax.jit when the subhalo redshift is free.

Plan

  • Add a unit test that fits a model with a free-parameter subhalo redshift under JAX and asserts no TracerBoolConversionError.
  • Reformulate grid_2d_at_redshift_from so the redshift comparisons used to choose where to insert the subhalo plane are JAX-friendly (e.g. compute candidate traced grids and select with jnp.where / jax.lax.switch, rather than branching on Python booleans).
  • Stop calling float() and pairwise < on potentially-traced redshifts in plane_redshifts_from / planes_from — sort by the concrete redshifts of the non-subhalo galaxies (which are Python floats) and carry the subhalo redshift through as a traced scalar.
  • Stop forcing the traced subhalo centre back through tuple(...) in AnalysisLens.tracer_via_instance_from (autolens/analysis/analysis/lens.py:116).
  • Cover the edge cases in tests: subhalo before lens plane, between lens and source, equal to lens redshift, equal to source redshift.
  • Verify the full Nautilus/Dynesty fit from the reporter's script runs end-to-end with use_jax=True.
Detailed implementation plan

Affected Repositories

  • Jammy2211/PyAutoLens (primary)

Branch Survey

Repository Current Branch Dirty?
./PyAutoLens main clean

Suggested branch: feature/subhalo-redshift-jax-fix

Implementation Steps

  1. Reproduce. Add a regression test in test_autolens/lens/test_tracer_util.py (or a new test_autolens/analysis/test_analysis_lens_jax.py) that builds a 3-galaxy model (lens at z=0.5, subhalo with af.UniformPrior(0.2, 0.9), source at z=1.0) and calls jax.jit(analysis.fit_from)(instance), asserting it returns without raising.

  2. Refactor autolens/lens/tracer_util.py:plane_redshifts_from. The list comprehension [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift] (line 49) and the sorted(galaxies, key=lambda g: g.redshift) call (line 46) both fail on traced redshifts. Either:

    • Split inputs into "concrete-redshift galaxies" (sorted with Python sorted) and "traced-redshift galaxies" (the subhalo), and recombine without calling float(); or
    • Accept that this helper only ever sees Python-float redshifts and move the subhalo handling out before the call.
  3. Refactor autolens/lens/tracer_util.py:grid_2d_at_redshift_from (line 199). The three branches that fail under JAX are:

    • Line 249: if redshift <= plane_redshifts[0]: return grid.copy()
    • Line 257: [... if galaxies[0].redshift == redshift]
    • Line 267-268: for plane_index, plane_redshift in enumerate(plane_redshifts): if redshift > plane_redshift: plane_index_insert = plane_index + 1

    Replace with a structured selection: compute the traced grid for every candidate insertion position (before plane 0, between each pair, after the last plane), then pick the right one via jax.lax.switch (or jnp.where over a stacked result) using a comparison vector built with jnp.less / jnp.less_equal. The numpy path keeps the existing implementation behind if xp is np:.

  4. Update autolens/analysis/analysis/lens.py:99-116:AnalysisLens.tracer_via_instance_from. The line instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0]) forces a Python tuple(...) of traced scalars. Either keep the centre as a traced 2-vector or skip the round-trip and pass the traced centre directly into the downstream Tracer build.

  5. Tests. Cover four scenarios in test_autolens/lens/test_tracer_util.py:

    • subhalo redshift < lens.redshift
    • subhalo redshift == lens.redshift
    • lens.redshift < subhalo redshift < source.redshift (the typical case)
    • subhalo redshift == source.redshift

    Run each under both xp=np and xp=jnp (the latter inside jax.jit).

  6. End-to-end check. Run the reporter's reproduction script (the model defined above plus a Nautilus/Dynesty search with a tiny nlive) with use_jax=True and confirm the fit completes.

Key Files

  • autolens/lens/tracer_util.pyplane_redshifts_from, planes_from, grid_2d_at_redshift_from.
  • autolens/analysis/analysis/lens.pyAnalysisLens.tracer_via_instance_from (subhalo branch at line 99-116).
  • test_autolens/lens/test_tracer_util.py — new JAX regression tests.

Out of scope

  • Changes to autogalaxy or autoarray.
  • Anything related to the instance.perturb shortcut at lens.py:99 — that path is unaffected by the bug (subhalo redshift is taken from the model, not perturbed).

Original Prompt

Click to expand starting prompt

Free-parameter subhalo redshift breaks under JAX (TracerBoolConversionError)

Reporter

Reported on Slack by @qiuhan96 (with an undergraduate at the University of Groningen). Running the public pip release of autolens on Python 3.13.

Symptom

Setting the subhalo's redshift as a free parameter (a af.UniformPrior) raises a jax.errors.TracerBoolConversionError during model fitting. The fit only runs if use_jax=False is set, which is much slower.

Reproduction

import autofit as af
import autolens as al

bulge = al.model_util.mge_model_from(
    mask_radius=mask_radius,
    total_gaussians=20,
    gaussian_per_basis=2,
    centre_prior_is_uniform=True,
)

mass = af.Model(al.mp.Isothermal)
shear = af.Model(al.mp.ExternalShear)

lens = af.Model(al.Galaxy, redshift=0.5, bulge=bulge, mass=mass, shear=shear)

# Subhalo
subhalo_mass = af.Model(al.mp.IsothermalSph)
subhalo_mass.centre_0 = af.UniformPrior(lower_limit=-0.1, upper_limit=0.1)
subhalo_mass.centre_1 = af.UniformPrior(lower_limit=1.2, upper_limit=1.8)
subhalo_mass.einstein_radius = af.UniformPrior(lower_limit=0.01, upper_limit=0.4)

# Trigger: free-parameter redshift
redshift_subhalo = af.UniformPrior(lower_limit=0.2, upper_limit=0.9)
# redshift_subhalo = 0.6   # <-- works fine

subhalo_galaxy = af.Model(al.Galaxy, redshift=redshift_subhalo, mass=subhalo_mass)

# Source
bulge = al.model_util.mge_model_from(
    mask_radius=mask_radius, total_gaussians=20, centre_prior_is_uniform=False
)
source = af.Model(al.Galaxy, redshift=1.0, bulge=bulge)

model = af.Collection(galaxies=af.Collection(lens=lens, subhalo=subhalo_galaxy, source=source))

When redshift_subhalo is a UniformPrior, the fit raises:

File autolens/analysis/analysis/lens.py:99, in AnalysisLens.tracer_via_instance_from
    subhalo_centre = tracer_util.grid_2d_at_redshift_from(
        galaxies=instance.galaxies,
        redshift=instance.galaxies.subhalo.redshift,
        ...
    )
File autolens/lens/tracer_util.py:247, in grid_2d_at_redshift_from
    plane_redshifts = plane_redshifts_from(galaxies=galaxies)
File autolens/lens/tracer_util.py:46, in plane_redshifts_from
    galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)
...
jax._src.core.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].

Workaround: setting use_jax=False lets the fit run, but is much slower.

Root cause

autolens/lens/tracer_util.py performs several Python-level operations on the redshift values that fail when one of them is a JAX traced scalar:

  • Line 46sorted(galaxies, key=lambda g: g.redshift) does pairwise < comparisons on Python objects holding traced redshifts.
  • Line 49[float(g.redshift) for g in ...] calls float() on a traced scalar.
  • Line 249if redshift <= plane_redshifts[0]: is a Python branch on a traced boolean.
  • Line 257[plane_index for ... if galaxies[0].redshift == redshift] filters on a traced boolean.
  • Line 267-268for ...: if redshift > plane_redshift: plane_index_insert = plane_index + 1 again branches on a traced boolean and uses a Python integer to index the inserted plane.

grid_2d_at_redshift_from is called from autolens/analysis/analysis/lens.py:99 whenever instance.galaxies.subhalo exists, with redshift=instance.galaxies.subhalo.redshift. When that redshift is free, the value passed in is a traced array and every comparison above is illegal under jax.jit.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions