From b87bf49c8493084a8dedbb13f2a5e70c8b4b5eb5 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 2 Jan 2026 23:46:26 +0800 Subject: [PATCH 01/10] feat(pt): add Muon optimizer fix(pt): Muon bug fix feat&fix(pt): Muon add bf16 support feat(pt): use tf32 for Muon fix(pt): Use 1e-8 for Muon feat(pt): Update Muon fix(pt): use the same lr for adam inside Muon feat(pt): add match_rms for Muon feat(pt): adjust Muon feat(pt): Update Muon (cherry picked from commit 9b4e63da9555eff2b8bdee555ddfc1910dfebfb8) --- deepmd/pt/optimizer/__init__.py | 5 +- deepmd/pt/optimizer/muon.py | 369 ++++++++++++++++++++++++++++++++ deepmd/pt/train/training.py | 15 +- source/tests/pt/test_muon.py | 211 ++++++++++++++++++ 4 files changed, 598 insertions(+), 2 deletions(-) create mode 100644 deepmd/pt/optimizer/muon.py create mode 100644 source/tests/pt/test_muon.py diff --git a/deepmd/pt/optimizer/__init__.py b/deepmd/pt/optimizer/__init__.py index 4c069cf2ea..6da11ebd0a 100644 --- a/deepmd/pt/optimizer/__init__.py +++ b/deepmd/pt/optimizer/__init__.py @@ -8,5 +8,8 @@ from .LKF import ( LKFOptimizer, ) +from .muon import ( + MuonOptimizer, +) -__all__ = ["AdaMuonOptimizer", "KFOptimizerWrapper", "LKFOptimizer"] +__all__ = ["AdaMuonOptimizer", "KFOptimizerWrapper", "LKFOptimizer", "MuonOptimizer"] diff --git a/deepmd/pt/optimizer/muon.py b/deepmd/pt/optimizer/muon.py new file mode 100644 index 0000000000..aa5185447b --- /dev/null +++ b/deepmd/pt/optimizer/muon.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +Muon optimizer for DeePMD-kit PyTorch backend. + +Muon is an optimizer that applies Newton-Schulz orthogonalization to the gradient +before using momentum, resulting in orthogonalized updates for weight matrices. +This can improve training stability and convergence for certain architectures. + +Reference: + https://github.com/KellerJordan/Muon +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +from torch.optim.optimizer import ( + Optimizer, +) + +if TYPE_CHECKING: + from collections.abc import ( + Iterable, + ) + + +def zeropower_via_newtonschulz5( + G: torch.Tensor, + steps: int = 5, + eps: float = 1e-7, +) -> torch.Tensor: + """ + Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. + + Uses quintic Newton-Schulz iteration to compute the orthogonal component of the + input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T. + + This implementation matches PyTorch official Muon behavior: it always performs + Newton-Schulz in bfloat16 and returns a bfloat16 tensor. + + Parameters + ---------- + G : torch.Tensor + Input matrix to orthogonalize with shape (..., M, N). + steps : int + Number of Newton-Schulz iterations with default 5. + eps : float + Numerical stability epsilon for norm clamping with default 1e-7. + + Returns + ------- + torch.Tensor + Orthogonalized matrix in bfloat16 with same shape as input. + + Raises + ------ + ValueError + If G has fewer than 2 dimensions. + ValueError + If steps >= 100 (guard for efficiency). + """ + # === Step 1. Validate === + if G.ndim < 2: + raise ValueError("Input must have at least 2 dimensions (..., M, N).") + if steps >= 100: + raise ValueError("Number of steps must be less than 100 for efficiency.") + + a, b, c = (3.4445, -4.7750, 2.0315) + + # === Step 2. Cast to bf16 (match official Muon) === + X = G.to(dtype=torch.bfloat16) + + # === Step 3. Transpose tall matrices === + if X.size(-2) > X.size(-1): + X = X.mT + + # === Step 4. Normalize Frobenius norm to at most 1 === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps) + + # === Step 5. Newton-Schulz iterations with fused GEMM === + for _ in range(steps): + A = X @ X.mT + # gram_update = b*A + c*(A@A) via addmm/baddbmm + # X = a*X + gram_update@X via addmm/baddbmm + if X.ndim == 2: + gram_update = torch.addmm(A, A, A, beta=b, alpha=c) + X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0) + else: + gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0) + + # === Step 6. Transpose back if needed === + if G.size(-2) > G.size(-1): + X = X.mT + + return X + + +def _prepare_muon_momentum( + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + beta: float, + nesterov: bool, +) -> tuple[torch.Tensor, tuple[int, ...]]: + """ + Prepare momentum update and reshape for batched Newton-Schulz. + + Parameters + ---------- + grad : torch.Tensor + Gradient tensor. + momentum_buffer : torch.Tensor + Momentum buffer (will be updated in-place). + beta : float + Momentum coefficient. + nesterov : bool + Whether to use Nesterov momentum. + + Returns + ------- + update : torch.Tensor + Reshaped update tensor with shape (M, N). + original_shape : tuple[int, ...] + Original shape before reshape. + """ + # === Step 1. Update momentum buffer === + momentum_buffer.lerp_(grad, 1 - beta) + update = grad.lerp(momentum_buffer, beta) if nesterov else momentum_buffer + + # === Step 2. Handle tensor -> matrix reshape === + original_shape = update.shape + if update.ndim > 2: + update = update.reshape(update.shape[0], -1) + + return update, original_shape + + +class MuonOptimizer(Optimizer): + """ + Muon optimizer with auxiliary Adam for non-matrix parameters. + + This optimizer applies different update rules based on parameter dimensionality: + - For 2D+ parameters (weight matrices): Muon update with Newton-Schulz orthogonalization + - For 1D parameters (biases, layer norms): Standard Adam update + + This hybrid approach is effective because Muon's orthogonalization is designed + for weight matrices, while Adam is more suitable for biases and normalization params. + + Parameters + ---------- + params : iterable + Iterable of parameters to optimize. + lr : float + Learning rate with default 1e-3. + momentum : float + Momentum coefficient for Muon with default 0.95. + weight_decay : float + Weight decay coefficient (applied only to >=2D params) with default 0.001. + ns_steps : int + Number of Newton-Schulz iterations with default 5. + adam_betas : tuple[float, float] + Adam beta coefficients with default (0.9, 0.95). + adam_eps : float + Adam epsilon with default 1e-7. + nesterov : bool + Whether to use Nesterov momentum for Muon with default True. + lr_adjust : float + Learning rate adjustment factor for Adam (1D params). + - If lr_adjust <= 0: use match-RMS scaling for Muon update, + scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly. + - If lr_adjust > 0: use rectangular correction for Muon update, + scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate. + Default is 10.0 (Adam lr = lr/10). + lr_adjust_coeff : float + Coefficient for match-RMS scaling with default 0.2. + Only effective when lr_adjust <= 0. + + Examples + -------- + >>> optimizer = MuonOptimizer(model.parameters(), lr=1e-3) + >>> for epoch in range(epochs): + ... optimizer.zero_grad() + ... loss.backward() + ... optimizer.step() + """ + + def __init__( + self, + params: Iterable[torch.Tensor] | Iterable[dict[str, Any]], + lr: float = 1e-3, + momentum: float = 0.95, + weight_decay: float = 0.001, + ns_steps: int = 5, + adam_betas: tuple[float, float] = (0.9, 0.95), + adam_eps: float = 1e-7, + nesterov: bool = True, + lr_adjust: float = 10.0, + lr_adjust_coeff: float = 0.2, + ) -> None: + defaults = { + "lr": lr, + "momentum": momentum, + "weight_decay": weight_decay, + "ns_steps": ns_steps, + "adam_betas": adam_betas, + "adam_eps": adam_eps, + "nesterov": nesterov, + "lr_adjust": lr_adjust, + "lr_adjust_coeff": lr_adjust_coeff, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: callable | None = None, + ) -> torch.Tensor | None: + """ + Perform a single optimization step. + + Parameters + ---------- + closure : callable, optional + A closure that reevaluates the model and returns the loss. + + Returns + ------- + loss : float, optional + The loss value if closure is provided. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + weight_decay = group["weight_decay"] + ns_steps = group["ns_steps"] + adam_betas = group["adam_betas"] + adam_eps = group["adam_eps"] + nesterov = group["nesterov"] + lr_adjust = group["lr_adjust"] + lr_adjust_coeff = group["lr_adjust_coeff"] + + # === Step 1. Collect params with gradients and separate by type === + muon_params: list[torch.Tensor] = [] # For weight decay (>=2D only) + muon_entries: list[tuple[torch.nn.Parameter, torch.Tensor, tuple]] = [] + # Adam batch lists + adam_params: list[torch.Tensor] = [] + adam_grads_fp32: list[torch.Tensor] = [] + adam_exp_avgs: list[torch.Tensor] = [] + adam_exp_avg_sqs: list[torch.Tensor] = [] + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + if grad.dtype != p.dtype: + grad = grad.to(dtype=p.dtype) + + state = self.state[p] + + if p.ndim >= 2: + # Muon path: collect for weight decay + muon_params.append(p) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + update, orig_shape = _prepare_muon_momentum( + grad, state["momentum_buffer"], momentum, nesterov + ) + muon_entries.append((p, update, orig_shape)) + else: + # Adam path: state tensors forced to FP32 + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + adam_params.append(p) + # Cast grad to FP32 for Adam computation + adam_grads_fp32.append(grad.float()) + adam_exp_avgs.append(state["exp_avg"]) + adam_exp_avg_sqs.append(state["exp_avg_sq"]) + + # === Step 2. Foreach weight decay (only >=2D params) === + if weight_decay > 0 and muon_params: + torch._foreach_mul_(muon_params, 1.0 - lr * weight_decay) + + # === Step 3. Adam update for 1D params (FP32 computation) === + if adam_params: + # Determine Adam learning rate + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + + # Update momentum estimates in FP32 + torch._foreach_lerp_(adam_exp_avgs, adam_grads_fp32, 1 - adam_betas[0]) + grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32) + torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + + # Compute updates with bias correction (per-param beta_pow) + for i, p in enumerate(adam_params): + state = self.state[p] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr / bias_corr1 + # FP32 computation: compute full delta in FP32, then cast once + denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(adam_eps) + delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) + p.add_(delta_fp32.to(p.dtype)) + + # === Step 4. Batched Newton-Schulz for Muon parameters === + if not muon_entries: + continue + + # Group by (rows, cols, device) for batched processing + # Note: dtype is not included since NS internally converts to bf16 + buckets: dict[ + tuple[int, int, torch.device], + list[tuple[torch.nn.Parameter, torch.Tensor, tuple]], + ] = {} + for entry in muon_entries: + p, update, orig_shape = entry + key = (update.shape[0], update.shape[1], update.device) + if key not in buckets: + buckets[key] = [] + buckets[key].append(entry) + + # Process each bucket + for bucket in buckets.values(): + # === Pre-compute bucket-level scaling constants === + # Get matrix dimensions from first entry + m, n = bucket[0][1].shape[-2], bucket[0][1].shape[-1] + # Scaling: match-RMS (lr_adjust<=0) or rectangular correction + if lr_adjust <= 0: + scale = lr_adjust_coeff * math.sqrt(float(max(m, n))) + else: + scale = max(1.0, m / n) ** 0.5 + + # === Stack and orthogonalize === + if len(bucket) == 1: + # Single parameter: 2D path with addmm (faster, correct behavior) + p, update, orig_shape = bucket[0] + orth = zeropower_via_newtonschulz5(update, steps=ns_steps) + # === Apply scaling and update parameters === + orth.mul_(scale) + p.add_(orth.reshape(orig_shape), alpha=-lr) + else: + # Multiple parameters: 3D batched path with baddbmm + stacked = torch.stack( + [item[1].contiguous() for item in bucket], dim=0 + ) + orth_stacked = zeropower_via_newtonschulz5(stacked, steps=ns_steps) + # === Apply scaling and update parameters === + orth_stacked.mul_(scale) + for i, (p, _, orig_shape) in enumerate(bucket): + p.add_(orth_stacked[i].reshape(orig_shape), alpha=-lr) + + return loss diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0dfbe94b6b..4f935c613a 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -45,6 +45,7 @@ AdaMuonOptimizer, KFOptimizerWrapper, LKFOptimizer, + MuonOptimizer, ) from deepmd.pt.train.wrapper import ( ModelWrapper, @@ -158,6 +159,7 @@ def __init__( def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: opt_type = params.get("opt_type", "Adam") opt_param = { + # LKF parameters "kf_blocksize": params.get("kf_blocksize", 5120), "kf_start_pref_e": params.get("kf_start_pref_e", 1), "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), @@ -741,6 +743,17 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: float(self.opt_param.get("adam_beta1", 0.9)), float(self.opt_param.get("adam_beta2", 0.95)), ), + ) + elif self.opt_type == "Muon": + self.optimizer = MuonOptimizer( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + momentum=float(self.opt_param.get("muon_momentum", 0.95)), + weight_decay=float(self.opt_param.get("weight_decay", 0.001)), + adam_betas=( + float(self.opt_param.get("adam_beta1", 0.9)), + float(self.opt_param.get("adam_beta2", 0.95)), + ), lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), ) @@ -820,7 +833,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() - if self.opt_type in ["Adam", "AdamW", "AdaMuon"]: + if self.opt_type in ["Adam", "AdamW", "AdaMuon", "Muon"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr diff --git a/source/tests/pt/test_muon.py b/source/tests/pt/test_muon.py new file mode 100644 index 0000000000..4658ac5b08 --- /dev/null +++ b/source/tests/pt/test_muon.py @@ -0,0 +1,211 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +# NOTE: avoid torch thread reconfiguration errors during import. +import torch + +torch_set_num_interop_threads = getattr(torch, "set_num_interop_threads", None) +torch_set_num_threads = getattr(torch, "set_num_threads", None) +if torch_set_num_interop_threads is not None: + torch.set_num_interop_threads = lambda *args, **kwargs: None # type: ignore[assignment] +if torch_set_num_threads is not None: + torch.set_num_threads = lambda *args, **kwargs: None # type: ignore[assignment] + +from deepmd.pt.optimizer.muon import ( + MuonOptimizer, + zeropower_via_newtonschulz5, +) +from deepmd.pt.utils import ( + env, +) + + +class TestNewtonSchulzOrthogonalization(unittest.TestCase): + """Test Newton-Schulz orthogonalization algorithm.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_square_matrix_approximate_orthogonality(self) -> None: + """Test that output is approximately orthogonal for square matrices.""" + torch.manual_seed(42) + G = torch.randn(4, 4, dtype=torch.float32, device=self.device) + X = zeropower_via_newtonschulz5(G, steps=5) + + # X @ X.T should be approximately identity (diagonal dominant) + # Note: NS returns bf16, so use relaxed tolerance + XXT = X.float() @ X.float().T + # Check diagonal elements are close to 1 (relaxed tolerance for bf16 + 5 iterations) + diag = torch.diag(XXT) + self.assertTrue( + torch.allclose( + diag, torch.ones(4, dtype=torch.float32, device=self.device), atol=0.5 + ), + f"Diagonal not close to 1: {diag}", + ) + # Check off-diagonal elements are relatively small + off_diag_norm = (XXT - torch.diag(diag)).norm() + self.assertLess( + off_diag_norm, 1.5, f"Off-diagonal norm too large: {off_diag_norm}" + ) + + def test_output_shape_preserved(self) -> None: + """Test that output shape matches input shape and dtype is bf16.""" + torch.manual_seed(42) + for shape in [(4, 4), (6, 4), (4, 6), (3, 4, 4)]: + G = torch.randn(*shape, dtype=torch.float32, device=self.device) + X = zeropower_via_newtonschulz5(G, steps=5) + self.assertEqual( + X.shape, G.shape, f"Shape mismatch for input shape {shape}" + ) + self.assertEqual( + X.dtype, torch.bfloat16, f"Output should be bf16, got {X.dtype}" + ) + + +class TestMuonOptimizer(unittest.TestCase): + """Test MuonOptimizer class.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_optimizer_step(self) -> None: + """Test basic optimizer step.""" + torch.manual_seed(42) + # Simple model with 2D and 1D parameters + model = torch.nn.Sequential( + torch.nn.Linear(10, 20, device=self.device), + torch.nn.ReLU(), + torch.nn.Linear(20, 5, device=self.device), + ) + + optimizer = MuonOptimizer(model.parameters(), lr=0.02) + + # Dummy forward-backward pass + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + + # Store initial params + initial_params = [p.clone() for p in model.parameters()] + + # Optimizer step + optimizer.step() + + # Verify parameters changed + for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): + self.assertFalse( + torch.allclose(p, init_p), + f"Parameter {i} did not change after optimizer step", + ) + + def test_weight_decay(self) -> None: + """Test weight decay application.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = MuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1) + + initial_weight_norm = model.weight.norm().item() + + # Multiple steps with gradients + for _ in range(10): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Weight norm should decrease due to weight decay + final_weight_norm = model.weight.norm().item() + self.assertLess( + final_weight_norm, + initial_weight_norm, + "Weight norm should decrease with weight decay", + ) + + def test_muon_for_2d_adam_for_1d(self) -> None: + """Test that Muon is applied to 2D params and Adam to 1D params.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = MuonOptimizer(model.parameters(), lr=0.02) + + # Forward-backward + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Check state - weight (2D) should have momentum_buffer + weight_state = optimizer.state[model.weight] + self.assertIn("momentum_buffer", weight_state) + self.assertNotIn("exp_avg", weight_state) + + # Bias (1D) should have exp_avg and exp_avg_sq + bias_state = optimizer.state[model.bias] + self.assertIn("exp_avg", bias_state) + self.assertIn("exp_avg_sq", bias_state) + self.assertNotIn("momentum_buffer", bias_state) + + def test_closure(self) -> None: + """Test optimizer with closure.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 5, device=self.device) + optimizer = MuonOptimizer(model.parameters(), lr=0.02) + + def closure(): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + return loss + + loss = optimizer.step(closure) + self.assertIsNotNone(loss) + + +class TestMuonOptimizerStateDict(unittest.TestCase): + """Test optimizer state dict save/load.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_state_dict_save_load(self) -> None: + """Test saving and loading optimizer state.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = MuonOptimizer(model.parameters(), lr=0.02) + + # Run a few steps to populate state + for _ in range(3): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Save state + state_dict = optimizer.state_dict() + + # Create new optimizer and load state + optimizer2 = MuonOptimizer(model.parameters(), lr=0.02) + optimizer2.load_state_dict(state_dict) + + # Verify state matches + for (_, s1), (_, s2) in zip(optimizer.state.items(), optimizer2.state.items()): + for key in s1: + if isinstance(s1[key], torch.Tensor): + self.assertTrue( + torch.allclose(s1[key], s2[key]), + f"State mismatch for key {key}", + ) + else: + self.assertEqual(s1[key], s2[key]) + + +if __name__ == "__main__": + unittest.main() From 3ea9e66f65d91490b0ed8ee95d136d52d286db10 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 11 Jan 2026 10:59:36 +0800 Subject: [PATCH 02/10] doc & test: Muon (cherry picked from commit 46fcb7d8d106fbeadc9af169248f4bdf6bc7ced9) --- deepmd/pt/optimizer/muon.py | 123 +++++++++++++++++---------- deepmd/pt/train/training.py | 6 +- deepmd/utils/argcheck.py | 3 +- source/tests/pt/test_muon.py | 160 ++++++++++++++--------------------- 4 files changed, 147 insertions(+), 145 deletions(-) diff --git a/deepmd/pt/optimizer/muon.py b/deepmd/pt/optimizer/muon.py index aa5185447b..c9763f9e7f 100644 --- a/deepmd/pt/optimizer/muon.py +++ b/deepmd/pt/optimizer/muon.py @@ -6,8 +6,35 @@ before using momentum, resulting in orthogonalized updates for weight matrices. This can improve training stability and convergence for certain architectures. -Reference: - https://github.com/KellerJordan/Muon +Algorithm +--------- +For >=2D parameters (weight matrices), the Muon update is: + + 1. Momentum update (Nesterov): + m_t = beta * m_{t-1} + (1 - beta) * g_t + update = beta * m_t + (1 - beta) * g_t + + 2. Newton-Schulz orthogonalization (quintic iteration): + X_0 = G / ||G||_F + X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T + Coefficients: a=3.4445, b=-4.7750, c=2.0315 + + 3. Scaling: scale = coeff * sqrt(max(m, n)) [match-RMS mode] + scale = sqrt(max(1, m/n)) [rectangular mode] + + 4. Parameter update: theta -= lr * scale * orth(update) + +For 1D parameters (biases, norms), standard Adam is used. + +Dtype Behavior +-------------- +- Newton-Schulz iterations: always bfloat16 (matches official Muon) +- Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability +- Gradients: cast to parameter dtype before momentum update + +Reference +--------- +https://github.com/KellerJordan/Muon """ from __future__ import ( @@ -30,11 +57,20 @@ Iterable, ) +# ============================================================================ +# Constants +# ============================================================================ + +# Newton-Schulz iteration count +NS_STEPS: int = 5 +# Numerical stability epsilon for norm clamping +NS_EPS: float = 1e-7 +# Adam epsilon for numerical stability +ADAM_EPS: float = 1e-7 + def zeropower_via_newtonschulz5( G: torch.Tensor, - steps: int = 5, - eps: float = 1e-7, ) -> torch.Tensor: """ Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. @@ -42,6 +78,11 @@ def zeropower_via_newtonschulz5( Uses quintic Newton-Schulz iteration to compute the orthogonal component of the input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T. + Mathematical formulation: + X_0 = G / ||G||_F + X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T + Coefficients: a=3.4445, b=-4.7750, c=2.0315 + This implementation matches PyTorch official Muon behavior: it always performs Newton-Schulz in bfloat16 and returns a bfloat16 tensor. @@ -49,10 +90,6 @@ def zeropower_via_newtonschulz5( ---------- G : torch.Tensor Input matrix to orthogonalize with shape (..., M, N). - steps : int - Number of Newton-Schulz iterations with default 5. - eps : float - Numerical stability epsilon for norm clamping with default 1e-7. Returns ------- @@ -63,14 +100,10 @@ def zeropower_via_newtonschulz5( ------ ValueError If G has fewer than 2 dimensions. - ValueError - If steps >= 100 (guard for efficiency). """ # === Step 1. Validate === if G.ndim < 2: raise ValueError("Input must have at least 2 dimensions (..., M, N).") - if steps >= 100: - raise ValueError("Number of steps must be less than 100 for efficiency.") a, b, c = (3.4445, -4.7750, 2.0315) @@ -82,10 +115,10 @@ def zeropower_via_newtonschulz5( X = X.mT # === Step 4. Normalize Frobenius norm to at most 1 === - X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps) + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS) # === Step 5. Newton-Schulz iterations with fused GEMM === - for _ in range(steps): + for _ in range(NS_STEPS): A = X @ X.mT # gram_update = b*A + c*(A@A) via addmm/baddbmm # X = a*X + gram_update@X via addmm/baddbmm @@ -107,11 +140,13 @@ def _prepare_muon_momentum( grad: torch.Tensor, momentum_buffer: torch.Tensor, beta: float, - nesterov: bool, ) -> tuple[torch.Tensor, tuple[int, ...]]: """ Prepare momentum update and reshape for batched Newton-Schulz. + Uses Nesterov momentum: update = beta*m_t + (1-beta)*g_t, where m_t is + the updated momentum buffer. + Parameters ---------- grad : torch.Tensor @@ -120,8 +155,6 @@ def _prepare_muon_momentum( Momentum buffer (will be updated in-place). beta : float Momentum coefficient. - nesterov : bool - Whether to use Nesterov momentum. Returns ------- @@ -132,7 +165,8 @@ def _prepare_muon_momentum( """ # === Step 1. Update momentum buffer === momentum_buffer.lerp_(grad, 1 - beta) - update = grad.lerp(momentum_buffer, beta) if nesterov else momentum_buffer + # Nesterov lookahead + update = grad.lerp(momentum_buffer, beta) # === Step 2. Handle tensor -> matrix reshape === original_shape = update.shape @@ -147,12 +181,24 @@ class MuonOptimizer(Optimizer): Muon optimizer with auxiliary Adam for non-matrix parameters. This optimizer applies different update rules based on parameter dimensionality: - - For 2D+ parameters (weight matrices): Muon update with Newton-Schulz orthogonalization + - For >=2D parameters (weight matrices): Muon update with Newton-Schulz orthogonalization - For 1D parameters (biases, layer norms): Standard Adam update This hybrid approach is effective because Muon's orthogonalization is designed for weight matrices, while Adam is more suitable for biases and normalization params. + Update Rules + ------------ + Muon (>=2D params): + 1. Momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t + 2. Nesterov lookahead: update = beta*m_t + (1-beta)*g_t + 3. Newton-Schulz orthogonalization: orth = NS(update) + 4. Scaling: scale = coeff*sqrt(max(m,n)) or sqrt(max(1, m/n)) + 5. Parameter update: theta -= lr * scale * orth + + Adam (1D params): + Standard Adam with bias correction, all computations in float32. + Parameters ---------- params : iterable @@ -163,21 +209,15 @@ class MuonOptimizer(Optimizer): Momentum coefficient for Muon with default 0.95. weight_decay : float Weight decay coefficient (applied only to >=2D params) with default 0.001. - ns_steps : int - Number of Newton-Schulz iterations with default 5. adam_betas : tuple[float, float] Adam beta coefficients with default (0.9, 0.95). - adam_eps : float - Adam epsilon with default 1e-7. - nesterov : bool - Whether to use Nesterov momentum for Muon with default True. lr_adjust : float - Learning rate adjustment factor for Adam (1D params). - - If lr_adjust <= 0: use match-RMS scaling for Muon update, + Learning rate adjustment mode for Muon scaling and Adam learning rate. + - If lr_adjust <= 0: use match-RMS scaling for Muon, scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly. - - If lr_adjust > 0: use rectangular correction for Muon update, - scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate. - Default is 10.0 (Adam lr = lr/10). + - If lr_adjust > 0: use rectangular correction for Muon, + scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust. + Default is 0.0 (match-RMS scaling). lr_adjust_coeff : float Coefficient for match-RMS scaling with default 0.2. Only effective when lr_adjust <= 0. @@ -197,21 +237,15 @@ def __init__( lr: float = 1e-3, momentum: float = 0.95, weight_decay: float = 0.001, - ns_steps: int = 5, adam_betas: tuple[float, float] = (0.9, 0.95), - adam_eps: float = 1e-7, - nesterov: bool = True, - lr_adjust: float = 10.0, + lr_adjust: float = 0.0, lr_adjust_coeff: float = 0.2, ) -> None: defaults = { "lr": lr, "momentum": momentum, "weight_decay": weight_decay, - "ns_steps": ns_steps, "adam_betas": adam_betas, - "adam_eps": adam_eps, - "nesterov": nesterov, "lr_adjust": lr_adjust, "lr_adjust_coeff": lr_adjust_coeff, } @@ -232,8 +266,8 @@ def step( Returns ------- - loss : float, optional - The loss value if closure is provided. + torch.Tensor | None + The loss value if closure is provided, otherwise None. """ loss = None if closure is not None: @@ -244,10 +278,7 @@ def step( lr = group["lr"] momentum = group["momentum"] weight_decay = group["weight_decay"] - ns_steps = group["ns_steps"] adam_betas = group["adam_betas"] - adam_eps = group["adam_eps"] - nesterov = group["nesterov"] lr_adjust = group["lr_adjust"] lr_adjust_coeff = group["lr_adjust_coeff"] @@ -276,7 +307,7 @@ def step( if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(grad) update, orig_shape = _prepare_muon_momentum( - grad, state["momentum_buffer"], momentum, nesterov + grad, state["momentum_buffer"], momentum ) muon_entries.append((p, update, orig_shape)) else: @@ -315,7 +346,7 @@ def step( bias_corr2 = 1 - state["beta2_pow"] step_size = adam_lr / bias_corr1 # FP32 computation: compute full delta in FP32, then cast once - denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(adam_eps) + denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS) delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) @@ -351,7 +382,7 @@ def step( if len(bucket) == 1: # Single parameter: 2D path with addmm (faster, correct behavior) p, update, orig_shape = bucket[0] - orth = zeropower_via_newtonschulz5(update, steps=ns_steps) + orth = zeropower_via_newtonschulz5(update) # === Apply scaling and update parameters === orth.mul_(scale) p.add_(orth.reshape(orig_shape), alpha=-lr) @@ -360,7 +391,7 @@ def step( stacked = torch.stack( [item[1].contiguous() for item in bucket], dim=0 ) - orth_stacked = zeropower_via_newtonschulz5(stacked, steps=ns_steps) + orth_stacked = zeropower_via_newtonschulz5(stacked) # === Apply scaling and update parameters === orth_stacked.mul_(scale) for i, (p, _, orig_shape) in enumerate(bucket): diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 4f935c613a..08cc07013b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -171,6 +171,8 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "momentum": params.get("momentum", 0.95), "adam_beta1": params.get("adam_beta1", 0.9), "adam_beta2": params.get("adam_beta2", 0.95), + "lr_adjust": params.get("lr_adjust", 0.0), + "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), } return opt_type, opt_param @@ -748,13 +750,13 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.optimizer = MuonOptimizer( self.wrapper.parameters(), lr=self.lr_exp.start_lr, - momentum=float(self.opt_param.get("muon_momentum", 0.95)), + momentum=float(self.opt_param.get("momentum", 0.95)), weight_decay=float(self.opt_param.get("weight_decay", 0.001)), adam_betas=( float(self.opt_param.get("adam_beta1", 0.9)), float(self.opt_param.get("adam_beta2", 0.95)), ), - lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), + lr_adjust=float(self.opt_param.get("lr_adjust", 0.0)), lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), ) if optimizer_state_dict is not None and self.restart_training: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 935762cdc7..dcb3ea6c0a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3425,6 +3425,7 @@ def training_args( float, optional=True, default=0.95, + alias=["muon_momentum"], doc=doc_only_pt_supported + "Momentum coefficient for AdaMuon optimizer.", ), @@ -3456,7 +3457,7 @@ def training_args( "lr_adjust", float, optional=True, - default=10.0, + default=0.0, doc=doc_only_pt_supported + "Learning rate adjustment factor for Adam (1D params). " "If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. " diff --git a/source/tests/pt/test_muon.py b/source/tests/pt/test_muon.py index 4658ac5b08..7889ef9066 100644 --- a/source/tests/pt/test_muon.py +++ b/source/tests/pt/test_muon.py @@ -1,16 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest -# NOTE: avoid torch thread reconfiguration errors during import. import torch -torch_set_num_interop_threads = getattr(torch, "set_num_interop_threads", None) -torch_set_num_threads = getattr(torch, "set_num_threads", None) -if torch_set_num_interop_threads is not None: - torch.set_num_interop_threads = lambda *args, **kwargs: None # type: ignore[assignment] -if torch_set_num_threads is not None: - torch.set_num_threads = lambda *args, **kwargs: None # type: ignore[assignment] - from deepmd.pt.optimizer.muon import ( MuonOptimizer, zeropower_via_newtonschulz5, @@ -26,16 +18,15 @@ class TestNewtonSchulzOrthogonalization(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE - def test_square_matrix_approximate_orthogonality(self) -> None: - """Test that output is approximately orthogonal for square matrices.""" + def test_orthogonalization(self) -> None: + """Test that NS produces approximately orthogonal output.""" torch.manual_seed(42) G = torch.randn(4, 4, dtype=torch.float32, device=self.device) - X = zeropower_via_newtonschulz5(G, steps=5) + X = zeropower_via_newtonschulz5(G) - # X @ X.T should be approximately identity (diagonal dominant) - # Note: NS returns bf16, so use relaxed tolerance + # X @ X.T should be approximately identity + # Note: NS uses bf16 internally, 5 iterations gives ~0.1-0.3 error XXT = X.float() @ X.float().T - # Check diagonal elements are close to 1 (relaxed tolerance for bf16 + 5 iterations) diag = torch.diag(XXT) self.assertTrue( torch.allclose( @@ -43,24 +34,25 @@ def test_square_matrix_approximate_orthogonality(self) -> None: ), f"Diagonal not close to 1: {diag}", ) - # Check off-diagonal elements are relatively small off_diag_norm = (XXT - torch.diag(diag)).norm() self.assertLess( off_diag_norm, 1.5, f"Off-diagonal norm too large: {off_diag_norm}" ) - def test_output_shape_preserved(self) -> None: - """Test that output shape matches input shape and dtype is bf16.""" + def test_shape_and_dtype(self) -> None: + """Test that output preserves shape and returns bf16.""" torch.manual_seed(42) - for shape in [(4, 4), (6, 4), (4, 6), (3, 4, 4)]: + for shape in [(4, 4), (6, 4), (3, 4, 4)]: G = torch.randn(*shape, dtype=torch.float32, device=self.device) - X = zeropower_via_newtonschulz5(G, steps=5) - self.assertEqual( - X.shape, G.shape, f"Shape mismatch for input shape {shape}" - ) - self.assertEqual( - X.dtype, torch.bfloat16, f"Output should be bf16, got {X.dtype}" - ) + X = zeropower_via_newtonschulz5(G) + self.assertEqual(X.shape, G.shape) + self.assertEqual(X.dtype, torch.bfloat16) + + def test_invalid_input(self) -> None: + """Test that <2D input raises ValueError.""" + G_1d = torch.randn(10, dtype=torch.float32, device=self.device) + with self.assertRaises(ValueError): + zeropower_via_newtonschulz5(G_1d) class TestMuonOptimizer(unittest.TestCase): @@ -69,102 +61,83 @@ class TestMuonOptimizer(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE - def test_optimizer_step(self) -> None: - """Test basic optimizer step.""" + def test_step(self) -> None: + """Test basic optimizer step changes parameters.""" torch.manual_seed(42) - # Simple model with 2D and 1D parameters model = torch.nn.Sequential( torch.nn.Linear(10, 20, device=self.device), torch.nn.ReLU(), torch.nn.Linear(20, 5, device=self.device), ) - optimizer = MuonOptimizer(model.parameters(), lr=0.02) - # Dummy forward-backward pass x = torch.randn(4, 10, device=self.device) - y = model(x) - loss = y.sum() - loss.backward() + model(x).sum().backward() - # Store initial params initial_params = [p.clone() for p in model.parameters()] - - # Optimizer step optimizer.step() - # Verify parameters changed for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): - self.assertFalse( - torch.allclose(p, init_p), - f"Parameter {i} did not change after optimizer step", - ) + self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change") def test_weight_decay(self) -> None: - """Test weight decay application.""" + """Test weight decay reduces parameter norm.""" torch.manual_seed(42) model = torch.nn.Linear(10, 10, device=self.device) optimizer = MuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1) - initial_weight_norm = model.weight.norm().item() - - # Multiple steps with gradients + initial_norm = model.weight.norm().item() for _ in range(10): optimizer.zero_grad() x = torch.randn(4, 10, device=self.device) - y = model(x) - loss = y.sum() - loss.backward() + model(x).sum().backward() optimizer.step() - # Weight norm should decrease due to weight decay - final_weight_norm = model.weight.norm().item() - self.assertLess( - final_weight_norm, - initial_weight_norm, - "Weight norm should decrease with weight decay", - ) + self.assertLess(model.weight.norm().item(), initial_norm) - def test_muon_for_2d_adam_for_1d(self) -> None: - """Test that Muon is applied to 2D params and Adam to 1D params.""" + def test_muon_adam_separation(self) -> None: + """Test Muon for 2D params, Adam for 1D params.""" torch.manual_seed(42) model = torch.nn.Linear(10, 10, device=self.device) optimizer = MuonOptimizer(model.parameters(), lr=0.02) - # Forward-backward x = torch.randn(4, 10, device=self.device) - y = model(x) - loss = y.sum() - loss.backward() + model(x).sum().backward() optimizer.step() - # Check state - weight (2D) should have momentum_buffer - weight_state = optimizer.state[model.weight] - self.assertIn("momentum_buffer", weight_state) - self.assertNotIn("exp_avg", weight_state) + # 2D weight uses Muon (momentum_buffer) + self.assertIn("momentum_buffer", optimizer.state[model.weight]) + self.assertNotIn("exp_avg", optimizer.state[model.weight]) + # 1D bias uses Adam (exp_avg, exp_avg_sq) + self.assertIn("exp_avg", optimizer.state[model.bias]) + self.assertIn("exp_avg_sq", optimizer.state[model.bias]) + self.assertNotIn("momentum_buffer", optimizer.state[model.bias]) - # Bias (1D) should have exp_avg and exp_avg_sq - bias_state = optimizer.state[model.bias] - self.assertIn("exp_avg", bias_state) - self.assertIn("exp_avg_sq", bias_state) - self.assertNotIn("momentum_buffer", bias_state) - - def test_closure(self) -> None: - """Test optimizer with closure.""" + def test_lr_adjust_modes(self) -> None: + """Test lr_adjust modes: match-RMS (<=0) vs rectangular (>0).""" torch.manual_seed(42) - model = torch.nn.Linear(10, 5, device=self.device) - optimizer = MuonOptimizer(model.parameters(), lr=0.02) - def closure(): - optimizer.zero_grad() - x = torch.randn(4, 10, device=self.device) - y = model(x) - loss = y.sum() - loss.backward() - return loss + model1 = torch.nn.Linear(10, 20, bias=False, device=self.device) + model2 = torch.nn.Linear(10, 20, bias=False, device=self.device) + model2.load_state_dict(model1.state_dict()) + + opt1 = MuonOptimizer(model1.parameters(), lr=0.02, lr_adjust=0.0) + opt2 = MuonOptimizer(model2.parameters(), lr=0.02, lr_adjust=10.0) + + x = torch.randn(4, 10, device=self.device) + + opt1.zero_grad() + model1(x).sum().backward() + opt1.step() - loss = optimizer.step(closure) - self.assertIsNotNone(loss) + opt2.zero_grad() + model2(x).sum().backward() + opt2.step() + + self.assertFalse( + torch.allclose(model1.weight, model2.weight), + "Different lr_adjust modes should produce different updates", + ) class TestMuonOptimizerStateDict(unittest.TestCase): @@ -179,30 +152,25 @@ def test_state_dict_save_load(self) -> None: model = torch.nn.Linear(10, 10, device=self.device) optimizer = MuonOptimizer(model.parameters(), lr=0.02) - # Run a few steps to populate state for _ in range(3): optimizer.zero_grad() x = torch.randn(4, 10, device=self.device) - y = model(x) - loss = y.sum() - loss.backward() + model(x).sum().backward() optimizer.step() - # Save state state_dict = optimizer.state_dict() - # Create new optimizer and load state optimizer2 = MuonOptimizer(model.parameters(), lr=0.02) optimizer2.load_state_dict(state_dict) - # Verify state matches - for (_, s1), (_, s2) in zip(optimizer.state.items(), optimizer2.state.items()): + # Verify state matches by param id, not iteration order + for p in model.parameters(): + s1 = optimizer.state.get(p, {}) + s2 = optimizer2.state.get(p, {}) + self.assertEqual(len(s1), len(s2)) for key in s1: if isinstance(s1[key], torch.Tensor): - self.assertTrue( - torch.allclose(s1[key], s2[key]), - f"State mismatch for key {key}", - ) + self.assertTrue(torch.allclose(s1[key], s2[key])) else: self.assertEqual(s1[key], s2[key]) From 40b5ed415c2584d9d38e655fcdc1c00d815b9bd1 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 11 Jan 2026 12:05:14 +0800 Subject: [PATCH 03/10] refactor: compile Muon (cherry picked from commit 1dd737f9b44f3bb2ea82c74bd51c4653147b87ec) --- deepmd/pt/optimizer/muon.py | 396 ++++++++++++++++++++++-------------- deepmd/pt/train/training.py | 4 +- deepmd/utils/argcheck.py | 2 +- 3 files changed, 241 insertions(+), 161 deletions(-) diff --git a/deepmd/pt/optimizer/muon.py b/deepmd/pt/optimizer/muon.py index c9763f9e7f..6044b5bcff 100644 --- a/deepmd/pt/optimizer/muon.py +++ b/deepmd/pt/optimizer/muon.py @@ -30,7 +30,8 @@ -------------- - Newton-Schulz iterations: always bfloat16 (matches official Muon) - Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability -- Gradients: cast to parameter dtype before momentum update +- Muon gradients: cast to parameter dtype before momentum update +- Adam gradients: cast to float32 for update computation Reference --------- @@ -67,113 +68,117 @@ NS_EPS: float = 1e-7 # Adam epsilon for numerical stability ADAM_EPS: float = 1e-7 +# Quintic Newton-Schulz polynomial coefficients +NS_COEFF_A: float = 3.4445 +NS_COEFF_B: float = -4.7750 +NS_COEFF_C: float = 2.0315 -def zeropower_via_newtonschulz5( +def _maybe_compile( + fn: callable, +) -> callable: + """Compile a function if torch.compile is available.""" + if hasattr(torch, "compile"): + return torch.compile(fn, fullgraph=True, dynamic=True) + return fn + + +@_maybe_compile +def _zeropower_via_newtonschulz5_2d( G: torch.Tensor, ) -> torch.Tensor: """ - Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. - - Uses quintic Newton-Schulz iteration to compute the orthogonal component of the - input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T. + Orthogonalize a 2D matrix via quintic Newton-Schulz iteration. Mathematical formulation: X_0 = G / ||G||_F X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T Coefficients: a=3.4445, b=-4.7750, c=2.0315 + """ + # === Step 1. Cast to bf16 and transpose tall matrices === + X = G.to(dtype=torch.bfloat16) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.transpose(-2, -1) - This implementation matches PyTorch official Muon behavior: it always performs - Newton-Schulz in bfloat16 and returns a bfloat16 tensor. + # === Step 2. Normalize Frobenius norm to at most 1 === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS) - Parameters - ---------- - G : torch.Tensor - Input matrix to orthogonalize with shape (..., M, N). + # === Step 3. Newton-Schulz iterations with fused GEMM === + for _ in range(NS_STEPS): + A = torch.mm(X, X.transpose(-2, -1)) + gram_update = torch.addmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C) + X = torch.addmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0) - Returns - ------- - torch.Tensor - Orthogonalized matrix in bfloat16 with same shape as input. + # === Step 4. Transpose back if needed === + if transposed: + X = X.transpose(-2, -1) - Raises - ------ - ValueError - If G has fewer than 2 dimensions. - """ - # === Step 1. Validate === - if G.ndim < 2: - raise ValueError("Input must have at least 2 dimensions (..., M, N).") + return X - a, b, c = (3.4445, -4.7750, 2.0315) - # === Step 2. Cast to bf16 (match official Muon) === - X = G.to(dtype=torch.bfloat16) +@_maybe_compile +def _zeropower_via_newtonschulz5_3d( + G: torch.Tensor, +) -> torch.Tensor: + """ + Orthogonalize a 3D batch of matrices via quintic Newton-Schulz iteration. - # === Step 3. Transpose tall matrices === - if X.size(-2) > X.size(-1): - X = X.mT + Mathematical formulation: + X_0 = G / ||G||_F + X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T + Coefficients: a=3.4445, b=-4.7750, c=2.0315 + """ + # === Step 1. Cast to bf16 and transpose tall matrices === + X = G.to(dtype=torch.bfloat16) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.transpose(-2, -1) - # === Step 4. Normalize Frobenius norm to at most 1 === + # === Step 2. Normalize Frobenius norm to at most 1 === X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS) - # === Step 5. Newton-Schulz iterations with fused GEMM === + # === Step 3. Newton-Schulz iterations with batched fused GEMM === for _ in range(NS_STEPS): - A = X @ X.mT - # gram_update = b*A + c*(A@A) via addmm/baddbmm - # X = a*X + gram_update@X via addmm/baddbmm - if X.ndim == 2: - gram_update = torch.addmm(A, A, A, beta=b, alpha=c) - X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0) - else: - gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c) - X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0) - - # === Step 6. Transpose back if needed === - if G.size(-2) > G.size(-1): - X = X.mT + A = torch.bmm(X, X.transpose(-2, -1)) + gram_update = torch.baddbmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C) + X = torch.baddbmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0) + + # === Step 4. Transpose back if needed === + if transposed: + X = X.transpose(-2, -1) return X -def _prepare_muon_momentum( - grad: torch.Tensor, - momentum_buffer: torch.Tensor, - beta: float, -) -> tuple[torch.Tensor, tuple[int, ...]]: +def zeropower_via_newtonschulz5( + G: torch.Tensor, +) -> torch.Tensor: """ - Prepare momentum update and reshape for batched Newton-Schulz. + Compute the zeroth power (orthogonalization) via Newton-Schulz iteration. - Uses Nesterov momentum: update = beta*m_t + (1-beta)*g_t, where m_t is - the updated momentum buffer. + Dispatches to compiled 2D or 3D kernels for best performance. Parameters ---------- - grad : torch.Tensor - Gradient tensor. - momentum_buffer : torch.Tensor - Momentum buffer (will be updated in-place). - beta : float - Momentum coefficient. + G : torch.Tensor + Input matrix with shape (M, N) or batched input with shape (B, M, N). Returns ------- - update : torch.Tensor - Reshaped update tensor with shape (M, N). - original_shape : tuple[int, ...] - Original shape before reshape. - """ - # === Step 1. Update momentum buffer === - momentum_buffer.lerp_(grad, 1 - beta) - # Nesterov lookahead - update = grad.lerp(momentum_buffer, beta) - - # === Step 2. Handle tensor -> matrix reshape === - original_shape = update.shape - if update.ndim > 2: - update = update.reshape(update.shape[0], -1) + torch.Tensor + Orthogonalized tensor in bfloat16 with same shape as input. - return update, original_shape + Raises + ------ + ValueError + If input is not 2D or 3D. + """ + if G.ndim == 2: + return _zeropower_via_newtonschulz5_2d(G) + if G.ndim == 3: + return _zeropower_via_newtonschulz5_3d(G) + raise ValueError("Input must be 2D or 3D for Newton-Schulz orthogonalization.") class MuonOptimizer(Optimizer): @@ -217,7 +222,7 @@ class MuonOptimizer(Optimizer): scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly. - If lr_adjust > 0: use rectangular correction for Muon, scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust. - Default is 0.0 (match-RMS scaling). + Default is 10.0 (Adam lr = lr/10). lr_adjust_coeff : float Coefficient for match-RMS scaling with default 0.2. Only effective when lr_adjust <= 0. @@ -238,7 +243,7 @@ def __init__( momentum: float = 0.95, weight_decay: float = 0.001, adam_betas: tuple[float, float] = (0.9, 0.95), - lr_adjust: float = 0.0, + lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, ) -> None: defaults = { @@ -250,6 +255,46 @@ def __init__( "lr_adjust_coeff": lr_adjust_coeff, } super().__init__(params, defaults) + # Static parameter routing: built once on first step() call. + self._routing_built = False + self._routing: list[dict[str, Any]] = [] + + def _build_param_routing(self) -> None: + """ + Classify parameters into Muon and Adam routes (static routing). + + Routing logic: + - >=2D parameters → Muon path (Newton-Schulz + momentum) + - 1D parameters → Adam path (standard Adam update) + """ + if self._routing_built: + return + + self._routing = [] + for group in self.param_groups: + muon_params: list[dict[str, Any]] = [] + adam_params: list[dict[str, Any]] = [] + + for p in group["params"]: + if p.ndim >= 2: + muon_params.append( + { + "param": p, + "rows": int(p.shape[0]), + "cols": int(p.numel() // p.shape[0]), + } + ) + else: + adam_params.append({"param": p}) + + self._routing.append( + { + "muon_params": muon_params, + "adam_params": adam_params, + } + ) + + self._routing_built = True @torch.no_grad() def step( @@ -274,7 +319,11 @@ def step( with torch.enable_grad(): loss = closure() - for group in self.param_groups: + # Build static parameter routing on first call. + self._build_param_routing() + + for group_idx, group in enumerate(self.param_groups): + route = self._routing[group_idx] lr = group["lr"] momentum = group["momentum"] weight_decay = group["weight_decay"] @@ -282,119 +331,150 @@ def step( lr_adjust = group["lr_adjust"] lr_adjust_coeff = group["lr_adjust_coeff"] - # === Step 1. Collect params with gradients and separate by type === - muon_params: list[torch.Tensor] = [] # For weight decay (>=2D only) - muon_entries: list[tuple[torch.nn.Parameter, torch.Tensor, tuple]] = [] - # Adam batch lists + # === Step 1. Adam update for 1D parameters (biases, norms, etc.) === adam_params: list[torch.Tensor] = [] adam_grads_fp32: list[torch.Tensor] = [] adam_exp_avgs: list[torch.Tensor] = [] adam_exp_avg_sqs: list[torch.Tensor] = [] + adam_states: list[dict[str, Any]] = [] - for p in group["params"]: - if p.grad is None: + for entry in route["adam_params"]: + p = entry["param"] + grad = p.grad + if grad is None: continue - grad = p.grad - if grad.dtype != p.dtype: - grad = grad.to(dtype=p.dtype) + grad_fp32 = grad.float() state = self.state[p] + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + + adam_params.append(p) + adam_grads_fp32.append(grad_fp32) + adam_exp_avgs.append(state["exp_avg"]) + adam_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_states.append(state) - if p.ndim >= 2: - # Muon path: collect for weight decay - muon_params.append(p) - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(grad) - update, orig_shape = _prepare_muon_momentum( - grad, state["momentum_buffer"], momentum - ) - muon_entries.append((p, update, orig_shape)) - else: - # Adam path: state tensors forced to FP32 - if "exp_avg" not in state: - state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) - state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) - state["beta1_pow"] = 1.0 - state["beta2_pow"] = 1.0 - state["beta1_pow"] *= adam_betas[0] - state["beta2_pow"] *= adam_betas[1] - adam_params.append(p) - # Cast grad to FP32 for Adam computation - adam_grads_fp32.append(grad.float()) - adam_exp_avgs.append(state["exp_avg"]) - adam_exp_avg_sqs.append(state["exp_avg_sq"]) - - # === Step 2. Foreach weight decay (only >=2D params) === - if weight_decay > 0 and muon_params: - torch._foreach_mul_(muon_params, 1.0 - lr * weight_decay) - - # === Step 3. Adam update for 1D params (FP32 computation) === if adam_params: - # Determine Adam learning rate adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust - # Update momentum estimates in FP32 + # exp_avg = beta1 * exp_avg + (1 - beta1) * grad + # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 torch._foreach_lerp_(adam_exp_avgs, adam_grads_fp32, 1 - adam_betas[0]) grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32) torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) - # Compute updates with bias correction (per-param beta_pow) for i, p in enumerate(adam_params): - state = self.state[p] + state = adam_states[i] bias_corr1 = 1 - state["beta1_pow"] bias_corr2 = 1 - state["beta2_pow"] step_size = adam_lr / bias_corr1 - # FP32 computation: compute full delta in FP32, then cast once + # delta = -step_size * m_hat / (sqrt(v_hat) + eps) denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS) delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) - # === Step 4. Batched Newton-Schulz for Muon parameters === - if not muon_entries: + # === Step 2. Muon update for >=2D parameters (weight matrices) === + muon_params_for_decay: list[torch.Tensor] = [] + muon_grads: list[torch.Tensor] = [] + muon_momentum_buffers: list[torch.Tensor] = [] + active_entries: list[tuple[dict[str, Any], torch.Tensor]] = [] + + for entry in route["muon_params"]: + p = entry["param"] + grad = p.grad + if grad is None: + continue + + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + + buf = state["momentum_buffer"] + if grad.dtype != buf.dtype: + grad = grad.to(dtype=buf.dtype) + + muon_params_for_decay.append(p) + muon_grads.append(grad) + muon_momentum_buffers.append(buf) + active_entries.append((entry, grad)) + + if weight_decay > 0 and muon_params_for_decay: + torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay) + + if not active_entries: continue - # Group by (rows, cols, device) for batched processing - # Note: dtype is not included since NS internally converts to bf16 + # m_t = beta * m_{t-1} + (1 - beta) * g_t + torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum) + # update = beta * m_t + (1 - beta) * g_t + muon_updates = torch._foreach_lerp( + muon_grads, muon_momentum_buffers, momentum + ) + buckets: dict[ - tuple[int, int, torch.device], - list[tuple[torch.nn.Parameter, torch.Tensor, tuple]], + tuple[int, int, torch.device, torch.dtype], + list[tuple[dict[str, Any], torch.Tensor]], ] = {} - for entry in muon_entries: - p, update, orig_shape = entry - key = (update.shape[0], update.shape[1], update.device) - if key not in buckets: - buckets[key] = [] - buckets[key].append(entry) - - # Process each bucket - for bucket in buckets.values(): - # === Pre-compute bucket-level scaling constants === - # Get matrix dimensions from first entry - m, n = bucket[0][1].shape[-2], bucket[0][1].shape[-1] - # Scaling: match-RMS (lr_adjust<=0) or rectangular correction + + for idx, entry_info in enumerate(active_entries): + entry, _ = entry_info + p = entry["param"] + bucket_key = (entry["rows"], entry["cols"], p.device, p.dtype) + if bucket_key not in buckets: + buckets[bucket_key] = [] + buckets[bucket_key].append((entry, muon_updates[idx])) + + for (rows, cols, _device, dtype), bucket_entries in buckets.items(): + # scale = coeff * sqrt(max(m, n)) [match-RMS mode] + # scale = sqrt(max(1, m/n)) [rectangular mode] if lr_adjust <= 0: - scale = lr_adjust_coeff * math.sqrt(float(max(m, n))) + scale = lr_adjust_coeff * math.sqrt(float(max(rows, cols))) else: - scale = max(1.0, m / n) ** 0.5 - - # === Stack and orthogonalize === - if len(bucket) == 1: - # Single parameter: 2D path with addmm (faster, correct behavior) - p, update, orig_shape = bucket[0] - orth = zeropower_via_newtonschulz5(update) - # === Apply scaling and update parameters === + scale = max(1.0, rows / cols) ** 0.5 + + if len(bucket_entries) == 1: + entry, update_tensor = bucket_entries[0] + update_matrix = update_tensor.reshape(rows, cols) + if not update_matrix.is_contiguous(): + update_matrix = update_matrix.contiguous() + + orth = _zeropower_via_newtonschulz5_2d(update_matrix) orth.mul_(scale) - p.add_(orth.reshape(orig_shape), alpha=-lr) - else: - # Multiple parameters: 3D batched path with baddbmm - stacked = torch.stack( - [item[1].contiguous() for item in bucket], dim=0 + delta = orth.reshape(entry["param"].shape) + if delta.dtype != dtype: + delta = delta.to(dtype) + entry["param"].add_(delta, alpha=-lr) + continue + + matrices: list[torch.Tensor] = [] + params: list[torch.Tensor] = [] + orig_shapes: list[tuple[int, ...]] = [] + + for entry, update_tensor in bucket_entries: + update_matrix = update_tensor.reshape(rows, cols) + matrices.append( + update_matrix + if update_matrix.is_contiguous() + else update_matrix.contiguous() ) - orth_stacked = zeropower_via_newtonschulz5(stacked) - # === Apply scaling and update parameters === - orth_stacked.mul_(scale) - for i, (p, _, orig_shape) in enumerate(bucket): - p.add_(orth_stacked[i].reshape(orig_shape), alpha=-lr) + params.append(entry["param"]) + orig_shapes.append(entry["param"].shape) + + stacked = torch.stack(matrices, dim=0) + orth = _zeropower_via_newtonschulz5_3d(stacked) + orth.mul_(scale) + if orth.dtype != dtype: + orth = orth.to(dtype) + + for i, _ in enumerate(bucket_entries): + params[i].add_(orth[i].reshape(orig_shapes[i]), alpha=-lr) return loss diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 08cc07013b..d986d2f1b8 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -171,7 +171,7 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "momentum": params.get("momentum", 0.95), "adam_beta1": params.get("adam_beta1", 0.9), "adam_beta2": params.get("adam_beta2", 0.95), - "lr_adjust": params.get("lr_adjust", 0.0), + "lr_adjust": params.get("lr_adjust", 10.0), "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), } return opt_type, opt_param @@ -756,7 +756,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: float(self.opt_param.get("adam_beta1", 0.9)), float(self.opt_param.get("adam_beta2", 0.95)), ), - lr_adjust=float(self.opt_param.get("lr_adjust", 0.0)), + lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), ) if optimizer_state_dict is not None and self.restart_training: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index dcb3ea6c0a..d07ca0bdb2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3457,7 +3457,7 @@ def training_args( "lr_adjust", float, optional=True, - default=0.0, + default=10.0, doc=doc_only_pt_supported + "Learning rate adjustment factor for Adam (1D params). " "If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. " From eb1a5f99e4f81140d2ee12218d92532792ff4f9a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 11 Jan 2026 12:39:32 +0800 Subject: [PATCH 04/10] feat: add muon_min_2d_dim parameter for Muon optimizer behavior (cherry picked from commit c6f7e9f74391a676b3c7eaea449b1503e946f700) --- deepmd/pt/optimizer/muon.py | 206 +++++++++++++++++++++++++++++++---- deepmd/pt/train/training.py | 9 +- deepmd/utils/argcheck.py | 12 ++ source/tests/pt/test_muon.py | 24 ++++ 4 files changed, 226 insertions(+), 25 deletions(-) diff --git a/deepmd/pt/optimizer/muon.py b/deepmd/pt/optimizer/muon.py index 6044b5bcff..bea54dd3b1 100644 --- a/deepmd/pt/optimizer/muon.py +++ b/deepmd/pt/optimizer/muon.py @@ -78,9 +78,14 @@ def _maybe_compile( fn: callable, ) -> callable: """Compile a function if torch.compile is available.""" - if hasattr(torch, "compile"): - return torch.compile(fn, fullgraph=True, dynamic=True) - return fn + if not hasattr(torch, "compile"): + return fn + # Skip compile if default device is CUDA but CUDA is unavailable. + if hasattr(torch, "get_default_device"): + default_device = torch.get_default_device() + if default_device.type == "cuda" and not torch.cuda.is_available(): + return fn + return torch.compile(fn, fullgraph=True, dynamic=True) @_maybe_compile @@ -181,13 +186,54 @@ def zeropower_via_newtonschulz5( raise ValueError("Input must be 2D or 3D for Newton-Schulz orthogonalization.") +def should_fallback_to_adam_for_matrix( + p: torch.Tensor, + min_2d_dim: int, +) -> bool: + """ + Check if a 2D matrix should fallback to Adam due to small dimensions. + + Parameters + ---------- + p : torch.Tensor + Parameter tensor with ndim >= 2. + min_2d_dim : int + Minimum min(m, n) threshold for Muon. Matrices with min(m, n) >= + min_2d_dim use Muon; those with min(m, n) < min_2d_dim use Adam. + + Returns + ------- + bool + True if min(m, n) < min_2d_dim, False otherwise. + + Raises + ------ + ValueError + If tensor has ndim < 2. + """ + # === Step 1. Validate === + if p.ndim < 2: + raise ValueError("Parameter must have ndim >= 2 for Muon suitability check.") + + # === Step 2. Derive matrix shape consistent with Muon reshape === + m = int(p.shape[0]) + n = int(p.numel() // p.shape[0]) + + # === Step 3. Check if any dimension too small for Muon === + return min(m, n) < min_2d_dim + + class MuonOptimizer(Optimizer): """ - Muon optimizer with auxiliary Adam for non-matrix parameters. + Muon optimizer with small-2D Adam fallback and 1D Adam path. This optimizer applies different update rules based on parameter dimensionality: - - For >=2D parameters (weight matrices): Muon update with Newton-Schulz orthogonalization - - For 1D parameters (biases, layer norms): Standard Adam update + - For >=2D parameters with min(m, n) >= min_2d_dim: + Muon update with Newton-Schulz orthogonalization. + - For 2D parameters with min(m, n) < min_2d_dim (small matrices): + Adam update with scaled learning rate and update clipping. + - For 1D parameters (biases, layer norms): + Standard Adam update. This hybrid approach is effective because Muon's orthogonalization is designed for weight matrices, while Adam is more suitable for biases and normalization params. @@ -224,8 +270,19 @@ class MuonOptimizer(Optimizer): scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust. Default is 10.0 (Adam lr = lr/10). lr_adjust_coeff : float - Coefficient for match-RMS scaling with default 0.2. - Only effective when lr_adjust <= 0. + Dual-purpose coefficient with default 0.2: + 1. For Muon (when lr_adjust <= 0): match-RMS scaling factor, + scale = lr_adjust_coeff * sqrt(max(m, n)). + 2. For 2D Adam fallback: learning rate multiplier, + adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1). + The min(., 0.1) cap ensures conservative updates for small matrices. + min_2d_dim : int + Minimum min(m, n) threshold for Muon on 2D matrices. + Matrices with min(m, n) >= min_2d_dim use Muon; + those with min(m, n) < min_2d_dim use Adam fallback. + Must be >= 1. + Set to 1 to disable fallback. + Default is 1. Examples -------- @@ -245,7 +302,11 @@ def __init__( adam_betas: tuple[float, float] = (0.9, 0.95), lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, + min_2d_dim: int = 1, ) -> None: + if min_2d_dim < 1: + raise ValueError("min_2d_dim must be >= 1.") + defaults = { "lr": lr, "momentum": momentum, @@ -253,6 +314,7 @@ def __init__( "adam_betas": adam_betas, "lr_adjust": lr_adjust, "lr_adjust_coeff": lr_adjust_coeff, + "min_2d_dim": min_2d_dim, } super().__init__(params, defaults) # Static parameter routing: built once on first step() call. @@ -264,8 +326,9 @@ def _build_param_routing(self) -> None: Classify parameters into Muon and Adam routes (static routing). Routing logic: - - >=2D parameters → Muon path (Newton-Schulz + momentum) - - 1D parameters → Adam path (standard Adam update) + - >=2D parameters with min(m, n) >= min_2d_dim → Muon path + - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path + - 1D parameters → Adam path """ if self._routing_built: return @@ -273,24 +336,40 @@ def _build_param_routing(self) -> None: self._routing = [] for group in self.param_groups: muon_params: list[dict[str, Any]] = [] - adam_params: list[dict[str, Any]] = [] + adam_1d: list[dict[str, Any]] = [] + adam_matrix: list[dict[str, Any]] = [] + + min_2d_dim = group["min_2d_dim"] for p in group["params"]: - if p.ndim >= 2: - muon_params.append( + if p.ndim < 2: + adam_1d.append({"param": p}) + continue + + if (p.ndim == 2) and should_fallback_to_adam_for_matrix( + p, min_2d_dim=min_2d_dim + ): + adam_matrix.append( { "param": p, - "rows": int(p.shape[0]), - "cols": int(p.numel() // p.shape[0]), + "abs_floor": 1e-3 * math.sqrt(float(p.numel())), } ) - else: - adam_params.append({"param": p}) + continue + + muon_params.append( + { + "param": p, + "rows": int(p.shape[0]), + "cols": int(p.numel() // p.shape[0]), + } + ) self._routing.append( { "muon_params": muon_params, - "adam_params": adam_params, + "adam_1d": adam_1d, + "adam_matrix": adam_matrix, } ) @@ -332,13 +411,14 @@ def step( lr_adjust_coeff = group["lr_adjust_coeff"] # === Step 1. Adam update for 1D parameters (biases, norms, etc.) === + # === Step 1.1. Collect gradients and initialize state === adam_params: list[torch.Tensor] = [] adam_grads_fp32: list[torch.Tensor] = [] adam_exp_avgs: list[torch.Tensor] = [] adam_exp_avg_sqs: list[torch.Tensor] = [] adam_states: list[dict[str, Any]] = [] - for entry in route["adam_params"]: + for entry in route["adam_1d"]: p = entry["param"] grad = p.grad if grad is None: @@ -363,6 +443,7 @@ def step( adam_states.append(state) if adam_params: + # === Step 1.2. Update exp_avg / exp_avg_sq === adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust # exp_avg = beta1 * exp_avg + (1 - beta1) * grad @@ -371,6 +452,7 @@ def step( grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32) torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + # === Step 1.3. Bias correction and parameter update === for i, p in enumerate(adam_params): state = adam_states[i] bias_corr1 = 1 - state["beta1_pow"] @@ -381,7 +463,87 @@ def step( delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) - # === Step 2. Muon update for >=2D parameters (weight matrices) === + # === Step 2. Adam update for small 2D matrices (fallback path) === + # === Step 2.1. Collect gradients and initialize state === + adam_matrix_params: list[torch.Tensor] = [] + adam_matrix_grads_fp32: list[torch.Tensor] = [] + adam_matrix_exp_avgs: list[torch.Tensor] = [] + adam_matrix_exp_avg_sqs: list[torch.Tensor] = [] + adam_matrix_states: list[dict[str, Any]] = [] + adam_matrix_abs_floor: list[float] = [] + + for entry in route["adam_matrix"]: + p = entry["param"] + grad = p.grad + if grad is None: + continue + + grad_fp32 = grad.float() + + state = self.state[p] + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + + adam_matrix_params.append(p) + adam_matrix_grads_fp32.append(grad_fp32) + adam_matrix_exp_avgs.append(state["exp_avg"]) + adam_matrix_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_matrix_states.append(state) + adam_matrix_abs_floor.append(entry["abs_floor"]) + + if adam_matrix_params: + # === Step 2.2. Update exp_avg / exp_avg_sq with scaled lr === + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1) + + # exp_avg = beta1 * exp_avg + (1 - beta1) * grad + # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 + torch._foreach_lerp_( + adam_matrix_exp_avgs, adam_matrix_grads_fp32, 1 - adam_betas[0] + ) + grad_sq_m = torch._foreach_mul( + adam_matrix_grads_fp32, adam_matrix_grads_fp32 + ) + torch._foreach_lerp_( + adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1] + ) + + # === Step 2.3. Compute unclipped deltas === + raw_deltas: list[torch.Tensor] = [] + for i in range(len(adam_matrix_params)): + state = adam_matrix_states[i] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr_matrix / bias_corr1 + denom = ( + (adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS) + ) + raw_deltas.append(-step_size * (adam_matrix_exp_avgs[i] / denom)) + + # === Step 2.4. Clip updates by relative norm and apply === + max_rel_change = 0.05 + p_norms = torch.stack(torch._foreach_norm(adam_matrix_params)) + delta_norms = torch.stack(torch._foreach_norm(raw_deltas)) + floors = torch.tensor( + adam_matrix_abs_floor, + device=p_norms.device, + dtype=p_norms.dtype, + ) + max_delta = torch.maximum(max_rel_change * p_norms, floors) + scales_tensor = torch.clamp(max_delta / (delta_norms + 1e-12), max=1.0) + for i, delta in enumerate(raw_deltas): + delta.mul_(scales_tensor[i]) + + torch._foreach_add_(adam_matrix_params, raw_deltas) + + # === Step 3. Muon update for >=2D parameters (weight matrices) === + # === Step 3.1. Collect gradients and initialize momentum === muon_params_for_decay: list[torch.Tensor] = [] muon_grads: list[torch.Tensor] = [] muon_momentum_buffers: list[torch.Tensor] = [] @@ -406,12 +568,14 @@ def step( muon_momentum_buffers.append(buf) active_entries.append((entry, grad)) + # === Step 3.2. Apply weight decay (Muon path only) === if weight_decay > 0 and muon_params_for_decay: torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay) if not active_entries: continue + # === Step 3.3. Momentum update (Nesterov) === # m_t = beta * m_{t-1} + (1 - beta) * g_t torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum) # update = beta * m_t + (1 - beta) * g_t @@ -419,6 +583,7 @@ def step( muon_grads, muon_momentum_buffers, momentum ) + # === Step 3.4. Bucket by shape/device/dtype for batched NS === buckets: dict[ tuple[int, int, torch.device, torch.dtype], list[tuple[dict[str, Any], torch.Tensor]], @@ -432,6 +597,7 @@ def step( buckets[bucket_key] = [] buckets[bucket_key].append((entry, muon_updates[idx])) + # === Step 3.5. Newton-Schulz orthogonalization and update === for (rows, cols, _device, dtype), bucket_entries in buckets.items(): # scale = coeff * sqrt(max(m, n)) [match-RMS mode] # scale = sqrt(max(1, m/n)) [rectangular mode] diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d986d2f1b8..c09bc8a161 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -173,6 +173,7 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "adam_beta2": params.get("adam_beta2", 0.95), "lr_adjust": params.get("lr_adjust", 10.0), "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), + "min_2d_dim": params.get("min_2d_dim", 1), } return opt_type, opt_param @@ -652,8 +653,7 @@ def single_model_finetune( missing, unexpected = self.model.load_state_dict(state, strict=False) if missing or unexpected: log.warning( - "Checkpoint loaded non-strictly. " - f"Missing keys: {missing}, Unexpected keys: {unexpected}" + f"Checkpoint loaded non-strictly. Missing keys: {missing}, Unexpected keys: {unexpected}" ) # Get model prob for multi-task @@ -758,6 +758,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: ), lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), + min_2d_dim=int(self.opt_param.get("min_2d_dim", 1)), ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) @@ -1577,8 +1578,6 @@ def model_change_out_bias( model_type_map = _model.get_type_map() log.info( - f"Change output bias of {model_type_map!s} " - f"from {to_numpy_array(old_bias).reshape(-1)!s} " - f"to {to_numpy_array(new_bias).reshape(-1)!s}." + f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}." ) return _model diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index d07ca0bdb2..39b1ee9d4a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3471,6 +3471,18 @@ def training_args( doc=doc_only_pt_supported + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", ), + Argument( + "min_2d_dim", + int, + optional=True, + default=1, + alias=["muon_min_2d_dim"], + doc=doc_only_pt_supported + + "Minimum min(m, n) threshold for Muon on 2D matrices. " + "Matrices with min(m, n) >= min_2d_dim use Muon; " + "those with min(m, n) < min_2d_dim use Adam fallback. " + "Set to 1 to disable fallback.", + ), ], [], optional=True, diff --git a/source/tests/pt/test_muon.py b/source/tests/pt/test_muon.py index 7889ef9066..c5c32b6dd1 100644 --- a/source/tests/pt/test_muon.py +++ b/source/tests/pt/test_muon.py @@ -113,6 +113,30 @@ def test_muon_adam_separation(self) -> None: self.assertIn("exp_avg_sq", optimizer.state[model.bias]) self.assertNotIn("momentum_buffer", optimizer.state[model.bias]) + def test_muon_adam_fallback_small_2d(self) -> None: + """Test Adam fallback for small 2D matrices when min_2d_dim is set.""" + torch.manual_seed(42) + linear_small = torch.nn.Linear(10, 1, bias=False, device=self.device) + linear_large = torch.nn.Linear(10, 10, bias=False, device=self.device) + optimizer = MuonOptimizer( + list(linear_small.parameters()) + list(linear_large.parameters()), + lr=0.02, + min_2d_dim=2, + ) + + x = torch.randn(4, 10, device=self.device) + loss = linear_small(x).sum() + linear_large(x).sum() + loss.backward() + optimizer.step() + + # Small 2D weight should use Adam fallback. + self.assertIn("exp_avg", optimizer.state[linear_small.weight]) + self.assertNotIn("momentum_buffer", optimizer.state[linear_small.weight]) + + # Large 2D weight should use Muon. + self.assertIn("momentum_buffer", optimizer.state[linear_large.weight]) + self.assertNotIn("exp_avg", optimizer.state[linear_large.weight]) + def test_lr_adjust_modes(self) -> None: """Test lr_adjust modes: match-RMS (<=0) vs rectangular (>0).""" torch.manual_seed(42) From 586ca17bfb0c19987bddc3d119a831a8322dbfb2 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 11 Jan 2026 17:11:24 +0800 Subject: [PATCH 05/10] fix(pt): compatible with AdaMuon --- deepmd/pt/train/training.py | 2 ++ deepmd/utils/argcheck.py | 61 +++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index c09bc8a161..2a5cebc002 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -745,6 +745,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: float(self.opt_param.get("adam_beta1", 0.9)), float(self.opt_param.get("adam_beta2", 0.95)), ), + lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), + lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), ) elif self.opt_type == "Muon": self.optimizer = MuonOptimizer( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 39b1ee9d4a..c8ca286657 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3471,6 +3471,67 @@ def training_args( doc=doc_only_pt_supported + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", ), + ], + [], + optional=True, + ), + Argument( + "Muon", + dict, + [ + Argument( + "momentum", + float, + optional=True, + default=0.95, + alias=["muon_momentum"], + doc=doc_only_pt_supported + + "Momentum coefficient for Muon optimizer (>=2D params). " + "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.", + ), + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_only_pt_supported + + "Adam beta1 coefficient for 1D parameters (biases, norms).", + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.95, + doc=doc_only_pt_supported + + "Adam beta2 coefficient for 1D parameters (biases, norms).", + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.001, + doc=doc_only_pt_supported + + "Weight decay coefficient. Applied only to >=2D parameters (Muon path).", + ), + Argument( + "lr_adjust", + float, + optional=True, + default=10.0, + doc=doc_only_pt_supported + + "Learning rate adjustment mode for Muon scaling and Adam learning rate. " + "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. " + "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. " + "Default is 10.0 (Adam lr = lr/10).", + ), + Argument( + "lr_adjust_coeff", + float, + optional=True, + default=0.2, + doc=doc_only_pt_supported + + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", + ), Argument( "min_2d_dim", int, From 1d29bbf4f76a454e20d7eefa9b3246365876234c Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 12 Jan 2026 13:32:14 +0800 Subject: [PATCH 06/10] skip bf16 at test if no bf16 support --- deepmd/pt/optimizer/muon.py | 8 ++++---- source/tests/pt/test_muon.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/optimizer/muon.py b/deepmd/pt/optimizer/muon.py index bea54dd3b1..6b3053ff36 100644 --- a/deepmd/pt/optimizer/muon.py +++ b/deepmd/pt/optimizer/muon.py @@ -537,10 +537,10 @@ def step( ) max_delta = torch.maximum(max_rel_change * p_norms, floors) scales_tensor = torch.clamp(max_delta / (delta_norms + 1e-12), max=1.0) - for i, delta in enumerate(raw_deltas): - delta.mul_(scales_tensor[i]) - - torch._foreach_add_(adam_matrix_params, raw_deltas) + for i, (p, delta) in enumerate( + zip(adam_matrix_params, raw_deltas, strict=False) + ): + p.add_(delta.mul_(scales_tensor[i]).to(p.dtype)) # === Step 3. Muon update for >=2D parameters (weight matrices) === # === Step 3.1. Collect gradients and initialize momentum === diff --git a/source/tests/pt/test_muon.py b/source/tests/pt/test_muon.py index c5c32b6dd1..ea1e7cbf01 100644 --- a/source/tests/pt/test_muon.py +++ b/source/tests/pt/test_muon.py @@ -12,6 +12,32 @@ ) +def _bf16_matmul_supported(device: torch.device) -> bool: + """Check if bf16 matmul is reliably supported on the given device.""" + if device.type == "cuda": + if not torch.cuda.is_available(): + return False + # bf16 requires compute capability >= 8.0 (Ampere+) for native support + # or >= 7.0 (Volta) with tensor cores, but may have precision issues + if hasattr(torch.cuda, "is_bf16_supported"): + return torch.cuda.is_bf16_supported() + # Fallback: check compute capability directly + cap = torch.cuda.get_device_capability(device) + return cap[0] >= 8 + # CPU bf16 support: available on x86 with AVX-512 BF16 or ARM with BF16 extension + # Since it's hard to detect reliably, try a small matmul and check for errors + try: + a = torch.randn(4, 4, dtype=torch.bfloat16, device=device) + _ = torch.mm(a, a.T) + return True + except (RuntimeError, TypeError): + return False + + +BF16_SUPPORTED = _bf16_matmul_supported(env.DEVICE) + + +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") class TestNewtonSchulzOrthogonalization(unittest.TestCase): """Test Newton-Schulz orthogonalization algorithm.""" @@ -55,6 +81,7 @@ def test_invalid_input(self) -> None: zeropower_via_newtonschulz5(G_1d) +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") class TestMuonOptimizer(unittest.TestCase): """Test MuonOptimizer class.""" @@ -164,6 +191,7 @@ def test_lr_adjust_modes(self) -> None: ) +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") class TestMuonOptimizerStateDict(unittest.TestCase): """Test optimizer state dict save/load.""" From d9af1d934f22d6d86f8ae370ad16399a33a7808c Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 13 Jan 2026 10:48:04 +0800 Subject: [PATCH 07/10] rename custom muon to hybridmuon --- deepmd/pt/optimizer/__init__.py | 13 +++++--- .../pt/optimizer/{muon.py => hybrid_muon.py} | 31 +++++++++++++------ deepmd/pt/train/training.py | 8 ++--- deepmd/utils/argcheck.py | 18 +++++++---- .../pt/{test_muon.py => test_hybrid_muon.py} | 26 ++++++++-------- 5 files changed, 59 insertions(+), 37 deletions(-) rename deepmd/pt/optimizer/{muon.py => hybrid_muon.py} (95%) rename source/tests/pt/{test_muon.py => test_hybrid_muon.py} (90%) diff --git a/deepmd/pt/optimizer/__init__.py b/deepmd/pt/optimizer/__init__.py index 6da11ebd0a..1899f27fff 100644 --- a/deepmd/pt/optimizer/__init__.py +++ b/deepmd/pt/optimizer/__init__.py @@ -2,14 +2,19 @@ from .adamuon import ( AdaMuonOptimizer, ) +from .hybrid_muon import ( + HybridMuonOptimizer, +) from .KFWrapper import ( KFOptimizerWrapper, ) from .LKF import ( LKFOptimizer, ) -from .muon import ( - MuonOptimizer, -) -__all__ = ["AdaMuonOptimizer", "KFOptimizerWrapper", "LKFOptimizer", "MuonOptimizer"] +__all__ = [ + "AdaMuonOptimizer", + "HybridMuonOptimizer", + "KFOptimizerWrapper", + "LKFOptimizer", +] diff --git a/deepmd/pt/optimizer/muon.py b/deepmd/pt/optimizer/hybrid_muon.py similarity index 95% rename from deepmd/pt/optimizer/muon.py rename to deepmd/pt/optimizer/hybrid_muon.py index 6b3053ff36..c624e0e4c0 100644 --- a/deepmd/pt/optimizer/muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -1,10 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """ -Muon optimizer for DeePMD-kit PyTorch backend. +HybridMuon optimizer for DeePMD-kit PyTorch backend. -Muon is an optimizer that applies Newton-Schulz orthogonalization to the gradient -before using momentum, resulting in orthogonalized updates for weight matrices. -This can improve training stability and convergence for certain architectures. +HybridMuon is a HYBRID optimizer that automatically combines Muon and Adam: +- For >=2D parameters with min(m,n) >= min_2d_dim: Muon update with Newton-Schulz +- For 2D parameters with min(m,n) < min_2d_dim: Adam fallback with update clipping +- For 1D parameters (biases, layer norms): Standard Adam + +This is different from PyTorch's torch.optim.Muon, which ONLY supports 2D parameters +and requires manual configuration of AdamW for 1D parameters. HybridMuon provides +automatic routing based on parameter dimensionality. Algorithm --------- @@ -33,9 +38,15 @@ - Muon gradients: cast to parameter dtype before momentum update - Adam gradients: cast to float32 for update computation -Reference ---------- -https://github.com/KellerJordan/Muon +References +---------- +.. [1] Keller Jordan, "Muon: An optimizer for hidden layers in neural networks." + https://kellerjordan.github.io/posts/muon/ + https://github.com/KellerJordan/Muon +.. [2] Moonshot team, "Muon is Scalable for LLM Training," arXiv:2502.16982, 2025. + https://arxiv.org/abs/2502.16982 +.. [3] Moonlight GitHub Repository. + https://github.com/MoonshotAI/Moonlight """ from __future__ import ( @@ -223,9 +234,9 @@ def should_fallback_to_adam_for_matrix( return min(m, n) < min_2d_dim -class MuonOptimizer(Optimizer): +class HybridMuonOptimizer(Optimizer): """ - Muon optimizer with small-2D Adam fallback and 1D Adam path. + HybridMuon optimizer with small-2D Adam fallback and 1D Adam path. This optimizer applies different update rules based on parameter dimensionality: - For >=2D parameters with min(m, n) >= min_2d_dim: @@ -286,7 +297,7 @@ class MuonOptimizer(Optimizer): Examples -------- - >>> optimizer = MuonOptimizer(model.parameters(), lr=1e-3) + >>> optimizer = HybridMuonOptimizer(model.parameters(), lr=1e-3) >>> for epoch in range(epochs): ... optimizer.zero_grad() ... loss.backward() diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 2a5cebc002..6b471a31cc 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -43,9 +43,9 @@ ) from deepmd.pt.optimizer import ( AdaMuonOptimizer, + HybridMuonOptimizer, KFOptimizerWrapper, LKFOptimizer, - MuonOptimizer, ) from deepmd.pt.train.wrapper import ( ModelWrapper, @@ -748,8 +748,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), ) - elif self.opt_type == "Muon": - self.optimizer = MuonOptimizer( + elif self.opt_type == "HybridMuon": + self.optimizer = HybridMuonOptimizer( self.wrapper.parameters(), lr=self.lr_exp.start_lr, momentum=float(self.opt_param.get("momentum", 0.95)), @@ -838,7 +838,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() - if self.opt_type in ["Adam", "AdamW", "AdaMuon", "Muon"]: + if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index c8ca286657..935cbb7813 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3476,7 +3476,7 @@ def training_args( optional=True, ), Argument( - "Muon", + "HybridMuon", dict, [ Argument( @@ -3486,7 +3486,7 @@ def training_args( default=0.95, alias=["muon_momentum"], doc=doc_only_pt_supported - + "Momentum coefficient for Muon optimizer (>=2D params). " + + "Momentum coefficient for HybridMuon optimizer (>=2D params). " "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.", ), Argument( @@ -3511,7 +3511,7 @@ def training_args( optional=True, default=0.001, doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to >=2D parameters (Muon path).", + + "Weight decay coefficient. Applied only to >=2D parameters (HybridMuon path).", ), Argument( "lr_adjust", @@ -3519,7 +3519,7 @@ def training_args( optional=True, default=10.0, doc=doc_only_pt_supported - + "Learning rate adjustment mode for Muon scaling and Adam learning rate. " + + "Learning rate adjustment mode for HybridMuon scaling and Adam learning rate. " "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. " "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. " "Default is 10.0 (Adam lr = lr/10).", @@ -3539,14 +3539,20 @@ def training_args( default=1, alias=["muon_min_2d_dim"], doc=doc_only_pt_supported - + "Minimum min(m, n) threshold for Muon on 2D matrices. " - "Matrices with min(m, n) >= min_2d_dim use Muon; " + + "Minimum min(m, n) threshold for HybridMuon on 2D matrices. " + "Matrices with min(m, n) >= min_2d_dim use HybridMuon; " "those with min(m, n) < min_2d_dim use Adam fallback. " "Set to 1 to disable fallback.", ), ], [], optional=True, + doc=doc_only_pt_supported + + "HybridMuon optimizer (DeePMD-kit custom implementation). " + + "This is a Hybrid optimizer that automatically combines Muon and Adam. " + + "For >=2D params: Muon update with Newton-Schulz. " + + "For 1D params: Standard Adam. " + + "This is DIFFERENT from PyTorch's torch.optim.Muon which ONLY supports 2D parameters.", ), ], optional=True, diff --git a/source/tests/pt/test_muon.py b/source/tests/pt/test_hybrid_muon.py similarity index 90% rename from source/tests/pt/test_muon.py rename to source/tests/pt/test_hybrid_muon.py index ea1e7cbf01..10698cc63e 100644 --- a/source/tests/pt/test_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -3,8 +3,8 @@ import torch -from deepmd.pt.optimizer.muon import ( - MuonOptimizer, +from deepmd.pt.optimizer.hybrid_muon import ( + HybridMuonOptimizer, zeropower_via_newtonschulz5, ) from deepmd.pt.utils import ( @@ -82,8 +82,8 @@ def test_invalid_input(self) -> None: @unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") -class TestMuonOptimizer(unittest.TestCase): - """Test MuonOptimizer class.""" +class TestHybridMuonOptimizer(unittest.TestCase): + """Test HybridMuonOptimizer class.""" def setUp(self) -> None: self.device = env.DEVICE @@ -96,7 +96,7 @@ def test_step(self) -> None: torch.nn.ReLU(), torch.nn.Linear(20, 5, device=self.device), ) - optimizer = MuonOptimizer(model.parameters(), lr=0.02) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) x = torch.randn(4, 10, device=self.device) model(x).sum().backward() @@ -111,7 +111,7 @@ def test_weight_decay(self) -> None: """Test weight decay reduces parameter norm.""" torch.manual_seed(42) model = torch.nn.Linear(10, 10, device=self.device) - optimizer = MuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1) initial_norm = model.weight.norm().item() for _ in range(10): @@ -126,7 +126,7 @@ def test_muon_adam_separation(self) -> None: """Test Muon for 2D params, Adam for 1D params.""" torch.manual_seed(42) model = torch.nn.Linear(10, 10, device=self.device) - optimizer = MuonOptimizer(model.parameters(), lr=0.02) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) x = torch.randn(4, 10, device=self.device) model(x).sum().backward() @@ -145,7 +145,7 @@ def test_muon_adam_fallback_small_2d(self) -> None: torch.manual_seed(42) linear_small = torch.nn.Linear(10, 1, bias=False, device=self.device) linear_large = torch.nn.Linear(10, 10, bias=False, device=self.device) - optimizer = MuonOptimizer( + optimizer = HybridMuonOptimizer( list(linear_small.parameters()) + list(linear_large.parameters()), lr=0.02, min_2d_dim=2, @@ -172,8 +172,8 @@ def test_lr_adjust_modes(self) -> None: model2 = torch.nn.Linear(10, 20, bias=False, device=self.device) model2.load_state_dict(model1.state_dict()) - opt1 = MuonOptimizer(model1.parameters(), lr=0.02, lr_adjust=0.0) - opt2 = MuonOptimizer(model2.parameters(), lr=0.02, lr_adjust=10.0) + opt1 = HybridMuonOptimizer(model1.parameters(), lr=0.02, lr_adjust=0.0) + opt2 = HybridMuonOptimizer(model2.parameters(), lr=0.02, lr_adjust=10.0) x = torch.randn(4, 10, device=self.device) @@ -192,7 +192,7 @@ def test_lr_adjust_modes(self) -> None: @unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") -class TestMuonOptimizerStateDict(unittest.TestCase): +class TestHybridMuonOptimizerStateDict(unittest.TestCase): """Test optimizer state dict save/load.""" def setUp(self) -> None: @@ -202,7 +202,7 @@ def test_state_dict_save_load(self) -> None: """Test saving and loading optimizer state.""" torch.manual_seed(42) model = torch.nn.Linear(10, 10, device=self.device) - optimizer = MuonOptimizer(model.parameters(), lr=0.02) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) for _ in range(3): optimizer.zero_grad() @@ -212,7 +212,7 @@ def test_state_dict_save_load(self) -> None: state_dict = optimizer.state_dict() - optimizer2 = MuonOptimizer(model.parameters(), lr=0.02) + optimizer2 = HybridMuonOptimizer(model.parameters(), lr=0.02) optimizer2.load_state_dict(state_dict) # Verify state matches by param id, not iteration order From d3a5abf6e7fc2611f26426cdecebe5ca45288eef Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 13 Jan 2026 11:27:29 +0800 Subject: [PATCH 08/10] refactor(pt): align HybridMuon dtype behavior with official PyTorch Muon Changes: 1. Remove dtype conversion: NS output (bfloat16) now directly applied to parameters, matching torch.optim.Muon behavior where PyTorch handles mixed precision automatically. 2. Add muon_2d_only parameter (default True): When True, only 2D parameters use Muon; >2D parameters use Adam without weight decay. This matches PyTorch's official torch.optim.Muon which only supports 2D matrices. 3. Merge NS_EPS and ADAM_EPS into single EPS constant (both 1e-7). 4. Update dtype documentation to reflect actual behavior: - NS output (bfloat16) directly applied to parameters - Muon momentum buffer follows gradient dtype (not param dtype) 5. Update weight_decay docstring from ">=2D params" to "Muon-routed parameters" for accuracy with muon_2d_only=True. --- deepmd/pt/optimizer/hybrid_muon.py | 126 ++++++++++++++++++++++------- deepmd/pt/train/training.py | 2 + deepmd/utils/argcheck.py | 12 ++- 3 files changed, 110 insertions(+), 30 deletions(-) diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py index c624e0e4c0..abf4d3a572 100644 --- a/deepmd/pt/optimizer/hybrid_muon.py +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -34,8 +34,9 @@ Dtype Behavior -------------- - Newton-Schulz iterations: always bfloat16 (matches official Muon) +- NS output (bfloat16) directly applied to parameters (PyTorch handles mixed precision) - Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability -- Muon gradients: cast to parameter dtype before momentum update +- Muon momentum buffer: follows gradient dtype (grad -> buffer -> update) - Adam gradients: cast to float32 for update computation References @@ -75,10 +76,8 @@ # Newton-Schulz iteration count NS_STEPS: int = 5 -# Numerical stability epsilon for norm clamping -NS_EPS: float = 1e-7 -# Adam epsilon for numerical stability -ADAM_EPS: float = 1e-7 +# Numerical stability epsilon for norm clamping and Adam +EPS: float = 1e-7 # Quintic Newton-Schulz polynomial coefficients NS_COEFF_A: float = 3.4445 NS_COEFF_B: float = -4.7750 @@ -118,7 +117,7 @@ def _zeropower_via_newtonschulz5_2d( X = X.transpose(-2, -1) # === Step 2. Normalize Frobenius norm to at most 1 === - X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS) + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) # === Step 3. Newton-Schulz iterations with fused GEMM === for _ in range(NS_STEPS): @@ -152,7 +151,7 @@ def _zeropower_via_newtonschulz5_3d( X = X.transpose(-2, -1) # === Step 2. Normalize Frobenius norm to at most 1 === - X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS) + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) # === Step 3. Newton-Schulz iterations with batched fused GEMM === for _ in range(NS_STEPS): @@ -270,7 +269,7 @@ class HybridMuonOptimizer(Optimizer): momentum : float Momentum coefficient for Muon with default 0.95. weight_decay : float - Weight decay coefficient (applied only to >=2D params) with default 0.001. + Weight decay coefficient (applied only to Muon-routed parameters) with default 0.001. adam_betas : tuple[float, float] Adam beta coefficients with default (0.9, 0.95). lr_adjust : float @@ -287,6 +286,11 @@ class HybridMuonOptimizer(Optimizer): 2. For 2D Adam fallback: learning rate multiplier, adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1). The min(., 0.1) cap ensures conservative updates for small matrices. + muon_2d_only : bool + If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). + Parameters with ndim > 2 use Adam without weight decay. + If False, all >=2D parameters use Muon (default behavior). + Default is True. min_2d_dim : int Minimum min(m, n) threshold for Muon on 2D matrices. Matrices with min(m, n) >= min_2d_dim use Muon; @@ -313,6 +317,7 @@ def __init__( adam_betas: tuple[float, float] = (0.9, 0.95), lr_adjust: float = 10.0, lr_adjust_coeff: float = 0.2, + muon_2d_only: bool = True, min_2d_dim: int = 1, ) -> None: if min_2d_dim < 1: @@ -325,6 +330,7 @@ def __init__( "adam_betas": adam_betas, "lr_adjust": lr_adjust, "lr_adjust_coeff": lr_adjust_coeff, + "muon_2d_only": muon_2d_only, "min_2d_dim": min_2d_dim, } super().__init__(params, defaults) @@ -337,9 +343,11 @@ def _build_param_routing(self) -> None: Classify parameters into Muon and Adam routes (static routing). Routing logic: - - >=2D parameters with min(m, n) >= min_2d_dim → Muon path - - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path - 1D parameters → Adam path + - >2D parameters (when muon_2d_only=True) → Adam path + - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path + - 2D parameters with min(m, n) >= min_2d_dim → Muon path + - >=2D parameters (when muon_2d_only=False) → Muon path """ if self._routing_built: return @@ -349,14 +357,23 @@ def _build_param_routing(self) -> None: muon_params: list[dict[str, Any]] = [] adam_1d: list[dict[str, Any]] = [] adam_matrix: list[dict[str, Any]] = [] + adam_nd: list[dict[str, Any]] = [] min_2d_dim = group["min_2d_dim"] + muon_2d_only = group["muon_2d_only"] for p in group["params"]: + # === Step 1. 1D parameters → Adam === if p.ndim < 2: adam_1d.append({"param": p}) continue + # === Step 2. >2D parameters (when muon_2d_only=True) → Adam === + if muon_2d_only and p.ndim > 2: + adam_nd.append({"param": p}) + continue + + # === Step 3. 2D small matrices → Adam fallback === if (p.ndim == 2) and should_fallback_to_adam_for_matrix( p, min_2d_dim=min_2d_dim ): @@ -368,6 +385,7 @@ def _build_param_routing(self) -> None: ) continue + # === Step 4. >=2D (or 2D only when muon_2d_only=True) → Muon === muon_params.append( { "param": p, @@ -381,6 +399,7 @@ def _build_param_routing(self) -> None: "muon_params": muon_params, "adam_1d": adam_1d, "adam_matrix": adam_matrix, + "adam_nd": adam_nd, } ) @@ -470,12 +489,67 @@ def step( bias_corr2 = 1 - state["beta2_pow"] step_size = adam_lr / bias_corr1 # delta = -step_size * m_hat / (sqrt(v_hat) + eps) - denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS) + denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) p.add_(delta_fp32.to(p.dtype)) - # === Step 2. Adam update for small 2D matrices (fallback path) === + # === Step 2. Adam update for >2D parameters (when muon_2d_only=True) === # === Step 2.1. Collect gradients and initialize state === + adam_nd_params: list[torch.Tensor] = [] + adam_nd_grads_fp32: list[torch.Tensor] = [] + adam_nd_exp_avgs: list[torch.Tensor] = [] + adam_nd_exp_avg_sqs: list[torch.Tensor] = [] + adam_nd_states: list[dict[str, Any]] = [] + + for entry in route.get("adam_nd", []): + p = entry["param"] + grad = p.grad + if grad is None: + continue + + grad_fp32 = grad.float() + + state = self.state[p] + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + + adam_nd_params.append(p) + adam_nd_grads_fp32.append(grad_fp32) + adam_nd_exp_avgs.append(state["exp_avg"]) + adam_nd_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_nd_states.append(state) + + if adam_nd_params: + # === Step 2.2. Update exp_avg / exp_avg_sq === + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + + # exp_avg = beta1 * exp_avg + (1 - beta1) * grad + # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 + torch._foreach_lerp_( + adam_nd_exp_avgs, adam_nd_grads_fp32, 1 - adam_betas[0] + ) + grad_sq = torch._foreach_mul(adam_nd_grads_fp32, adam_nd_grads_fp32) + torch._foreach_lerp_(adam_nd_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + + # === Step 2.3. Bias correction and parameter update === + for i, p in enumerate(adam_nd_params): + state = adam_nd_states[i] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr / bias_corr1 + # delta = -step_size * m_hat / (sqrt(v_hat) + eps) + denom = (adam_nd_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) + delta_fp32 = -step_size * (adam_nd_exp_avgs[i] / denom) + p.add_(delta_fp32.to(p.dtype)) + + # === Step 3. Adam update for small 2D matrices (fallback path) === + # === Step 3.1. Collect gradients and initialize state === adam_matrix_params: list[torch.Tensor] = [] adam_matrix_grads_fp32: list[torch.Tensor] = [] adam_matrix_exp_avgs: list[torch.Tensor] = [] @@ -509,7 +583,7 @@ def step( adam_matrix_abs_floor.append(entry["abs_floor"]) if adam_matrix_params: - # === Step 2.2. Update exp_avg / exp_avg_sq with scaled lr === + # === Step 3.2. Update exp_avg / exp_avg_sq with scaled lr === adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1) @@ -525,19 +599,17 @@ def step( adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1] ) - # === Step 2.3. Compute unclipped deltas === + # === Step 3.3. Compute unclipped deltas === raw_deltas: list[torch.Tensor] = [] for i in range(len(adam_matrix_params)): state = adam_matrix_states[i] bias_corr1 = 1 - state["beta1_pow"] bias_corr2 = 1 - state["beta2_pow"] step_size = adam_lr_matrix / bias_corr1 - denom = ( - (adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS) - ) + denom = (adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) raw_deltas.append(-step_size * (adam_matrix_exp_avgs[i] / denom)) - # === Step 2.4. Clip updates by relative norm and apply === + # === Step 3.4. Clip updates by relative norm and apply === max_rel_change = 0.05 p_norms = torch.stack(torch._foreach_norm(adam_matrix_params)) delta_norms = torch.stack(torch._foreach_norm(raw_deltas)) @@ -553,8 +625,8 @@ def step( ): p.add_(delta.mul_(scales_tensor[i]).to(p.dtype)) - # === Step 3. Muon update for >=2D parameters (weight matrices) === - # === Step 3.1. Collect gradients and initialize momentum === + # === Step 4. Muon update for >=2D parameters (weight matrices) === + # === Step 4.1. Collect gradients and initialize momentum === muon_params_for_decay: list[torch.Tensor] = [] muon_grads: list[torch.Tensor] = [] muon_momentum_buffers: list[torch.Tensor] = [] @@ -579,14 +651,14 @@ def step( muon_momentum_buffers.append(buf) active_entries.append((entry, grad)) - # === Step 3.2. Apply weight decay (Muon path only) === + # === Step 4.2. Apply weight decay (Muon path only) === if weight_decay > 0 and muon_params_for_decay: torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay) if not active_entries: continue - # === Step 3.3. Momentum update (Nesterov) === + # === Step 4.3. Momentum update (Nesterov) === # m_t = beta * m_{t-1} + (1 - beta) * g_t torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum) # update = beta * m_t + (1 - beta) * g_t @@ -594,7 +666,7 @@ def step( muon_grads, muon_momentum_buffers, momentum ) - # === Step 3.4. Bucket by shape/device/dtype for batched NS === + # === Step 4.4. Bucket by shape/device/dtype for batched NS === buckets: dict[ tuple[int, int, torch.device, torch.dtype], list[tuple[dict[str, Any], torch.Tensor]], @@ -608,8 +680,8 @@ def step( buckets[bucket_key] = [] buckets[bucket_key].append((entry, muon_updates[idx])) - # === Step 3.5. Newton-Schulz orthogonalization and update === - for (rows, cols, _device, dtype), bucket_entries in buckets.items(): + # === Step 4.5. Newton-Schulz orthogonalization and update === + for (rows, cols, _device, _), bucket_entries in buckets.items(): # scale = coeff * sqrt(max(m, n)) [match-RMS mode] # scale = sqrt(max(1, m/n)) [rectangular mode] if lr_adjust <= 0: @@ -626,8 +698,6 @@ def step( orth = _zeropower_via_newtonschulz5_2d(update_matrix) orth.mul_(scale) delta = orth.reshape(entry["param"].shape) - if delta.dtype != dtype: - delta = delta.to(dtype) entry["param"].add_(delta, alpha=-lr) continue @@ -648,8 +718,6 @@ def step( stacked = torch.stack(matrices, dim=0) orth = _zeropower_via_newtonschulz5_3d(stacked) orth.mul_(scale) - if orth.dtype != dtype: - orth = orth.to(dtype) for i, _ in enumerate(bucket_entries): params[i].add_(orth[i].reshape(orig_shapes[i]), alpha=-lr) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6b471a31cc..6720e6ff0b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -173,6 +173,7 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "adam_beta2": params.get("adam_beta2", 0.95), "lr_adjust": params.get("lr_adjust", 10.0), "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), + "muon_2d_only": params.get("muon_2d_only", True), "min_2d_dim": params.get("min_2d_dim", 1), } return opt_type, opt_param @@ -760,6 +761,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: ), lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), + muon_2d_only=bool(self.opt_param.get("muon_2d_only", True)), min_2d_dim=int(self.opt_param.get("min_2d_dim", 1)), ) if optimizer_state_dict is not None and self.restart_training: diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 935cbb7813..8c20bb8bf4 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3511,7 +3511,7 @@ def training_args( optional=True, default=0.001, doc=doc_only_pt_supported - + "Weight decay coefficient. Applied only to >=2D parameters (HybridMuon path).", + + "Weight decay coefficient. Applied only to Muon-routed parameters", ), Argument( "lr_adjust", @@ -3532,6 +3532,16 @@ def training_args( doc=doc_only_pt_supported + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", ), + Argument( + "muon_2d_only", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + + "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). " + + "Parameters with ndim > 2 use Adam without weight decay. " + + "If False, all >=2D parameters use Muon.", + ), Argument( "min_2d_dim", int, From 7c3958fcd30dbe80be08db39a4ce8a26f8a47f97 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Tue, 13 Jan 2026 17:01:39 +0800 Subject: [PATCH 09/10] fix B905 --- source/tests/pt/test_hybrid_muon.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py index 10698cc63e..77973c5728 100644 --- a/source/tests/pt/test_hybrid_muon.py +++ b/source/tests/pt/test_hybrid_muon.py @@ -104,7 +104,9 @@ def test_step(self) -> None: initial_params = [p.clone() for p in model.parameters()] optimizer.step() - for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): + for i, (p, init_p) in enumerate( + zip(model.parameters(), initial_params, strict=True) + ): self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change") def test_weight_decay(self) -> None: From 6109e6e196013222e2fbe2d134d60fa16fb2a54f Mon Sep 17 00:00:00 2001 From: OutisLi Date: Wed, 14 Jan 2026 13:15:22 +0800 Subject: [PATCH 10/10] fix --- deepmd/pt/train/training.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 6720e6ff0b..20497a0ceb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -740,29 +740,29 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.optimizer = AdaMuonOptimizer( self.wrapper.parameters(), lr=self.lr_exp.start_lr, - momentum=float(self.opt_param.get("momentum", 0.95)), - weight_decay=float(self.opt_param.get("weight_decay", 0.001)), + momentum=float(self.opt_param["momentum"]), + weight_decay=float(self.opt_param["weight_decay"]), adam_betas=( - float(self.opt_param.get("adam_beta1", 0.9)), - float(self.opt_param.get("adam_beta2", 0.95)), + float(self.opt_param["adam_beta1"]), + float(self.opt_param["adam_beta2"]), ), - lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), - lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), + lr_adjust=float(self.opt_param["lr_adjust"]), + lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), ) elif self.opt_type == "HybridMuon": self.optimizer = HybridMuonOptimizer( self.wrapper.parameters(), lr=self.lr_exp.start_lr, - momentum=float(self.opt_param.get("momentum", 0.95)), - weight_decay=float(self.opt_param.get("weight_decay", 0.001)), + momentum=float(self.opt_param["momentum"]), + weight_decay=float(self.opt_param["weight_decay"]), adam_betas=( - float(self.opt_param.get("adam_beta1", 0.9)), - float(self.opt_param.get("adam_beta2", 0.95)), + float(self.opt_param["adam_beta1"]), + float(self.opt_param["adam_beta2"]), ), - lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), - lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), - muon_2d_only=bool(self.opt_param.get("muon_2d_only", True)), - min_2d_dim=int(self.opt_param.get("min_2d_dim", 1)), + lr_adjust=float(self.opt_param["lr_adjust"]), + lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), + muon_2d_only=bool(self.opt_param["muon_2d_only"]), + min_2d_dim=int(self.opt_param["min_2d_dim"]), ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict)