Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 60 additions & 19 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,34 @@ def compute_latent_samples(self, samples: Samples, batch_size : Optional[int] =
start = time.time()
if self.LATENT_BATCH_MODE == "vmap":
logger.info("JAX: Applying vmap and jit to likelihood function for latent variables -- may take a few seconds.")
# vmap traces `compute_latent_variables` once for the whole
# batch, so a per-sample try/except is not possible here —
# latent functions on the vmap path must express failures as
# NaN (e.g. `jnp.where`), never by raising. The `jit` and
# numpy paths below do guard per sample.
batched_compute_latent = jax.jit(jax.vmap(compute_latent_for_model))
elif self.LATENT_BATCH_MODE == "jit":
logger.info("JAX: Applying per-sample jit to latent variables (LATENT_BATCH_MODE='jit') -- may take a few seconds on first sample.")
jitted_compute_latent = jax.jit(compute_latent_for_model)
n_latents = len(self.LATENT_KEYS)
nan_tuple = tuple(jnp.nan for _ in range(n_latents))

def _safe_jitted(p):
# A latent that raises (any exception, not just
# FitException) becomes a NaN row, which the global mask
# below drops — one bad sample must not abort the batch.
try:
return jitted_compute_latent(p)
except Exception:
return nan_tuple

def batched_compute_latent(parameters_batch):
# Per-sample jit returns one (l1, l2, ..., lN) tuple per
# sample. Transpose to a tuple of N batched arrays so the
# downstream `jnp.stack(latent_values_batch, axis=-1)`
# works identically to the vmap path.
sample_results = [
jitted_compute_latent(p) for p in parameters_batch
_safe_jitted(p) for p in parameters_batch
]
n_latents = len(sample_results[0])
return tuple(
Expand All @@ -202,9 +218,12 @@ def batched_compute_latent(parameters_batch):
nan_row = np.full(n_latents, np.nan)

def _safe_compute(xx):
# Any exception (not just FitException) becomes a NaN row,
# which the global mask below drops — a single failing latent
# evaluation must not abort the whole post-fit latent pass.
try:
return compute_latent_for_model(xx)
except exc.FitException:
except Exception:
return nan_row

def batched_compute_latent(x):
Expand Down Expand Up @@ -255,26 +274,48 @@ def batched_compute_latent(x):
else:
all_values = np.empty((0, len(self.LATENT_KEYS)))

# Global masking, in two stages:
# 1. Drop a latent column only if it is non-finite for EVERY sample
# (a genuinely degenerate latent, e.g. a µJy latent with no magzero).
col_mask = np.any(np.isfinite(all_values), axis=0)
kept_keys = [k for k, keep in zip(self.LATENT_KEYS, col_mask) if keep]
kept_values = all_values[:, col_mask]

# 2. Drop individual samples that still carry a NaN in a surviving
# latent (e.g. a FitException NaN row, or a latent that went NaN
# for just that sample). Every survivor now has all `kept_keys`.
if kept_values.size:
row_mask = np.all(np.isfinite(kept_values), axis=1)
else:
row_mask = np.zeros(len(all_samples), dtype=bool)
kept_values = kept_values[row_mask]
kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep]
# Global masking. Every retained sample must share one identical,
# fully-finite latent key set (a `Samples` object is a rectangular
# samples x latents block; NaNs anywhere break `quantile`).
#
# 1. Drop a latent column that is non-finite for EVERY sample
# (genuinely degenerate, e.g. a µJy latent with no magzero).
kept_idx = [
i
for i in range(all_values.shape[1])
if np.isfinite(all_values[:, i]).any()
]

# 2. Keep samples that are finite across all currently-kept latents.
# Normally at least one sample is finite everywhere and we stop
# immediately (behaviour unchanged). But if NaNs are anti-correlated
# across latents — every sample NaN in some latent — the rectangular
# block is empty. Rather than discard ALL latent output, greedily
# sacrifice the worst-coverage latent and retry, retaining the
# maximal-coverage latents and their finite samples.
while kept_idx:
row_mask = np.isfinite(all_values[:, kept_idx]).all(axis=1)
if row_mask.any():
break
nan_counts = (~np.isfinite(all_values[:, kept_idx])).sum(axis=0)
kept_idx.pop(int(np.argmax(nan_counts)))

print(f"Time to compute latent variables: {time.time() - start_latent} seconds for {len(samples)} samples.")

if not kept_keys or len(kept_samples) == 0:
if not kept_idx:
logger.warning(
"compute_latent_samples: no finite latent samples remained "
"after masking; skipping latent output."
)
return None

kept_keys = [self.LATENT_KEYS[i] for i in kept_idx]
kept_values = all_values[:, kept_idx]
row_mask = np.isfinite(kept_values).all(axis=1)
kept_values = kept_values[row_mask]
kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep]

if len(kept_samples) == 0:
logger.warning(
"compute_latent_samples: no finite latent samples remained "
"after masking; skipping latent output."
Expand Down
7 changes: 7 additions & 0 deletions autofit/non_linear/samples/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,13 @@ def quantile(x, q, weights=None):
if np.any(q < 0.0) or np.any(q > 1.0):
raise ValueError("Quantiles must be between 0 and 1")

if x.shape[0] == 1:
# A single sample places all PDF mass at one point, so every quantile is
# that value. Handle this explicitly: the weighted branch below computes
# `np.cumsum(sw)[:-1]`, which is empty for one sample and then indexes
# `cdf[-1]`, raising IndexError.
return [float(x[0])] * len(q)

if weights is None:
return np.percentile(x, list(100.0 * q))
else:
Expand Down
126 changes: 126 additions & 0 deletions test_autofit/analysis/test_latent_variables.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

import autofit as af
Expand Down Expand Up @@ -223,3 +224,128 @@ def test_complex_model():
assert lens.brightness == 2.0

assert instance.source.brightness == 3.0


class RaisingAnalysis(af.Analysis):
"""Latent that raises a non-FitException (e.g. an unexpected solver error)
for some samples."""

LATENT_KEYS = ["fwhm"]

def log_likelihood_function(self, instance):
return 1.0

def compute_latent_variables(self, parameters, model):
if parameters[0] < 0:
raise ValueError("boom from latent function")
instance = model.instance_from_vector(vector=parameters)
return (instance.fwhm,)


def test_compute_latent_samples_skips_arbitrary_exception_samples():
"""
A latent function that raises ANY exception (not just FitException) must
drop that sample rather than crashing the whole latent pass.
"""
analysis = RaisingAnalysis()
latent_samples = analysis.compute_latent_samples(
SamplesPDF(
model=af.Model(af.ex.Gaussian),
sample_list=[
af.Sample(log_likelihood=1.0, log_prior=0.0, weight=1.0,
kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0}),
af.Sample(log_likelihood=-1.0, log_prior=0.0, weight=0.0,
kwargs={"centre": -1.0, "normalization": 2.0, "sigma": 3.0}),
],
),
)
assert len(latent_samples.sample_list) == 1
assert latent_samples.sample_list[0].kwargs == {("fwhm",): 7.0644601350928475}


class AntiCorrelatedNaNAnalysis(af.Analysis):
"""Two latents whose NaNs are anti-correlated across samples: latent ``a`` is
NaN where ``centre < 0`` and latent ``b`` is NaN where ``centre >= 0``. No
single sample is finite in both."""

LATENT_KEYS = ["a", "b"]

def log_likelihood_function(self, instance):
return 1.0

def compute_latent_variables(self, parameters, model):
c = parameters[0]
a = np.nan if c < 0 else 1.0
b = np.nan if c >= 0 else 2.0
return (a, b)


def test_compute_latent_samples_salvages_anti_correlated_nans():
"""
When NaNs are anti-correlated so the rectangular (samples x latents) block is
empty, the masking must NOT discard all latent output. It sacrifices the
worst-coverage latent and retains the maximal-coverage one with its finite
samples — rather than returning None.
"""
analysis = AntiCorrelatedNaNAnalysis()
latent_samples = analysis.compute_latent_samples(
SamplesPDF(
model=af.Model(af.ex.Gaussian),
sample_list=[
af.Sample(log_likelihood=1.0, log_prior=0.0, weight=1.0,
kwargs={"centre": cv, "normalization": 2.0, "sigma": 3.0})
for cv in (1.0, -1.0, 2.0, -2.0)
],
),
)
assert latent_samples is not None
# The two latents tie on NaN count; the first (a) is sacrificed, b is kept.
surviving_keys = set(latent_samples.sample_list[0].kwargs)
assert surviving_keys == {("b",)}
# All retained samples share the same single-key set (no KeyError downstream).
assert all(set(s.kwargs) == surviving_keys for s in latent_samples.sample_list)


class OneSurvivorAnalysis(af.Analysis):
"""Only the ``centre == 1.0`` sample yields a finite latent; all others NaN."""

LATENT_KEYS = ["fwhm"]

def log_likelihood_function(self, instance):
return 1.0

def compute_latent_variables(self, parameters, model):
if parameters[0] != 1.0:
return (np.nan,)
instance = model.instance_from_vector(vector=parameters)
return (instance.fwhm,)


def test_compute_latent_samples_single_survivor_summary_does_not_crash():
"""
When masking leaves exactly one finite latent sample whose weight is < 0.99
(so `pdf_converged` is True and `median_pdf` uses `quantile`), `summary()` and
`median_pdf()` must succeed. Regression for the `quantile` n=1 IndexError.
"""
analysis = OneSurvivorAnalysis()
latent_samples = analysis.compute_latent_samples(
SamplesPDF(
model=af.Model(af.ex.Gaussian),
sample_list=[
af.Sample(log_likelihood=3.0, log_prior=0.0, weight=0.5,
kwargs={"centre": 1.0, "normalization": 2.0, "sigma": 3.0}),
af.Sample(log_likelihood=2.0, log_prior=0.0, weight=0.3,
kwargs={"centre": 2.0, "normalization": 2.0, "sigma": 3.0}),
af.Sample(log_likelihood=1.0, log_prior=0.0, weight=0.2,
kwargs={"centre": 3.0, "normalization": 2.0, "sigma": 3.0}),
],
),
)
assert latent_samples is not None
assert len(latent_samples.sample_list) == 1
# weight 0.5 <= 0.99 => pdf_converged True => median_pdf routes to quantile (n=1).
assert latent_samples.pdf_converged
# The real regression: these must not raise.
_ = latent_samples.summary()
instance = latent_samples.median_pdf()
assert instance.fwhm == pytest.approx(7.0644601350928475)
15 changes: 15 additions & 0 deletions test_autofit/non_linear/samples/test_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,18 @@ def test__covariance_matrix(make_samples):
assert samples_x5.covariance_matrix == pytest.approx(
np.array([[0.90909, -0.90909], [-0.90909, 0.90909]]), 1.0e-4
)


def test__quantile_single_weighted_sample_does_not_crash():
"""
A weighted `quantile` over a single sample must return that sample's value
for every quantile rather than raising. Previously `np.cumsum(sw)[:-1]` was
empty for one sample and `cdf[-1]` raised IndexError — surfaced when latent
masking left exactly one finite sample whose weight was < 0.99.
"""
from autofit.non_linear.samples.pdf import quantile

assert quantile(x=[5.0], q=0.5, weights=[0.3]) == [5.0]
assert quantile(x=[5.0], q=[0.16, 0.5, 0.84], weights=[0.3]) == [5.0, 5.0, 5.0]
# n >= 2 weighted path is unchanged.
assert quantile(x=[1.0, 2.0], q=0.5, weights=[0.5, 0.5]) == pytest.approx([1.5])
Loading