Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions autoconf/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,105 @@ def skip_latents():
return is_test_mode() or os.environ.get("PYAUTO_SKIP_LATENTS", "0") == "1"


def latent_nan_inject_spec():
"""
Return the ``PYAUTO_LATENT_NAN_INJECT`` spec string, or ``None``.

This is a **test-only** knob used by the ``*_workspace_test`` integration
suites to deliberately poison latent-variable values with NaNs in an
arbitrary per-sample pattern, reproducing the failure mode where
``compute_latent_samples`` produced ``Sample`` objects with inconsistent
key sets (see ``autofit/non_linear/analysis/analysis.py``). It is
``None`` (a no-op) in every normal/production run because the env var
is unset.

Supported spec, used by :func:`inject_latent_nans`:

- ``"stride:N"`` — set NaN on column 0 of every sample whose **absolute**
index (across the full sample list, not the per-batch index) is a
non-zero multiple of ``N``. Using absolute indices makes the NaN
pattern straddle batch boundaries, which is the exact condition
needed to surface per-batch (rather than global) masking bugs.

Index 0 is deliberately **never** poisoned: with the latent batch
size ``B`` chosen so that ``N >= B``, batch 0 (absolute indices
``0 .. B-1``) is left fully finite, so it produces samples with the
*complete* latent key set and seeds the latent ``Samples`` model with
all keys. A later batch then loses column 0 (under the buggy per-batch
JAX column mask), giving it a *reduced* key set — the mismatch that
raises ``KeyError`` in ``Samples.summary()``. Were index 0 poisoned,
batch 0 would itself be reduced and the mismatch would not surface.
"""
return os.environ.get("PYAUTO_LATENT_NAN_INJECT") or None


def inject_latent_nans(values_2d, start_index):
"""
Apply :func:`latent_nan_inject_spec` to a materialised
``(n_samples, n_latents)`` latent-value array.

Parameters
----------
values_2d
A NumPy or JAX array of latent values for one batch, shape
``(n_samples_in_batch, n_latents)``.
start_index
The absolute index of row 0 of ``values_2d`` within the full sample
list (i.e. the batch offset). NaNs are placed using
``start_index + local_row`` so the pattern is consistent across
batch boundaries.

Returns
-------
The array with NaNs injected per the active spec. When the spec is unset
or unrecognised the input is returned unchanged (true no-op).

Notes
-----
Handles both NumPy and JAX arrays. For JAX the functional ``.at[...].set``
update is used (``jax.numpy`` imported locally, matching the library's
convention of never importing JAX at module scope).
"""
spec = latent_nan_inject_spec()
if not spec:
return values_2d

if not spec.startswith("stride:"):
return values_2d

try:
stride = int(spec.split(":", 1)[1])
except (ValueError, IndexError):
return values_2d
if stride <= 0:
return values_2d

n_rows = values_2d.shape[0]
nan_rows = [
r
for r in range(n_rows)
if (start_index + r) != 0 and (start_index + r) % stride == 0
]
if not nan_rows:
return values_2d

# JAX arrays expose ``.at`` for functional updates; NumPy arrays do not.
if hasattr(values_2d, "at"):
import jax.numpy as jnp

out = values_2d
for r in nan_rows:
out = out.at[r, 0].set(jnp.nan)
return out

import numpy as np

out = np.array(values_2d, copy=True)
for r in nan_rows:
out[r, 0] = np.nan
return out


def with_test_mode_segment(base: Path) -> Path:
"""
Return ``base`` with a ``test_mode`` segment appended when
Expand Down
Loading