From df58b69025e3c38b37f64c2eeb42b5855d80ec43 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 7 Jun 2026 15:29:55 +0100 Subject: [PATCH] fix(latent): handle degenerate latent edge cases (quantile n=1, latent exceptions, anti-correlated NaNs) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to the per-batch latent NaN-masking fix; a sanity probe surfaced three further break scenarios in the latent pipeline: A. quantile() crashed for a single weighted sample — np.cumsum(sw)[:-1] is empty for n=1 and cdf[-1] raised IndexError. General Samples bug, exposed when latent masking leaves one finite survivor with weight <= 0.99. Fix: return the sample's value for every quantile when n==1. B. A latent function raising a non-FitException crashed the whole latent pass. _safe_compute (numpy) only caught FitException; the JAX jit per-sample path had no guard. Fix: broaden both to catch any Exception -> NaN row (dropped by the mask). vmap path documented as raise-free (traced once). C. Anti-correlated NaNs (every sample NaN in some latent) dropped all samples -> None, losing all latent output. Fix: greedy column-salvage — sacrifice the worst-coverage latent(s) and retain the maximal-coverage ones with their finite samples. Normal fits are unchanged (stop on first finite row). Adds unit tests in test_autofit for all three. Co-Authored-By: Claude Opus 4.8 (1M context) --- autofit/non_linear/analysis/analysis.py | 79 ++++++++--- autofit/non_linear/samples/pdf.py | 7 + .../analysis/test_latent_variables.py | 126 ++++++++++++++++++ test_autofit/non_linear/samples/test_pdf.py | 15 +++ 4 files changed, 208 insertions(+), 19 deletions(-) diff --git a/autofit/non_linear/analysis/analysis.py b/autofit/non_linear/analysis/analysis.py index 6f626b9b8..df17360b2 100644 --- a/autofit/non_linear/analysis/analysis.py +++ b/autofit/non_linear/analysis/analysis.py @@ -174,10 +174,26 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] = start = time.time() if self.LATENT_BATCH_MODE == "vmap": logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.") + # vmap traces `compute_latent_variables` once for the whole + # batch, so a per-sample try/except is not possible here — + # latent functions on the vmap path must express failures as + # NaN (e.g. `jnp.where`), never by raising. The `jit` and + # numpy paths below do guard per sample. batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model)) elif self.LATENT_BATCH_MODE == "jit": logger.info("JAX: Applying per-sample jit to latent variables (LATENT_BATCH_MODE='jit') -- may take a few seconds on first sample.") jitted_compute_latent = jax.jit(compute_latent_for_model) + n_latents = len(self.LATENT_KEYS) + nan_tuple = tuple(jnp.nan for _ in range(n_latents)) + + def _safe_jitted(p): + # A latent that raises (any exception, not just + # FitException) becomes a NaN row, which the global mask + # below drops — one bad sample must not abort the batch. + try: + return jitted_compute_latent(p) + except Exception: + return nan_tuple def batched_compute_latent(parameters_batch): # Per-sample jit returns one (l1, l2, ..., lN) tuple per @@ -185,7 +201,7 @@ def batched_compute_latent(parameters_batch): # downstream `jnp.stack(latent_values_batch, axis=-1)` # works identically to the vmap path. sample_results = [ - jitted_compute_latent(p) for p in parameters_batch + _safe_jitted(p) for p in parameters_batch ] n_latents = len(sample_results[0]) return tuple( @@ -202,9 +218,12 @@ def batched_compute_latent(parameters_batch): nan_row = np.full(n_latents, np.nan) def _safe_compute(xx): + # Any exception (not just FitException) becomes a NaN row, + # which the global mask below drops — a single failing latent + # evaluation must not abort the whole post-fit latent pass. try: return compute_latent_for_model(xx) - except exc.FitException: + except Exception: return nan_row def batched_compute_latent(x): @@ -255,26 +274,48 @@ def batched_compute_latent(x): else: all_values = np.empty((0, len(self.LATENT_KEYS))) - # Global masking, in two stages: - # 1. Drop a latent column only if it is non-finite for EVERY sample - # (a genuinely degenerate latent, e.g. a µJy latent with no magzero). - col_mask = np.any(np.isfinite(all_values), axis=0) - kept_keys = [k for k, keep in zip(self.LATENT_KEYS, col_mask) if keep] - kept_values = all_values[:, col_mask] - - # 2. Drop individual samples that still carry a NaN in a surviving - # latent (e.g. a FitException NaN row, or a latent that went NaN - # for just that sample). Every survivor now has all `kept_keys`. - if kept_values.size: - row_mask = np.all(np.isfinite(kept_values), axis=1) - else: - row_mask = np.zeros(len(all_samples), dtype=bool) - kept_values = kept_values[row_mask] - kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep] + # Global masking. Every retained sample must share one identical, + # fully-finite latent key set (a `Samples` object is a rectangular + # samples x latents block; NaNs anywhere break `quantile`). + # + # 1. Drop a latent column that is non-finite for EVERY sample + # (genuinely degenerate, e.g. a µJy latent with no magzero). + kept_idx = [ + i + for i in range(all_values.shape[1]) + if np.isfinite(all_values[:, i]).any() + ] + + # 2. Keep samples that are finite across all currently-kept latents. + # Normally at least one sample is finite everywhere and we stop + # immediately (behaviour unchanged). But if NaNs are anti-correlated + # across latents — every sample NaN in some latent — the rectangular + # block is empty. Rather than discard ALL latent output, greedily + # sacrifice the worst-coverage latent and retry, retaining the + # maximal-coverage latents and their finite samples. + while kept_idx: + row_mask = np.isfinite(all_values[:, kept_idx]).all(axis=1) + if row_mask.any(): + break + nan_counts = (~np.isfinite(all_values[:, kept_idx])).sum(axis=0) + kept_idx.pop(int(np.argmax(nan_counts))) print(f"Time to compute latent variables: {time.time() - start_latent} seconds for {len(samples)} samples.") - if not kept_keys or len(kept_samples) == 0: + if not kept_idx: + logger.warning( + "compute_latent_samples: no finite latent samples remained " + "after masking; skipping latent output." + ) + return None + + kept_keys = [self.LATENT_KEYS[i] for i in kept_idx] + kept_values = all_values[:, kept_idx] + row_mask = np.isfinite(kept_values).all(axis=1) + kept_values = kept_values[row_mask] + kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep] + + if len(kept_samples) == 0: logger.warning( "compute_latent_samples: no finite latent samples remained " "after masking; skipping latent output." diff --git a/autofit/non_linear/samples/pdf.py b/autofit/non_linear/samples/pdf.py index 6eb18de9e..3ea842f33 100644 --- a/autofit/non_linear/samples/pdf.py +++ b/autofit/non_linear/samples/pdf.py @@ -481,6 +481,13 @@ def quantile(x, q, weights=None): if np.any(q < 0.0) or np.any(q > 1.0): raise ValueError("Quantiles must be between 0 and 1") + if x.shape[0] == 1: + # A single sample places all PDF mass at one point, so every quantile is + # that value. Handle this explicitly: the weighted branch below computes + # `np.cumsum(sw)[:-1]`, which is empty for one sample and then indexes + # `cdf[-1]`, raising IndexError. + return [float(x[0])] * len(q) + if weights is None: return np.percentile(x, list(100.0 * q)) else: diff --git a/test_autofit/analysis/test_latent_variables.py b/test_autofit/analysis/test_latent_variables.py index 3c9f950cc..22bf74838 100644 --- a/test_autofit/analysis/test_latent_variables.py +++ b/test_autofit/analysis/test_latent_variables.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import autofit as af @@ -223,3 +224,128 @@ def test_complex_model(): assert lens.brightness == 2.0 assert instance.source.brightness == 3.0 + + +class RaisingAnalysis(af.Analysis): + """Latent that raises a non-FitException (e.g. an unexpected solver error) + for some samples.""" + + LATENT_KEYS = ["fwhm"] + + def log_likelihood_function(self, instance): + return 1.0 + + def compute_latent_variables(self, parameters, model): + if parameters[0] < 0: + raise ValueError("boom from latent function") + instance = model.instance_from_vector(vector=parameters) + return (instance.fwhm,) + + +def test_compute_latent_samples_skips_arbitrary_exception_samples(): + """ + A latent function that raises ANY exception (not just FitException) must + drop that sample rather than crashing the whole latent pass. + """ + analysis = RaisingAnalysis() + latent_samples = analysis.compute_latent_samples( + SamplesPDF( + model=af.Model(af.ex.Gaussian), + sample_list=[ + af.Sample(log_likelihood=1.0, log_prior=0.0, weight=1.0, + kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0}), + af.Sample(log_likelihood=-1.0, log_prior=0.0, weight=0.0, + kwargs={"centre": -1.0, "normalization": 2.0, "sigma": 3.0}), + ], + ), + ) + assert len(latent_samples.sample_list) == 1 + assert latent_samples.sample_list[0].kwargs == {("fwhm",): 7.0644601350928475} + + +class AntiCorrelatedNaNAnalysis(af.Analysis): + """Two latents whose NaNs are anti-correlated across samples: latent ``a`` is + NaN where ``centre < 0`` and latent ``b`` is NaN where ``centre >= 0``. No + single sample is finite in both.""" + + LATENT_KEYS = ["a", "b"] + + def log_likelihood_function(self, instance): + return 1.0 + + def compute_latent_variables(self, parameters, model): + c = parameters[0] + a = np.nan if c < 0 else 1.0 + b = np.nan if c >= 0 else 2.0 + return (a, b) + + +def test_compute_latent_samples_salvages_anti_correlated_nans(): + """ + When NaNs are anti-correlated so the rectangular (samples x latents) block is + empty, the masking must NOT discard all latent output. It sacrifices the + worst-coverage latent and retains the maximal-coverage one with its finite + samples — rather than returning None. + """ + analysis = AntiCorrelatedNaNAnalysis() + latent_samples = analysis.compute_latent_samples( + SamplesPDF( + model=af.Model(af.ex.Gaussian), + sample_list=[ + af.Sample(log_likelihood=1.0, log_prior=0.0, weight=1.0, + kwargs={"centre": cv, "normalization": 2.0, "sigma": 3.0}) + for cv in (1.0, -1.0, 2.0, -2.0) + ], + ), + ) + assert latent_samples is not None + # The two latents tie on NaN count; the first (a) is sacrificed, b is kept. + surviving_keys = set(latent_samples.sample_list[0].kwargs) + assert surviving_keys == {("b",)} + # All retained samples share the same single-key set (no KeyError downstream). + assert all(set(s.kwargs) == surviving_keys for s in latent_samples.sample_list) + + +class OneSurvivorAnalysis(af.Analysis): + """Only the ``centre == 1.0`` sample yields a finite latent; all others NaN.""" + + LATENT_KEYS = ["fwhm"] + + def log_likelihood_function(self, instance): + return 1.0 + + def compute_latent_variables(self, parameters, model): + if parameters[0] != 1.0: + return (np.nan,) + instance = model.instance_from_vector(vector=parameters) + return (instance.fwhm,) + + +def test_compute_latent_samples_single_survivor_summary_does_not_crash(): + """ + When masking leaves exactly one finite latent sample whose weight is < 0.99 + (so `pdf_converged` is True and `median_pdf` uses `quantile`), `summary()` and + `median_pdf()` must succeed. Regression for the `quantile` n=1 IndexError. + """ + analysis = OneSurvivorAnalysis() + latent_samples = analysis.compute_latent_samples( + SamplesPDF( + model=af.Model(af.ex.Gaussian), + sample_list=[ + af.Sample(log_likelihood=3.0, log_prior=0.0, weight=0.5, + kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0}), + af.Sample(log_likelihood=2.0, log_prior=0.0, weight=0.3, + kwargs={"centre": 2.0, "normalization": 2.0, "sigma": 3.0}), + af.Sample(log_likelihood=1.0, log_prior=0.0, weight=0.2, + kwargs={"centre": 3.0, "normalization": 2.0, "sigma": 3.0}), + ], + ), + ) + assert latent_samples is not None + assert len(latent_samples.sample_list) == 1 + # weight 0.5 <= 0.99 => pdf_converged True => median_pdf routes to quantile (n=1). + assert latent_samples.pdf_converged + # The real regression: these must not raise. + _ = latent_samples.summary() + instance = latent_samples.median_pdf() + assert instance.fwhm == pytest.approx(7.0644601350928475) diff --git a/test_autofit/non_linear/samples/test_pdf.py b/test_autofit/non_linear/samples/test_pdf.py index 9ba92d046..355574c46 100644 --- a/test_autofit/non_linear/samples/test_pdf.py +++ b/test_autofit/non_linear/samples/test_pdf.py @@ -497,3 +497,18 @@ def test__covariance_matrix(make_samples): assert samples_x5.covariance_matrix == pytest.approx( np.array([[0.90909, -0.90909], [-0.90909, 0.90909]]), 1.0e-4 ) + + +def test__quantile_single_weighted_sample_does_not_crash(): + """ + A weighted `quantile` over a single sample must return that sample's value + for every quantile rather than raising. Previously `np.cumsum(sw)[:-1]` was + empty for one sample and `cdf[-1]` raised IndexError — surfaced when latent + masking left exactly one finite sample whose weight was < 0.99. + """ + from autofit.non_linear.samples.pdf import quantile + + assert quantile(x=[5.0], q=0.5, weights=[0.3]) == [5.0] + assert quantile(x=[5.0], q=[0.16, 0.5, 0.84], weights=[0.3]) == [5.0, 5.0, 5.0] + # n >= 2 weighted path is unchanged. + assert quantile(x=[1.0, 2.0], q=0.5, weights=[0.5, 0.5]) == pytest.approx([1.5])