Skip to content

fix(latent): global masking in compute_latent_samples to prevent KeyError on per-batch NaN drops#1310

Open
Jammy2211 wants to merge 1 commit into
mainfrom
feature/latent-nan-masking-fix
Open

fix(latent): global masking in compute_latent_samples to prevent KeyError on per-batch NaN drops#1310
Jammy2211 wants to merge 1 commit into
mainfrom
feature/latent-nan-masking-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fixes a KeyError crash at the end of a search when latent variables are enabled and any latent goes NaN for a subset of samples (e.g. total_lensed_source_flux for a degenerate source, or effective_einstein_radius where the critical-curve solve fails).

Root cause: compute_latent_samples masked finite latent columns per-batch on the JAX path (jnp.all(isfinite, axis=0) over a batch_size=10 chunk). A single NaN sample dropped that latent's whole column for that batch only; other batches kept it. The resulting Sample objects then carried inconsistent kwargs key sets, and Samples.summary() raised KeyError building its model from the first sample's keys. (The NumPy path row-masks first, so it never hit this — the crash was JAX-only, matching the reported traceback.)

Fix: accumulate all batches first, then mask once globally — drop a latent only if it is non-finite for every sample, then drop individual samples still carrying a NaN in a surviving latent. Every retained sample now shares one identical key set. Adds a guard returning None when nothing finite remains (previously an IndexError).

Depends on the new autoconf.test_mode.inject_latent_nans hook (lazy import) — merge the PyAutoConf PR first.

API Changes

No signature changes. Behaviour change to Analysis.compute_latent_samples: latent finite-masking is now global instead of per-batch, and a fully-degenerate latent set returns None rather than raising.
See full details below.

Test Plan

  • pytest test_autofit/non_linear passes (NSS optional-dependency tests excluded — pre-existing af.NSS ImportErrors unrelated to this change).
  • latent_nan_robustness.py in the autofit/autogalaxy/autolens *_workspace_test repos reproduce the KeyError pre-fix and pass post-fix.
  • latent_variables_smoke.py still passes with the injection env var unset.
Full API Changes (for automation & release notes)

Changed Behaviour

  • autofit.non_linear.analysis.Analysis.compute_latent_samples — finite-value masking of latent samples is now computed once across all samples (global col-then-row) instead of per batch, removing the inconsistent per-sample latent key sets that caused Samples.summary() to raise KeyError. When no finite latent remains it returns None (previously raised IndexError). Signature unchanged.

🤖 Generated with Claude Code

…rror

Latent finite-masking was computed per batch on the JAX path, so a latent
that went NaN for one sample in a batch had its column dropped for that batch
only. Samples then carried inconsistent kwargs key sets and Samples.summary()
raised KeyError. Accumulate all batches and mask once globally (col-then-row);
return None when nothing finite remains. Depends on autoconf.test_mode hook.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label Jun 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant