diff --git a/source/tests/pt_expt/test_change_bias.py b/source/tests/pt_expt/test_change_bias.py index e3749671aa..2f441f12bd 100644 --- a/source/tests/pt_expt/test_change_bias.py +++ b/source/tests/pt_expt/test_change_bias.py @@ -118,19 +118,68 @@ def _make_config(data_dir: str) -> dict: } +def _make_subset_dataset(src_system: str, dst_system: str, n_frames: int) -> None: + """Copy ``type{,_map}.raw`` and the first ``n_frames`` of every ``.npy`` + in ``set.000`` from ``src_system`` to ``dst_system``. + + Used by ``TestChangeBias`` to shrink the water/data_0 example (80 + frames) down to a tiny subset so that ``dp change-bias`` enumerates + over only ``n_frames`` frames. Why this matters: the in-process + ``main(cmds)`` path runs the model forward over ``nbatches`` frames + via ``compute_output_stats``, and each frame leaks ~50 MB into + torch's caching allocator. At ``n_frames=80`` (the default, + ``min(data.get_nbatches()) = 80``) peak RSS hits ~5 GB which OOMs + the 7 GB GitHub-hosted CI runner. Shrinking to ``n_frames=5`` keeps + peak at ~800 MB while preserving **determinism**: the test + ``test_change_bias_pt2_pte_consistency`` asserts ``atol=1e-10`` + between two .pte and .pt2 calls in the same process, which requires + every frame to be seen on each call regardless of the + shuffle-based ``_load_batch_set`` order. ``nbatches == total + frames`` makes the forward enumerate every frame and so the + aggregate bias is invariant under shuffle. + """ + src_set = os.path.join(src_system, "set.000") + dst_set = os.path.join(dst_system, "set.000") + os.makedirs(dst_set, exist_ok=True) + for raw in ("type.raw", "type_map.raw"): + src = os.path.join(src_system, raw) + if os.path.isfile(src): + shutil.copyfile(src, os.path.join(dst_system, raw)) + for fname in os.listdir(src_set): + if not fname.endswith(".npy"): + continue + arr = np.load(os.path.join(src_set, fname)) + np.save(os.path.join(dst_set, fname), arr[:n_frames]) + + class TestChangeBias(unittest.TestCase): """Test dp change-bias for the pt_expt backend.""" @classmethod def setUpClass(cls) -> None: - data_dir = os.path.join(EXAMPLE_DIR, "data") - if not os.path.isdir(data_dir): - raise unittest.SkipTest(f"Example data not found: {data_dir}") + full_data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(full_data_dir): + raise unittest.SkipTest(f"Example data not found: {full_data_dir}") + cls.tmpdir = tempfile.mkdtemp() + cls.old_cwd = os.getcwd() + + # Shrink the water example dataset (80 frames) to a 5-frame + # subset. ``dp change-bias`` defaults to enumerating every + # frame (``nbatches = min(data.get_nbatches())``), and each + # frame's forward pass leaks ~50 MB into torch's allocator; at + # 80 frames peak RSS pushes the 7 GB CI runner into OOM. See + # the docstring of ``_make_subset_dataset`` for why we keep + # full enumeration (determinism) but shrink the dataset. + data_dir = os.path.join(cls.tmpdir, "data") + os.makedirs(data_dir, exist_ok=True) + _make_subset_dataset( + src_system=os.path.join(full_data_dir, "data_0"), + dst_system=os.path.join(data_dir, "data_0"), + n_frames=5, + ) cls.data_dir = data_dir cls.data_file = [os.path.join(data_dir, "data_0")] - cls.tmpdir = tempfile.mkdtemp() - cls.old_cwd = os.getcwd() os.chdir(cls.tmpdir) # Build & train 1-step model