Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 65 additions & 29 deletions autofit/non_linear/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,23 @@ def _safe_compute(xx):
def batched_compute_latent(x):
return np.array([_safe_compute(xx) for xx in x])

from autoconf.test_mode import inject_latent_nans

parameter_array = np.array(samples.parameter_lists)
latent_samples = []

# process in batches
# Compute every batch first and accumulate the raw, UN-masked latent
# values into one (n_samples, n_latents) array. Masking is then done
# ONCE, globally, after the loop (see below).
#
# Doing the finite mask per batch (the previous behaviour) was a bug:
# a latent that went NaN for a single sample in one batch had its whole
# column dropped *for that batch only*, while other batches kept it.
# The resulting `Sample` objects then carried inconsistent kwargs key
# sets, and `Samples.summary()` raised `KeyError` building its model
# from the first sample's keys. Masking globally guarantees every
# retained sample shares one identical key set.
all_values = []
all_samples = []
for i in range(0, len(parameter_array), batch_size):

batch = parameter_array[i:i + batch_size]
Expand All @@ -225,36 +238,59 @@ def batched_compute_latent(x):
if self._use_jax:
import jax.numpy as jnp
latent_values_batch = jnp.stack(latent_values_batch, axis=-1) # (batch, n_latents)
mask = jnp.all(jnp.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, mask]
else:
# Drop samples whose latent computation failed (e.g. FitException from
# model assertions surfaced as a NaN row in _safe_compute). This leaves
# the per-latent column mask to continue handling degenerate latent
# dimensions that produce NaN for all remaining samples.
row_mask = np.all(np.isfinite(latent_values_batch), axis=1)
latent_values_batch = latent_values_batch[row_mask]
batch_samples = [s for s, keep in zip(batch_samples, row_mask) if keep]

if len(latent_values_batch):
col_mask = np.all(np.isfinite(latent_values_batch), axis=0)
latent_values_batch = latent_values_batch[:, col_mask]

for sample, values in zip(batch_samples, latent_values_batch):

kwargs = {k: float(v) for k, v in zip(self.LATENT_KEYS, values)}

latent_samples.append(
Sample(
log_likelihood=sample.log_likelihood,
log_prior=sample.log_prior,
weight=sample.weight,
kwargs=kwargs,
)
)

# Unify to NumPy so the global masking below is a single code path
# for both backends (latent values are scalars, host transfer is
# cheap and was already forced by the downstream `float(v)`).
latent_values_batch = np.asarray(latent_values_batch)

# Test-only NaN injection (no-op unless PYAUTO_LATENT_NAN_INJECT set).
latent_values_batch = inject_latent_nans(latent_values_batch, start_index=i)

all_values.append(latent_values_batch)
all_samples.extend(batch_samples)

if all_values:
all_values = np.concatenate(all_values, axis=0)
else:
all_values = np.empty((0, len(self.LATENT_KEYS)))

# Global masking, in two stages:
# 1. Drop a latent column only if it is non-finite for EVERY sample
# (a genuinely degenerate latent, e.g. a µJy latent with no magzero).
col_mask = np.any(np.isfinite(all_values), axis=0)
kept_keys = [k for k, keep in zip(self.LATENT_KEYS, col_mask) if keep]
kept_values = all_values[:, col_mask]

# 2. Drop individual samples that still carry a NaN in a surviving
# latent (e.g. a FitException NaN row, or a latent that went NaN
# for just that sample). Every survivor now has all `kept_keys`.
if kept_values.size:
row_mask = np.all(np.isfinite(kept_values), axis=1)
else:
row_mask = np.zeros(len(all_samples), dtype=bool)
kept_values = kept_values[row_mask]
kept_samples = [s for s, keep in zip(all_samples, row_mask) if keep]

print(f"Time to compute latent variables: {time.time() - start_latent} seconds for {len(samples)} samples.")

if not kept_keys or len(kept_samples) == 0:
logger.warning(
"compute_latent_samples: no finite latent samples remained "
"after masking; skipping latent output."
)
return None

latent_samples = [
Sample(
log_likelihood=sample.log_likelihood,
log_prior=sample.log_prior,
weight=sample.weight,
kwargs={k: float(v) for k, v in zip(kept_keys, values)},
)
for sample, values in zip(kept_samples, kept_values)
]

return type(samples)(
sample_list=latent_samples,
model=simple_model_for_kwargs(latent_samples[0].kwargs),
Expand Down
Loading