[Fix] Muon optimizer per-expert orthogonalization for MoE models#1582
[Fix] Muon optimizer per-expert orthogonalization for MoE models#1582HAOCHENYE merged 6 commits intoInternLM:mainfrom
Conversation
|
@claude plz edit title and PR message. |
xtuner/v1/config/optim.py
Outdated
| num_adamw = 0 | ||
|
|
||
| # Get MoE config if available | ||
| num_experts = getattr(model, "n_routed_experts", 1) or 1 |
There was a problem hiding this comment.
I believe it should be model.config? n_routed_experts is no attribute of model per se.
There was a problem hiding this comment.
| num_experts = getattr(model, "n_routed_experts", 1) or 1 | |
| num_experts = getattr(model.config, "n_routed_experts", 1) or 1 |
xtuner/v1/config/optim.py
Outdated
| if is_moe_model: | ||
| logger.info( | ||
| f"Muon params: {(num_muon + num_muon_moe) // 1e6}M " | ||
| f"(regular: {num_muon // 1e6}M, MoE expert: {num_muon_moe // 1e6}M), " |
There was a problem hiding this comment.
int // 1e6 might introduce precision loss and hence confusion here. For example,
9e5 // 1e6would result in0.0.5.99e6 // 1e6would be5.0.
f"{xxx / 1e6:.2f}" might be better off here.
There was a problem hiding this comment.
| f"(regular: {num_muon // 1e6}M, MoE expert: {num_muon_moe // 1e6}M), " | |
| f"(regular: {num_muon // 1e6:.2f}M, MoE expert: {num_muon_moe // 1e6:.2f}M), " |
xtuner/v1/optim/muon.py
Outdated
| # For regular case (num_experts=1), this adds a batch dimension of size 1 | ||
| M = X.size(-2) // num_experts | ||
| N = X.size(-1) | ||
| X = X.view(num_experts, M, N) |
There was a problem hiding this comment.
X.view(num_experts, -1, N) would suffice?
There was a problem hiding this comment.
| X = X.view(num_experts, M, N) | |
| X = X.view(num_experts, -1, N) |
xtuner/v1/optim/muon.py
Outdated
|
|
||
| # Transpose if needed (when rows > cols) for numerical stability in NS iteration | ||
| # This ensures X @ X.mT produces a smaller square matrix | ||
| need_transpose = G.size(-2) > G.size(-1) |
There was a problem hiding this comment.
Now that X has been reshaped, need_transpose should be conditioned on shape of X instead of that of G?
There was a problem hiding this comment.
| need_transpose = G.size(-2) > G.size(-1) | |
| need_transpose = X.size(-2) > X.size(-1) |
| X = X.view(num_experts, M, N) | ||
|
|
||
| # Transpose if rows > cols for numerical stability | ||
| need_transpose = G.size(-2) > G.size(-1) |
There was a problem hiding this comment.
The same concern of non-triton NS impl applies here.
There was a problem hiding this comment.
| need_transpose = G.size(-2) > G.size(-1) | |
| need_transpose = X.size(-2) > X.size(-1) |
xtuner/v1/optim/muon.py
Outdated
| return newton_schulz_func(X, epsilon=epsilon, num_experts=num_experts).reshape(original_shape) | ||
|
|
||
|
|
||
| def adjust_lr_rms_norm(lr, param_shape): |
There was a problem hiding this comment.
Since we split up the expert weights, the adjust_lr functions should also use per-expert shape instead of tensor shape. Current implementation would result in sqrt(max(2 * n_routed_experts * expert_dim, model_dim)) times base lr for fused_w1w3, which is way larger than sqrt(max(expert_dim, model_dim))
There was a problem hiding this comment.
def adjust_lr_rms_norm(lr, param_shape, num_experts=1):
# Adjust learning rate for constant element-wise RMS norm
# https://arxiv.org/abs/2502.16982
A = param_shape.size(-2) // num_experts
B = param_shape.size(-1)
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
adjusted_lr = lr * adjusted_ratio
return adjusted_lr```|
Claude: Updated the branch with nil0x9's review comments. Changes made (commit
|
|
@claude review |
xtuner/v1/config/optim.py
Outdated
| num_adamw = 0 | ||
|
|
||
| # Get MoE config if available | ||
| num_experts = getattr(model.config, "n_routed_experts", 1) or 1 |
There was a problem hiding this comment.
Claude: Warning: model.config will raise AttributeError if the model object doesn't have a config attribute (e.g., non-HuggingFace models, or models wrapped in custom classes). Consider a safer access pattern:
| num_experts = getattr(model.config, "n_routed_experts", 1) or 1 | |
| num_experts = getattr(getattr(model, "config", None), "n_routed_experts", 1) or 1 |
This way if model lacks a config attribute, it gracefully falls back to 1 instead of crashing.
xtuner/v1/config/optim.py
Outdated
|
|
||
| for name, p in model.named_parameters(): | ||
| n = p.numel() | ||
| num_total += n | ||
| if name in trainable_names: | ||
| num_total_requires_grad += n | ||
| is_muon_tensor = p.ndim >= 2 and "embed_tokens" not in name and "lm_head" not in name | ||
| if is_muon_tensor: | ||
| num_muon += n | ||
| # Check if this is an MoE expert parameter | ||
| if is_moe_model and any(pattern in name for pattern in all_expert_patterns): | ||
| num_muon_moe += n | ||
| else: | ||
| num_muon += n | ||
| else: | ||
| num_adamw += n | ||
| else: |
There was a problem hiding this comment.
Claude: Nit: The logic for classifying parameters into muon/adamw/moe groups iterates over model.named_parameters() twice — once here for counting, and again below (lines 105-117) for building the lists. Consider combining both loops to avoid the redundant iteration over all parameters.
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| """Test Muon optimizer Newton-Schulz functions with/without torch.compile. | ||
|
|
||
| Test shapes are based on Qwen3-30B-A3B model config: | ||
| - hidden_size: 2048 | ||
| - num_experts: 128 | ||
| - moe_intermediate_size: 768 | ||
| - intermediate_size: 6144 (for shared expert) | ||
|
|
||
| MoE expert weight shapes: | ||
| - w1/w3: (num_experts * moe_intermediate_size, hidden_size) = (98304, 2048) | ||
| per expert: (768, 2048) | ||
| - w2: (hidden_size, num_experts * moe_intermediate_size) = (2048, 98304) | ||
| per expert: (2048, 768) | ||
|
|
||
| For testing, we use scaled-down versions to keep tests fast while maintaining | ||
| representative shapes. | ||
|
|
||
| ================================================================================ | ||
| IMPORTANT: DTensor Compatibility Note | ||
| ================================================================================ | ||
|
|
||
| The zeropower_via_newtonschulz5 function supports DTensor input, but with a | ||
| known limitation when M > N (e.g., w2 weights where hidden_size > moe_intermediate_size). | ||
|
|
||
| Root Cause Analysis (verified by /tmp/test_dtensor_root_cause_detailed.py): | ||
| --------------------------------------------------------------------------- | ||
| When M > N, the Newton-Schulz algorithm transposes the input matrix: | ||
| X = G.view(1, M, N).mT # becomes (1, N, M) | ||
|
|
||
| For a DTensor sharded on dim 0 (M dimension): | ||
| 1. After view(1, M, N): placements become Shard(dim=1) | ||
| 2. After mT: placements become Shard(dim=2) # the M dimension moves to dim 2 | ||
| 3. X @ X.mT produces Partial(sum) DTensor # contraction dim is sharded | ||
| 4. Partial values are not correctly reduced in subsequent operations | ||
| 5. Error accumulates across 5 Newton-Schulz iterations: | ||
| Iter 1: X max ~0.016 | ||
| Iter 2: X max ~0.060 | ||
| Iter 3: X max ~0.099 | ||
| Iter 4: X max ~0.29 | ||
| Iter 5: X max ~47.5 (EXPLOSION!) | ||
| 6. Final result is completely wrong (e.g., 0.1 -> 47.5) | ||
|
|
||
| Verification Results: | ||
| - M < N (w1/w3): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly | ||
| - M > N (w2): ✗ FAIL - A @ A.mT produces Partial(sum), results explode | ||
| - M = N (square): ✓ PASS - A @ A.mT produces Shard(dim=1), results match exactly | ||
|
|
||
| Workaround: | ||
| For DTensor with M > N (w2 weights), convert to local tensor: | ||
| result = zeropower_via_newtonschulz5(dtensor.to_local(), num_experts=1) | ||
|
|
||
| Note: | ||
| This is NOT a torch.compile issue. The same problem occurs with or without | ||
| torch.compile. It's a fundamental limitation of DTensor's Partial placement | ||
| handling in complex matrix operation chains. | ||
|
|
||
| newton_schulz_triton: | ||
| Does not support DTensor at all due to direct Triton kernel usage. | ||
| Must use .to_local() to convert before calling. |
There was a problem hiding this comment.
Claude: Warning: The module docstring contains ~40 lines of debug analysis (DTensor root cause, iteration-by-iteration error growth, references to /tmp/test_dtensor_root_cause_detailed.py). This reads like investigation notes rather than test documentation.
Consider trimming this to a brief summary of what's tested and why, with a link to an issue or doc if the DTensor limitation needs to be tracked. The reference to /tmp/test_dtensor_root_cause_detailed.py is particularly problematic since that file doesn't exist in the repo.
| G, epsilon=self.epsilon, num_experts=num_experts | ||
| ) | ||
|
|
||
| max_diff = (result1 - result2).abs().max().item() |
There was a problem hiding this comment.
Claude: Nit: The tolerance was relaxed from 2e-2 to 3e-2 without explanation. The larger test matrix sizes likely explain this, but a brief comment on why would help (e.g., "Larger matrices amplify numerical differences between PyTorch and Triton kernels").
AGENTS.md
Outdated
| @@ -0,0 +1,59 @@ | |||
| # AGENTS.md - Project Guidelines for AI Assistants | |||
There was a problem hiding this comment.
Claude: Warning: Per project guidelines ("One logical change per PR. Do not mix bug fixes with features or refactors"), this documentation file should be in a separate PR. Also, this repo already has .claude/CLAUDE.md with project conventions — having a second set of guidelines in AGENTS.md at the repo root creates a risk of divergence. Consider whether this is needed at all, or if it should extend the existing CLAUDE.md.
|
Claude: ## Summary Fixes Muon optimizer to apply Newton-Schulz orthogonalization per-expert instead of on the concatenated matrix for MoE models. The core approach — reshaping to IssuesWarning
Nit
VerdictCOMMENT — The core algorithmic fix is correct and the |
Fix Muon optimizer to apply orthogonalization per expert matrix instead of to the concatenated large matrix for MoE models.
Add guidelines for git commit messages and PR descriptions.
72e2a42 to
8be55e0
Compare
- Use model.config instead of model for n_routed_experts attribute - Fix logging precision: use f-string formatting (:.2f) instead of // 1e6 - Simplify X.view(num_experts, -1, N) instead of computing M explicitly - Fix need_transpose to use reshaped X shape instead of original G shape - Fix adjust_lr functions to use per-expert shape (divide by num_experts)
8be55e0 to
35dec9e
Compare
[Enhance] enhance muon config and args (InternLM#1610)
35dec9e to
9503cbd
Compare
Summary
Fix the Muon optimizer to apply Newton-Schulz orthogonalization per expert matrix instead of on the concatenated large matrix for MoE models. Previously, all expert weights were concatenated into a single matrix before orthogonalization, which is mathematically incorrect — each expert should be orthogonalized independently with its own spectral norm.
Changes
num_expertsparameter tozeropower_via_newtonschulz5andnewton_schulz_tritonto support per-expert orthogonalization via batch matrix operations.optim.pyinto regular and MoE expert params, passingnum_expertsthrough the optimizer pipeline.torch.compileand between PyTorch/Triton implementations.Root Cause
The Muon optimizer treated MoE expert weight matrices as a single concatenated matrix during Newton-Schulz orthogonalization. This meant the spectral norm was computed over all experts jointly, and the iterative orthogonalization mixed information across experts — producing incorrect gradient updates for MoE models.
Fix
Reshape the concatenated expert matrix from
(num_experts * M, N)to(num_experts, M, N)and use batch matrix multiplication to orthogonalize each expert independently. For non-MoE models (num_experts=1), this is mathematically equivalent to the original code path.Test Plan
tests/optim/test_muon_compile.pywith tests for:zeropower_via_newtonschulz5with and withouttorch.compilefor both regular and MoE matrices.newton_schulz_tritonbasic correctness (shape, no NaN/Inf).