Skip to content

feat: re-key linear_light_profile_intensity_dict to survive jax.jit round-trip #448

@Jammy2211

Description

@Jammy2211

Overview

Make fit.linear_light_profile_intensity_dict survive a jax.jit round-trip so the visualization path works for any model containing a linear light profile (al.lp_linear.Gaussian, al.lmp_linear.GaussianGradient, etc.). Currently, Part 2 of modeling_visualization_jit.py only works with parametric Sersic — any LightProfileLinear dies with KeyError inside the visualizer callback because the dict is identity-keyed and pytree unflatten produces fresh objects.

Plan

  • Attach a stable _pytree_token (monotonic counter) to every LightProfileLinear instance
  • Override __hash__ / __eq__ on LightProfileLinear to use the token instead of id()
  • Add __exclude_identifier_fields__ so the token never affects PyAutoFit's model unique_id
  • Register LightProfileLinear as a pytree with no_flatten=("_pytree_token",) so the token rides as static aux data across JIT
  • Revert Part 2 of modeling_visualization_jit.py to full MGE (parametric basis of GaussianGradient, NFWSph mass, MGE source) and verify subplot_fit.png lands with no KeyError
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary — hosts LightProfileLinear)
  • PyAutoLens (pytree registration site)
  • autolens_workspace_test (Part 2 revert)

Work Classification

Both (library + workspace)

Branch Survey

Repository Current Branch Dirty?
PyAutoGalaxy main clean
PyAutoLens main clean
autolens_workspace_test main clean

Suggested branch: feature/linear-light-profile-intensity-dict-pytree
Worktree root: ~/Code/PyAutoLabs-wt/linear-light-profile-intensity-dict-pytree/

Option chosen: B (stable identifier token)

Option B over A (structural path) and C (value eq/hash). A requires threading paths through linear_obj_func.light_profile_list construction — higher blast radius. C risks surprise collisions between profiles with identical params.

Implementation Steps

  1. In PyAutoGalaxy/autogalaxy/profiles/light/linear/abstract.py, add to LightProfileLinear:
    import itertools
    _pytree_token_counter = itertools.count()
    __exclude_identifier_fields__ = ("_pytree_token",)
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._pytree_token = next(LightProfileLinear._pytree_token_counter)
    
    def __hash__(self):
        return self._pytree_token
    
    def __eq__(self, other):
        return isinstance(other, LightProfileLinear) and self._pytree_token == other._pytree_token
  2. In PyAutoLens/autolens/imaging/model/analysis.py _register_fit_imaging_pytrees (after line 151): register_instance_pytree(LightProfileLinear, no_flatten=("_pytree_token",))
  3. Revert Part 2 of autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py to full MGE model (look at git log for the pre-split state).
  4. Verify:
    • modeling_visualization_jit.py runs end-to-end, subplot_fit.png lands, no KeyError
    • multi_galaxy_mge.py regression identifier 5a3c480de681f6958048b22b3db8ecf9 unchanged (Option B audit confirmed identifier naturally filters _pytree_token because it's not an __init__ arg; __exclude_identifier_fields__ is belt-and-braces)
    • PyAutoGalaxy + PyAutoLens test suites green

Key Files

  • PyAutoGalaxy/autogalaxy/profiles/light/linear/abstract.pyLightProfileLinear class + lp_instance_from at line 148
  • PyAutoGalaxy/autogalaxy/abstract_fit.py:122linear_light_profile_intensity_dict construction
  • PyAutoLens/autolens/imaging/model/analysis.py_register_fit_imaging_pytrees
  • autolens_workspace_test/scripts/imaging/modeling_visualization_jit.py — Part 2

Audits completed (de-risked Option B)

  • Class structure: LightProfileLinear is a plain class, no dataclass/slots. MRO override of __hash__/__eq__ wins cleanly over GeometryProfile ancestors. Minor caveat: pickled instances retain pickled token.
  • Call sites: the only dict that changes is the target one. LinearObj subclasses (Mapper, LightProfileLinearObjFuncList) are not LightProfileLinear; other identity-keyed dicts unaffected.
  • Identifier: PyAutoFit/autofit/mapper/identifier.py:127-128 filters __dict__ to __init__ args only, so _pytree_token is naturally excluded. __exclude_identifier_fields__ is added explicitly for clarity + insurance against future code paths.

Original Prompt

Click to expand starting prompt

Make `fit.linear_light_profile_intensity_dict` survive a `jax.jit` round-trip so the visualization path works for any model containing a linear light profile. See `admin_jammy/prompt/issued/linear_light_profile_intensity_dict_pytree.md` for the full text.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions