Ssameni/puzzletron bypass 2 core#1469
Conversation
|
b8ed8ea to
ce66fb2
Compare
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a PUZZLE bypass-distillation pipeline: stitched-model factory, keys-to-learn logic, distributed training and LR scheduler, checkpoint discovery/save/load with master-only metadata, HF sharded checkpoint improvements, package exports and example tweak, plus comprehensive unit tests. ChangesBypass Distillation Feature Implementation
Test Utilities and Infrastructure
Comprehensive Unit Test Coverage
🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
470fe16 to
8d3db43
Compare
8d3db43 to
71edd2d
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1469 +/- ##
===========================================
- Coverage 77.43% 61.15% -16.29%
===========================================
Files 480 485 +5
Lines 52564 53650 +1086
===========================================
- Hits 40703 32808 -7895
- Misses 11861 20842 +8981
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
71edd2d to
f878d8e
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 6
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py (1)
182-184: ⚡ Quick winGate the progress bar to rank 0.
Every rank enters this loop, so distributed checkpoint saves will emit one
tqdmbar per rank. That gets noisy fast and obscures real save failures; please disable the bar on non-master ranks.As per coding guidelines, "Develop with distributed processing in mind by using
print_rank_0orwarn_rank_0to avoid noisy logs and guarding shared side effects..."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py` around lines 182 - 184, The tqdm progress bar is created by every distributed rank because the loop over stitched_module_descriptors uses tqdm unconditionally; guard the progress bar so only rank 0 shows it. Replace the unconditional tqdm(...) in the loop over stitched_module_descriptors.items() with a rank-checked variant (e.g., compute is_rank0 via torch.distributed.get_rank()==0 or the repo's helper like print_rank_0/is_main_process) and either pass disable=not is_rank0 into tqdm or iterate plain stitched_module_descriptors.items() on non-master ranks; keep the loop body and variable names (stitched_module_name, stitched_module_descriptor) unchanged.modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (2)
62-75: ⚡ Quick winDeclare this module’s public surface with
__all__.This file adds public symbols, but the export surface is still implicit. Please make it explicit here so package re-exports stay stable and star-import-safe.
♻️ Suggested shape
StitchedModulesProcessOwnership = list[int] SyncDistributedModelWeightsFn = Callable[[], None] Config = Mapping[str, Any] Args = Namespace +__all__ = [ + "StitchedModuleDescriptor", + "StitchedModulesProcessOwnership", + "bypass_factory_fn", +] + `@dataclasses.dataclass` class StitchedModuleDescriptor:As per coding guidelines, "Define the public API with
__all__at the top of each Python module" and "Define the public surface with all near the definitions."Also applies to: 171-184
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 62 - 75, Add an explicit module export list __all__ that enumerates the public symbols declared in this file (e.g., "StitchedModulesProcessOwnership", "SyncDistributedModelWeightsFn", "Config", "Args", "StitchedModuleDescriptor" and any other public classes/functions like StitchedModule) and place it near these definitions (close to the dataclass and the type aliases) so the module's public API is explicit and stable for star-imports and package re-exports.
249-249: ⚡ Quick winHoist this internal import or document the deferral reason.
Deferring an internal import here pushes import failures to the first bypass run. If this is avoiding a real cycle, please add a brief comment naming that reason; otherwise move it to module scope.
♻️ Minimal cleanup
+from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher from modelopt.torch.puzzletron.tools.logger import mprint from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype @@ - from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher - runtime = Namespace(As per coding guidelines, "Imports: keep imports at the top of the file unless there’s a concrete reason ...; add a brief comment if deviating."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` at line 249, The inline import of deci_x_patcher (from modelopt.torch.puzzletron.anymodel.puzzformer) should be hoisted to module scope or documented why it's deferred; either move the import to the top of stitched_model_factory.py to fail fast on import errors, or if it truly avoids a cyclic import, add a brief comment above the local import referencing the cycle and the module(s) involved and keep the import local inside the function that uses deci_x_patcher. Ensure you reference deci_x_patcher and the local import site in your change so future readers know why it's not at module scope.modelopt/torch/puzzletron/bypass_distillation/training_loop.py (2)
59-70: ⚡ Quick winAdd
__all__for the new public entry points.This module now exposes several public functions, but the API surface is still implicit. Please declare it explicitly so package re-exports do not drift.
♻️ Suggested shape
from .data_classes import GlobalRank, IterNum, IterStatistics, TimeToSaveSignal from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership +__all__ = [ + "launch_bypass_distillation", + "train", + "run_bypassed_training", + "realize_bypass_checkpoints", +] + os.environ["TOKENIZERS_PARALLELISM"] = "false"As per coding guidelines, "Define the public API with
__all__at the top of each Python module" and "Define the public surface with all near the definitions."Also applies to: 83-84, 228-241, 772-773, 1195-1196
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines 59 - 70, Add an explicit __all__ near the top of the module (close to the public definitions) listing the public API symbols so exports don't drift; include the imported and exposed names such as "find_latest_run_dir", "load_local_state", "save_bypass_checkpoint", "bypass_run_is_complete", "get_distributed_modules_ownership", "get_pipeline_ownership_context", "load_bypass_state", "mark_bypass_run_completed", "set_experiment_dir", "set_experiment_id", and the data/class names "GlobalRank", "IterNum", "IterStatistics", "TimeToSaveSignal", "StitchedModuleDescriptor", "StitchedModulesProcessOwnership" (add any additional public functions from this file) and place the __all__ definition near the module-level definitions as per guidelines.
670-672: ⚡ Quick winHoist these internal imports or annotate why they must stay local.
These are internal modules rather than optional dependencies, so keeping them inside runtime paths defers import errors until the middle of training/validation. Please move them to module scope, or add a short comment if there is a concrete cycle/latency reason.
As per coding guidelines, "Imports: keep imports at the top of the file unless there’s a concrete reason ...; add a brief comment if deviating."
Also applies to: 939-944, 1088-1090, 1144-1146
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 24-35: Add an explicit __all__ list at the top of the module that
enumerates the intended public API (e.g., "load_bypass_state",
"update_bypass_checkpoint_state", "save_checkpoint_from_shards",
"StitchedModuleDescriptor", "ModelDescriptor", "aprint", "mprint", "json_dump")
and do not include internal helpers with leading underscores (such as
_save_local_file, _save_local_state); place the __all__ right after the imports
so star-imports only expose the listed names and internal functions remain
private.
- Around line 58-64: The code returns the symlink path `latest_dir` after
validation which can race if the symlink moves; change the return to the
resolved target so the caller gets the stable checkpoint directory (use
`latest_dir.resolve()` instead of `latest_dir`) — update the return in the block
that validates `latest_dir` to return the resolved path string so callers open
the exact checkpoint verified by the prior checks.
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py`:
- Around line 24-32: Add an explicit __all__ list at the top of the module to
declare the public API (for example include "BYPASS_STATE_FILENAME",
"BYPASS_SUBBLOCK_KEYS_TO_LEARN", and any helper functions you intend to export
such as json_dump/json_load if they are re-exposed); place the __all__
definition before other module-level code so future star-imports and package
re-exports are safe and only the intended symbols (e.g., BYPASS_STATE_FILENAME,
BYPASS_SUBBLOCK_KEYS_TO_LEARN, json_dump, json_load) are exported.
In `@modelopt/torch/puzzletron/bypass_distillation/data_classes.py`:
- Around line 18-22: Add an explicit __all__ near the top of the module listing
the public symbols (e.g., __all__ = ["IterNum", "GlobalRank"]) so the module's
public API is explicit; place it after the imports and do not include
internal/imported modules like dataclasses in the list, and update it to include
any dataclass names defined later in this file if applicable.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 790-792: The code currently derives trust_remote_code from
descriptor.requires_trust_remote_code(), which must be caller-controlled; update
calls that pass trust_remote_code (ModelDescriptorFactory.get /
descriptor.requires_trust_remote_code() usage) to instead read an explicit
opt-in flag (e.g., cfg.trust_remote_code with default False) and pass that to
load_model_config(...) and AutoTokenizer.from_pretrained(...); keep
descriptor.requires_trust_remote_code() for informational checks only (log or
warn if descriptor requires but cfg.trust_remote_code is False) but do not use
it to set the effective trust flag so callers retain explicit control.
In `@tests/_test_utils/torch/puzzletron/utils.py`:
- Around line 33-61: Add an explicit module export list by defining __all__ at
the top of the module and include the new public symbol name
"PUZZLETRON_FAMILIES" in it; update or create the __all__ list near the module
header so the public API exports PUZZLETRON_FAMILIES (refer to the
PUZZLETRON_FAMILIES symbol) and ensure __all__ is a list of string names for any
other intended public symbols in this file.
---
Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 182-184: The tqdm progress bar is created by every distributed
rank because the loop over stitched_module_descriptors uses tqdm
unconditionally; guard the progress bar so only rank 0 shows it. Replace the
unconditional tqdm(...) in the loop over stitched_module_descriptors.items()
with a rank-checked variant (e.g., compute is_rank0 via
torch.distributed.get_rank()==0 or the repo's helper like
print_rank_0/is_main_process) and either pass disable=not is_rank0 into tqdm or
iterate plain stitched_module_descriptors.items() on non-master ranks; keep the
loop body and variable names (stitched_module_name, stitched_module_descriptor)
unchanged.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 62-75: Add an explicit module export list __all__ that enumerates
the public symbols declared in this file (e.g.,
"StitchedModulesProcessOwnership", "SyncDistributedModelWeightsFn", "Config",
"Args", "StitchedModuleDescriptor" and any other public classes/functions like
StitchedModule) and place it near these definitions (close to the dataclass and
the type aliases) so the module's public API is explicit and stable for
star-imports and package re-exports.
- Line 249: The inline import of deci_x_patcher (from
modelopt.torch.puzzletron.anymodel.puzzformer) should be hoisted to module scope
or documented why it's deferred; either move the import to the top of
stitched_model_factory.py to fail fast on import errors, or if it truly avoids a
cyclic import, add a brief comment above the local import referencing the cycle
and the module(s) involved and keep the import local inside the function that
uses deci_x_patcher. Ensure you reference deci_x_patcher and the local import
site in your change so future readers know why it's not at module scope.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 59-70: Add an explicit __all__ near the top of the module (close
to the public definitions) listing the public API symbols so exports don't
drift; include the imported and exposed names such as "find_latest_run_dir",
"load_local_state", "save_bypass_checkpoint", "bypass_run_is_complete",
"get_distributed_modules_ownership", "get_pipeline_ownership_context",
"load_bypass_state", "mark_bypass_run_completed", "set_experiment_dir",
"set_experiment_id", and the data/class names "GlobalRank", "IterNum",
"IterStatistics", "TimeToSaveSignal", "StitchedModuleDescriptor",
"StitchedModulesProcessOwnership" (add any additional public functions from this
file) and place the __all__ definition near the module-level definitions as per
guidelines.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 9c9cb0fd-e208-46ef-9312-459ad201d97e
📒 Files selected for processing (17)
examples/megatron_bridge/distill.pymodelopt/torch/puzzletron/__init__.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pytests/_test_utils/torch/puzzletron/utils.pytests/unit/torch/puzzletron/test_bypass_checkpoint_utils.pytests/unit/torch/puzzletron/test_bypass_keys_to_learn.pytests/unit/torch/puzzletron/test_bypass_lr_scheduler.pytests/unit/torch/puzzletron/test_bypass_utils.pytests/unit/torch/puzzletron/test_checkpoint_utils_hf.pytests/unit/torch/puzzletron/test_launch_bypass_distillation.pytests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
f878d8e to
735c1d3
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (2)
tests/unit/torch/puzzletron/test_launch_bypass_distillation.py (2)
82-83: 💤 Low valueAvoid hardcoding source line numbers in test docstrings.
References to "line 85 of training_loop.py" and "line 99" will silently go stale as
training_loop.pyevolves and mislead future readers. Refer to the function/branch by name instead (e.g., the truthiness check onbypass.configs, theif "keys_to_learn" in overrideguard).Also applies to: 136-137
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/puzzletron/test_launch_bypass_distillation.py` around lines 82 - 83, Update the test docstring to remove hardcoded source line numbers and instead reference code symbols/branches by name: replace "line 85 of training_loop.py" with a description like "the truthiness check on bypass.configs in training_loop.py" and replace "line 99" (and the similar references at 136-137) with "the if 'keys_to_learn' in override guard" (or the relevant function/branch name). Ensure the docstring clearly identifies the behavior being tested using symbol names (e.g., bypass.configs truthiness check, the if "keys_to_learn" guard) rather than numeric line references so it remains correct as training_loop.py changes.
159-183: 💤 Low valueMake the warning-log assertion less brittle
any("cfg.trust_remote_code is false" in message ...)couples the test to exact log wording; asserting that a warning was emitted (and/or matching a stable token liketrust_remote_code/requires_trust_remote_code) would be more resilient.- The current
messages.appendworks today because_resolve_trust_remote_codecallsmprintwith a single string argument, but capturing*argswould make the test robust to any future change inmprintcall signature.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/unit/torch/puzzletron/test_launch_bypass_distillation.py` around lines 159 - 183, Update the brittle log assertion and message capture in the tests for tl._resolve_trust_remote_code: instead of monkeypatching tl.mprint to messages.append, replace it with a small wrapper that captures all args (e.g., append(" ".join(map(str, args))) ) so the test is robust to changes in mprint signature; then assert that any message contains a stable token such as "trust_remote_code" or "requires_trust_remote_code" (rather than the full phrase "cfg.trust_remote_code is false") in test_trust_remote_code_defaults_to_false_even_when_descriptor_requires_it, leaving test_trust_remote_code_uses_explicit_cfg_opt_in expecting no messages.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@tests/unit/torch/puzzletron/test_launch_bypass_distillation.py`:
- Around line 82-83: Update the test docstring to remove hardcoded source line
numbers and instead reference code symbols/branches by name: replace "line 85 of
training_loop.py" with a description like "the truthiness check on
bypass.configs in training_loop.py" and replace "line 99" (and the similar
references at 136-137) with "the if 'keys_to_learn' in override guard" (or the
relevant function/branch name). Ensure the docstring clearly identifies the
behavior being tested using symbol names (e.g., bypass.configs truthiness check,
the if "keys_to_learn" guard) rather than numeric line references so it remains
correct as training_loop.py changes.
- Around line 159-183: Update the brittle log assertion and message capture in
the tests for tl._resolve_trust_remote_code: instead of monkeypatching tl.mprint
to messages.append, replace it with a small wrapper that captures all args
(e.g., append(" ".join(map(str, args))) ) so the test is robust to changes in
mprint signature; then assert that any message contains a stable token such as
"trust_remote_code" or "requires_trust_remote_code" (rather than the full phrase
"cfg.trust_remote_code is false") in
test_trust_remote_code_defaults_to_false_even_when_descriptor_requires_it,
leaving test_trust_remote_code_uses_explicit_cfg_opt_in expecting no messages.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: fd245e2b-4e43-4c6e-96dc-f0c332631853
📒 Files selected for processing (18)
examples/megatron_bridge/distill.pymodelopt/torch/puzzletron/__init__.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/tools/sharded_checkpoint_utils.pytests/_test_utils/torch/puzzletron/utils.pytests/unit/torch/puzzletron/test_bypass_checkpoint_utils.pytests/unit/torch/puzzletron/test_bypass_keys_to_learn.pytests/unit/torch/puzzletron/test_bypass_lr_scheduler.pytests/unit/torch/puzzletron/test_bypass_utils.pytests/unit/torch/puzzletron/test_checkpoint_utils_hf.pytests/unit/torch/puzzletron/test_launch_bypass_distillation.pytests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
🚧 Files skipped from review as they are similar to previous changes (15)
- tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
- tests/_test_utils/torch/puzzletron/utils.py
- modelopt/torch/puzzletron/bypass_distillation/data_classes.py
- modelopt/torch/puzzletron/init.py
- modelopt/torch/puzzletron/bypass_distillation/init.py
- examples/megatron_bridge/distill.py
- tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
- tests/unit/torch/puzzletron/test_bypass_utils.py
- tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
- tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
- tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
- modelopt/torch/puzzletron/bypass_distillation/training_loop.py
735c1d3 to
abbe9ac
Compare
|
/ok to test abbe9ac |
|
/claude review |
abbe9ac to
703ce91
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/puzzletron/replacement_library/build_replacement_library.py`:
- Around line 463-479: The fallback in the exception handler for
learned_subblocks_from_keys_to_learn treats mixed legacy keys (keys_to_learn
containing both "mlp" and "attn") as an error; update the handler in
build_replacement_library.py so that when legacy_keys contains both has_mlp and
has_attn you set subblocks_to_extract to include both subblocks (e.g., ["ffn",
"attention"]) instead of raising a ValueError, while preserving the existing
single-only branches and the empty/unknown case handling.
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 595-642: The code in _copy_auto_map_code_files collects
module_names by taking class_ref.split(".")[0], but PretrainedConfig.auto_map
can contain repo-id prefixes like "org/repo--modeling_x.Foo", so the module name
extraction must strip any repo-id portion before splitting by "."; update the
module_names comprehension (used with _iter_auto_map_class_refs) to first take
class_ref.split("--", 1)[-1] (to handle repo_id--class_ref), then split on "."
to get the module base, and keep the existing _module_name_re validation and
warning logic; reference symbols: _copy_auto_map_code_files,
_iter_auto_map_class_refs, module_names, _module_name_re.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 64b36a59-cd6f-4aac-aa07-68afea7a2894
📒 Files selected for processing (20)
examples/megatron_bridge/distill.pymodelopt/torch/puzzletron/__init__.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/replacement_library/build_replacement_library.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/tools/sharded_checkpoint_utils.pytests/_test_utils/torch/puzzletron/utils.pytests/unit/torch/puzzletron/test_bypass_checkpoint_utils.pytests/unit/torch/puzzletron/test_bypass_keys_to_learn.pytests/unit/torch/puzzletron/test_bypass_lr_scheduler.pytests/unit/torch/puzzletron/test_bypass_utils.pytests/unit/torch/puzzletron/test_checkpoint_utils_hf.pytests/unit/torch/puzzletron/test_launch_bypass_distillation.pytests/unit/torch/puzzletron/test_replacement_library_bypass_config.pytests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/puzzletron/bypass_distillation/init.py
🚧 Files skipped from review as they are similar to previous changes (15)
- examples/megatron_bridge/distill.py
- tests/_test_utils/torch/puzzletron/utils.py
- tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
- tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
- modelopt/torch/puzzletron/bypass_distillation/data_classes.py
- modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
- tests/unit/torch/puzzletron/test_checkpoint_utils_hf.py
- tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
- tests/unit/torch/puzzletron/test_bypass_utils.py
- modelopt/torch/puzzletron/init.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
- tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
- modelopt/torch/puzzletron/bypass_distillation/training_loop.py
- modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
703ce91 to
4b016d0
Compare
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py`:
- Around line 129-143: Normalize and canonicalize the
model_factory["keys_to_learn"] value before it gets incorporated into the run
identity: detect single-string vs list inputs, convert single strings to a
one-element list, flatten/normalize subblock identifiers, sort the list
deterministically and convert to an immutable representation (e.g., a tuple or
sorted tuple of normalized items) so semantically equivalent inputs
("entire_block" vs ["entire_block"], reordered lists) yield the same
fingerprint; apply this normalization wherever keys_to_learn is read in
bypass_utils.py (including the places that build the returned dict around
model_factory and the other occurrences mentioned) so the hashed/slugged
representation uses the canonical form.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 900-907: The resume logic computes resume_skip_first_batches
incorrectly by adding saved_skip + resume_cfg.iter_num, which double-counts one
batch when resuming from "final-step-*" checkpoints because iter_num there is
already the next step; update the calculation in the block that sets
resume_skip_first_batches (and the similar block around the other occurrence) to
use saved_skip + max(0, resume_cfg.iter_num - 1) (or subtract 1 from
resume_cfg.iter_num when it is > 0) so the dataloader position aligns with the
restored model/optimizer state; locate the symbol resume_skip_first_batches and
change its assignment accordingly in the code that follows
_get_resume_state_path and where resume_cfg.iter_num is used.
- Around line 823-827: The short-circuiting check using
bypass_run_is_complete(cfg) must be made rank-consistent: have only the master
rank evaluate bypass_run_is_complete(cfg) after set_experiment_id(cfg) and
set_experiment_dir(cfg), then broadcast the resulting boolean to all ranks (same
pattern used in launch_bypass_distillation()), and have all ranks use that
broadcasted value to decide whether to return; also replace the direct mprint
with print_rank_0 to avoid noisy logs and ensure any shared side-effects (e.g.,
skipping writes) are only performed or suppressed based on the broadcasted
result so no rank hits the next barrier unexpectedly.
- Around line 510-532: The current grad-clipping branches sync GPU→CPU by using
Python-side conditionals and .item(); instead, keep the "did we clip?" logic
on-device and only sync once per block to update
cfg.bypass.training.clipping_count. Specifically: for the norm branch, treat the
return of torch.nn.utils.clip_grad_norm_ as a tensor (or wrap it in
torch.as_tensor on the stitched_module device), compute a GPU tensor boolean
clipped_mask = grad_norm > grad_clip, convert that to an integer tensor and
accumulate into a local on-device counter tensor; for the value branch, avoid
.item() by computing max_abs_grad_tensor = torch.stack(grad_maxes).max(), do the
tensor compare on-GPU (e.g. max_abs_grad_tensor > grad_clip) and accumulate
similarly; after processing the stitched block, call .item() once to add the
summed on-device counter to cfg.bypass.training.clipping_count. Use
stitched_module.parameters() device for tensor creation to ensure no implicit
host sync.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: a7f8d701-5245-4c05-8207-33bb418aef4e
📒 Files selected for processing (20)
examples/megatron_bridge/distill.pymodelopt/torch/puzzletron/__init__.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/replacement_library/build_replacement_library.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/tools/sharded_checkpoint_utils.pytests/_test_utils/torch/puzzletron/utils.pytests/unit/torch/puzzletron/test_bypass_checkpoint_utils.pytests/unit/torch/puzzletron/test_bypass_keys_to_learn.pytests/unit/torch/puzzletron/test_bypass_lr_scheduler.pytests/unit/torch/puzzletron/test_bypass_utils.pytests/unit/torch/puzzletron/test_checkpoint_utils_hf.pytests/unit/torch/puzzletron/test_launch_bypass_distillation.pytests/unit/torch/puzzletron/test_replacement_library_bypass_config.pytests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
✅ Files skipped from review due to trivial changes (1)
- modelopt/torch/puzzletron/bypass_distillation/init.py
🚧 Files skipped from review as they are similar to previous changes (15)
- modelopt/torch/puzzletron/init.py
- tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py
- examples/megatron_bridge/distill.py
- tests/unit/torch/puzzletron/test_replacement_library_bypass_config.py
- tests/unit/torch/puzzletron/test_bypass_utils.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
- modelopt/torch/puzzletron/tools/sharded_checkpoint_utils.py
- tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py
- tests/_test_utils/torch/puzzletron/utils.py
- modelopt/torch/puzzletron/bypass_distillation/data_classes.py
- tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py
- tests/unit/torch/puzzletron/test_launch_bypass_distillation.py
- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
- modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
- tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
/claude review |
|
/claude review |
|
/claude review |
|
/ok to test ddde4e5 |
|
/claude review |
There was a problem hiding this comment.
Review summary
Reviewed the bypass-distillation core engine for algorithmic correctness, mode/state machine handling, export semantics, compatibility, and performance. CodeRabbit's review covers the style/security surface (e.g., torch.load(weights_only=False), trust_remote_code=True, hardcoded paths) so I focused on logic.
Findings
CRITICAL (1)
- Algorithm —
_get_resume_skip_first_batchesoff-by-one (training_loop.py:130): On every periodic-step / time-based / step-interval resume, the formula returns one fewer than the total batches consumed at save time. The save snapshotsiter_num=Nbeforeiter_num += 1, so the resume skip count must equalsaved_skip + N, but the function returnssaved_skip + N - 1. The first iter after resume re-trains on the same batch that was just trained on before the save. Thefinal-step-*path is incidentally correct because itsargs.jsonwas snapshotted after the loop already advanced past the last iter, but the loop also exits immediately on that path so the skip count is moot. Inline diff suggested.
IMPORTANT (1)
- ModeState — silent partial-checkpoint survival (
bypass_checkpoint_utils.py:232):_save_local_state(overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints)conflates two unrelated concerns.delete_old_checkpointscontrols cleanup of older sibling subdirs; it has nothing to do with overwriting files in the current subdir. When set to False,_save_local_filesilently skips any pre-existingstitched/{name}.optimizer_state.pthfrom a partial prior save, thensaving_completedis touched and HF weights are rewritten on top — producing a "valid" checkpoint that pairs fresh weights with stale optimizer state. Inline suggestion to drop theoverwriteparameter.
SUGGESTION (1)
- Algorithm —
_param_names_for_subblock_keymamba fallback on non-puzzletron configs (stitched_model_factory.py:115-141): Whenblock_configs is None(raw HF model),subblock_mambafalls through and returns all attention params instead of nothing. Low-probability misuse but silent. Either raise ormprint-warn.
Other notes
- Per-file gather in
_save_checkpoint_from_distributed_shards(checkpoint_utils_hf.py) is the right pattern for pipeline-parallel state-dict assembly — avoids the OOM that the unshardedsave_checkpointwould hit on rank 0. bypass_run_is_completecorrectly guards on bothconfig_fingerprintAND on-disk realized/symlink existence, so a manually-deleted checkpoint dir won't be falsely treated as complete.expected_bypass_runsand the dispatcher both resetexperiment_id=Nonebeforeset_experiment_id, keeping sweep entries from inheriting each other's IDs (covered by the new unit tests).- The 1 CRITICAL is currently pinned by
test_resume_skip_first_batches_uses_completed_iter_countin the new unit-test file; the test will need updating along with the fix.
Cannot approve until the CRITICAL is addressed.
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
/ok to test b7b7586 |
|
/claude review |
There was a problem hiding this comment.
Deep review summary — bypass distillation core (PR 2/3)
Findings: 1 CRITICAL · 0 IMPORTANT · 0 SUGGESTION
CRITICAL
- Algorithm — gradient clipping on scaled gradients (
training_loop.pyL561–574): whencfg.bypass.training.use_grad_scaling=True,_clip_stitched_module_gradsruns on still-scaled gradients becausegrad_scaler.unscale_(optimizer)is never called betweengrad_scaler.scale(loss).backward()and clipping. With a default fp16 init scale of 2¹⁶ andgrad_clip=1.0, the true gradient norm is effectively clipped to ~1.5e-5 — silent training corruption with no error. Theclipping_countmetric (_clip_stitched_module_gradsL162–172) is also degraded since it compares scaled-norm against unscaled threshold. Masked today by the defaultuse_grad_scaling=False(bf16 path), but the knob is exposed in config and the standard PyTorch AMP pattern documents the required ordering. Inline comment with suggested fix posted.
Verified non-issues / already addressed
bypass_run_is_completeis correctly evaluated only on master and broadcast (training_loop.pyL205–206, L255–256) — no rank divergence.keys_to_learnis included in the canonicalized fingerprint (bypass_utils.py), so sweeps with the same model overrides but different learnable keys produce distinct experiment IDs._copy_auto_map_code_files(checkpoint_utils_hf.py) handles HF'srepo_id--module.Classform viaclass_ref.split("--", 1)[-1].split(".")[0]— earlier review concern resolved.load_and_shard_model/create_sharded_modelaccepttrust_remote_code: bool | None, defaulting todescriptor.requires_trust_remote_code()— backward compatible and routed correctly to_get_auto_class_for_trust_remote_code.save_bypass_checkpointusessave_checkpoint_from_shards(gather-aware) rather than the unsharded variant — required for pipeline-parallel ranks; comment in code accurately captures the prior failure mode.- Atomic
latestsymlink update via tmp-symlink +Path.replace(bypass_checkpoint_utils.pyL242–246) — concurrent readers never see a missing link. find_latest_run_dircorrectly excludesbest-step-*andstart-step-*and validatessaving_completedexists before accepting any candidate.- Tied-embeddings handling in
load_and_shard_model(last-rankembed_tokensretained for the load, replaced withDummyWTEafter, plustie_weights()on first/last) — load order withassign=Truecorrectly accounted for. - Per-rank stitched optimizer/scaler files keyed by
stitched_module_name— safe because each rank only owns disjoint blocks.
Decision
Per the review gate (≥1 CRITICAL ⇒ comment review, not approve), submitting as a comment review. The CRITICAL finding only fires when use_grad_scaling=True, but the fix is one line and the cost of merging without it is silent training-correctness degradation for any user who enables the flag — please address before merge.
🤖 Generated with Claude Code
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
/ok to test 3d684db |
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
/claude review |
There was a problem hiding this comment.
Claude review passed — no blocking issues found. Verified the prior CRITICAL findings (resume skip-first-batches off-by-one, gradient clipping on scaled gradients) and IMPORTANT/SUGGESTION findings (overwrite parameter coupling, subblock_mamba fallback when block_configs is None) are all addressed. The distributed per-file gather in save_checkpoint_from_shards is sound; bypass_run_is_complete correctly guards on fingerprint AND on-disk existence; atomic latest-symlink updates via tmp + Path.replace; tied-embedding handling on the last rank is correct. LGTM.
Summary
This is PR 2 of 3 in the Puzzletron bypass/local-distillation stack.
This PR adds the bypass distillation core engine. It builds on PR 1’s shared infrastructure, but it does not yet wire bypass into
the full Puzzletron pipeline.
Stack:
ssameni/puzzletron-bypass-1-prereqs: shared prerequisitesssameni/puzzletron-bypass-3-integration: Puzzletron integration, configs, docs, GPU coverageWhat Changed
modelopt.torch.puzzletron.bypass_distillation.rank 0.
keys_to_learnto subblock-level targets only:entire_blocksubblock_attentionsubblock_ffnsubblock_mambatrust_remote_code/auto_maphandling for checkpoint config saves.Why
Bypass distillation is the local-distillation stage used to train pruned/reconfigured blocks before they are added to the
replacement library.
Keeping the core engine separate from Puzzletron pipeline wiring makes this PR reviewable on its own: reviewers can focus on
distributed training, checkpoint/resume semantics, and Sewing Kit stitching behavior without also reviewing configs/docs/pipeline
integration.
Tests
Added focused unit coverage for:
keys_to_learnsubblock selectionSummary by CodeRabbit
New Features
Tests