From 4ce3ba6b1bb18fba7e06da25748e05935c27bab9 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Mon, 1 Jun 2026 16:03:43 +0100 Subject: [PATCH] fix(latent): global masking in compute_latent_samples to prevent KeyError Latent finite-masking was computed per batch on the JAX path, so a latent that went NaN for one sample in a batch had its column dropped for that batch only. Samples then carried inconsistent kwargs key sets and Samples.summary() raised KeyError. Accumulate all batches and mask once globally (col-then-row); return None when nothing finite remains. Depends on autoconf.test_mode hook. Co-Authored-By: Claude Opus 4.8 (1M context) --- autofit/non_linear/analysis/analysis.py | 94 +++++++++++++++++-------- 1 file changed, 65 insertions(+), 29 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 097ceb911..6f626b9b8 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -210,10 +210,23 @@ def _safe_compute(xx): def batched_compute_latent(x): return np.array([_safe_compute(xx) for xx in x]) + from autoconf.test_mode import inject_latent_nans + parameter_array = np.array(samples.parameter_lists) - latent_samples = [] - # process in batches + # Compute every batch first and accumulate the raw, UN-masked latent + # values into one (n_samples, n_latents) array. Masking is then done + # ONCE, globally, after the loop (see below). + # + # Doing the finite mask per batch (the previous behaviour) was a bug: + # a latent that went NaN for a single sample in one batch had its whole + # column dropped *for that batch only*, while other batches kept it. + # The resulting `Sample` objects then carried inconsistent kwargs key + # sets, and `Samples.summary()` raised `KeyError` building its model + # from the first sample's keys. Masking globally guarantees every + # retained sample shares one identical key set. + all_values = [] + all_samples = [] for i in range(0, len(parameter_array), batch_size): batch = parameter_array[i:i + batch_size] @@ -225,36 +238,59 @@ def batched_compute_latent(x): if self._use_jax: import jax.numpy as jnp latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents) - mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0) - latent_values_batch = latent_values_batch[:, mask] - else: - # Drop samples whose latent computation failed (e.g. FitException from - # model assertions surfaced as a NaN row in _safe_compute). This leaves - # the per-latent column mask to continue handling degenerate latent - # dimensions that produce NaN for all remaining samples. - row_mask = np.all(np.isfinite(latent_values_batch), axis=1) - latent_values_batch = latent_values_batch[row_mask] - batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep] - - if len(latent_values_batch): - col_mask = np.all(np.isfinite(latent_values_batch), axis=0) - latent_values_batch = latent_values_batch[:, col_mask] - - for sample, values in zip(batch_samples, latent_values_batch): - - kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)} - - latent_samples.append( - Sample( - log_likelihood=sample.log_likelihood, - log_prior=sample.log_prior, - weight=sample.weight, - kwargs=kwargs, - ) - ) + + # Unify to NumPy so the global masking below is a single code path + # for both backends (latent values are scalars, host transfer is + # cheap and was already forced by the downstream `float(v)`). + latent_values_batch = np.asarray(latent_values_batch) + + # Test-only NaN injection (no-op unless PYAUTO_LATENT_NAN_INJECT set). + latent_values_batch = inject_latent_nans(latent_values_batch, start_index=i) + + all_values.append(latent_values_batch) + all_samples.extend(batch_samples) + + if all_values: + all_values = np.concatenate(all_values, axis=0) + else: + all_values = np.empty((0, len(self.LATENT_KEYS))) + + # Global masking, in two stages: + # 1. Drop a latent column only if it is non-finite for EVERY sample + # (a genuinely degenerate latent, e.g. a µJy latent with no magzero). + col_mask = np.any(np.isfinite(all_values), axis=0) + kept_keys = [k for k, keep in zip(self.LATENT_KEYS, col_mask) if keep] + kept_values = all_values[:, col_mask] + + # 2. Drop individual samples that still carry a NaN in a surviving + # latent (e.g. a FitException NaN row, or a latent that went NaN + # for just that sample). Every survivor now has all `kept_keys`. + if kept_values.size: + row_mask = np.all(np.isfinite(kept_values), axis=1) + else: + row_mask = np.zeros(len(all_samples), dtype=bool) + kept_values = kept_values[row_mask] + kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep] print(f"Time to compute latent variables: {time.time() - start_latent} seconds for {len(samples)} samples.") + if not kept_keys or len(kept_samples) == 0: + logger.warning( + "compute_latent_samples: no finite latent samples remained " + "after masking; skipping latent output." + ) + return None + + latent_samples = [ + Sample( + log_likelihood=sample.log_likelihood, + log_prior=sample.log_prior, + weight=sample.weight, + kwargs={k: float(v) for k, v in zip(kept_keys, values)}, + ) + for sample, values in zip(kept_samples, kept_values) + ] + return type(samples)( sample_list=latent_samples, model=simple_model_for_kwargs(latent_samples[0].kwargs),