diff --git a/autoconf/test_mode.py b/autoconf/test_mode.py index 22b14ba..ece7add 100644 --- a/autoconf/test_mode.py +++ b/autoconf/test_mode.py @@ -64,6 +64,105 @@ def skip_latents(): return is_test_mode() or os.environ.get("PYAUTO_SKIP_LATENTS", "0") == "1" +def latent_nan_inject_spec(): + """ + Return the ``PYAUTO_LATENT_NAN_INJECT`` spec string, or ``None``. + + This is a **test-only** knob used by the ``*_workspace_test`` integration + suites to deliberately poison latent-variable values with NaNs in an + arbitrary per-sample pattern, reproducing the failure mode where + ``compute_latent_samples`` produced ``Sample`` objects with inconsistent + key sets (see ``autofit/non_linear/analysis/analysis.py``). It is + ``None`` (a no-op) in every normal/production run because the env var + is unset. + + Supported spec, used by :func:`inject_latent_nans`: + + - ``"stride:N"`` — set NaN on column 0 of every sample whose **absolute** + index (across the full sample list, not the per-batch index) is a + non-zero multiple of ``N``. Using absolute indices makes the NaN + pattern straddle batch boundaries, which is the exact condition + needed to surface per-batch (rather than global) masking bugs. + + Index 0 is deliberately **never** poisoned: with the latent batch + size ``B`` chosen so that ``N >= B``, batch 0 (absolute indices + ``0 .. B-1``) is left fully finite, so it produces samples with the + *complete* latent key set and seeds the latent ``Samples`` model with + all keys. A later batch then loses column 0 (under the buggy per-batch + JAX column mask), giving it a *reduced* key set — the mismatch that + raises ``KeyError`` in ``Samples.summary()``. Were index 0 poisoned, + batch 0 would itself be reduced and the mismatch would not surface. + """ + return os.environ.get("PYAUTO_LATENT_NAN_INJECT") or None + + +def inject_latent_nans(values_2d, start_index): + """ + Apply :func:`latent_nan_inject_spec` to a materialised + ``(n_samples, n_latents)`` latent-value array. + + Parameters + ---------- + values_2d + A NumPy or JAX array of latent values for one batch, shape + ``(n_samples_in_batch, n_latents)``. + start_index + The absolute index of row 0 of ``values_2d`` within the full sample + list (i.e. the batch offset). NaNs are placed using + ``start_index + local_row`` so the pattern is consistent across + batch boundaries. + + Returns + ------- + The array with NaNs injected per the active spec. When the spec is unset + or unrecognised the input is returned unchanged (true no-op). + + Notes + ----- + Handles both NumPy and JAX arrays. For JAX the functional ``.at[...].set`` + update is used (``jax.numpy`` imported locally, matching the library's + convention of never importing JAX at module scope). + """ + spec = latent_nan_inject_spec() + if not spec: + return values_2d + + if not spec.startswith("stride:"): + return values_2d + + try: + stride = int(spec.split(":", 1)[1]) + except (ValueError, IndexError): + return values_2d + if stride <= 0: + return values_2d + + n_rows = values_2d.shape[0] + nan_rows = [ + r + for r in range(n_rows) + if (start_index + r) != 0 and (start_index + r) % stride == 0 + ] + if not nan_rows: + return values_2d + + # JAX arrays expose ``.at`` for functional updates; NumPy arrays do not. + if hasattr(values_2d, "at"): + import jax.numpy as jnp + + out = values_2d + for r in nan_rows: + out = out.at[r, 0].set(jnp.nan) + return out + + import numpy as np + + out = np.array(values_2d, copy=True) + for r in nan_rows: + out[r, 0] = np.nan + return out + + def with_test_mode_segment(base: Path) -> Path: """ Return ``base`` with a ``test_mode`` segment appended when