fix(latent): global masking in compute_latent_samples to prevent KeyError on per-batch NaN drops#1310
Open
Jammy2211 wants to merge 1 commit into
Open
fix(latent): global masking in compute_latent_samples to prevent KeyError on per-batch NaN drops#1310Jammy2211 wants to merge 1 commit into
Jammy2211 wants to merge 1 commit into
Conversation
…rror Latent finite-masking was computed per batch on the JAX path, so a latent that went NaN for one sample in a batch had its column dropped for that batch only. Samples then carried inconsistent kwargs key sets and Samples.summary() raised KeyError. Accumulate all batches and mask once globally (col-then-row); return None when nothing finite remains. Depends on autoconf.test_mode hook. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes a
KeyErrorcrash at the end of a search when latent variables are enabled and any latent goes NaN for a subset of samples (e.g.total_lensed_source_fluxfor a degenerate source, oreffective_einstein_radiuswhere the critical-curve solve fails).Root cause:
compute_latent_samplesmasked finite latent columns per-batch on the JAX path (jnp.all(isfinite, axis=0)over abatch_size=10chunk). A single NaN sample dropped that latent's whole column for that batch only; other batches kept it. The resultingSampleobjects then carried inconsistentkwargskey sets, andSamples.summary()raisedKeyErrorbuilding its model from the first sample's keys. (The NumPy path row-masks first, so it never hit this — the crash was JAX-only, matching the reported traceback.)Fix: accumulate all batches first, then mask once globally — drop a latent only if it is non-finite for every sample, then drop individual samples still carrying a NaN in a surviving latent. Every retained sample now shares one identical key set. Adds a guard returning
Nonewhen nothing finite remains (previously anIndexError).Depends on the new
autoconf.test_mode.inject_latent_nanshook (lazy import) — merge the PyAutoConf PR first.API Changes
No signature changes. Behaviour change to
Analysis.compute_latent_samples: latent finite-masking is now global instead of per-batch, and a fully-degenerate latent set returnsNonerather than raising.See full details below.
Test Plan
pytest test_autofit/non_linearpasses (NSS optional-dependency tests excluded — pre-existingaf.NSSImportErrors unrelated to this change).latent_nan_robustness.pyin the autofit/autogalaxy/autolens*_workspace_testrepos reproduce theKeyErrorpre-fix and pass post-fix.latent_variables_smoke.pystill passes with the injection env var unset.Full API Changes (for automation & release notes)
Changed Behaviour
autofit.non_linear.analysis.Analysis.compute_latent_samples— finite-value masking of latent samples is now computed once across all samples (global col-then-row) instead of per batch, removing the inconsistent per-sample latent key sets that causedSamples.summary()to raiseKeyError. When no finite latent remains it returnsNone(previously raisedIndexError). Signature unchanged.🤖 Generated with Claude Code