Skip to content

Add Minitron pruning support for Gemma3 via Megatron-Bridge#1604

Merged
kevalmorabia97 merged 3 commits into
mainfrom
kmorabia/gemma3-mbridge-pruning
Jun 3, 2026
Merged

Add Minitron pruning support for Gemma3 via Megatron-Bridge#1604
kevalmorabia97 merged 3 commits into
mainfrom
kmorabia/gemma3-mbridge-pruning

Conversation

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

@kevalmorabia97 kevalmorabia97 commented Jun 2, 2026

What does this PR do?

Type of change: New feature (+ new tests)

Adds Minitron pruning support for Gemma3 models loaded through Megatron-Bridge (Gemma3ForCausalLMGPTModel).

Gemma3's bridge implementation subclasses several megatron-core modules and overrides forward/__init__, so DMRegistry (which matches by nn_cls.forward is parent.forward) does not recognize them and pruning's convert_to_dynamic raised:

KeyError: "<class '...Gemma3LanguageModelEmbedding'> is not registered for a dynamic module!"

There were actually two blockers, both fixed here:

  1. Layer spec was discarded. load_mbridge_model_from_hf unconditionally replaced provider.transformer_layer_spec with the generic TE GPT spec (to disable grouped-GEMM), throwing away gemma3_layer_spec. With the generic spec, plain TEDotProductAttention received Gemma3's int window_size and crashed at construction (TypeError: 'int' object is not subscriptable) before pruning even started. Now the override is applied only to MoE models; dense models keep the bridge's native spec.
  2. Custom layers weren't registered as dynamic modules. Added registrations for Gemma3's custom layers.

Changes:

  • modelopt/torch/nas/plugins/mbridge.py (new): register dynamic modules for Gemma3LanguageModelEmbedding (reuses the existing embedding dynamic class), Gemma3SelfAttention, and the fused post-LN TERowParallelLinearLayerNorm used by both self_attention.linear_proj and mlp.linear_fc2. The fused post_layernorm is converted to a dynamic module and sliced by the linear's output_size (== hidden_size).
  • modelopt/torch/nas/plugins/megatron.py: two small reuse seams in _DynamicSelfAttention — build the core-attention dynamic class from type(self.core_attention) (preserves Gemma3TEDotProductAttention) and extract an overridable _convert_linear_proj hook. No behavior change for megatron-core models.
  • modelopt/torch/utils/plugins/mbridge.py: only override the layer spec for MoE models; dense models keep the bridge's native spec.
  • modelopt/torch/nas/plugins/__init__.py: import mbridge under import_plugin("megatron.bridge").
  • Tests: new get_tiny_gemma3 / create_tiny_gemma3_dir helpers; parametrize test_prune_minitron over qwen3 and gemma3.

Usage

# Prune a Gemma3 checkpoint with Minitron (same flow as other models)
torchrun --nproc_per_node 2 examples/megatron_bridge/prune_minitron.py \
    --hf_model_name_or_path google/gemma-3-1b-it \
    --prune_target_params 0.6e9 \
    --output_hf_path /tmp/gemma-3-pruned

Testing

tests/examples/megatron_bridge/test_prune_minitron.py now runs for both qwen3 and gemma3. Verified in the megatron-bridge container:

Also verified the gemma3 case standalone on 1 GPU (PP=1) and 2 GPUs (PP=2).

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅ (the spec-override change only affects how dense Megatron-Bridge models are loaded — they now keep their native, correct spec; MoE behavior is unchanged)
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅ (parametrized gemma3 coverage + tiny-model helper)
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ✅

Additional Information

Other Megatron-Bridge models with custom layers were surveyed; only the Gemma family needs this treatment. Gemma1 (embedding scaling via EmbeddingScalingMixin) and Gemma2 (post-LN linear + non-TE Gemma2DotProductAttention + custom output layer + mixin embedding) are intentionally left as outdated. All other dense/MoE LLMs (llama, qwen2/3, qwen3-moe, mistral, nemotron, deepseek, gpt-oss, etc.) use standard megatron-core specs already covered by megatron.py.

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Added Minitron pruning support for Megatron-Bridge Gemma3 models.
    • Improved Megatron-Bridge layer configuration handling for Mixture-of-Experts scenarios.
  • Tests

    • Added Gemma3 model creation utilities for testing.
    • Expanded pruning validation tests to cover both Qwen3 and Gemma3 models.
  • Examples

    • Updated pruning example to adjust tokenizer "use fast" detection based on model architecture.

Gemma3's Megatron-Bridge implementation subclasses several megatron-core
modules (overriding forward/__init__), so they are not matched to their
parents by DMRegistry and pruning's convert_to_dynamic raised KeyError.

Changes:
- nas/plugins/mbridge.py (new): register dynamic modules for Gemma3's
  custom layers - Gemma3LanguageModelEmbedding, Gemma3SelfAttention, and
  the fused post-LN TERowParallelLinearLayerNorm (attention linear_proj +
  MLP linear_fc2). The post_layernorm is converted dynamic and sliced by
  the linear's output_size (== hidden_size).
- nas/plugins/megatron.py: add two reuse seams in _DynamicSelfAttention -
  build the core-attention dynamic class from type(self.core_attention)
  (preserves Gemma3TEDotProductAttention) and extract an overridable
  _convert_linear_proj hook. No behavior change for megatron-core models.
- utils/plugins/mbridge.py: load_mbridge_model_from_hf now only overrides
  the layer spec for MoE models (to disable grouped GEMM); dense models
  keep the bridge's native spec so gemma3_layer_spec survives.
- nas/plugins/__init__.py: import mbridge under import_plugin("megatron.bridge").

Tests:
- Add get_tiny_gemma3 / create_tiny_gemma3_dir helpers.
- Parametrize test_prune_minitron over qwen3 and gemma3.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners June 2, 2026 15:44
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 2, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Adds Minitron pruning support for Megatron-Bridge Gemma3 by refactoring Megatron attention projection conversion, adding a Gemma3 NAS plugin with fused post-LN dynamic layers, refining loader layer-spec overrides, updating the prune script tokenizer detection, and extending tests and test utilities.

Changes

Megatron-Bridge Gemma3 Pruning Support

Layer / File(s) Summary
Megatron plugin projection conversion hook
modelopt/torch/nas/plugins/megatron.py
_DynamicSelfAttention extracts _convert_linear_proj into an overridable helper and preserves the concrete core_attention subclass type during dynamic conversion.
Megatron-Bridge Gemma3 NAS plugin
modelopt/torch/nas/plugins/mbridge.py
Adds Gemma3 dynamic wrappers: embedding registration, _convert_post_layernorm, dynamic fused post-LN row-parallel linear variants, and _DynamicGemma3SelfAttention that converts output projections to the fused post-LN variant.
NAS plugin module registration
modelopt/torch/nas/plugins/__init__.py
Conditionally imports and exports the new mbridge module under the megatron.bridge context so Gemma3 dynamic wrappers are available when Megatron-Bridge is present.
Megatron-Bridge loader layer-spec refinement
modelopt/torch/utils/plugins/mbridge.py
Restricts transformer_layer_spec overrides to only apply for non-Mamba providers that have MoE experts, preserving default Bridge specs for dense models.
Gemma3 tiny model test utilities
tests/_test_utils/torch/transformers_models.py
Adds get_tiny_gemma3 and create_tiny_gemma3_dir helpers to build and persist minimal Gemma3 causal LM artifacts for tests.
Pruning test parametrization
tests/examples/megatron_bridge/test_prune_minitron.py
Parametrizes test_prune_minitron to run pruning on both Qwen3 and Gemma3 tiny models and updates the test signature.
Prune script tokenizer detection
examples/megatron_bridge/prune_minitron.py
Removes NemotronHModelProvider import and derives use_fast_tokenizer from bridge.hf_pretrained.config.architectures presence of NemotronHForCausalLM.
Changelog
CHANGELOG.rst
Documents new Minitron pruning support for Megatron-Bridge Gemma3 models under version 0.45.

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers:

  • claude
  • jenchen13
  • ChenhanYu
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add Minitron pruning support for Gemma3 via Megatron-Bridge' accurately and concisely summarizes the main objective of the PR, which is to add pruning support for Gemma3 models through Megatron-Bridge.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. trust_remote_code properly parametrized, filtered via is_safe_repo. No unsafe torch.load, numpy.load, eval, or # nosec found.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmorabia/gemma3-mbridge-pruning

Comment @coderabbitai help to get the list of available commands and usage tips.

@kevalmorabia97 kevalmorabia97 requested review from AAnoosheh and removed request for shengliangxu June 2, 2026 15:45
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 2, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-06-03 06:30 UTC

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🧹 Nitpick comments (2)
modelopt/torch/nas/plugins/__init__.py (1)

18-27: ⚡ Quick win

Add an explicit package __all__.

This __init__.py now expands the package export surface again, but it still relies on implicit namespace exports. Please define __all__ here so star-imports stay explicit and stable.

♻️ Proposed fix
 from modelopt.torch.utils import import_plugin
 
+__all__ = []
+
 from .torch import *
+from .torch import __all__ as _torch_all
+
+__all__ += _torch_all
 
 with import_plugin("megatron"):
     from .megatron import *
+    from .megatron import __all__ as _megatron_all
     from .megatron_model_stats import *
+    from .megatron_model_stats import __all__ as _megatron_model_stats_all
+
+    __all__ += _megatron_all + _megatron_model_stats_all
 
 with import_plugin("megatron.bridge"):
     from .mbridge import *
+    from .mbridge import __all__ as _mbridge_all
+
+    __all__ += _mbridge_all

As per coding guidelines: **/*.py: Define the public API with __all__ at the top of each Python module and re-export submodules in __init__.py using from .module import * to keep the public API explicit and make star-imports safe.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/nas/plugins/__init__.py` around lines 18 - 27, Add an explicit
__all__ list to this package to make the public API deterministic: inspect the
symbols exported by the local modules referenced (from .torch import *, from
.megatron import *, from .megatron_model_stats import *, from .mbridge import *)
and create a top-level __all__ tuple in this __init__.py that enumerates the
exact names you intend to re-export; keep the import_plugin blocks but remove
reliance on implicit star-exports by listing those exported names (or the
submodule names) in __all__ so star-imports become explicit and stable (use the
module names or the specific function/class names from torch, megatron,
megatron_model_stats, and mbridge).
tests/_test_utils/torch/transformers_models.py (1)

289-315: ⚡ Quick win

Add docstrings for the new public helpers.

get_tiny_gemma3 and create_tiny_gemma3_dir are public utilities. They should document the Gemma3-specific defaults and how return_model changes the return contract.

As per coding guidelines, "Document public APIs with docstrings, including examples when useful; internal helpers should be self-documenting through clear names and structure".

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/_test_utils/torch/transformers_models.py` around lines 289 - 315, Add
docstrings to the public helpers get_tiny_gemma3 and create_tiny_gemma3_dir that
explain the Gemma3-specific defaults (e.g., dtype=torch.bfloat16,
hidden_size=32, intermediate_size=32, num_hidden_layers=2,
num_attention_heads=16, num_key_value_heads=2, head_dim=8,
query_pre_attn_scalar=8, sliding_window=16, max_position_embeddings=32,
vocab_size=32), note the head_dim/query_pre_attn_scalar relationship, and
document how config_kwargs overrides defaults; for create_tiny_gemma3_dir also
document parameters (tmp_path, with_tokenizer, return_model, **config_kwargs),
the return contract when return_model is True vs False (Path vs tuple[Path,
PreTrainedModel]), and include a short usage example showing both return cases
and typical override of a default.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/_test_utils/torch/transformers_models.py`:
- Around line 289-315: Add the new helpers to the module's explicit public API
by updating the __all__ list near the top of the file to include
"get_tiny_gemma3" and "create_tiny_gemma3_dir"; locate the existing __all__
definition (or add one if missing) and append these two symbol names as strings
so they are exported for star-imports and documented as public test utilities.
- Around line 303-310: The helper can produce inconsistent Gemma3 configs
because callers can override head_dim via config_kwargs while
query_pre_attn_scalar remains at its default 8; after merging config_kwargs into
kwargs, explicitly set kwargs["query_pre_attn_scalar"] = kwargs["head_dim"] (or
otherwise copy head_dim into query_pre_attn_scalar) before constructing
Gemma3TextConfig in AutoModelForCausalLM.from_config so the two always remain in
sync; update the code around the kwargs.update(...) and return line to enforce
this.

---

Nitpick comments:
In `@modelopt/torch/nas/plugins/__init__.py`:
- Around line 18-27: Add an explicit __all__ list to this package to make the
public API deterministic: inspect the symbols exported by the local modules
referenced (from .torch import *, from .megatron import *, from
.megatron_model_stats import *, from .mbridge import *) and create a top-level
__all__ tuple in this __init__.py that enumerates the exact names you intend to
re-export; keep the import_plugin blocks but remove reliance on implicit
star-exports by listing those exported names (or the submodule names) in __all__
so star-imports become explicit and stable (use the module names or the specific
function/class names from torch, megatron, megatron_model_stats, and mbridge).

In `@tests/_test_utils/torch/transformers_models.py`:
- Around line 289-315: Add docstrings to the public helpers get_tiny_gemma3 and
create_tiny_gemma3_dir that explain the Gemma3-specific defaults (e.g.,
dtype=torch.bfloat16, hidden_size=32, intermediate_size=32, num_hidden_layers=2,
num_attention_heads=16, num_key_value_heads=2, head_dim=8,
query_pre_attn_scalar=8, sliding_window=16, max_position_embeddings=32,
vocab_size=32), note the head_dim/query_pre_attn_scalar relationship, and
document how config_kwargs overrides defaults; for create_tiny_gemma3_dir also
document parameters (tmp_path, with_tokenizer, return_model, **config_kwargs),
the return contract when return_model is True vs False (Path vs tuple[Path,
PreTrainedModel]), and include a short usage example showing both return cases
and typical override of a default.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 954bb772-fe80-4ee0-995d-0ec9ba259292

📥 Commits

Reviewing files that changed from the base of the PR and between 72df833 and 0576d59.

📒 Files selected for processing (7)
  • CHANGELOG.rst
  • modelopt/torch/nas/plugins/__init__.py
  • modelopt/torch/nas/plugins/mbridge.py
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/utils/plugins/mbridge.py
  • tests/_test_utils/torch/transformers_models.py
  • tests/examples/megatron_bridge/test_prune_minitron.py

Comment thread tests/_test_utils/torch/transformers_models.py
Comment thread tests/_test_utils/torch/transformers_models.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 2, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 77.08%. Comparing base (72df833) to head (216045a).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1604      +/-   ##
==========================================
- Coverage   77.43%   77.08%   -0.35%     
==========================================
  Files         480      481       +1     
  Lines       52564    53364     +800     
==========================================
+ Hits        40703    41137     +434     
- Misses      11861    12227     +366     
Flag Coverage Δ
examples 41.76% <100.00%> (+0.94%) ⬆️
gpu 59.86% <66.66%> (-0.58%) ⬇️
regression 15.22% <12.12%> (+0.10%) ⬆️
unit 53.73% <12.12%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97
Copy link
Copy Markdown
Collaborator Author

/claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found.

Verified:

  • The type(self.core_attention) change in _DynamicSelfAttention is a real correctness fix: previously, when a Gemma3 model was converted, module.__class__ was set to a class with (DynamicModule, TEDotProductAttention) bases, and bind_forward_method_if_needed would re-bind TEDotProductAttention.forward onto the instance — silently overriding Gemma3TEDotProductAttention's sliding-window/local-rope behavior. Using type(self.core_attention) makes the dynamic class a proper subclass of the original, preserving forward via MRO.
  • The _convert_linear_proj extraction is a clean Open/Closed seam with no behavior change for non-Gemma3 models; existing call sites (mcore_minitron.py, GPU tests) still match via isinstance through inheritance.
  • The MoE-only spec override ((provider.num_moe_experts or 0) > 0) is the right scope: dense models keep the bridge-native spec (required for Gemma3's gemma3_layer_spec to construct TEDotProductAttention with Gemma3's window_size tuple). Correctly handles None and 0.
  • _DynamicGemma3SelfAttention / _DynamicTERowParallelLinearLayerNorm correctly chain super()._setup() → convert post_layernorm sliced by output_size, and override export() to flush the post_layernorm before super().export(). The MRO (_DynamicTEProjRowParallelLinear, TERowParallelLinearLayerNorm) linearizes cleanly.
  • Test parametrization exercises the Gemma3 sliding/full attention layer_types code path on top of Qwen3 coverage.

LGTM

kevalmorabia97 and others added 2 commits June 2, 2026 09:42
…dge >0.4 compat)

Megatron-Bridge removed NemotronHModelProvider (NVIDIA-NeMo/Megatron-Bridge#3827);
Nemotron-H now loads as a plain MambaModelProvider. The hard import raised ImportError
on Megatron-Bridge main, and isinstance(provider, NemotronHModelProvider) (used to set
use_fast=True as a tokenizer WAR) no longer discriminates Nemotron-H.

Detect Nemotron-H via the HF architecture ("NemotronHForCausalLM") instead - the same key
the bridge dispatches on - so it is stable across Megatron-Bridge versions and lets us drop
the provider import entirely.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Default query_pre_attn_scalar to head_dim (Gemma3's convention for all sizes except 27B)
unless the caller overrides it, so overriding head_dim alone produces a consistent config.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
@kevalmorabia97 kevalmorabia97 requested a review from a team as a code owner June 2, 2026 16:42
@kevalmorabia97 kevalmorabia97 merged commit e40b4d6 into main Jun 3, 2026
68 of 70 checks passed
@kevalmorabia97 kevalmorabia97 deleted the kmorabia/gemma3-mbridge-pruning branch June 3, 2026 06:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants