Add Minitron pruning support for Gemma3 via Megatron-Bridge#1604
Conversation
Gemma3's Megatron-Bridge implementation subclasses several megatron-core
modules (overriding forward/__init__), so they are not matched to their
parents by DMRegistry and pruning's convert_to_dynamic raised KeyError.
Changes:
- nas/plugins/mbridge.py (new): register dynamic modules for Gemma3's
custom layers - Gemma3LanguageModelEmbedding, Gemma3SelfAttention, and
the fused post-LN TERowParallelLinearLayerNorm (attention linear_proj +
MLP linear_fc2). The post_layernorm is converted dynamic and sliced by
the linear's output_size (== hidden_size).
- nas/plugins/megatron.py: add two reuse seams in _DynamicSelfAttention -
build the core-attention dynamic class from type(self.core_attention)
(preserves Gemma3TEDotProductAttention) and extract an overridable
_convert_linear_proj hook. No behavior change for megatron-core models.
- utils/plugins/mbridge.py: load_mbridge_model_from_hf now only overrides
the layer spec for MoE models (to disable grouped GEMM); dense models
keep the bridge's native spec so gemma3_layer_spec survives.
- nas/plugins/__init__.py: import mbridge under import_plugin("megatron.bridge").
Tests:
- Add get_tiny_gemma3 / create_tiny_gemma3_dir helpers.
- Parametrize test_prune_minitron over qwen3 and gemma3.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
📝 WalkthroughWalkthroughAdds Minitron pruning support for Megatron-Bridge Gemma3 by refactoring Megatron attention projection conversion, adding a Gemma3 NAS plugin with fused post-LN dynamic layers, refining loader layer-spec overrides, updating the prune script tokenizer detection, and extending tests and test utilities. ChangesMegatron-Bridge Gemma3 Pruning Support
🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers:
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
There was a problem hiding this comment.
Warning
CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.
Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.
Actionable comments posted: 2
🧹 Nitpick comments (2)
modelopt/torch/nas/plugins/__init__.py (1)
18-27: ⚡ Quick winAdd an explicit package
__all__.This
__init__.pynow expands the package export surface again, but it still relies on implicit namespace exports. Please define__all__here so star-imports stay explicit and stable.♻️ Proposed fix
from modelopt.torch.utils import import_plugin +__all__ = [] + from .torch import * +from .torch import __all__ as _torch_all + +__all__ += _torch_all with import_plugin("megatron"): from .megatron import * + from .megatron import __all__ as _megatron_all from .megatron_model_stats import * + from .megatron_model_stats import __all__ as _megatron_model_stats_all + + __all__ += _megatron_all + _megatron_model_stats_all with import_plugin("megatron.bridge"): from .mbridge import * + from .mbridge import __all__ as _mbridge_all + + __all__ += _mbridge_allAs per coding guidelines:
**/*.py:Define the public API with __all__ at the top of each Python module and re-export submodules in __init__.py using from .module import * to keep the public API explicit and make star-imports safe.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@modelopt/torch/nas/plugins/__init__.py` around lines 18 - 27, Add an explicit __all__ list to this package to make the public API deterministic: inspect the symbols exported by the local modules referenced (from .torch import *, from .megatron import *, from .megatron_model_stats import *, from .mbridge import *) and create a top-level __all__ tuple in this __init__.py that enumerates the exact names you intend to re-export; keep the import_plugin blocks but remove reliance on implicit star-exports by listing those exported names (or the submodule names) in __all__ so star-imports become explicit and stable (use the module names or the specific function/class names from torch, megatron, megatron_model_stats, and mbridge).tests/_test_utils/torch/transformers_models.py (1)
289-315: ⚡ Quick winAdd docstrings for the new public helpers.
get_tiny_gemma3andcreate_tiny_gemma3_dirare public utilities. They should document the Gemma3-specific defaults and howreturn_modelchanges the return contract.As per coding guidelines, "Document public APIs with docstrings, including examples when useful; internal helpers should be self-documenting through clear names and structure".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/_test_utils/torch/transformers_models.py` around lines 289 - 315, Add docstrings to the public helpers get_tiny_gemma3 and create_tiny_gemma3_dir that explain the Gemma3-specific defaults (e.g., dtype=torch.bfloat16, hidden_size=32, intermediate_size=32, num_hidden_layers=2, num_attention_heads=16, num_key_value_heads=2, head_dim=8, query_pre_attn_scalar=8, sliding_window=16, max_position_embeddings=32, vocab_size=32), note the head_dim/query_pre_attn_scalar relationship, and document how config_kwargs overrides defaults; for create_tiny_gemma3_dir also document parameters (tmp_path, with_tokenizer, return_model, **config_kwargs), the return contract when return_model is True vs False (Path vs tuple[Path, PreTrainedModel]), and include a short usage example showing both return cases and typical override of a default.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/_test_utils/torch/transformers_models.py`:
- Around line 289-315: Add the new helpers to the module's explicit public API
by updating the __all__ list near the top of the file to include
"get_tiny_gemma3" and "create_tiny_gemma3_dir"; locate the existing __all__
definition (or add one if missing) and append these two symbol names as strings
so they are exported for star-imports and documented as public test utilities.
- Around line 303-310: The helper can produce inconsistent Gemma3 configs
because callers can override head_dim via config_kwargs while
query_pre_attn_scalar remains at its default 8; after merging config_kwargs into
kwargs, explicitly set kwargs["query_pre_attn_scalar"] = kwargs["head_dim"] (or
otherwise copy head_dim into query_pre_attn_scalar) before constructing
Gemma3TextConfig in AutoModelForCausalLM.from_config so the two always remain in
sync; update the code around the kwargs.update(...) and return line to enforce
this.
---
Nitpick comments:
In `@modelopt/torch/nas/plugins/__init__.py`:
- Around line 18-27: Add an explicit __all__ list to this package to make the
public API deterministic: inspect the symbols exported by the local modules
referenced (from .torch import *, from .megatron import *, from
.megatron_model_stats import *, from .mbridge import *) and create a top-level
__all__ tuple in this __init__.py that enumerates the exact names you intend to
re-export; keep the import_plugin blocks but remove reliance on implicit
star-exports by listing those exported names (or the submodule names) in __all__
so star-imports become explicit and stable (use the module names or the specific
function/class names from torch, megatron, megatron_model_stats, and mbridge).
In `@tests/_test_utils/torch/transformers_models.py`:
- Around line 289-315: Add docstrings to the public helpers get_tiny_gemma3 and
create_tiny_gemma3_dir that explain the Gemma3-specific defaults (e.g.,
dtype=torch.bfloat16, hidden_size=32, intermediate_size=32, num_hidden_layers=2,
num_attention_heads=16, num_key_value_heads=2, head_dim=8,
query_pre_attn_scalar=8, sliding_window=16, max_position_embeddings=32,
vocab_size=32), note the head_dim/query_pre_attn_scalar relationship, and
document how config_kwargs overrides defaults; for create_tiny_gemma3_dir also
document parameters (tmp_path, with_tokenizer, return_model, **config_kwargs),
the return contract when return_model is True vs False (Path vs tuple[Path,
PreTrainedModel]), and include a short usage example showing both return cases
and typical override of a default.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 954bb772-fe80-4ee0-995d-0ec9ba259292
📒 Files selected for processing (7)
CHANGELOG.rstmodelopt/torch/nas/plugins/__init__.pymodelopt/torch/nas/plugins/mbridge.pymodelopt/torch/nas/plugins/megatron.pymodelopt/torch/utils/plugins/mbridge.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.py
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #1604 +/- ##
==========================================
- Coverage 77.43% 77.08% -0.35%
==========================================
Files 480 481 +1
Lines 52564 53364 +800
==========================================
+ Hits 40703 41137 +434
- Misses 11861 12227 +366
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
/claude review |
There was a problem hiding this comment.
Claude review passed — no blocking issues found.
Verified:
- The
type(self.core_attention)change in_DynamicSelfAttentionis a real correctness fix: previously, when a Gemma3 model was converted,module.__class__was set to a class with(DynamicModule, TEDotProductAttention)bases, andbind_forward_method_if_neededwould re-bindTEDotProductAttention.forwardonto the instance — silently overridingGemma3TEDotProductAttention's sliding-window/local-rope behavior. Usingtype(self.core_attention)makes the dynamic class a proper subclass of the original, preserving forward via MRO. - The
_convert_linear_projextraction is a clean Open/Closed seam with no behavior change for non-Gemma3 models; existing call sites (mcore_minitron.py, GPU tests) still match via isinstance through inheritance. - The MoE-only spec override (
(provider.num_moe_experts or 0) > 0) is the right scope: dense models keep the bridge-native spec (required for Gemma3'sgemma3_layer_specto constructTEDotProductAttentionwith Gemma3'swindow_sizetuple). Correctly handles None and 0. _DynamicGemma3SelfAttention/_DynamicTERowParallelLinearLayerNormcorrectly chainsuper()._setup()→ convert post_layernorm sliced byoutput_size, and overrideexport()to flush the post_layernorm beforesuper().export(). The MRO(_DynamicTEProjRowParallelLinear, TERowParallelLinearLayerNorm)linearizes cleanly.- Test parametrization exercises the Gemma3 sliding/full attention
layer_typescode path on top of Qwen3 coverage.
LGTM
…dge >0.4 compat) Megatron-Bridge removed NemotronHModelProvider (NVIDIA-NeMo/Megatron-Bridge#3827); Nemotron-H now loads as a plain MambaModelProvider. The hard import raised ImportError on Megatron-Bridge main, and isinstance(provider, NemotronHModelProvider) (used to set use_fast=True as a tokenizer WAR) no longer discriminates Nemotron-H. Detect Nemotron-H via the HF architecture ("NemotronHForCausalLM") instead - the same key the bridge dispatches on - so it is stable across Megatron-Bridge versions and lets us drop the provider import entirely. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Default query_pre_attn_scalar to head_dim (Gemma3's convention for all sizes except 27B) unless the caller overrides it, so overriding head_dim alone produces a consistent config. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
What does this PR do?
Type of change: New feature (+ new tests)
Adds Minitron pruning support for Gemma3 models loaded through Megatron-Bridge (
Gemma3ForCausalLM→GPTModel).Gemma3's bridge implementation subclasses several megatron-core modules and overrides
forward/__init__, soDMRegistry(which matches bynn_cls.forward is parent.forward) does not recognize them and pruning'sconvert_to_dynamicraised:There were actually two blockers, both fixed here:
load_mbridge_model_from_hfunconditionally replacedprovider.transformer_layer_specwith the generic TE GPT spec (to disable grouped-GEMM), throwing awaygemma3_layer_spec. With the generic spec, plainTEDotProductAttentionreceived Gemma3's intwindow_sizeand crashed at construction (TypeError: 'int' object is not subscriptable) before pruning even started. Now the override is applied only to MoE models; dense models keep the bridge's native spec.Changes:
modelopt/torch/nas/plugins/mbridge.py(new): register dynamic modules forGemma3LanguageModelEmbedding(reuses the existing embedding dynamic class),Gemma3SelfAttention, and the fused post-LNTERowParallelLinearLayerNormused by bothself_attention.linear_projandmlp.linear_fc2. The fusedpost_layernormis converted to a dynamic module and sliced by the linear'soutput_size(==hidden_size).modelopt/torch/nas/plugins/megatron.py: two small reuse seams in_DynamicSelfAttention— build the core-attention dynamic class fromtype(self.core_attention)(preservesGemma3TEDotProductAttention) and extract an overridable_convert_linear_projhook. No behavior change for megatron-core models.modelopt/torch/utils/plugins/mbridge.py: only override the layer spec for MoE models; dense models keep the bridge's native spec.modelopt/torch/nas/plugins/__init__.py: importmbridgeunderimport_plugin("megatron.bridge").get_tiny_gemma3/create_tiny_gemma3_dirhelpers; parametrizetest_prune_minitronover qwen3 and gemma3.Usage
# Prune a Gemma3 checkpoint with Minitron (same flow as other models) torchrun --nproc_per_node 2 examples/megatron_bridge/prune_minitron.py \ --hf_model_name_or_path google/gemma-3-1b-it \ --prune_target_params 0.6e9 \ --output_hf_path /tmp/gemma-3-prunedTesting
tests/examples/megatron_bridge/test_prune_minitron.pynow runs for both qwen3 and gemma3. Verified in the megatron-bridge container:Also verified the gemma3 case standalone on 1 GPU (PP=1) and 2 GPUs (PP=2).
Before your PR is "Ready for review"
CONTRIBUTING.md: N/AAdditional Information
Other Megatron-Bridge models with custom layers were surveyed; only the Gemma family needs this treatment. Gemma1 (embedding scaling via
EmbeddingScalingMixin) and Gemma2 (post-LN linear + non-TEGemma2DotProductAttention+ custom output layer + mixin embedding) are intentionally left as outdated. All other dense/MoE LLMs (llama, qwen2/3, qwen3-moe, mistral, nemotron, deepseek, gpt-oss, etc.) use standard megatron-core specs already covered bymegatron.py.🤖 Generated with Claude Code
Summary by CodeRabbit
New Features
Tests
Examples