From 4b016d09304a72385ab299e5a5e9e035b124a7d3 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 12 May 2026 12:32:01 +0200 Subject: [PATCH 1/8] Split bypass distillation core Signed-off-by: Sepehr Sameni --- examples/megatron_bridge/distill.py | 2 +- modelopt/torch/puzzletron/__init__.py | 1 + .../bypass_distillation/__init__.py | 24 + .../bypass_checkpoint_utils.py | 266 ++++ .../bypass_distillation/bypass_utils.py | 422 ++++++ .../bypass_distillation/data_classes.py | 44 + .../stitched_model_factory.py | 644 +++++++++ .../bypass_distillation/training_loop.py | 1275 +++++++++++++++++ .../build_replacement_library.py | 10 +- .../puzzletron/tools/checkpoint_utils_hf.py | 204 ++- .../tools/sharded_checkpoint_utils.py | 23 +- tests/_test_utils/torch/puzzletron/utils.py | 42 + .../test_bypass_checkpoint_utils.py | 381 +++++ .../puzzletron/test_bypass_keys_to_learn.py | 256 ++++ .../puzzletron/test_bypass_lr_scheduler.py | 127 ++ .../torch/puzzletron/test_bypass_utils.py | 223 +++ .../puzzletron/test_checkpoint_utils_hf.py | 91 ++ .../test_launch_bypass_distillation.py | 248 ++++ .../test_replacement_library_bypass_config.py | 57 + .../test_stitched_model_factory_buffers.py | 76 + 20 files changed, 4372 insertions(+), 44 deletions(-) create mode 100644 modelopt/torch/puzzletron/bypass_distillation/__init__.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/data_classes.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/training_loop.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_utils.py create mode 100644 tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py create mode 100644 tests/unit/torch/puzzletron/test_launch_bypass_distillation.py create mode 100644 tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py create mode 100644 tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py diff --git a/examples/megatron_bridge/distill.py b/examples/megatron_bridge/distill.py index 16a0a85f842..5c4224c03dd 100644 --- a/examples/megatron_bridge/distill.py +++ b/examples/megatron_bridge/distill.py @@ -343,7 +343,7 @@ def _build_model_provider(hf_path): load=checkpoint_dir, # Resume from this directory (if exists) most_recent_k=5, # Keeps 5 most recent checkpoints (not metric-based) ckpt_format="torch_dist", - async_save=False, + async_save=True, fully_parallel_save=True, ), rng=RNGConfig(seed=args.seed), diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py index 15389dedfa2..0af53b5cef3 100644 --- a/modelopt/torch/puzzletron/__init__.py +++ b/modelopt/torch/puzzletron/__init__.py @@ -19,6 +19,7 @@ anymodel, block_config, build_library_and_stats, + bypass_distillation, dataset, entrypoint, mip, diff --git a/modelopt/torch/puzzletron/bypass_distillation/__init__.py b/modelopt/torch/puzzletron/bypass_distillation/__init__.py new file mode 100644 index 00000000000..119cbd5cdaf --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/__init__.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation (blockwise local distillation) for the PUZZLE framework. + +This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer +block configurations using per-block knowledge distillation from a teacher model. +""" + +from .training_loop import launch_bypass_distillation + +__all__ = ["launch_bypass_distillation"] diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py new file mode 100644 index 00000000000..89260bcc5f9 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities for bypass distillation.""" + +import os +import re +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Union + +import torch +from omegaconf import DictConfig +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint_from_shards +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.utils.robust_json import json_dump + +from .bypass_utils import load_bypass_state, update_bypass_checkpoint_state +from .stitched_model_factory import StitchedModuleDescriptor + +__all__ = ["find_latest_run_dir", "load_local_state", "save_bypass_checkpoint"] + + +def find_latest_run_dir(run_parent_dir: Union[str, Path]) -> str | None: + """Find the latest plain-step checkpoint directory within a run parent directory. + + Resume prefers the manifest's final checkpoint, then the latest plain step + checkpoint. It must not pick ``best-step-*`` because validation-best snapshots + can be stale relative to the latest optimizer state, nor ``start-step-*``. + """ + run_parent_dir = Path(run_parent_dir) + + state = load_bypass_state(run_parent_dir) + if state is not None: + checkpoints = state.get("checkpoints", {}) + for role in ("final", "resume"): + candidate = checkpoints.get(role) + if candidate and (Path(candidate) / "saving_completed").exists(): + return str(candidate) + + # Check for the "latest" symlink. Current checkpoints only update it for + # plain periodic resume checkpoints, but older runs may have pointed it at a + # best/start/final checkpoint. Validate the target name before accepting it. + latest_dir = run_parent_dir / "latest" + if latest_dir.exists(): + latest_resolved = latest_dir.resolve() + if ( + re.match(r"^step-\d+-ckpt$", latest_resolved.name) + and (latest_resolved / "saving_completed").exists() + ): + return str(latest_resolved) + + # Fallback: scan plain ``step-NNNNNN-ckpt`` directories only. + # Treat a missing parent dir as "no previous runs" rather than fatal — this + # handles two cases cleanly: a freshly-wiped bypass dir, and the race where + # non-master ranks reach this function before master finishes the + # ``set_experiment_dir`` mkdir on a shared filesystem. + if not run_parent_dir.exists(): + return None + step_re = re.compile(r"^step-(\d+)-ckpt$") + candidate_dirs: list[tuple[int, Path]] = [] + for d in run_parent_dir.iterdir(): + if not d.is_dir(): + continue + match = step_re.match(d.name) + if match: + candidate_dirs.append((int(match.group(1)), d)) + + if not candidate_dirs: + return None + + candidate_dirs.sort(key=lambda x: x[0], reverse=True) + for _, ckpt_dir in candidate_dirs: + if (ckpt_dir / "saving_completed").exists(): + return str(ckpt_dir) + return None + + +def load_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_path: str | Path, +) -> None: + """Load optimizer and grad-scaler state for each stitched module. + + Weights are NOT loaded here — they live in the HF checkpoint at + ``checkpoint_path`` and must be loaded into the student model via + ``load_and_shard_model`` before this function runs (typically by setting + ``init_checkpoint_path`` to the resume directory). This avoids + persisting the same parameters twice (once in ``stitched/*.pth`` and + once in the HF state dict). + + Modifies ``stitched_module_descriptors`` in place. + """ + device = torch.device(f"cuda:{dist.local_rank()}") + load_dir = Path(checkpoint_path) + + if not load_dir.exists(): + raise RuntimeError(f'Can\'t load local state. "{load_dir}" does not exist.') + + for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + optimizer_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" + ) + mprint( + f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}" + ) + loaded_optimizer_state = torch.load( + optimizer_state_path, map_location=device, weights_only=True + ) + optimizer.load_state_dict(loaded_optimizer_state) + del loaded_optimizer_state + + # Restore GradScaler state (only relevant when use_grad_scaling=True; for the + # default bf16 / use_grad_scaling=False path the scaler is disabled and its + # state is a no-op, but we still load it if present for forward-compatibility). + # Older checkpoints predating this save path won't have the file — skip silently. + if grad_scaler is not None: + grad_scaler_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.grad_scaler.pth" + ) + if grad_scaler_state_path.exists(): + mprint( + f"Loading grad_scaler state for module {stitched_module_name} " + f"from {grad_scaler_state_path}" + ) + loaded_scaler_state = torch.load( + grad_scaler_state_path, map_location=device, weights_only=True + ) + grad_scaler.load_state_dict(loaded_scaler_state) + del loaded_scaler_state + + +def _save_local_file(obj, save_path: Path | str, overwrite=True): + save_path = Path(save_path) + if save_path.exists(): + if not overwrite: + mprint(f'WARNING: Local save path "{save_path}" already exists. Skipping') + return + torch.save(obj, save_path) + + +def _save_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + overwrite=True, +) -> None: + """Persist optimizer and grad-scaler state for each stitched module. + + Weights are intentionally NOT saved here. The same trainable parameters + would otherwise land on disk twice — once as ``stitched/{block}.state_dict.pth`` + and once as part of the HF checkpoint that ``save_bypass_checkpoint`` + writes at the top level via ``save_checkpoint(model, ...)``. The HF + checkpoint is the single source of truth for weights; this directory + only carries the optimizer/scaler state that the HF format doesn't + cover. + """ + save_dir = Path(checkpoint_dir) / "stitched" + + if dist.is_master(): + save_dir.mkdir(parents=True, exist_ok=True) + + # Main process creates the directory, so we must wait for it to finish + dist.barrier() + + for stitched_module_name, stitched_module_descriptor in tqdm( + stitched_module_descriptors.items(), disable=not dist.is_master() + ): + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth" + aprint( + f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" + ) + _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + + # Persist GradScaler state. Required for correct resume when + # use_grad_scaling=True (state dict carries running scale + growth tracker). + # For the default bf16 / use_grad_scaling=False path the state dict is trivial + # but cheap, so save unconditionally whenever a scaler exists — keeps the + # save/load paths symmetric with the optimizer. + if grad_scaler is not None: + grad_scaler_state_path = save_dir / f"{stitched_module_name}.grad_scaler.pth" + mprint( + f"Saving grad_scaler state for module {stitched_module_name} " + f"to {grad_scaler_state_path}" + ) + _save_local_file(grad_scaler.state_dict(), grad_scaler_state_path, overwrite=overwrite) + + dist.barrier() + + +def save_bypass_checkpoint( + cfg: DictConfig, + descriptor: ModelDescriptor, + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + reference_checkpoint_dir: Optional[Path] = None, + checkpoint_role: str = "resume", +) -> None: + """Save a bypass distillation checkpoint.""" + checkpoint_dir = Path(checkpoint_dir) + mprint("Starting checkpoint save") + mprint(f"Saving checkpoint to {checkpoint_dir}") + + # Save stitched module states + _save_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=checkpoint_dir, + overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints, + ) + # Save as HF checkpoint. Must use the gather-aware variant: bypass training is + # pipeline-parallel so each rank's `model.state_dict()` only carries its own + # owned blocks. The unsharded `save_checkpoint` would have every rank write a + # partial `model.safetensors.index.json` to the same path (last writer wins), + # producing an index that omits most ranks' weights — resume then leaves params + # on the meta device. + save_checkpoint_from_shards(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor) + + if dist.is_master(): + if checkpoint_role == "resume": + # Create 'latest' symlink via tmp-symlink + atomic rename so concurrent + # readers on a shared filesystem never observe a missing `latest`. The + # plain unlink + symlink_to pair leaves a brief window where the link + # doesn't exist; Path.replace (== os.replace) is atomic on POSIX. + latest_symlink = Path(cfg.bypass.experiment_dir) / "latest" + tmp_symlink = latest_symlink.with_name(f".latest_tmp_{os.getpid()}") + tmp_symlink.unlink(missing_ok=True) + tmp_symlink.symlink_to(checkpoint_dir.name) + tmp_symlink.replace(latest_symlink) + # Save config args json + json_dump(cfg.bypass, checkpoint_dir / "args.json") + model_factory_cfg = cfg.bypass.get("model_factory", {}) + json_dump( + {"keys_to_learn": model_factory_cfg.get("keys_to_learn", "entire_block")}, + checkpoint_dir / "bypass_config.json", + ) + # Save completed file + completed_file = checkpoint_dir / "saving_completed" + completed_file.touch() + update_bypass_checkpoint_state(cfg, checkpoint_dir, checkpoint_role) + + dist.barrier() + mprint("Checkpoint save done") diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py new file mode 100644 index 00000000000..4402e3f9217 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -0,0 +1,422 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for bypass distillation.""" + +import hashlib +import json +from collections.abc import Sequence +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.utils.robust_json import json_dump, json_load + +__all__ = [ + "BYPASS_STATE_FILENAME", + "BYPASS_SUBBLOCK_KEYS_TO_LEARN", + "bypass_run_is_complete", + "expected_bypass_runs", + "get_bypass_config_fingerprint", + "get_bypass_experiment_fingerprint", + "get_bypass_run_identity", + "get_bypass_state_path", + "get_distributed_modules_ownership", + "get_pipeline_ownership_context", + "learned_subblocks_from_keys_to_learn", + "load_bypass_state", + "mark_bypass_run_completed", + "normalize_keys_to_learn", + "set_experiment_dir", + "set_experiment_id", + "update_bypass_checkpoint_state", + "write_bypass_state", +] + +BYPASS_STATE_FILENAME = "bypass_state.json" +BYPASS_SUBBLOCK_KEYS_TO_LEARN = frozenset( + {"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"} +) + + +def _to_plain_container(value: Any) -> Any: + if isinstance(value, (DictConfig, ListConfig)): + return OmegaConf.to_container(value, resolve=True) + return value + + +def normalize_keys_to_learn(keys_to_learn: Any) -> dict[str, Any]: + """Normalize bypass ``keys_to_learn`` into v1 subblock semantics.""" + keys_to_learn = _to_plain_container(keys_to_learn) + if isinstance(keys_to_learn, str): + if keys_to_learn in BYPASS_SUBBLOCK_KEYS_TO_LEARN: + return {"mode": "subblocks", "subblocks": (keys_to_learn,)} + raise ValueError( + "keys_to_learn must be one of " + f"{sorted(BYPASS_SUBBLOCK_KEYS_TO_LEARN)}, got {keys_to_learn!r}" + ) + + if isinstance(keys_to_learn, Sequence): + values = tuple(keys_to_learn) + if not all(isinstance(value, str) for value in values): + raise TypeError(f"keys_to_learn entries must be strings, got {keys_to_learn!r}") + if not values: + raise ValueError("keys_to_learn cannot be empty") + invalid = [value for value in values if value not in BYPASS_SUBBLOCK_KEYS_TO_LEARN] + if invalid: + raise ValueError( + f"keys_to_learn supports only subblock keys in v1; invalid entries: {invalid!r}" + ) + if "entire_block" in values and len(set(values)) > 1: + raise ValueError("keys_to_learn cannot mix 'entire_block' with other subblock keys") + return {"mode": "subblocks", "subblocks": tuple(dict.fromkeys(values))} + + raise TypeError(f"Unsupported keys_to_learn={keys_to_learn!r}") + + +def learned_subblocks_from_keys_to_learn(keys_to_learn: Any) -> list[str]: + """Return replacement-library subblocks represented by ``keys_to_learn``.""" + normalized = normalize_keys_to_learn(keys_to_learn) + subblocks = set(normalized["subblocks"]) + if subblocks == {"entire_block"}: + return ["block"] + + out: list[str] = [] + if "subblock_attention" in subblocks or "subblock_mamba" in subblocks: + out.append("attention") + if "subblock_ffn" in subblocks: + out.append("ffn") + return out + + +def _slug(value: Any) -> str: + text = str(value).strip().lower().replace("subblock_", "") + keep = [ch if ch.isalnum() else "_" for ch in text] + slug = "".join(keep).strip("_") + while "__" in slug: + slug = slug.replace("__", "_") + return slug or "custom" + + +def get_bypass_run_identity(cfg: DictConfig) -> dict[str, Any]: + """Return the config subset that defines a bypass output. + + The full Hydra config carries mutable runtime counters, checkpoint paths and + logging fields. Those should not decide whether a completed bypass run can + be reused. This identity intentionally keeps architecture, training budget, + data shape and learning-target fields, because changing any of them changes + the produced checkpoint. + """ + bypass = _to_plain_container(cfg.bypass) + training = bypass.get("training", {}) + data = bypass.get("data", {}) + model = bypass.get("model", {}) + model_factory = bypass.get("model_factory", {}) + return { + "model": { + "student_weights_dtype": model.get("student_weights_dtype"), + "model_config_overrides": model.get("model_config_overrides"), + }, + "model_factory": { + "factory": model_factory.get("factory"), + "block_loss_func": model_factory.get("block_loss_func"), + "gqa_init_mode": model_factory.get("gqa_init_mode"), + "mlp_init_mode": model_factory.get("mlp_init_mode"), + "mlp_init_config": model_factory.get("mlp_init_config"), + "linear_init_mode": model_factory.get("linear_init_mode"), + "submodule_for_loss_calculation": model_factory.get("submodule_for_loss_calculation"), + "keys_to_learn": model_factory.get("keys_to_learn"), + }, + "training": { + "learning_rate": training.get("learning_rate"), + "training_tokens": training.get("training_tokens"), + "micro_batch_size": training.get("micro_batch_size"), + "grad_accumulation_steps": training.get("grad_accumulation_steps"), + "weight_decay": training.get("weight_decay"), + "decay_lr": training.get("decay_lr"), + "beta1": training.get("beta1"), + "beta2": training.get("beta2"), + "grad_clip": training.get("grad_clip"), + "grad_clip_type": training.get("grad_clip_type"), + "warmup_ratio": training.get("warmup_ratio"), + "min_lr_factor": training.get("min_lr_factor"), + }, + "data": { + "dataset_path": cfg.get("dataset_path", None), + "block_size": data.get("block_size"), + "data_column": data.get("data_column"), + "fim_rate": data.get("fim_rate"), + "fim_spm_rate": data.get("fim_spm_rate"), + "bos_rate": data.get("bos_rate"), + "source_datasets_to_discard": data.get("source_datasets_to_discard"), + "load_from_disk": data.get("load_from_disk"), + "keep_in_memory": data.get("keep_in_memory"), + "shuffle_train_data_seed": data.get("shuffle_train_data_seed"), + "val_dataset_name": data.get("val_dataset_name"), + "max_eval_samples": data.get("max_eval_samples"), + "eval_samples_per_process": data.get("eval_samples_per_process"), + }, + "validation": { + "disable_validation": bypass.get("disable_validation"), + "save_best_ckpt": bypass.get("save_best_ckpt"), + "realize_best_or_latest": bypass.get("realize_best_or_latest"), + "eval_interval": training.get("eval_interval"), + "val_micro_batch_size": training.get("val_micro_batch_size"), + }, + "seed": bypass.get("seed"), + "dtype": bypass.get("dtype"), + } + + +def get_bypass_config_fingerprint(cfg: DictConfig) -> str: + identity = get_bypass_run_identity(cfg) + payload = json.dumps(identity, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def get_bypass_experiment_fingerprint(cfg: DictConfig) -> str: + """Return a stable ID fingerprint for the architecture and learning target. + + Training budget and data settings are deliberately excluded so a longer + rerun can resume the same architecture from its previous final checkpoint. + The full config fingerprint is still recorded in bypass_state.json and used + for skip-if-complete decisions. + """ + identity = get_bypass_run_identity(cfg) + experiment_identity = { + "model": identity["model"], + "model_factory": { + "factory": identity["model_factory"]["factory"], + "block_loss_func": identity["model_factory"]["block_loss_func"], + "keys_to_learn": identity["model_factory"]["keys_to_learn"], + "gqa_init_mode": identity["model_factory"]["gqa_init_mode"], + "mlp_init_mode": identity["model_factory"]["mlp_init_mode"], + "mlp_init_config": identity["model_factory"]["mlp_init_config"], + "linear_init_mode": identity["model_factory"]["linear_init_mode"], + "submodule_for_loss_calculation": identity["model_factory"][ + "submodule_for_loss_calculation" + ], + }, + } + payload = json.dumps(experiment_identity, sort_keys=True, default=str, separators=(",", ":")) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def set_experiment_id(cfg: DictConfig) -> None: + """Set the experiment ID based on the model config overrides. + + The ID encodes every override that affects the produced student so that + sweeps over (FFN size × KV heads) or (num_experts × KV heads) get distinct + directories instead of clobbering each other. + """ + if cfg.bypass.experiment_id is not None: + return + + overrides = cfg.bypass.model.model_config_overrides + parts: list[str] = [] + + if "ffn" in overrides: + ffn_override = overrides.ffn[0] + if "intermediate_size" in ffn_override and ffn_override["intermediate_size"] is not None: + parts.append(f"ffn_{ffn_override['intermediate_size']}") + elif "moe" in ffn_override and ffn_override["moe"] is not None: + parts.append(f"experts_{ffn_override['moe']['num_local_experts']}") + + if "attention" in overrides: + attn_override = overrides.attention[0] + if ( + "num_key_value_heads" in attn_override + and attn_override["num_key_value_heads"] is not None + ): + parts.append(f"heads_{attn_override['num_key_value_heads']}") + + keys_to_learn = cfg.bypass.model_factory.get("keys_to_learn", None) + if keys_to_learn not in (None, "entire_block"): + parts.append(_slug(keys_to_learn)) + + if not parts: + parts.append("custom") + + # Keep the readable architecture prefix, but suffix it with the config + # fingerprint so two runs with the same architecture but different learning + # target or training budget cannot collide in the same experiment_dir. + cfg.bypass.experiment_id = "bypass_" + "_".join(parts) + cfg.bypass.experiment_id += f"_{get_bypass_experiment_fingerprint(cfg)[:8]}" + + +def set_experiment_dir(cfg: DictConfig) -> None: + """Set the experiment directory for the bypass run. + + Stores the path as a string in the OmegaConf node (OmegaConf only supports + primitive types natively). Use sites should reconstruct ``Path(...)`` as needed. + """ + experiment_dir = Path(cfg.puzzle_dir) / "bypass" / "bypass_runs" / cfg.bypass.experiment_id + cfg.bypass.experiment_dir = str(experiment_dir) + if dist.is_master(): + experiment_dir.mkdir(parents=True, exist_ok=True) + + +def get_bypass_state_path(experiment_dir: str | Path) -> Path: + return Path(experiment_dir) / BYPASS_STATE_FILENAME + + +def load_bypass_state(experiment_dir: str | Path) -> dict[str, Any] | None: + state_path = get_bypass_state_path(experiment_dir) + if not state_path.exists(): + return None + return json_load(state_path) + + +def write_bypass_state(cfg: DictConfig, state: dict[str, Any]) -> None: + if not dist.is_master(): + return + json_dump(state, get_bypass_state_path(cfg.bypass.experiment_dir)) + + +def _base_bypass_state(cfg: DictConfig) -> dict[str, Any]: + return { + "version": 1, + "experiment_id": cfg.bypass.get("experiment_id", None), + "config_fingerprint": get_bypass_config_fingerprint(cfg), + "identity": get_bypass_run_identity(cfg), + "status": "running", + "checkpoints": {}, + "realized_checkpoint": None, + "ckpts_symlink": None, + } + + +def update_bypass_checkpoint_state( + cfg: DictConfig, checkpoint_dir: str | Path, checkpoint_role: str +) -> None: + if not dist.is_master(): + return + state = load_bypass_state(cfg.bypass.experiment_dir) or _base_bypass_state(cfg) + state["status"] = "running" + state["config_fingerprint"] = get_bypass_config_fingerprint(cfg) + state["identity"] = get_bypass_run_identity(cfg) + state.setdefault("checkpoints", {})[checkpoint_role] = str(Path(checkpoint_dir)) + write_bypass_state(cfg, state) + + +def mark_bypass_run_completed( + cfg: DictConfig, realized_checkpoint: str | Path, ckpts_symlink: str | Path +) -> None: + state = load_bypass_state(cfg.bypass.experiment_dir) or _base_bypass_state(cfg) + state["status"] = "completed" + state["config_fingerprint"] = get_bypass_config_fingerprint(cfg) + state["identity"] = get_bypass_run_identity(cfg) + state["realized_checkpoint"] = str(realized_checkpoint) + state["ckpts_symlink"] = str(ckpts_symlink) + write_bypass_state(cfg, state) + if dist.is_master(): + (Path(cfg.bypass.experiment_dir) / "_DONE").touch() + + +def bypass_run_is_complete(cfg: DictConfig) -> bool: + state = load_bypass_state(cfg.bypass.experiment_dir) + if state is None: + return False + if state.get("status") != "completed": + return False + if state.get("config_fingerprint") != get_bypass_config_fingerprint(cfg): + return False + realized = state.get("realized_checkpoint") + symlink = state.get("ckpts_symlink") + if not realized or not Path(realized).exists(): + return False + if not symlink or not Path(symlink).exists(): + return False + return True + + +def expected_bypass_runs(cfg: DictConfig) -> list[dict[str, Any]]: + """Return expected run metadata for the current bypass config or sweep.""" + runs: list[dict[str, Any]] = [] + configs_list = cfg.bypass.get("configs", None) + overrides = configs_list if configs_list else [None] + + for override in overrides: + run_cfg = OmegaConf.create( + { + "puzzle_dir": cfg.puzzle_dir, + "dataset_path": cfg.get("dataset_path", None), + "descriptor": cfg.get("descriptor", None), + "bypass": OmegaConf.to_container(cfg.bypass, resolve=True), + } + ) + OmegaConf.set_struct(run_cfg, False) + if override: + run_cfg.bypass.experiment_id = None + if "model_config_overrides" in override: + run_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + run_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + set_experiment_id(run_cfg) + experiment_dir = ( + Path(run_cfg.puzzle_dir) / "bypass" / "bypass_runs" / run_cfg.bypass.experiment_id + ) + runs.append( + { + "experiment_id": run_cfg.bypass.experiment_id, + "experiment_dir": str(experiment_dir), + "config_fingerprint": get_bypass_config_fingerprint(run_cfg), + } + ) + return runs + + +def get_distributed_modules_ownership(module_count: int, world_size: int) -> list[int]: + """Map module (block) indices to GPU ranks for pipeline-parallel distribution.""" + modules_process_ownership: list[int] = [] + + for i in range(world_size): + num_modules_for_process = module_count // world_size + if i < module_count % world_size: + num_modules_for_process += 1 + + modules_process_ownership.extend([i] * num_modules_for_process) + + return modules_process_ownership + + +def get_pipeline_ownership_context( + module_ownership: Sequence[int], rank: int | None = None +) -> dict[str, Any]: + """Return local module indices and neighboring pipeline ranks for ``rank``.""" + if rank is None: + rank = dist.rank() + owned_indices = [i for i, owner in enumerate(module_ownership) if owner == rank] + if not owned_indices: + raise RuntimeError( + f"rank {rank} owns no modules in pipeline ownership map {list(module_ownership)}" + ) + + min_owned_index = min(owned_indices) + max_owned_index = max(owned_indices) + prev_rank = None if min_owned_index == 0 else module_ownership[min_owned_index - 1] + next_rank = ( + None + if max_owned_index + 1 >= len(module_ownership) + else module_ownership[max_owned_index + 1] + ) + return { + "owned_indices": owned_indices, + "owned_index_set": set(owned_indices), + "prev_rank": prev_rank, + "next_rank": next_rank, + } diff --git a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py new file mode 100644 index 00000000000..bb04f68e4bf --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data classes for bypass distillation training.""" + +import dataclasses +from typing import TypeAlias + +__all__ = ["GlobalRank", "IterNum", "IterStatistics", "LocalTrainingStats", "TimeToSaveSignal"] + +IterNum: TypeAlias = int +GlobalRank: TypeAlias = int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class IterStatistics: + step_num: int + token_count: int + iter_duration: float + lr: float + clipping_count: int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class LocalTrainingStats: + iter_num: int + stitched_module_losses: dict[str, float] + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class TimeToSaveSignal: + step_num: int diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py new file mode 100644 index 00000000000..c44be3e7e3f --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -0,0 +1,644 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating stitched teacher/student models for bypass distillation.""" + +import copy +import dataclasses +import re +from argparse import Namespace +from collections import OrderedDict +from contextlib import nullcontext +from pathlib import Path +from typing import Any, Callable, Mapping, Optional, Sequence + +import torch +from omegaconf import DictConfig, OmegaConf +from torch.amp.grad_scaler import GradScaler +from torch.optim import AdamW, Optimizer +from transformers import PretrainedConfig, PreTrainedModel + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher +from modelopt.torch.puzzletron.pruning.pruning_utils import GQAInitMode, LinearInitMode, MlpInitMode +from modelopt.torch.puzzletron.sewing_kit import ( + ExternalTarget, + FunctionTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, + always_true_predicate, +) +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + create_child_state_dict, + update_model_config, +) +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model +from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype + +from .bypass_utils import get_pipeline_ownership_context, normalize_keys_to_learn + +__all__ = [ + "Args", + "Config", + "StitchedModuleDescriptor", + "StitchedModulesProcessOwnership", + "SyncDistributedModelWeightsFn", + "bypass_factory_fn", +] + +StitchedModulesProcessOwnership = list[int] +SyncDistributedModelWeightsFn = Callable[[], None] +Config = Mapping[str, Any] +Args = Namespace + + +@dataclasses.dataclass +class StitchedModuleDescriptor: + stitched_module: StitchedModule + owned_parameters: dict[str, torch.nn.Parameter] + owned_buffers: dict[str, torch.Tensor] + optimizer: Optional[Optimizer] = None + grad_scaler: Optional[GradScaler] = None + + +def _autocast_context(descriptor: ModelDescriptor): + return ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if descriptor.uses_autocast() + else nullcontext() + ) + + +def _param_names_for_subblock_key( + model: PreTrainedModel, + descriptor: ModelDescriptor, + subblock_key: str, +) -> set[str]: + lm_config = descriptor.get_language_model_config(model.config) + weight_groups = descriptor.get_weight_groups( + model.state_dict().keys(), lm_config.num_hidden_layers + ) + + attn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_attention") + ] + ffn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_ffn") + ] + if subblock_key == "subblock_attention": + group_names = attn_group_names + elif subblock_key == "subblock_ffn": + group_names = ffn_group_names + elif subblock_key == "subblock_mamba": + group_names = attn_group_names # Mamba params live in _attention groups + elif subblock_key == "entire_block": + group_names = attn_group_names + ffn_group_names + else: + raise ValueError(f"Unsupported subblock key: {subblock_key!r}") + + # block_configs lives on the outer puzzletron-converted config for nested + # HF configs (for example Qwen3-VL), not necessarily on the language sub-config. + block_configs = getattr(model.config, "block_configs", None) or getattr( + lm_config, "block_configs", None + ) + + collected: list[str] = [] + for group_name in group_names: + if block_configs is not None: + m = re.match(r"block_(\d+)_attention", group_name) + if m: + block_idx = int(m.group(1)) + if block_idx < len(block_configs): + attention_cfg = getattr(block_configs[block_idx], "attention", None) + is_mamba = getattr(attention_cfg, "mamba", None) is not None + if subblock_key == "subblock_attention" and is_mamba: + continue + if subblock_key == "subblock_mamba" and not is_mamba: + continue + collected.extend(weight_groups[group_name]) + return set(collected) + + +def _set_keys_to_learn( + model: PreTrainedModel, + descriptor: ModelDescriptor, + keys_to_learn: str | Sequence[str], +) -> None: + """Set ``requires_grad=True`` on parameters selected by ``keys_to_learn``. + + Bypass v1 supports only descriptor-backed subblock keys. This keeps training + selection aligned with replacement-library extraction. + """ + normalized = normalize_keys_to_learn(keys_to_learn) + param_names = set() + for subblock_key in normalized["subblocks"]: + param_names.update(_param_names_for_subblock_key(model, descriptor, subblock_key)) + # In pipeline-parallel training a rank may own only blocks that don't match + # keys_to_learn (e.g. a rank with only Mamba blocks during subblock_attention + # bypass has no GQA params after the _mamba rename). That is a valid state: + # those blocks are tracked as non-trainable and omitted from numeric loss stats. + if not param_names: + return + + # Set requires_grad to True for the selected parameters. + for param_name, param in model.named_parameters(): + if param_name in param_names and torch.is_floating_point(param): + param.requires_grad_(True) + + +def _get_all_non_persistent_buffers_set(module: torch.nn.Module) -> set[str]: + all_non_persistent = set() + for module_name, submodule in module.named_modules(): + for buffer_name in submodule._non_persistent_buffers_set: + full_name = f"{module_name}.{buffer_name}" if module_name else buffer_name + all_non_persistent.add(full_name) + return all_non_persistent + + +def bypass_factory_fn( + teacher_model: PreTrainedModel, + descriptor: ModelDescriptor, + cfg: DictConfig, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + """Unified factory function for bypass (blockwise local) distillation. + + Handles all layer types — FFN, attention (GQA/MHA), MoE experts, Mamba, and whole blocks — + through a single pipeline. Behavior is driven entirely by ``model_factory`` config fields: + + - ``mlp_init_mode``: how student FFN / MoE weights are initialised + - ``"ExpertRemoval"``: select top-N experts from teacher (MoE models) + - ``"Truncate"`` / ``"PruneByActivationsLog"``: prune FFN channels (dense models) + - ``"CopyAsIs"``: copy weights unchanged (attention-only or Mamba-only runs) + - ``gqa_init_mode``: how attention KV heads are initialised (optional, default ``AverageKV``). + Irrelevant when the student has the same number of KV heads as the teacher. + - ``keys_to_learn``: which subblock parameters to train. + Accepts ``"subblock_ffn"``, ``"subblock_attention"``, ``"subblock_mamba"``, + ``"entire_block"``, or a list of those keys. + + The stitching logic (pipeline-parallel per-block KD) is architecture-agnostic and unchanged + regardless of which layer type is being distilled. + + Args: + teacher_model: The teacher model to use for stitching. + descriptor: Model descriptor for layer naming and pruning mixin lookup. + cfg: The bypass config section. + model_blocks_process_ownership: Ownership mapping of model blocks to process ranks. + student_model: Optionally provided pre-built student model (skips initialisation). + + Returns: + Tuple of (student_model, teacher_stitched, teacher_val_stitched, + student_val_stitched, stitched_module_descriptors, student_config) + """ + device = torch.device(f"cuda:{dist.local_rank()}") + model_config_overrides = cfg.model.model_config_overrides + + _block_loss_funcs: dict[str, Callable[..., Any]] = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + } + block_loss_func = _block_loss_funcs[cfg.model_factory.block_loss_func] + mprint(f"{block_loss_func.__name__=}") + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + # Initialize student_model + if student_model is None: + mprint("Creating student model from teacher model") + + with _autocast_context(descriptor): + if isinstance(model_config_overrides, DictConfig): + config_to_override = OmegaConf.to_container(model_config_overrides, resolve=True) + else: + config_to_override = model_config_overrides + mprint(f"{config_to_override=}") + student_model_config = update_model_config( + model_config=teacher_model.config, + model_config_overrides=config_to_override, + ) + student_model_config.use_cache = False + + mprint(f"Student model config:\n {format_block_configs(student_model_config)}") + + runtime = Namespace( + device=device, + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + ) + + with deci_x_patcher( + model_descriptor=descriptor, + block_configs=getattr(student_model_config, "block_configs", None), + ): + student_model = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=student_model_config, + owned_block_indexes=owned_block_indexes, + trust_remote_code=cfg.get("trust_remote_code", False), + device=device, + ) + # `_init_weights` is HF's per-module initializer; apply it across the + # whole model rather than passing the model itself as a single module. + student_model.apply(student_model._init_weights) + + student_weights_dtype = parse_dtype(cfg.model.student_weights_dtype) + descriptor.init_rotary_embedding(student_model, runtime) + student_model.type(student_weights_dtype) + + mlp_init_mode = MlpInitMode(cfg.model_factory.mlp_init_mode or MlpInitMode.CopyAsIs) + + # For expert removal, use the model-specific pruning mixin so that model-specific + # key paths (e.g. backbone.layers.{i}.mixer for Nemotron-H vs model.layers.{i}.mlp + # for GPT-OSS) are handled correctly. For all other init modes the legacy inline + # key logic in create_child_state_dict is sufficient. + _mixins = [] + if mlp_init_mode == MlpInitMode.ExpertRemoval: + _expert_mixin = descriptor.pruning_mixins().get("experts_removal") + if _expert_mixin is not None: + _mixins.append(_expert_mixin) + + # If any attention layer has fewer KV heads in the student than the teacher, use the + # model-specific KV heads mixin so that k_proj/v_proj weights are correctly sliced + # rather than copied verbatim from the (larger) teacher state dict. + _kv_mixin = descriptor.pruning_mixins().get("kv_heads") + if _kv_mixin is not None: + _student_kv = [ + b.attention.num_key_value_heads + for b in student_model_config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + _teacher_kv = [ + b.attention.num_key_value_heads + for b in teacher_model.config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + assert len(_student_kv) == len(_teacher_kv), ( + f"KV-head block-config length mismatch: student={len(_student_kv)} " + f"teacher={len(_teacher_kv)} — check model_config_overrides" + ) + if _student_kv != _teacher_kv: + _mixins.append(_kv_mixin) + + # If any FFN layer has a smaller intermediate_size in the student than the teacher, + # use the model-specific FFN-intermediate mixin. The generic create_child_state_dict + # path is hardcoded to `model.layers.{i}.mlp.*` (Llama-style), so for families that + # place FFN under a different prefix (e.g. `backbone.layers.{i}.mixer.*` for + # Nemotron-H/H_v2) the mixin is required to slice up_proj/down_proj correctly. + # Filter out no_op FFN blocks (their intermediate_size is None) — relevant for + # hybrid families where each layer is exactly one of {attention, ffn, mamba}. + _ffn_mixin = descriptor.pruning_mixins().get("ffn_intermediate") + if _ffn_mixin is not None and mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + _student_ffn = [ + b.ffn.intermediate_size + for b in student_model_config.block_configs + if b.ffn is not None and b.ffn.intermediate_size is not None + ] + _teacher_ffn = [ + b.ffn.intermediate_size + for b in teacher_model.config.block_configs + if b.ffn is not None and b.ffn.intermediate_size is not None + ] + assert len(_student_ffn) == len(_teacher_ffn), ( + f"FFN-intermediate block-config length mismatch: student={len(_student_ffn)} " + f"teacher={len(_teacher_ffn)} — check model_config_overrides" + ) + if _student_ffn != _teacher_ffn: + _mixins.append(_ffn_mixin) + + if len(_mixins) == 0: + pruning_mixin = None + elif len(_mixins) == 1: + pruning_mixin = _mixins[0] + else: + pruning_mixin = _mixins + + # GQA init mode is optional: only relevant when the student has fewer KV heads than + # the teacher. Defaults to AverageKV and is a no-op when head counts are equal. + gqa_init_mode = GQAInitMode(cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV)) + + student_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=teacher_model.state_dict(), + new_state_dict=student_model.state_dict(), + original_config=teacher_model.config, + new_config=student_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=cfg.model_factory.mlp_init_config, + owned_block_indexes=owned_block_indexes, + linear_init_mode=LinearInitMode( + cfg.model_factory.linear_init_mode or LinearInitMode.Random + ), + ) + + # Load student state dict + missing_keys, unexpected_keys = student_model.load_state_dict( + student_state_dict, strict=False + ) + assert len(unexpected_keys) == 0, f"{unexpected_keys=}" + # GQA models have learnable logit parameters not present in the teacher state dict; + # allow those to be absent and assert nothing else is missing. + non_gqa_missing = [k for k in missing_keys if not re.search(r"gqa_\w+_logits", k)] + assert len(non_gqa_missing) == 0, f"Unexpected missing keys: {non_gqa_missing}" + + else: + mprint("Student model provided explicitly, not using teacher model to instantiate") + student_model_config = student_model.config + + # Set up training parameters + lm_config = descriptor.get_language_model_config(student_model_config) + all_block_indices = list(range(lm_config.num_hidden_layers)) + + student_model.requires_grad_(False) + keys_to_learn = cfg.model_factory.keys_to_learn + mprint(f"Keys to learn: {keys_to_learn}") + + _set_keys_to_learn(model=student_model, descriptor=descriptor, keys_to_learn=keys_to_learn) + + dist.barrier() + mprint(f"Global rank: {dist.rank()}, {owned_block_indexes=}") + dist.barrier() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + dist.barrier() + + # Every rank derives ownership from the same `model_blocks_process_ownership` + # list, so this guard fires identically on every rank when world_size exceeds + # num_hidden_layers — no NCCL hang from a single rank diverging. + ranks_with_blocks = set(model_blocks_process_ownership) + empty_ranks = [r for r in range(dist.size()) if r not in ranks_with_blocks] + if empty_ranks: + raise RuntimeError( + f"world_size ({dist.size()}) exceeds num_hidden_layers " + f"({len(all_block_indices)}); ranks {empty_ranks} would own 0 blocks. " + f"Pipeline-parallel bypass distillation does not support idle ranks — " + f"reduce nproc_per_node to at most num_hidden_layers." + ) + + ownership_context = get_pipeline_ownership_context(model_blocks_process_ownership) + prev_rank: Optional[int] = ownership_context["prev_rank"] + next_rank: Optional[int] = ownership_context["next_rank"] + + teacher_parameters = set(teacher_model.parameters()) + teacher_buffers = set(teacher_model.buffers()) + + # Setup the student model's submodules for knowledge distillation training + with _autocast_context(descriptor), torch.device(device): + stitched_module_descriptors = OrderedDict[str, StitchedModuleDescriptor]() + submodule_for_loss_calculation = cfg.model_factory.submodule_for_loss_calculation + + teacher_target = ModuleTarget("teacher", teacher_model) + teacher_stitcher = Needle() + teacher_val_stitcher = Needle() + + student_target = ModuleTarget("student", student_model) + student_val_stitcher = Needle() + + for local_block_index, global_block_index in enumerate(sorted(owned_block_indexes)): + module_name = descriptor.layer_block_name(global_block_index) + module = student_model.get_submodule(module_name) + + submodule_name = "" + submodule_input_descriptor = submodule_name + submodule_output_descriptor = submodule_name + + if submodule_for_loss_calculation is not None: + assert hasattr(module, submodule_for_loss_calculation) + submodule_output_descriptor = submodule_for_loss_calculation + + input_descriptor = f"{module_name}.{submodule_input_descriptor}".rstrip(".") + output_descriptor = f"{module_name}.{submodule_output_descriptor}".rstrip(".") + + # Receive activations from previous rank + if global_block_index > 0 and local_block_index == 0 and prev_rank is not None: + teacher_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + teacher_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + student_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="student_activations", adapter=lambda x: InputArgs(x) + ), + student_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + # Send activations to next rank or register model output + if local_block_index + 1 == len(owned_block_indexes): + if next_rank is None: + student_val_stitcher.stitch( + student_target.output(name=""), + ExternalTarget().output("model_output"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=""), + ExternalTarget().output("model_output"), + ) + else: + teacher_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + student_val_stitcher.stitch( + student_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="student_activations"), + ) + + # Bypass training stitches + teacher_stitcher.stitch( + teacher_target.input(name=input_descriptor), + ExternalTarget().input(name=input_descriptor), + ).stitch( + teacher_target.output(name=output_descriptor), + ExternalTarget().output(name=output_descriptor), + ) + + # Create the student block stitched module + student_stitched_module_loss_target = FunctionTarget( + "module_loss_func", block_loss_func + ) + student_stitched_module_name = f"block_{global_block_index}" + student_submodule_target = ModuleTarget("student_submodule", module) + # When a block returns a tuple, ``v[0]`` is the hidden state by + # HF convention — every HF transformer block (Llama, Qwen, GPT-OSS, + # NemotronH, …) returns ``(hidden_states, *aux)``, with ``aux`` + # varying (attention weights, KV cache, router logits, …) but + # element 0 always being the hidden state. Puzzletron is HF-format- + # only, so this assumption holds across every supported family. + student_stitched_module = ( + Needle() + .stitch( + ExternalTarget().input(name=input_descriptor), + student_submodule_target.input(name=submodule_input_descriptor), + ) + .stitch( + ExternalTarget().output( + name=output_descriptor, + adapter=lambda v: InputArgs(target=v) + if not isinstance(v, tuple) + else InputArgs(target=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_submodule_target.output( + name=submodule_output_descriptor, + adapter=lambda v: InputArgs(input=v) + if not isinstance(v, tuple) + else InputArgs(input=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_stitched_module_loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot( + ignore_extra_overrides=True, + capture_cache_outputs_predicate=always_true_predicate, + ) + ) + + assert "learning_rate" in cfg.training + # Do NOT enable dummy params: blocks with no real trainable parameters + # (e.g. Mamba blocks during an attention-only bypass run) should produce + # NaN loss so they are excluded from statistics — identical to the + # optimizer=None path in the training loop. + + student_module_parameters = { + p_name: p + for p_name, p in student_stitched_module.named_parameters() + if p not in teacher_parameters and "dummy_param" not in p_name + } + student_module_buffers = { + p_name: p + for p_name, p in student_stitched_module.named_buffers() + if p not in teacher_buffers + and p_name not in _get_all_non_persistent_buffers_set(student_stitched_module) + } + + trainable_params = { + p_name: p for p_name, p in student_module_parameters.items() if p.requires_grad + } + + optimizer = ( + AdamW( + list(trainable_params.values()), + lr=cfg.training.learning_rate, + weight_decay=cfg.training.weight_decay, + betas=(cfg.training.beta1, cfg.training.beta2), + fused=True, + ) + if len(trainable_params) > 0 + else None + ) + + grad_scaler = ( + None + if optimizer is None + else GradScaler(device=device.type, enabled=cfg.training.use_grad_scaling) + ) + + stitched_module_descriptors[student_stitched_module_name] = StitchedModuleDescriptor( + stitched_module=student_stitched_module, + owned_parameters=student_module_parameters, + owned_buffers=student_module_buffers, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + + teacher_stitched_module = teacher_stitcher.knot(ignore_extra_overrides=True) + teacher_val_stitched_module = teacher_val_stitcher.knot(ignore_extra_overrides=True) + student_val_stitched_module = student_val_stitcher.knot(ignore_extra_overrides=True) + + local_trainable_param_count = sum( + p.numel() + for descriptor_ in stitched_module_descriptors.values() + for p in descriptor_.owned_parameters.values() + if p.requires_grad + ) + global_trainable_param_count = dist.allreduce(local_trainable_param_count, reduction="sum") + if global_trainable_param_count == 0: + raise ValueError( + f"keys_to_learn={keys_to_learn!r} did not match any trainable student parameters" + ) + + return ( + student_model, + teacher_stitched_module, + teacher_val_stitched_module, + student_val_stitched_module, + stitched_module_descriptors, + student_model_config, + ) diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py new file mode 100644 index 00000000000..b9cbc4060ff --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -0,0 +1,1275 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation training loop for per-block knowledge distillation. + +This module implements the blockwise local distillation (BLD) stage of the PUZZLE framework. +It trains alternative transformer block configurations using per-block knowledge distillation +from a teacher model, producing a library of "puzzle pieces" with different efficiency/performance +trade-offs. +""" + +import logging +import math +import os +import shutil +import sys +import time +import traceback +from collections import OrderedDict, defaultdict +from contextlib import nullcontext +from pathlib import Path +from statistics import mean +from typing import Optional + +import datasets +import torch +import transformers +from omegaconf import DictConfig, OmegaConf +from torch.utils.data.dataloader import DataLoader +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase + +import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) +from modelopt.torch.puzzletron.sewing_kit import InputArgs, StitchedModule +from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses +from modelopt.torch.utils.robust_json import json_load + +from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint +from .bypass_utils import ( + bypass_run_is_complete, + get_distributed_modules_ownership, + get_pipeline_ownership_context, + load_bypass_state, + mark_bypass_run_completed, + set_experiment_dir, + set_experiment_id, +) +from .data_classes import GlobalRank, IterNum, IterStatistics, TimeToSaveSignal +from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership + +__all__ = [ + "GlobalRank", + "IterNum", + "IterStatistics", + "StitchedModuleDescriptor", + "StitchedModulesProcessOwnership", + "TimeToSaveSignal", + "bypass_run_is_complete", + "find_latest_run_dir", + "get_distributed_modules_ownership", + "get_pipeline_ownership_context", + "launch_bypass_distillation", + "load_bypass_state", + "load_local_state", + "mark_bypass_run_completed", + "realize_bypass_checkpoints", + "run_bypassed_training", + "save_bypass_checkpoint", + "set_experiment_dir", + "set_experiment_id", + "train", +] + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def _autocast_context(descriptor: ModelDescriptor): + return ( + torch.autocast(device_type="cuda", dtype=torch.bfloat16) + if descriptor.uses_autocast() + else nullcontext() + ) + + +def _resolve_trust_remote_code(cfg: DictConfig, descriptor: ModelDescriptor) -> bool: + trust_remote_code = bool(cfg.get("trust_remote_code", False)) + if descriptor.requires_trust_remote_code() and not trust_remote_code: + descriptor_name = getattr(descriptor, "__name__", descriptor.__class__.__name__) + mprint( + f"WARNING: descriptor {descriptor_name} usually requires trust_remote_code=True, " + "but cfg.trust_remote_code is false; loading will proceed without executing " + "custom checkpoint code." + ) + return trust_remote_code + + +def _get_resume_state_path(cfg: DictConfig, resume_checkpoint_path: Optional[str]) -> Optional[str]: + if cfg.bypass.init_checkpoint_path is not None: + if resume_checkpoint_path is not None: + mprint( + f"Ignoring resume checkpoint state from {resume_checkpoint_path} because " + f"bypass.init_checkpoint_path={cfg.bypass.init_checkpoint_path} is set" + ) + return None + return resume_checkpoint_path + + +def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: + """Top-level entry point for bypass distillation stage. + + Runs sewing-kit pipeline-parallel per-block knowledge distillation. + + Supports multiple bypass configurations via ``bypass.configs`` list. + Each entry overrides ``bypass.model.model_config_overrides`` and optionally + ``bypass.model_factory.keys_to_learn``, then runs a full bypass training. + + If ``bypass.configs`` is absent or empty, runs a single bypass training + with the settings already in ``bypass``. + + Args: + hydra_cfg: The full Hydra configuration with a 'bypass' section. + """ + configs_list = hydra_cfg.bypass.get("configs", None) + + if not configs_list: + # Single config mode — run once with whatever is in bypass already + set_experiment_id(hydra_cfg) + set_experiment_dir(hydra_cfg) + dist.barrier() + bypass_complete = bypass_run_is_complete(hydra_cfg) if dist.is_master() else None + bypass_complete = dist.broadcast(bypass_complete, src=0) + if bypass_complete: + mprint( + f"Bypass distillation already completed for {hydra_cfg.bypass.experiment_id}, skipping" + ) + return + mprint("Starting bypass distillation (single config)") + run_bypassed_training(hydra_cfg) + mprint("Bypass distillation completed") + return + + base_model_config_overrides = OmegaConf.to_container( + hydra_cfg.bypass.model.model_config_overrides, resolve=True + ) + base_keys_to_learn = hydra_cfg.bypass.model_factory.keys_to_learn + + mprint(f"Starting bypass distillation sweep ({len(configs_list)} configs)") + for i, override in enumerate(configs_list): + mprint(f"Bypass config {i + 1}/{len(configs_list)}: {override}") + + hydra_cfg.bypass.model.model_config_overrides = OmegaConf.create( + base_model_config_overrides + ) + hydra_cfg.bypass.model_factory.keys_to_learn = base_keys_to_learn + + # Apply overrides for this run + if "model_config_overrides" in override: + hydra_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + hydra_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + + # Reset per-run state so each config starts fresh + hydra_cfg.bypass.experiment_id = None + hydra_cfg.bypass.iter_num = 1 + hydra_cfg.bypass.step_num = 1 + hydra_cfg.bypass.token_count = 0 + hydra_cfg.bypass.best_val_loss = 1e9 + hydra_cfg.bypass.training.clipping_count = 0 + # Per-block bookkeeping for the Stitched-Module-Losses table. Mirrored + # into cfg.bypass on every log chunk so save_bypass_checkpoint's + # args.json snapshot carries them, and resume can restore the columns + # instead of trivially re-anchoring to the first post-resume chunk. + hydra_cfg.bypass.best_losses_by_name = {} + hydra_cfg.bypass.best_steps_by_name = {} + hydra_cfg.bypass.initial_losses_by_name = {} + + set_experiment_id(hydra_cfg) + set_experiment_dir(hydra_cfg) + dist.barrier() + bypass_complete = bypass_run_is_complete(hydra_cfg) if dist.is_master() else None + bypass_complete = dist.broadcast(bypass_complete, src=0) + if bypass_complete: + mprint( + f"Bypass config {i + 1}/{len(configs_list)} " + f"({hydra_cfg.bypass.experiment_id}) already completed, skipping" + ) + else: + run_bypassed_training(hydra_cfg) + mprint(f"Bypass config {i + 1}/{len(configs_list)} completed") + + mprint("Bypass distillation sweep completed") + + +def _flush_loss_buffer( + local_buffer: dict[int, dict[str, float]], + stitched_losses_history: Optional[dict[int, dict[str, float]]], +) -> None: + """All-gather buffered per-iter losses and merge into master's history. + + Pickle-based ``all_gather_object`` was previously called on every micro-batch; + batching to log-chunk boundaries reduces that cost ~``iters_per_log_chunk``×. + All ranks must call this so the collective doesn't deadlock; only master + actually accumulates into ``stitched_losses_history``. + """ + if not local_buffer: + return + gathered = dist.allgather(local_buffer) + if dist.is_master(): + assert stitched_losses_history is not None + for rank_buf in gathered: + for it, losses in rank_buf.items(): + stitched_losses_history.setdefault(it, {}).update(losses) + + +def _delete_old_checkpoints( + experiment_dir: Path, + glob_pattern: str, + keep_name: str, +) -> None: + if not dist.is_master(): + return + for old_ckpt_path in experiment_dir.glob(glob_pattern): + if old_ckpt_path.name != keep_name: + shutil.rmtree(str(old_ckpt_path)) + + +def _save_training_checkpoint( + *, + cfg: DictConfig, + descriptor: ModelDescriptor, + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + subdir_name: str, + checkpoint_role: str, + cleanup_glob: str | None = None, +) -> None: + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + checkpoint_role=checkpoint_role, + ) + if cleanup_glob and cfg.bypass.model.model_overrides.delete_old_checkpoints: + _delete_old_checkpoints(Path(cfg.bypass.experiment_dir), cleanup_glob, subdir_name) + + +def train( + cfg: DictConfig, + descriptor: ModelDescriptor, + student_model: torch.nn.Module, + student_stitched_model: StitchedModule, + teacher_stitched_model: StitchedModule, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + stitched_modules_process_ownership: StitchedModulesProcessOwnership, + train_dataloader: Optional[DataLoader], + val_dataloader: Optional[DataLoader], + student_model_config: PretrainedConfig, + skip_first_batches: int = 0, + tokenizer: Optional[PreTrainedTokenizerBase] = None, +) -> None: + """Inner training loop for bypass distillation.""" + device = torch.device(f"cuda:{dist.local_rank()}") + + dist.barrier() + + # Anchor the time-based save interval at training start, not module import. + # Earlier this was a module-level `time_start = time.time()`, which made + # the first time-based save fire immediately if the module was imported + # well before train() actually ran (e.g. via test collection or Hydra config + # resolution). + time_last_save = time.time() + iter_t0 = time.time() + + resumed_iter_num = cfg.bypass.iter_num + mprint(f"resumed_iter_num: {resumed_iter_num}") + + # Number of total stitched modules + global_stitched_modules_count = len(stitched_modules_process_ownership) + # Number of stitched modules per process + num_stitched_modules_per_process = [ + sum(1 for x in stitched_modules_process_ownership if x == owner_rank) + for owner_rank in range(dist.size()) + ] + ownership_context = get_pipeline_ownership_context(stitched_modules_process_ownership) + owned_stitched_module_indices = ownership_context["owned_indices"] + mprint(f"{global_stitched_modules_count=}") + mprint(f"{num_stitched_modules_per_process=}") + dist.barrier() + + if dist.is_master(): + # {iter_num: {stitched_module_name: loss}} + stitched_losses_history = dict[IterNum, dict[str, float]]() + else: + stitched_losses_history = None + + # Save checkpoint before training starts + if cfg.bypass.save_checkpoint_before_training and not cfg.bypass.disable_checkpoint_save: + subdir_name = f"start-step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + subdir_name=subdir_name, + checkpoint_role="start", + ) + + # Track statistics for each iteration + iter_stats_history: dict[IterNum, IterStatistics] = {} + + # Create fake input ids for the teacher model + fake_input_ids = fake_tensor( + torch.ones( + size=(cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + device=device, + ) + ) + + prev_rank: Optional[int] = ownership_context["prev_rank"] + next_rank: Optional[int] = ownership_context["next_rank"] + + torch.cuda.synchronize() + + mprint( + f"Grad scaling status: {'enabled' if cfg.bypass.training.use_grad_scaling else 'disabled'}" + ) + + # Only master consumes the dataloader — `next(train_iterator)` is gated by + # `if dist.is_master()` further down. Building the iterator (or running + # skip_first_batches against it) on non-master ranks wastes startup time + # and memory proportional to the dataset, since each tokenizes the full + # corpus only to throw it away. + train_iterator = None + if dist.is_master(): + assert train_dataloader is not None + train_iterator = iter(train_dataloader) + + # Advance past the first `skip_first_batches` batches before the training loop + # starts. Used either to skip a known-bad batch range during debugging, or to + # roll the data iterator forward when resuming a run (model + optimizer state + # are restored from the checkpoint, but the dataloader itself starts fresh). + if dist.is_master() and skip_first_batches > 0: + assert train_iterator is not None + mprint(f"Skipping first {skip_first_batches} batches before training") + for _ in range(skip_first_batches): + next(train_iterator) + + mprint("Waiting for everyone before training starts") + dist.barrier() + + step_to_save = None + # Track best loss value for each block. Seeded from cfg.bypass so resume + # picks up where the previous run left off (run_bypassed_training restores + # these from args.json before train_pipeline_parallel runs). + best_losses_by_name: dict[str, float] = dict(cfg.bypass.get("best_losses_by_name", {})) + best_steps_by_name: dict[str, int] = dict(cfg.bypass.get("best_steps_by_name", {})) + # Anchor for the "Δ from initial" column: per-block loss from the first log chunk. + initial_losses_by_name: dict[str, float] = dict(cfg.bypass.get("initial_losses_by_name", {})) + non_trainable_stitched_module_names = { + name + for name, descriptor in stitched_module_descriptors.items() + if descriptor.optimizer is None + } + + # log_interval is in optimizer-step units; multiply by grad_accum to land in + # micro-batch units, which is what the per-iter loss collection counts. + iters_per_log_chunk = ( + cfg.bypass.training.log_interval * cfg.bypass.training.grad_accumulation_steps + ) + # Per-rank local buffer of {iter_num: {block_name: loss}}. We accumulate + # losses locally on every rank and only collide them via all_gather_object + # at log-chunk boundaries — the object collective is pickle-based and + # was previously the per-iter sync cost. See `_flush_loss_buffer` below. + local_losses_buffer: dict[int, dict[str, float]] = {} + # Buffer variables. Initialise on the active device so non-master ranks + # never hand a CPU tensor to a downstream GPU op if the master-only-fetch + # invariant is ever relaxed (today only master replaces this in the loop). + input_ids = torch.zeros(1, 1, dtype=torch.int64, device=device) + + aprint( + f"previous rank: {str(prev_rank):<5} next rank: {str(next_rank):<5} {owned_stitched_module_indices=}" + ) + + # Train loop start + while True: + time_now = time.time() + # Check if we've reached the maximum number of steps. `step_num` is 1-based + # and incremented at the END of each iteration, so we must use `>` (not `>=`) + # to ensure step `max_steps` itself runs before exiting. + if cfg.bypass.step_num > cfg.bypass.training.max_steps: + # Drain any residual buffered losses (< log-chunk boundary) so the + # final partial chunk's stats reach master and can be logged before + # the function returns. Must run on every rank — collective op. + _flush_loss_buffer(local_losses_buffer, stitched_losses_history) + local_losses_buffer.clear() + if ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and not cfg.bypass.disable_checkpoint_save + ): + mprint("Saving final checkpoint before training completion") + subdir_name = f"final-step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_role="final", + subdir_name=subdir_name, + cleanup_glob="step-*", + ) + break + + is_accumulating = cfg.bypass.iter_num % cfg.bypass.training.grad_accumulation_steps != 0 + # Determine and set the learning rate for this iteration + lr = ( + _get_lr(cfg, cfg.bypass.step_num) + if cfg.bypass.training.decay_lr + else cfg.bypass.training.learning_rate + ) + for stitched_module_descriptor in stitched_module_descriptors.values(): + optimizer = stitched_module_descriptor.optimizer + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + if dist.is_master(): + assert train_iterator is not None + train_data = next(train_iterator) + input_ids = train_data["input_ids"] + input_ids = input_ids.to(device) + + with _autocast_context(descriptor), torch.no_grad(): + teacher_input_ids = input_ids if prev_rank is None else fake_input_ids + teacher_output = teacher_stitched_model({}, {}, teacher_input_ids) + + input_overrides = teacher_output.captured_inputs + output_overrides = teacher_output.captured_outputs + + del teacher_output + + input_overrides["teacher_inputs"] = InputArgs(fake_input_ids) + + # Collect per-block loss tensors and batch the GPU→CPU copy to a + # single sync point at the end of the per-block loop. Doing + # ``.to("cpu").item()`` per block forced one CUDA synchronization per + # block per iter, serialising the GPU pipeline across N blocks. + iter_loss_tensors: dict[str, torch.Tensor] = {} + + for local_stitched_module_index, ( + stitched_module_name, + stitched_module_descriptor, + ) in enumerate(stitched_module_descriptors.items()): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + assert grad_scaler is not None + + with _autocast_context(descriptor): + stitched_module_output = stitched_module( + input_overrides=input_overrides, + output_overrides=output_overrides, + ) + stitched_module_loss = stitched_module_output.captured_outputs["loss"] + del stitched_module_output + scaled_stitched_module_loss = ( + stitched_module_loss / cfg.bypass.training.grad_accumulation_steps + ) + grad_scaler.scale(scaled_stitched_module_loss).backward() + iter_loss_tensors[stitched_module_name] = stitched_module_loss.detach() + del scaled_stitched_module_loss + else: + # No real trainable parameters on this rank/block. Keep this out + # of the numeric loss stream so genuine non-finite losses from + # trainable blocks remain visible instead of being conflated with + # an intentional "not trainable" sentinel. + stitched_module_loss = None + + del stitched_module_loss + + if not is_accumulating: + if optimizer is not None: + grad_clip = cfg.bypass.training.grad_clip + if grad_clip is not None: + if cfg.bypass.training.grad_clip_type == "norm": + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=stitched_module.parameters(), + max_norm=grad_clip, + ) + if grad_norm > grad_clip: + cfg.bypass.training.clipping_count += 1 + elif cfg.bypass.training.grad_clip_type == "value": + # Stack per-param maxes into a single GPU tensor and + # reduce before `.item()` so we sync once per block + # instead of once per parameter (see per-block batching + # rationale at lines 301-304). + grad_maxes = [ + p.grad.abs().max() + for p in stitched_module.parameters() + if p.grad is not None + ] + if grad_maxes: + max_abs_grad = torch.stack(grad_maxes).max().item() + else: + max_abs_grad = 0.0 + if max_abs_grad > grad_clip: + cfg.bypass.training.clipping_count += 1 + torch.nn.utils.clip_grad_value_( + parameters=stitched_module.parameters(), + clip_value=grad_clip, + ) + else: + raise RuntimeError(f"Invalid {cfg.bypass.training.grad_clip_type}") + + assert grad_scaler is not None + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + + # Single GPU→CPU sync for all per-block losses collected above. Stacking + # into a 1-D tensor lets us issue exactly one ``.to("cpu")`` instead of + # one per block. + if iter_loss_tensors: + loss_stack = torch.stack([t.flatten()[0] for t in iter_loss_tensors.values()]) + iter_stitched_module_losses: dict[str, float] = dict( + zip(iter_loss_tensors.keys(), loss_stack.to("cpu").tolist()) + ) + else: + iter_stitched_module_losses = {} + + if dist.is_master() and cfg.bypass.iter_num == resumed_iter_num: + mprint(f"Starting from iter {cfg.bypass.iter_num}") + + # Buffer this rank's per-block losses locally. The collide-across-ranks + # gather happens only at log-chunk boundaries (`_flush_loss_buffer`), + # which cuts the per-iter pickle-based all_gather_object cost down to + # one gather per `iters_per_log_chunk` micro-batches. + local_losses_buffer[cfg.bypass.iter_num] = iter_stitched_module_losses + if len(local_losses_buffer) >= iters_per_log_chunk: + _flush_loss_buffer(local_losses_buffer, stitched_losses_history) + local_losses_buffer.clear() + + cfg.bypass.token_count += cfg.bypass.training.tokens_per_iter + iter_t1 = time.time() + iter_duration = iter_t1 - iter_t0 + iter_stats_history[cfg.bypass.iter_num] = IterStatistics( + token_count=cfg.bypass.token_count, + iter_duration=iter_duration, + step_num=cfg.bypass.step_num, + lr=lr, + clipping_count=cfg.bypass.training.clipping_count, + ) + iter_t0 = iter_t1 + + # Time-based save signal (broadcast from master) + save_signal = [step_to_save] + if dist.is_master(): + if cfg.bypass.model.model_overrides.save_interval_seconds is not None: + time_now = time.time() + if ( + time_now - time_last_save + >= cfg.bypass.model.model_overrides.save_interval_seconds + ): + mprint( + f"Time to save! {cfg.bypass.model.model_overrides.save_interval_seconds=}, " + f"{time_last_save=}, {time_now=}" + ) + step_to_save = cfg.bypass.step_num + 5 + save_signal = [step_to_save] + time_last_save = time_now + + step_to_save = dist.broadcast(save_signal[0], src=0) + + # Logging + if dist.is_master(): + assert stitched_losses_history is not None + # `iters_per_log_chunk` is computed once before the loop (in + # micro-batch units = log_interval × grad_accum) and reused for + # both the gather-batching threshold and this log drain. + while len(stitched_losses_history) >= iters_per_log_chunk: + lowest_iter = next(iter(stitched_losses_history.keys())) + + log_chunk = { + it: losses + for it, losses in stitched_losses_history.items() + if it - lowest_iter < iters_per_log_chunk + } + if len(log_chunk) < iters_per_log_chunk: + break + + highest_iter = list(log_chunk.keys())[-1] + highest_iter_stats = iter_stats_history[highest_iter] + + losses_by_name = defaultdict[str, list[float]](list) + for losses in log_chunk.values(): + for name, loss in losses.items(): + losses_by_name[name].append(loss) + + losses_by_name_avg = {name: mean(losses) for name, losses in losses_by_name.items()} + non_finite_losses_by_name = { + name: loss + for name, loss in losses_by_name_avg.items() + if not math.isfinite(loss) + } + if non_finite_losses_by_name: + cfg.bypass.non_finite_losses_by_name = dict(non_finite_losses_by_name) + mprint(f"Non-finite stitched losses detected: {non_finite_losses_by_name}") + + # Anchor "Δ from initial" at the very first iter's per-block losses + # (lowest_iter — typically iter 1 on a fresh run, the resumed iter + # otherwise). Using the first chunk's *average* would tautologically + # make Δ == 0 on the first row, since "Loss Value" is that same average. + if not initial_losses_by_name: + initial_losses_by_name.update(stitched_losses_history[lowest_iter]) + + # Update best losses tracking. Record the optimizer-step number + # so the "Best Step" column matches the header's "step N/max" units. + for name, current_loss in losses_by_name_avg.items(): + if not math.isfinite(current_loss): + continue + if name not in best_losses_by_name or current_loss < best_losses_by_name[name]: + best_losses_by_name[name] = current_loss + best_steps_by_name[name] = highest_iter_stats.step_num + + # Mirror to cfg.bypass so save_bypass_checkpoint's args.json snapshot + # carries these forward across resumes. + cfg.bypass.best_losses_by_name = dict(best_losses_by_name) + cfg.bypass.best_steps_by_name = dict(best_steps_by_name) + cfg.bypass.initial_losses_by_name = dict(initial_losses_by_name) + + chunk_iter_durations = [ + iter_stats_history[it].iter_duration for it in log_chunk.keys() + ] + avg_chunk_iter_duration = mean(chunk_iter_durations) + # Report time in step units (= grad_accum × per-iter), since one + # step is one optimizer update — what the user actually thinks of + # as "a training step." Tokens/sec is invariant to that framing. + avg_step_time = ( + avg_chunk_iter_duration * cfg.bypass.training.grad_accumulation_steps + ) + avg_token_speed = cfg.bypass.training.tokens_per_iter / avg_chunk_iter_duration + mprint( + f"step {highest_iter_stats.step_num}/{cfg.bypass.training.max_steps:,}:" + f" avg_step_time={avg_step_time * 1000:.2f}ms" + f" avg_token_speed={avg_token_speed:,.0f}[tok/s]" + ) + mprint( + format_stitched_losses( + losses_dict=losses_by_name_avg, + best_steps_dict=best_steps_by_name, + best_values_dict=best_losses_by_name, + initial_values_dict=initial_losses_by_name, + not_trainable_names=non_trainable_stitched_module_names, + step_number=highest_iter_stats.step_num, + title="Stitched Module Losses", + ) + ) + + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.log( + { + "step": highest_iter_stats.step_num, + "token_count": highest_iter_stats.token_count, + "token_speed": avg_token_speed, + "lr": highest_iter_stats.lr, + "grad_clipping": highest_iter_stats.clipping_count, + }, + step=highest_iter_stats.step_num, + ) + except ImportError: + pass + + for it in log_chunk.keys(): + del iter_stats_history[it] + del stitched_losses_history[it] + + # Validation + if ( + not is_accumulating + and (cfg.bypass.step_num % cfg.bypass.training.eval_interval) == 0 + and val_dataloader is not None + ): + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + losses, _ = calculate_losses_pipeline( + stitched_model=student_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + + val_loss = float("inf") + if losses is not None and "lm_loss" in losses: + val_loss = losses["lm_loss"]["avg"] + mprint(f"Validation loss at iter {cfg.bypass.iter_num}: {val_loss:.4f}") + + # Broadcast val_loss so all ranks agree on checkpoint decisions + val_loss = dist.broadcast(val_loss, src=dist.size() - 1) + + if val_loss < cfg.bypass.best_val_loss: + cfg.bypass.best_val_loss = val_loss + if not cfg.bypass.disable_checkpoint_save and cfg.bypass.save_best_ckpt: + subdir_name = f"best-step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_role="best", + subdir_name=subdir_name, + cleanup_glob="best-step-*", + ) + if cfg.bypass.kill_after_first_save: + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") + + # Checkpoint saving (step-based or time-based) + if not is_accumulating and ( + (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0 + or step_to_save == cfg.bypass.step_num + ): + if not cfg.bypass.disable_checkpoint_save: + if (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0: + mprint("Saving step-interval checkpoint") + elif step_to_save == cfg.bypass.step_num: + mprint("Saving time-based checkpoint") + + subdir_name = f"step-{cfg.bypass.step_num:06d}-ckpt" + _save_training_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_role="resume", + subdir_name=subdir_name, + cleanup_glob="step-*", + ) + + if cfg.bypass.kill_after_first_save: + dist.barrier() + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") + + cfg.bypass.iter_num += 1 + if not is_accumulating: + cfg.bypass.step_num += 1 + + mprint("Finished successfully!") + + +# Learning rate decay scheduler (cosine with warmup) +def _get_lr(cfg: DictConfig, step: int) -> float: + warmup_steps = cfg.bypass.training.warmup_steps + lr_decay_steps = cfg.bypass.training.lr_decay_steps + # Degenerate budget (e.g. tiny `training_tokens` in tests): no room for cosine decay. + # Skip warmup/decay entirely and return base LR — avoids ZeroDivisionError on + # `lr_decay_steps - warmup_steps` and `step / warmup_steps`. + if lr_decay_steps <= warmup_steps: + return cfg.bypass.training.learning_rate + + # 1) linear warmup for warmup_steps steps + if step <= warmup_steps: + if warmup_steps == 0: + # Defensive: training loop's step starts at 1 so this branch is + # unreachable today, but a future caller passing step=0 would hit + # a ZeroDivisionError on `step / warmup_steps` below. + return cfg.bypass.training.learning_rate + lr = cfg.bypass.training.learning_rate * step / warmup_steps + # 2) if step > lr_decay_steps, return min learning rate + elif step > lr_decay_steps: + lr = cfg.bypass.training.min_lr + # 3) in between, use cosine decay down to min learning rate + else: + decay_ratio = (step - warmup_steps) / (lr_decay_steps - warmup_steps) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + lr = cfg.bypass.training.min_lr + coeff * ( + cfg.bypass.training.learning_rate - cfg.bypass.training.min_lr + ) + + return lr + + +def run_bypassed_training(cfg: DictConfig): + """Setup and orchestrate bypass distillation training.""" + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.WARN + ) + + # Suppress debug messages from HuggingFace libraries + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + device = torch.device(f"cuda:{dist.local_rank()}") + + set_experiment_id(cfg) + set_experiment_dir(cfg) + if bypass_run_is_complete(cfg): + mprint(f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping") + return + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = _resolve_trust_remote_code(cfg, descriptor) + OmegaConf.update(cfg, "bypass.trust_remote_code", trust_remote_code, force_add=True) + teacher_model_config = load_model_config(cfg.teacher_dir, trust_remote_code=trust_remote_code) + + try: + mprint("Waiting for distributed setup...") + dist.barrier() + + if cfg.bypass.disable_initial_validate: + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + if cfg.bypass.teacher_model_load_on_cpu: + assert not cfg.bypass.validate_teacher_model, ( + "Teacher model validation is too slow on CPU" + ) + + num_hidden_layers = descriptor.get_language_model_config( + teacher_model_config + ).num_hidden_layers + + model_blocks_process_ownership = get_distributed_modules_ownership( + module_count=num_hidden_layers, + world_size=dist.size(), + ) + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + cfg.teacher_dir = str(Path(cfg.teacher_dir).expanduser()) + teacher_model_config = load_model_config( + cfg.teacher_dir, + trust_remote_code=trust_remote_code, + ) + # Disable KV cache during bypass forward passes. Set the attribute directly rather + # than passing it as an AutoConfig override — some custom configs (GptOss, Qwen3-VL, etc.) + # don't accept it as a known kwarg and would raise via the strict unused-kwargs check. + if hasattr(teacher_model_config, "use_cache"): + teacher_model_config.use_cache = False + if hasattr(teacher_model_config, "text_config") and hasattr( + teacher_model_config.text_config, "use_cache" + ): + teacher_model_config.text_config.use_cache = False + + # Resume detection has to run BEFORE the weight-loading branch below + # so a resume can route through ``load_and_shard_model`` (the HF + # checkpoint at ``resume_checkpoint_path`` is now the single source + # of truth for weights — see _save_local_state docstring). + # set_experiment_id / set_experiment_dir are idempotent and only + # depend on cfg.bypass.model.model_config_overrides + cfg.puzzle_dir, + # so it's safe to call them this early. + resume_checkpoint_path: Optional[str] = None + resume_cfg: Optional[DictConfig] = None + resume_skip_first_batches = cfg.bypass.training.skip_first_batches + if cfg.bypass.resume_checkpoint_path is not None: + resume_checkpoint_path = cfg.bypass.resume_checkpoint_path + elif cfg.bypass.find_last_ckpt_for_resume: + _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) + if _ckpt_dir is None: + mprint("Couldn't find any run dir for resume, assuming this is the first job") + else: + mprint( + f"`cfg.bypass.find_last_ckpt_for_resume` is True. " + f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" + ) + resume_checkpoint_path = _ckpt_dir + + resume_state_path = _get_resume_state_path(cfg, resume_checkpoint_path) + if resume_state_path: + resume_cfg = DictConfig(json_load(Path(resume_state_path) / "args.json")) + saved_skip = resume_cfg.training.get( + "skip_first_batches", cfg.bypass.training.skip_first_batches + ) + resume_skip_first_batches = saved_skip + resume_cfg.iter_num + if "data" in resume_cfg and "shuffle_train_data_seed" in resume_cfg.data: + cfg.bypass.data.shuffle_train_data_seed = resume_cfg.data.shuffle_train_data_seed + if "seed" in resume_cfg: + cfg.bypass.seed = resume_cfg.seed + + # Both ``init_checkpoint_path`` and ``resume_checkpoint_path`` point at + # an HF-format directory; share the same loader. ``init_checkpoint_path`` + # wins if both are set (explicit user override beats auto-detect). + weight_load_path = cfg.bypass.init_checkpoint_path or resume_state_path + student_model = None + if weight_load_path is not None: + mprint(f"Loading student model from {weight_load_path}") + student_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=weight_load_path, + owned_block_indexes=owned_block_indexes, + trust_remote_code=trust_remote_code, + ) + + cfg.bypass.training.min_lr = ( + cfg.bypass.training.learning_rate * cfg.bypass.training.min_lr_factor + ) + cfg.bypass.training.batch_size_per_iter = cfg.bypass.training.micro_batch_size + cfg.bypass.training.tokens_per_iter = ( + cfg.bypass.data.block_size * cfg.bypass.training.batch_size_per_iter + ) + requested_iters = math.ceil( + cfg.bypass.training.training_tokens / cfg.bypass.training.tokens_per_iter + ) + # The loop steps optimizers only after a full grad-accum window, so round + # the requested token budget up to complete optimizer-step units and report + # that actual budget back to the user. + cfg.bypass.training.max_steps = math.ceil( + requested_iters / cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_iters = ( + cfg.bypass.training.max_steps * cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_token_count = ( + cfg.bypass.training.max_iters * cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.lr_decay_steps = cfg.bypass.training.max_steps + + if cfg.bypass.training.val_micro_batch_size is None: + cfg.bypass.training.val_micro_batch_size = cfg.bypass.training.micro_batch_size + + if cfg.bypass.training.warmup_steps is None: + cfg.bypass.training.warmup_steps = 0 + + mprint(f"\n{format_global_config(cfg.bypass, 'Bypass Configurations')}") + mprint(f"Max token count: {cfg.bypass.training.max_token_count:,}") + + seed = cfg.bypass.seed + torch.manual_seed(seed) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.teacher_dir, + trust_remote_code=trust_remote_code, + token=True, + ) + + assert teacher_model_config is not None + + mprint(f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}") + teacher_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.teacher_dir, + owned_block_indexes=owned_block_indexes, + model_config=teacher_model_config, + trust_remote_code=trust_remote_code, + ) + + teacher_model.requires_grad_(False) + + # Create dataloaders + from modelopt.torch.puzzletron.utils.data.dataloaders import ( + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + ) + + if cfg.bypass.data.eval_samples_per_process is not None: + max_eval_samples = cfg.bypass.data.eval_samples_per_process * dist.size() + else: + max_eval_samples = cfg.bypass.data.max_eval_samples + + load_dataset_fn = ( + load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + ) + + # Only master ever fetches from the train dataloader (training_loop.train + # gates `next(train_iterator)` on `dist.is_master()`), so skip the + # potentially-large HF dataset load + tokenisation on non-master ranks. + if dist.is_master(): + train_dataloader = create_train_dataloader( + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset_path=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.micro_batch_size, + load_dataset_fn=load_dataset_fn, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.data.get( + "source_datasets_to_discard", tuple() + ), + bos_rate=cfg.bypass.data.bos_rate, + shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, + ) + else: + train_dataloader = None + + val_dataloader = None + # Note: val_dataloader is kept constructed on every rank even though only + # master reads from it inside calculate_losses_pipeline. The validation + # block uses `val_dataloader is not None` as a "validation enabled" gate + # that must agree across ranks — and calculate_losses_pipeline itself is + # pipeline-parallel and requires every rank to enter it. Skipping + # construction on non-master ranks would break those invariants. + if not cfg.bypass.disable_validation: + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.val_micro_batch_size, + eval_samples=max_eval_samples, + load_dataset_fn=load_dataset_fn, + dataset_name=cfg.bypass.data.val_dataset_name, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.data.get( + "source_datasets_to_discard", tuple() + ), + bos_rate=cfg.bypass.data.bos_rate, + ) + + # set_experiment_id / set_experiment_dir already ran above (before + # weight loading) so the resume detection could use experiment_dir. + + dist.barrier() + + with torch.device(device): + stitched_model_factory_fn = getattr( + stitched_model_factory_module, cfg.bypass.model_factory.factory + ) + ( + student_model, + teacher_stitched_model, + teacher_val_stitched_module, + student_val_stitched_model, + stitched_module_descriptors, + student_model_config, + ) = stitched_model_factory_fn( + teacher_model=teacher_model, + descriptor=descriptor, + cfg=cfg.bypass, + model_blocks_process_ownership=model_blocks_process_ownership, + student_model=student_model, + ) + + # ``resume_state_path`` was determined earlier (before weight + # loading); the student weights are already in place via + # ``load_and_shard_model``. Only the optimizer/scaler state needs to + # be restored from the per-block ``stitched/`` files. + if resume_state_path: + load_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_path=resume_state_path, + ) + + assert resume_cfg is not None + + # Periodic checkpoints are saved before the loop increments counters, + # so their args.json is inclusive and needs a +1 bump. Final + # checkpoints are saved after the loop already advanced beyond the + # last completed step, so their counters are already the next values. + resume_from_final = Path(resume_state_path).name.startswith("final-step-") + counter_bump = 0 if resume_from_final else 1 + cfg.bypass.iter_num = resume_cfg.iter_num + counter_bump + cfg.bypass.token_count = resume_cfg.token_count + cfg.bypass.step_num = resume_cfg.step_num + counter_bump + cfg.bypass.best_val_loss = resume_cfg.best_val_loss + cfg.bypass.training.clipping_count = resume_cfg.training.clipping_count + # Per-block bookkeeping. .get() defaults handle resume from older ckpts + # that predate these fields. + cfg.bypass.best_losses_by_name = resume_cfg.get("best_losses_by_name", {}) + cfg.bypass.best_steps_by_name = resume_cfg.get("best_steps_by_name", {}) + cfg.bypass.initial_losses_by_name = resume_cfg.get("initial_losses_by_name", {}) + mprint(f"Resume from iter_num: {cfg.bypass.iter_num}") + + # Only copy wandb.run_id if it exists in resume config + if hasattr(resume_cfg, "wandb") and hasattr(resume_cfg.wandb, "run_id"): + cfg.bypass.wandb.run_id = resume_cfg.wandb.run_id + + cfg.bypass.save_checkpoint_before_training = False + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + cfg.bypass.resume_checkpoint_path = resume_state_path + + # Initialize Weights and Biases + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.init( + project=cfg.bypass.wandb.project, + entity=cfg.bypass.wandb.entity, + config=dict(cfg.bypass), + ) + except ImportError: + mprint("wandb not installed, disabling wandb logging") + cfg.bypass.wandb_log = False + else: + mprint("Weights & Biases logging disabled (wandb_log=False)") + + if cfg.bypass.validate_teacher_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Evaluating teacher model:") + losses, _ = calculate_losses_pipeline( + stitched_model=teacher_val_stitched_module, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Teacher validation losses: {losses}") + mprint("Evaluated teacher model") + + torch.cuda.empty_cache() + dist.barrier() + + parameter_count = sum(p.numel() for p in student_model.parameters()) + aprint(f"Model parameter count: {parameter_count:,}") + cfg.bypass.parameter_count = parameter_count + + dist.barrier() + mprint("Performing dummy runs on stitched modules:") + torch.cuda.synchronize() + with ( + torch.no_grad(), + _autocast_context(descriptor), + torch.device(device), + ): + input_ids = torch.ones( + (cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + ) + dummy_fake_input_ids = fake_tensor(input_ids) + mprint(f"Dummy runs on stitched modules with shape: {dummy_fake_input_ids.shape=}") + teacher_output = teacher_stitched_model({}, {}, input_ids) + for stitched_module_descriptor in stitched_module_descriptors.values(): + stitched_module = stitched_module_descriptor.stitched_module + stitched_module( + input_overrides={ + **teacher_output.captured_inputs, + "teacher_inputs": InputArgs(dummy_fake_input_ids), + }, + output_overrides=teacher_output.captured_outputs, + ) + for name, param in stitched_module.named_parameters(recurse=True): + if "iter_num" in name: + param.data = torch.zeros_like(param.data) + del name, param + del input_ids, dummy_fake_input_ids, teacher_output + torch.cuda.synchronize() + dist.barrier() + + del teacher_model + + if cfg.bypass.validate_student_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Validating model before training:") + losses, _ = calculate_losses_pipeline( + stitched_model=student_val_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Student validation losses: {losses}") + + dist.barrier() + torch.cuda.empty_cache() + dist.barrier() + + train( + cfg=cfg, + descriptor=descriptor, + student_model=student_model, + student_stitched_model=student_val_stitched_model, + teacher_stitched_model=teacher_stitched_model, + stitched_module_descriptors=stitched_module_descriptors, + stitched_modules_process_ownership=model_blocks_process_ownership, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + student_model_config=student_model_config, + skip_first_batches=resume_skip_first_batches, + tokenizer=tokenizer, + ) + + aprint("Finished training successfully!") + dist.barrier() + + except Exception: + # Print the traceback explicitly so distributed runs surface it on every + # rank's stderr (workers under torchrun otherwise lose ordering), then + # re-raise so test frameworks see the real exception instead of a + # generic SystemExit(1). + print(traceback.format_exc(), file=sys.stderr) + raise + + dist.barrier() + if dist.is_master(): + mprint("Realizing bypass checkpoints") + realized_checkpoint, ckpts_symlink = realize_bypass_checkpoints(cfg) + mark_bypass_run_completed(cfg, realized_checkpoint, ckpts_symlink) + dist.barrier() + + +def realize_bypass_checkpoints(cfg: DictConfig) -> tuple[Path, Path]: + """Create symlinks from bypass checkpoint directories to the ckpts directory.""" + state = load_bypass_state(cfg.bypass.experiment_dir) or {} + checkpoints = state.get("checkpoints", {}) + realize_mode = cfg.bypass.get("realize_best_or_latest", "latest") + if realize_mode == "best": + role_preference = ("best", "final", "resume") + elif realize_mode == "latest": + role_preference = ("final", "resume", "best") + else: + raise ValueError(f"Invalid bypass.realize_best_or_latest={realize_mode!r}") + + checkpoint_dir = None + for role in role_preference: + candidate = checkpoints.get(role) + if candidate and Path(candidate).exists(): + checkpoint_dir = Path(candidate).resolve() + break + + if checkpoint_dir is None: + fallback = Path(cfg.bypass.experiment_dir) / "latest" + if fallback.exists(): + checkpoint_dir = fallback.resolve() + else: + raise FileNotFoundError( + f"Could not find a bypass checkpoint to realize in {cfg.bypass.experiment_dir}" + ) + + ckpts_dir = Path(cfg.puzzle_dir) / "ckpts" + ckpts_dir.mkdir(parents=True, exist_ok=True) + + symlink_name = ckpts_dir / cfg.bypass.experiment_id + if symlink_name.exists() or symlink_name.is_symlink(): + symlink_name.unlink() + + symlink_name.symlink_to(checkpoint_dir.resolve(), target_is_directory=True) + mprint(f"Created symlink: {symlink_name} -> {checkpoint_dir}") + return checkpoint_dir, symlink_name diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 999ec6c690a..b4edbdd385c 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -43,6 +43,7 @@ from ..anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory from ..block_config import AttentionConfig, BlockConfig, FFNConfig +from ..bypass_distillation.bypass_utils import learned_subblocks_from_keys_to_learn from ..mip.utils import sort_replacements from ..tools.checkpoint_utils import ( SAFETENSORS_SUBBLOCKS_DIR_NAME, @@ -459,14 +460,7 @@ def _infer_subblocks_to_extract( else: bypass_config = json.loads(bypass_config_path.read_text()) keys_to_learn = bypass_config.get("keys_to_learn", "entire_block") - if keys_to_learn == "entire_block": - subblocks_to_extract = ["block"] - elif "mlp" in keys_to_learn and "attn" not in keys_to_learn: - subblocks_to_extract = ["ffn"] - elif "attn" in keys_to_learn and "mlp" not in keys_to_learn: - subblocks_to_extract = ["attention"] - else: - raise ValueError(f"Unrecognized {keys_to_learn=}") + subblocks_to_extract = learned_subblocks_from_keys_to_learn(keys_to_learn) return subblocks_to_extract diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 1240d1c9b65..e5f6fb5df91 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -21,7 +21,10 @@ import concurrent.futures import dataclasses import fcntl +import inspect import os +import re +import shutil import time from collections import defaultdict from collections.abc import Callable, Mapping @@ -88,10 +91,20 @@ def force_cache_dynamic_modules( and "AutoConfig" in config.auto_map.keys() ) if has_remote_code and trust_remote_code: - for class_reference in config.auto_map.values(): + for class_reference in _iter_auto_map_class_refs(config.auto_map): _ = get_class_from_dynamic_module(class_reference, checkpoint_dir) +def _iter_auto_map_class_refs(auto_map: Mapping[str, Any]): + for value in auto_map.values(): + if isinstance(value, str): + yield value + elif isinstance(value, (list, tuple)): + for item in value: + if isinstance(item, str): + yield item + + def load_model_config( checkpoint_dir: Path | str, model_config_overrides: Mapping | None = None, @@ -135,16 +148,23 @@ def load_model_config( return config +_FALLBACK_WARNED_CLASSES: set[str] = set() + + def _get_model_class_from_config(config: PretrainedConfig) -> type: """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf).""" if hasattr(config, "architectures") and config.architectures: model_class_name = config.architectures[0] if hasattr(transformers, model_class_name): return getattr(transformers, model_class_name) - mprint( - f"Warning: {model_class_name} not found in transformers, " - "falling back to AutoModelForCausalLM" - ) + # Warn at most once per missing class per process — the fallback path + # may be hit thousands of times during scoring/realize loops. + if model_class_name not in _FALLBACK_WARNED_CLASSES: + _FALLBACK_WARNED_CLASSES.add(model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, " + "falling back to AutoModelForCausalLM" + ) return AutoModelForCausalLM @@ -209,10 +229,9 @@ def save_checkpoint_from_shards( """ Save a checkpoint when the model's weights are sharded across distributed ranks. - Gathers each rank's partial state dictionary onto rank 0 and writes a complete checkpoint - (including the safetensors index and subblocks) from the merged weights. On a single-process - run, saves directly from the local state dict. Only rank 0 performs the filesystem write; - non-master ranks only participate in the gather. + On distributed runs, rank 0 gathers only tensor-name metadata up front and then gathers + tensors one safetensors file at a time. This avoids materializing the full model from all + ranks on rank 0 while still producing a single HF-compatible checkpoint/index. Parameters: model (PreTrainedModel): The model instance whose local state_dict contains this rank's @@ -222,31 +241,113 @@ def save_checkpoint_from_shards( the safetensors index. """ - local_sd = {k: v.cpu() for k, v in model.state_dict().items()} + local_sd = {k: v.detach().cpu() for k, v in model.state_dict().items()} if dist_utils.size() > 1: - save_err: str | None = None + _save_checkpoint_from_distributed_shards(model.config, local_sd, checkpoint_dir, descriptor) + dist_utils.barrier() + else: + _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) + + +def _save_checkpoint_from_distributed_shards( + model_config: PretrainedConfig, + local_state_dict: dict[str, torch.Tensor], + checkpoint_dir: Path | str, + descriptor: "ModelDescriptor", +) -> None: + if not isinstance(checkpoint_dir, Path): + checkpoint_dir = Path(checkpoint_dir) + + local_keys = list(local_state_dict.keys()) + gathered_keys: list[list[str] | None] | None = ( + [None] * dist_utils.size() if dist_utils.is_master() else None + ) + tdist.gather_object(local_keys, gathered_keys, dst=0) + + owner_by_key = None + weight_map = None + setup_err = None + if dist_utils.is_master(): + try: + assert gathered_keys is not None + checkpoint_dir.mkdir(parents=True, exist_ok=True) + save_model_config(model_config, checkpoint_dir) + + # Match the old full_sd.update(rank_order) behavior for duplicate tied + # weights by letting the highest rank that owns a key supply it. + owner_by_key = { + key: rank + for rank, keys in enumerate(gathered_keys) + if keys is not None + for key in keys + } + full_keys = list(owner_by_key) + + output_emb_weight_name = f"{descriptor.output_embedding_name()}.weight" + if getattr(model_config, "tie_word_embeddings", False): + owner_by_key.pop(output_emb_weight_name, None) + full_keys = [key for key in full_keys if key != output_emb_weight_name] + + lm_config = descriptor.get_language_model_config(model_config) + subblock_keys = descriptor.get_weight_groups( + layer_names=full_keys, + num_hidden_layers=lm_config.num_hidden_layers, + ) + weight_map = { + key: f"subblocks_safetensors/{subblock}.safetensors" + for subblock, layer_keys in subblock_keys.items() + for key in layer_keys + } + + index = {"metadata": {"format": "pt"}, "weight_map": weight_map} + _write_file_process_safe(json_dumps(index), checkpoint_dir / SAFE_WEIGHTS_INDEX_NAME) + (checkpoint_dir / SAFETENSORS_SUBBLOCKS_DIR_NAME).mkdir(parents=True, exist_ok=True) + except Exception as e: + setup_err = repr(e) + owner_by_key = {} + weight_map = {} + + payload = [setup_err, owner_by_key, weight_map] + tdist.broadcast_object_list(payload, src=0) + setup_err, owner_by_key, weight_map = payload + if setup_err is not None: + raise RuntimeError(f"Checkpoint setup failed on rank 0: {setup_err}") + assert owner_by_key is not None + assert weight_map is not None + + for relative_filename in sorted(set(weight_map.values())): + local_file_tensors = { + key: tensor.contiguous() + for key, tensor in local_state_dict.items() + if owner_by_key.get(key) == dist_utils.rank() + and weight_map.get(key) == relative_filename + } + gathered_tensors: list[dict[str, torch.Tensor] | None] | None = ( + [None] * dist_utils.size() if dist_utils.is_master() else None + ) + tdist.gather_object(local_file_tensors, gathered_tensors, dst=0) if dist_utils.is_master(): - gathered: list[dict] = [None] * dist_utils.size() - tdist.gather_object(local_sd, gathered, dst=0) - full_sd: dict[str, torch.Tensor] = {} - for shard_sd in gathered: - if shard_sd is None: - continue - full_sd.update(shard_sd) - try: - _save_checkpoint(model.config, full_sd, checkpoint_dir, descriptor) - except Exception as e: - save_err = repr(e) + assert gathered_tensors is not None + file_state_dict: dict[str, torch.Tensor] = {} + for shard_tensors in gathered_tensors: + if shard_tensors: + file_state_dict.update(shard_tensors) + file_err = None + if file_state_dict: + try: + safe_save_file( + tensors=file_state_dict, + filename=checkpoint_dir / relative_filename, + metadata={"format": "pt"}, + ) + except Exception as e: + file_err = repr(e) else: - tdist.gather_object(local_sd, dst=0) - err_box = [save_err] + file_err = None + err_box = [file_err] tdist.broadcast_object_list(err_box, src=0) - # Barrier ensures all ranks wait until file I/O completes before continuing - dist_utils.barrier() if err_box[0] is not None: - raise RuntimeError(f"Checkpoint save failed on rank 0: {err_box[0]}") - else: - _save_checkpoint(model.config, local_sd, checkpoint_dir, descriptor) + raise RuntimeError(f"Checkpoint save failed for {relative_filename}: {err_box[0]}") def _save_checkpoint( @@ -265,9 +366,10 @@ def _save_checkpoint( save_model_config(model_config, checkpoint_dir) # Phase 2: Build weight map using descriptor and write index + lm_config = descriptor.get_language_model_config(model_config) subblock_keys = descriptor.get_weight_groups( layer_names=state_dict.keys(), - num_hidden_layers=model_config.num_hidden_layers, + num_hidden_layers=lm_config.num_hidden_layers, ) weight_map = {} @@ -490,6 +592,47 @@ def _build_safetensors_weight_map( return weight_map +def _copy_auto_map_code_files(model_config: PretrainedConfig, checkpoint_dir: Path) -> None: + """Copy custom modeling Python files referenced in ``auto_map`` to the checkpoint dir. + + ``PretrainedConfig.save_pretrained()`` only copies the config class's own source file + (e.g. ``configuration_nemotron_h.py``). Trust-remote-code models also need ``modeling_*.py`` + (and any other auto_map-referenced ``.py``) present alongside ``config.json``, otherwise + later ``AutoConfig.from_pretrained(..., trust_remote_code=True)`` calls fail with + "does not appear to have a file named modeling_*.py". + + We discover the source directory from the config class itself (via ``inspect.getfile``) + and copy every distinct ``.py`` referenced by the auto_map values. + """ + if not hasattr(model_config, "auto_map") or not isinstance(model_config.auto_map, dict): + return + + try: + source_dir = Path(inspect.getfile(type(model_config))).parent + except (TypeError, OSError): + # Built-in / non-file-backed config class — nothing to copy. + return + + # Module names must look like Python identifiers — refuse anything with separators + # or relative-path components so a malformed/hostile auto_map can't drive shutil.copy + # outside source_dir / checkpoint_dir. + _module_name_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + module_names = { + class_ref.split("--", 1)[-1].split(".")[0] + for class_ref in _iter_auto_map_class_refs(model_config.auto_map) + } + + for module_name in module_names: + if not _module_name_re.match(module_name): + mprint(f"Warning: skipping non-identifier auto_map module name: {module_name!r}") + continue + filename = f"{module_name}.py" + src = source_dir / filename + dst = Path(checkpoint_dir) / filename + if src.exists() and not dst.exists(): + shutil.copy(src, dst) + + def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: if hasattr(model_config, "block_configs"): model_config.block_configs = [ @@ -497,3 +640,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str for conf in model_config.block_configs ] model_config.save_pretrained(checkpoint_dir) + _copy_auto_map_code_files(model_config, Path(checkpoint_dir)) diff --git a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py index 986c5c0107f..9a9ebbaade1 100644 --- a/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py @@ -121,8 +121,12 @@ def load_and_shard_model( checkpoint_path: str | Path, owned_block_indexes: set[int] | Literal["auto"] = "auto", model_config: PretrainedConfig | None = None, + trust_remote_code: bool | None = None, ): checkpoint_path = Path(checkpoint_path) + effective_trust_remote_code = ( + descriptor.requires_trust_remote_code() if trust_remote_code is None else trust_remote_code + ) runtime = SimpleNamespace( device=torch.device(dist.local_rank()), dtype=torch.bfloat16, @@ -135,8 +139,9 @@ def load_and_shard_model( with runtime.device: if model_config is None: - trust_remote_code = descriptor.requires_trust_remote_code() - model_config = load_model_config(checkpoint_path, trust_remote_code=trust_remote_code) + model_config = load_model_config( + checkpoint_path, trust_remote_code=effective_trust_remote_code + ) num_hidden_layers = descriptor.get_language_model_config(model_config).num_hidden_layers if owned_block_indexes == "auto": @@ -159,6 +164,7 @@ def load_and_shard_model( descriptor=descriptor, model_config=model_config, owned_block_indexes=owned_block_indexes, + trust_remote_code=effective_trust_remote_code, ) if (checkpoint_path / SAFE_WEIGHTS_NAME).exists() or ( @@ -231,6 +237,7 @@ def create_sharded_model( owned_block_indexes: set[int], device: str | torch.device | None = "meta", dtype: torch.dtype | None = torch.float32, + trust_remote_code: bool | None = None, ): if isinstance(device, str): device = torch.device(device) @@ -240,10 +247,16 @@ def create_sharded_model( with EmptyInitOnDevice(device="meta", dtype=dtype): # Get model class from config.architectures (works for CausalLM, VL models, etc.) model_class = _get_model_class_from_config(model_config) - trust_remote_code = descriptor.requires_trust_remote_code() - if trust_remote_code: + effective_trust_remote_code = ( + descriptor.requires_trust_remote_code() + if trust_remote_code is None + else trust_remote_code + ) + if effective_trust_remote_code: auto_cls = _get_auto_class_for_trust_remote_code(model_config) - model = auto_cls.from_config(model_config, trust_remote_code=trust_remote_code) + model = auto_cls.from_config( + model_config, trust_remote_code=effective_trust_remote_code + ) elif model_class is AutoModelForCausalLM: model = AutoModelForCausalLM.from_config(model_config) else: diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index ea0a6fd2193..2ce44858ef5 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -16,6 +16,7 @@ import os from pathlib import Path +import pytest import torch from _test_utils.torch.transformers_models import get_tiny_tokenizer from datasets import Dataset, DatasetDict @@ -25,6 +26,47 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.export import copy_hf_ckpt_remote_code +__all__ = [ + "PUZZLETRON_FAMILIES", + "create_and_save_small_hf_model", + "save_dummy_dataset", + "setup_test_model_and_data", +] + +# Shared parametrize tuple for puzzletron GPU integration tests. +# Fields: (hf_model_name, converter, hybrid_override_pattern, has_moe_layers). +# To add a new model family, append a single pytest.param row here — every test +# that imports PUZZLETRON_FAMILIES picks it up automatically. +PUZZLETRON_FAMILIES = [ + pytest.param("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False, id="llama-3.1-8B"), + pytest.param("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B"), + pytest.param( + "mistralai/Mistral-Small-24B-Instruct-2501", + "mistral_small", + None, + False, + id="mistral-small-24B", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", + "nemotron_h", + "*E", + True, + id="nemotron-3-30B-A3B", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-Nano-12B-v2", + "nemotron_h_v2", + "*-", + False, + id="nemotron-nano-12B-v2", + ), + pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b"), + pytest.param("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False, id="qwen2.5-7B"), + pytest.param("Qwen/Qwen3-8B", "qwen3", None, False, id="qwen3-8B"), + pytest.param("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True, id="qwen3-VL-30B-A3B"), +] + def setup_test_model_and_data( tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py new file mode 100644 index 00000000000..4f8cd88c7ef --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for ``bypass_checkpoint_utils``. + +The save/resume contract here is the most important regression surface in the +bypass feature: a wrong checkpoint pick or a missing ``saving_completed`` +marker silently restarts training from the wrong iteration. + +What's covered here (CPU-only, codecov-visible): + * ``find_latest_run_dir`` — every branch of the regex/scan/symlink logic. + * ``_save_local_file`` — overwrite/skip semantics. + * ``_save_local_state`` — same three save-path assertions as the GPU file + (state_dict / optimizer / grad_scaler), but on CPU so codecov picks them + up. The GPU file's ``test_load_local_state_*`` cases stay there because + ``load_local_state`` constructs ``torch.device(f"cuda:{rank}")`` directly. + * ``save_bypass_checkpoint`` — orchestration: ``latest`` symlink update, + ``args.json`` dump, ``saving_completed`` marker, master-only gating. +""" + +import os +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf +from torch.amp.grad_scaler import GradScaler + +from modelopt.torch.puzzletron.bypass_distillation import bypass_checkpoint_utils as bcu +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + StitchedModuleDescriptor, +) + +# --------------------------------------------------------------------------- +# Shared fixture: silence the dist helpers so these run single-process / CPU. +# Mirrors tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py:56-62. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcu_no_dist(monkeypatch): + monkeypatch.setattr(bcu.dist, "local_rank", lambda: 0) + monkeypatch.setattr(bcu.dist, "is_master", lambda: True) + monkeypatch.setattr(bcu.dist, "barrier", lambda: None) + return bcu + + +def _make_descriptor(*, with_optimizer: bool = True, with_scaler: bool = True): + """Build a CPU-only StitchedModuleDescriptor — the GPU file's helper minus + the configurable init_scale (we don't round-trip the scaler here).""" + module = nn.Linear(4, 4, bias=False) + owned_parameters = dict(module.named_parameters()) + optimizer = torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None + scaler = GradScaler(device="cpu", enabled=True, init_scale=2.0**16) if with_scaler else None + return StitchedModuleDescriptor( + stitched_module=module, + owned_parameters=owned_parameters, + owned_buffers={}, + optimizer=optimizer, + grad_scaler=scaler, + ) + + +# --------------------------------------------------------------------------- +# find_latest_run_dir +# --------------------------------------------------------------------------- + + +def test_find_latest_run_dir_returns_none_for_empty_dir(tmp_path: Path): + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_picks_only_step_with_marker(tmp_path: Path): + step_dir = tmp_path / "step-000010-ckpt" + step_dir.mkdir() + (step_dir / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(step_dir) + + +def test_find_latest_run_dir_picks_highest_step_number(tmp_path: Path): + """When several plain step checkpoints have completed markers, the highest + integer wins — not lexicographic order, not insertion order.""" + for i in (5, 10, 20): + d = tmp_path / f"step-{i:06d}-ckpt" + d.mkdir() + (d / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(tmp_path / "step-000020-ckpt") + + +def test_find_latest_run_dir_skips_step_without_marker(tmp_path: Path): + """A partially-written checkpoint (no ``saving_completed``) must be skipped + even when it has a higher step number — otherwise resume would crash on a + truncated state dict.""" + high = tmp_path / "step-000099-ckpt" + high.mkdir() + # No saving_completed → must be ignored. + low = tmp_path / "step-000050-ckpt" + low.mkdir() + (low / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(low) + + +def test_find_latest_run_dir_returns_none_when_no_step_has_marker(tmp_path: Path): + (tmp_path / "step-000010-ckpt").mkdir() + (tmp_path / "step-000020-ckpt").mkdir() + # No saving_completed anywhere. + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_excludes_non_plain_step_names(tmp_path: Path): + """``best-step-*`` / ``start-step-*`` / ``final-step-*`` aren't valid resume + targets — pinned by the docstring on lines 39-42.""" + for name in ("best-step-000099-ckpt", "start-step-000001-ckpt", "final-step-000050-ckpt"): + d = tmp_path / name + d.mkdir() + (d / "saving_completed").touch() + # No plain step-*-ckpt at all. + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_uses_latest_symlink_fast_path(tmp_path: Path): + """The ``latest`` symlink, when present and complete, short-circuits the + scan — even when a numerically higher step dir also has a marker. This + matters because the scan branch can be slow on filesystems with many + step dirs (NFS, lustre).""" + target = tmp_path / "step-000010-ckpt" + target.mkdir() + (target / "saving_completed").touch() + (tmp_path / "latest").symlink_to(target.name) + + higher = tmp_path / "step-000020-ckpt" + higher.mkdir() + (higher / "saving_completed").touch() + + # Symlink wins despite higher step existing, but returns the resolved target + # so callers open the same checkpoint that was validated. + assert bcu.find_latest_run_dir(tmp_path) == str(target.resolve()) + + +def test_find_latest_run_dir_falls_through_when_latest_lacks_marker(tmp_path: Path): + """A ``latest`` symlink whose target lacks ``saving_completed`` (interrupted + save) must be ignored, falling through to the highest completed step.""" + incomplete = tmp_path / "step-000020-ckpt" + incomplete.mkdir() + # No saving_completed. + (tmp_path / "latest").symlink_to(incomplete.name) + + completed = tmp_path / "step-000010-ckpt" + completed.mkdir() + (completed / "saving_completed").touch() + + assert bcu.find_latest_run_dir(tmp_path) == str(completed) + + +def test_find_latest_run_dir_ignores_latest_to_best_checkpoint(tmp_path: Path): + """`latest` is a resume pointer, so old symlinks to best checkpoints are ignored.""" + best = tmp_path / "best-step-000020-ckpt" + best.mkdir() + (best / "saving_completed").touch() + (tmp_path / "latest").symlink_to(best.name) + + completed = tmp_path / "step-000010-ckpt" + completed.mkdir() + (completed / "saving_completed").touch() + + assert bcu.find_latest_run_dir(tmp_path) == str(completed) + + +# --------------------------------------------------------------------------- +# _save_local_file +# --------------------------------------------------------------------------- + + +def test_save_local_file_writes_object_to_disk(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"a": torch.tensor([1, 2, 3])}, target) + assert target.exists() + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["a"], torch.tensor([1, 2, 3])) + + +def test_save_local_file_overwrite_true_replaces_contents(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"v": torch.tensor([1])}, target) + bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=True) + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["v"], torch.tensor([99])) + + +def test_save_local_file_overwrite_false_skips_existing(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"v": torch.tensor([1])}, target) + # Second save should be a no-op. + bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=False) + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["v"], torch.tensor([1])) + + +# --------------------------------------------------------------------------- +# _save_local_state: optimizer + grad_scaler only. +# Weights deliberately do NOT land here — the HF checkpoint at the same +# directory carries the full student state dict via ``save_checkpoint``. +# Saving the per-block weights again would just double the disk footprint. +# --------------------------------------------------------------------------- + + +def test_save_local_state_writes_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.optimizer_state.pth").exists() + assert (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_does_not_write_weights_state_dict(tmp_path: Path, bcu_no_dist): + """Pin the de-duplication: weights live in the HF checkpoint, not here.""" + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + assert not (tmp_path / "stitched" / "block_0.state_dict.pth").exists() + + +def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict([("block_0", _make_descriptor(with_scaler=False))]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.optimizer_state.pth").exists() + assert not (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict( + [("block_0", _make_descriptor(with_optimizer=False, with_scaler=False))] + ) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert not (stitched / "block_0.optimizer_state.pth").exists() + assert not (stitched / "block_0.grad_scaler.pth").exists() + + +# --------------------------------------------------------------------------- +# save_bypass_checkpoint — orchestration: symlink, args.json, marker +# --------------------------------------------------------------------------- + + +def _make_save_cfg(experiment_dir: Path, *, delete_old: bool = True): + """Minimal cfg shape used by ``save_bypass_checkpoint``. + + ``cfg.bypass`` is the object that gets dumped to ``args.json``, so it must + be JSON-serialisable (or DictConfig-with-primitives, which json_dump handles). + """ + return OmegaConf.create( + { + "bypass": { + "experiment_dir": str(experiment_dir), + "model": {"model_overrides": {"delete_old_checkpoints": delete_old}}, + "iter_num": 7, + } + } + ) + + +@pytest.fixture +def patched_save(monkeypatch, bcu_no_dist): + """Stub out the heavy callees so the test only exercises the orchestration + logic in ``save_bypass_checkpoint``.""" + monkeypatch.setattr(bcu_no_dist, "_save_local_state", lambda **kwargs: None) + monkeypatch.setattr(bcu_no_dist, "save_checkpoint_from_shards", lambda **kwargs: None) + return bcu_no_dist + + +def test_save_bypass_checkpoint_creates_latest_symlink_and_marker(tmp_path: Path, patched_save): + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + checkpoint_dir = experiment_dir / "step-000007-ckpt" + checkpoint_dir.mkdir() + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=checkpoint_dir, + ) + + latest = experiment_dir / "latest" + assert latest.is_symlink() + # Symlink target is relative — just the dir name, so it resolves under experiment_dir. + assert os.readlink(latest) == "step-000007-ckpt" + assert latest.resolve() == checkpoint_dir.resolve() + assert (checkpoint_dir / "args.json").exists() + assert (checkpoint_dir / "saving_completed").exists() + + +def test_save_bypass_checkpoint_replaces_existing_latest_symlink(tmp_path: Path, patched_save): + """A stale ``latest`` from a prior save must be replaced, not appended to. + Without ``unlink(missing_ok=True)`` the symlink_to() call would raise + FileExistsError mid-save and leave the run unable to checkpoint.""" + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + old_target = experiment_dir / "step-000003-ckpt" + old_target.mkdir() + new_target = experiment_dir / "step-000007-ckpt" + new_target.mkdir() + (experiment_dir / "latest").symlink_to(old_target.name) + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=new_target, + ) + + assert os.readlink(experiment_dir / "latest") == "step-000007-ckpt" + + +def test_save_bypass_checkpoint_best_does_not_replace_latest(tmp_path: Path, patched_save): + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + resume_target = experiment_dir / "step-000003-ckpt" + resume_target.mkdir() + best_target = experiment_dir / "best-step-000007-ckpt" + best_target.mkdir() + (experiment_dir / "latest").symlink_to(resume_target.name) + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=best_target, + checkpoint_role="best", + ) + + assert os.readlink(experiment_dir / "latest") == "step-000003-ckpt" + assert (best_target / "saving_completed").exists() + assert (best_target / "bypass_config.json").exists() + + +def test_save_bypass_checkpoint_master_only_skips_symlink_on_non_master( + tmp_path: Path, monkeypatch, patched_save +): + """Non-master ranks must not write the symlink, args.json, or marker — + only rank 0 owns those files. The other ranks still call _save_local_state + (their owned blocks) but stop short of the per-experiment metadata.""" + monkeypatch.setattr(patched_save.dist, "is_master", lambda: False) + + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + checkpoint_dir = experiment_dir / "step-000007-ckpt" + checkpoint_dir.mkdir() + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=checkpoint_dir, + ) + + assert not (experiment_dir / "latest").exists() + assert not (checkpoint_dir / "args.json").exists() + assert not (checkpoint_dir / "saving_completed").exists() diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py new file mode 100644 index 00000000000..6e7b663b6e0 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_set_keys_to_learn`` in stitched_model_factory.py. + +This function is the single source of truth for which subblock parameters get +trained during a bypass run. Its branches (subblock_ffn / subblock_attention / +subblock_mamba / entire_block / list) and its hybrid-model ``block_configs`` +filter are all silent on misuse — a regression here would freeze the wrong +layers and produce a worse-than-teacher checkpoint with no loud failure. +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import _set_keys_to_learn + +# --------------------------------------------------------------------------- +# Fixtures: a minimal Llama-shaped model and a Llama-shaped descriptor stub +# --------------------------------------------------------------------------- + + +def _make_dense_model(num_layers: int = 2) -> nn.Module: + """Build a tiny model whose named_parameters mimic Llama's naming. + + Parameters live under ``model.layers.{i}.self_attn.{q,k,v,o}_proj.weight`` + and ``model.layers.{i}.mlp.{up,down}_proj.weight``. The function never reads + parameter shapes, so size doesn't matter — what matters is that the names + match what `_set_keys_to_learn` expects to see in `named_parameters()` and + `state_dict().keys()`. + """ + model = nn.Module() + model_inner = nn.Module() + layers = nn.ModuleList() + for _ in range(num_layers): + layer = nn.Module() + # attention + layer.self_attn = nn.Module() + for proj in ("q_proj", "k_proj", "v_proj", "o_proj"): + setattr(layer.self_attn, proj, nn.Linear(4, 4, bias=False)) + # feed-forward + layer.mlp = nn.Module() + for proj in ("up_proj", "down_proj"): + setattr(layer.mlp, proj, nn.Linear(4, 4, bias=False)) + layers.append(layer) + model_inner.layers = layers + model.model = model_inner + # `_set_keys_to_learn` reads `model.config` only to pass through to + # `descriptor.get_language_model_config` — a SimpleNamespace is enough. + model.config = SimpleNamespace() + # Start with everything frozen so any True flag is something the function set. + for p in model.parameters(): + p.requires_grad_(False) + return model + + +def _make_descriptor(num_layers: int, *, block_configs=None): + """Build a descriptor stub exposing only what ``_set_keys_to_learn`` calls. + + - ``get_language_model_config(config)`` returns an object with + ``num_hidden_layers`` and (optionally) ``block_configs``. + - ``get_weight_groups(state_dict_keys, num_hidden_layers)`` returns + ``{"block_{i}_attention": [...], "block_{i}_ffn": [...]}``. + """ + + def get_language_model_config(_config): + ns = SimpleNamespace(num_hidden_layers=num_layers) + if block_configs is not None: + ns.block_configs = block_configs + return ns + + def get_weight_groups(state_dict_keys, n): + groups: dict[str, list[str]] = {} + for i in range(n): + attn_prefix = f"model.layers.{i}.self_attn." + ffn_prefix = f"model.layers.{i}.mlp." + groups[f"block_{i}_attention"] = [ + k for k in state_dict_keys if k.startswith(attn_prefix) + ] + groups[f"block_{i}_ffn"] = [k for k in state_dict_keys if k.startswith(ffn_prefix)] + return groups + + return SimpleNamespace( + get_language_model_config=get_language_model_config, + get_weight_groups=get_weight_groups, + ) + + +def _trainable_names(model: nn.Module) -> set[str]: + return {n for n, p in model.named_parameters() if p.requires_grad} + + +# --------------------------------------------------------------------------- +# Single-string subblock keys (dense model) +# --------------------------------------------------------------------------- + + +def test_subblock_ffn_trains_only_mlp(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "subblock_ffn") + trainable = _trainable_names(model) + assert all(".mlp." in n for n in trainable), trainable + assert not any(".self_attn." in n for n in trainable), trainable + # Both layers' mlp params must be trainable, not just one. + assert any("model.layers.0.mlp." in n for n in trainable) + assert any("model.layers.1.mlp." in n for n in trainable) + + +def test_subblock_attention_trains_only_self_attn(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "subblock_attention") + trainable = _trainable_names(model) + assert all(".self_attn." in n for n in trainable), trainable + assert not any(".mlp." in n for n in trainable), trainable + + +def test_entire_block_trains_attention_and_mlp(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "entire_block") + trainable = _trainable_names(model) + # Both groups present. + assert any(".self_attn." in n for n in trainable), trainable + assert any(".mlp." in n for n in trainable), trainable + # Equal to the union of every model parameter. + assert trainable == {n for n, _ in model.named_parameters()} + + +def test_subblock_key_list_trains_union_of_subblocks(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, ["subblock_attention", "subblock_ffn"]) + trainable = _trainable_names(model) + assert any(".self_attn." in n for n in trainable), trainable + assert any(".mlp." in n for n in trainable), trainable + assert trainable == {n for n, _ in model.named_parameters()} + + +def test_mixed_subblock_and_exact_name_list_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="supports only subblock keys"): + _set_keys_to_learn( + model, + descriptor, + ["subblock_attention", "model.layers.0.self_attn.q_proj.weight"], + ) + + +# --------------------------------------------------------------------------- +# Hybrid model: subblock_mamba vs subblock_attention should partition by +# block_configs[i].attention.mamba — this is the path most likely to +# silently misroute training under future descriptor changes. +# --------------------------------------------------------------------------- + + +def _hybrid_block_configs(): + """Block 0: Mamba. Block 1: GQA. Detected via ``attention.mamba is not None``.""" + return [ + SimpleNamespace(attention=SimpleNamespace(mamba=SimpleNamespace())), # Mamba + SimpleNamespace(attention=SimpleNamespace(mamba=None)), # GQA + ] + + +def test_subblock_mamba_on_hybrid_trains_only_mamba_block(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, "subblock_mamba") + trainable = _trainable_names(model) + # Block 0 (Mamba) attention-group params should be trainable; block 1 (GQA) must not. + assert any("model.layers.0.self_attn." in n for n in trainable), trainable + assert not any("model.layers.1.self_attn." in n for n in trainable), trainable + # FFN params are never trainable under subblock_mamba. + assert not any(".mlp." in n for n in trainable), trainable + + +def test_subblock_attention_on_hybrid_trains_only_gqa_block(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, "subblock_attention") + trainable = _trainable_names(model) + # Block 1 (GQA) attention-group params are trainable; block 0 (Mamba) must not. + assert any("model.layers.1.self_attn." in n for n in trainable), trainable + assert not any("model.layers.0.self_attn." in n for n in trainable), trainable + assert not any(".mlp." in n for n in trainable), trainable + + +# --------------------------------------------------------------------------- +# Unsupported free-form key forms +# --------------------------------------------------------------------------- + + +def test_explicit_param_name_list_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + target = "model.layers.0.self_attn.q_proj.weight" + with pytest.raises(ValueError, match="subblock keys"): + _set_keys_to_learn(model, descriptor, [target]) + + +def test_regex_string_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="keys_to_learn must be one of"): + _set_keys_to_learn(model, descriptor, r"q_proj") + + +def test_empty_key_list_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="cannot be empty"): + _set_keys_to_learn(model, descriptor, []) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "keys_to_learn", + ["subblock_ffn", "subblock_attention", "entire_block"], +) +def test_subblock_keys_skip_non_floating_point_params(keys_to_learn): + """Integer / non-floating buffers exposed as parameters must stay frozen. + + The function explicitly guards on ``torch.is_floating_point(param)``; this + test pins that guard so a future refactor doesn't accidentally try to + enable grad on int tensors (which would raise at runtime). + """ + model = _make_dense_model(num_layers=2) + # Inject an int "param" alongside a real one. + int_param = nn.Parameter(torch.zeros(2, dtype=torch.long), requires_grad=False) + model.model.layers[0].self_attn.register_parameter("int_counter", int_param) + descriptor = _make_descriptor(num_layers=2) + # Should not raise even though the int param's name matches the attention group. + _set_keys_to_learn(model, descriptor, keys_to_learn) + # The int counter must remain frozen regardless. + assert not model.model.layers[0].self_attn.int_counter.requires_grad diff --git a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py new file mode 100644 index 00000000000..38701ba8be3 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the cosine-with-warmup LR scheduler used by bypass distillation. + +``_get_lr`` is the scheduler invoked every step inside ``train``. An off-by-one +in the cosine ramp would silently degrade convergence — bypass jobs run for +hours and produce subtly worse student weights. The degenerate-budget guard +matters for tests and short sweeps where ``training_tokens`` is small. + +Schedule shape (warmup_steps=W, lr_decay_steps=D): + + step ∈ [0, W]: linear ramp 0 → base_lr (warmup branch) + step ∈ (W, D]: cosine decay base_lr → min_lr (cosine branch) + step > D: clamped to min_lr (post-decay branch) + +The cosine uses ``decay_ratio = (step - W) / (D - W)`` so the boundary cases +align: at step=W+1 the cosine has just started (decay_ratio = 1/(D-W)) and at +step=D it reaches min_lr exactly (decay_ratio=1, coeff=0). +""" + +import math + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.bypass_distillation.training_loop import _get_lr + + +def _make_cfg( + *, + warmup_steps: int, + lr_decay_steps: int, + learning_rate: float = 1.0, + min_lr: float = 0.1, +): + return OmegaConf.create( + { + "bypass": { + "training": { + "warmup_steps": warmup_steps, + "lr_decay_steps": lr_decay_steps, + "learning_rate": learning_rate, + "min_lr": min_lr, + } + } + } + ) + + +def test_degenerate_budget_returns_base_lr(): + """When ``lr_decay_steps <= warmup_steps`` (tiny test budgets), the scheduler + must short-circuit to ``learning_rate`` rather than divide by zero.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=10, learning_rate=0.5) + assert _get_lr(cfg, step=0) == 0.5 + assert _get_lr(cfg, step=1) == 0.5 + assert _get_lr(cfg, step=99) == 0.5 + + +def test_degenerate_budget_warmup_greater_than_decay(): + """``lr_decay_steps < warmup_steps`` is also caught by the same guard.""" + cfg = _make_cfg(warmup_steps=20, lr_decay_steps=10, learning_rate=0.7) + assert _get_lr(cfg, step=5) == 0.7 + + +def test_warmup_linear_ramp(): + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=100, learning_rate=1.0) + assert _get_lr(cfg, step=0) == pytest.approx(0.0) + assert _get_lr(cfg, step=5) == pytest.approx(0.5) + assert _get_lr(cfg, step=10) == pytest.approx(1.0) + + +def test_cosine_starts_decaying_immediately_after_warmup(): + """At ``step == warmup_steps + 1`` the cosine branch is entered with + ``decay_ratio = 1/(D-W)`` — already a small step below base LR, not a + duplicate plateau at base LR. This is the boundary the previous formula + got wrong (it used ``step - W - 1`` and gave ``decay_ratio == 0`` here).""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) + # decay_ratio = (11 - 10) / 10 = 0.1 + expected = 0.5 * (1.0 + math.cos(math.pi * 0.1)) + assert _get_lr(cfg, step=11) == pytest.approx(expected) + # Strictly below base LR — the cosine has begun. + assert _get_lr(cfg, step=11) < 1.0 + + +def test_cosine_endpoint_returns_min_lr(): + """At ``step == lr_decay_steps`` the cosine branch reaches its endpoint: + ``decay_ratio == 1`` → ``coeff == 0`` → returns ``min_lr`` exactly. The + post-decay clamp at ``step == lr_decay_steps + 1`` is then a no-op + continuation, not a correction for an off-by-one.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) + assert _get_lr(cfg, step=20) == pytest.approx(0.1) + + +def test_cosine_midpoint_is_halfway(): + """At the cosine midpoint, ``coeff == 0.5`` → returns ``(lr + min_lr) / 2``.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) + # Midpoint of the post-warmup window: step such that decay_ratio == 0.5. + # decay_ratio = (step - 10) / (20 - 10) → step = 15 gives ratio 0.5. + expected_coeff = 0.5 * (1.0 + math.cos(math.pi * 0.5)) + assert _get_lr(cfg, step=15) == pytest.approx(expected_coeff) + + +def test_post_decay_clamps_to_min_lr(): + """``step > lr_decay_steps`` always returns ``min_lr`` exactly.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) + assert _get_lr(cfg, step=21) == 0.1 + assert _get_lr(cfg, step=1000) == 0.1 + + +def test_min_lr_zero_decays_to_zero(): + """Common config: ``min_lr=0`` → cosine endpoint is exactly 0.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=30, learning_rate=2.0, min_lr=0.0) + assert _get_lr(cfg, step=30) == pytest.approx(0.0) + assert _get_lr(cfg, step=31) == 0.0 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py new file mode 100644 index 00000000000..0b43a97c01c --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for get_distributed_modules_ownership in bypass_utils.py.""" + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import ( + get_bypass_config_fingerprint, + get_distributed_modules_ownership, + get_pipeline_ownership_context, + set_experiment_id, +) + + +def test_single_gpu_all_to_rank_0(): + """With world_size=1, all 4 modules should be assigned to rank 0.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=1) + assert ownership == [0, 0, 0, 0] + + +def test_even_distribution(): + """With world_size=2 and 4 modules, each rank should own exactly 2 modules.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 2 + assert len(ownership) == 4 + + +def test_uneven_distribution(): + """With world_size=2 and 3 modules, rank 0 should own 2 and rank 1 should own 1.""" + ownership = get_distributed_modules_ownership(module_count=3, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 1 + assert len(ownership) == 3 + + +@pytest.mark.parametrize( + ("module_count", "world_size"), + [ + (1, 1), + (4, 1), + (4, 2), + (4, 4), + (7, 3), + (10, 4), + (1, 2), + ], +) +def test_total_equals_module_count(module_count, world_size): + """The length of the ownership list must always equal module_count.""" + ownership = get_distributed_modules_ownership(module_count=module_count, world_size=world_size) + assert len(ownership) == module_count + + +def test_consecutive_ownership(): + """Each rank should own a contiguous block of indices (no interleaving).""" + ownership = get_distributed_modules_ownership(module_count=7, world_size=3) + # Verify that once we see a new rank, we never see the previous rank again. + seen_ranks = set() + prev_rank = ownership[0] + seen_ranks.add(prev_rank) + for rank in ownership[1:]: + if rank != prev_rank: + assert rank not in seen_ranks, ( + f"Rank {rank} appears non-consecutively in ownership list: {ownership}" + ) + seen_ranks.add(rank) + prev_rank = rank + + +def test_single_module(): + """With world_size=2 and only 1 module, rank 0 should be the sole owner.""" + ownership = get_distributed_modules_ownership(module_count=1, world_size=2) + assert ownership == [0] + assert len(ownership) == 1 + + +def test_pipeline_ownership_context_returns_neighbors(): + ownership = [0, 0, 1, 1, 2] + + assert get_pipeline_ownership_context(ownership, rank=0) == { + "owned_indices": [0, 1], + "owned_index_set": {0, 1}, + "prev_rank": None, + "next_rank": 1, + } + assert get_pipeline_ownership_context(ownership, rank=1) == { + "owned_indices": [2, 3], + "owned_index_set": {2, 3}, + "prev_rank": 0, + "next_rank": 2, + } + assert get_pipeline_ownership_context(ownership, rank=2) == { + "owned_indices": [4], + "owned_index_set": {4}, + "prev_rank": 1, + "next_rank": None, + } + + +def test_pipeline_ownership_context_rejects_idle_rank(): + with pytest.raises(RuntimeError, match="owns no modules"): + get_pipeline_ownership_context([0, 0, 1], rank=2) + + +def _experiment_cfg(keys_to_learn: str): + return OmegaConf.create( + { + "descriptor": "test_descriptor", + "dataset_path": "/tmp/dataset_a", + "bypass": { + "experiment_id": None, + "dtype": "bf16", + "seed": 42, + "data": { + "block_size": 64, + "data_column": "text", + "fim_rate": 0, + "fim_spm_rate": 0, + "bos_rate": 1.0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "shuffle_train_data_seed": 123, + "val_dataset_name": "valid", + "max_eval_samples": 4, + "eval_samples_per_process": None, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": 1024, + "micro_batch_size": 1, + "grad_accumulation_steps": 1, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "warmup_ratio": 0.05, + "min_lr_factor": 1e-5, + }, + "model": { + "student_weights_dtype": "bf16", + "model_config_overrides": { + "attention": [{"num_key_value_heads": 1, "no_op": None}] + }, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": "Truncate", + "mlp_init_config": {"activations_log_dir": None}, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": keys_to_learn, + }, + "disable_validation": False, + "save_best_ckpt": True, + "realize_best_or_latest": "best", + }, + } + ) + + +def test_experiment_id_includes_learning_target_and_fingerprint(): + attention_cfg = _experiment_cfg("subblock_attention") + ffn_cfg = _experiment_cfg("subblock_ffn") + + set_experiment_id(attention_cfg) + set_experiment_id(ffn_cfg) + + assert attention_cfg.bypass.experiment_id.startswith("bypass_heads_1_attention_") + assert ffn_cfg.bypass.experiment_id.startswith("bypass_heads_1_ffn_") + assert attention_cfg.bypass.experiment_id != ffn_cfg.bypass.experiment_id + + +def test_experiment_id_falls_back_when_no_architecture_parts_exist(): + cfg = _experiment_cfg("entire_block") + cfg.bypass.model.model_config_overrides = {} + + set_experiment_id(cfg) + + assert cfg.bypass.experiment_id.startswith("bypass_custom_") + assert cfg.bypass.experiment_id != "bypass_None" + + +def test_config_fingerprint_changes_with_dataset_path(): + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + cfg.dataset_path = "/tmp/dataset_b" + assert get_bypass_config_fingerprint(cfg) != original + + +def test_config_fingerprint_changes_with_shuffle_seed(): + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + cfg.bypass.data.shuffle_train_data_seed = 456 + assert get_bypass_config_fingerprint(cfg) != original + + +def test_experiment_id_does_not_change_with_dataset_path(): + cfg_a = _experiment_cfg("subblock_attention") + cfg_b = _experiment_cfg("subblock_attention") + cfg_b.dataset_path = "/tmp/dataset_b" + set_experiment_id(cfg_a) + set_experiment_id(cfg_b) + assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id diff --git a/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py b/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py new file mode 100644 index 00000000000..e702256d606 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import torch + +from modelopt.torch.puzzletron.tools import checkpoint_utils_hf as cuhf + + +def test_save_checkpoint_uses_descriptor_language_model_config(tmp_path, monkeypatch): + calls = {} + + class Descriptor: + @staticmethod + def get_language_model_config(config): + return config.text_config + + @staticmethod + def get_weight_groups(layer_names, num_hidden_layers): + calls["num_hidden_layers"] = num_hidden_layers + return {"weights": list(layer_names)} + + @staticmethod + def output_embedding_name(): + return "lm_head" + + monkeypatch.setattr(cuhf, "save_model_config", lambda *args, **kwargs: None) + monkeypatch.setattr(cuhf, "save_subblocks", lambda *args, **kwargs: None) + + cfg = SimpleNamespace( + text_config=SimpleNamespace(num_hidden_layers=7), + tie_word_embeddings=False, + ) + cuhf._save_checkpoint(cfg, {"some.weight": torch.zeros(1)}, tmp_path, Descriptor) + + assert calls["num_hidden_layers"] == 7 + + +def test_copy_auto_map_code_files_ignores_non_string_entries(tmp_path, monkeypatch): + source_dir = tmp_path / "source" + checkpoint_dir = tmp_path / "checkpoint" + source_dir.mkdir() + checkpoint_dir.mkdir() + (source_dir / "modeling_custom.py").write_text("# modeling\n") + (source_dir / "tokenization_custom.py").write_text("# tokenizer\n") + + monkeypatch.setattr(cuhf.inspect, "getfile", lambda _cls: source_dir / "configuration.py") + + cfg = SimpleNamespace( + auto_map={ + "AutoConfig": "configuration_custom.CustomConfig", + "AutoModelForCausalLM": "modeling_custom.CustomModel", + "AutoTokenizer": [None, "tokenization_custom.CustomTokenizer"], + } + ) + + cuhf._copy_auto_map_code_files(cfg, checkpoint_dir) + + assert (checkpoint_dir / "modeling_custom.py").exists() + assert (checkpoint_dir / "tokenization_custom.py").exists() + + +def test_copy_auto_map_code_files_strips_repo_id_prefix(tmp_path, monkeypatch): + source_dir = tmp_path / "source" + checkpoint_dir = tmp_path / "checkpoint" + source_dir.mkdir() + checkpoint_dir.mkdir() + (source_dir / "modeling_custom.py").write_text("# modeling\n") + + monkeypatch.setattr(cuhf.inspect, "getfile", lambda _cls: source_dir / "configuration.py") + + cfg = SimpleNamespace( + auto_map={"AutoModelForCausalLM": "org/repo--modeling_custom.CustomModel"} + ) + + cuhf._copy_auto_map_code_files(cfg, checkpoint_dir) + + assert (checkpoint_dir / "modeling_custom.py").exists() diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py new file mode 100644 index 00000000000..5975612809a --- /dev/null +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``launch_bypass_distillation`` (sweep dispatcher). + +The dispatcher's job is to iterate over ``bypass.configs``, apply each override +to the live ``hydra_cfg``, reset the per-run state machine, and invoke +``run_bypassed_training``. Reordering or dropping a reset would silently make +the second sweep entry resume from the first entry's iter counter — a bug +that would only surface as wasted compute and confused checkpoint dirs. + +We patch ``run_bypassed_training`` to a recorder so this stays a pure-Python +test (no GPU, no real training). +""" + +import json +from pathlib import Path + +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.bypass_distillation.training_loop as tl + + +def _base_cfg(tmp_path, configs=None): + """Build a minimal cfg shape that ``launch_bypass_distillation`` reads. + + Includes only the keys touched by the dispatcher itself; ``run_bypassed_training`` + is mocked so its richer requirements are irrelevant here. + """ + cfg = { + "puzzle_dir": str(tmp_path / "puzzletron_bypass_unit"), + "descriptor": "test_descriptor", + "bypass": { + "model": {"model_config_overrides": {"intermediate_size": 1024}}, + "model_factory": {"keys_to_learn": "subblock_ffn"}, + "experiment_id": "stale-id", + "iter_num": 999, + "step_num": 999, + "token_count": 999_999, + "best_val_loss": 0.0, + "training": {"clipping_count": 42}, + }, + } + if configs is not None: + cfg["bypass"]["configs"] = configs + return OmegaConf.create(cfg) + + +def _record_calls(monkeypatch): + """Patch ``run_bypassed_training`` to capture deep-copied cfg snapshots.""" + snapshots = [] + + def _recorder(cfg): + # Deep-copy via container conversion; the live cfg is mutated between calls. + snapshots.append(OmegaConf.to_container(cfg, resolve=True)) + + monkeypatch.setattr(tl, "run_bypassed_training", _recorder) + return snapshots + + +def test_no_configs_key_runs_once(monkeypatch, tmp_path): + """Absent ``bypass.configs`` is the single-config path — one call, no resets.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=None) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + # Single-config path doesn't touch the state machine — values remain as supplied. + assert snapshots[0]["bypass"]["iter_num"] == 999 + assert snapshots[0]["bypass"]["training"]["clipping_count"] == 42 + + +def test_empty_configs_list_runs_once(monkeypatch, tmp_path): + """``configs: []`` must hit the same branch as missing — the truthiness + check on ``bypass.configs`` treats both as 'no sweep'.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=[]) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + + +def test_two_configs_run_twice_with_distinct_overrides(monkeypatch, tmp_path): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + tmp_path, + configs=[ + {"model_config_overrides": {"intermediate_size": 256}}, + {"model_config_overrides": {"intermediate_size": 128}}, + ], + ) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 2 + assert snapshots[0]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 256} + assert snapshots[1]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 128} + + +def test_keys_to_learn_override_applied(monkeypatch, tmp_path): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=[{"keys_to_learn": "subblock_attention"}]) + tl.launch_bypass_distillation(cfg) + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" + + +def test_per_run_state_reset_before_each_call(monkeypatch, tmp_path): + """Every sweep entry must see iter_num=1, step_num=1, token_count=0, + best_val_loss=1e9, clipping_count=0, and a fresh experiment_id even when the + previous entry left the cfg in some other state.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + tmp_path, + configs=[ + {"model_config_overrides": {"intermediate_size": 256}}, + {"model_config_overrides": {"intermediate_size": 128}}, + ], + ) + tl.launch_bypass_distillation(cfg) + for snap in snapshots: + assert snap["bypass"]["experiment_id"].startswith("bypass_ffn_") + assert snap["bypass"]["iter_num"] == 1 + assert snap["bypass"]["step_num"] == 1 + assert snap["bypass"]["token_count"] == 0 + assert snap["bypass"]["best_val_loss"] == 1e9 + assert snap["bypass"]["training"]["clipping_count"] == 0 + + +def test_override_without_keys_to_learn_leaves_cfg_value_untouched(monkeypatch, tmp_path): + """A sweep entry that only sets ``model_config_overrides`` must not clobber + the inherited ``keys_to_learn`` (the dispatcher's `if "keys_to_learn" in override` + guard).""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=[{"model_config_overrides": {"intermediate_size": 256}}]) + tl.launch_bypass_distillation(cfg) + # keys_to_learn was set to "subblock_ffn" in _base_cfg — must survive. + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" + + +def test_sweep_entry_without_keys_to_learn_uses_base_not_previous_override(monkeypatch, tmp_path): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + tmp_path, + configs=[ + {"keys_to_learn": "subblock_attention"}, + {"model_config_overrides": {"intermediate_size": 256}}, + ], + ) + tl.launch_bypass_distillation(cfg) + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" + assert snapshots[1]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" + + +def test_trust_remote_code_defaults_to_false_even_when_descriptor_requires_it(monkeypatch): + class DescriptorRequiringTrust: + @staticmethod + def requires_trust_remote_code(): + return True + + messages = [] + + def capture_message(*args): + messages.append(" ".join(map(str, args))) + + monkeypatch.setattr(tl, "mprint", capture_message) + + assert tl._resolve_trust_remote_code(OmegaConf.create({}), DescriptorRequiringTrust) is False + assert any("trust_remote_code" in message for message in messages) + + +def test_trust_remote_code_uses_explicit_cfg_opt_in(monkeypatch): + class DescriptorRequiringTrust: + @staticmethod + def requires_trust_remote_code(): + return True + + messages = [] + + def capture_message(*args): + messages.append(" ".join(map(str, args))) + + monkeypatch.setattr(tl, "mprint", capture_message) + + cfg = OmegaConf.create({"trust_remote_code": True}) + assert tl._resolve_trust_remote_code(cfg, DescriptorRequiringTrust) is True + assert messages == [] + + +def test_resume_state_ignored_when_init_checkpoint_path_wins(monkeypatch): + messages = [] + + def capture_message(*args): + messages.append(" ".join(map(str, args))) + + monkeypatch.setattr(tl, "mprint", capture_message) + cfg = OmegaConf.create({"bypass": {"init_checkpoint_path": "/tmp/init-ckpt"}}) + + assert tl._get_resume_state_path(cfg, "/tmp/resume-ckpt") is None + assert any("init_checkpoint_path" in message for message in messages) + + +def test_resume_state_used_when_no_init_checkpoint_path(): + cfg = OmegaConf.create({"bypass": {"init_checkpoint_path": None}}) + + assert tl._get_resume_state_path(cfg, "/tmp/resume-ckpt") == "/tmp/resume-ckpt" + + +def test_flush_loss_buffer_single_rank_without_process_group(): + local_buffer = {1: {"block_0": 0.25}} + stitched_losses_history = {} + + tl._flush_loss_buffer(local_buffer, stitched_losses_history) + + assert stitched_losses_history == local_buffer + + +def test_realize_bypass_checkpoints_uses_resolved_symlink_target(monkeypatch, tmp_path: Path): + monkeypatch.chdir(tmp_path) + experiment_dir = Path("puzzle/bypass/bypass_runs/run_0") + checkpoint_dir = experiment_dir / "final-step-000002-ckpt" + checkpoint_dir.mkdir(parents=True) + (experiment_dir / "bypass_state.json").write_text( + json.dumps({"checkpoints": {"final": str(checkpoint_dir)}}) + ) + cfg = OmegaConf.create( + { + "puzzle_dir": "puzzle", + "bypass": { + "experiment_dir": str(experiment_dir), + "experiment_id": "run_0", + "realize_best_or_latest": "latest", + }, + } + ) + + realized_checkpoint, ckpts_symlink = tl.realize_bypass_checkpoints(cfg) + + assert realized_checkpoint == checkpoint_dir.resolve() + assert ckpts_symlink.readlink() == checkpoint_dir.resolve() + assert ckpts_symlink.exists() diff --git a/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py b/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py new file mode 100644 index 00000000000..07f46c0327b --- /dev/null +++ b/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for bypass checkpoint metadata consumed by replacement-library extraction.""" + +import json +from pathlib import Path + +import pytest + +from modelopt.torch.puzzletron.replacement_library.build_replacement_library import ( + _infer_subblocks_to_extract, +) + + +@pytest.mark.parametrize( + ("keys_to_learn", "expected_subblocks"), + [ + ("entire_block", ["block"]), + ("subblock_ffn", ["ffn"]), + ("subblock_attention", ["attention"]), + ("subblock_mamba", ["attention"]), + (["subblock_attention", "subblock_ffn"], ["attention", "ffn"]), + ], +) +def test_infer_subblocks_to_extract_accepts_bypass_keys( + tmp_path: Path, + keys_to_learn, + expected_subblocks, +): + checkpoint_dir = tmp_path / "checkpoint" + checkpoint_dir.mkdir() + (checkpoint_dir / "bypass_config.json").write_text(json.dumps({"keys_to_learn": keys_to_learn})) + + assert _infer_subblocks_to_extract(checkpoint_dir, []) == expected_subblocks + + +@pytest.mark.parametrize("keys_to_learn", ["mlp", "attn", ["mlp", "attn"]]) +def test_infer_subblocks_to_extract_rejects_legacy_keys(tmp_path: Path, keys_to_learn): + checkpoint_dir = tmp_path / "checkpoint" + checkpoint_dir.mkdir() + (checkpoint_dir / "bypass_config.json").write_text(json.dumps({"keys_to_learn": keys_to_learn})) + + with pytest.raises(ValueError, match="keys_to_learn"): + _infer_subblocks_to_extract(checkpoint_dir, []) diff --git a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py new file mode 100644 index 00000000000..5fab764b565 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_get_all_non_persistent_buffers_set``. + +This helper is what ``bypass_factory_fn`` uses to decide which buffers belong +to ``owned_buffers`` (and therefore get checkpointed) versus which are +recomputed on every forward (RoPE caches, attention masks, etc.). A regression +that drops the module-name prefix would cause the post-resume model to silently +load buffers under wrong names. +""" + +import torch +import torch.nn as nn + +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + _get_all_non_persistent_buffers_set, +) + + +def test_module_with_no_buffers_returns_empty_set(): + assert _get_all_non_persistent_buffers_set(nn.Module()) == set() + + +def test_persistent_buffer_excluded_non_persistent_included(): + m = nn.Module() + m.register_buffer("p", torch.zeros(1), persistent=True) + m.register_buffer("np", torch.zeros(1), persistent=False) + out = _get_all_non_persistent_buffers_set(m) + assert out == {"np"} + + +def test_nested_submodule_paths_are_fully_qualified(): + """Sub-module non-persistent buffers must surface as ``submodule_name.buffer_name`` + so the matching key in ``state_dict()`` and the bypass save/restore code agree.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer("nb", torch.zeros(1), persistent=False) + outer.add_module("inner", inner) + out = _get_all_non_persistent_buffers_set(outer) + assert out == {"inner.nb"} + + +def test_top_level_buffer_has_no_leading_dot(): + """Module name is "" at the root — fully-qualified name must not start + with a dot, otherwise it won't match any state_dict key.""" + m = nn.Module() + m.register_buffer("x", torch.zeros(1), persistent=False) + out = _get_all_non_persistent_buffers_set(m) + assert out == {"x"} + assert not any(name.startswith(".") for name in out) + + +def test_mix_of_persistent_and_non_persistent_in_nested_module(): + """The full discrimination: only the nested non-persistent buffer should + appear, with its fully-qualified path.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer("keep", torch.zeros(1), persistent=True) # persistent → excluded + inner.register_buffer("rope_cache", torch.zeros(1), persistent=False) + outer.add_module("attn", inner) + outer.register_buffer("global_keep", torch.zeros(1), persistent=True) # → excluded + out = _get_all_non_persistent_buffers_set(outer) + assert out == {"attn.rope_cache"} From 24a2eca53d02863c92375ec08f63227c50405055 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Mon, 1 Jun 2026 14:39:47 +0200 Subject: [PATCH 2/8] Fix bypass distillation reuse and resume guards Signed-off-by: Sepehr Sameni --- .../bypass_distillation/bypass_utils.py | 45 ++++-- .../bypass_distillation/training_loop.py | 105 +++++++++----- .../torch/puzzletron/test_bypass_utils.py | 48 ++++++- .../test_launch_bypass_distillation.py | 130 ++++++++++++++++++ 4 files changed, 280 insertions(+), 48 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py index 4402e3f9217..6baf42c4c7a 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -81,13 +81,20 @@ def normalize_keys_to_learn(keys_to_learn: Any) -> dict[str, Any]: raise ValueError( f"keys_to_learn supports only subblock keys in v1; invalid entries: {invalid!r}" ) - if "entire_block" in values and len(set(values)) > 1: + subblocks = tuple(sorted(set(values))) + if "entire_block" in subblocks and len(subblocks) > 1: raise ValueError("keys_to_learn cannot mix 'entire_block' with other subblock keys") - return {"mode": "subblocks", "subblocks": tuple(dict.fromkeys(values))} + return {"mode": "subblocks", "subblocks": subblocks} raise TypeError(f"Unsupported keys_to_learn={keys_to_learn!r}") +def _canonical_keys_to_learn(keys_to_learn: Any) -> tuple[str, ...] | None: + if keys_to_learn is None: + return None + return normalize_keys_to_learn(keys_to_learn)["subblocks"] + + def learned_subblocks_from_keys_to_learn(keys_to_learn: Any) -> list[str]: """Return replacement-library subblocks represented by ``keys_to_learn``.""" normalized = normalize_keys_to_learn(keys_to_learn) @@ -112,14 +119,24 @@ def _slug(value: Any) -> str: return slug or "custom" +def _teacher_dir_identity(cfg: DictConfig) -> str | None: + teacher_dir = cfg.get("teacher_dir", None) + if teacher_dir is None: + return None + teacher_dir = str(teacher_dir) + if teacher_dir.startswith("~"): + return str(Path(teacher_dir).expanduser()) + return teacher_dir + + def get_bypass_run_identity(cfg: DictConfig) -> dict[str, Any]: """Return the config subset that defines a bypass output. The full Hydra config carries mutable runtime counters, checkpoint paths and logging fields. Those should not decide whether a completed bypass run can - be reused. This identity intentionally keeps architecture, training budget, - data shape and learning-target fields, because changing any of them changes - the produced checkpoint. + be reused. This identity intentionally keeps teacher source, architecture, + training budget, data shape and learning-target fields, because changing any + of them changes the produced checkpoint. """ bypass = _to_plain_container(cfg.bypass) training = bypass.get("training", {}) @@ -127,6 +144,10 @@ def get_bypass_run_identity(cfg: DictConfig) -> dict[str, Any]: model = bypass.get("model", {}) model_factory = bypass.get("model_factory", {}) return { + "teacher": { + "teacher_dir": _teacher_dir_identity(cfg), + "descriptor": cfg.get("descriptor", None), + }, "model": { "student_weights_dtype": model.get("student_weights_dtype"), "model_config_overrides": model.get("model_config_overrides"), @@ -139,7 +160,7 @@ def get_bypass_run_identity(cfg: DictConfig) -> dict[str, Any]: "mlp_init_config": model_factory.get("mlp_init_config"), "linear_init_mode": model_factory.get("linear_init_mode"), "submodule_for_loss_calculation": model_factory.get("submodule_for_loss_calculation"), - "keys_to_learn": model_factory.get("keys_to_learn"), + "keys_to_learn": _canonical_keys_to_learn(model_factory.get("keys_to_learn")), }, "training": { "learning_rate": training.get("learning_rate"), @@ -189,15 +210,16 @@ def get_bypass_config_fingerprint(cfg: DictConfig) -> str: def get_bypass_experiment_fingerprint(cfg: DictConfig) -> str: - """Return a stable ID fingerprint for the architecture and learning target. + """Return a stable ID fingerprint for the teacher, architecture and learning target. Training budget and data settings are deliberately excluded so a longer - rerun can resume the same architecture from its previous final checkpoint. + rerun can resume the same teacher and architecture from its previous final checkpoint. The full config fingerprint is still recorded in bypass_state.json and used for skip-if-complete decisions. """ identity = get_bypass_run_identity(cfg) experiment_identity = { + "teacher": identity["teacher"], "model": identity["model"], "model_factory": { "factory": identity["model_factory"]["factory"], @@ -244,9 +266,9 @@ def set_experiment_id(cfg: DictConfig) -> None: ): parts.append(f"heads_{attn_override['num_key_value_heads']}") - keys_to_learn = cfg.bypass.model_factory.get("keys_to_learn", None) - if keys_to_learn not in (None, "entire_block"): - parts.append(_slug(keys_to_learn)) + keys_to_learn = _canonical_keys_to_learn(cfg.bypass.model_factory.get("keys_to_learn", None)) + if keys_to_learn is not None and keys_to_learn != ("entire_block",): + parts.append(_slug("_".join(keys_to_learn))) if not parts: parts.append("custom") @@ -354,6 +376,7 @@ def expected_bypass_runs(cfg: DictConfig) -> list[dict[str, Any]]: run_cfg = OmegaConf.create( { "puzzle_dir": cfg.puzzle_dir, + "teacher_dir": cfg.get("teacher_dir", None), "dataset_path": cfg.get("dataset_path", None), "descriptor": cfg.get("descriptor", None), "bypass": OmegaConf.to_container(cfg.bypass, resolve=True), diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index b9cbc4060ff..baae78fb071 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -53,6 +53,7 @@ from modelopt.torch.puzzletron.tools.logger import aprint, mprint from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses +from modelopt.torch.utils.logging import print_rank_0 from modelopt.torch.utils.robust_json import json_load from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint @@ -125,6 +126,60 @@ def _get_resume_state_path(cfg: DictConfig, resume_checkpoint_path: Optional[str return resume_checkpoint_path +def _get_resume_skip_first_batches(saved_skip: int, resume_iter_num: int) -> int: + return saved_skip + max(0, resume_iter_num - 1) + + +def _finalize_bypass_run(cfg: DictConfig) -> None: + """Realize and mark a completed bypass run when a checkpoint exists.""" + if cfg.bypass.get("disable_checkpoint_save", False): + mprint( + "Bypass checkpoint saving is disabled; skipping checkpoint realization " + "and completion marker" + ) + return + + if not dist.is_master(): + return + + mprint("Realizing bypass checkpoints") + try: + realized_checkpoint, ckpts_symlink = realize_bypass_checkpoints(cfg) + except FileNotFoundError as err: + mprint(f"{err}; skipping bypass completion marker") + return + mark_bypass_run_completed(cfg, realized_checkpoint, ckpts_symlink) + + +def _clip_stitched_module_grads( + stitched_module: StitchedModule, grad_clip: float, grad_clip_type: str +) -> int: + params_with_grads = [p for p in stitched_module.parameters() if p.grad is not None] + if not params_with_grads: + return 0 + + device = params_with_grads[0].device + clipped_count = torch.zeros((), dtype=torch.int64, device=device) + if grad_clip_type == "norm": + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=params_with_grads, + max_norm=grad_clip, + ) + grad_norm = torch.as_tensor(grad_norm, device=device) + clipped_count += (grad_norm > grad_clip).to(torch.int64) + elif grad_clip_type == "value": + max_abs_grad = torch.stack([p.grad.detach().abs().max() for p in params_with_grads]).max() + clipped_count += (max_abs_grad > grad_clip).to(torch.int64) + torch.nn.utils.clip_grad_value_( + parameters=params_with_grads, + clip_value=grad_clip, + ) + else: + raise RuntimeError(f"Invalid {grad_clip_type}") + + return int(clipped_count.item()) + + def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: """Top-level entry point for bypass distillation stage. @@ -507,35 +562,11 @@ def train( if optimizer is not None: grad_clip = cfg.bypass.training.grad_clip if grad_clip is not None: - if cfg.bypass.training.grad_clip_type == "norm": - grad_norm = torch.nn.utils.clip_grad_norm_( - parameters=stitched_module.parameters(), - max_norm=grad_clip, - ) - if grad_norm > grad_clip: - cfg.bypass.training.clipping_count += 1 - elif cfg.bypass.training.grad_clip_type == "value": - # Stack per-param maxes into a single GPU tensor and - # reduce before `.item()` so we sync once per block - # instead of once per parameter (see per-block batching - # rationale at lines 301-304). - grad_maxes = [ - p.grad.abs().max() - for p in stitched_module.parameters() - if p.grad is not None - ] - if grad_maxes: - max_abs_grad = torch.stack(grad_maxes).max().item() - else: - max_abs_grad = 0.0 - if max_abs_grad > grad_clip: - cfg.bypass.training.clipping_count += 1 - torch.nn.utils.clip_grad_value_( - parameters=stitched_module.parameters(), - clip_value=grad_clip, - ) - else: - raise RuntimeError(f"Invalid {cfg.bypass.training.grad_clip_type}") + cfg.bypass.training.clipping_count += _clip_stitched_module_grads( + stitched_module=stitched_module, + grad_clip=grad_clip, + grad_clip_type=cfg.bypass.training.grad_clip_type, + ) assert grad_scaler is not None grad_scaler.step(optimizer) @@ -822,8 +853,11 @@ def run_bypassed_training(cfg: DictConfig): set_experiment_id(cfg) set_experiment_dir(cfg) - if bypass_run_is_complete(cfg): - mprint(f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping") + dist.barrier() + bypass_complete = bypass_run_is_complete(cfg) if dist.is_master() else None + bypass_complete = dist.broadcast(bypass_complete, src=0) + if bypass_complete: + print_rank_0(f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping") return descriptor = ModelDescriptorFactory.get(cfg.descriptor) @@ -903,7 +937,9 @@ def run_bypassed_training(cfg: DictConfig): saved_skip = resume_cfg.training.get( "skip_first_batches", cfg.bypass.training.skip_first_batches ) - resume_skip_first_batches = saved_skip + resume_cfg.iter_num + resume_skip_first_batches = _get_resume_skip_first_batches( + saved_skip, resume_cfg.iter_num + ) if "data" in resume_cfg and "shuffle_train_data_seed" in resume_cfg.data: cfg.bypass.data.shuffle_train_data_seed = resume_cfg.data.shuffle_train_data_seed if "seed" in resume_cfg: @@ -1228,10 +1264,7 @@ def run_bypassed_training(cfg: DictConfig): raise dist.barrier() - if dist.is_master(): - mprint("Realizing bypass checkpoints") - realized_checkpoint, ckpts_symlink = realize_bypass_checkpoints(cfg) - mark_bypass_run_completed(cfg, realized_checkpoint, ckpts_symlink) + _finalize_bypass_run(cfg) dist.barrier() diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py index 0b43a97c01c..58f1d2955a7 100644 --- a/tests/unit/torch/puzzletron/test_bypass_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -117,10 +117,11 @@ def test_pipeline_ownership_context_rejects_idle_rank(): get_pipeline_ownership_context([0, 0, 1], rank=2) -def _experiment_cfg(keys_to_learn: str): +def _experiment_cfg(keys_to_learn): return OmegaConf.create( { "descriptor": "test_descriptor", + "teacher_dir": "/tmp/teacher_a", "dataset_path": "/tmp/dataset_a", "bypass": { "experiment_id": None, @@ -214,6 +215,34 @@ def test_config_fingerprint_changes_with_shuffle_seed(): assert get_bypass_config_fingerprint(cfg) != original +def test_config_fingerprint_changes_with_teacher_dir(): + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + cfg.teacher_dir = "/tmp/teacher_b" + assert get_bypass_config_fingerprint(cfg) != original + + +def test_config_fingerprint_changes_with_descriptor(): + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + cfg.descriptor = "other_descriptor" + assert get_bypass_config_fingerprint(cfg) != original + + +def test_config_fingerprint_canonicalizes_single_keys_to_learn(): + cfg_a = _experiment_cfg("entire_block") + cfg_b = _experiment_cfg(["entire_block"]) + + assert get_bypass_config_fingerprint(cfg_a) == get_bypass_config_fingerprint(cfg_b) + + +def test_config_fingerprint_canonicalizes_keys_to_learn_order(): + cfg_a = _experiment_cfg(["subblock_ffn", "subblock_attention"]) + cfg_b = _experiment_cfg(["subblock_attention", "subblock_ffn"]) + + assert get_bypass_config_fingerprint(cfg_a) == get_bypass_config_fingerprint(cfg_b) + + def test_experiment_id_does_not_change_with_dataset_path(): cfg_a = _experiment_cfg("subblock_attention") cfg_b = _experiment_cfg("subblock_attention") @@ -221,3 +250,20 @@ def test_experiment_id_does_not_change_with_dataset_path(): set_experiment_id(cfg_a) set_experiment_id(cfg_b) assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id + + +def test_experiment_id_changes_with_teacher_source(): + cfg_a = _experiment_cfg("subblock_attention") + cfg_b = _experiment_cfg("subblock_attention") + cfg_b.teacher_dir = "/tmp/teacher_b" + set_experiment_id(cfg_a) + set_experiment_id(cfg_b) + assert cfg_a.bypass.experiment_id != cfg_b.bypass.experiment_id + + +def test_experiment_id_canonicalizes_keys_to_learn_order(): + cfg_a = _experiment_cfg(["subblock_ffn", "subblock_attention"]) + cfg_b = _experiment_cfg(["subblock_attention", "subblock_ffn"]) + set_experiment_id(cfg_a) + set_experiment_id(cfg_b) + assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py index 5975612809a..117b6606744 100644 --- a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -28,6 +28,7 @@ import json from pathlib import Path +import torch from omegaconf import OmegaConf import modelopt.torch.puzzletron.bypass_distillation.training_loop as tl @@ -213,6 +214,12 @@ def test_resume_state_used_when_no_init_checkpoint_path(): assert tl._get_resume_state_path(cfg, "/tmp/resume-ckpt") == "/tmp/resume-ckpt" +def test_resume_skip_first_batches_uses_completed_iter_count(): + assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=0) == 10 + assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=1) == 10 + assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=7) == 16 + + def test_flush_loss_buffer_single_rank_without_process_group(): local_buffer = {1: {"block_0": 0.25}} stitched_losses_history = {} @@ -222,6 +229,129 @@ def test_flush_loss_buffer_single_rank_without_process_group(): assert stitched_losses_history == local_buffer +def test_run_bypassed_training_broadcasts_completion_skip(monkeypatch, tmp_path): + cfg = _base_cfg(tmp_path) + cfg.bypass.experiment_id = None + checks = [] + broadcasts = [] + messages = [] + + def fail(*args, **kwargs): + raise AssertionError("training setup should not run after completed bypass check") + + monkeypatch.setattr(tl.dist, "local_rank", lambda: 0) + monkeypatch.setattr(tl.dist, "barrier", lambda: None) + monkeypatch.setattr(tl.dist, "is_master", lambda: True) + monkeypatch.setattr( + tl.dist, "broadcast", lambda value, src: broadcasts.append((value, src)) or value + ) + monkeypatch.setattr( + tl, "bypass_run_is_complete", lambda cfg_arg: checks.append(cfg_arg) or True + ) + monkeypatch.setattr(tl, "print_rank_0", lambda *args, **kwargs: messages.append(args[0])) + monkeypatch.setattr(tl.ModelDescriptorFactory, "get", fail) + + tl.run_bypassed_training(cfg) + + assert checks == [cfg] + assert broadcasts == [(True, 0)] + assert messages == [f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping"] + + +def test_run_bypassed_training_non_master_uses_broadcasted_completion(monkeypatch, tmp_path): + cfg = _base_cfg(tmp_path) + cfg.bypass.experiment_id = None + + def fail(*args, **kwargs): + raise AssertionError("non-master should not evaluate completion or continue setup") + + monkeypatch.setattr(tl.dist, "local_rank", lambda: 0) + monkeypatch.setattr(tl.dist, "barrier", lambda: None) + monkeypatch.setattr(tl.dist, "is_master", lambda: False) + monkeypatch.setattr(tl.dist, "broadcast", lambda value, src: True) + monkeypatch.setattr(tl, "bypass_run_is_complete", fail) + monkeypatch.setattr(tl.ModelDescriptorFactory, "get", fail) + + tl.run_bypassed_training(cfg) + + +def test_clip_stitched_module_grads_norm_counts_clipped_block(): + module = torch.nn.Linear(2, 1, bias=False) + module.weight.grad = torch.full_like(module.weight, 10.0) + + assert tl._clip_stitched_module_grads(module, grad_clip=0.1, grad_clip_type="norm") == 1 + assert torch.linalg.vector_norm(module.weight.grad) <= 0.1 + 1e-6 + + +def test_clip_stitched_module_grads_value_counts_clipped_block(): + module = torch.nn.Linear(2, 1, bias=False) + module.weight.grad = torch.tensor([[0.05, 2.0]]) + + assert tl._clip_stitched_module_grads(module, grad_clip=0.5, grad_clip_type="value") == 1 + assert module.weight.grad.abs().max() <= 0.5 + + +def test_clip_stitched_module_grads_returns_zero_when_below_threshold(): + module = torch.nn.Linear(2, 1, bias=False) + module.weight.grad = torch.full_like(module.weight, 0.01) + + assert tl._clip_stitched_module_grads(module, grad_clip=1.0, grad_clip_type="value") == 0 + + +def test_finalize_bypass_run_skips_realization_when_checkpoint_saving_disabled(monkeypatch): + cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": True}}) + + def fail(*args, **kwargs): + raise AssertionError("checkpoint realization should be skipped") + + monkeypatch.setattr(tl.dist, "is_master", lambda: True) + monkeypatch.setattr(tl, "realize_bypass_checkpoints", fail) + monkeypatch.setattr(tl, "mark_bypass_run_completed", fail) + + tl._finalize_bypass_run(cfg) + + +def test_finalize_bypass_run_skips_completion_when_no_checkpoint_exists(monkeypatch): + cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": False}}) + completed = False + + def missing_checkpoint(_cfg): + raise FileNotFoundError("missing checkpoint") + + def mark_completed(*args, **kwargs): + nonlocal completed + completed = True + + monkeypatch.setattr(tl.dist, "is_master", lambda: True) + monkeypatch.setattr(tl, "realize_bypass_checkpoints", missing_checkpoint) + monkeypatch.setattr(tl, "mark_bypass_run_completed", mark_completed) + + tl._finalize_bypass_run(cfg) + + assert completed is False + + +def test_finalize_bypass_run_marks_realized_checkpoint(monkeypatch): + cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": False}}) + realized = Path("/tmp/realized") + symlink = Path("/tmp/ckpts/run_0") + completed = {} + + monkeypatch.setattr(tl.dist, "is_master", lambda: True) + monkeypatch.setattr(tl, "realize_bypass_checkpoints", lambda _cfg: (realized, symlink)) + monkeypatch.setattr( + tl, + "mark_bypass_run_completed", + lambda cfg_arg, realized_arg, symlink_arg: completed.update( + cfg=cfg_arg, realized=realized_arg, symlink=symlink_arg + ), + ) + + tl._finalize_bypass_run(cfg) + + assert completed == {"cfg": cfg, "realized": realized, "symlink": symlink} + + def test_realize_bypass_checkpoints_uses_resolved_symlink_target(monkeypatch, tmp_path: Path): monkeypatch.chdir(tmp_path) experiment_dir = Path("puzzle/bypass/bypass_runs/run_0") From 47e007493e0656395ff6dca83435501230319725 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 2 Jun 2026 12:16:34 +0200 Subject: [PATCH 3/8] Fix bypass review issues Signed-off-by: Sepehr Sameni --- .../bypass_checkpoint_utils.py | 12 ++------ .../stitched_model_factory.py | 2 ++ .../bypass_distillation/training_loop.py | 2 +- .../test_bypass_checkpoint_utils.py | 30 ++++++++++++------- .../puzzletron/test_bypass_keys_to_learn.py | 7 +++++ .../test_launch_bypass_distillation.py | 4 +-- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py index 89260bcc5f9..f7219a9ef7e 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -150,19 +150,14 @@ def load_local_state( del loaded_scaler_state -def _save_local_file(obj, save_path: Path | str, overwrite=True): +def _save_local_file(obj, save_path: Path | str): save_path = Path(save_path) - if save_path.exists(): - if not overwrite: - mprint(f'WARNING: Local save path "{save_path}" already exists. Skipping') - return torch.save(obj, save_path) def _save_local_state( stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], checkpoint_dir: Path | str, - overwrite=True, ) -> None: """Persist optimizer and grad-scaler state for each stitched module. @@ -193,7 +188,7 @@ def _save_local_state( aprint( f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" ) - _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + _save_local_file(optimizer.state_dict(), optimizer_state_path) # Persist GradScaler state. Required for correct resume when # use_grad_scaling=True (state dict carries running scale + growth tracker). @@ -206,7 +201,7 @@ def _save_local_state( f"Saving grad_scaler state for module {stitched_module_name} " f"to {grad_scaler_state_path}" ) - _save_local_file(grad_scaler.state_dict(), grad_scaler_state_path, overwrite=overwrite) + _save_local_file(grad_scaler.state_dict(), grad_scaler_state_path) dist.barrier() @@ -229,7 +224,6 @@ def save_bypass_checkpoint( _save_local_state( stitched_module_descriptors=stitched_module_descriptors, checkpoint_dir=checkpoint_dir, - overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints, ) # Save as HF checkpoint. Must use the gather-aware variant: bypass training is # pipeline-parallel so each rank's `model.state_dict()` only carries its own diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py index c44be3e7e3f..d7ed78a64c5 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -124,6 +124,8 @@ def _param_names_for_subblock_key( block_configs = getattr(model.config, "block_configs", None) or getattr( lm_config, "block_configs", None ) + if subblock_key == "subblock_mamba" and block_configs is None: + raise ValueError("keys_to_learn='subblock_mamba' requires model config block_configs") collected: list[str] = [] for group_name in group_names: diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index baae78fb071..84cc359957e 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -127,7 +127,7 @@ def _get_resume_state_path(cfg: DictConfig, resume_checkpoint_path: Optional[str def _get_resume_skip_first_batches(saved_skip: int, resume_iter_num: int) -> int: - return saved_skip + max(0, resume_iter_num - 1) + return saved_skip + max(0, resume_iter_num) def _finalize_bypass_run(cfg: DictConfig) -> None: diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index 4f8cd88c7ef..dd717e7e1ac 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -21,7 +21,7 @@ What's covered here (CPU-only, codecov-visible): * ``find_latest_run_dir`` — every branch of the regex/scan/symlink logic. - * ``_save_local_file`` — overwrite/skip semantics. + * ``_save_local_file`` — checkpoint-local file writes. * ``_save_local_state`` — same three save-path assertions as the GPU file (state_dict / optimizer / grad_scaler), but on CPU so codecov picks them up. The GPU file's ``test_load_local_state_*`` cases stay there because @@ -196,20 +196,11 @@ def test_save_local_file_writes_object_to_disk(tmp_path: Path): def test_save_local_file_overwrite_true_replaces_contents(tmp_path: Path): target = tmp_path / "blob.pth" bcu._save_local_file({"v": torch.tensor([1])}, target) - bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=True) + bcu._save_local_file({"v": torch.tensor([99])}, target) loaded = torch.load(target, weights_only=True) assert torch.equal(loaded["v"], torch.tensor([99])) -def test_save_local_file_overwrite_false_skips_existing(tmp_path: Path): - target = tmp_path / "blob.pth" - bcu._save_local_file({"v": torch.tensor([1])}, target) - # Second save should be a no-op. - bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=False) - loaded = torch.load(target, weights_only=True) - assert torch.equal(loaded["v"], torch.tensor([1])) - - # --------------------------------------------------------------------------- # _save_local_state: optimizer + grad_scaler only. # Weights deliberately do NOT land here — the HF checkpoint at the same @@ -226,6 +217,23 @@ def test_save_local_state_writes_optimizer_and_grad_scaler(tmp_path: Path, bcu_n assert (stitched / "block_0.grad_scaler.pth").exists() +def test_save_local_state_overwrites_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + stale_optimizer_state = {"stale": torch.tensor([1])} + stale_scaler_state = {"stale": torch.tensor([1])} + torch.save(stale_optimizer_state, stitched / "block_0.optimizer_state.pth") + torch.save(stale_scaler_state, stitched / "block_0.grad_scaler.pth") + + bcu_no_dist._save_local_state(descriptors, tmp_path) + + optimizer_state = torch.load(stitched / "block_0.optimizer_state.pth", weights_only=True) + grad_scaler_state = torch.load(stitched / "block_0.grad_scaler.pth", weights_only=True) + assert "stale" not in optimizer_state + assert "stale" not in grad_scaler_state + + def test_save_local_state_does_not_write_weights_state_dict(tmp_path: Path, bcu_no_dist): """Pin the de-duplication: weights live in the HF checkpoint, not here.""" descriptors = OrderedDict([("block_0", _make_descriptor())]) diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py index 6e7b663b6e0..39a7d8afeb2 100644 --- a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -131,6 +131,13 @@ def test_subblock_attention_trains_only_self_attn(): assert not any(".mlp." in n for n in trainable), trainable +def test_subblock_mamba_without_block_configs_is_rejected(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match="subblock_mamba.*block_configs"): + _set_keys_to_learn(model, descriptor, "subblock_mamba") + + def test_entire_block_trains_attention_and_mlp(): model = _make_dense_model(num_layers=2) descriptor = _make_descriptor(num_layers=2) diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py index 117b6606744..35ecd4c6d69 100644 --- a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -216,8 +216,8 @@ def test_resume_state_used_when_no_init_checkpoint_path(): def test_resume_skip_first_batches_uses_completed_iter_count(): assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=0) == 10 - assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=1) == 10 - assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=7) == 16 + assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=1) == 11 + assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=7) == 17 def test_flush_loss_buffer_single_rank_without_process_group(): From 13e23b979f476add1f1fdbb5a7d3fe83af850224 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 2 Jun 2026 13:14:02 +0200 Subject: [PATCH 4/8] Fix AMP gradient clipping in bypass training Signed-off-by: Sepehr Sameni --- modelopt/torch/puzzletron/bypass_distillation/training_loop.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 84cc359957e..45fc0901462 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -562,6 +562,7 @@ def train( if optimizer is not None: grad_clip = cfg.bypass.training.grad_clip if grad_clip is not None: + grad_scaler.unscale_(optimizer) cfg.bypass.training.clipping_count += _clip_stitched_module_grads( stitched_module=stitched_module, grad_clip=grad_clip, From 3d684db7fb5a16248a9379cf45cfd609065c0463 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 2 Jun 2026 13:29:42 +0200 Subject: [PATCH 5/8] Fix bypass GradScaler type narrowing Signed-off-by: Sepehr Sameni --- modelopt/torch/puzzletron/bypass_distillation/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 45fc0901462..876a9eec5ca 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -560,6 +560,7 @@ def train( if not is_accumulating: if optimizer is not None: + assert grad_scaler is not None grad_clip = cfg.bypass.training.grad_clip if grad_clip is not None: grad_scaler.unscale_(optimizer) @@ -569,7 +570,6 @@ def train( grad_clip_type=cfg.bypass.training.grad_clip_type, ) - assert grad_scaler is not None grad_scaler.step(optimizer) grad_scaler.update() optimizer.zero_grad(set_to_none=True) From 34558b2c8ee6962f02d5dbe0302809bd764a5958 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 2 Jun 2026 15:09:03 +0200 Subject: [PATCH 6/8] Tighten bypass unit tests Signed-off-by: Sepehr Sameni --- .../bypass_distillation/training_loop.py | 43 ++++++++---- .../test_bypass_checkpoint_utils.py | 27 -------- .../puzzletron/test_bypass_keys_to_learn.py | 2 +- .../puzzletron/test_bypass_lr_scheduler.py | 33 ++++----- .../torch/puzzletron/test_bypass_utils.py | 68 +++---------------- .../test_launch_bypass_distillation.py | 53 ++++++++++----- .../test_stitched_model_factory_buffers.py | 14 ---- 7 files changed, 93 insertions(+), 147 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 876a9eec5ca..f2ed37f1f45 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -38,6 +38,8 @@ import torch import transformers from omegaconf import DictConfig, OmegaConf +from torch.amp.grad_scaler import GradScaler +from torch.optim import Optimizer from torch.utils.data.dataloader import DataLoader from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase @@ -180,6 +182,28 @@ def _clip_stitched_module_grads( return int(clipped_count.item()) +def _step_stitched_module_optimizer( + stitched_module: StitchedModule, + optimizer: Optimizer, + grad_scaler: GradScaler, + grad_clip: Optional[float], + grad_clip_type: str, +) -> int: + clipped_count = 0 + if grad_clip is not None: + grad_scaler.unscale_(optimizer) + clipped_count = _clip_stitched_module_grads( + stitched_module=stitched_module, + grad_clip=grad_clip, + grad_clip_type=grad_clip_type, + ) + + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + return clipped_count + + def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: """Top-level entry point for bypass distillation stage. @@ -561,18 +585,13 @@ def train( if not is_accumulating: if optimizer is not None: assert grad_scaler is not None - grad_clip = cfg.bypass.training.grad_clip - if grad_clip is not None: - grad_scaler.unscale_(optimizer) - cfg.bypass.training.clipping_count += _clip_stitched_module_grads( - stitched_module=stitched_module, - grad_clip=grad_clip, - grad_clip_type=cfg.bypass.training.grad_clip_type, - ) - - grad_scaler.step(optimizer) - grad_scaler.update() - optimizer.zero_grad(set_to_none=True) + cfg.bypass.training.clipping_count += _step_stitched_module_optimizer( + stitched_module=stitched_module, + optimizer=optimizer, + grad_scaler=grad_scaler, + grad_clip=cfg.bypass.training.grad_clip, + grad_clip_type=cfg.bypass.training.grad_clip_type, + ) # Single GPU→CPU sync for all per-block losses collected above. Stacking # into a 1-D tensor lets us issue exactly one ``.to("cpu")`` instead of diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index dd717e7e1ac..c8b0dd09a0c 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -21,7 +21,6 @@ What's covered here (CPU-only, codecov-visible): * ``find_latest_run_dir`` — every branch of the regex/scan/symlink logic. - * ``_save_local_file`` — checkpoint-local file writes. * ``_save_local_state`` — same three save-path assertions as the GPU file (state_dict / optimizer / grad_scaler), but on CPU so codecov picks them up. The GPU file's ``test_load_local_state_*`` cases stay there because @@ -80,10 +79,6 @@ def _make_descriptor(*, with_optimizer: bool = True, with_scaler: bool = True): # --------------------------------------------------------------------------- -def test_find_latest_run_dir_returns_none_for_empty_dir(tmp_path: Path): - assert bcu.find_latest_run_dir(tmp_path) is None - - def test_find_latest_run_dir_picks_only_step_with_marker(tmp_path: Path): step_dir = tmp_path / "step-000010-ckpt" step_dir.mkdir() @@ -179,28 +174,6 @@ def test_find_latest_run_dir_ignores_latest_to_best_checkpoint(tmp_path: Path): assert bcu.find_latest_run_dir(tmp_path) == str(completed) - -# --------------------------------------------------------------------------- -# _save_local_file -# --------------------------------------------------------------------------- - - -def test_save_local_file_writes_object_to_disk(tmp_path: Path): - target = tmp_path / "blob.pth" - bcu._save_local_file({"a": torch.tensor([1, 2, 3])}, target) - assert target.exists() - loaded = torch.load(target, weights_only=True) - assert torch.equal(loaded["a"], torch.tensor([1, 2, 3])) - - -def test_save_local_file_overwrite_true_replaces_contents(tmp_path: Path): - target = tmp_path / "blob.pth" - bcu._save_local_file({"v": torch.tensor([1])}, target) - bcu._save_local_file({"v": torch.tensor([99])}, target) - loaded = torch.load(target, weights_only=True) - assert torch.equal(loaded["v"], torch.tensor([99])) - - # --------------------------------------------------------------------------- # _save_local_state: optimizer + grad_scaler only. # Weights deliberately do NOT land here — the HF checkpoint at the same diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py index 39a7d8afeb2..3de0a6ab050 100644 --- a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -243,7 +243,7 @@ def test_empty_key_list_is_rejected(): @pytest.mark.parametrize( "keys_to_learn", - ["subblock_ffn", "subblock_attention", "entire_block"], + ["subblock_attention", "entire_block"], ) def test_subblock_keys_skip_non_floating_point_params(keys_to_learn): """Integer / non-floating buffers exposed as parameters must stay frozen. diff --git a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py index 38701ba8be3..101657f7d13 100644 --- a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py +++ b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py @@ -60,19 +60,23 @@ def _make_cfg( ) -def test_degenerate_budget_returns_base_lr(): +@pytest.mark.parametrize( + ("warmup_steps", "lr_decay_steps", "learning_rate"), + [ + (10, 10, 0.5), + (20, 10, 0.7), + ], +) +def test_degenerate_budget_returns_base_lr(warmup_steps, lr_decay_steps, learning_rate): """When ``lr_decay_steps <= warmup_steps`` (tiny test budgets), the scheduler must short-circuit to ``learning_rate`` rather than divide by zero.""" - cfg = _make_cfg(warmup_steps=10, lr_decay_steps=10, learning_rate=0.5) - assert _get_lr(cfg, step=0) == 0.5 - assert _get_lr(cfg, step=1) == 0.5 - assert _get_lr(cfg, step=99) == 0.5 - - -def test_degenerate_budget_warmup_greater_than_decay(): - """``lr_decay_steps < warmup_steps`` is also caught by the same guard.""" - cfg = _make_cfg(warmup_steps=20, lr_decay_steps=10, learning_rate=0.7) - assert _get_lr(cfg, step=5) == 0.7 + cfg = _make_cfg( + warmup_steps=warmup_steps, + lr_decay_steps=lr_decay_steps, + learning_rate=learning_rate, + ) + assert _get_lr(cfg, step=0) == learning_rate + assert _get_lr(cfg, step=99) == learning_rate def test_warmup_linear_ramp(): @@ -118,10 +122,3 @@ def test_post_decay_clamps_to_min_lr(): cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) assert _get_lr(cfg, step=21) == 0.1 assert _get_lr(cfg, step=1000) == 0.1 - - -def test_min_lr_zero_decays_to_zero(): - """Common config: ``min_lr=0`` → cosine endpoint is exactly 0.""" - cfg = _make_cfg(warmup_steps=10, lr_decay_steps=30, learning_rate=2.0, min_lr=0.0) - assert _get_lr(cfg, step=30) == pytest.approx(0.0) - assert _get_lr(cfg, step=31) == 0.0 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py index 58f1d2955a7..90799d7e33f 100644 --- a/tests/unit/torch/puzzletron/test_bypass_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -26,67 +26,21 @@ ) -def test_single_gpu_all_to_rank_0(): - """With world_size=1, all 4 modules should be assigned to rank 0.""" - ownership = get_distributed_modules_ownership(module_count=4, world_size=1) - assert ownership == [0, 0, 0, 0] - - -def test_even_distribution(): - """With world_size=2 and 4 modules, each rank should own exactly 2 modules.""" - ownership = get_distributed_modules_ownership(module_count=4, world_size=2) - assert ownership.count(0) == 2 - assert ownership.count(1) == 2 - assert len(ownership) == 4 - - -def test_uneven_distribution(): - """With world_size=2 and 3 modules, rank 0 should own 2 and rank 1 should own 1.""" - ownership = get_distributed_modules_ownership(module_count=3, world_size=2) - assert ownership.count(0) == 2 - assert ownership.count(1) == 1 - assert len(ownership) == 3 - - @pytest.mark.parametrize( - ("module_count", "world_size"), + ("module_count", "world_size", "expected_ownership"), [ - (1, 1), - (4, 1), - (4, 2), - (4, 4), - (7, 3), - (10, 4), - (1, 2), + (4, 1, [0, 0, 0, 0]), + (4, 2, [0, 0, 1, 1]), + (3, 2, [0, 0, 1]), + (7, 3, [0, 0, 0, 1, 1, 2, 2]), + (1, 2, [0]), ], ) -def test_total_equals_module_count(module_count, world_size): - """The length of the ownership list must always equal module_count.""" - ownership = get_distributed_modules_ownership(module_count=module_count, world_size=world_size) - assert len(ownership) == module_count - - -def test_consecutive_ownership(): - """Each rank should own a contiguous block of indices (no interleaving).""" - ownership = get_distributed_modules_ownership(module_count=7, world_size=3) - # Verify that once we see a new rank, we never see the previous rank again. - seen_ranks = set() - prev_rank = ownership[0] - seen_ranks.add(prev_rank) - for rank in ownership[1:]: - if rank != prev_rank: - assert rank not in seen_ranks, ( - f"Rank {rank} appears non-consecutively in ownership list: {ownership}" - ) - seen_ranks.add(rank) - prev_rank = rank - - -def test_single_module(): - """With world_size=2 and only 1 module, rank 0 should be the sole owner.""" - ownership = get_distributed_modules_ownership(module_count=1, world_size=2) - assert ownership == [0] - assert len(ownership) == 1 +def test_distributed_modules_ownership(module_count, world_size, expected_ownership): + assert ( + get_distributed_modules_ownership(module_count=module_count, world_size=world_size) + == expected_ownership + ) def test_pipeline_ownership_context_returns_neighbors(): diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py index 35ecd4c6d69..f655b0da926 100644 --- a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -30,6 +30,7 @@ import torch from omegaconf import OmegaConf +from torch.amp.grad_scaler import GradScaler import modelopt.torch.puzzletron.bypass_distillation.training_loop as tl @@ -106,13 +107,6 @@ def test_two_configs_run_twice_with_distinct_overrides(monkeypatch, tmp_path): assert snapshots[1]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 128} -def test_keys_to_learn_override_applied(monkeypatch, tmp_path): - snapshots = _record_calls(monkeypatch) - cfg = _base_cfg(tmp_path, configs=[{"keys_to_learn": "subblock_attention"}]) - tl.launch_bypass_distillation(cfg) - assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" - - def test_per_run_state_reset_before_each_call(monkeypatch, tmp_path): """Every sweep entry must see iter_num=1, step_num=1, token_count=0, best_val_loss=1e9, clipping_count=0, and a fresh experiment_id even when the @@ -135,17 +129,6 @@ def test_per_run_state_reset_before_each_call(monkeypatch, tmp_path): assert snap["bypass"]["training"]["clipping_count"] == 0 -def test_override_without_keys_to_learn_leaves_cfg_value_untouched(monkeypatch, tmp_path): - """A sweep entry that only sets ``model_config_overrides`` must not clobber - the inherited ``keys_to_learn`` (the dispatcher's `if "keys_to_learn" in override` - guard).""" - snapshots = _record_calls(monkeypatch) - cfg = _base_cfg(tmp_path, configs=[{"model_config_overrides": {"intermediate_size": 256}}]) - tl.launch_bypass_distillation(cfg) - # keys_to_learn was set to "subblock_ffn" in _base_cfg — must survive. - assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" - - def test_sweep_entry_without_keys_to_learn_uses_base_not_previous_override(monkeypatch, tmp_path): snapshots = _record_calls(monkeypatch) cfg = _base_cfg( @@ -298,6 +281,40 @@ def test_clip_stitched_module_grads_returns_zero_when_below_threshold(): assert tl._clip_stitched_module_grads(module, grad_clip=1.0, grad_clip_type="value") == 0 +def test_step_stitched_module_optimizer_unscales_before_clipping(monkeypatch): + module = torch.nn.Linear(1, 1, bias=False) + optimizer = torch.optim.SGD(module.parameters(), lr=0.0) + grad_scaler = GradScaler(device="cpu", enabled=True, init_scale=16.0) + grad_scaler.scale(module.weight.sum() * 2.0).backward() + assert module.weight.grad is not None + assert torch.equal(module.weight.grad, torch.full_like(module.weight, 32.0)) + observed = {} + + def capture_clip(stitched_module, grad_clip, grad_clip_type): + observed["stitched_module"] = stitched_module + observed["grad_clip"] = grad_clip + observed["grad_clip_type"] = grad_clip_type + observed["grad"] = module.weight.grad.detach().clone() + return 1 + + monkeypatch.setattr(tl, "_clip_stitched_module_grads", capture_clip) + + clipped_count = tl._step_stitched_module_optimizer( + stitched_module=module, + optimizer=optimizer, + grad_scaler=grad_scaler, + grad_clip=1.0, + grad_clip_type="norm", + ) + + assert clipped_count == 1 + assert observed["stitched_module"] is module + assert observed["grad_clip"] == 1.0 + assert observed["grad_clip_type"] == "norm" + assert torch.equal(observed["grad"], torch.full_like(module.weight, 2.0)) + assert module.weight.grad is None + + def test_finalize_bypass_run_skips_realization_when_checkpoint_saving_disabled(monkeypatch): cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": True}}) diff --git a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py index 5fab764b565..d7ab66e1264 100644 --- a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py +++ b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py @@ -30,10 +30,6 @@ ) -def test_module_with_no_buffers_returns_empty_set(): - assert _get_all_non_persistent_buffers_set(nn.Module()) == set() - - def test_persistent_buffer_excluded_non_persistent_included(): m = nn.Module() m.register_buffer("p", torch.zeros(1), persistent=True) @@ -53,16 +49,6 @@ def test_nested_submodule_paths_are_fully_qualified(): assert out == {"inner.nb"} -def test_top_level_buffer_has_no_leading_dot(): - """Module name is "" at the root — fully-qualified name must not start - with a dot, otherwise it won't match any state_dict key.""" - m = nn.Module() - m.register_buffer("x", torch.zeros(1), persistent=False) - out = _get_all_non_persistent_buffers_set(m) - assert out == {"x"} - assert not any(name.startswith(".") for name in out) - - def test_mix_of_persistent_and_non_persistent_in_nested_module(): """The full discrimination: only the nested non-persistent buffer should appear, with its fully-qualified path.""" From 4e75b7bb5b778a59915afaca52dbc2f5a46602f9 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Tue, 2 Jun 2026 15:33:16 +0200 Subject: [PATCH 7/8] Apply Ruff format to bypass tests Signed-off-by: Sepehr Sameni --- tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index c8b0dd09a0c..3c709c0670a 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -174,6 +174,7 @@ def test_find_latest_run_dir_ignores_latest_to_best_checkpoint(tmp_path: Path): assert bcu.find_latest_run_dir(tmp_path) == str(completed) + # --------------------------------------------------------------------------- # _save_local_state: optimizer + grad_scaler only. # Weights deliberately do NOT land here — the HF checkpoint at the same From 98cbde36d36d278a4b50362820ace23b52e5256f Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 3 Jun 2026 09:20:35 +0200 Subject: [PATCH 8/8] Deduplicate bypass unit tests Signed-off-by: Sepehr Sameni --- .../test_bypass_checkpoint_utils.py | 196 ++++--------- .../puzzletron/test_bypass_dataloaders.py | 100 +++---- .../puzzletron/test_bypass_keys_to_learn.py | 150 +++------- .../torch/puzzletron/test_bypass_losses.py | 79 ++--- .../puzzletron/test_bypass_lr_scheduler.py | 80 ++---- .../torch/puzzletron/test_bypass_utils.py | 165 +++++------ .../puzzletron/test_checkpoint_utils_hf.py | 22 +- .../test_launch_bypass_distillation.py | 271 +++++++----------- .../test_replacement_library_bypass_config.py | 59 ++-- .../test_stitched_model_factory_buffers.py | 31 +- 10 files changed, 408 insertions(+), 745 deletions(-) diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index 3c709c0670a..81b340be150 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -74,105 +74,58 @@ def _make_descriptor(*, with_optimizer: bool = True, with_scaler: bool = True): ) +def _make_checkpoint_dir(parent: Path, name: str, *, completed: bool = True) -> Path: + checkpoint_dir = parent / name + checkpoint_dir.mkdir(parents=True) + if completed: + (checkpoint_dir / "saving_completed").touch() + return checkpoint_dir + + # --------------------------------------------------------------------------- # find_latest_run_dir # --------------------------------------------------------------------------- -def test_find_latest_run_dir_picks_only_step_with_marker(tmp_path: Path): - step_dir = tmp_path / "step-000010-ckpt" - step_dir.mkdir() - (step_dir / "saving_completed").touch() - assert bcu.find_latest_run_dir(tmp_path) == str(step_dir) - - -def test_find_latest_run_dir_picks_highest_step_number(tmp_path: Path): - """When several plain step checkpoints have completed markers, the highest - integer wins — not lexicographic order, not insertion order.""" - for i in (5, 10, 20): - d = tmp_path / f"step-{i:06d}-ckpt" - d.mkdir() - (d / "saving_completed").touch() - assert bcu.find_latest_run_dir(tmp_path) == str(tmp_path / "step-000020-ckpt") - - -def test_find_latest_run_dir_skips_step_without_marker(tmp_path: Path): - """A partially-written checkpoint (no ``saving_completed``) must be skipped - even when it has a higher step number — otherwise resume would crash on a - truncated state dict.""" - high = tmp_path / "step-000099-ckpt" - high.mkdir() - # No saving_completed → must be ignored. - low = tmp_path / "step-000050-ckpt" - low.mkdir() - (low / "saving_completed").touch() - assert bcu.find_latest_run_dir(tmp_path) == str(low) - - -def test_find_latest_run_dir_returns_none_when_no_step_has_marker(tmp_path: Path): - (tmp_path / "step-000010-ckpt").mkdir() - (tmp_path / "step-000020-ckpt").mkdir() - # No saving_completed anywhere. - assert bcu.find_latest_run_dir(tmp_path) is None +def test_find_latest_run_dir_scans_highest_completed_plain_step(tmp_path: Path): + """The scan branch picks the highest completed plain step checkpoint only.""" + scan_dir = tmp_path / "scan" + _make_checkpoint_dir(scan_dir, "step-000005-ckpt") + expected = _make_checkpoint_dir(scan_dir, "step-000020-ckpt") + _make_checkpoint_dir(scan_dir, "step-000099-ckpt", completed=False) + for name in ("best-step-000099-ckpt", "start-step-000001-ckpt", "final-step-000050-ckpt"): + _make_checkpoint_dir(scan_dir, name) + assert bcu.find_latest_run_dir(scan_dir) == str(expected) -def test_find_latest_run_dir_excludes_non_plain_step_names(tmp_path: Path): - """``best-step-*`` / ``start-step-*`` / ``final-step-*`` aren't valid resume - targets — pinned by the docstring on lines 39-42.""" - for name in ("best-step-000099-ckpt", "start-step-000001-ckpt", "final-step-000050-ckpt"): - d = tmp_path / name - d.mkdir() - (d / "saving_completed").touch() - # No plain step-*-ckpt at all. - assert bcu.find_latest_run_dir(tmp_path) is None + no_completed_plain_steps = tmp_path / "no_completed_plain_steps" + _make_checkpoint_dir(no_completed_plain_steps, "step-000010-ckpt", completed=False) + _make_checkpoint_dir(no_completed_plain_steps, "best-step-000020-ckpt") + assert bcu.find_latest_run_dir(no_completed_plain_steps) is None -def test_find_latest_run_dir_uses_latest_symlink_fast_path(tmp_path: Path): +def test_find_latest_run_dir_handles_latest_symlink_fast_path_and_fallbacks(tmp_path: Path): """The ``latest`` symlink, when present and complete, short-circuits the scan — even when a numerically higher step dir also has a marker. This matters because the scan branch can be slow on filesystems with many step dirs (NFS, lustre).""" - target = tmp_path / "step-000010-ckpt" - target.mkdir() - (target / "saving_completed").touch() - (tmp_path / "latest").symlink_to(target.name) - - higher = tmp_path / "step-000020-ckpt" - higher.mkdir() - (higher / "saving_completed").touch() - - # Symlink wins despite higher step existing, but returns the resolved target - # so callers open the same checkpoint that was validated. - assert bcu.find_latest_run_dir(tmp_path) == str(target.resolve()) - - -def test_find_latest_run_dir_falls_through_when_latest_lacks_marker(tmp_path: Path): - """A ``latest`` symlink whose target lacks ``saving_completed`` (interrupted - save) must be ignored, falling through to the highest completed step.""" - incomplete = tmp_path / "step-000020-ckpt" - incomplete.mkdir() - # No saving_completed. - (tmp_path / "latest").symlink_to(incomplete.name) + complete_latest = tmp_path / "complete_latest" + target = _make_checkpoint_dir(complete_latest, "step-000010-ckpt") + _make_checkpoint_dir(complete_latest, "step-000020-ckpt") + (complete_latest / "latest").symlink_to(target.name) + assert bcu.find_latest_run_dir(complete_latest) == str(target.resolve()) - completed = tmp_path / "step-000010-ckpt" - completed.mkdir() - (completed / "saving_completed").touch() + incomplete_latest = tmp_path / "incomplete_latest" + incomplete = _make_checkpoint_dir(incomplete_latest, "step-000020-ckpt", completed=False) + completed = _make_checkpoint_dir(incomplete_latest, "step-000010-ckpt") + (incomplete_latest / "latest").symlink_to(incomplete.name) + assert bcu.find_latest_run_dir(incomplete_latest) == str(completed) - assert bcu.find_latest_run_dir(tmp_path) == str(completed) - - -def test_find_latest_run_dir_ignores_latest_to_best_checkpoint(tmp_path: Path): - """`latest` is a resume pointer, so old symlinks to best checkpoints are ignored.""" - best = tmp_path / "best-step-000020-ckpt" - best.mkdir() - (best / "saving_completed").touch() - (tmp_path / "latest").symlink_to(best.name) - - completed = tmp_path / "step-000010-ckpt" - completed.mkdir() - (completed / "saving_completed").touch() - - assert bcu.find_latest_run_dir(tmp_path) == str(completed) + latest_to_best = tmp_path / "latest_to_best" + best = _make_checkpoint_dir(latest_to_best, "best-step-000020-ckpt") + completed = _make_checkpoint_dir(latest_to_best, "step-000010-ckpt") + (latest_to_best / "latest").symlink_to(best.name) + assert bcu.find_latest_run_dir(latest_to_best) == str(completed) # --------------------------------------------------------------------------- @@ -183,18 +136,14 @@ def test_find_latest_run_dir_ignores_latest_to_best_checkpoint(tmp_path: Path): # --------------------------------------------------------------------------- -def test_save_local_state_writes_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): +def test_save_local_state_writes_only_optimizer_and_grad_scaler_state(tmp_path: Path, bcu_no_dist): descriptors = OrderedDict([("block_0", _make_descriptor())]) bcu_no_dist._save_local_state(descriptors, tmp_path) stitched = tmp_path / "stitched" assert (stitched / "block_0.optimizer_state.pth").exists() assert (stitched / "block_0.grad_scaler.pth").exists() + assert not (stitched / "block_0.state_dict.pth").exists() - -def test_save_local_state_overwrites_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): - descriptors = OrderedDict([("block_0", _make_descriptor())]) - bcu_no_dist._save_local_state(descriptors, tmp_path) - stitched = tmp_path / "stitched" stale_optimizer_state = {"stale": torch.tensor([1])} stale_scaler_state = {"stale": torch.tensor([1])} torch.save(stale_optimizer_state, stitched / "block_0.optimizer_state.pth") @@ -208,29 +157,21 @@ def test_save_local_state_overwrites_optimizer_and_grad_scaler(tmp_path: Path, b assert "stale" not in grad_scaler_state -def test_save_local_state_does_not_write_weights_state_dict(tmp_path: Path, bcu_no_dist): - """Pin the de-duplication: weights live in the HF checkpoint, not here.""" - descriptors = OrderedDict([("block_0", _make_descriptor())]) - bcu_no_dist._save_local_state(descriptors, tmp_path) - assert not (tmp_path / "stitched" / "block_0.state_dict.pth").exists() - - -def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): - descriptors = OrderedDict([("block_0", _make_descriptor(with_scaler=False))]) - bcu_no_dist._save_local_state(descriptors, tmp_path) - stitched = tmp_path / "stitched" - assert (stitched / "block_0.optimizer_state.pth").exists() - assert not (stitched / "block_0.grad_scaler.pth").exists() - - -def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): - descriptors = OrderedDict( - [("block_0", _make_descriptor(with_optimizer=False, with_scaler=False))] - ) - bcu_no_dist._save_local_state(descriptors, tmp_path) - stitched = tmp_path / "stitched" - assert not (stitched / "block_0.optimizer_state.pth").exists() - assert not (stitched / "block_0.grad_scaler.pth").exists() +def test_save_local_state_respects_optional_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): + for name, descriptor, expected_files in [ + ("full", _make_descriptor(), {"block_0.optimizer_state.pth", "block_0.grad_scaler.pth"}), + ( + "no_scaler", + _make_descriptor(with_scaler=False), + {"block_0.optimizer_state.pth"}, + ), + ("no_optimizer", _make_descriptor(with_optimizer=False, with_scaler=False), set()), + ]: + checkpoint_dir = tmp_path / name + descriptors = OrderedDict([("block_0", descriptor)]) + bcu_no_dist._save_local_state(descriptors, checkpoint_dir) + stitched = checkpoint_dir / "stitched" + assert {path.name for path in stitched.glob("*")} == expected_files # --------------------------------------------------------------------------- @@ -264,11 +205,14 @@ def patched_save(monkeypatch, bcu_no_dist): return bcu_no_dist -def test_save_bypass_checkpoint_creates_latest_symlink_and_marker(tmp_path: Path, patched_save): +def test_save_bypass_checkpoint_updates_latest_symlink_and_marker(tmp_path: Path, patched_save): experiment_dir = tmp_path / "exp" experiment_dir.mkdir() + old_target = experiment_dir / "step-000003-ckpt" + old_target.mkdir() checkpoint_dir = experiment_dir / "step-000007-ckpt" checkpoint_dir.mkdir() + (experiment_dir / "latest").symlink_to(old_target.name) cfg = _make_save_cfg(experiment_dir) patched_save.save_bypass_checkpoint( @@ -288,30 +232,6 @@ def test_save_bypass_checkpoint_creates_latest_symlink_and_marker(tmp_path: Path assert (checkpoint_dir / "saving_completed").exists() -def test_save_bypass_checkpoint_replaces_existing_latest_symlink(tmp_path: Path, patched_save): - """A stale ``latest`` from a prior save must be replaced, not appended to. - Without ``unlink(missing_ok=True)`` the symlink_to() call would raise - FileExistsError mid-save and leave the run unable to checkpoint.""" - experiment_dir = tmp_path / "exp" - experiment_dir.mkdir() - old_target = experiment_dir / "step-000003-ckpt" - old_target.mkdir() - new_target = experiment_dir / "step-000007-ckpt" - new_target.mkdir() - (experiment_dir / "latest").symlink_to(old_target.name) - - cfg = _make_save_cfg(experiment_dir) - patched_save.save_bypass_checkpoint( - cfg=cfg, - descriptor=None, - model=None, - stitched_module_descriptors=OrderedDict(), - checkpoint_dir=new_target, - ) - - assert os.readlink(experiment_dir / "latest") == "step-000007-ckpt" - - def test_save_bypass_checkpoint_best_does_not_replace_latest(tmp_path: Path, patched_save): experiment_dir = tmp_path / "exp" experiment_dir.mkdir() diff --git a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py index 17057019154..9cffd53da0c 100644 --- a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py +++ b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py @@ -15,6 +15,7 @@ """Tests for bypass-distillation dataloader behavior added by this PR.""" +from contextlib import nullcontext from types import SimpleNamespace import pytest @@ -206,73 +207,42 @@ def __iter__(self): yield {"text": []} -def test_constant_length_dataset_no_chat_template_adds_role_tags_and_warns_once(monkeypatch): - monkeypatch.setattr(dataset_module, "_CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED", False) - tokenizer = _NoChatTemplateTokenizer() - dataset = ConstantLengthDataset( - tokenizer, - _ConversationDataset(), - infinite=False, - seq_length=2, - num_of_sequences=1, - chars_per_token=100, - content_field="text", - fim_rate=0.0, - fim_spm_rate=0.0, - label_shift=False, - ) - - with pytest.warns(UserWarning, match="no chat_template"): - realized = list(dataset) - - assert tokenizer.seen_texts == ["user: hello\nassistant: world"] - assert len(realized) == 1 - assert torch.equal(realized[0]["input_ids"], torch.tensor([0, 1])) - assert torch.equal(realized[0]["targets"], torch.tensor([0, 1])) - - -def test_constant_length_dataset_uses_tokenizer_chat_template_when_available(monkeypatch): - monkeypatch.setattr(dataset_module, "_CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED", False) - tokenizer = _ChatTemplateTokenizer() - dataset = ConstantLengthDataset( - tokenizer, - _ConversationDataset(), - infinite=False, - seq_length=2, - num_of_sequences=1, - chars_per_token=100, - content_field="text", - fim_rate=0.0, - fim_spm_rate=0.0, - label_shift=False, - ) - - realized = list(dataset) - - assert tokenizer.template_messages == [ +def test_constant_length_dataset_formats_conversation_messages(monkeypatch): + expected_messages = [ {"role": "user", "content": {"text": "hello"}}, {"role": "assistant", "content": "world"}, ] - assert tokenizer.seen_texts == ["templated chat"] - assert len(realized) == 1 - - -def test_constant_length_dataset_handles_empty_message_list(): - tokenizer = _NoChatTemplateTokenizer() - dataset = ConstantLengthDataset( - tokenizer, - _EmptyConversationDataset(), - infinite=False, - seq_length=2, - num_of_sequences=1, - chars_per_token=100, - content_field="text", - fim_rate=0.0, - fim_spm_rate=0.0, - label_shift=False, - ) + for tokenizer, raw_dataset, expected_texts, warning_context in [ + ( + _NoChatTemplateTokenizer(), + _ConversationDataset(), + ["user: hello\nassistant: world"], + pytest.warns(UserWarning, match="no chat_template"), + ), + (_ChatTemplateTokenizer(), _ConversationDataset(), ["templated chat"], nullcontext()), + (_NoChatTemplateTokenizer(), _EmptyConversationDataset(), [""], nullcontext()), + ]: + monkeypatch.setattr(dataset_module, "_CHAT_TEMPLATE_FALLBACK_WARNING_EMITTED", False) + dataset = ConstantLengthDataset( + tokenizer, + raw_dataset, + infinite=False, + seq_length=2, + num_of_sequences=1, + chars_per_token=100, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + label_shift=False, + ) - realized = list(dataset) + with warning_context: + realized = list(dataset) - assert tokenizer.seen_texts == [""] - assert len(realized) == 1 + assert tokenizer.seen_texts == expected_texts + assert len(realized) == 1 + if isinstance(tokenizer, _ChatTemplateTokenizer): + assert tokenizer.template_messages == expected_messages + else: + assert torch.equal(realized[0]["input_ids"], torch.tensor([0, 1])) + assert torch.equal(realized[0]["targets"], torch.tensor([0, 1])) diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py index 3de0a6ab050..f9d124618c2 100644 --- a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -106,69 +106,28 @@ def _trainable_names(model: nn.Module) -> set[str]: # --------------------------------------------------------------------------- -# Single-string subblock keys (dense model) +# Dense-model key semantics # --------------------------------------------------------------------------- -def test_subblock_ffn_trains_only_mlp(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - _set_keys_to_learn(model, descriptor, "subblock_ffn") - trainable = _trainable_names(model) - assert all(".mlp." in n for n in trainable), trainable - assert not any(".self_attn." in n for n in trainable), trainable - # Both layers' mlp params must be trainable, not just one. - assert any("model.layers.0.mlp." in n for n in trainable) - assert any("model.layers.1.mlp." in n for n in trainable) - - -def test_subblock_attention_trains_only_self_attn(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - _set_keys_to_learn(model, descriptor, "subblock_attention") - trainable = _trainable_names(model) - assert all(".self_attn." in n for n in trainable), trainable - assert not any(".mlp." in n for n in trainable), trainable - - -def test_subblock_mamba_without_block_configs_is_rejected(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - with pytest.raises(ValueError, match="subblock_mamba.*block_configs"): - _set_keys_to_learn(model, descriptor, "subblock_mamba") - - -def test_entire_block_trains_attention_and_mlp(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - _set_keys_to_learn(model, descriptor, "entire_block") - trainable = _trainable_names(model) - # Both groups present. - assert any(".self_attn." in n for n in trainable), trainable - assert any(".mlp." in n for n in trainable), trainable - # Equal to the union of every model parameter. - assert trainable == {n for n, _ in model.named_parameters()} +def test_dense_subblock_keys_select_expected_parameters(): + for keys_to_learn, include_fragments, exclude_fragments, trains_every_param in [ + ("subblock_ffn", [".mlp."], [".self_attn."], False), + ("subblock_attention", [".self_attn."], [".mlp."], False), + ("entire_block", [".self_attn.", ".mlp."], [], True), + (["subblock_attention", "subblock_ffn"], [".self_attn.", ".mlp."], [], True), + ]: + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, keys_to_learn) + trainable = _trainable_names(model) - -def test_subblock_key_list_trains_union_of_subblocks(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - _set_keys_to_learn(model, descriptor, ["subblock_attention", "subblock_ffn"]) - trainable = _trainable_names(model) - assert any(".self_attn." in n for n in trainable), trainable - assert any(".mlp." in n for n in trainable), trainable - assert trainable == {n for n, _ in model.named_parameters()} - - -def test_mixed_subblock_and_exact_name_list_is_rejected(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - with pytest.raises(ValueError, match="supports only subblock keys"): - _set_keys_to_learn( - model, - descriptor, - ["subblock_attention", "model.layers.0.self_attn.q_proj.weight"], - ) + for fragment in include_fragments: + assert any(fragment in n for n in trainable), (keys_to_learn, trainable) + for fragment in exclude_fragments: + assert not any(fragment in n for n in trainable), (keys_to_learn, trainable) + if trains_every_param: + assert trainable == {n for n, _ in model.named_parameters()} # --------------------------------------------------------------------------- @@ -186,27 +145,21 @@ def _hybrid_block_configs(): ] -def test_subblock_mamba_on_hybrid_trains_only_mamba_block(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) - _set_keys_to_learn(model, descriptor, "subblock_mamba") - trainable = _trainable_names(model) - # Block 0 (Mamba) attention-group params should be trainable; block 1 (GQA) must not. - assert any("model.layers.0.self_attn." in n for n in trainable), trainable - assert not any("model.layers.1.self_attn." in n for n in trainable), trainable - # FFN params are never trainable under subblock_mamba. - assert not any(".mlp." in n for n in trainable), trainable - +def test_hybrid_subblock_keys_partition_attention_by_block_type(): + for keys_to_learn, included_block, excluded_block in [ + ("subblock_mamba", 0, 1), + ("subblock_attention", 1, 0), + ]: + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, keys_to_learn) + trainable = _trainable_names(model) -def test_subblock_attention_on_hybrid_trains_only_gqa_block(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) - _set_keys_to_learn(model, descriptor, "subblock_attention") - trainable = _trainable_names(model) - # Block 1 (GQA) attention-group params are trainable; block 0 (Mamba) must not. - assert any("model.layers.1.self_attn." in n for n in trainable), trainable - assert not any("model.layers.0.self_attn." in n for n in trainable), trainable - assert not any(".mlp." in n for n in trainable), trainable + assert any(f"model.layers.{included_block}.self_attn." in n for n in trainable), trainable + assert not any(f"model.layers.{excluded_block}.self_attn." in n for n in trainable), ( + trainable + ) + assert not any(".mlp." in n for n in trainable), trainable # --------------------------------------------------------------------------- @@ -214,26 +167,19 @@ def test_subblock_attention_on_hybrid_trains_only_gqa_block(): # --------------------------------------------------------------------------- -def test_explicit_param_name_list_is_rejected(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) +def test_unsupported_keys_to_learn_are_rejected(): target = "model.layers.0.self_attn.q_proj.weight" - with pytest.raises(ValueError, match="subblock keys"): - _set_keys_to_learn(model, descriptor, [target]) - - -def test_regex_string_is_rejected(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - with pytest.raises(ValueError, match="keys_to_learn must be one of"): - _set_keys_to_learn(model, descriptor, r"q_proj") - - -def test_empty_key_list_is_rejected(): - model = _make_dense_model(num_layers=2) - descriptor = _make_descriptor(num_layers=2) - with pytest.raises(ValueError, match="cannot be empty"): - _set_keys_to_learn(model, descriptor, []) + for keys_to_learn, match in [ + ("subblock_mamba", "subblock_mamba.*block_configs"), + (["subblock_attention", target], "supports only subblock keys"), + ([target], "subblock keys"), + (r"q_proj", "keys_to_learn must be one of"), + ([], "cannot be empty"), + ]: + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + with pytest.raises(ValueError, match=match): + _set_keys_to_learn(model, descriptor, keys_to_learn) # --------------------------------------------------------------------------- @@ -241,11 +187,7 @@ def test_empty_key_list_is_rejected(): # --------------------------------------------------------------------------- -@pytest.mark.parametrize( - "keys_to_learn", - ["subblock_attention", "entire_block"], -) -def test_subblock_keys_skip_non_floating_point_params(keys_to_learn): +def test_subblock_keys_skip_non_floating_point_params(): """Integer / non-floating buffers exposed as parameters must stay frozen. The function explicitly guards on ``torch.is_floating_point(param)``; this @@ -258,6 +200,6 @@ def test_subblock_keys_skip_non_floating_point_params(keys_to_learn): model.model.layers[0].self_attn.register_parameter("int_counter", int_param) descriptor = _make_descriptor(num_layers=2) # Should not raise even though the int param's name matches the attention group. - _set_keys_to_learn(model, descriptor, keys_to_learn) + _set_keys_to_learn(model, descriptor, "subblock_attention") # The int counter must remain frozen regardless. assert not model.model.layers[0].self_attn.int_counter.requires_grad diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py index 85c9490abf2..8e4b3f77bf3 100644 --- a/tests/unit/torch/puzzletron/test_bypass_losses.py +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -45,7 +45,7 @@ def test_batched_normalized_mse_loss_matches_manual_relative_l2(): torch.testing.assert_close(loss, expected) -def test_batched_normalized_mse_loss_zero_target_is_finite(): +def test_batched_normalized_mse_loss_handles_zero_targets(): """All-zero target slice must not produce NaN/Inf. With the relative-L2 formula ``sum((x-t)^2) / (sum(t^2) + eps)``, an all-zero @@ -54,67 +54,43 @@ def test_batched_normalized_mse_loss_zero_target_is_finite(): by construction (that's what zero-magnitude targets mean), but the test pins the property we actually care about: finiteness, not magnitude. """ - input_ = torch.full((1, 8), 1.0) - target = torch.zeros(1, 8) - loss = batched_normalized_mse_loss(input_, target) + loss = batched_normalized_mse_loss(torch.full((1, 8), 1.0), torch.zeros(1, 8)) assert torch.isfinite(loss) assert not torch.isnan(loss) + zero_loss = batched_normalized_mse_loss(torch.zeros(2, 4), torch.zeros(2, 4)) + torch.testing.assert_close(zero_loss, torch.tensor(0.0)) -def test_batched_normalized_mse_loss_zero_input_and_target(): - """Both zero should give exactly 0.0 — numerator is zero, denominator is eps.""" - input_ = torch.zeros(2, 4) - target = torch.zeros(2, 4) - loss = batched_normalized_mse_loss(input_, target) - assert loss.item() == 0.0 - -def test_batched_normalized_mse_loss_rejects_shape_mismatch(): - input_ = torch.randn(2, 3) - target = torch.randn(2, 1) - - with pytest.raises(ValueError, match="input and target shapes must match exactly"): - batched_normalized_mse_loss(input_, target) - - -def test_batched_normalized_mse_loss_rejects_invalid_batch_dim(): +def test_batched_normalized_mse_loss_rejects_invalid_inputs(): input_ = torch.randn(2, 3) target = torch.randn(2, 3) - with pytest.raises(ValueError, match="batch_dims contains invalid dimension"): - batched_normalized_mse_loss(input_, target, batch_dims=(2,)) - + for args, kwargs, match in [ + ((input_, torch.randn(2, 1)), {}, "input and target shapes must match exactly"), + ((input_, target), {"batch_dims": (2,)}, "batch_dims contains invalid dimension"), + ((input_, target), {"epsilon": 0.0}, "epsilon must be strictly positive"), + ]: + with pytest.raises(ValueError, match=match): + batched_normalized_mse_loss(*args, **kwargs) -def test_batched_normalized_mse_loss_rejects_invalid_options(): - input_ = torch.randn(2, 3) - target = torch.randn(2, 3) - with pytest.raises(ValueError, match="epsilon must be strictly positive"): - batched_normalized_mse_loss(input_, target, epsilon=0.0) - - -def test_format_stitched_losses_keeps_trainable_nan_visible(): - out = format_stitched_losses( +def test_format_stitched_losses_reports_expected_summary_states(): + non_finite_out = format_stitched_losses( {"block_0": float("nan"), "block_1": 1.0}, initial_values_dict={"block_0": 0.5, "block_1": 2.0}, not_trainable_names={"block_2"}, step_number=3, ) + assert "nan" in non_finite_out + assert "non-finite" in non_finite_out + assert "Skipped=1" in non_finite_out + assert "No trainable blocks found" not in non_finite_out - assert "nan" in out - assert "non-finite" in out - assert "Skipped=1" in out - assert "No trainable blocks found" not in out - - -def test_format_stitched_losses_empty_trainable_reports_skipped_blocks(): - out = format_stitched_losses({}, not_trainable_names={"block_0", "block_1"}) + empty_out = format_stitched_losses({}, not_trainable_names={"block_0", "block_1"}) + assert empty_out == "No trainable losses found; skipped 2 non-trainable blocks" - assert out == "No trainable losses found; skipped 2 non-trainable blocks" - - -def test_format_stitched_losses_reports_delta_from_initial_and_filters_stale_history(): - out = format_stitched_losses( + delta_out = format_stitched_losses( {"block_0": 1.0, "block_1": 3.0}, best_steps_dict={"block_0": 5, "block_9": 99}, best_values_dict={"block_0": 0.5, "block_9": 9.0}, @@ -122,10 +98,9 @@ def test_format_stitched_losses_reports_delta_from_initial_and_filters_stale_his not_trainable_names={"block_2"}, step_number=8, ) - - assert "↓ -1.0e+00 (-50%)" in out - assert "↔ 0.0e+00" in out - assert "Step 5" in out - assert "Step 99" not in out - assert "Skipped=1" in out - assert "Avg=2.00e+00" in out + assert "↓ -1.0e+00 (-50%)" in delta_out + assert "↔ 0.0e+00" in delta_out + assert "Step 5" in delta_out + assert "Step 99" not in delta_out + assert "Skipped=1" in delta_out + assert "Avg=2.00e+00" in delta_out diff --git a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py index 101657f7d13..0bdd5962ba9 100644 --- a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py +++ b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py @@ -60,65 +60,37 @@ def _make_cfg( ) -@pytest.mark.parametrize( - ("warmup_steps", "lr_decay_steps", "learning_rate"), - [ - (10, 10, 0.5), - (20, 10, 0.7), - ], -) -def test_degenerate_budget_returns_base_lr(warmup_steps, lr_decay_steps, learning_rate): +def test_degenerate_budget_returns_base_lr(): """When ``lr_decay_steps <= warmup_steps`` (tiny test budgets), the scheduler must short-circuit to ``learning_rate`` rather than divide by zero.""" - cfg = _make_cfg( - warmup_steps=warmup_steps, - lr_decay_steps=lr_decay_steps, - learning_rate=learning_rate, - ) - assert _get_lr(cfg, step=0) == learning_rate - assert _get_lr(cfg, step=99) == learning_rate + for warmup_steps, lr_decay_steps, learning_rate in [(10, 10, 0.5), (20, 10, 0.7)]: + cfg = _make_cfg( + warmup_steps=warmup_steps, + lr_decay_steps=lr_decay_steps, + learning_rate=learning_rate, + ) + assert _get_lr(cfg, step=0) == learning_rate + assert _get_lr(cfg, step=99) == learning_rate -def test_warmup_linear_ramp(): +def test_lr_schedule_matches_key_points(): cfg = _make_cfg(warmup_steps=10, lr_decay_steps=100, learning_rate=1.0) - assert _get_lr(cfg, step=0) == pytest.approx(0.0) - assert _get_lr(cfg, step=5) == pytest.approx(0.5) - assert _get_lr(cfg, step=10) == pytest.approx(1.0) - + for step, expected, name in [ + (0, 0.0, "warmup start"), + (5, 0.5, "warmup midpoint"), + (10, 1.0, "warmup end"), + ]: + assert _get_lr(cfg, step=step) == pytest.approx(expected), name -def test_cosine_starts_decaying_immediately_after_warmup(): - """At ``step == warmup_steps + 1`` the cosine branch is entered with - ``decay_ratio = 1/(D-W)`` — already a small step below base LR, not a - duplicate plateau at base LR. This is the boundary the previous formula - got wrong (it used ``step - W - 1`` and gave ``decay_ratio == 0`` here).""" cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) - # decay_ratio = (11 - 10) / 10 = 0.1 - expected = 0.5 * (1.0 + math.cos(math.pi * 0.1)) - assert _get_lr(cfg, step=11) == pytest.approx(expected) - # Strictly below base LR — the cosine has begun. + cosine_start = 0.5 * (1.0 + math.cos(math.pi * 0.1)) + cosine_midpoint = 0.5 * (1.0 + math.cos(math.pi * 0.5)) + for step, expected, name in [ + (11, cosine_start, "cosine starts immediately after warmup"), + (15, cosine_midpoint, "cosine midpoint"), + (20, 0.0, "cosine endpoint"), + (21, 0.0, "post-decay clamp"), + (1000, 0.0, "long post-decay clamp"), + ]: + assert _get_lr(cfg, step=step) == pytest.approx(expected), name assert _get_lr(cfg, step=11) < 1.0 - - -def test_cosine_endpoint_returns_min_lr(): - """At ``step == lr_decay_steps`` the cosine branch reaches its endpoint: - ``decay_ratio == 1`` → ``coeff == 0`` → returns ``min_lr`` exactly. The - post-decay clamp at ``step == lr_decay_steps + 1`` is then a no-op - continuation, not a correction for an off-by-one.""" - cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) - assert _get_lr(cfg, step=20) == pytest.approx(0.1) - - -def test_cosine_midpoint_is_halfway(): - """At the cosine midpoint, ``coeff == 0.5`` → returns ``(lr + min_lr) / 2``.""" - cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) - # Midpoint of the post-warmup window: step such that decay_ratio == 0.5. - # decay_ratio = (step - 10) / (20 - 10) → step = 15 gives ratio 0.5. - expected_coeff = 0.5 * (1.0 + math.cos(math.pi * 0.5)) - assert _get_lr(cfg, step=15) == pytest.approx(expected_coeff) - - -def test_post_decay_clamps_to_min_lr(): - """``step > lr_decay_steps`` always returns ``min_lr`` exactly.""" - cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) - assert _get_lr(cfg, step=21) == 0.1 - assert _get_lr(cfg, step=1000) == 0.1 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py index 90799d7e33f..4cca9ff5499 100644 --- a/tests/unit/torch/puzzletron/test_bypass_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -26,47 +26,54 @@ ) -@pytest.mark.parametrize( - ("module_count", "world_size", "expected_ownership"), - [ +def test_distributed_modules_ownership(): + for module_count, world_size, expected_ownership in [ (4, 1, [0, 0, 0, 0]), (4, 2, [0, 0, 1, 1]), (3, 2, [0, 0, 1]), (7, 3, [0, 0, 0, 1, 1, 2, 2]), (1, 2, [0]), - ], -) -def test_distributed_modules_ownership(module_count, world_size, expected_ownership): - assert ( - get_distributed_modules_ownership(module_count=module_count, world_size=world_size) - == expected_ownership - ) + ]: + assert ( + get_distributed_modules_ownership(module_count=module_count, world_size=world_size) + == expected_ownership + ) -def test_pipeline_ownership_context_returns_neighbors(): +def test_pipeline_ownership_context_returns_neighbors_and_rejects_idle_rank(): ownership = [0, 0, 1, 1, 2] - assert get_pipeline_ownership_context(ownership, rank=0) == { - "owned_indices": [0, 1], - "owned_index_set": {0, 1}, - "prev_rank": None, - "next_rank": 1, - } - assert get_pipeline_ownership_context(ownership, rank=1) == { - "owned_indices": [2, 3], - "owned_index_set": {2, 3}, - "prev_rank": 0, - "next_rank": 2, - } - assert get_pipeline_ownership_context(ownership, rank=2) == { - "owned_indices": [4], - "owned_index_set": {4}, - "prev_rank": 1, - "next_rank": None, - } - - -def test_pipeline_ownership_context_rejects_idle_rank(): + for rank, expected_context in [ + ( + 0, + { + "owned_indices": [0, 1], + "owned_index_set": {0, 1}, + "prev_rank": None, + "next_rank": 1, + }, + ), + ( + 1, + { + "owned_indices": [2, 3], + "owned_index_set": {2, 3}, + "prev_rank": 0, + "next_rank": 2, + }, + ), + ( + 2, + { + "owned_indices": [4], + "owned_index_set": {4}, + "prev_rank": 1, + "next_rank": None, + }, + ), + ]: + assert get_pipeline_ownership_context(ownership, rank=rank) == expected_context + with pytest.raises(RuntimeError, match="owns no modules"): get_pipeline_ownership_context([0, 0, 1], rank=2) @@ -155,49 +162,37 @@ def test_experiment_id_falls_back_when_no_architecture_parts_exist(): assert cfg.bypass.experiment_id != "bypass_None" -def test_config_fingerprint_changes_with_dataset_path(): - cfg = _experiment_cfg("subblock_attention") - original = get_bypass_config_fingerprint(cfg) - cfg.dataset_path = "/tmp/dataset_b" - assert get_bypass_config_fingerprint(cfg) != original - - -def test_config_fingerprint_changes_with_shuffle_seed(): - cfg = _experiment_cfg("subblock_attention") - original = get_bypass_config_fingerprint(cfg) - cfg.bypass.data.shuffle_train_data_seed = 456 - assert get_bypass_config_fingerprint(cfg) != original - - -def test_config_fingerprint_changes_with_teacher_dir(): - cfg = _experiment_cfg("subblock_attention") - original = get_bypass_config_fingerprint(cfg) - cfg.teacher_dir = "/tmp/teacher_b" - assert get_bypass_config_fingerprint(cfg) != original - - -def test_config_fingerprint_changes_with_descriptor(): - cfg = _experiment_cfg("subblock_attention") - original = get_bypass_config_fingerprint(cfg) - cfg.descriptor = "other_descriptor" - assert get_bypass_config_fingerprint(cfg) != original - - -def test_config_fingerprint_canonicalizes_single_keys_to_learn(): - cfg_a = _experiment_cfg("entire_block") - cfg_b = _experiment_cfg(["entire_block"]) - - assert get_bypass_config_fingerprint(cfg_a) == get_bypass_config_fingerprint(cfg_b) - - -def test_config_fingerprint_canonicalizes_keys_to_learn_order(): - cfg_a = _experiment_cfg(["subblock_ffn", "subblock_attention"]) - cfg_b = _experiment_cfg(["subblock_attention", "subblock_ffn"]) - - assert get_bypass_config_fingerprint(cfg_a) == get_bypass_config_fingerprint(cfg_b) - - -def test_experiment_id_does_not_change_with_dataset_path(): +def test_config_fingerprint_changes_with_identity_inputs(): + for name, mutate_cfg in [ + ("dataset path", lambda cfg: setattr(cfg, "dataset_path", "/tmp/dataset_b")), + ( + "shuffle seed", + lambda cfg: setattr(cfg.bypass.data, "shuffle_train_data_seed", 456), + ), + ("teacher dir", lambda cfg: setattr(cfg, "teacher_dir", "/tmp/teacher_b")), + ("descriptor", lambda cfg: setattr(cfg, "descriptor", "other_descriptor")), + ]: + cfg = _experiment_cfg("subblock_attention") + original = get_bypass_config_fingerprint(cfg) + mutate_cfg(cfg) + assert get_bypass_config_fingerprint(cfg) != original, name + + +def test_config_fingerprint_and_experiment_id_canonicalize_keys_to_learn(): + for keys_a, keys_b in [ + ("entire_block", ["entire_block"]), + (["subblock_ffn", "subblock_attention"], ["subblock_attention", "subblock_ffn"]), + ]: + cfg_a = _experiment_cfg(keys_a) + cfg_b = _experiment_cfg(keys_b) + assert get_bypass_config_fingerprint(cfg_a) == get_bypass_config_fingerprint(cfg_b) + + set_experiment_id(cfg_a) + set_experiment_id(cfg_b) + assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id + + +def test_experiment_id_uses_teacher_source_not_dataset_path(): cfg_a = _experiment_cfg("subblock_attention") cfg_b = _experiment_cfg("subblock_attention") cfg_b.dataset_path = "/tmp/dataset_b" @@ -205,19 +200,7 @@ def test_experiment_id_does_not_change_with_dataset_path(): set_experiment_id(cfg_b) assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id - -def test_experiment_id_changes_with_teacher_source(): - cfg_a = _experiment_cfg("subblock_attention") - cfg_b = _experiment_cfg("subblock_attention") - cfg_b.teacher_dir = "/tmp/teacher_b" - set_experiment_id(cfg_a) - set_experiment_id(cfg_b) - assert cfg_a.bypass.experiment_id != cfg_b.bypass.experiment_id - - -def test_experiment_id_canonicalizes_keys_to_learn_order(): - cfg_a = _experiment_cfg(["subblock_ffn", "subblock_attention"]) - cfg_b = _experiment_cfg(["subblock_attention", "subblock_ffn"]) - set_experiment_id(cfg_a) - set_experiment_id(cfg_b) - assert cfg_a.bypass.experiment_id == cfg_b.bypass.experiment_id + cfg_c = _experiment_cfg("subblock_attention") + cfg_c.teacher_dir = "/tmp/teacher_b" + set_experiment_id(cfg_c) + assert cfg_a.bypass.experiment_id != cfg_c.bypass.experiment_id diff --git a/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py b/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py index e702256d606..2a3712901a1 100644 --- a/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py +++ b/tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py @@ -49,7 +49,7 @@ def output_embedding_name(): assert calls["num_hidden_layers"] == 7 -def test_copy_auto_map_code_files_ignores_non_string_entries(tmp_path, monkeypatch): +def test_copy_auto_map_code_files_copies_valid_local_code_references(tmp_path, monkeypatch): source_dir = tmp_path / "source" checkpoint_dir = tmp_path / "checkpoint" source_dir.mkdir() @@ -62,7 +62,7 @@ def test_copy_auto_map_code_files_ignores_non_string_entries(tmp_path, monkeypat cfg = SimpleNamespace( auto_map={ "AutoConfig": "configuration_custom.CustomConfig", - "AutoModelForCausalLM": "modeling_custom.CustomModel", + "AutoModelForCausalLM": "org/repo--modeling_custom.CustomModel", "AutoTokenizer": [None, "tokenization_custom.CustomTokenizer"], } ) @@ -71,21 +71,3 @@ def test_copy_auto_map_code_files_ignores_non_string_entries(tmp_path, monkeypat assert (checkpoint_dir / "modeling_custom.py").exists() assert (checkpoint_dir / "tokenization_custom.py").exists() - - -def test_copy_auto_map_code_files_strips_repo_id_prefix(tmp_path, monkeypatch): - source_dir = tmp_path / "source" - checkpoint_dir = tmp_path / "checkpoint" - source_dir.mkdir() - checkpoint_dir.mkdir() - (source_dir / "modeling_custom.py").write_text("# modeling\n") - - monkeypatch.setattr(cuhf.inspect, "getfile", lambda _cls: source_dir / "configuration.py") - - cfg = SimpleNamespace( - auto_map={"AutoModelForCausalLM": "org/repo--modeling_custom.CustomModel"} - ) - - cuhf._copy_auto_map_code_files(cfg, checkpoint_dir) - - assert (checkpoint_dir / "modeling_custom.py").exists() diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py index f655b0da926..28d597296fa 100644 --- a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -72,56 +72,41 @@ def _recorder(cfg): return snapshots -def test_no_configs_key_runs_once(monkeypatch, tmp_path): - """Absent ``bypass.configs`` is the single-config path — one call, no resets.""" - snapshots = _record_calls(monkeypatch) - cfg = _base_cfg(tmp_path, configs=None) - tl.launch_bypass_distillation(cfg) - assert len(snapshots) == 1 - # Single-config path doesn't touch the state machine — values remain as supplied. - assert snapshots[0]["bypass"]["iter_num"] == 999 - assert snapshots[0]["bypass"]["training"]["clipping_count"] == 42 - - -def test_empty_configs_list_runs_once(monkeypatch, tmp_path): - """``configs: []`` must hit the same branch as missing — the truthiness - check on ``bypass.configs`` treats both as 'no sweep'.""" - snapshots = _record_calls(monkeypatch) - cfg = _base_cfg(tmp_path, configs=[]) - tl.launch_bypass_distillation(cfg) - assert len(snapshots) == 1 - - -def test_two_configs_run_twice_with_distinct_overrides(monkeypatch, tmp_path): +def test_single_config_modes_run_once_without_reset(monkeypatch, tmp_path): + """Absent and empty ``bypass.configs`` both use the single-config path.""" + for configs in (None, []): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(tmp_path, configs=configs) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + # Single-config path doesn't touch the state machine. + assert snapshots[0]["bypass"]["iter_num"] == 999 + assert snapshots[0]["bypass"]["training"]["clipping_count"] == 42 + + +def test_sweep_configs_apply_overrides_reset_state_and_restore_base_keys(monkeypatch, tmp_path): + """Each sweep entry gets its override, reset counters, and base keys fallback.""" snapshots = _record_calls(monkeypatch) cfg = _base_cfg( tmp_path, configs=[ - {"model_config_overrides": {"intermediate_size": 256}}, + { + "model_config_overrides": {"intermediate_size": 256}, + "keys_to_learn": "subblock_attention", + }, {"model_config_overrides": {"intermediate_size": 128}}, ], ) tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 2 assert snapshots[0]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 256} + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" assert snapshots[1]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 128} + assert snapshots[1]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" - -def test_per_run_state_reset_before_each_call(monkeypatch, tmp_path): - """Every sweep entry must see iter_num=1, step_num=1, token_count=0, - best_val_loss=1e9, clipping_count=0, and a fresh experiment_id even when the - previous entry left the cfg in some other state.""" - snapshots = _record_calls(monkeypatch) - cfg = _base_cfg( - tmp_path, - configs=[ - {"model_config_overrides": {"intermediate_size": 256}}, - {"model_config_overrides": {"intermediate_size": 128}}, - ], - ) - tl.launch_bypass_distillation(cfg) - for snap in snapshots: - assert snap["bypass"]["experiment_id"].startswith("bypass_ffn_") + for snap, expected_prefix in zip(snapshots, ["bypass_attention_", "bypass_ffn_"], strict=True): + assert snap["bypass"]["experiment_id"].startswith(expected_prefix) assert snap["bypass"]["iter_num"] == 1 assert snap["bypass"]["step_num"] == 1 assert snap["bypass"]["token_count"] == 0 @@ -129,21 +114,7 @@ def test_per_run_state_reset_before_each_call(monkeypatch, tmp_path): assert snap["bypass"]["training"]["clipping_count"] == 0 -def test_sweep_entry_without_keys_to_learn_uses_base_not_previous_override(monkeypatch, tmp_path): - snapshots = _record_calls(monkeypatch) - cfg = _base_cfg( - tmp_path, - configs=[ - {"keys_to_learn": "subblock_attention"}, - {"model_config_overrides": {"intermediate_size": 256}}, - ], - ) - tl.launch_bypass_distillation(cfg) - assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" - assert snapshots[1]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" - - -def test_trust_remote_code_defaults_to_false_even_when_descriptor_requires_it(monkeypatch): +def test_resolve_trust_remote_code_requires_explicit_cfg_opt_in(monkeypatch): class DescriptorRequiringTrust: @staticmethod def requires_trust_remote_code(): @@ -158,27 +129,18 @@ def capture_message(*args): assert tl._resolve_trust_remote_code(OmegaConf.create({}), DescriptorRequiringTrust) is False assert any("trust_remote_code" in message for message in messages) + messages.clear() - -def test_trust_remote_code_uses_explicit_cfg_opt_in(monkeypatch): - class DescriptorRequiringTrust: - @staticmethod - def requires_trust_remote_code(): - return True - - messages = [] - - def capture_message(*args): - messages.append(" ".join(map(str, args))) - - monkeypatch.setattr(tl, "mprint", capture_message) - - cfg = OmegaConf.create({"trust_remote_code": True}) - assert tl._resolve_trust_remote_code(cfg, DescriptorRequiringTrust) is True + assert ( + tl._resolve_trust_remote_code( + OmegaConf.create({"trust_remote_code": True}), DescriptorRequiringTrust + ) + is True + ) assert messages == [] -def test_resume_state_ignored_when_init_checkpoint_path_wins(monkeypatch): +def test_resume_state_path_prefers_explicit_init_checkpoint(monkeypatch): messages = [] def capture_message(*args): @@ -190,19 +152,10 @@ def capture_message(*args): assert tl._get_resume_state_path(cfg, "/tmp/resume-ckpt") is None assert any("init_checkpoint_path" in message for message in messages) - -def test_resume_state_used_when_no_init_checkpoint_path(): - cfg = OmegaConf.create({"bypass": {"init_checkpoint_path": None}}) - + cfg.bypass.init_checkpoint_path = None assert tl._get_resume_state_path(cfg, "/tmp/resume-ckpt") == "/tmp/resume-ckpt" -def test_resume_skip_first_batches_uses_completed_iter_count(): - assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=0) == 10 - assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=1) == 11 - assert tl._get_resume_skip_first_batches(saved_skip=10, resume_iter_num=7) == 17 - - def test_flush_loss_buffer_single_rank_without_process_group(): local_buffer = {1: {"block_0": 0.25}} stitched_losses_history = {} @@ -212,73 +165,72 @@ def test_flush_loss_buffer_single_rank_without_process_group(): assert stitched_losses_history == local_buffer -def test_run_bypassed_training_broadcasts_completion_skip(monkeypatch, tmp_path): - cfg = _base_cfg(tmp_path) - cfg.bypass.experiment_id = None - checks = [] - broadcasts = [] - messages = [] - - def fail(*args, **kwargs): - raise AssertionError("training setup should not run after completed bypass check") - - monkeypatch.setattr(tl.dist, "local_rank", lambda: 0) - monkeypatch.setattr(tl.dist, "barrier", lambda: None) - monkeypatch.setattr(tl.dist, "is_master", lambda: True) - monkeypatch.setattr( - tl.dist, "broadcast", lambda value, src: broadcasts.append((value, src)) or value - ) - monkeypatch.setattr( - tl, "bypass_run_is_complete", lambda cfg_arg: checks.append(cfg_arg) or True - ) - monkeypatch.setattr(tl, "print_rank_0", lambda *args, **kwargs: messages.append(args[0])) - monkeypatch.setattr(tl.ModelDescriptorFactory, "get", fail) - - tl.run_bypassed_training(cfg) - - assert checks == [cfg] - assert broadcasts == [(True, 0)] - assert messages == [f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping"] - - -def test_run_bypassed_training_non_master_uses_broadcasted_completion(monkeypatch, tmp_path): - cfg = _base_cfg(tmp_path) - cfg.bypass.experiment_id = None - - def fail(*args, **kwargs): - raise AssertionError("non-master should not evaluate completion or continue setup") - - monkeypatch.setattr(tl.dist, "local_rank", lambda: 0) - monkeypatch.setattr(tl.dist, "barrier", lambda: None) - monkeypatch.setattr(tl.dist, "is_master", lambda: False) - monkeypatch.setattr(tl.dist, "broadcast", lambda value, src: True) - monkeypatch.setattr(tl, "bypass_run_is_complete", fail) - monkeypatch.setattr(tl.ModelDescriptorFactory, "get", fail) - - tl.run_bypassed_training(cfg) - - -def test_clip_stitched_module_grads_norm_counts_clipped_block(): - module = torch.nn.Linear(2, 1, bias=False) - module.weight.grad = torch.full_like(module.weight, 10.0) - - assert tl._clip_stitched_module_grads(module, grad_clip=0.1, grad_clip_type="norm") == 1 - assert torch.linalg.vector_norm(module.weight.grad) <= 0.1 + 1e-6 - +def test_run_bypassed_training_skips_completed_runs_on_all_ranks(monkeypatch, tmp_path): + for is_master in (True, False): + cfg = _base_cfg(tmp_path) + cfg.bypass.experiment_id = None + checks = [] + broadcasts = [] + messages = [] -def test_clip_stitched_module_grads_value_counts_clipped_block(): - module = torch.nn.Linear(2, 1, bias=False) - module.weight.grad = torch.tensor([[0.05, 2.0]]) - - assert tl._clip_stitched_module_grads(module, grad_clip=0.5, grad_clip_type="value") == 1 - assert module.weight.grad.abs().max() <= 0.5 + def fail(*args, **kwargs): + raise AssertionError("training setup should not run after completed bypass check") + def check_complete(cfg_arg): + checks.append(cfg_arg) + return True -def test_clip_stitched_module_grads_returns_zero_when_below_threshold(): - module = torch.nn.Linear(2, 1, bias=False) - module.weight.grad = torch.full_like(module.weight, 0.01) + def broadcast(value, src): + broadcasts.append((value, src)) + return True - assert tl._clip_stitched_module_grads(module, grad_clip=1.0, grad_clip_type="value") == 0 + monkeypatch.setattr(tl.dist, "local_rank", lambda: 0) + monkeypatch.setattr(tl.dist, "barrier", lambda: None) + monkeypatch.setattr(tl.dist, "is_master", lambda: is_master) + monkeypatch.setattr(tl.dist, "broadcast", broadcast) + monkeypatch.setattr(tl, "bypass_run_is_complete", check_complete) + monkeypatch.setattr(tl, "print_rank_0", lambda *args, **kwargs: messages.append(args[0])) + monkeypatch.setattr(tl.ModelDescriptorFactory, "get", fail) + + tl.run_bypassed_training(cfg) + + if is_master: + assert checks == [cfg] + assert broadcasts == [(True, 0)] + else: + assert checks == [] + assert broadcasts == [(None, 0)] + assert messages == [f"Bypass run {cfg.bypass.experiment_id} is already complete, skipping"] + + +def test_clip_stitched_module_grads_counts_only_clipped_blocks(): + for grad, grad_clip, grad_clip_type, expected_count, validate_grad in [ + ( + torch.full((1, 2), 10.0), + 0.1, + "norm", + 1, + lambda module: torch.linalg.vector_norm(module.weight.grad) <= 0.1 + 1e-6, + ), + ( + torch.tensor([[0.05, 2.0]]), + 0.5, + "value", + 1, + lambda module: module.weight.grad.abs().max() <= 0.5, + ), + ( + torch.full((1, 2), 0.01), + 1.0, + "value", + 0, + lambda module: torch.equal(module.weight.grad, torch.full_like(module.weight, 0.01)), + ), + ]: + module = torch.nn.Linear(2, 1, bias=False) + module.weight.grad = grad + assert tl._clip_stitched_module_grads(module, grad_clip, grad_clip_type) == expected_count + assert validate_grad(module) def test_step_stitched_module_optimizer_unscales_before_clipping(monkeypatch): @@ -315,46 +267,31 @@ def capture_clip(stitched_module, grad_clip, grad_clip_type): assert module.weight.grad is None -def test_finalize_bypass_run_skips_realization_when_checkpoint_saving_disabled(monkeypatch): +def test_finalize_bypass_run_marks_completion_only_after_realization(monkeypatch): + monkeypatch.setattr(tl.dist, "is_master", lambda: True) + cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": True}}) def fail(*args, **kwargs): raise AssertionError("checkpoint realization should be skipped") - monkeypatch.setattr(tl.dist, "is_master", lambda: True) monkeypatch.setattr(tl, "realize_bypass_checkpoints", fail) monkeypatch.setattr(tl, "mark_bypass_run_completed", fail) - tl._finalize_bypass_run(cfg) - -def test_finalize_bypass_run_skips_completion_when_no_checkpoint_exists(monkeypatch): + completed = {} cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": False}}) - completed = False - - def missing_checkpoint(_cfg): - raise FileNotFoundError("missing checkpoint") - - def mark_completed(*args, **kwargs): - nonlocal completed - completed = True - - monkeypatch.setattr(tl.dist, "is_master", lambda: True) - monkeypatch.setattr(tl, "realize_bypass_checkpoints", missing_checkpoint) - monkeypatch.setattr(tl, "mark_bypass_run_completed", mark_completed) + monkeypatch.setattr( + tl, "realize_bypass_checkpoints", lambda _cfg: (_ for _ in ()).throw(FileNotFoundError) + ) + monkeypatch.setattr(tl, "mark_bypass_run_completed", lambda *args: completed.update(hit=True)) tl._finalize_bypass_run(cfg) - assert completed is False + assert completed == {} - -def test_finalize_bypass_run_marks_realized_checkpoint(monkeypatch): - cfg = OmegaConf.create({"bypass": {"disable_checkpoint_save": False}}) realized = Path("/tmp/realized") symlink = Path("/tmp/ckpts/run_0") - completed = {} - - monkeypatch.setattr(tl.dist, "is_master", lambda: True) monkeypatch.setattr(tl, "realize_bypass_checkpoints", lambda _cfg: (realized, symlink)) monkeypatch.setattr( tl, diff --git a/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py b/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py index 07f46c0327b..018807ee97b 100644 --- a/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py +++ b/tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py @@ -25,33 +25,32 @@ ) -@pytest.mark.parametrize( - ("keys_to_learn", "expected_subblocks"), - [ - ("entire_block", ["block"]), - ("subblock_ffn", ["ffn"]), - ("subblock_attention", ["attention"]), - ("subblock_mamba", ["attention"]), - (["subblock_attention", "subblock_ffn"], ["attention", "ffn"]), - ], -) -def test_infer_subblocks_to_extract_accepts_bypass_keys( - tmp_path: Path, - keys_to_learn, - expected_subblocks, -): - checkpoint_dir = tmp_path / "checkpoint" - checkpoint_dir.mkdir() - (checkpoint_dir / "bypass_config.json").write_text(json.dumps({"keys_to_learn": keys_to_learn})) - - assert _infer_subblocks_to_extract(checkpoint_dir, []) == expected_subblocks - - -@pytest.mark.parametrize("keys_to_learn", ["mlp", "attn", ["mlp", "attn"]]) -def test_infer_subblocks_to_extract_rejects_legacy_keys(tmp_path: Path, keys_to_learn): - checkpoint_dir = tmp_path / "checkpoint" - checkpoint_dir.mkdir() - (checkpoint_dir / "bypass_config.json").write_text(json.dumps({"keys_to_learn": keys_to_learn})) - - with pytest.raises(ValueError, match="keys_to_learn"): - _infer_subblocks_to_extract(checkpoint_dir, []) +def test_infer_subblocks_to_extract_accepts_bypass_keys(tmp_path: Path): + for i, (keys_to_learn, expected_subblocks) in enumerate( + [ + ("entire_block", ["block"]), + ("subblock_ffn", ["ffn"]), + ("subblock_attention", ["attention"]), + ("subblock_mamba", ["attention"]), + (["subblock_attention", "subblock_ffn"], ["attention", "ffn"]), + ] + ): + checkpoint_dir = tmp_path / f"checkpoint_{i}" + checkpoint_dir.mkdir() + (checkpoint_dir / "bypass_config.json").write_text( + json.dumps({"keys_to_learn": keys_to_learn}) + ) + + assert _infer_subblocks_to_extract(checkpoint_dir, []) == expected_subblocks + + +def test_infer_subblocks_to_extract_rejects_legacy_keys(tmp_path: Path): + for i, keys_to_learn in enumerate(["mlp", "attn", ["mlp", "attn"]]): + checkpoint_dir = tmp_path / f"legacy_checkpoint_{i}" + checkpoint_dir.mkdir() + (checkpoint_dir / "bypass_config.json").write_text( + json.dumps({"keys_to_learn": keys_to_learn}) + ) + + with pytest.raises(ValueError, match="keys_to_learn"): + _infer_subblocks_to_extract(checkpoint_dir, []) diff --git a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py index d7ab66e1264..bccb3b8d020 100644 --- a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py +++ b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py @@ -30,33 +30,16 @@ ) -def test_persistent_buffer_excluded_non_persistent_included(): - m = nn.Module() - m.register_buffer("p", torch.zeros(1), persistent=True) - m.register_buffer("np", torch.zeros(1), persistent=False) - out = _get_all_non_persistent_buffers_set(m) - assert out == {"np"} - - -def test_nested_submodule_paths_are_fully_qualified(): - """Sub-module non-persistent buffers must surface as ``submodule_name.buffer_name`` - so the matching key in ``state_dict()`` and the bypass save/restore code agree.""" +def test_non_persistent_buffers_are_reported_with_qualified_paths(): + """Only non-persistent buffers should appear, including nested names.""" outer = nn.Module() - inner = nn.Module() - inner.register_buffer("nb", torch.zeros(1), persistent=False) - outer.add_module("inner", inner) - out = _get_all_non_persistent_buffers_set(outer) - assert out == {"inner.nb"} + outer.register_buffer("global_keep", torch.zeros(1), persistent=True) + outer.register_buffer("scratch", torch.zeros(1), persistent=False) - -def test_mix_of_persistent_and_non_persistent_in_nested_module(): - """The full discrimination: only the nested non-persistent buffer should - appear, with its fully-qualified path.""" - outer = nn.Module() inner = nn.Module() - inner.register_buffer("keep", torch.zeros(1), persistent=True) # persistent → excluded + inner.register_buffer("keep", torch.zeros(1), persistent=True) inner.register_buffer("rope_cache", torch.zeros(1), persistent=False) outer.add_module("attn", inner) - outer.register_buffer("global_keep", torch.zeros(1), persistent=True) # → excluded + out = _get_all_non_persistent_buffers_set(outer) - assert out == {"attn.rope_cache"} + assert out == {"scratch", "attn.rope_cache"}