Skip to content

Ssameni/puzzletron bypass 2 core#1469

Open
Separius wants to merge 9 commits into
mainfrom
ssameni/puzzletron-bypass-2-core
Open

Ssameni/puzzletron bypass 2 core#1469
Separius wants to merge 9 commits into
mainfrom
ssameni/puzzletron-bypass-2-core

Conversation

@Separius
Copy link
Copy Markdown
Contributor

@Separius Separius commented May 12, 2026

Summary

This is PR 2 of 3 in the Puzzletron bypass/local-distillation stack.

This PR adds the bypass distillation core engine. It builds on PR 1’s shared infrastructure, but it does not yet wire bypass into
the full Puzzletron pipeline.

Stack:

  1. ssameni/puzzletron-bypass-1-prereqs: shared prerequisites
  2. This PR: bypass distillation core
  3. ssameni/puzzletron-bypass-3-integration: Puzzletron integration, configs, docs, GPU coverage

What Changed

  • Added modelopt.torch.puzzletron.bypass_distillation.
  • Added bypass run identity/fingerprinting, experiment naming, state manifests, and completion tracking.
  • Added stitched teacher/student model construction for local blockwise distillation.
  • Added bypass training loop with:
    • pipeline-parallel teacher activation stitching
    • per-block student losses
    • gradient accumulation
    • checkpoint/resume support
    • validation hooks
    • best/latest checkpoint realization
  • Added bypass checkpoint helpers for saving optimizer/scaler state and HF-format model checkpoints.
  • Added scalable distributed checkpoint saving that gathers tensors per safetensors file instead of materializing all shards on
    rank 0.
  • Restricted v1 keys_to_learn to subblock-level targets only:
    • entire_block
    • subblock_attention
    • subblock_ffn
    • subblock_mamba
    • lists of those keys
  • Added pipeline ownership helper for deriving owned blocks and neighboring PP ranks.
  • Added robust trust_remote_code / auto_map handling for checkpoint config saves.

Why

Bypass distillation is the local-distillation stage used to train pruned/reconfigured blocks before they are added to the
replacement library.

Keeping the core engine separate from Puzzletron pipeline wiring makes this PR reviewable on its own: reviewers can focus on
distributed training, checkpoint/resume semantics, and Sewing Kit stitching behavior without also reviewing configs/docs/pipeline
integration.

Tests

Added focused unit coverage for:

  • bypass checkpoint utilities
  • bypass run identity/fingerprinting and completion state
  • keys_to_learn subblock selection
  • LR scheduler behavior
  • launch/sweep dispatch behavior
  • HF checkpoint utility behavior
  • stitched model factory buffer ownership

Summary by CodeRabbit

  • New Features

    • Full bypass distillation workflow: blockwise stitched student/teacher distillation, deterministic experiment IDs/fingerprints, resume-capable checkpointing with atomic symlink updates, optional asynchronous checkpoint saves, and improved support for trust-remote-code model loading.
    • Public bypass-distillation entrypoint exposed for easier launch.
  • Tests

    • Extensive unit tests covering bypass utilities, checkpoint behavior, LR scheduler, stitched factory, and training orchestration.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 12, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1469/

Built to branch gh-pages at 2026-06-02 13:37 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from b8ed8ea to ce66fb2 Compare May 12, 2026 10:52
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 12, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a PUZZLE bypass-distillation pipeline: stitched-model factory, keys-to-learn logic, distributed training and LR scheduler, checkpoint discovery/save/load with master-only metadata, HF sharded checkpoint improvements, package exports and example tweak, plus comprehensive unit tests.

Changes

Bypass Distillation Feature Implementation

Layer / File(s) Summary
Core data structures and type aliases
modelopt/torch/puzzletron/bypass_distillation/data_classes.py
Frozen dataclasses for iteration statistics, local training stats, and save signals; type aliases for iteration number and global rank.
Keys-to-learn, identity, and run state
modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
Normalization of keys_to_learn, learned-subblock labeling, deterministic experiment/run fingerprinting and IDs, experiment-dir creation, bypass_state.json lifecycle, expected-run enumeration, and pipeline-ownership utilities.
Checkpoint discovery, local state load/save, and orchestration
modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
find_latest_run_dir selection, load_local_state for per-stitched optimizer/scaler, _save_local_file/_save_local_state, and save_bypass_checkpoint with shard-aware saving and master-only metadata/symlink updates.
Stitched model descriptor and factory
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
StitchedModuleDescriptor dataclass; parameter-name selection for subblocks (attention/FFN/Mamba/entire_block), _set_keys_to_learn, non-persistent buffer collection, and bypass_factory_fn building stitched teacher/student graphs and per-block optimizers/scalers.
Training loop orchestration and execution
modelopt/torch/puzzletron/bypass_distillation/training_loop.py
launch_bypass_distillation dispatcher (single/sweep), run_bypassed_training orchestration, the train loop with loss buffering/flush, validation, checkpoint scheduling, and _get_lr cosine-with-warmup scheduler.
HF checkpoint and auto_map enhancements
modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
Distributed sharded save rewrite to avoid rank‑0 materialize, nested auto_map class-ref iteration, throttled fallback warnings, and copying auto_map source files into checkpoints.
Sharded loading trust_remote_code plumbing
modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
Adds trust_remote_code parameter handling to model-config loading and sharded model creation, routing through remote-code auto class when enabled.
Package exports and example config tweak
modelopt/torch/puzzletron/__init__.py, modelopt/torch/puzzletron/bypass_distillation/__init__.py, examples/megatron_bridge/distill.py
Exports bypass_distillation submodule, re-exports launch_bypass_distillation, and enables async checkpoint saving in the example.

Test Utilities and Infrastructure

Layer / File(s) Summary
Test family parametrization
tests/_test_utils/torch/puzzletron/utils.py
Adds PUZZLETRON_FAMILIES pytest parametrization list and exports it for shared tests.

Comprehensive Unit Test Coverage

Layer / File(s) Summary
Checkpoint utility tests
tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
CPU-only tests for checkpoint discovery, local save semantics, symlink/metadata orchestration, and master-only behavior.
Parameter selection tests
tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
Tests for _set_keys_to_learn across dense and hybrid cases, union/list inputs, and validation error cases.
LR scheduler tests
tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
Tests for warmup, cosine decay, endpoint/midpoint correctness, degenerate budgets, and clamping to min_lr.
Bypass utility tests
tests/unit/torch/puzzletron/test_bypass_utils.py
Tests for module ownership partitioning, pipeline ownership context, experiment ID/fingerprint behavior and stability.
HF checkpoint utility tests
tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
Tests ensuring descriptor language-model-config is used for weight grouping and auto_map code file copying filters non-string entries and strips repo prefixes.
Launch dispatcher tests
tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
Tests for sweep dispatching, per-run state reset, override application, resume-path precedence, and trust_remote_code resolution behavior.
Buffer management tests
tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
Tests for non-persistent buffer detection and fully-qualified naming across nested modules.
Replacement library bypass-config tests
tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py
Tests _infer_subblocks_to_extract driven by bypass_config.json with valid and legacy/invalid keys_to_learn inputs.

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 41.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Ssameni/puzzletron bypass 2 core' is vague and does not clearly convey the main change; it reads as a branch name rather than a descriptive PR title. Revise to a more descriptive title that captures the core contribution, e.g., 'Add bypass distillation engine for Puzzletron pipeline-parallel training' or 'Implement bypass distillation core with stitched model factory and training loop'.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All security checks pass: torch.load uses weights_only=True; trust_remote_code defaults False; no eval/exec, nosec, or new dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ssameni/puzzletron-bypass-2-core

Comment @coderabbitai help to get the list of available commands and usage tips.

@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch 3 times, most recently from 470fe16 to 8d3db43 Compare May 12, 2026 14:39
@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from 8d3db43 to 71edd2d Compare June 1, 2026 08:00
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 1, 2026

Codecov Report

❌ Patch coverage is 46.11486% with 638 lines in your changes missing coverage. Please review.
✅ Project coverage is 61.15%. Comparing base (72df833) to head (4e75b7b).

Files with missing lines Patch % Lines
...ch/puzzletron/bypass_distillation/training_loop.py 33.93% 368 Missing ⚠️
...tron/bypass_distillation/stitched_model_factory.py 40.09% 127 Missing ⚠️
...lopt/torch/puzzletron/tools/checkpoint_utils_hf.py 31.18% 64 Missing ⚠️
...rch/puzzletron/bypass_distillation/bypass_utils.py 74.19% 48 Missing ⚠️
...ron/bypass_distillation/bypass_checkpoint_utils.py 75.92% 26 Missing ⚠️
...torch/puzzletron/tools/sharded_checkpoint_utils.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1469       +/-   ##
===========================================
- Coverage   77.43%   61.15%   -16.29%     
===========================================
  Files         480      485        +5     
  Lines       52564    53650     +1086     
===========================================
- Hits        40703    32808     -7895     
- Misses      11861    20842     +8981     
Flag Coverage Δ
unit 53.71% <46.11%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from 71edd2d to f878d8e Compare June 1, 2026 08:26
@Separius Separius marked this pull request as ready for review June 1, 2026 08:38
@Separius Separius requested review from a team as code owners June 1, 2026 08:38
@Separius Separius requested a review from kevalmorabia97 June 1, 2026 08:38
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 6

🧹 Nitpick comments (5)
modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py (1)

182-184: ⚡ Quick win

Gate the progress bar to rank 0.

Every rank enters this loop, so distributed checkpoint saves will emit one tqdm bar per rank. That gets noisy fast and obscures real save failures; please disable the bar on non-master ranks.

As per coding guidelines, "Develop with distributed processing in mind by using print_rank_0 or warn_rank_0 to avoid noisy logs and guarding shared side effects..."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 182 - 184, The tqdm progress bar is created by every distributed
rank because the loop over stitched_module_descriptors uses tqdm
unconditionally; guard the progress bar so only rank 0 shows it. Replace the
unconditional tqdm(...) in the loop over stitched_module_descriptors.items()
with a rank-checked variant (e.g., compute is_rank0 via
torch.distributed.get_rank()==0 or the repo's helper like
print_rank_0/is_main_process) and either pass disable=not is_rank0 into tqdm or
iterate plain stitched_module_descriptors.items() on non-master ranks; keep the
loop body and variable names (stitched_module_name, stitched_module_descriptor)
unchanged.
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (2)

62-75: ⚡ Quick win

Declare this module’s public surface with __all__.

This file adds public symbols, but the export surface is still implicit. Please make it explicit here so package re-exports stay stable and star-import-safe.

♻️ Suggested shape
 StitchedModulesProcessOwnership = list[int]
 SyncDistributedModelWeightsFn = Callable[[], None]
 Config = Mapping[str, Any]
 Args = Namespace
 
+__all__ = [
+    "StitchedModuleDescriptor",
+    "StitchedModulesProcessOwnership",
+    "bypass_factory_fn",
+]
+
 
 `@dataclasses.dataclass`
 class StitchedModuleDescriptor:

As per coding guidelines, "Define the public API with __all__ at the top of each Python module" and "Define the public surface with all near the definitions."

Also applies to: 171-184

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 62 - 75, Add an explicit module export list __all__ that enumerates
the public symbols declared in this file (e.g.,
"StitchedModulesProcessOwnership", "SyncDistributedModelWeightsFn", "Config",
"Args", "StitchedModuleDescriptor" and any other public classes/functions like
StitchedModule) and place it near these definitions (close to the dataclass and
the type aliases) so the module's public API is explicit and stable for
star-imports and package re-exports.

249-249: ⚡ Quick win

Hoist this internal import or document the deferral reason.

Deferring an internal import here pushes import failures to the first bypass run. If this is avoiding a real cycle, please add a brief comment naming that reason; otherwise move it to module scope.

♻️ Minimal cleanup
+from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
 from modelopt.torch.puzzletron.tools.logger import mprint
 from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model
 from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype
@@
-            from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
-
             runtime = Namespace(

As per coding guidelines, "Imports: keep imports at the top of the file unless there’s a concrete reason ...; add a brief comment if deviating."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` at
line 249, The inline import of deci_x_patcher (from
modelopt.torch.puzzletron.anymodel.puzzformer) should be hoisted to module scope
or documented why it's deferred; either move the import to the top of
stitched_model_factory.py to fail fast on import errors, or if it truly avoids a
cyclic import, add a brief comment above the local import referencing the cycle
and the module(s) involved and keep the import local inside the function that
uses deci_x_patcher. Ensure you reference deci_x_patcher and the local import
site in your change so future readers know why it's not at module scope.
modelopt/torch/puzzletron/bypass_distillation/training_loop.py (2)

59-70: ⚡ Quick win

Add __all__ for the new public entry points.

This module now exposes several public functions, but the API surface is still implicit. Please declare it explicitly so package re-exports do not drift.

♻️ Suggested shape
 from .data_classes import GlobalRank, IterNum, IterStatistics, TimeToSaveSignal
 from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership
 
+__all__ = [
+    "launch_bypass_distillation",
+    "train",
+    "run_bypassed_training",
+    "realize_bypass_checkpoints",
+]
+
 os.environ["TOKENIZERS_PARALLELISM"] = "false"

As per coding guidelines, "Define the public API with __all__ at the top of each Python module" and "Define the public surface with all near the definitions."

Also applies to: 83-84, 228-241, 772-773, 1195-1196

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
59 - 70, Add an explicit __all__ near the top of the module (close to the public
definitions) listing the public API symbols so exports don't drift; include the
imported and exposed names such as "find_latest_run_dir", "load_local_state",
"save_bypass_checkpoint", "bypass_run_is_complete",
"get_distributed_modules_ownership", "get_pipeline_ownership_context",
"load_bypass_state", "mark_bypass_run_completed", "set_experiment_dir",
"set_experiment_id", and the data/class names "GlobalRank", "IterNum",
"IterStatistics", "TimeToSaveSignal", "StitchedModuleDescriptor",
"StitchedModulesProcessOwnership" (add any additional public functions from this
file) and place the __all__ definition near the module-level definitions as per
guidelines.

670-672: ⚡ Quick win

Hoist these internal imports or annotate why they must stay local.

These are internal modules rather than optional dependencies, so keeping them inside runtime paths defers import errors until the middle of training/validation. Please move them to module scope, or add a short comment if there is a concrete cycle/latency reason.

As per coding guidelines, "Imports: keep imports at the top of the file unless there’s a concrete reason ...; add a brief comment if deviating."

Also applies to: 939-944, 1088-1090, 1144-1146

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 24-35: Add an explicit __all__ list at the top of the module that
enumerates the intended public API (e.g., "load_bypass_state",
"update_bypass_checkpoint_state", "save_checkpoint_from_shards",
"StitchedModuleDescriptor", "ModelDescriptor", "aprint", "mprint", "json_dump")
and do not include internal helpers with leading underscores (such as
_save_local_file, _save_local_state); place the __all__ right after the imports
so star-imports only expose the listed names and internal functions remain
private.
- Around line 58-64: The code returns the symlink path `latest_dir` after
validation which can race if the symlink moves; change the return to the
resolved target so the caller gets the stable checkpoint directory (use
`latest_dir.resolve()` instead of `latest_dir`) — update the return in the block
that validates `latest_dir` to return the resolved path string so callers open
the exact checkpoint verified by the prior checks.

In `@modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py`:
- Around line 24-32: Add an explicit __all__ list at the top of the module to
declare the public API (for example include "BYPASS_STATE_FILENAME",
"BYPASS_SUBBLOCK_KEYS_TO_LEARN", and any helper functions you intend to export
such as json_dump/json_load if they are re-exposed); place the __all__
definition before other module-level code so future star-imports and package
re-exports are safe and only the intended symbols (e.g., BYPASS_STATE_FILENAME,
BYPASS_SUBBLOCK_KEYS_TO_LEARN, json_dump, json_load) are exported.

In `@modelopt/torch/puzzletron/bypass_distillation/data_classes.py`:
- Around line 18-22: Add an explicit __all__ near the top of the module listing
the public symbols (e.g., __all__ = ["IterNum", "GlobalRank"]) so the module's
public API is explicit; place it after the imports and do not include
internal/imported modules like dataclasses in the list, and update it to include
any dataclass names defined later in this file if applicable.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 790-792: The code currently derives trust_remote_code from
descriptor.requires_trust_remote_code(), which must be caller-controlled; update
calls that pass trust_remote_code (ModelDescriptorFactory.get /
descriptor.requires_trust_remote_code() usage) to instead read an explicit
opt-in flag (e.g., cfg.trust_remote_code with default False) and pass that to
load_model_config(...) and AutoTokenizer.from_pretrained(...); keep
descriptor.requires_trust_remote_code() for informational checks only (log or
warn if descriptor requires but cfg.trust_remote_code is False) but do not use
it to set the effective trust flag so callers retain explicit control.

In `@tests/_test_utils/torch/puzzletron/utils.py`:
- Around line 33-61: Add an explicit module export list by defining __all__ at
the top of the module and include the new public symbol name
"PUZZLETRON_FAMILIES" in it; update or create the __all__ list near the module
header so the public API exports PUZZLETRON_FAMILIES (refer to the
PUZZLETRON_FAMILIES symbol) and ensure __all__ is a list of string names for any
other intended public symbols in this file.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 182-184: The tqdm progress bar is created by every distributed
rank because the loop over stitched_module_descriptors uses tqdm
unconditionally; guard the progress bar so only rank 0 shows it. Replace the
unconditional tqdm(...) in the loop over stitched_module_descriptors.items()
with a rank-checked variant (e.g., compute is_rank0 via
torch.distributed.get_rank()==0 or the repo's helper like
print_rank_0/is_main_process) and either pass disable=not is_rank0 into tqdm or
iterate plain stitched_module_descriptors.items() on non-master ranks; keep the
loop body and variable names (stitched_module_name, stitched_module_descriptor)
unchanged.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 62-75: Add an explicit module export list __all__ that enumerates
the public symbols declared in this file (e.g.,
"StitchedModulesProcessOwnership", "SyncDistributedModelWeightsFn", "Config",
"Args", "StitchedModuleDescriptor" and any other public classes/functions like
StitchedModule) and place it near these definitions (close to the dataclass and
the type aliases) so the module's public API is explicit and stable for
star-imports and package re-exports.
- Line 249: The inline import of deci_x_patcher (from
modelopt.torch.puzzletron.anymodel.puzzformer) should be hoisted to module scope
or documented why it's deferred; either move the import to the top of
stitched_model_factory.py to fail fast on import errors, or if it truly avoids a
cyclic import, add a brief comment above the local import referencing the cycle
and the module(s) involved and keep the import local inside the function that
uses deci_x_patcher. Ensure you reference deci_x_patcher and the local import
site in your change so future readers know why it's not at module scope.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 59-70: Add an explicit __all__ near the top of the module (close
to the public definitions) listing the public API symbols so exports don't
drift; include the imported and exposed names such as "find_latest_run_dir",
"load_local_state", "save_bypass_checkpoint", "bypass_run_is_complete",
"get_distributed_modules_ownership", "get_pipeline_ownership_context",
"load_bypass_state", "mark_bypass_run_completed", "set_experiment_dir",
"set_experiment_id", and the data/class names "GlobalRank", "IterNum",
"IterStatistics", "TimeToSaveSignal", "StitchedModuleDescriptor",
"StitchedModulesProcessOwnership" (add any additional public functions from this
file) and place the __all__ definition near the module-level definitions as per
guidelines.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9c9cb0fd-e208-46ef-9312-459ad201d97e

📥 Commits

Reviewing files that changed from the base of the PR and between 54fb87e and f878d8e.

📒 Files selected for processing (17)
  • examples/megatron_bridge/distill.py
  • modelopt/torch/puzzletron/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
  • tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py

Comment thread modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py Outdated
Comment thread modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
Comment thread modelopt/torch/puzzletron/bypass_distillation/data_classes.py
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Comment thread tests/_test_utils/torch/puzzletron/utils.py
@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from f878d8e to 735c1d3 Compare June 1, 2026 09:10
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/unit/torch/puzzletron/test_launch_bypass_distillation.py (2)

82-83: 💤 Low value

Avoid hardcoding source line numbers in test docstrings.

References to "line 85 of training_loop.py" and "line 99" will silently go stale as training_loop.py evolves and mislead future readers. Refer to the function/branch by name instead (e.g., the truthiness check on bypass.configs, the if "keys_to_learn" in override guard).

Also applies to: 136-137

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/puzzletron/test_launch_bypass_distillation.py` around lines
82 - 83, Update the test docstring to remove hardcoded source line numbers and
instead reference code symbols/branches by name: replace "line 85 of
training_loop.py" with a description like "the truthiness check on
bypass.configs in training_loop.py" and replace "line 99" (and the similar
references at 136-137) with "the if 'keys_to_learn' in override guard" (or the
relevant function/branch name). Ensure the docstring clearly identifies the
behavior being tested using symbol names (e.g., bypass.configs truthiness check,
the if "keys_to_learn" guard) rather than numeric line references so it remains
correct as training_loop.py changes.

159-183: 💤 Low value

Make the warning-log assertion less brittle

  • any("cfg.trust_remote_code is false" in message ...) couples the test to exact log wording; asserting that a warning was emitted (and/or matching a stable token like trust_remote_code / requires_trust_remote_code) would be more resilient.
  • The current messages.append works today because _resolve_trust_remote_code calls mprint with a single string argument, but capturing *args would make the test robust to any future change in mprint call signature.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/puzzletron/test_launch_bypass_distillation.py` around lines
159 - 183, Update the brittle log assertion and message capture in the tests for
tl._resolve_trust_remote_code: instead of monkeypatching tl.mprint to
messages.append, replace it with a small wrapper that captures all args (e.g.,
append(" ".join(map(str, args))) ) so the test is robust to changes in mprint
signature; then assert that any message contains a stable token such as
"trust_remote_code" or "requires_trust_remote_code" (rather than the full phrase
"cfg.trust_remote_code is false") in
test_trust_remote_code_defaults_to_false_even_when_descriptor_requires_it,
leaving test_trust_remote_code_uses_explicit_cfg_opt_in expecting no messages.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@tests/unit/torch/puzzletron/test_launch_bypass_distillation.py`:
- Around line 82-83: Update the test docstring to remove hardcoded source line
numbers and instead reference code symbols/branches by name: replace "line 85 of
training_loop.py" with a description like "the truthiness check on
bypass.configs in training_loop.py" and replace "line 99" (and the similar
references at 136-137) with "the if 'keys_to_learn' in override guard" (or the
relevant function/branch name). Ensure the docstring clearly identifies the
behavior being tested using symbol names (e.g., bypass.configs truthiness check,
the if "keys_to_learn" guard) rather than numeric line references so it remains
correct as training_loop.py changes.
- Around line 159-183: Update the brittle log assertion and message capture in
the tests for tl._resolve_trust_remote_code: instead of monkeypatching tl.mprint
to messages.append, replace it with a small wrapper that captures all args
(e.g., append(" ".join(map(str, args))) ) so the test is robust to changes in
mprint signature; then assert that any message contains a stable token such as
"trust_remote_code" or "requires_trust_remote_code" (rather than the full phrase
"cfg.trust_remote_code is false") in
test_trust_remote_code_defaults_to_false_even_when_descriptor_requires_it,
leaving test_trust_remote_code_uses_explicit_cfg_opt_in expecting no messages.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: fd245e2b-4e43-4c6e-96dc-f0c332631853

📥 Commits

Reviewing files that changed from the base of the PR and between f878d8e and 735c1d3.

📒 Files selected for processing (18)
  • examples/megatron_bridge/distill.py
  • modelopt/torch/puzzletron/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
  • tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
🚧 Files skipped from review as they are similar to previous changes (15)
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/init.py
  • modelopt/torch/puzzletron/bypass_distillation/init.py
  • examples/megatron_bridge/distill.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py

@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from 735c1d3 to abbe9ac Compare June 1, 2026 09:33
@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Jun 1, 2026

/ok to test abbe9ac

@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Jun 1, 2026

/claude review

@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from abbe9ac to 703ce91 Compare June 1, 2026 11:42
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/puzzletron/replacement_library/build_replacement_library.py`:
- Around line 463-479: The fallback in the exception handler for
learned_subblocks_from_keys_to_learn treats mixed legacy keys (keys_to_learn
containing both "mlp" and "attn") as an error; update the handler in
build_replacement_library.py so that when legacy_keys contains both has_mlp and
has_attn you set subblocks_to_extract to include both subblocks (e.g., ["ffn",
"attention"]) instead of raising a ValueError, while preserving the existing
single-only branches and the empty/unknown case handling.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 595-642: The code in _copy_auto_map_code_files collects
module_names by taking class_ref.split(".")[0], but PretrainedConfig.auto_map
can contain repo-id prefixes like "org/repo--modeling_x.Foo", so the module name
extraction must strip any repo-id portion before splitting by "."; update the
module_names comprehension (used with _iter_auto_map_class_refs) to first take
class_ref.split("--", 1)[-1] (to handle repo_id--class_ref), then split on "."
to get the module base, and keep the existing _module_name_re validation and
warning logic; reference symbols: _copy_auto_map_code_files,
_iter_auto_map_class_refs, module_names, _module_name_re.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 64b36a59-cd6f-4aac-aa07-68afea7a2894

📥 Commits

Reviewing files that changed from the base of the PR and between abbe9ac and 703ce91.

📒 Files selected for processing (20)
  • examples/megatron_bridge/distill.py
  • modelopt/torch/puzzletron/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/replacement_library/build_replacement_library.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
  • tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
  • tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/puzzletron/bypass_distillation/init.py
🚧 Files skipped from review as they are similar to previous changes (15)
  • examples/megatron_bridge/distill.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • modelopt/torch/puzzletron/init.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py

Comment thread modelopt/torch/puzzletron/replacement_library/build_replacement_library.py Outdated
Comment thread modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@Separius Separius force-pushed the ssameni/puzzletron-bypass-2-core branch from 703ce91 to 4b016d0 Compare June 1, 2026 11:58
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py`:
- Around line 129-143: Normalize and canonicalize the
model_factory["keys_to_learn"] value before it gets incorporated into the run
identity: detect single-string vs list inputs, convert single strings to a
one-element list, flatten/normalize subblock identifiers, sort the list
deterministically and convert to an immutable representation (e.g., a tuple or
sorted tuple of normalized items) so semantically equivalent inputs
("entire_block" vs ["entire_block"], reordered lists) yield the same
fingerprint; apply this normalization wherever keys_to_learn is read in
bypass_utils.py (including the places that build the returned dict around
model_factory and the other occurrences mentioned) so the hashed/slugged
representation uses the canonical form.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 900-907: The resume logic computes resume_skip_first_batches
incorrectly by adding saved_skip + resume_cfg.iter_num, which double-counts one
batch when resuming from "final-step-*" checkpoints because iter_num there is
already the next step; update the calculation in the block that sets
resume_skip_first_batches (and the similar block around the other occurrence) to
use saved_skip + max(0, resume_cfg.iter_num - 1) (or subtract 1 from
resume_cfg.iter_num when it is > 0) so the dataloader position aligns with the
restored model/optimizer state; locate the symbol resume_skip_first_batches and
change its assignment accordingly in the code that follows
_get_resume_state_path and where resume_cfg.iter_num is used.
- Around line 823-827: The short-circuiting check using
bypass_run_is_complete(cfg) must be made rank-consistent: have only the master
rank evaluate bypass_run_is_complete(cfg) after set_experiment_id(cfg) and
set_experiment_dir(cfg), then broadcast the resulting boolean to all ranks (same
pattern used in launch_bypass_distillation()), and have all ranks use that
broadcasted value to decide whether to return; also replace the direct mprint
with print_rank_0 to avoid noisy logs and ensure any shared side-effects (e.g.,
skipping writes) are only performed or suppressed based on the broadcasted
result so no rank hits the next barrier unexpectedly.
- Around line 510-532: The current grad-clipping branches sync GPU→CPU by using
Python-side conditionals and .item(); instead, keep the "did we clip?" logic
on-device and only sync once per block to update
cfg.bypass.training.clipping_count. Specifically: for the norm branch, treat the
return of torch.nn.utils.clip_grad_norm_ as a tensor (or wrap it in
torch.as_tensor on the stitched_module device), compute a GPU tensor boolean
clipped_mask = grad_norm > grad_clip, convert that to an integer tensor and
accumulate into a local on-device counter tensor; for the value branch, avoid
.item() by computing max_abs_grad_tensor = torch.stack(grad_maxes).max(), do the
tensor compare on-GPU (e.g. max_abs_grad_tensor > grad_clip) and accumulate
similarly; after processing the stitched block, call .item() once to add the
summed on-device counter to cfg.bypass.training.clipping_count. Use
stitched_module.parameters() device for tensor creation to ensure no implicit
host sync.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: a7f8d701-5245-4c05-8207-33bb418aef4e

📥 Commits

Reviewing files that changed from the base of the PR and between 703ce91 and 4b016d0.

📒 Files selected for processing (20)
  • examples/megatron_bridge/distill.py
  • modelopt/torch/puzzletron/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/replacement_library/build_replacement_library.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
  • tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
  • tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/puzzletron/bypass_distillation/init.py
🚧 Files skipped from review as they are similar to previous changes (15)
  • modelopt/torch/puzzletron/init.py
  • tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
  • examples/megatron_bridge/distill.py
  • tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
  • tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py

Comment thread modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py Outdated
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/claude review

@Separius Separius requested a review from AAnoosheh June 1, 2026 15:39
@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/claude review

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/claude review

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/ok to test ddde4e5

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/claude review
Only review modelopt/torch/puzzletron folder changes

Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py Outdated
Comment thread modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py Outdated
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

Reviewed the bypass-distillation core engine for algorithmic correctness, mode/state machine handling, export semantics, compatibility, and performance. CodeRabbit's review covers the style/security surface (e.g., torch.load(weights_only=False), trust_remote_code=True, hardcoded paths) so I focused on logic.

Findings

CRITICAL (1)

  • Algorithm — _get_resume_skip_first_batches off-by-one (training_loop.py:130): On every periodic-step / time-based / step-interval resume, the formula returns one fewer than the total batches consumed at save time. The save snapshots iter_num=N before iter_num += 1, so the resume skip count must equal saved_skip + N, but the function returns saved_skip + N - 1. The first iter after resume re-trains on the same batch that was just trained on before the save. The final-step-* path is incidentally correct because its args.json was snapshotted after the loop already advanced past the last iter, but the loop also exits immediately on that path so the skip count is moot. Inline diff suggested.

IMPORTANT (1)

  • ModeState — silent partial-checkpoint survival (bypass_checkpoint_utils.py:232): _save_local_state(overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints) conflates two unrelated concerns. delete_old_checkpoints controls cleanup of older sibling subdirs; it has nothing to do with overwriting files in the current subdir. When set to False, _save_local_file silently skips any pre-existing stitched/{name}.optimizer_state.pth from a partial prior save, then saving_completed is touched and HF weights are rewritten on top — producing a "valid" checkpoint that pairs fresh weights with stale optimizer state. Inline suggestion to drop the overwrite parameter.

SUGGESTION (1)

  • Algorithm — _param_names_for_subblock_key mamba fallback on non-puzzletron configs (stitched_model_factory.py:115-141): When block_configs is None (raw HF model), subblock_mamba falls through and returns all attention params instead of nothing. Low-probability misuse but silent. Either raise or mprint-warn.

Other notes

  • Per-file gather in _save_checkpoint_from_distributed_shards (checkpoint_utils_hf.py) is the right pattern for pipeline-parallel state-dict assembly — avoids the OOM that the unsharded save_checkpoint would hit on rank 0.
  • bypass_run_is_complete correctly guards on both config_fingerprint AND on-disk realized/symlink existence, so a manually-deleted checkpoint dir won't be falsely treated as complete.
  • expected_bypass_runs and the dispatcher both reset experiment_id=None before set_experiment_id, keeping sweep entries from inheriting each other's IDs (covered by the new unit tests).
  • The 1 CRITICAL is currently pinned by test_resume_skip_first_batches_uses_completed_iter_count in the new unit-test file; the test will need updating along with the fix.

Cannot approve until the CRITICAL is addressed.

@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Jun 2, 2026

/ok to test b7b7586

@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Jun 2, 2026

/claude review
Only review modelopt/torch/puzzletron folder changes

Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py Outdated
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deep review summary — bypass distillation core (PR 2/3)

Findings: 1 CRITICAL · 0 IMPORTANT · 0 SUGGESTION

CRITICAL

  1. Algorithm — gradient clipping on scaled gradients (training_loop.py L561–574): when cfg.bypass.training.use_grad_scaling=True, _clip_stitched_module_grads runs on still-scaled gradients because grad_scaler.unscale_(optimizer) is never called between grad_scaler.scale(loss).backward() and clipping. With a default fp16 init scale of 2¹⁶ and grad_clip=1.0, the true gradient norm is effectively clipped to ~1.5e-5 — silent training corruption with no error. The clipping_count metric (_clip_stitched_module_grads L162–172) is also degraded since it compares scaled-norm against unscaled threshold. Masked today by the default use_grad_scaling=False (bf16 path), but the knob is exposed in config and the standard PyTorch AMP pattern documents the required ordering. Inline comment with suggested fix posted.

Verified non-issues / already addressed

  • bypass_run_is_complete is correctly evaluated only on master and broadcast (training_loop.py L205–206, L255–256) — no rank divergence.
  • keys_to_learn is included in the canonicalized fingerprint (bypass_utils.py), so sweeps with the same model overrides but different learnable keys produce distinct experiment IDs.
  • _copy_auto_map_code_files (checkpoint_utils_hf.py) handles HF's repo_id--module.Class form via class_ref.split("--", 1)[-1].split(".")[0] — earlier review concern resolved.
  • load_and_shard_model / create_sharded_model accept trust_remote_code: bool | None, defaulting to descriptor.requires_trust_remote_code() — backward compatible and routed correctly to _get_auto_class_for_trust_remote_code.
  • save_bypass_checkpoint uses save_checkpoint_from_shards (gather-aware) rather than the unsharded variant — required for pipeline-parallel ranks; comment in code accurately captures the prior failure mode.
  • Atomic latest symlink update via tmp-symlink + Path.replace (bypass_checkpoint_utils.py L242–246) — concurrent readers never see a missing link.
  • find_latest_run_dir correctly excludes best-step-* and start-step-* and validates saving_completed exists before accepting any candidate.
  • Tied-embeddings handling in load_and_shard_model (last-rank embed_tokens retained for the load, replaced with DummyWTE after, plus tie_weights() on first/last) — load order with assign=True correctly accounted for.
  • Per-rank stitched optimizer/scaler files keyed by stitched_module_name — safe because each rank only owns disjoint blocks.

Decision

Per the review gate (≥1 CRITICAL ⇒ comment review, not approve), submitting as a comment review. The CRITICAL finding only fires when use_grad_scaling=True, but the fix is one line and the cost of merging without it is silent training-correctness degradation for any user who enables the flag — please address before merge.

🤖 Generated with Claude Code

Separius added 2 commits June 2, 2026 13:14
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Jun 2, 2026

/ok to test 3d684db

Separius added 2 commits June 2, 2026 15:09
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Jun 2, 2026

/claude review
Only review modelopt/torch/puzzletron folder changes

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. Verified the prior CRITICAL findings (resume skip-first-batches off-by-one, gradient clipping on scaled gradients) and IMPORTANT/SUGGESTION findings (overwrite parameter coupling, subblock_mamba fallback when block_configs is None) are all addressed. The distributed per-file gather in save_checkpoint_from_shards is sound; bypass_run_is_complete correctly guards on fingerprint AND on-disk existence; atomic latest-symlink updates via tmp + Path.replace; tied-embedding handling on the last rank is correct. LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants