From fc683d3b2564cb78c2b1d52b30931ae7237920fa Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 22:15:49 +0800 Subject: [PATCH 01/18] feat(dpmodel): add _project_frames helper for frame-aware DPA4 grid ops --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 31 ++++ .../dpmodel/test_dpa4_project_frames.py | 174 ++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 source/tests/common/dpmodel/test_dpa4_project_frames.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index dc7d2ab1a8..bb4ab0855e 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -97,6 +97,37 @@ def _softmax_last_axis(x: Any) -> Any: return e_x / xp.sum(e_x, axis=-1, keepdims=True) +def _project_frames(coeff: Any, proj: ChannelLinear, n_frames: int) -> Any: + """ + Apply a channel-only linear map to each Wigner-D frame independently. + + Parameters + ---------- + coeff + Frame-packed coefficients with shape ``(N, D, F, n_frames * C_in)``. + proj : ChannelLinear + Linear map acting on the per-frame channel axis (``C_in -> C_out``). + n_frames : int + Number of Wigner-D frames packed along the trailing axis. + + Returns + ------- + Array + Projected coefficients with shape ``(N, D, F, n_frames * C_out)``. + + Notes + ----- + ``to_grid`` and ``from_grid`` are frame-wise linear and commute with any + channel map, so applying the map at coefficient resolution here is identical + to applying it on the grid field while touching ``n_frames``-fold fewer rows + than the ``G``-point grid. + """ + xp = array_api_compat.array_namespace(coeff) + n_batch, coeff_dim, n_focus, _ = coeff.shape + projected = proj(xp.reshape(coeff, (n_batch, coeff_dim, n_focus, n_frames, -1))) + return xp.reshape(projected, (n_batch, coeff_dim, n_focus, -1)) + + class GridProduct(NativeOP): """Parameter-free quadratic grid product ``u(g) * v(g)``.""" diff --git a/source/tests/common/dpmodel/test_dpa4_project_frames.py b/source/tests/common/dpmodel/test_dpa4_project_frames.py new file mode 100644 index 0000000000..485621156e --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_project_frames.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the frame-aware DPA4 grid helper ``_project_frames``. + +These mirror the current pt ``deepmd.pt.model.descriptor.sezm_nn.grid_net`` +(refactored by PR #5552 to operate on coefficients). pt imports live inside +the test functions because ruff TID253 bans module-level ``deepmd.pt`` imports +under ``source/tests/common``. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + GridProduct as DPGridProduct, +) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + _project_frames, +) +from deepmd.dpmodel.descriptor.dpa4_nn.so3 import ( + ChannelLinear as DPChannelLinear, +) + + +@pytest.mark.parametrize("n_frames", [1, 2, 3]) # number of Wigner-D frames +def test_project_frames_parity(n_frames) -> None: + """The dpmodel ``_project_frames`` matches pt with a weight-copied ChannelLinear.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + _project_frames as pt_project_frames, + ) + from deepmd.pt.model.descriptor.sezm_nn.so3 import ( + ChannelLinear as PTChannelLinear, + ) + + c_in, c_out = 4, 6 + # pin to CPU so torch.from_numpy fp64 inputs and the module agree under the + # CUDA-default-device CI configuration + pt_proj = PTChannelLinear( + in_channels=c_in, + out_channels=c_out, + dtype=torch.float64, + bias=False, + trainable=True, + seed=11, + ).to("cpu") + rng = np.random.default_rng(2026) + with torch.no_grad(): + for p in pt_proj.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + state = {k: v.detach().cpu().numpy() for k, v in pt_proj.state_dict().items()} + assert set(state) == {"weight"} + + dp_proj = DPChannelLinear( + in_channels=c_in, + out_channels=c_out, + precision="float64", + bias=False, + trainable=True, + seed=11, + ) + dp_proj.weight = state["weight"] + + n_batch, coeff_dim, n_focus = 5, 9, 2 + coeff = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * c_in)) + + dp_out = _project_frames(coeff, dp_proj, n_frames) + pt_out = pt_project_frames(torch.from_numpy(coeff), pt_proj, n_frames) + assert dp_out.shape == (n_batch, coeff_dim, n_focus, n_frames * c_out) + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + + +def test_project_frames_torch_namespace() -> None: + """``_project_frames`` on torch input matches the numpy-input result. + + Array-API pitfall guard: the helper must work with any array namespace. + """ + import torch + + c_in, c_out, n_frames = 4, 5, 2 + dp_proj = DPChannelLinear( + in_channels=c_in, + out_channels=c_out, + precision="float64", + bias=False, + trainable=True, + seed=21, + ) + rng = np.random.default_rng(99) + coeff = rng.normal(size=(3, 9, 2, n_frames * c_in)) + + np_out = _project_frames(coeff, dp_proj, n_frames) + torch_out = _project_frames(torch.from_numpy(coeff), dp_proj, n_frames) + np.testing.assert_allclose( + np.asarray(np_out), + torch_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + + +def test_grid_product_parity() -> None: + """The dpmodel ``GridProduct`` matches pt over a real S2 projector's grid fns.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + GridProduct as PTGridProduct, + S2GridNet as PTS2GridNet, + ) + + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet as DPS2GridNet, + ) + + lmax, channels, n_focus = 2, 4, 1 + # op_type='glu' makes grid_op a GridProduct; we reuse the nets only for + # their (parameter-free, deterministic) _to_grid/_from_grid projectors. + pt_net = PTS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="glu", + dtype=torch.float64, + layout="ndfc", + grid_method="lebedev", + trainable=True, + seed=7, + ).to("cpu") + dp_net = DPS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="glu", + precision="float64", + layout="ndfc", + grid_method="lebedev", + trainable=True, + seed=7, + ) + + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(314) + left = rng.normal(size=(5, coeff_dim, n_focus, channels)) + right = rng.normal(size=(5, coeff_dim, n_focus, channels)) + scalar = rng.normal(size=(5, n_focus, 2 * channels)) + + dp_out = DPGridProduct().call( + left, + right, + scalar, + to_grid=dp_net._to_grid, + from_grid=dp_net._from_grid, + ) + pt_out = PTGridProduct()( + torch.from_numpy(left), + torch.from_numpy(right), + torch.from_numpy(scalar), + to_grid=pt_net._to_grid, + from_grid=pt_net._from_grid, + ) + assert dp_out.shape == (5, coeff_dim, n_focus, channels) + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) From 9e3dbec6fad3c25f27042c335dc5860160c8fa77 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 22:24:37 +0800 Subject: [PATCH 02/18] feat(dpmodel): generalize GridMLP to frame-aware (n_frames) for DPA4 SO3 grid --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 49 ++- .../dpmodel/test_dpa4_gridmlp_frames.py | 320 ++++++++++++++++++ .../pt/model/test_dpa4_dpmodel_parity.py | 1 + 3 files changed, 355 insertions(+), 15 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index bb4ab0855e..dd68365cfd 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -176,15 +176,20 @@ class GridMLP(NativeOP): """ Polynomial point-wise MLP applied independently at every grid point. - Specialized to the S2 ``n_frames == 1`` case, so the per-frame packing of - the pt ``GridMLP`` collapses to a plain channel concatenation in self mode. + Frame-aware port of the pt ``GridMLP``: operands are packed as + ``(N, D, F, n_frames * C)`` and every channel projection is applied to each + Wigner-D frame independently through :func:`_project_frames`. The S2 case + (``n_frames == 1``) reduces to a plain per-channel projection, byte-for-byte + identical to the previous S2-only specialization. Parameters ---------- channels : int - Number of channels per grid point. + Number of channels per grid point (per frame). mode : str Pairing mode, either ``"self"`` or ``"cross"``. + n_frames : int + Number of Wigner-D frames packed along the trailing channel axis. precision : str Parameter precision. trainable : bool @@ -198,6 +203,7 @@ def __init__( *, channels: int, mode: str, + n_frames: int, precision: str = DEFAULT_PRECISION, trainable: bool = True, seed: int | list[int] | None = None, @@ -206,6 +212,7 @@ def __init__( self.mode = str(mode).lower() if self.mode not in {"self", "cross"}: raise ValueError("`mode` must be either 'self' or 'cross'") + self.n_frames = int(n_frames) self.precision = precision self.trainable = bool(trainable) self.input_channels = ( @@ -241,7 +248,7 @@ def call( self, left: Any, right: Any, - scalar_pair: Any, + scalar_pair: Any = None, *, to_grid: Callable[[Any], Any], from_grid: Callable[[Any], Any], @@ -249,29 +256,38 @@ def call( """ Apply the polynomial point-wise MLP on coefficient operands. - In self mode both projections see the concatenation of the two operands - and can form self and cross quadratic channel terms. In cross mode the - query and context roles stay separate: ``(W_q query) * (W_c context)``. + In self mode both projections see the per-frame concatenation of the + two operands and can form self and cross quadratic channel terms. In + cross mode the query and context roles stay separate: + ``(W_q query) * (W_c context)``. Parameters ---------- left, right - Coefficient operands with shape ``(N, D, F, C)``. + Coefficient operands with shape ``(N, D, F, n_frames * C)``. scalar_pair Invariant routing signal; unused on this path. to_grid, from_grid Coefficient/grid projectors supplied by the owning grid net. """ + # === Step 1. Channel projections at coefficient resolution === if self.mode == "self": xp = array_api_compat.array_namespace(left) - fused = xp.concat([left, right], axis=-1) # (N, D, F, 2C) - left = self.left_proj(fused) - right = self.right_proj(fused) + left_shape = tuple(left.shape) + shape = (*left_shape[:-1], self.n_frames, -1) + fused = xp.concat( + [xp.reshape(left, shape), xp.reshape(right, shape)], axis=-1 + ) # per-frame concat -> (N, D, F, n_frames, 2C) + fused = xp.reshape(fused, (*left_shape[:-1], -1)) # (N, D, F, n_frames*2C) + left = _project_frames(fused, self.left_proj, self.n_frames) + right = _project_frames(fused, self.right_proj, self.n_frames) else: - left = self.left_proj(left) - right = self.right_proj(right) - coeff = from_grid(to_grid(left) * to_grid(right)) # (N, D, F, 2C) - return self.out_proj(coeff) # (N, D, F, C) + left = _project_frames(left, self.left_proj, self.n_frames) + right = _project_frames(right, self.right_proj, self.n_frames) + + # === Step 2. Quadratic product on the grid, projected back === + coeff = from_grid(to_grid(left) * to_grid(right)) + return _project_frames(coeff, self.out_proj, self.n_frames) def serialize(self) -> dict[str, Any]: """Serialize the GridMLP to a dict. @@ -285,6 +301,7 @@ def serialize(self) -> dict[str, Any]: "config": { "channels": self.channels, "mode": self.mode, + "n_frames": self.n_frames, "precision": np.dtype(PRECISION_DICT[self.precision]).name, "trainable": self.trainable, "seed": None, @@ -310,6 +327,7 @@ def deserialize(cls, data: dict[str, Any]) -> GridMLP: obj = cls( channels=int(config["channels"]), mode=str(config["mode"]), + n_frames=int(config["n_frames"]), precision=str(config["precision"]), trainable=bool(config["trainable"]), seed=config.get("seed"), @@ -589,6 +607,7 @@ def __init__( self.grid_op: NativeOP = GridMLP( channels=self.channels, mode=self.mode, + n_frames=self.n_frames, precision=self.precision, trainable=trainable, seed=child_seed(seed, 1), diff --git a/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py b/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py new file mode 100644 index 0000000000..6af6a71016 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the frame-aware DPA4 ``GridMLP``. + +These mirror the current pt ``deepmd.pt.model.descriptor.sezm_nn.grid_net`` +``GridMLP``, which packs operands as ``(N, D, F, n_frames * C)`` and projects +each Wigner-D frame independently. pt imports live inside the test functions +because ruff TID253 bans module-level ``deepmd.pt`` imports under +``source/tests/common``. + +The ``to_grid``/``from_grid`` callables are supplied as namespace-agnostic +closures that reproduce the pt ``BaseGridNet`` frame-aware projector einsums +(``"gdk,ndfkc->ngfc"`` / ``"dkg,ngfc->ndfkc"``) with random matrices. The same +matrices are fed to both backends, so the closures only need to be identical, +not orthonormal. ``test_gridmlp_s2_regression`` additionally checks the +``n_frames == 1`` path against a real S2 projector's grid functions. +""" + +import array_api_compat +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + GridMLP as DPGridMLP, +) + + +def _make_grid_fns(to_mat, from_mat, n_frames): + """Build namespace-agnostic frame-aware ``to_grid``/``from_grid`` closures. + + Parameters + ---------- + to_mat : np.ndarray + Shape ``(G, D, n_frames)``; reproduces ``projector.to_grid_mat``. + from_mat : np.ndarray + Shape ``(D, n_frames, G)``; reproduces ``projector.from_grid_mat``. + n_frames : int + Number of Wigner-D frames. + """ + + def to_grid(coeff): + # einsum "gdk,ndfkc->ngfc": sum over d (coeff dim) and k (frame). + xp = array_api_compat.array_namespace(coeff) + dev = array_api_compat.device(coeff) + mat = xp.asarray(to_mat, device=dev) + if mat.dtype != coeff.dtype: + mat = xp.astype(mat, coeff.dtype) + n_batch, coeff_dim, n_focus, _ = coeff.shape + # (N, D, F, n_frames, C) + cv = xp.reshape(coeff, (n_batch, coeff_dim, n_focus, n_frames, -1)) + # (N, G, D, F, n_frames, C) + prod = mat[None, :, :, None, :, None] * cv[:, None, :, :, :, :] + prod = xp.sum(prod, axis=4) # contract frame -> (N, G, D, F, C) + return xp.sum(prod, axis=2) # contract coeff dim -> (N, G, F, C) + + def from_grid(grid): + # einsum "dkg,ngfc->ndfkc": sum over g (grid point). + xp = array_api_compat.array_namespace(grid) + dev = array_api_compat.device(grid) + mat = xp.asarray(from_mat, device=dev) + if mat.dtype != grid.dtype: + mat = xp.astype(mat, grid.dtype) + n_batch, _, n_focus, n_channels = grid.shape + coeff_dim = from_mat.shape[0] + # (N, F, G, C) so the grid axis lands at position 4 of the 6D product + grid_p = xp.permute_dims(grid, (0, 2, 1, 3)) + # (N, D, F, n_frames, G, C) + prod = mat[None, :, None, :, :, None] * grid_p[:, None, :, None, :, :] + prod = xp.sum(prod, axis=4) # contract grid -> (N, D, F, n_frames, C) + return xp.reshape(prod, (n_batch, coeff_dim, n_focus, n_frames * n_channels)) + + return to_grid, from_grid + + +def _copy_pt_to_dp(pt_mlp, dp_mlp): + """Copy pt ``GridMLP`` state-dict weights into the dpmodel ``GridMLP``.""" + state = {k: v.detach().cpu().numpy() for k, v in pt_mlp.state_dict().items()} + assert set(state) == { + "left_proj.weight", + "right_proj.weight", + "out_proj.weight", + } + dp_mlp.left_proj.weight = state["left_proj.weight"] + dp_mlp.right_proj.weight = state["right_proj.weight"] + dp_mlp.out_proj.weight = state["out_proj.weight"] + + +@pytest.mark.parametrize("n_frames", [1, 2, 3]) # number of Wigner-D frames +@pytest.mark.parametrize("mode", ["self", "cross"]) # operand pairing mode +def test_gridmlp_parity(n_frames, mode) -> None: + """The dpmodel ``GridMLP`` matches pt over identical frame-aware grid fns.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + GridMLP as PTGridMLP, + ) + + channels, n_batch, coeff_dim, n_focus, grid_size = 4, 5, 9, 2, 7 + rng = np.random.default_rng(2026) + to_mat = rng.normal(size=(grid_size, coeff_dim, n_frames)) + from_mat = rng.normal(size=(coeff_dim, n_frames, grid_size)) + np_to_grid, np_from_grid = _make_grid_fns(to_mat, from_mat, n_frames) + + # pin to CPU so torch.from_numpy fp64 inputs and the module agree under the + # CUDA-default-device CI configuration + pt_mlp = PTGridMLP( + channels=channels, + mode=mode, + n_frames=n_frames, + dtype=torch.float64, + trainable=True, + seed=7, + ).to("cpu") + with torch.no_grad(): + for p in pt_mlp.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_mlp = DPGridMLP( + channels=channels, + mode=mode, + n_frames=n_frames, + precision="float64", + trainable=True, + seed=7, + ) + _copy_pt_to_dp(pt_mlp, dp_mlp) + + left = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + right = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + + dp_out = dp_mlp.call(left, right, to_grid=np_to_grid, from_grid=np_from_grid) + pt_out = pt_mlp( + torch.from_numpy(left), + torch.from_numpy(right), + None, + to_grid=np_to_grid, + from_grid=np_from_grid, + ) + assert dp_out.shape == (n_batch, coeff_dim, n_focus, n_frames * channels) + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # operand pairing mode +def test_gridmlp_s2_regression(mode) -> None: + """``n_frames == 1`` ``GridMLP`` matches pt over a real S2 projector. + + Guards the S2 path: the frame-aware reshape with ``n_frames == 1`` is an + identity, so the output stays byte-identical to the previous S2-only + specialization (which this generalization replaces). + """ + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + GridMLP as PTGridMLP, + S2GridNet as PTS2GridNet, + ) + + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet as DPS2GridNet, + ) + + lmax, channels, n_focus = 2, 4, 1 + # op_type='glu' makes grid_op a GridProduct; we reuse the nets only for + # their (parameter-free, deterministic) _to_grid/_from_grid S2 projectors. + pt_net = PTS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="glu", + dtype=torch.float64, + layout="ndfc", + grid_method="lebedev", + trainable=True, + seed=7, + ).to("cpu") + dp_net = DPS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="glu", + precision="float64", + layout="ndfc", + grid_method="lebedev", + trainable=True, + seed=7, + ) + + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(99) + + pt_mlp = PTGridMLP( + channels=channels, + mode=mode, + n_frames=1, + dtype=torch.float64, + trainable=True, + seed=13, + ).to("cpu") + with torch.no_grad(): + for p in pt_mlp.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_mlp = DPGridMLP( + channels=channels, + mode=mode, + n_frames=1, + precision="float64", + trainable=True, + seed=13, + ) + _copy_pt_to_dp(pt_mlp, dp_mlp) + + left = rng.normal(size=(5, coeff_dim, n_focus, channels)) + right = rng.normal(size=(5, coeff_dim, n_focus, channels)) + + dp_out = dp_mlp.call( + left, right, to_grid=dp_net._to_grid, from_grid=dp_net._from_grid + ) + pt_out = pt_mlp( + torch.from_numpy(left), + torch.from_numpy(right), + None, + to_grid=pt_net._to_grid, + from_grid=pt_net._from_grid, + ) + assert dp_out.shape == (5, coeff_dim, n_focus, channels) + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # operand pairing mode +def test_gridmlp_serialize_roundtrip(mode) -> None: + """Serialize -> deserialize -> forward is identical; n_frames in config.""" + channels, n_frames, n_batch, coeff_dim, n_focus, grid_size = 4, 2, 3, 9, 2, 7 + rng = np.random.default_rng(404) + to_mat = rng.normal(size=(grid_size, coeff_dim, n_frames)) + from_mat = rng.normal(size=(coeff_dim, n_frames, grid_size)) + to_grid, from_grid = _make_grid_fns(to_mat, from_mat, n_frames) + + mlp = DPGridMLP( + channels=channels, + mode=mode, + n_frames=n_frames, + precision="float64", + trainable=True, + seed=5, + ) + # perturb to non-default weights + mlp.left_proj.weight = mlp.left_proj.weight + 0.1 * rng.normal( + size=mlp.left_proj.weight.shape + ) + mlp.right_proj.weight = mlp.right_proj.weight + 0.1 * rng.normal( + size=mlp.right_proj.weight.shape + ) + mlp.out_proj.weight = mlp.out_proj.weight + 0.1 * rng.normal( + size=mlp.out_proj.weight.shape + ) + + data = mlp.serialize() + assert data["@version"] == 1 + assert data["config"]["n_frames"] == n_frames + restored = DPGridMLP.deserialize(data) + + left = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + right = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + out0 = mlp.call(left, right, to_grid=to_grid, from_grid=from_grid) + out1 = restored.call(left, right, to_grid=to_grid, from_grid=from_grid) + np.testing.assert_allclose( + np.asarray(out0), np.asarray(out1), rtol=1e-12, atol=1e-12 + ) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # operand pairing mode +def test_gridmlp_torch_namespace(mode) -> None: + """``GridMLP.call`` on torch input matches the numpy-input result. + + Array-API pitfall guard: the dpmodel forward must work with any namespace. + """ + import torch + + channels, n_frames, n_batch, coeff_dim, n_focus, grid_size = 4, 2, 3, 9, 2, 7 + rng = np.random.default_rng(77) + to_mat = rng.normal(size=(grid_size, coeff_dim, n_frames)) + from_mat = rng.normal(size=(coeff_dim, n_frames, grid_size)) + to_grid, from_grid = _make_grid_fns(to_mat, from_mat, n_frames) + + mlp = DPGridMLP( + channels=channels, + mode=mode, + n_frames=n_frames, + precision="float64", + trainable=True, + seed=9, + ) + left = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + right = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + + np_out = mlp.call(left, right, to_grid=to_grid, from_grid=from_grid) + torch_out = mlp.call( + torch.from_numpy(left), + torch.from_numpy(right), + to_grid=to_grid, + from_grid=from_grid, + ) + np.testing.assert_allclose( + np.asarray(np_out), + torch_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 90a986f77f..7970915c9c 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -1712,6 +1712,7 @@ def test_grid_mlp(self, mode) -> None: dp_mod = DPGridMLP( channels=self.channels, mode=mode, + n_frames=1, precision="float64", seed=9, ) From db407ba52ada18ede70e8d0fe3e956689fcde4c7 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 22:30:52 +0800 Subject: [PATCH 03/18] feat(dpmodel): generalize GridBranch to frame-aware (n_frames) for DPA4 SO3 grid --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 37 +- .../dpmodel/test_dpa4_gridbranch_frames.py | 332 ++++++++++++++++++ .../pt/model/test_dpa4_dpmodel_parity.py | 3 +- 3 files changed, 362 insertions(+), 10 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index dd68365cfd..32e9e5f07c 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -359,12 +359,20 @@ class GridBranch(NativeOP): quadratic product of grid fields, so rotations only act through the grid argument and the operation remains as band-limited as the product path. + Frame-aware port of the pt ``GridBranch``: operands are packed as + ``(N, D, F, n_frames * C)`` and every channel projection is applied to each + Wigner-D frame independently through :func:`_project_frames`. The S2 case + (``n_frames == 1``) reduces to a plain per-channel projection, byte-for-byte + identical to the previous S2-only specialization. + Parameters ---------- channels : int - Number of channels per grid point. + Number of channels per grid point (per frame). n_branches : int Number of scalar-routed product branches. + n_frames : int + Number of Wigner-D frames packed along the trailing channel axis. precision : str Parameter precision. trainable : bool @@ -378,6 +386,7 @@ def __init__( *, channels: int, n_branches: int, + n_frames: int, precision: str = DEFAULT_PRECISION, trainable: bool = True, seed: int | list[int] | None = None, @@ -386,6 +395,7 @@ def __init__( self.n_branches = int(n_branches) if self.n_branches < 1: raise ValueError("`n_branches` must be positive") + self.n_frames = int(n_frames) self.precision = precision self.trainable = bool(trainable) self.left_proj = ChannelLinear( @@ -433,23 +443,27 @@ def call( """ Apply scalar-routed grid branch mixing on coefficient operands. - The channel maps are applied at coefficient resolution and the grid - transform is deferred to the injected ``to_grid``/``from_grid`` - callables, matching the pt ``GridBranch`` specialized to the S2 - ``n_frames == 1`` case (so no per-frame packing is needed). + The channel maps are applied at coefficient resolution (per Wigner-D + frame via :func:`_project_frames`) and the grid transform is deferred to + the injected ``to_grid``/``from_grid`` callables, matching the pt + ``GridBranch``. The router operates on invariant scalars only, so the + softmax is frame-independent. Parameters ---------- left, right - Coefficient operands with shape ``(N, D, F, C)``. + Coefficient operands with shape ``(N, D, F, n_frames * C)``. scalar_pair Invariant router source with shape ``(N, F, 2*C)``. to_grid, from_grid Coefficient/grid projectors supplied by the owning grid net. """ xp = array_api_compat.array_namespace(left) - left = self.left_proj(left) # (N, D, F, N_branches * C) - right = self.right_proj(right) # (N, D, F, N_branches * C) + # === Step 1. Branch channel projections at coefficient resolution === + left = _project_frames(left, self.left_proj, self.n_frames) + right = _project_frames(right, self.right_proj, self.n_frames) + + # === Step 2. Quadratic branches on the grid, routed by scalars === value = to_grid(left) * to_grid(right) # (N, G, F, N_branches * C) n_batch, n_grid, n_focus, _ = value.shape value = xp.reshape( @@ -459,7 +473,9 @@ def call( router = _softmax_last_axis(self.router(scalar_pair)) # (N, F, N_branches) # einsum "ngfhc,nfh->ngfc" as a broadcast sum over the branch axis out = xp.sum(value * router[:, None, :, :, None], axis=3) # (N, G, F, C) - return self.out_proj(from_grid(out)) # (N, D, F, C) + + # === Step 3. Project back to coefficients and mix output channels === + return _project_frames(from_grid(out), self.out_proj, self.n_frames) def serialize(self) -> dict[str, Any]: """Serialize the GridBranch to a dict. @@ -473,6 +489,7 @@ def serialize(self) -> dict[str, Any]: "config": { "channels": self.channels, "n_branches": self.n_branches, + "n_frames": self.n_frames, "precision": np.dtype(PRECISION_DICT[self.precision]).name, "trainable": self.trainable, "seed": None, @@ -499,6 +516,7 @@ def deserialize(cls, data: dict[str, Any]) -> GridBranch: obj = cls( channels=int(config["channels"]), n_branches=int(config["n_branches"]), + n_frames=int(config["n_frames"]), precision=str(config["precision"]), trainable=bool(config["trainable"]), seed=config.get("seed"), @@ -616,6 +634,7 @@ def __init__( self.grid_op = GridBranch( channels=self.channels, n_branches=grid_branches, + n_frames=self.n_frames, precision=self.precision, trainable=trainable, seed=child_seed(seed, 1), diff --git a/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py b/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py new file mode 100644 index 0000000000..04c6640ce1 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the frame-aware DPA4 ``GridBranch``. + +These mirror the current pt ``deepmd.pt.model.descriptor.sezm_nn.grid_net`` +``GridBranch``, which packs operands as ``(N, D, F, n_frames * C)``, projects +each Wigner-D frame independently through ``_project_frames``, forms quadratic +grid product branches, and combines them with a scalar-routed softmax (the +router sees invariant scalars only). pt imports live inside the test functions +because ruff TID253 bans module-level ``deepmd.pt`` imports under +``source/tests/common``. + +The ``to_grid``/``from_grid`` callables are supplied as namespace-agnostic +closures (shared with ``test_dpa4_gridmlp_frames.py``'s approach) that reproduce +the pt ``BaseGridNet`` frame-aware projector einsums with random matrices. The +same matrices are fed to both backends, so the closures only need to be +identical, not orthonormal. ``test_gridbranch_s2_regression`` additionally +checks the ``n_frames == 1`` path against a real S2 projector's grid functions. +""" + +import array_api_compat +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + GridBranch as DPGridBranch, +) + + +def _make_grid_fns(to_mat, from_mat, n_frames): + """Build namespace-agnostic frame-aware ``to_grid``/``from_grid`` closures. + + Parameters + ---------- + to_mat : np.ndarray + Shape ``(G, D, n_frames)``; reproduces ``projector.to_grid_mat``. + from_mat : np.ndarray + Shape ``(D, n_frames, G)``; reproduces ``projector.from_grid_mat``. + n_frames : int + Number of Wigner-D frames. + """ + + def to_grid(coeff): + # einsum "gdk,ndfkc->ngfc": sum over d (coeff dim) and k (frame). + xp = array_api_compat.array_namespace(coeff) + dev = array_api_compat.device(coeff) + mat = xp.asarray(to_mat, device=dev) + if mat.dtype != coeff.dtype: + mat = xp.astype(mat, coeff.dtype) + n_batch, coeff_dim, n_focus, _ = coeff.shape + # (N, D, F, n_frames, C) + cv = xp.reshape(coeff, (n_batch, coeff_dim, n_focus, n_frames, -1)) + # (N, G, D, F, n_frames, C) + prod = mat[None, :, :, None, :, None] * cv[:, None, :, :, :, :] + prod = xp.sum(prod, axis=4) # contract frame -> (N, G, D, F, C) + return xp.sum(prod, axis=2) # contract coeff dim -> (N, G, F, C) + + def from_grid(grid): + # einsum "dkg,ngfc->ndfkc": sum over g (grid point). + xp = array_api_compat.array_namespace(grid) + dev = array_api_compat.device(grid) + mat = xp.asarray(from_mat, device=dev) + if mat.dtype != grid.dtype: + mat = xp.astype(mat, grid.dtype) + n_batch, _, n_focus, n_channels = grid.shape + coeff_dim = from_mat.shape[0] + # (N, F, G, C) so the grid axis lands at position 4 of the 6D product + grid_p = xp.permute_dims(grid, (0, 2, 1, 3)) + # (N, D, F, n_frames, G, C) + prod = mat[None, :, None, :, :, None] * grid_p[:, None, :, None, :, :] + prod = xp.sum(prod, axis=4) # contract grid -> (N, D, F, n_frames, C) + return xp.reshape(prod, (n_batch, coeff_dim, n_focus, n_frames * n_channels)) + + return to_grid, from_grid + + +def _copy_pt_to_dp(pt_branch, dp_branch): + """Copy pt ``GridBranch`` state-dict weights into the dpmodel ``GridBranch``.""" + state = {k: v.detach().cpu().numpy() for k, v in pt_branch.state_dict().items()} + assert set(state) == { + "left_proj.weight", + "right_proj.weight", + "router.weight", + "out_proj.weight", + } + dp_branch.left_proj.weight = state["left_proj.weight"] + dp_branch.right_proj.weight = state["right_proj.weight"] + dp_branch.router.weight = state["router.weight"] + dp_branch.out_proj.weight = state["out_proj.weight"] + + +@pytest.mark.parametrize("n_frames", [1, 2]) # number of Wigner-D frames +@pytest.mark.parametrize("n_branches", [1, 2]) # scalar-routed product branches +def test_gridbranch_parity(n_frames, n_branches) -> None: + """The dpmodel ``GridBranch`` matches pt over identical frame-aware grid fns.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + GridBranch as PTGridBranch, + ) + + channels, n_batch, coeff_dim, n_focus, grid_size = 4, 5, 9, 2, 7 + rng = np.random.default_rng(2026) + # to_grid collapses the n_frames axis but preserves the per-frame channel + # axis, so the projector matrices are channel-count agnostic and the same + # closures work for the (N_branches * C) operand width. + to_mat = rng.normal(size=(grid_size, coeff_dim, n_frames)) + from_mat = rng.normal(size=(coeff_dim, n_frames, grid_size)) + np_to_grid, np_from_grid = _make_grid_fns(to_mat, from_mat, n_frames) + + # pin to CPU so torch.from_numpy fp64 inputs and the module agree under the + # CUDA-default-device CI configuration + pt_branch = PTGridBranch( + channels=channels, + n_branches=n_branches, + n_frames=n_frames, + dtype=torch.float64, + trainable=True, + seed=7, + ).to("cpu") + with torch.no_grad(): + for p in pt_branch.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_branch = DPGridBranch( + channels=channels, + n_branches=n_branches, + n_frames=n_frames, + precision="float64", + trainable=True, + seed=7, + ) + _copy_pt_to_dp(pt_branch, dp_branch) + + left = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + right = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + scalar_pair = rng.normal(size=(n_batch, n_focus, 2 * channels)) + + dp_out = dp_branch.call( + left, right, scalar_pair, to_grid=np_to_grid, from_grid=np_from_grid + ) + pt_out = pt_branch( + torch.from_numpy(left), + torch.from_numpy(right), + torch.from_numpy(scalar_pair), + to_grid=np_to_grid, + from_grid=np_from_grid, + ) + assert dp_out.shape == (n_batch, coeff_dim, n_focus, n_frames * channels) + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + + +@pytest.mark.parametrize("n_branches", [1, 2]) # scalar-routed product branches +def test_gridbranch_s2_regression(n_branches) -> None: + """``n_frames == 1`` ``GridBranch`` matches pt over a real S2 projector. + + Guards the S2 path: the frame-aware reshape with ``n_frames == 1`` is an + identity, so the output stays byte-identical to the previous S2-only + specialization (which this generalization replaces). + """ + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + GridBranch as PTGridBranch, + S2GridNet as PTS2GridNet, + ) + + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet as DPS2GridNet, + ) + + lmax, channels, n_focus = 2, 4, 1 + # op_type='glu' makes grid_op a GridProduct; we reuse the nets only for + # their (parameter-free, deterministic) _to_grid/_from_grid S2 projectors. + pt_net = PTS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="glu", + dtype=torch.float64, + layout="ndfc", + grid_method="lebedev", + trainable=True, + seed=7, + ).to("cpu") + dp_net = DPS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="glu", + precision="float64", + layout="ndfc", + grid_method="lebedev", + trainable=True, + seed=7, + ) + + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(99) + + pt_branch = PTGridBranch( + channels=channels, + n_branches=n_branches, + n_frames=1, + dtype=torch.float64, + trainable=True, + seed=13, + ).to("cpu") + with torch.no_grad(): + for p in pt_branch.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_branch = DPGridBranch( + channels=channels, + n_branches=n_branches, + n_frames=1, + precision="float64", + trainable=True, + seed=13, + ) + _copy_pt_to_dp(pt_branch, dp_branch) + + left = rng.normal(size=(5, coeff_dim, n_focus, channels)) + right = rng.normal(size=(5, coeff_dim, n_focus, channels)) + scalar_pair = rng.normal(size=(5, n_focus, 2 * channels)) + + dp_out = dp_branch.call( + left, right, scalar_pair, to_grid=dp_net._to_grid, from_grid=dp_net._from_grid + ) + pt_out = pt_branch( + torch.from_numpy(left), + torch.from_numpy(right), + torch.from_numpy(scalar_pair), + to_grid=pt_net._to_grid, + from_grid=pt_net._from_grid, + ) + assert dp_out.shape == (5, coeff_dim, n_focus, channels) + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) + + +@pytest.mark.parametrize("n_branches", [1, 2]) # scalar-routed product branches +def test_gridbranch_serialize_roundtrip(n_branches) -> None: + """Serialize -> deserialize -> forward is identical; n_frames in config.""" + channels, n_frames, n_batch, coeff_dim, n_focus, grid_size = 4, 2, 3, 9, 2, 7 + rng = np.random.default_rng(404) + to_mat = rng.normal(size=(grid_size, coeff_dim, n_frames)) + from_mat = rng.normal(size=(coeff_dim, n_frames, grid_size)) + to_grid, from_grid = _make_grid_fns(to_mat, from_mat, n_frames) + + branch = DPGridBranch( + channels=channels, + n_branches=n_branches, + n_frames=n_frames, + precision="float64", + trainable=True, + seed=5, + ) + # perturb to non-default weights + for proj in ( + branch.left_proj, + branch.right_proj, + branch.router, + branch.out_proj, + ): + proj.weight = proj.weight + 0.1 * rng.normal(size=proj.weight.shape) + + data = branch.serialize() + assert data["@version"] == 1 + assert data["config"]["n_frames"] == n_frames + restored = DPGridBranch.deserialize(data) + + left = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + right = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + scalar_pair = rng.normal(size=(n_batch, n_focus, 2 * channels)) + out0 = branch.call(left, right, scalar_pair, to_grid=to_grid, from_grid=from_grid) + out1 = restored.call(left, right, scalar_pair, to_grid=to_grid, from_grid=from_grid) + np.testing.assert_allclose( + np.asarray(out0), np.asarray(out1), rtol=1e-12, atol=1e-12 + ) + + +@pytest.mark.parametrize("n_branches", [1, 2]) # scalar-routed product branches +def test_gridbranch_torch_namespace(n_branches) -> None: + """``GridBranch.call`` on torch input matches the numpy-input result. + + Array-API pitfall guard: the dpmodel forward must work with any namespace. + """ + import torch + + channels, n_frames, n_batch, coeff_dim, n_focus, grid_size = 4, 2, 3, 9, 2, 7 + rng = np.random.default_rng(77) + to_mat = rng.normal(size=(grid_size, coeff_dim, n_frames)) + from_mat = rng.normal(size=(coeff_dim, n_frames, grid_size)) + to_grid, from_grid = _make_grid_fns(to_mat, from_mat, n_frames) + + branch = DPGridBranch( + channels=channels, + n_branches=n_branches, + n_frames=n_frames, + precision="float64", + trainable=True, + seed=9, + ) + left = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + right = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + scalar_pair = rng.normal(size=(n_batch, n_focus, 2 * channels)) + + np_out = branch.call(left, right, scalar_pair, to_grid=to_grid, from_grid=from_grid) + torch_out = branch.call( + torch.from_numpy(left), + torch.from_numpy(right), + torch.from_numpy(scalar_pair), + to_grid=to_grid, + from_grid=from_grid, + ) + np.testing.assert_allclose( + np.asarray(np_out), + torch_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 7970915c9c..8f41ebbccb 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -1645,6 +1645,7 @@ def test_grid_branch(self, n_branches) -> None: dp_mod = DPGridBranch( channels=self.channels, n_branches=n_branches, + n_frames=1, precision="float64", seed=9, ) @@ -1859,7 +1860,7 @@ def test_value_errors(self) -> None: with pytest.raises(ValueError): # flat layout is cross-only DPS2GridNet(**{**common, "layout": "flat"}) with pytest.raises(ValueError): # n_branches must be positive - DPGridBranch(channels=4, n_branches=0, precision="float64") + DPGridBranch(channels=4, n_branches=0, n_frames=1, precision="float64") dp_net = DPS2GridNet(**common) rng = np.random.default_rng(2086) with pytest.raises(ValueError): # wrong query channel count From 8e2ca7c3850dbe5efb2bbfdda5fb482b3c9bed24 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 22:44:50 +0800 Subject: [PATCH 04/18] feat(dpmodel): generalize BaseGridNet (cross/flat/residual/n_frames>1) for DPA4 SO3 grid --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 317 ++++++++++++----- .../dpmodel/test_dpa4_basegridnet_cross.py | 320 ++++++++++++++++++ .../pt/model/test_dpa4_dpmodel_parity.py | 9 +- 3 files changed, 562 insertions(+), 84 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 32e9e5f07c..c82c69e378 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -3,32 +3,28 @@ Grid-space nonlinearities for DPA4/SeZM coefficient tensors. This module is the dpmodel port of -``deepmd.pt.model.descriptor.sezm_nn.grid_net``, restricted to the S2/Lebedev -path used by the core DPA4 configuration. A grid net receives coefficient -tensors, converts them to quadrature values, applies one point-wise grid -operation, and projects the result back to coefficients. The public shapes -are: +``deepmd.pt.model.descriptor.sezm_nn.grid_net``. A grid net receives +coefficient tensors, converts them to quadrature values, applies one +point-wise grid operation, and projects the result back to coefficients. The +public shapes are: * ``mode='self'``: one input ``(N, D, F, 2*C)`` or ``(N, F, D, 2*C)``. -* grid values: ``(N, G, F, C)`` after S2 projection. +* ``mode='cross'``: separate query/context inputs each with ``C`` channels. +* grid values: ``(N, G, F, C)`` after S2 or SO(3) projection. -Ported names: ``BaseGridNet`` (``mode='self'``; ``op_type`` -'glu'/'mlp'/'branch'), ``S2GridNet``, ``GridProduct``, ``GridMLP``, +Ported names: ``BaseGridNet`` (``mode`` 'self'/'cross'; ``op_type`` +'glu'/'mlp'/'branch'; ``layout`` 'ndfc'/'nfdc'/'flat'; ``residual_scale_init``; +general ``n_frames``), ``S2GridNet``, ``GridProduct``, ``GridMLP``, ``GridBranch``. -Skipped names, with consumer evidence from the pt sources: - -- ``SO3GridNet``: only constructed by ``so2.py`` (``node_wise_so3``, - ``message_node_so3``) and ``ffn.py`` (``ffn_so3_grid``) — all disabled in - the core DPA4 config. -- ``FrameContract``, ``FrameExpand``, ``_build_frame_degree_index``: only - constructed by ``SO3GridNet`` (``mode='cross'``); the S2 projector always - has ``n_frames == 1``, so the frame machinery is unreachable here. - -Guarded (routable from the shared ``S2GridNet`` entry point but only used by -the disabled ``node_wise_s2``/``message_node_s2`` grid products in -``so2.py``): ``mode='cross'`` (and with it ``layout='flat'``) and -``residual_scale_init is not None`` raise ``NotImplementedError``. +``BaseGridNet`` mirrors the current pt ``BaseGridNet`` for arbitrary +``n_frames`` (the ``_to_grid``/``_from_grid`` frame-axis contraction). The S2 +path (``n_frames == 1``, ``mode='self'``) keeps a dedicated fast branch that is +byte-identical to the previous S2-only specialization. The SO(3) frame +machinery (``SO3GridNet``, ``FrameContract``, ``FrameExpand``) is not ported +here; ``BaseGridNet`` exposes ``frame_expand``/``frame_contract`` seams (kept +``None`` for S2) so a later SO(3) port can plug them in without touching the +shared forward. Serialization contract: the pt ``S2GridNet`` and ``GridBranch`` define no ``serialize()`` (they only appear nested inside larger modules' @@ -543,15 +539,19 @@ def _load_variables(self, variables: dict[str, Any]) -> None: class BaseGridNet(NativeOP): """ - Shared implementation for S2 grid nets (``mode='self'`` only). + Shared implementation for S2 and SO(3) grid nets. ``mode='self'`` expects one input whose last channel axis contains two branches; the first half supplies the SwiGLU gates of the scalar path. - - The pt ``mode='cross'`` path (with ``layout='flat'``, - ``residual_scale_init``, and the SO(3) frame machinery) backs the - ``node_wise_s2``/``message_node_s2`` grid products only, which are - disabled in the core DPA4 config; it is not ported. + ``mode='cross'`` expects separate query and context inputs. + + Mirrors the current pt ``BaseGridNet``: ``mode`` ('self'/'cross'), + ``layout`` ('ndfc'/'nfdc'/'flat'), ``residual_scale_init`` and arbitrary + ``n_frames`` are all supported. The S2 (``n_frames == 1``) path keeps a + dedicated fast branch in ``_to_grid``/``_from_grid``/``_apply_scalar_path`` + that is byte-identical to the previous S2-only specialization. The SO(3) + frame machinery (``frame_expand``/``frame_contract``) is built by the + not-yet-ported ``SO3GridNet``; the seams here stay ``None`` for S2. """ def __init__( @@ -567,6 +567,8 @@ def __init__( mlp_bias: bool, trainable: bool = True, grid_branches: int = 1, + frame_expand: NativeOP | None = None, + frame_contract: NativeOP | None = None, residual_scale_init: float | None = None, seed: int | list[int] | None = None, ) -> None: @@ -575,18 +577,9 @@ def __init__( self.channels = int(channels) self.n_focus = int(n_focus) self.n_frames = int(projector.n_frames) - if self.n_frames != 1: - raise ValueError( - "dpmodel BaseGridNet only supports S2 projectors (n_frames == 1)" - ) self.mode = str(mode).lower() if self.mode not in {"self", "cross"}: raise ValueError("`mode` must be either 'self' or 'cross'") - if self.mode == "cross": - raise NotImplementedError( - "mode='cross' (node_wise_s2/message_node_s2 grid products) " - "is not ported to dpmodel" - ) self.op_type = str(op_type).lower() if self.op_type not in {"glu", "mlp", "branch"}: raise ValueError("`op_type` must be one of 'glu', 'mlp', or 'branch'") @@ -599,16 +592,36 @@ def __init__( self.mlp_bias = bool(mlp_bias) self.trainable = bool(trainable) self.expanded_channels = self.n_frames * self.channels - self.query_channels = 2 * self.expanded_channels - self.output_channels = self.expanded_channels - self.frame_zero_index = 0 - if residual_scale_init is not None: - raise NotImplementedError( - "`residual_scale_init` is only used by the cross-mode " - "node_wise_s2/message_node_s2 grid products, which are not " - "ported to dpmodel" + # ``frame_expand``/``frame_contract`` are the SO(3) frame machinery + # (built only by ``SO3GridNet`` in cross mode). They stay ``None`` for + # S2 (``n_frames == 1``); the seam below lets a later SO(3) port plug + # them in without touching the shared forward. + self.frame_expand = frame_expand + self.frame_contract = frame_contract + self.query_channels = ( + 2 * self.expanded_channels + if self.mode == "self" + else ( + self.channels + if self.frame_expand is not None + else self.expanded_channels ) - self.residual_scale = None + ) + self.context_channels = ( + self.channels if self.frame_expand is not None else self.expanded_channels + ) + self.output_channels = ( + self.channels if self.frame_contract is not None else self.expanded_channels + ) + self.frame_zero_index = int(getattr(projector, "frame_zero_index", 0)) + self.residual_scale_init = residual_scale_init + if residual_scale_init is None: + self.residual_scale: np.ndarray | None = None + else: + prec = PRECISION_DICT[self.precision.lower()] + self.residual_scale = np.ones( + (self.n_focus, self.output_channels), dtype=prec + ) * float(residual_scale_init) self.scalar_act = SwiGLU() self.scalar_gate = FocusLinear( @@ -647,14 +660,73 @@ def call(self, query: Any, context: Any = None) -> Any: xp = array_api_compat.array_namespace(query) input_dtype = query.dtype compute_dtype = get_xp_precision(xp, self.precision) - query_ndfc = self._to_ndfc(query) - left, right = self._split_self_query(query_ndfc) - scalar_pair = self._make_scalar_pair(left, right, compute_dtype) + query_ndfc, shape_info = self._to_ndfc(query) + left, right, scalar_pair = self._prepare_pair( + query_ndfc, context, compute_dtype + ) coeff_out = self._apply_grid_op(left, right, scalar_pair, compute_dtype) coeff_out = self._apply_scalar_path(coeff_out, scalar_pair) + coeff_out = self._contract_frames(coeff_out) + coeff_out = self._apply_residual_scale(coeff_out) if coeff_out.dtype != input_dtype: coeff_out = xp.astype(coeff_out, input_dtype) - return self._restore_layout(coeff_out) + return self._restore_layout(coeff_out, shape_info) + + def _prepare_pair( + self, query: Any, context: Any, compute_dtype: Any + ) -> tuple[Any, Any, Any]: + if self.mode == "self": + return self._prepare_self_pair(query, compute_dtype) + return self._prepare_cross_pair(query, context, compute_dtype) + + def _prepare_self_pair( + self, query: Any, compute_dtype: Any + ) -> tuple[Any, Any, Any]: + left, right = self._split_self_query(query) + scalar_pair = self._make_scalar_pair(left, right, compute_dtype) + return left, right, scalar_pair + + def _prepare_cross_pair( + self, query: Any, context: Any, compute_dtype: Any + ) -> tuple[Any, Any, Any]: + if context is None: + raise ValueError("`context` is required when `mode='cross'`") + context_ndfc, _ = self._to_ndfc(context) + self._check_last_dim(query, self.context_channels, "query") + self._check_last_dim(context_ndfc, self.context_channels, "context") + if self.frame_expand is None: + scalar_pair = self._make_scalar_pair(query, context_ndfc, compute_dtype) + return query, context_ndfc, scalar_pair + # SO(3) frame-expansion seam (built only by a later SO3GridNet port): + # the scalar pair is read from the d=0 slice before expansion, then + # both operands are lifted to the frame-packed width. + xp = array_api_compat.array_namespace(query) + scalar_pair = xp.concat([query[:, 0, :, :], context_ndfc[:, 0, :, :]], axis=-1) + if scalar_pair.dtype != compute_dtype: + scalar_pair = xp.astype(scalar_pair, compute_dtype) + return ( + self.frame_expand(query), + self.frame_expand(context_ndfc), + scalar_pair, + ) + + def _contract_frames(self, coeff: Any) -> Any: + if self.frame_contract is None: + return coeff + return self.frame_contract(coeff) + + def _apply_residual_scale(self, coeff: Any) -> Any: + if self.residual_scale is None: + return coeff + xp = array_api_compat.array_namespace(coeff) + residual_scale = xp_asarray_nodetach( + xp, self.residual_scale[...], device=array_api_compat.device(coeff) + ) + if residual_scale.dtype != coeff.dtype: + residual_scale = xp.astype(residual_scale, coeff.dtype) + return coeff * xp.reshape( + residual_scale, (1, 1, self.n_focus, self.output_channels) + ) def _apply_grid_op( self, @@ -680,11 +752,32 @@ def _apply_scalar_path(self, coeff: Any, scalar_pair: Any) -> Any: xp = array_api_compat.array_namespace(coeff) scalar_out = self.scalar_act(scalar_pair) # (N, F, C) scalar_gate = xp_sigmoid(self.scalar_gate(scalar_pair)) # (N, F, C) - coeff = coeff * scalar_gate[:, None, :, :] - # gradient-safe equivalent of the pt in-place - # ``coeff_view[:, 0, :, 0, :].add_(scalar_out)`` (n_frames == 1) - head = coeff[:, :1, :, :] + scalar_out[:, None, :, :] - return xp.concat([head, coeff[:, 1:, :, :]], axis=1) + if self.n_frames == 1: + # Fast S2 path (byte-identical to the previous specialization). + coeff = coeff * scalar_gate[:, None, :, :] + # gradient-safe equivalent of the pt in-place + # ``coeff_view[:, 0, :, 0, :].add_(scalar_out)`` (n_frames == 1) + head = coeff[:, :1, :, :] + scalar_out[:, None, :, :] + return xp.concat([head, coeff[:, 1:, :, :]], axis=1) + # General frame-packed path mirroring the pt + # ``coeff_view = coeff.reshape(N, D, F, K, C)`` followed by a gated + # multiply and an in-place add into ``[:, 0, :, frame_zero_index, :]``. + n_batch, coeff_dim, n_focus, _ = coeff.shape + coeff_view = xp.reshape( + coeff, (n_batch, coeff_dim, n_focus, self.n_frames, self.channels) + ) + coeff_view = coeff_view * scalar_gate[:, None, :, None, :] + # gradient-safe in-place add into the d=0, frame_zero_index slice + fzi = self.frame_zero_index + head = coeff_view[:, :1, :, :, :] # (N, 1, F, K, C) + pre = head[:, :, :, :fzi, :] + mid = head[:, :, :, fzi : fzi + 1, :] + scalar_out[:, None, :, None, :] + post = head[:, :, :, fzi + 1 :, :] + head = xp.concat([pre, mid, post], axis=3) + coeff_view = xp.concat([head, coeff_view[:, 1:, :, :, :]], axis=1) + return xp.reshape( + coeff_view, (n_batch, coeff_dim, n_focus, self.expanded_channels) + ) def _split_self_query(self, query: Any) -> tuple[Any, Any]: self._check_last_dim(query, self.query_channels, "query") @@ -708,51 +801,103 @@ def _make_scalar_pair(self, left: Any, right: Any, compute_dtype: Any) -> Any: return scalar_pair def _extract_scalar(self, coeff: Any) -> Any: - # (N, D, F, C) -> the (l=0, m=0) scalar slice (N, F, C); n_frames == 1 - return coeff[:, 0, :, :] + # (N, D, F, K*C) -> the (l=0, m=0) scalar slice (N, F, C). + if self.n_frames == 1: + return coeff[:, 0, :, :] + xp = array_api_compat.array_namespace(coeff) + n_batch, coeff_dim, n_focus, _ = coeff.shape + coeff_view = xp.reshape( + coeff, (n_batch, coeff_dim, n_focus, self.n_frames, self.channels) + ) + return coeff_view[:, 0, :, self.frame_zero_index, :] def _to_grid(self, coeff: Any) -> Any: - # einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul. - # The per-point channel width is inferred so the projector also serves - # widened operands (e.g. a branch hidden width ``n_branches * C``). xp = array_api_compat.array_namespace(coeff) - n_batch, coeff_dim, n_focus, n_channels = coeff.shape to_grid_mat = xp_asarray_nodetach( xp, self.projector.to_grid_mat[...], device=array_api_compat.device(coeff) ) if to_grid_mat.dtype != coeff.dtype: to_grid_mat = xp.astype(to_grid_mat, coeff.dtype) - flat = xp.reshape(coeff, (n_batch, coeff_dim, n_focus * n_channels)) - out = xp.matmul(to_grid_mat[None, ...], flat) # (N, G, F*C) + if self.n_frames == 1: + # einsum "gd,ndfc->ngfc" as a broadcast batched matmul. The per-point + # channel width is inferred so the projector also serves widened + # operands (e.g. a branch hidden width ``n_branches * C``). + n_batch, coeff_dim, n_focus, n_channels = coeff.shape + flat = xp.reshape(coeff, (n_batch, coeff_dim, n_focus * n_channels)) + out = xp.matmul(to_grid_mat[None, ...], flat) # (N, G, F*C) + return xp.reshape( + out, (n_batch, self.projector.grid_size, n_focus, n_channels) + ) + # General SO(3) frame-packed path mirroring the pt + # ``einsum("gdk,ndfkc->ngfc", to_grid.reshape(G, D, K), coeff_view)``. + # ``to_grid_mat`` columns are ordered (d outer, k inner), so the operand + # is permuted to the matching ``(d, k)`` flattening before the matmul. + n_batch, coeff_dim, n_focus, last = coeff.shape + n_channels = last // self.n_frames + coeff_view = xp.reshape( + coeff, (n_batch, coeff_dim, n_focus, self.n_frames, n_channels) + ) + coeff_dk = xp.permute_dims(coeff_view, (0, 1, 3, 2, 4)) # (N, D, K, F, C) + coeff_flat = xp.reshape( + coeff_dk, (n_batch, coeff_dim * self.n_frames, n_focus * n_channels) + ) + out = xp.matmul(to_grid_mat[None, ...], coeff_flat) # (N, G, F*C) return xp.reshape(out, (n_batch, self.projector.grid_size, n_focus, n_channels)) def _from_grid(self, grid: Any) -> Any: - # einsum "dg,ngfc->ndfc" (n_frames == 1) as a broadcast batched matmul. - # The channel width is inferred to match the (possibly widened) grid. xp = array_api_compat.array_namespace(grid) - n_batch, n_grid, n_focus, n_channels = grid.shape - coeff_dim = self.projector.coeff_dim from_grid_mat = xp_asarray_nodetach( xp, self.projector.from_grid_mat[...], device=array_api_compat.device(grid) ) if from_grid_mat.dtype != grid.dtype: from_grid_mat = xp.astype(from_grid_mat, grid.dtype) + if self.n_frames == 1: + # einsum "dg,ngfc->ndfc" as a broadcast batched matmul. The channel + # width is inferred to match the (possibly widened) grid field. + n_batch, n_grid, n_focus, n_channels = grid.shape + coeff_dim = self.projector.coeff_dim + flat = xp.reshape(grid, (n_batch, n_grid, n_focus * n_channels)) + out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D, F*C) + return xp.reshape(out, (n_batch, coeff_dim, n_focus, n_channels)) + # General SO(3) frame-packed path mirroring the pt + # ``einsum("dkg,ngfc->ndfkc", from_grid.reshape(D, K, G), grid)`` then a + # reshape to ``(N, D, F, K*C)``. ``from_grid_mat`` rows are ordered + # (d outer, k inner); the matmul output is reshaped/permuted to match. + n_batch, n_grid, n_focus, n_channels = grid.shape + coeff_dim = self.projector.coeff_dim // self.n_frames flat = xp.reshape(grid, (n_batch, n_grid, n_focus * n_channels)) - out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D, F*C) - return xp.reshape(out, (n_batch, coeff_dim, n_focus, n_channels)) + out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D*K, F*C) + out = xp.reshape(out, (n_batch, coeff_dim, self.n_frames, n_focus, n_channels)) + out = xp.permute_dims(out, (0, 1, 3, 2, 4)) # (N, D, F, K, C) + return xp.reshape( + out, (n_batch, coeff_dim, n_focus, self.n_frames * n_channels) + ) - def _to_ndfc(self, value: Any) -> Any: + def _to_ndfc(self, value: Any) -> tuple[Any, tuple[int, ...]]: + shape_info = tuple(value.shape) if self.layout == "ndfc": - return value - # "nfdc": (N, F, D, C) -> (N, D, F, C); "flat" is cross-only (blocked) + return value, shape_info + if self.layout == "nfdc": + # (N, F, D, C) -> (N, D, F, C) + xp = array_api_compat.array_namespace(value) + return xp.permute_dims(value, (0, 2, 1, 3)), shape_info + # "flat": (N, D, F*k*C) -> (N, D, F, k*C) xp = array_api_compat.array_namespace(value) - return xp.permute_dims(value, (0, 2, 1, 3)) + n_batch, coeff_dim, _ = value.shape + return ( + xp.reshape(value, (n_batch, coeff_dim, self.n_focus, -1)), + shape_info, + ) - def _restore_layout(self, value: Any) -> Any: + def _restore_layout(self, value: Any, shape_info: tuple[int, ...]) -> Any: if self.layout == "ndfc": return value xp = array_api_compat.array_namespace(value) - return xp.permute_dims(value, (0, 2, 1, 3)) + if self.layout == "nfdc": + return xp.permute_dims(value, (0, 2, 1, 3)) + # "flat": (N, D, F, k*C) -> (N, D, F*k*C) + n_batch, coeff_dim, _ = shape_info + return xp.reshape(value, (n_batch, coeff_dim, -1)) def _check_last_dim(self, value: Any, expected: int, name: str) -> None: if value.shape[-1] != expected: @@ -775,14 +920,14 @@ class S2GridNet(BaseGridNet): n_focus : int Number of focus streams. mode : str - Pairing mode; only ``"self"`` is ported. + Pairing mode; ``"self"`` or ``"cross"``. op_type : str - Point-wise grid operation; ``"glu"`` or ``"branch"`` (``"mlp"`` is - not ported). + Point-wise grid operation; ``"glu"``, ``"mlp"`` or ``"branch"``. precision : str Parameter precision. layout : str - Tensor layout convention: ``"ndfc"`` or ``"nfdc"``. + Tensor layout convention: ``"ndfc"``, ``"nfdc"`` or ``"flat"`` + (``"flat"`` is cross-only). grid_resolution_list : list[int] | None Lebedev ``[precision, n_points]`` pair; resolved automatically if None. coefficient_layout : str @@ -792,7 +937,8 @@ class S2GridNet(BaseGridNet): grid_branches : int Number of scalar-routed branches when ``op_type='branch'``. residual_scale_init : float | None - Not ported (cross-mode only); must be None. + Initial value of the per-(focus, channel) residual scale; ``None`` + disables the residual scale. mlp_bias : bool Whether to use bias in the scalar gate projection. trainable : bool @@ -865,6 +1011,9 @@ def serialize(self) -> dict[str, Any]: grid_op_data = self.grid_op.serialize()["@variables"] for key, value in grid_op_data.items(): variables[f"grid_op.{key}"] = value + if self.residual_scale is not None: + # pt state-dict key name for the (n_focus, output_channels) parameter + variables["residual_scale"] = to_numpy_array(self.residual_scale) return { "@class": "S2GridNet", "@version": 1, @@ -881,6 +1030,7 @@ def serialize(self) -> dict[str, Any]: "coefficient_layout": self.projector.coefficient_layout, "grid_method": self.grid_method, "grid_branches": self.grid_branches, + "residual_scale_init": self.residual_scale_init, "mlp_bias": self.mlp_bias, "trainable": self.trainable, "seed": None, @@ -912,6 +1062,7 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridNet: coefficient_layout=str(config["coefficient_layout"]), grid_method=str(config["grid_method"]), grid_branches=int(config["grid_branches"]), + residual_scale_init=config.get("residual_scale_init"), mlp_bias=bool(config["mlp_bias"]), trainable=bool(config["trainable"]), seed=config.get("seed"), @@ -928,6 +1079,14 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridNet: obj.scalar_gate.bias = np.asarray( variables["scalar_gate.bias"], dtype=prec ).reshape(obj.scalar_gate.bias.shape) + if obj.residual_scale is not None: + residual_scale = np.asarray(variables["residual_scale"], dtype=prec) + if residual_scale.shape != obj.residual_scale.shape: + raise ValueError( + f"residual_scale shape {residual_scale.shape} does not match " + f"the expected shape {obj.residual_scale.shape}" + ) + obj.residual_scale = residual_scale if obj.op_type in {"mlp", "branch"}: obj.grid_op._load_variables( { diff --git a/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py b/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py new file mode 100644 index 0000000000..c097b4293f --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity / equivariance tests for the generalized DPA4 ``BaseGridNet``. + +These cover the ``mode='cross'``, ``layout='flat'`` and ``residual_scale_init`` +paths that ``BaseGridNet`` gained when it was generalized to mirror the current +pt ``deepmd.pt.model.descriptor.sezm_nn.grid_net.BaseGridNet``. All tests use +``S2GridNet`` (``n_frames == 1``); the ``n_frames > 1`` SO(3) frame contraction +in ``_to_grid``/``_from_grid`` is exercised structurally by the SO(3) port +(verified there). pt imports live inside the test functions because ruff TID253 +bans module-level ``deepmd.pt`` imports under ``source/tests/common``; pt +modules are pinned to CPU (``.to("cpu")``) under the CUDA-default-device CI. +""" + +import array_api_compat +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet as DPS2GridNet, +) + + +def _grid_op_param_names(op_type): + return { + "glu": (), + "mlp": ("left_proj", "right_proj", "out_proj"), + "branch": ("left_proj", "right_proj", "router", "out_proj"), + }[op_type] + + +def _build_nets( + *, + mode, + op_type, + layout, + residual_scale_init=None, + lmax=2, + channels=4, + n_focus=1, + grid_branches=1, + mlp_bias=False, + grid_resolution_list=None, + seed=7, +): + """Build a pt + dp ``S2GridNet`` with identical (perturbed) weights.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + S2GridNet as PTS2GridNet, + ) + + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": mode, + "op_type": op_type, + "layout": layout, + "grid_resolution_list": grid_resolution_list, + "coefficient_layout": "packed", + "grid_method": "lebedev", + "grid_branches": grid_branches, + "residual_scale_init": residual_scale_init, + "mlp_bias": mlp_bias, + "trainable": True, + "seed": seed, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + rng = np.random.default_rng(2100) + with torch.no_grad(): + for p in pt_net.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_net = DPS2GridNet(precision="float64", **common) + + state = {k: v.detach().cpu().numpy() for k, v in pt_net.state_dict().items()} + expected = {"scalar_gate.weight"} + if mlp_bias: + expected.add("scalar_gate.bias") + expected |= {f"grid_op.{n}.weight" for n in _grid_op_param_names(op_type)} + if residual_scale_init is not None: + expected.add("residual_scale") + assert set(state) == expected, set(state) + + dp_net.scalar_gate.weight = state["scalar_gate.weight"] + if mlp_bias: + dp_net.scalar_gate.bias = state["scalar_gate.bias"] + for name in _grid_op_param_names(op_type): + getattr(dp_net.grid_op, name).weight = state[f"grid_op.{name}.weight"] + if residual_scale_init is not None: + dp_net.residual_scale = state["residual_scale"] + return pt_net, dp_net + + +def _coeff_dim(lmax): + return (lmax + 1) ** 2 + + +def _make_inputs(*, mode, layout, n_batch, lmax, n_focus, channels, rng): + """Build (query, context) for the given mode/layout. context is None for self.""" + coeff_dim = _coeff_dim(lmax) + if mode == "self": + if layout == "nfdc": + query = rng.normal(size=(n_batch, n_focus, coeff_dim, 2 * channels)) + else: # ndfc + query = rng.normal(size=(n_batch, coeff_dim, n_focus, 2 * channels)) + return query, None + # cross + if layout == "flat": + query = rng.normal(size=(n_batch, coeff_dim, n_focus * channels)) + context = rng.normal(size=(n_batch, coeff_dim, n_focus * channels)) + elif layout == "nfdc": + query = rng.normal(size=(n_batch, n_focus, coeff_dim, channels)) + context = rng.normal(size=(n_batch, n_focus, coeff_dim, channels)) + else: # ndfc + query = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + context = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + return query, context + + +def _run(net, query, context, backend): + """Run a net with the given backend; return numpy output.""" + if backend == "pt": + import torch + + q = torch.from_numpy(query) + c = None if context is None else torch.from_numpy(context) + return net(q, c).detach().cpu().numpy() + out = net.call(query, None if context is None else context) + return np.asarray(out) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_s2_self_regression(op_type) -> None: + """mode='self' S2GridNet still matches pt at 1e-12 (guards the self path).""" + lmax, n_focus, n_batch = 2, 2, 5 + pt_net, dp_net = _build_nets( + mode="self", op_type=op_type, layout="ndfc", lmax=lmax, n_focus=n_focus + ) + rng = np.random.default_rng(11) + query, context = _make_inputs( + mode="self", + layout="ndfc", + n_batch=n_batch, + lmax=lmax, + n_focus=n_focus, + channels=dp_net.channels, + rng=rng, + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_s2_cross_parity(op_type) -> None: + """mode='cross' S2GridNet matches pt at 1e-12 (separate query/context).""" + lmax, n_focus, n_batch = 2, 2, 5 + pt_net, dp_net = _build_nets( + mode="cross", op_type=op_type, layout="ndfc", lmax=lmax, n_focus=n_focus + ) + rng = np.random.default_rng(22) + query, context = _make_inputs( + mode="cross", + layout="ndfc", + n_batch=n_batch, + lmax=lmax, + n_focus=n_focus, + channels=dp_net.channels, + rng=rng, + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + assert dp_out.shape == query.shape + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +def _rotate_ndfc(x, d_matrix): + """Rotate coefficient-layout tensors (N, D, F, C) by per-batch (N, D, D).""" + return np.einsum("nij,njfc->nifc", d_matrix, x) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_s2_cross_equivariance(op_type) -> None: + """net(rot(q), rot(c)) == rot(net(q, c)) for a shared SO(3) rotation. + + The default Lebedev grid has algebraic precision >= ``3 * lmax``, so the + degree-``2 * lmax`` grid product integrates exactly and the grid net is + equivariant to machine precision. + """ + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + lmax, n_focus, n_batch, channels = 2, 1, 4, 4 + _, dp_net = _build_nets( + mode="cross", + op_type=op_type, + layout="ndfc", + lmax=lmax, + n_focus=n_focus, + channels=channels, + ) + rng = np.random.default_rng(33) + query, context = _make_inputs( + mode="cross", + layout="ndfc", + n_batch=n_batch, + lmax=lmax, + n_focus=n_focus, + channels=channels, + rng=rng, + ) + quat = rng.normal(size=(n_batch, 4)) + quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True) + d_matrix = np.asarray(WignerDCalculator(lmax, precision="float64")(quat)[0]) + + y_rot_in = _run( + dp_net, _rotate_ndfc(query, d_matrix), _rotate_ndfc(context, d_matrix), "dp" + ) + y_then_rot = _rotate_ndfc(_run(dp_net, query, context, "dp"), d_matrix) + np.testing.assert_allclose(y_rot_in, y_then_rot, rtol=1e-10, atol=1e-10) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_layout_flat_parity(op_type) -> None: + """mode='cross', layout='flat' matches pt at 1e-12.""" + lmax, n_focus, n_batch = 2, 3, 5 + pt_net, dp_net = _build_nets( + mode="cross", op_type=op_type, layout="flat", lmax=lmax, n_focus=n_focus + ) + rng = np.random.default_rng(44) + query, context = _make_inputs( + mode="cross", + layout="flat", + n_batch=n_batch, + lmax=lmax, + n_focus=n_focus, + channels=dp_net.channels, + rng=rng, + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + assert dp_out.shape == query.shape + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +@pytest.mark.parametrize("residual_scale_init", [None, 0.5]) # residual scale init +def test_residual_scale_parity(residual_scale_init) -> None: + """residual_scale_init parity vs pt at 1e-12; residual_scale (de)serialized.""" + lmax, n_focus, n_batch = 2, 2, 5 + pt_net, dp_net = _build_nets( + mode="cross", + op_type="glu", + layout="ndfc", + lmax=lmax, + n_focus=n_focus, + residual_scale_init=residual_scale_init, + ) + rng = np.random.default_rng(55) + query, context = _make_inputs( + mode="cross", + layout="ndfc", + n_batch=n_batch, + lmax=lmax, + n_focus=n_focus, + channels=dp_net.channels, + rng=rng, + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + # serialize -> deserialize keeps residual_scale and the forward output + data = dp_net.serialize() + assert data["config"]["residual_scale_init"] == residual_scale_init + if residual_scale_init is None: + assert "residual_scale" not in data["@variables"] + else: + assert "residual_scale" in data["@variables"] + restored = DPS2GridNet.deserialize(data) + if residual_scale_init is None: + assert restored.residual_scale is None + else: + np.testing.assert_array_equal(restored.residual_scale, dp_net.residual_scale) + np.testing.assert_allclose( + _run(restored, query, context, "dp"), dp_out, rtol=1e-12, atol=1e-12 + ) + + +def test_torch_namespace() -> None: + """cross-mode S2GridNet.call on torch.from_numpy input matches numpy result.""" + import torch + + lmax, n_focus, n_batch, channels = 2, 2, 5, 4 + _, dp_net = _build_nets( + mode="cross", + op_type="mlp", + layout="ndfc", + lmax=lmax, + n_focus=n_focus, + channels=channels, + residual_scale_init=0.7, + ) + rng = np.random.default_rng(66) + query, context = _make_inputs( + mode="cross", + layout="ndfc", + n_batch=n_batch, + lmax=lmax, + n_focus=n_focus, + channels=channels, + rng=rng, + ) + np_out = np.asarray(dp_net.call(query, context)) + torch_out = dp_net.call(torch.from_numpy(query), torch.from_numpy(context)) + assert array_api_compat.is_torch_array(torch_out) + np.testing.assert_allclose( + np_out, torch_out.detach().cpu().numpy(), rtol=1e-12, atol=1e-12 + ) diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 8f41ebbccb..7b4087b6bc 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -1811,11 +1811,10 @@ def test_not_ported_guards(self) -> None: # "e3nn" default, which dp rejects): default construction works net = DPS2GridNet(**{k: v for k, v in common.items() if k != "grid_method"}) assert net.grid_method == "lebedev" - with pytest.raises(NotImplementedError, match="node_wise_s2"): - # cross mode backs node_wise_s2/message_node_s2 only - DPS2GridNet(**{**common, "mode": "cross"}) - with pytest.raises(NotImplementedError, match="residual_scale_init"): - DPS2GridNet(**common, residual_scale_init=1e-3) + # cross mode and residual_scale_init are now ported (see + # test_dpa4_basegridnet_cross.py for parity coverage); they construct. + DPS2GridNet(**{**common, "mode": "cross"}) + DPS2GridNet(**common, residual_scale_init=1e-3) def test_value_errors(self) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( From e29a8c1932dbd1d561040ec99b716373afe9b10e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 22:50:46 +0800 Subject: [PATCH 05/18] feat(dpmodel): port FrameContract/FrameExpand per-degree frame mixers for DPA4 SO3 grid --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 247 +++++++++++++++++ .../common/dpmodel/test_dpa4_frame_mixers.py | 250 ++++++++++++++++++ 2 files changed, 497 insertions(+) create mode 100644 source/tests/common/dpmodel/test_dpa4_frame_mixers.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index c82c69e378..68e0e056b8 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -45,6 +45,8 @@ Any, ) +import math + import array_api_compat import numpy as np @@ -71,6 +73,11 @@ from .activation import ( SwiGLU, ) +from .indexing import ( + build_l_major_index, + build_m_major_l_index, + map_degree_idx, +) from .projection import ( BaseGridProjector, S2GridProjector, @@ -124,6 +131,246 @@ def _project_frames(coeff: Any, proj: ChannelLinear, n_frames: int) -> Any: return xp.reshape(projected, (n_batch, coeff_dim, n_focus, -1)) +def _build_frame_degree_index( + *, + lmax: int, + mmax: int, + coefficient_layout: str, +) -> np.ndarray: + """Build the per-coefficient degree index used by the frame channel mixers. + + The pt version's ``device`` parameter is dropped: the output is a static + ``np.int64`` table mapping each coefficient row to its degree ``l`` for the + packed / truncated / m-major layouts. + """ + coefficient_layout = str(coefficient_layout).lower() + if coefficient_layout == "m_major": + return build_m_major_l_index(lmax, mmax) + if coefficient_layout == "packed": + degree_index = map_degree_idx(lmax) + if int(mmax) == int(lmax): + return degree_index + coeff_index = build_l_major_index(lmax, mmax) + return degree_index[coeff_index] + raise ValueError("`coefficient_layout` must be either 'packed' or 'm_major'") + + +class _FrameMixer(NativeOP): + """Shared base for the per-degree frame channel mixers. + + The pt ``FrameContract`` / ``FrameExpand`` are ``nn.Module`` wrappers around + a per-degree weight of shape ``(lmax + 1, in_ch, out_ch)`` selected by a + static degree-index buffer; they realise an + ``einsum("ndfi,dio->ndfo", coeff, weight[degree_index])``. ``mode='self'`` + S2 grid nets have ``n_frames == 1`` and never construct these; they back the + SO(3) cross-mode grid products only. Subclasses set ``in_channels`` / + ``out_channels`` and the init ``bound`` (matching the pt weight init). + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + in_channels: int, + out_channels: int, + init_bound: float, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + self.lmax = int(lmax) + self.mmax = int(mmax) + self.coefficient_layout = str(coefficient_layout).lower() + self.n_frames = int(n_frames) + self.channels = int(channels) + self.precision = precision + self.trainable = bool(trainable) + # static np.int64 table; rebuilt from config on deserialize (the pt + # degree_index is a non-persistent buffer, not in the state dict) + self.degree_index = _build_frame_degree_index( + lmax=self.lmax, + mmax=self.mmax, + coefficient_layout=self.coefficient_layout, + ) + prec = PRECISION_DICT[self.precision.lower()] + rng = np.random.default_rng(seed) + shape = (self.lmax + 1, int(in_channels), int(out_channels)) + self.weight = rng.uniform(-init_bound, init_bound, size=shape).astype(prec) + + def call(self, coeff: Any) -> Any: + """Apply the per-degree frame/channel map preserving the order index. + + ``einsum("ndfi,dio->ndfo", coeff, weight[degree_index])`` is realised as + a broadcast batched matmul: the gathered weight ``(D, i, o)`` broadcasts + over the leading frame batch dim of ``coeff``. + """ + xp = array_api_compat.array_namespace(coeff) + device = array_api_compat.device(coeff) + weight = xp_asarray_nodetach(xp, self.weight[...], device=device) + if weight.dtype != coeff.dtype: + weight = xp.astype(weight, coeff.dtype) + degree_index = xp_asarray_nodetach(xp, self.degree_index, device=device) + weight = xp.take(weight, degree_index, axis=0) # (D, i, o) + # (N, D, F, i) @ (1, D, i, o) -> (N, D, F, o) + return xp.matmul(coeff, weight[None, ...]) + + def _serialize_config(self) -> dict[str, Any]: + return { + "lmax": self.lmax, + "mmax": self.mmax, + "coefficient_layout": self.coefficient_layout, + "n_frames": self.n_frames, + "channels": self.channels, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "trainable": self.trainable, + "seed": None, + } + + @classmethod + def _deserialize(cls, data: dict[str, Any]) -> Any: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != cls.__name__: + raise ValueError(f"Invalid class for {cls.__name__}: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls( + lmax=int(config["lmax"]), + mmax=int(config["mmax"]), + coefficient_layout=str(config["coefficient_layout"]), + n_frames=int(config["n_frames"]), + channels=int(config["channels"]), + precision=str(config["precision"]), + trainable=bool(config["trainable"]), + seed=config.get("seed"), + ) + prec = PRECISION_DICT[obj.precision.lower()] + weight = np.asarray(variables["weight"], dtype=prec) + if weight.shape != obj.weight.shape: + raise ValueError( + f"weight shape {weight.shape} does not match " + f"the expected shape {obj.weight.shape}" + ) + obj.weight = weight + return obj + + +class FrameContract(_FrameMixer): + """Per-degree frame/channel contraction that preserves the order index. + + Maps ``(N, D, F, K*C) -> (N, D, F, C)`` with a per-degree weight of shape + ``(lmax + 1, K*C, C)`` where ``K`` is ``n_frames``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + n_frames = int(n_frames) + channels = int(channels) + super().__init__( + lmax=lmax, + mmax=mmax, + coefficient_layout=coefficient_layout, + n_frames=n_frames, + channels=channels, + in_channels=n_frames * channels, + out_channels=channels, + init_bound=1.0 / math.sqrt(n_frames * channels), + precision=precision, + trainable=trainable, + seed=seed, + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the FrameContract to a dict. + + The pt ``FrameContract`` has no ``serialize()``; the ``@variables`` key + (``weight``) matches the pt ``state_dict`` key name. ``degree_index`` is + a non-persistent buffer in pt and is rebuilt from the config. + """ + return { + "@class": "FrameContract", + "@version": 1, + "config": self._serialize_config(), + "@variables": {"weight": to_numpy_array(self.weight)}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> FrameContract: + """Deserialize a FrameContract from a dict.""" + return cls._deserialize(data) + + +class FrameExpand(_FrameMixer): + """Per-degree frame/channel expansion that preserves the order index. + + Maps ``(N, D, F, C) -> (N, D, F, K*C)`` with a per-degree weight of shape + ``(lmax + 1, C, K*C)`` where ``K`` is ``n_frames``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + n_frames = int(n_frames) + channels = int(channels) + super().__init__( + lmax=lmax, + mmax=mmax, + coefficient_layout=coefficient_layout, + n_frames=n_frames, + channels=channels, + in_channels=channels, + out_channels=n_frames * channels, + init_bound=1.0 / math.sqrt(channels), + precision=precision, + trainable=trainable, + seed=seed, + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the FrameExpand to a dict. + + The pt ``FrameExpand`` has no ``serialize()``; the ``@variables`` key + (``weight``) matches the pt ``state_dict`` key name. ``degree_index`` is + a non-persistent buffer in pt and is rebuilt from the config. + """ + return { + "@class": "FrameExpand", + "@version": 1, + "config": self._serialize_config(), + "@variables": {"weight": to_numpy_array(self.weight)}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> FrameExpand: + """Deserialize a FrameExpand from a dict.""" + return cls._deserialize(data) + + class GridProduct(NativeOP): """Parameter-free quadratic grid product ``u(g) * v(g)``.""" diff --git a/source/tests/common/dpmodel/test_dpa4_frame_mixers.py b/source/tests/common/dpmodel/test_dpa4_frame_mixers.py new file mode 100644 index 0000000000..9f37c0bd8b --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_frame_mixers.py @@ -0,0 +1,250 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 SO3-grid per-degree frame mixers. + +These mirror the current pt +``deepmd.pt.model.descriptor.sezm_nn.grid_net`` ``FrameContract`` / +``FrameExpand`` (and the ``_build_frame_degree_index`` helper). The pt mixers +realise a per-degree ``einsum("ndfi,dio->ndfo", coeff, weight[degree_index])``; +the dpmodel port realises the same map as a broadcast batched ``xp.matmul``. + +pt imports live inside the test functions because ruff TID253 bans +module-level ``deepmd.pt`` imports under ``source/tests/common``. pt modules +are pinned to CPU so ``torch.from_numpy`` fp64 inputs and the module agree +under the CUDA-default-device CI configuration. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + FrameContract as DPFrameContract, + FrameExpand as DPFrameExpand, + _build_frame_degree_index, +) + +# (lmax, channels, kmax); n_frames K = 2 * kmax + 1 +_CASES = [(2, 4, 1), (3, 2, 2)] + + +def _copy_weight(pt_mod, dp_mod) -> None: + """Copy the pt mixer ``weight`` state-dict entry into the dpmodel mixer.""" + state = {k: v.detach().cpu().numpy() for k, v in pt_mod.state_dict().items()} + assert set(state) == {"weight"}, state.keys() + dp_mod.weight = state["weight"] + + +@pytest.mark.parametrize("lmax,channels,kmax", _CASES) # degree, channels, kmax +def test_frame_contract_parity(lmax, channels, kmax) -> None: + """The dpmodel ``FrameContract`` matches pt with weight-copied fp64 weights.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + FrameContract as PTFrameContract, + ) + + n_frames = 2 * kmax + 1 + coeff_dim = (lmax + 1) ** 2 + n_batch, n_focus = 5, 2 + rng = np.random.default_rng(2026) + + pt_mod = PTFrameContract( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + dtype=torch.float64, + trainable=True, + seed=7, + ).to("cpu") + dp_mod = DPFrameContract( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=7, + ) + _copy_weight(pt_mod, dp_mod) + + coeff = rng.normal(size=(n_batch, coeff_dim, n_focus, n_frames * channels)) + dp_out = dp_mod.call(coeff) + pt_out = pt_mod(torch.from_numpy(coeff)) + assert dp_out.shape == (n_batch, coeff_dim, n_focus, channels) + np.testing.assert_allclose( + np.asarray(dp_out), pt_out.detach().cpu().numpy(), rtol=1e-12, atol=1e-12 + ) + + +@pytest.mark.parametrize("lmax,channels,kmax", _CASES) # degree, channels, kmax +def test_frame_expand_parity(lmax, channels, kmax) -> None: + """The dpmodel ``FrameExpand`` matches pt with weight-copied fp64 weights.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + FrameExpand as PTFrameExpand, + ) + + n_frames = 2 * kmax + 1 + coeff_dim = (lmax + 1) ** 2 + n_batch, n_focus = 5, 2 + rng = np.random.default_rng(2027) + + pt_mod = PTFrameExpand( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + dtype=torch.float64, + trainable=True, + seed=11, + ).to("cpu") + dp_mod = DPFrameExpand( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=11, + ) + _copy_weight(pt_mod, dp_mod) + + coeff = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + dp_out = dp_mod.call(coeff) + pt_out = pt_mod(torch.from_numpy(coeff)) + assert dp_out.shape == (n_batch, coeff_dim, n_focus, n_frames * channels) + np.testing.assert_allclose( + np.asarray(dp_out), pt_out.detach().cpu().numpy(), rtol=1e-12, atol=1e-12 + ) + + +@pytest.mark.parametrize("lmax,channels,kmax", _CASES) # degree, channels, kmax +def test_expand_then_contract_shapes(lmax, channels, kmax) -> None: + """Shape round-trip ``(N,D,F,C) -> expand -> (N,D,F,K*C) -> contract -> (N,D,F,C)``.""" + n_frames = 2 * kmax + 1 + coeff_dim = (lmax + 1) ** 2 + n_batch, n_focus = 3, 2 + rng = np.random.default_rng(404) + + expand = DPFrameExpand( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=1, + ) + contract = DPFrameContract( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=2, + ) + coeff = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + expanded = expand.call(coeff) + assert expanded.shape == (n_batch, coeff_dim, n_focus, n_frames * channels) + contracted = contract.call(expanded) + assert contracted.shape == (n_batch, coeff_dim, n_focus, channels) + + +@pytest.mark.parametrize("cls", [DPFrameContract, DPFrameExpand]) # mixer class +def test_serialize_roundtrip(cls) -> None: + """Serialize -> deserialize -> forward is identical; @version == 1.""" + lmax, channels, n_frames = 2, 4, 3 + coeff_dim = (lmax + 1) ** 2 + n_batch, n_focus = 3, 2 + rng = np.random.default_rng(505) + + mod = cls( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=5, + ) + # perturb to non-default weights + mod.weight = mod.weight + 0.1 * rng.normal(size=mod.weight.shape) + + data = mod.serialize() + assert data["@version"] == 1 + assert data["config"]["n_frames"] == n_frames + assert set(data["@variables"]) == {"weight"} + restored = cls.deserialize(data) + np.testing.assert_array_equal(restored.weight, mod.weight) + + in_ch = mod.weight.shape[1] + coeff = rng.normal(size=(n_batch, coeff_dim, n_focus, in_ch)) + out0 = mod.call(coeff) + out1 = restored.call(coeff) + np.testing.assert_allclose( + np.asarray(out0), np.asarray(out1), rtol=1e-12, atol=1e-12 + ) + + +@pytest.mark.parametrize("lmax,mmax", [(2, 2), (3, 3), (3, 1)]) # degree, order +@pytest.mark.parametrize("layout", ["packed", "m_major"]) # coefficient layout +def test_degree_index(lmax, mmax, layout) -> None: + """``_build_frame_degree_index`` maps each (l, m) row to its degree l. + + Compared against the pt helper output. + """ + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + _build_frame_degree_index as pt_build, + ) + + dp_idx = _build_frame_degree_index(lmax=lmax, mmax=mmax, coefficient_layout=layout) + pt_idx = pt_build(lmax=lmax, mmax=mmax, coefficient_layout=layout) + np.testing.assert_array_equal(np.asarray(dp_idx), pt_idx.detach().cpu().numpy()) + # explicit (l, m) check for the packed, untruncated case + if layout == "packed" and mmax == lmax: + expected = np.repeat(np.arange(lmax + 1), [2 * l + 1 for l in range(lmax + 1)]) + np.testing.assert_array_equal(np.asarray(dp_idx), expected) + + +@pytest.mark.parametrize("cls", [DPFrameContract, DPFrameExpand]) # mixer class +def test_torch_namespace(cls) -> None: + """Mixer ``call`` on torch input matches the numpy-input result. + + Array-API pitfall guard (no ``np.einsum`` on tensors). + """ + import torch + + lmax, channels, n_frames = 2, 4, 3 + coeff_dim = (lmax + 1) ** 2 + n_batch, n_focus = 3, 2 + rng = np.random.default_rng(606) + + mod = cls( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=9, + ) + in_ch = mod.weight.shape[1] + coeff = rng.normal(size=(n_batch, coeff_dim, n_focus, in_ch)) + np_out = mod.call(coeff) + torch_out = mod.call(torch.from_numpy(coeff)) + np.testing.assert_allclose( + np.asarray(np_out), + torch_out.detach().cpu().numpy(), + rtol=1e-12, + atol=1e-12, + ) From 8f0b63a43700ddfd981a664691b2363627b9c1dd Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 17 Jun 2026 22:34:53 +0800 Subject: [PATCH 06/18] feat(dpmodel): port resolve_so3_grid + _build_so3_frame_set for DPA4 SO3 grid --- .../dpmodel/descriptor/dpa4_nn/projection.py | 52 +++++++++++++++++++ .../dpmodel/test_dpa4_so3_grid_utils.py | 40 ++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py index e5131b5d12..7666b32c2c 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py @@ -370,3 +370,55 @@ def _normalize_s2_grid_resolution( if resolution[0] < 1 or resolution[1] < 1: raise ValueError("grid resolutions must be positive") return resolution + + +def resolve_so3_grid( + lmax: int, + *, + kmax: int = 1, + lebedev_precision: int | None = None, +) -> tuple[int, int, int]: + """ + Resolve the default SO(3) quadrature as Lebedev sphere times gamma samples. + + The Lebedev precision follows the same conservative ``3*lmax`` rule used by + the S2 grid path. The gamma grid is chosen for the quadratic grid products + used by the SO(3) grid nets, whose third-angle frequency can reach + ``k1 + k2 - kout``. + """ + lmax_i = int(lmax) + kmax_i = int(kmax) + if kmax_i < 0: + raise ValueError("`kmax` must be non-negative") + if lebedev_precision is None: + required_precision = 3 * lmax_i + for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): + if precision >= required_precision: + lebedev_precision = precision + lebedev_npoints = n_points + break + else: + raise ValueError( + f"No packaged Lebedev rule has precision >= {required_precision}" + ) + else: + lebedev_precision = int(lebedev_precision) + lebedev_npoints = LEBEDEV_PRECISION_TO_NPOINTS.get(lebedev_precision) + if lebedev_npoints is None: + raise ValueError( + f"Lebedev rule with precision {lebedev_precision} is not packaged" + ) + + # A quadratic product followed by analysis can contain gamma frequencies up + # to ``3*kmax``. A uniform grid with more samples than that frequency + # resolves the integer Fourier modes exactly. + n_gamma = 1 if kmax_i == 0 else 3 * kmax_i + 1 + return int(lebedev_precision), int(lebedev_npoints), int(n_gamma) + + +def _build_so3_frame_set(kmax: int) -> list[int]: + """Build the symmetric frame-index set with zero first.""" + kmax_i = int(kmax) + if kmax_i < 0: + raise ValueError("`kmax` must be non-negative") + return [0, *[frame for kk in range(1, kmax_i + 1) for frame in (-kk, kk)]] diff --git a/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py b/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py new file mode 100644 index 0000000000..58335abab7 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 SO(3) grid utility functions. + +Compares the dpmodel ports of ``resolve_so3_grid`` and ``_build_so3_frame_set`` +against the reference pt implementations. +""" + +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.projection import ( + _build_so3_frame_set, + resolve_so3_grid, +) + + +@pytest.mark.parametrize("kmax", [0, 1, 2, 3]) # frame-index half-width +def test_build_so3_frame_set(kmax) -> None: + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + _build_so3_frame_set as pt_build_so3_frame_set, + ) + + assert _build_so3_frame_set(kmax) == pt_build_so3_frame_set(kmax) + if kmax == 2: + assert _build_so3_frame_set(kmax) == [0, -1, 1, -2, 2] + + +@pytest.mark.parametrize( + "lmax,kmax", # max angular momentum, frame-index half-width + [(1, 1), (2, 1), (2, 2), (3, 1), (3, 2)], +) +def test_resolve_so3_grid(lmax, kmax) -> None: + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + resolve_so3_grid as pt_resolve_so3_grid, + ) + + dp_result = resolve_so3_grid(lmax, kmax=kmax) + pt_result = pt_resolve_so3_grid(lmax, kmax=kmax) + assert dp_result == pt_result + n_gamma = dp_result[2] + assert n_gamma == (1 if kmax == 0 else 3 * kmax + 1) From 937abc9aa8dbc69e4cea98aa9633cde979fb4938 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 17 Jun 2026 22:36:48 +0800 Subject: [PATCH 07/18] test(dpmodel): cover SO3 grid util edge branches; refresh projection docstring --- .../dpmodel/descriptor/dpa4_nn/projection.py | 9 ++-- .../dpmodel/test_dpa4_so3_grid_utils.py | 41 +++++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py index 7666b32c2c..0ebd147c24 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py @@ -13,11 +13,10 @@ ``resolve_s2_grid_resolution`` (as-is, both methods — pure arithmetic), and ``_normalize_s2_grid_resolution``. -Skipped names (SO(3) Wigner-D grid machinery; consumed only by -``SO3GridNet`` in pt ``grid_net.py``, which backs the ``node_wise_so3``, -``message_node_so3``, and ``ffn_so3_grid`` paths — all disabled in the core -DPA4 config): ``SO3GridProjector``, ``resolve_so3_grid``, -``_build_so3_frame_set``. +SO(3) Wigner-D grid machinery (consumed by ``SO3GridNet`` in +``grid_net.py``, which backs the ``node_wise_so3``, ``message_node_so3``, +and ``ffn_so3_grid`` paths): ``resolve_so3_grid`` and ``_build_so3_frame_set`` +are ported here; ``SO3GridProjector`` is still pending. Not ported (guarded): the e3nn product-grid branch of ``S2GridProjector`` (``grid_method="e3nn"``, i.e. ``lebedev_quadrature=False``) raises diff --git a/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py b/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py index 58335abab7..1e29eeb4e2 100644 --- a/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py +++ b/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py @@ -38,3 +38,44 @@ def test_resolve_so3_grid(lmax, kmax) -> None: assert dp_result == pt_result n_gamma = dp_result[2] assert n_gamma == (1 if kmax == 0 else 3 * kmax + 1) + + +def test_resolve_so3_grid_kmax_zero() -> None: + """kmax=0 collapses the gamma grid to a single sample (n_gamma=1).""" + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + resolve_so3_grid as pt_resolve_so3_grid, + ) + + dp_result = resolve_so3_grid(2, kmax=0) + assert dp_result == pt_resolve_so3_grid(2, kmax=0) + assert dp_result[2] == 1 + + +def test_resolve_so3_grid_explicit_precision() -> None: + """An explicitly supplied (packaged) Lebedev precision is honored.""" + from deepmd.dpmodel.utils.lebedev import ( + LEBEDEV_PRECISION_TO_NPOINTS, + ) + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + resolve_so3_grid as pt_resolve_so3_grid, + ) + + precision = sorted(LEBEDEV_PRECISION_TO_NPOINTS)[3] + dp_result = resolve_so3_grid(2, kmax=1, lebedev_precision=precision) + assert dp_result == pt_resolve_so3_grid(2, kmax=1, lebedev_precision=precision) + assert dp_result[0] == precision + + +def test_resolve_so3_grid_unpackaged_precision_raises() -> None: + """An unpackaged explicit precision raises ValueError.""" + with pytest.raises(ValueError, match="not packaged"): + resolve_so3_grid(2, kmax=1, lebedev_precision=999999) + + +@pytest.mark.parametrize("kmax", [-1, -2]) # negative half-width is invalid +def test_negative_kmax_raises(kmax) -> None: + """Both utilities reject negative kmax.""" + with pytest.raises(ValueError, match="non-negative"): + _build_so3_frame_set(kmax) + with pytest.raises(ValueError, match="non-negative"): + resolve_so3_grid(2, kmax=kmax) From 09d3ddda00a9ed65b2b01b007b51d662769d6443 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 17 Jun 2026 22:41:37 +0800 Subject: [PATCH 08/18] feat(dpmodel): port SO3GridProjector (Wigner-D grid quadrature) for DPA4 --- .../dpmodel/descriptor/dpa4_nn/projection.py | 163 +++++++++++++++++- .../common/dpmodel/test_dpa4_so3_projector.py | 142 +++++++++++++++ 2 files changed, 303 insertions(+), 2 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_so3_projector.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py index 0ebd147c24..c487e16773 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py @@ -15,8 +15,11 @@ SO(3) Wigner-D grid machinery (consumed by ``SO3GridNet`` in ``grid_net.py``, which backs the ``node_wise_so3``, ``message_node_so3``, -and ``ffn_so3_grid`` paths): ``resolve_so3_grid`` and ``_build_so3_frame_set`` -are ported here; ``SO3GridProjector`` is still pending. +and ``ffn_so3_grid`` paths): ``resolve_so3_grid``, ``_build_so3_frame_set``, +and ``SO3GridProjector`` are ported here. The SO(3) projection matrices are +assembled at init time with pure numpy via the Wigner-D quadrature +(``WignerDCalculator``) over a Lebedev x gamma rotation grid, matching the pt +float64 buffers to machine precision. Not ported (guarded): the e3nn product-grid branch of ``S2GridProjector`` (``grid_method="e3nn"``, i.e. ``lebedev_quadrature=False``) raises @@ -65,6 +68,13 @@ from .indexing import ( build_l_major_index, build_m_major_index, + so3_packed_index, +) +from .wignerd import ( + WignerDCalculator, + build_edge_quaternion, + quaternion_multiply, + quaternion_z_rotation, ) @@ -302,6 +312,155 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridProjector: ) +class SO3GridProjector(BaseGridProjector): + """ + Project SO(3) coefficients to/from a Wigner-D grid with frame indices. + + The coefficient axis is packed as ``(l, m, k)`` with ordinary SeZM + ``(l, m)`` order outside and the configured frame set inside each row. A + frame index outside ``[-l, l]`` is kept as a zero column/row. This keeps the + tensor layout regular while preserving the exact per-degree frame support. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order kept in the coefficient layout. If None, use ``lmax``. + kmax + Frame-index half-width; the frame set is ``{0, -1, 1, ..., -kmax, kmax}``. + precision + Buffer precision used by the projection matrices. + lebedev_precision + Explicit Lebedev rule precision. If None, resolved automatically. + coefficient_layout + Coefficient ordering expected by the caller: + - ``"packed"``: packed ``(l, m)`` order, optionally truncated by ``mmax``. + - ``"m_major"``: reduced m-major order used inside ``SO2Convolution``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + kmax: int = 1, + precision: str = DEFAULT_PRECISION, + lebedev_precision: int | None = None, + coefficient_layout: str = "packed", + ) -> None: + lmax_i = int(lmax) + mmax_i = int(lmax_i if mmax is None else mmax) + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") + self.frame_set = _build_so3_frame_set(self.kmax) + self.frame_zero_index = self.frame_set.index(0) + self.lebedev_precision, self.lebedev_npoints, self.n_gamma = resolve_so3_grid( + lmax_i, + kmax=self.kmax, + lebedev_precision=lebedev_precision, + ) + super().__init__( + lmax=lmax_i, + mmax=mmax_i, + precision=precision, + n_frames=len(self.frame_set), + coefficient_layout=coefficient_layout, + ) + # plain numpy int64 attribute (becomes a torch buffer in pt_expt later) + self.frame_values = np.asarray(self.frame_set, dtype=np.int64) + + def _build_projection_mats( + self, + coeff_index: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + points, weights = load_lebedev_rule(self.lebedev_precision) + points = np.asarray(points, dtype=np.float64) + weights = np.asarray(weights, dtype=np.float64) + gamma = np.arange(self.n_gamma, dtype=np.float64) * ( + 2.0 * math.pi / float(self.n_gamma) + ) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + # torch ``repeat_interleave(n_gamma, dim=0)`` -> numpy ``repeat`` on axis 0 + edge_quaternion = np.repeat(edge_quaternion, self.n_gamma, axis=0) + # torch ``repeat(n_points, 1)`` (tile) -> numpy ``tile`` + gamma_quaternion = np.tile(quaternion_z_rotation(gamma), (points.shape[0], 1)) + grid_quaternion = quaternion_multiply(gamma_quaternion, edge_quaternion) + # WignerDCalculator.__call__ returns ``(D_full, Dt_full)``; take D_full. + wigner_grid = WignerDCalculator(self.lmax, precision="float64")( + grid_quaternion + )[0] + # ``build_edge_quaternion`` follows SeZM's global-to-local convention. + # The transpose below stores the local m=0 column in the same layout as + # ``WignerDCalculator.forward_zonal`` and extends it to k != 0. + wigner_grid = np.ascontiguousarray(np.swapaxes(wigner_grid, -1, -2)) + haar_weight = np.repeat(weights, self.n_gamma) / float(self.n_gamma) + + grid_size = int(grid_quaternion.shape[0]) + n_frames = len(self.frame_set) + coeff_dim = int(coeff_index.shape[0]) * n_frames + to_grid_mat = np.zeros((grid_size, coeff_dim), dtype=np.float64) + from_grid_mat = np.zeros((coeff_dim, grid_size), dtype=np.float64) + + for degree in range(self.lmax + 1): + degree_factor = float(2 * degree + 1) + for m_order in range(-degree, degree + 1): + packed_idx = so3_packed_index(degree, m_order) + # init-time numpy control-flow index; replaces pt's + # ``(coeff_index == packed_idx).nonzero(as_tuple=False)`` + coeff_positions = np.argwhere(coeff_index == packed_idx) + if coeff_positions.size == 0: + continue + coeff_pos = int(coeff_positions[0, 0]) + for frame_pos, frame_order in enumerate(self.frame_set): + flat_idx = coeff_pos * n_frames + frame_pos + if abs(frame_order) > degree: + continue + row = so3_packed_index(degree, m_order) + col = so3_packed_index(degree, frame_order) + values = wigner_grid[:, row, col] + to_grid_mat[:, flat_idx] = values + from_grid_mat[flat_idx, :] = degree_factor * haar_weight * values + return to_grid_mat, from_grid_mat + + def serialize(self) -> dict[str, Any]: + """Serialize the SO3GridProjector to a dict (pt-compatible format).""" + return { + "@class": "SO3GridProjector", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "kmax": self.kmax, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "lebedev_precision": self.lebedev_precision, + "coefficient_layout": self.coefficient_layout, + }, + "@variables": {}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO3GridProjector: + """Deserialize an SO3GridProjector from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO3GridProjector": + raise ValueError(f"Invalid class for SO3GridProjector: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + data.pop("@variables", None) + return cls( + lmax=int(config["lmax"]), + mmax=int(config["mmax"]), + kmax=int(config["kmax"]), + precision=str(config["precision"]), + lebedev_precision=int(config["lebedev_precision"]), + coefficient_layout=str(config["coefficient_layout"]), + ) + + def resolve_s2_grid_resolution( lmax: int, mmax: int, diff --git a/source/tests/common/dpmodel/test_dpa4_so3_projector.py b/source/tests/common/dpmodel/test_dpa4_so3_projector.py new file mode 100644 index 0000000000..01274f134c --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so3_projector.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the dpmodel ``SO3GridProjector`` (Wigner-D grid quadrature). + +Compares the dpmodel port against the reference pt implementation +(``deepmd.pt.model.descriptor.sezm_nn.projection.SO3GridProjector``) and checks +the legal-frame round-trip identity, serialization round-trip, and the kmax=0 +zonal convention. pt imports live inside the test functions to satisfy the +``source/tests/common`` import-isolation rule (ruff TID253). +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.projection import ( + SO3GridProjector, +) + + +def _legal_so3_frame_mask(projector: SO3GridProjector) -> np.ndarray: + """Build the boolean mask of legal ``(l, m, k)`` flat-coefficient slots.""" + mask = np.ones(projector.coeff_dim, dtype=np.bool_) + n_frames = projector.n_frames + for degree in range(projector.lmax + 1): + for m_order in range(-degree, degree + 1): + packed_idx = degree * degree + degree + m_order + for frame_pos, frame_order in enumerate(projector.frame_set): + flat_idx = packed_idx * n_frames + frame_pos + if flat_idx >= projector.coeff_dim: + continue + if abs(frame_order) > degree: + mask[flat_idx] = False + return mask + + +# (lmax, kmax, mmax): max degree, frame-index half-width, retained order +_CASES = [(1, 1, 1), (2, 1, 1), (2, 2, 2), (3, 1, 1)] + + +@pytest.mark.parametrize("lmax,kmax,mmax", _CASES) +def test_projection_matrices_match_pt(lmax, kmax, mmax) -> None: + """Dpmodel projection matrices match the pt buffers at fp64.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + SO3GridProjector as PTSO3GridProjector, + ) + + dp = SO3GridProjector(lmax=lmax, mmax=mmax, kmax=kmax, precision="float64") + pt = PTSO3GridProjector(lmax=lmax, mmax=mmax, kmax=kmax, dtype=torch.float64) + + np.testing.assert_allclose( + dp.to_grid_mat, + pt.to_grid_mat.detach().cpu().numpy(), + atol=1e-12, + rtol=1e-12, + ) + np.testing.assert_allclose( + dp.from_grid_mat, + pt.from_grid_mat.detach().cpu().numpy(), + atol=1e-12, + rtol=1e-12, + ) + assert dp.n_frames == pt.n_frames + assert dp.coeff_dim == pt.coeff_dim + assert dp.grid_size == pt.grid_size + assert dp.frame_set == pt.frame_set + + +@pytest.mark.parametrize("lmax", [1, 2, 3, 4, 5, 6]) # max degree +def test_roundtrip_preserves_legal_frame_coeffs(lmax) -> None: + """Project legal-frame coefficients to grid and back; recovery to 1e-12.""" + rng = np.random.default_rng(8100 + lmax) + projector = SO3GridProjector(lmax=lmax, kmax=1, precision="float64") + x = rng.standard_normal((2, projector.coeff_dim, 2)).astype(np.float64) + mask = _legal_so3_frame_mask(projector) + x[:, ~mask, :] = 0.0 + y = projector.from_grid(projector.to_grid(x)) + np.testing.assert_allclose(y[:, mask, :], x[:, mask, :], atol=1e-12, rtol=1e-12) + assert float(np.max(np.abs(y[:, ~mask, :]))) < 1e-14 + + +@pytest.mark.parametrize("lmax,kmax,mmax", _CASES) +def test_serialize_roundtrip(lmax, kmax, mmax) -> None: + """Serialize -> deserialize reproduces the matrices and config keys.""" + projector = SO3GridProjector(lmax=lmax, mmax=mmax, kmax=kmax, precision="float64") + data = projector.serialize() + assert data["@class"] == "SO3GridProjector" + assert data["@version"] == 1 + config = data["config"] + for key in ( + "lmax", + "mmax", + "kmax", + "precision", + "lebedev_precision", + "coefficient_layout", + ): + assert key in config + # the matrices must NOT be serialized (rebuilt at deserialize) + assert "to_grid_mat" not in config + assert "from_grid_mat" not in config + + restored = SO3GridProjector.deserialize(data) + np.testing.assert_array_equal(restored.to_grid_mat, projector.to_grid_mat) + np.testing.assert_array_equal(restored.from_grid_mat, projector.from_grid_mat) + assert restored.frame_set == projector.frame_set + + +def test_kmax_zero_zonal() -> None: + """kmax=0 collapses to a single frame and matches the Wigner zonal column.""" + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + build_edge_quaternion, + ) + from deepmd.dpmodel.utils.lebedev import ( + load_lebedev_rule, + ) + + lmax = 6 + projector = SO3GridProjector(lmax=lmax, kmax=0, precision="float64") + assert projector.n_frames == 1 + + points, _ = load_lebedev_rule(projector.lebedev_precision) + points = np.asarray(points, dtype=np.float64) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + zonal = WignerDCalculator(lmax, precision="float64").forward_zonal( + edge_quaternion, lmin=1 + ) + np.testing.assert_allclose( + projector.to_grid_mat[:, 0], np.ones_like(points[:, 0]), atol=1e-14, rtol=1e-14 + ) + np.testing.assert_allclose( + projector.to_grid_mat[:, 1:], zonal, atol=1e-14, rtol=1e-14 + ) + + # the single-frame projector still round-trips legal coefficients + rng = np.random.default_rng(99) + x = rng.standard_normal((2, projector.coeff_dim, 2)).astype(np.float64) + mask = _legal_so3_frame_mask(projector) + x[:, ~mask, :] = 0.0 + y = projector.from_grid(projector.to_grid(x)) + np.testing.assert_allclose(y[:, mask, :], x[:, mask, :], atol=1e-12, rtol=1e-12) From 28269e50409e8e86487ac9b803e9abbbe5e97f99 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 22:59:31 +0800 Subject: [PATCH 09/18] feat(dpmodel): port SO3GridNet (self+cross) for DPA4 SO3 grid --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 259 ++++++++++- .../common/dpmodel/test_dpa4_so3_gridnet.py | 408 ++++++++++++++++++ 2 files changed, 663 insertions(+), 4 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_so3_gridnet.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 68e0e056b8..a5fee5a86f 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -21,10 +21,10 @@ ``n_frames`` (the ``_to_grid``/``_from_grid`` frame-axis contraction). The S2 path (``n_frames == 1``, ``mode='self'``) keeps a dedicated fast branch that is byte-identical to the previous S2-only specialization. The SO(3) frame -machinery (``SO3GridNet``, ``FrameContract``, ``FrameExpand``) is not ported -here; ``BaseGridNet`` exposes ``frame_expand``/``frame_contract`` seams (kept -``None`` for S2) so a later SO(3) port can plug them in without touching the -shared forward. +machinery (``SO3GridNet``, ``FrameContract``, ``FrameExpand``) is ported here; +``SO3GridNet`` builds an ``SO3GridProjector`` (``n_frames = 2 * kmax + 1``) and, +in ``mode='cross'``, plugs ``FrameExpand``/``FrameContract`` into the +``BaseGridNet`` ``frame_expand``/``frame_contract`` seams (kept ``None`` for S2). Serialization contract: the pt ``S2GridNet`` and ``GridBranch`` define no ``serialize()`` (they only appear nested inside larger modules' @@ -81,6 +81,7 @@ from .projection import ( BaseGridProjector, S2GridProjector, + SO3GridProjector, ) from .so3 import ( ChannelLinear, @@ -1343,3 +1344,253 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridNet: } ) return obj + + +class SO3GridNet(BaseGridNet): + """Grid net using a Wigner-D SO(3) projector with frame indices. + + dpmodel port of the current pt + ``deepmd.pt.model.descriptor.sezm_nn.grid_net.SO3GridNet``. Unlike + ``S2GridNet`` (``n_frames == 1``), the SO(3) projector packs + ``n_frames = 2 * kmax + 1`` Wigner-D frames along the trailing channel axis, + exercising the general ``n_frames > 1`` ``_to_grid``/``_from_grid`` paths of + ``BaseGridNet``. In ``mode='cross'`` it additionally builds the per-degree + :class:`FrameExpand` / :class:`FrameContract` channel mixers and plugs them + into the ``BaseGridNet`` ``frame_expand``/``frame_contract`` seam: the query + and context are expanded ``C -> n_frames * C`` before the grid product and + contracted ``n_frames * C -> C`` afterwards. + + Parameters + ---------- + lmax : int + Maximum spherical harmonic degree. + mmax : int | None + Maximum order kept in the coefficient layout. If None, use ``lmax``. + kmax : int + Frame band width; ``n_frames = 2 * kmax + 1`` Wigner-D frames. + channels : int + Number of channels per (l, m) coefficient (per frame). + n_focus : int + Number of focus streams. + mode : str + Pairing mode; ``"self"`` or ``"cross"``. + op_type : str + Point-wise grid operation; ``"glu"``, ``"mlp"`` or ``"branch"``. + precision : str + Parameter precision. + layout : str + Tensor layout convention: ``"ndfc"``, ``"nfdc"`` or ``"flat"`` + (``"flat"`` is cross-only). + lebedev_precision : int | None + Lebedev algebraic precision; resolved automatically if None. + coefficient_layout : str + ``"packed"`` or ``"m_major"`` coefficient ordering. + grid_branches : int + Number of scalar-routed branches when ``op_type='branch'``. + residual_scale_init : float | None + Initial value of the per-(focus, channel) residual scale; ``None`` + disables the residual scale. + mlp_bias : bool + Whether to use bias in the scalar gate projection. + trainable : bool + Whether parameters are trainable. + seed : int | list[int] | None + Random seed for weight initialization. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + kmax: int = 1, + channels: int, + n_focus: int = 1, + mode: str, + op_type: str, + precision: str = DEFAULT_PRECISION, + layout: str, + lebedev_precision: int | None = None, + coefficient_layout: str = "packed", + grid_branches: int = 1, + residual_scale_init: float | None = None, + mlp_bias: bool = False, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + projector = SO3GridProjector( + lmax=lmax, + mmax=mmax, + kmax=kmax, + precision=precision, + lebedev_precision=lebedev_precision, + coefficient_layout=coefficient_layout, + ) + self.frames = projector.frame_set + self.kmax = projector.kmax + self.lebedev_precision = projector.lebedev_precision + self.n_gamma = projector.n_gamma + self.grid_branches = int(grid_branches) + frame_expand: FrameExpand | None = None + frame_contract: FrameContract | None = None + if str(mode).lower() == "cross": + # pt builds the frame mixers with child_seed(seed, 4)/(seed, 5); + # ``BaseGridNet`` uses child_seed(seed, 0)/(seed, 1) for scalar_gate + # /grid_op, so these branches never collide. + frame_expand = FrameExpand( + lmax=lmax, + mmax=projector.mmax, + coefficient_layout=coefficient_layout, + n_frames=projector.n_frames, + channels=channels, + precision=precision, + trainable=trainable, + seed=child_seed(seed, 4), + ) + frame_contract = FrameContract( + lmax=lmax, + mmax=projector.mmax, + coefficient_layout=coefficient_layout, + n_frames=projector.n_frames, + channels=channels, + precision=precision, + trainable=trainable, + seed=child_seed(seed, 5), + ) + super().__init__( + projector=projector, + channels=channels, + n_focus=n_focus, + mode=mode, + op_type=op_type, + precision=precision, + layout=layout, + mlp_bias=mlp_bias, + trainable=trainable, + grid_branches=grid_branches, + frame_expand=frame_expand, + frame_contract=frame_contract, + residual_scale_init=residual_scale_init, + seed=seed, + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the SO3GridNet to a dict. + + The pt ``SO3GridNet`` has no ``serialize()``; the ``@variables`` keys + here match the pt ``state_dict`` key names (``scalar_gate.weight``, + ``grid_op.*``, ``frame_expand.weight``, ``frame_contract.weight``, + ``residual_scale``) so pt state-dict fragments load directly. The + projector matrices are non-persistent buffers in pt and are rebuilt + from the nested projector config on deserialization. + """ + variables = {"scalar_gate.weight": to_numpy_array(self.scalar_gate.weight)} + if self.mlp_bias: + variables["scalar_gate.bias"] = to_numpy_array(self.scalar_gate.bias) + if self.op_type in {"mlp", "branch"}: + grid_op_data = self.grid_op.serialize()["@variables"] + for key, value in grid_op_data.items(): + variables[f"grid_op.{key}"] = value + if self.frame_expand is not None: + variables["frame_expand.weight"] = to_numpy_array(self.frame_expand.weight) + if self.frame_contract is not None: + variables["frame_contract.weight"] = to_numpy_array( + self.frame_contract.weight + ) + if self.residual_scale is not None: + variables["residual_scale"] = to_numpy_array(self.residual_scale) + return { + "@class": "SO3GridNet", + "@version": 1, + "config": { + "channels": self.channels, + "n_focus": self.n_focus, + "mode": self.mode, + "op_type": self.op_type, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "layout": self.layout, + "grid_branches": self.grid_branches, + "residual_scale_init": self.residual_scale_init, + "mlp_bias": self.mlp_bias, + "trainable": self.trainable, + "seed": None, + "projector": self.projector.serialize(), + }, + "@variables": variables, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO3GridNet: + """Deserialize an SO3GridNet from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO3GridNet": + raise ValueError(f"Invalid class for SO3GridNet: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + projector_config = config["projector"]["config"] + obj = cls( + lmax=int(projector_config["lmax"]), + mmax=int(projector_config["mmax"]), + kmax=int(projector_config["kmax"]), + channels=int(config["channels"]), + n_focus=int(config["n_focus"]), + mode=str(config["mode"]), + op_type=str(config["op_type"]), + precision=str(config["precision"]), + layout=str(config["layout"]), + lebedev_precision=int(projector_config["lebedev_precision"]), + coefficient_layout=str(projector_config["coefficient_layout"]), + grid_branches=int(config["grid_branches"]), + residual_scale_init=config.get("residual_scale_init"), + mlp_bias=bool(config["mlp_bias"]), + trainable=bool(config["trainable"]), + seed=config.get("seed"), + ) + prec = PRECISION_DICT[obj.precision.lower()] + weight = np.asarray(variables["scalar_gate.weight"], dtype=prec) + if weight.shape != obj.scalar_gate.weight.shape: + raise ValueError( + f"scalar_gate.weight shape {weight.shape} does not match " + f"the expected shape {obj.scalar_gate.weight.shape}" + ) + obj.scalar_gate.weight = weight + if obj.mlp_bias: + obj.scalar_gate.bias = np.asarray( + variables["scalar_gate.bias"], dtype=prec + ).reshape(obj.scalar_gate.bias.shape) + if obj.residual_scale is not None: + residual_scale = np.asarray(variables["residual_scale"], dtype=prec) + if residual_scale.shape != obj.residual_scale.shape: + raise ValueError( + f"residual_scale shape {residual_scale.shape} does not match " + f"the expected shape {obj.residual_scale.shape}" + ) + obj.residual_scale = residual_scale + if obj.frame_expand is not None: + expand_weight = np.asarray(variables["frame_expand.weight"], dtype=prec) + if expand_weight.shape != obj.frame_expand.weight.shape: + raise ValueError( + f"frame_expand.weight shape {expand_weight.shape} does not " + f"match the expected shape {obj.frame_expand.weight.shape}" + ) + obj.frame_expand.weight = expand_weight + if obj.frame_contract is not None: + contract_weight = np.asarray(variables["frame_contract.weight"], dtype=prec) + if contract_weight.shape != obj.frame_contract.weight.shape: + raise ValueError( + f"frame_contract.weight shape {contract_weight.shape} does " + f"not match the expected shape {obj.frame_contract.weight.shape}" + ) + obj.frame_contract.weight = contract_weight + if obj.op_type in {"mlp", "branch"}: + obj.grid_op._load_variables( + { + key[len("grid_op.") :]: value + for key, value in variables.items() + if key.startswith("grid_op.") + } + ) + return obj diff --git a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py new file mode 100644 index 0000000000..078b7662f4 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py @@ -0,0 +1,408 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity / equivariance tests for the DPA4 ``SO3GridNet``. + +``SO3GridNet`` is the capstone of the SO(3)-grid port: it packs +``n_frames = 2 * kmax + 1`` Wigner-D frames along the trailing channel axis, +exercising the general ``n_frames > 1`` ``_to_grid``/``_from_grid`` paths of +``BaseGridNet`` and, in ``mode='cross'``, the ``FrameExpand``/``FrameContract`` +seam. Tests mirror the pt ``TestSO3GridNet`` in +``source/tests/pt/model/test_descriptor_sezm_grid_projection.py``. + +pt imports live inside the test functions because ruff TID253 bans +module-level ``deepmd.pt`` imports under ``source/tests/common``; pt modules are +pinned to CPU (``.to("cpu")``) under the CUDA-default-device CI. +""" + +import array_api_compat +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet as DPS2GridNet, +) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + SO3GridNet as DPSO3GridNet, +) + + +def _grid_op_param_names(op_type): + return { + "glu": (), + "mlp": ("left_proj", "right_proj", "out_proj"), + "branch": ("left_proj", "right_proj", "router", "out_proj"), + }[op_type] + + +def _build_so3_nets( + *, + mode, + op_type, + layout, + lmax=2, + mmax=None, + kmax=1, + channels=4, + n_focus=1, + grid_branches=1, + mlp_bias=False, + lebedev_precision=None, + residual_scale_init=None, + seed=7, +): + """Build a pt + dp ``SO3GridNet`` with identical (perturbed) weights.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + SO3GridNet as PTSO3GridNet, + ) + + common = { + "lmax": lmax, + "mmax": mmax, + "kmax": kmax, + "channels": channels, + "n_focus": n_focus, + "mode": mode, + "op_type": op_type, + "layout": layout, + "lebedev_precision": lebedev_precision, + "coefficient_layout": "packed", + "grid_branches": grid_branches, + "residual_scale_init": residual_scale_init, + "mlp_bias": mlp_bias, + "trainable": True, + "seed": seed, + } + pt_net = PTSO3GridNet(dtype=torch.float64, **common).to("cpu") + rng = np.random.default_rng(2100) + with torch.no_grad(): + for p in pt_net.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_net = DPSO3GridNet(precision="float64", **common) + + state = {k: v.detach().cpu().numpy() for k, v in pt_net.state_dict().items()} + expected = {"scalar_gate.weight"} + if mlp_bias: + expected.add("scalar_gate.bias") + expected |= {f"grid_op.{n}.weight" for n in _grid_op_param_names(op_type)} + if mode == "cross": + expected |= {"frame_expand.weight", "frame_contract.weight"} + if residual_scale_init is not None: + expected.add("residual_scale") + assert set(state) == expected, set(state) + + dp_net.scalar_gate.weight = state["scalar_gate.weight"] + if mlp_bias: + dp_net.scalar_gate.bias = state["scalar_gate.bias"] + for name in _grid_op_param_names(op_type): + getattr(dp_net.grid_op, name).weight = state[f"grid_op.{name}.weight"] + if mode == "cross": + dp_net.frame_expand.weight = state["frame_expand.weight"] + dp_net.frame_contract.weight = state["frame_contract.weight"] + if residual_scale_init is not None: + dp_net.residual_scale = state["residual_scale"] + return pt_net, dp_net + + +def _make_so3_inputs(*, dp_net, mode, layout, n_batch, rng): + """Build (query, context) for the given mode/layout; context None for self.""" + # D axis is the per-frame coefficient count (frames packed in channels). + coeff_dim = dp_net.projector.coeff_dim // dp_net.n_frames + n_focus = dp_net.n_focus + if mode == "self": + channels = dp_net.query_channels + query = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + return query, None + # cross: both query and context carry ``context_channels`` (== channels). + channels = dp_net.context_channels + if layout == "flat": + query = rng.normal(size=(n_batch, coeff_dim, n_focus * channels)) + context = rng.normal(size=(n_batch, coeff_dim, n_focus * channels)) + else: # ndfc + query = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + context = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + return query, context + + +def _run(net, query, context, backend): + """Run a net with the given backend; return numpy output.""" + if backend == "pt": + import torch + + q = torch.from_numpy(query) + c = None if context is None else torch.from_numpy(context) + return net(q, c).detach().cpu().numpy() + out = net.call(query, None if context is None else context) + return np.asarray(out) + + +def _rotate_ndfc(x, d_matrix): + """Rotate coefficient-layout tensors (N, D, F, C) by per-batch (N, D, D).""" + return np.einsum("nij,njfc->nifc", d_matrix, x) + + +# === parity ========================================================= + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +@pytest.mark.parametrize("kmax", [1, 2]) # frame band width (n_frames = 2*kmax+1) +def test_so3_self_parity(op_type, kmax) -> None: + """mode='self' SO3GridNet matches pt at 1e-12 (n_frames>1 to/from-grid).""" + lmax, n_focus, n_batch = 2, 2, 5 + pt_net, dp_net = _build_so3_nets( + mode="self", + op_type=op_type, + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + ) + rng = np.random.default_rng(11) + query, context = _make_so3_inputs( + dp_net=dp_net, mode="self", layout="ndfc", n_batch=n_batch, rng=rng + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +@pytest.mark.parametrize("kmax", [1, 2]) # frame band width (n_frames = 2*kmax+1) +def test_so3_self_equivariance(op_type, kmax) -> None: + """net(rot(x)) == rot(net(x)) for a shared SO(3) rotation (self mode).""" + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + lmax, n_focus, n_batch = 2, 1, 4 + _, dp_net = _build_so3_nets( + mode="self", + op_type=op_type, + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + grid_branches=2, + ) + rng = np.random.default_rng(33) + query, _ = _make_so3_inputs( + dp_net=dp_net, mode="self", layout="ndfc", n_batch=n_batch, rng=rng + ) + quat = rng.normal(size=(n_batch, 4)) + quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True) + d_matrix = np.asarray(WignerDCalculator(lmax, precision="float64")(quat)[0]) + + y_rot_in = _run(dp_net, _rotate_ndfc(query, d_matrix), None, "dp") + y_then_rot = _rotate_ndfc(_run(dp_net, query, None, "dp"), d_matrix) + np.testing.assert_allclose(y_rot_in, y_then_rot, rtol=1e-10, atol=1e-10) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +@pytest.mark.parametrize("kmax", [1, 2]) # frame band width (n_frames = 2*kmax+1) +def test_so3_cross_parity(op_type, kmax) -> None: + """mode='cross' SO3GridNet matches pt at 1e-12 (frame_expand/contract seam).""" + lmax, n_focus, n_batch = 2, 2, 5 + pt_net, dp_net = _build_so3_nets( + mode="cross", + op_type=op_type, + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + ) + rng = np.random.default_rng(22) + query, context = _make_so3_inputs( + dp_net=dp_net, mode="cross", layout="ndfc", n_batch=n_batch, rng=rng + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + assert dp_out.shape == query.shape + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_so3_cross_equivariance(op_type) -> None: + """net(rot(q), rot(c)) == rot(net(q, c)) for a shared SO(3) rotation.""" + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + lmax, n_focus, n_batch, kmax = 2, 1, 4, 1 + _, dp_net = _build_so3_nets( + mode="cross", + op_type=op_type, + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + grid_branches=2, + ) + rng = np.random.default_rng(44) + query, context = _make_so3_inputs( + dp_net=dp_net, mode="cross", layout="ndfc", n_batch=n_batch, rng=rng + ) + quat = rng.normal(size=(n_batch, 4)) + quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True) + d_matrix = np.asarray(WignerDCalculator(lmax, precision="float64")(quat)[0]) + + y_rot_in = _run( + dp_net, _rotate_ndfc(query, d_matrix), _rotate_ndfc(context, d_matrix), "dp" + ) + y_then_rot = _rotate_ndfc(_run(dp_net, query, context, "dp"), d_matrix) + np.testing.assert_allclose(y_rot_in, y_then_rot, rtol=1e-10, atol=1e-10) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_so3_cross_flat_parity(op_type) -> None: + """mode='cross', layout='flat' SO3GridNet matches pt at 1e-12.""" + lmax, n_focus, n_batch, kmax = 2, 3, 5, 2 + pt_net, dp_net = _build_so3_nets( + mode="cross", + op_type=op_type, + layout="flat", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + ) + rng = np.random.default_rng(55) + query, context = _make_so3_inputs( + dp_net=dp_net, mode="cross", layout="flat", n_batch=n_batch, rng=rng + ) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + assert dp_out.shape == query.shape + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +# === serialize ====================================================== + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation +def test_so3_serialize_roundtrip(mode, op_type) -> None: + """Serialize -> deserialize -> forward identical; @version == 1.""" + lmax, n_focus, n_batch, kmax = 2, 2, 5, 2 + _, dp_net = _build_so3_nets( + mode=mode, + op_type=op_type, + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + residual_scale_init=0.5, + ) + rng = np.random.default_rng(66) + query, context = _make_so3_inputs( + dp_net=dp_net, mode=mode, layout="ndfc", n_batch=n_batch, rng=rng + ) + dp_out = _run(dp_net, query, context, "dp") + + data = dp_net.serialize() + assert data["@version"] == 1 + assert data["config"]["projector"]["@class"] == "SO3GridProjector" + assert "residual_scale" in data["@variables"] + if mode == "cross": + assert "frame_expand.weight" in data["@variables"] + assert "frame_contract.weight" in data["@variables"] + else: + assert "frame_expand.weight" not in data["@variables"] + + restored = DPSO3GridNet.deserialize(data) + np.testing.assert_array_equal(restored.residual_scale, dp_net.residual_scale) + np.testing.assert_allclose( + _run(restored, query, context, "dp"), dp_out, rtol=1e-12, atol=1e-12 + ) + + +# === torch namespace ================================================ + + +def test_torch_namespace() -> None: + """cross-mode SO3GridNet.call on torch.from_numpy input matches numpy. + + Guards the frame-axis to/from-grid reshape/permute pitfall: a numpy-only + bug there (e.g. ``np.einsum`` on a tensor) would diverge here. + """ + import torch + + lmax, n_focus, n_batch, kmax = 2, 2, 5, 2 + _, dp_net = _build_so3_nets( + mode="cross", + op_type="mlp", + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + residual_scale_init=0.7, + ) + rng = np.random.default_rng(77) + query, context = _make_so3_inputs( + dp_net=dp_net, mode="cross", layout="ndfc", n_batch=n_batch, rng=rng + ) + np_out = np.asarray(dp_net.call(query, context)) + torch_out = dp_net.call(torch.from_numpy(query), torch.from_numpy(context)) + assert array_api_compat.is_torch_array(torch_out) + np.testing.assert_allclose( + np_out, torch_out.detach().cpu().numpy(), rtol=1e-12, atol=1e-12 + ) + + +# === S2 regression (n_frames == 1 path untouched) =================== + + +def _build_s2_nets(*, mode, op_type, layout, lmax=2, channels=4, n_focus=1): + """Build a pt + dp ``S2GridNet`` with identical (perturbed) weights.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + S2GridNet as PTS2GridNet, + ) + + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": mode, + "op_type": op_type, + "layout": layout, + "grid_resolution_list": None, + "coefficient_layout": "packed", + "grid_method": "lebedev", + "grid_branches": 1, + "residual_scale_init": None, + "mlp_bias": False, + "trainable": True, + "seed": 7, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + rng = np.random.default_rng(2100) + with torch.no_grad(): + for p in pt_net.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + dp_net = DPS2GridNet(precision="float64", **common) + state = {k: v.detach().cpu().numpy() for k, v in pt_net.state_dict().items()} + dp_net.scalar_gate.weight = state["scalar_gate.weight"] + for name in _grid_op_param_names(op_type): + getattr(dp_net.grid_op, name).weight = state[f"grid_op.{name}.weight"] + return pt_net, dp_net + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +def test_s2_regression(mode) -> None: + """An existing S2GridNet (n_frames == 1) still matches pt at 1e-12.""" + lmax, n_focus, n_batch, channels = 2, 2, 5, 4 + pt_net, dp_net = _build_s2_nets( + mode=mode, op_type="mlp", layout="ndfc", lmax=lmax, n_focus=n_focus + ) + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(88) + if mode == "self": + query = rng.normal(size=(n_batch, coeff_dim, n_focus, 2 * channels)) + context = None + else: + query = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + context = rng.normal(size=(n_batch, coeff_dim, n_focus, channels)) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) From 8038577973f7a3b92a601b64f01071c509c9d9c9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 23:07:47 +0800 Subject: [PATCH 10/18] feat(dpmodel): wire SO3GridNet into FFN (un-guard ffn_so3_grid) for DPA4 --- deepmd/dpmodel/descriptor/dpa4_nn/block.py | 3 +- deepmd/dpmodel/descriptor/dpa4_nn/ffn.py | 65 +++--- .../tests/common/dpmodel/test_descrpt_dpa4.py | 4 +- .../tests/common/dpmodel/test_dpa4_ffn_so3.py | 190 ++++++++++++++++++ .../pt/model/test_dpa4_dpmodel_parity.py | 21 +- 5 files changed, 244 insertions(+), 39 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_ffn_so3.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/block.py b/deepmd/dpmodel/descriptor/dpa4_nn/block.py index b17e2c0c4c..82f47cb3c5 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/block.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/block.py @@ -20,8 +20,7 @@ Flags merely forwarded to sub-components keep their guards there (delegated, not duplicated here): ``so2_attn_res``, ``so2_s2_activation``, ``node_wise_s2/so3``, ``message_node_s2/so3``, ``atten_f_mix``, -``atten_v_proj``, ``atten_o_proj`` (raised by ``SO2Convolution``) and -``ffn_so3_grid`` with the grid path active (raised by ``EquivariantFFN``). +``atten_v_proj``, ``atten_o_proj`` (raised by ``SO2Convolution``). The pt eval-time activation-checkpoint / nvtx instrumentation (``DP_ACT_INFER``, ``DP_COMPILE_INFER``, ``nvtx_range``) is pt-runtime-only diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py b/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py index 673e4ee836..984573ffed 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py @@ -5,12 +5,6 @@ This module is the dpmodel port of ``deepmd.pt.model.descriptor.sezm_nn.ffn``. It defines the full SO(3)-equivariant feed-forward network used inside SeZM interaction blocks. - -Branches guarded with ``NotImplementedError`` (flags unused by the core DPA4 -config): - -- ``ffn_so3_grid=True`` — the pt path instantiates ``SO3GridNet`` - (pt ffn.py:209), which is not ported to dpmodel. """ from __future__ import ( @@ -40,6 +34,7 @@ ) from .grid_net import ( S2GridNet, + SO3GridNet, ) from .projection import ( resolve_s2_grid_resolution, @@ -91,7 +86,7 @@ class EquivariantFFN(NativeOP): s2_activation If True, enable the S2 FFN grid path. ffn_so3_grid - If True, enable the SO3 Wigner-D FFN grid path (not ported). + If True, enable the SO3 Wigner-D FFN grid path. lebedev_quadrature If True, use Lebedev quadrature for the S2 projector in this FFN. activation_function @@ -141,10 +136,6 @@ def __init__( self.use_grid_branch = self.grid_branch > 0 self.s2_activation = bool(s2_activation) self.ffn_so3_grid = bool(ffn_so3_grid) - if self.ffn_so3_grid: - raise NotImplementedError( - "ffn_so3_grid=True (SO3GridNet) is not ported to dpmodel" - ) self.lebedev_quadrature = bool(lebedev_quadrature) self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" base_grid = resolve_s2_grid_resolution( @@ -163,8 +154,7 @@ def __init__( self.precision = precision self.compute_precision = _compute_precision(precision) self.trainable = bool(trainable) - # pt: grid_n_frames = 2 * kmax + 1 only when ffn_so3_grid (guarded above) - self.grid_n_frames = 1 + self.grid_n_frames = 2 * self.kmax + 1 if self.ffn_so3_grid else 1 # === Step 0. Split deterministic seeds at the module top-level === seed_so3_in = child_seed(seed, 0) @@ -199,22 +189,39 @@ def __init__( if self.use_grid_branch else ("mlp" if self.use_grid_mlp else "glu") ) - self.act: NativeOP = S2GridNet( - lmax=self.lmax, - channels=self.hidden_channels, - n_focus=1, - mode="self", - op_type=grid_op, - precision=self.compute_precision, - layout="ndfc", - grid_resolution_list=self.s2_grid_resolution, - coefficient_layout="packed", - grid_method=self.s2_grid_method, - grid_branches=max(1, self.grid_branch), - mlp_bias=self.mlp_bias, - trainable=self.trainable, - seed=seed_act, - ) + self.act: NativeOP + if self.ffn_so3_grid: + self.act = SO3GridNet( + lmax=self.lmax, + kmax=self.kmax, + channels=self.hidden_channels, + n_focus=1, + mode="self", + op_type=grid_op, + precision=self.compute_precision, + layout="ndfc", + grid_branches=max(1, self.grid_branch), + mlp_bias=self.mlp_bias, + trainable=self.trainable, + seed=seed_act, + ) + else: + self.act = S2GridNet( + lmax=self.lmax, + channels=self.hidden_channels, + n_focus=1, + mode="self", + op_type=grid_op, + precision=self.compute_precision, + layout="ndfc", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="packed", + grid_method=self.s2_grid_method, + grid_branches=max(1, self.grid_branch), + mlp_bias=self.mlp_bias, + trainable=self.trainable, + seed=seed_act, + ) else: self.act = GatedActivation( lmax=self.lmax, diff --git a/source/tests/common/dpmodel/test_descrpt_dpa4.py b/source/tests/common/dpmodel/test_descrpt_dpa4.py index 43c882864a..223c89ca05 100644 --- a/source/tests/common/dpmodel/test_descrpt_dpa4.py +++ b/source/tests/common/dpmodel/test_descrpt_dpa4.py @@ -166,7 +166,6 @@ def test_masked_edge_inertness(self) -> None: ("s2_activation", [True, True]), # so2-side S2 activation ("node_wise_s2", True), # SO(2) cross-grid product ("message_node_so3", True), # SO(2) cross-grid product - ("ffn_so3_grid", True), # SO(3) Wigner-D FFN grid ("atten_f_mix", True), # SO(2) attention focus mix ("atten_v_proj", True), # SO(2) attention value projection ("atten_o_proj", True), # SO(2) attention output projection @@ -186,7 +185,8 @@ def test_not_implemented_guards(self, flag, value) -> None: ("full_attn_res", "none"), ("s2_activation", [False, True]), ("node_wise_s2", False), - ("ffn_so3_grid", False), + ("ffn_so3_grid", False), # SO(3) Wigner-D FFN grid off + ("ffn_so3_grid", True), # SO(3) Wigner-D FFN grid on (now wired) ("use_amp", True), # pt-runtime-only switch: accepted and ignored ("use_amp", False), ], diff --git a/source/tests/common/dpmodel/test_dpa4_ffn_so3.py b/source/tests/common/dpmodel/test_dpa4_ffn_so3.py new file mode 100644 index 0000000000..a726c59c80 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_ffn_so3.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 ``EquivariantFFN`` SO3-grid (``ffn_so3_grid``) path. + +These tests cover the newly un-guarded ``ffn_so3_grid=True`` branch of the +dpmodel ``EquivariantFFN`` which wires ``SO3GridNet(mode='self')`` in place of +the ``S2GridNet`` used by the ``s2_activation`` path. They mirror the current pt +``EquivariantFFN`` in ``deepmd.pt.model.descriptor.sezm_nn.ffn``. + +Note: ``EquivariantFFN`` exposes neither ``mmax`` nor ``n_focus`` — the SO3 grid +always uses ``mmax = lmax`` (projector default) and ``n_focus = 1`` internally +(matching pt). The truncated ``mmax < lmax`` path is exercised at the +``SO3GridNet`` level in ``test_dpa4_so3_gridnet.py``. + +pt imports live inside the test functions (ruff TID253 bans module-level +``deepmd.pt`` imports under ``source/tests/common``); pt modules are pinned to +CPU (``.to("cpu")``) under the CUDA-default-device CI. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.ffn import ( + EquivariantFFN as DPFFN, +) + + +def _build_ffn_pair(*, lmax, channels, hidden_channels, ffn_config, seed=7): + """Build a pt + dp ``EquivariantFFN`` sharing identical (perturbed) weights. + + Returns ``(pt_ffn, dp_ffn)``. The weight copy goes pt -> dp via + ``DPFFN.deserialize(pt_ffn.serialize())`` (both share state_dict key names). + Weights are perturbed first because ``so3_linear_2`` is zero-initialised, + which would otherwise make the FFN output identically zero. + """ + import torch + + from deepmd.pt.model.descriptor.sezm_nn.ffn import ( + EquivariantFFN as PTFFN, + ) + + pt_ffn = PTFFN( + lmax=lmax, + channels=channels, + hidden_channels=hidden_channels, + dtype=torch.float64, + trainable=True, + seed=seed, + **ffn_config, + ).to("cpu") + rng = np.random.default_rng(2100) + with torch.no_grad(): + for p in pt_ffn.parameters(): + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + + dp_ffn = DPFFN.deserialize(pt_ffn.serialize()) + return pt_ffn, dp_ffn + + +def _run_pt(ffn, x): + import torch + + return ffn(torch.from_numpy(x)).detach().cpu().numpy() + + +def _run_dp(ffn, x): + return np.asarray(ffn.call(x)) + + +# === SO3 grid parity ================================================ + + +@pytest.mark.parametrize( + "grid_mlp,grid_branch", # grid op: (False,0)->glu, (True,0)->mlp, (False,1)->branch + [(False, 0), (True, 0), (False, 1)], +) +def test_ffn_so3_grid_parity(grid_mlp, grid_branch) -> None: + """ffn_so3_grid=True dp FFN matches pt at 1e-12 (lmax=3, kmax=1, mmax=lmax).""" + lmax, channels, hidden_channels, kmax = 3, 8, 8, 1 + pt_ffn, dp_ffn = _build_ffn_pair( + lmax=lmax, + channels=channels, + hidden_channels=hidden_channels, + ffn_config={ + "kmax": kmax, + "ffn_so3_grid": True, + "grid_mlp": grid_mlp, + "grid_branch": grid_branch, + }, + ) + assert dp_ffn.ffn_so3_grid + assert dp_ffn.grid_n_frames == 2 * kmax + 1 + rng = np.random.default_rng(11) + # (N, D, F, C): D=(lmax+1)^2, F=n_focus=1, C=channels + x = rng.normal(size=(5, (lmax + 1) ** 2, 1, channels)) + dp_out = _run_dp(dp_ffn, x) + pt_out = _run_pt(pt_ffn, x) + assert dp_out.shape == x.shape + # output must be non-trivial (so3_linear_2 perturbed away from zero-init) + assert np.max(np.abs(dp_out)) > 1e-6 + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +def test_ffn_so3_grid_constructs() -> None: + """The dp FFN with ffn_so3_grid=True constructs and runs (no NotImplementedError).""" + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + SO3GridNet as DPSO3GridNet, + ) + + lmax, channels, hidden_channels, kmax = 3, 8, 8, 1 + ffn = DPFFN( + lmax=lmax, + channels=channels, + hidden_channels=hidden_channels, + kmax=kmax, + ffn_so3_grid=True, + grid_mlp=False, + grid_branch=0, + precision="float64", + trainable=True, + seed=3, + ) + assert isinstance(ffn.act, DPSO3GridNet) + assert ffn.grid_n_frames == 2 * kmax + 1 + # linear1 output channels mirror pt: 2 * grid_n_frames * hidden_channels + assert ffn.so3_linear_1.out_channels == 2 * ffn.grid_n_frames * hidden_channels + assert ffn.so3_linear_2.in_channels == ffn.grid_n_frames * hidden_channels + rng = np.random.default_rng(1) + x = rng.normal(size=(4, (lmax + 1) ** 2, 1, channels)) + out = _run_dp(ffn, x) + assert out.shape == x.shape + + +# === S2 regression (ffn_so3_grid=False, n_frames==1 path untouched) = + + +@pytest.mark.parametrize("grid_mlp", [False, True]) # glu vs mlp grid op +def test_ffn_s2_regression(grid_mlp) -> None: + """The s2_activation FFN path still matches pt at 1e-12 (not broken).""" + lmax, channels, hidden_channels = 2, 4, 4 + pt_ffn, dp_ffn = _build_ffn_pair( + lmax=lmax, + channels=channels, + hidden_channels=hidden_channels, + ffn_config={ + "s2_activation": True, + "ffn_so3_grid": False, + "grid_mlp": grid_mlp, + "lebedev_quadrature": True, + }, + ) + assert not dp_ffn.ffn_so3_grid + assert dp_ffn.grid_n_frames == 1 + rng = np.random.default_rng(22) + x = rng.normal(size=(5, (lmax + 1) ** 2, 1, channels)) + dp_out = _run_dp(dp_ffn, x) + pt_out = _run_pt(pt_ffn, x) + assert np.max(np.abs(dp_out)) > 1e-6 + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +# === serialize roundtrip ============================================ + + +@pytest.mark.parametrize( + "grid_mlp,grid_branch", # grid op: (False,0)->glu, (True,0)->mlp, (False,1)->branch + [(False, 0), (True, 0), (False, 1)], +) +def test_ffn_so3_serialize_roundtrip(grid_mlp, grid_branch) -> None: + """ffn_so3_grid FFN serialize -> deserialize -> forward identical.""" + lmax, channels, hidden_channels, kmax = 3, 8, 8, 1 + _, dp_ffn = _build_ffn_pair( + lmax=lmax, + channels=channels, + hidden_channels=hidden_channels, + ffn_config={ + "kmax": kmax, + "ffn_so3_grid": True, + "grid_mlp": grid_mlp, + "grid_branch": grid_branch, + }, + ) + data = dp_ffn.serialize() + assert data["@class"] == "EquivariantFFN" + assert data["config"]["ffn_so3_grid"] is True + restored = DPFFN.deserialize(data) + rng = np.random.default_rng(33) + x = rng.normal(size=(5, (lmax + 1) ** 2, 1, channels)) + np.testing.assert_allclose( + _run_dp(restored, x), _run_dp(dp_ffn, x), rtol=1e-12, atol=1e-12 + ) diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 7b4087b6bc..c3966fcf2b 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -3184,11 +3184,16 @@ def test_ffn_roundtrip(self, s2_activation) -> None: np.asarray(dp_mod.call(x)), np.asarray(dp_mod2.call(x)) ) - def test_ffn_guards(self) -> None: - from deepmd.dpmodel.descriptor.dpa4_nn.ffn import EquivariantFFN as DPFFN - - with pytest.raises(NotImplementedError, match="ffn_so3_grid"): - DPFFN(**self._ffn_kwargs(ffn_so3_grid=True), precision="float64") + @pytest.mark.parametrize("grid_branch", [0, 1]) # branch mixer off/on + @pytest.mark.parametrize("grid_mlp", [False, True]) # polynomial grid MLP op + def test_ffn_so3_grid(self, grid_mlp, grid_branch) -> None: + # ffn_so3_grid=True wires SO3GridNet(mode='self'); grid_n_frames=2*kmax+1 + pt_mod, dp_mod, kwargs = self._build_ffn_pair( + ffn_so3_grid=True, grid_mlp=grid_mlp, grid_branch=grid_branch + ) + assert dp_mod.ffn_so3_grid + assert dp_mod.grid_n_frames == 2 * kwargs["kmax"] + 1 + self._assert_ffn_parity(pt_mod, dp_mod, kwargs) def test_ffn_errors(self) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.ffn import EquivariantFFN as DPFFN @@ -3335,6 +3340,11 @@ def test_block_plain_ffn_act(self) -> None: ) self._assert_block_parity(pt_mod, dp_mod, kwargs) + def test_block_ffn_so3_grid(self) -> None: + # ffn_so3_grid=True: block FFN uses SO3GridNet(mode='self') + pt_mod, dp_mod, kwargs = self._build_block_pair(ffn_so3_grid=True) + self._assert_block_parity(pt_mod, dp_mod, kwargs) + def test_block_real_edge_cache(self) -> None: # end-to-end: REAL pt build_edge_cache vs REAL dp build_edge_cache # feeding the same weight-copied block (no synthetic cache) @@ -3403,7 +3413,6 @@ def test_block_roundtrip(self) -> None: ("so2_s2_activation", True), # delegated to SO2Convolution ("node_wise_s2", True), # delegated to SO2Convolution ("message_node_so3", True), # delegated to SO2Convolution - ("ffn_so3_grid", True), # delegated to EquivariantFFN ], ) def test_block_guards(self, flag, value) -> None: From e4937ff69439f7e715d42296d5abae2c84e7b145 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 23:20:59 +0800 Subject: [PATCH 11/18] feat(dpmodel): wire SO3/S2 cross-mode grid products into SO2Convolution for DPA4 --- deepmd/dpmodel/descriptor/dpa4_nn/block.py | 5 +- deepmd/dpmodel/descriptor/dpa4_nn/so2.py | 172 ++++++++++- .../tests/common/dpmodel/test_descrpt_dpa4.py | 4 +- .../common/dpmodel/test_dpa4_so2_grid.py | 267 ++++++++++++++++++ .../pt/model/test_dpa4_dpmodel_parity.py | 6 - 5 files changed, 433 insertions(+), 21 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_so2_grid.py diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/block.py b/deepmd/dpmodel/descriptor/dpa4_nn/block.py index 82f47cb3c5..861a3b34d4 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/block.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/block.py @@ -19,8 +19,9 @@ Flags merely forwarded to sub-components keep their guards there (delegated, not duplicated here): ``so2_attn_res``, ``so2_s2_activation``, -``node_wise_s2/so3``, ``message_node_s2/so3``, ``atten_f_mix``, -``atten_v_proj``, ``atten_o_proj`` (raised by ``SO2Convolution``). +``atten_f_mix``, ``atten_v_proj``, ``atten_o_proj`` (raised by +``SO2Convolution``). The cross-mode grid products (``node_wise_s2/so3``, +``message_node_s2/so3``) are ported and forwarded to ``SO2Convolution``. The pt eval-time activation-checkpoint / nvtx instrumentation (``DP_ACT_INFER``, ``DP_COMPILE_INFER``, ``nvtx_range``) is pt-runtime-only diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index 1414956b13..a74c30ce7e 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -22,8 +22,11 @@ Branches guarded with ``NotImplementedError`` (flags unused by the core DPA4 config): ``so2_attn_res != "none"``, ``layer_scale``, ``s2_activation``, -``atten_f_mix``, ``atten_v_proj``, ``atten_o_proj``, ``node_wise_s2``, -``node_wise_so3``, ``message_node_s2``, ``message_node_so3``. +``atten_f_mix``, ``atten_v_proj``, ``atten_o_proj``. + +The cross-mode SO(3)/S2 grid products (``node_wise_s2``/``node_wise_so3`` and +``message_node_s2``/``message_node_so3``) are ported and wired into the +convolution, mirroring the pt ``SO2Convolution`` forward placement. """ from __future__ import ( @@ -64,6 +67,10 @@ from .attention import ( segment_envelope_gated_softmax, ) +from .grid_net import ( + S2GridNet, + SO3GridNet, +) from .indexing import ( build_m_major_index, build_m_major_l_index, @@ -899,15 +906,6 @@ def __init__( self.node_wise_so3 = bool(node_wise_so3) self.message_node_s2 = bool(message_node_s2) self.message_node_so3 = bool(message_node_so3) - if self.node_wise_s2 or self.node_wise_so3: - raise NotImplementedError( - "node_wise_s2/node_wise_so3 grid products are not ported to dpmodel" - ) - if self.message_node_s2 or self.message_node_so3: - raise NotImplementedError( - "message_node_s2/message_node_so3 grid products are not ported " - "to dpmodel" - ) self.lebedev_quadrature = bool(lebedev_quadrature) self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" self.s2_grid_resolution = resolve_s2_grid_resolution( @@ -915,6 +913,19 @@ def __init__( self.mmax, method=self.s2_grid_method, ) + base_full_grid_resolution = resolve_s2_grid_resolution( + self.lmax, + self.lmax, + method=self.s2_grid_method, + ) + # Mirror pt: the e3nn product-grid branch squares the max resolution. + # dpmodel only ports the Lebedev backend (the e3nn S2GridNet raises), + # so this just preserves the config-recorded resolution for parity. + self.s2_full_grid_resolution = ( + [max(base_full_grid_resolution), max(base_full_grid_resolution)] + if self.s2_grid_method == "e3nn" + else base_full_grid_resolution + ) self.activation_function = str(activation_function) self.attn_n_focus = self.n_focus self.attn_focus_dim = self.so2_focus_dim @@ -964,6 +975,8 @@ def __init__( seed_gate = child_seed(seed, 4) seed_radial_hidden = child_seed(seed, 6) seed_radial_degree = child_seed(seed, 7) + seed_node_wise_s2 = child_seed(seed, 8) + seed_message_node_s2 = child_seed(seed, 9) # === Step 3. Multiple SO2Linear layers === # (s2_activation is guarded above, so out_channels == so2_focus_dim.) @@ -1127,6 +1140,100 @@ def __init__( trainable=trainable, ) + # === Step 8.5. Optional cross-mode grid products === + # ``op_type`` selection mirrors pt: ``branch`` (count > 0) takes + # precedence over ``mlp``, else ``glu``. When both ``*_s2`` and + # ``*_so3`` are set the SO(3) branch wins (per the argcheck doc). + node_wise_op = ( + "branch" + if self.node_wise_grid_branch > 0 + else ("mlp" if self.node_wise_grid_mlp else "glu") + ) + node_wise_branches = max(1, self.node_wise_grid_branch) + message_node_op = ( + "branch" + if self.message_node_grid_branch > 0 + else ("mlp" if self.message_node_grid_mlp else "glu") + ) + message_node_branches = max(1, self.message_node_grid_branch) + self.node_wise_grid_product: S2GridNet | SO3GridNet | None = None + if self.node_wise_s2 or self.node_wise_so3: + if self.node_wise_so3: + self.node_wise_grid_product = SO3GridNet( + lmax=self.lmax, + mmax=self.mmax, + kmax=self.kmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=node_wise_op, + precision=self.compute_precision, + layout="flat", + coefficient_layout="m_major", + grid_branches=node_wise_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_node_wise_s2, + ) + else: + self.node_wise_grid_product = S2GridNet( + lmax=self.lmax, + mmax=self.mmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=node_wise_op, + precision=self.compute_precision, + layout="flat", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="m_major", + grid_method=self.s2_grid_method, + grid_branches=node_wise_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_node_wise_s2, + ) + self.message_node_grid_product: S2GridNet | SO3GridNet | None = None + if self.message_node_s2 or self.message_node_so3: + if self.message_node_so3: + self.message_node_grid_product = SO3GridNet( + lmax=self.lmax, + kmax=self.kmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=message_node_op, + precision=self.compute_precision, + layout="flat", + coefficient_layout="packed", + grid_branches=message_node_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_message_node_s2, + ) + else: + self.message_node_grid_product = S2GridNet( + lmax=self.lmax, + mmax=self.lmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=message_node_op, + precision=self.compute_precision, + layout="flat", + grid_resolution_list=self.s2_full_grid_resolution, + grid_method=self.s2_grid_method, + grid_branches=message_node_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + coefficient_layout="packed", + seed=seed_message_node_s2, + ) + # === Step 9. Pre-focus channel mixing === # This projects the full channel width before the SO(2) focus split. self.pre_focus_mix = SO3Linear( @@ -1220,6 +1327,13 @@ def call( src_idx = xp.astype(xp.reshape(src, (n_edge,)), xp.int64) x_src = xp.take(x, src_idx, axis=0) # (E, D, C_wide) x_local = xp.matmul(D_m_prime, x_src) # (E, D_m, C_wide) + # pt rotates the *destination* node into the same edge frame for the + # node-wise cross-mode grid product (raw, before radial modulation). + x_dst_local: Any = None + if self.node_wise_grid_product is not None: + dst_idx_nw = xp.astype(xp.reshape(dst, (n_edge,)), xp.int64) + x_dst = xp.take(x, dst_idx_nw, axis=0) # (E, D, C_wide) + x_dst_local = xp.matmul(D_m_prime, x_dst) # (E, D_m, C_wide) # === Step 3. Select radial/type features for reduced layout === degree_index_m = xp_asarray_nodetach(xp, self.degree_index_m, device=device) @@ -1230,6 +1344,11 @@ def call( x_local = x_local * rad_feat else: x_local = self.radial_degree_mixer(x_local, rad_feat) + # pt Step 3: edge-local cross-mode grid product between the + # radial-fused source (query) and the raw destination (context), + # added as a residual in the reduced m-major layout (E, D_m, C_wide). + if self.node_wise_grid_product is not None: + x_local = x_local + self.node_wise_grid_product(x_local, x_dst_local) rad_feat_l0_focus = xp.reshape( rad_feat[:, 0, :], (n_edge, self.n_focus, self.so2_focus_dim) ) # (E, F, Cf) @@ -1480,6 +1599,13 @@ def apply_bias_correction( x.dtype, ) # (N, D, C_wide) + # === Step 9. Optional message-node grid product === + # pt: post-aggregation packed-layout cross-mode product between the + # aggregated message (query) and the pre-focus-mixed node features + # (context), added as a residual before the final channel mixing. + if self.message_node_grid_product is not None: + out = out + self.message_node_grid_product(out, x) + # === Step 10. Final channel mixing === out = self.post_focus_mix(out[:, :, None, :])[:, :, 0, :] return out # (N, D, C) @@ -1530,6 +1656,15 @@ def _variables(self) -> dict[str, np.ndarray]: if self.radial_degree_mixer is not None: for key, value in self.radial_degree_mixer._variables().items(): variables[f"radial_degree_mixer.{key}"] = value + # Cross-mode grid products: nest each net's @variables under the pt + # state_dict attribute name so ``deserialize(pt.serialize())`` matches. + for name, grid in ( + ("node_wise_grid_product", self.node_wise_grid_product), + ("message_node_grid_product", self.message_node_grid_product), + ): + if grid is not None: + for key, value in grid.serialize()["@variables"].items(): + variables[f"{name}.{key}"] = value for name, mix in ( ("pre_focus_mix", self.pre_focus_mix), ("post_focus_mix", self.post_focus_mix), @@ -1689,6 +1824,21 @@ def sub_vars(prefix: str) -> dict[str, Any]: ) if self.radial_degree_mixer is not None: self.radial_degree_mixer._load_variables(sub_vars("radial_degree_mixer")) + # Grid products have no ``_load_variables``; reuse their config (from a + # fresh ``serialize()``) plus the loaded @variables and re-deserialize + # in place. This exercises the full grid-net serialize round-trip. + if self.node_wise_grid_product is not None: + template = self.node_wise_grid_product.serialize() + template["@variables"] = sub_vars("node_wise_grid_product") + self.node_wise_grid_product = type(self.node_wise_grid_product).deserialize( + template + ) + if self.message_node_grid_product is not None: + template = self.message_node_grid_product.serialize() + template["@variables"] = sub_vars("message_node_grid_product") + self.message_node_grid_product = type( + self.message_node_grid_product + ).deserialize(template) for name, mix in ( ("pre_focus_mix", self.pre_focus_mix), ("post_focus_mix", self.post_focus_mix), diff --git a/source/tests/common/dpmodel/test_descrpt_dpa4.py b/source/tests/common/dpmodel/test_descrpt_dpa4.py index 223c89ca05..e4ab24c064 100644 --- a/source/tests/common/dpmodel/test_descrpt_dpa4.py +++ b/source/tests/common/dpmodel/test_descrpt_dpa4.py @@ -164,8 +164,6 @@ def test_masked_edge_inertness(self) -> None: ("block_attn_res", "dependent"), # DepthAttnRes ("so2_attn_res", "independent"), # SO(2) DepthAttnRes ("s2_activation", [True, True]), # so2-side S2 activation - ("node_wise_s2", True), # SO(2) cross-grid product - ("message_node_so3", True), # SO(2) cross-grid product ("atten_f_mix", True), # SO(2) attention focus mix ("atten_v_proj", True), # SO(2) attention value projection ("atten_o_proj", True), # SO(2) attention output projection @@ -185,6 +183,8 @@ def test_not_implemented_guards(self, flag, value) -> None: ("full_attn_res", "none"), ("s2_activation", [False, True]), ("node_wise_s2", False), + ("node_wise_so3", True), # SO(2) edge-local SO(3) cross-grid product + ("message_node_so3", True), # SO(2) post-agg SO(3) cross-grid product ("ffn_so3_grid", False), # SO(3) Wigner-D FFN grid off ("ffn_so3_grid", True), # SO(3) Wigner-D FFN grid on (now wired) ("use_amp", True), # pt-runtime-only switch: accepted and ignored diff --git a/source/tests/common/dpmodel/test_dpa4_so2_grid.py b/source/tests/common/dpmodel/test_dpa4_so2_grid.py new file mode 100644 index 0000000000..082c306503 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so2_grid.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 ``SO2Convolution`` cross-mode grid products. + +These tests cover the ``node_wise_s2``/``node_wise_so3`` (edge-local) and +``message_node_s2``/``message_node_so3`` (post-aggregation) grid-product +branches wired into the dpmodel ``SO2Convolution``. Each test builds a pt and a +dpmodel ``SO2Convolution`` with the same config, copies the (perturbed) pt +weights into the dpmodel module via ``DP.deserialize(pt.serialize())`` (which +also exercises the serialize round-trip), runs both on the same random padded / +sparse edge data, and asserts forward parity at fp64 (~1e-12 on CPU). + +pt imports live inside the test functions because ruff TID253 bans module-level +``deepmd.pt`` imports under ``source/tests/common``; pt modules are pinned to +CPU (``.to("cpu")``) under the CUDA-default-device CI. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.so2 import ( + SO2Convolution as DPSO2Conv, +) + +# fp64 weight-copied parity is near-bit on CPU. +RTOL, ATOL = 1e-12, 1e-14 + +NLOC = 5 +NNEI = 4 + + +def _to_pt(x): + import torch + + return torch.from_numpy(np.ascontiguousarray(x)).to("cpu") + + +def _assert_parity(a, t, rtol=RTOL, atol=ATOL): + np.testing.assert_allclose( + np.asarray(a), t.detach().cpu().numpy(), rtol=rtol, atol=atol + ) + + +def _base_kwargs(**overrides): + kwargs = { + "lmax": 3, + "mmax": 1, + "kmax": 1, + "channels": 4, + "n_focus": 1, + "focus_dim": 0, + "focus_compete": True, + "so2_norm": False, + "so2_layers": 2, + "so2_attn_res": "none", + "layer_scale": False, + "n_atten_head": 1, + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "lebedev_quadrature": True, + "activation_function": "silu", + "mlp_bias": False, + "eps": 1e-7, + } + kwargs.update(overrides) + return kwargs + + +def _perturb(pt_mod, seed): + import torch + + rng = np.random.default_rng(seed) + with torch.no_grad(): + for p in pt_mod.parameters(): + p += _to_pt(0.1 * rng.normal(size=tuple(p.shape))) + + +def _build_conv_pair(seed=17, perturb_seed=2060, **overrides): + import torch + + from deepmd.pt.model.descriptor.sezm_nn.so2 import ( + SO2Convolution as PTSO2Conv, + ) + + kwargs = _base_kwargs(**overrides) + pt_mod = PTSO2Conv(**kwargs, dtype=torch.float64, seed=seed, trainable=True).to( + "cpu" + ) + # post_focus_mix is zero-initialized; perturb so the output is nonzero and + # the (residual-scaled) grid-product contribution is observable. + _perturb(pt_mod, perturb_seed) + dp_mod = DPSO2Conv.deserialize(pt_mod.serialize()) + return pt_mod, dp_mod, kwargs + + +def _build_edge_data(rng, *, nloc, nnei, lmax, channels, masked="slots"): + """Build matching pt (sparse) and dp (padded) edge caches on CPU.""" + from deepmd.dpmodel.descriptor.dpa4_nn.edge_cache import ( + EdgeCache, + ) + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + build_edge_quaternion, + ) + from deepmd.pt.model.descriptor.sezm_nn.edge_cache import ( + EdgeFeatureCache, + ) + + n_edge = nloc * nnei + dim_full = (lmax + 1) ** 2 + src = np.array( + [(i + 1 + k) % nloc for i in range(nloc) for k in range(nnei)], + dtype=np.int64, + ) + dst = np.repeat(np.arange(nloc, dtype=np.int64), nnei) + mask = np.ones(n_edge, dtype=np.float64) + if masked == "slots": + mask[3] = 0.0 + mask[nnei + 1] = 0.0 + mask[-1] = 0.0 + elif masked != "none": + raise ValueError(f"unknown masked mode {masked}") + valid = mask > 0.5 + n_valid = int(valid.sum()) + + edge_vec = rng.normal(size=(n_edge, 3)) + edge_vec /= np.linalg.norm(edge_vec, axis=-1, keepdims=True) + quat = build_edge_quaternion(edge_vec) + D_full, Dt_full = WignerDCalculator(lmax, precision="float64").call(quat) + D_full = np.asarray(D_full) + Dt_full = np.asarray(Dt_full) + edge_rbf = np.zeros((n_edge, 1)) + edge_env = rng.uniform(0.2, 1.0, size=(n_edge, 1)) + deg = ((edge_env[:, 0] ** 2) * mask).reshape(nloc, nnei).sum(axis=1) + inv_sqrt_deg = (1.0 / np.sqrt(deg + 1.0)).reshape(nloc, 1, 1) + radial = rng.normal(size=(n_edge, lmax + 1, channels)) + x = rng.normal(size=(nloc, dim_full, channels)) + + t = _to_pt + pt_cache = EdgeFeatureCache( + src=t(src[valid]), + dst=t(dst[valid]), + edge_type_feat=t(np.zeros((n_valid, channels))), + edge_vec=t(edge_vec[valid]), + edge_rbf=t(edge_rbf[valid]), + edge_env=t(edge_env[valid]), + deg=t(deg), + inv_sqrt_deg=t(inv_sqrt_deg), + D_full=t(D_full[valid]), + Dt_full=t(Dt_full[valid]), + edge_src_gate=None, + ) + dp_cache = EdgeCache( + src=src, + dst=dst, + edge_type_feat=np.zeros((n_edge, channels)), + edge_vec=edge_vec, + edge_rbf=edge_rbf, + edge_env=edge_env, + deg=deg, + inv_sqrt_deg=inv_sqrt_deg, + D_full=D_full, + Dt_full=Dt_full, + edge_src_gate=None, + edge_mask=mask, + ) + return pt_cache, dp_cache, radial, radial[valid], x + + +def _assert_conv_parity(pt_mod, dp_mod, kwargs, *, masked="slots"): + rng = np.random.default_rng(2061) + pt_cache, dp_cache, radial, radial_valid, x = _build_edge_data( + rng, + nloc=NLOC, + nnei=NNEI, + lmax=kwargs["lmax"], + channels=kwargs["channels"], + masked=masked, + ) + out_dp = dp_mod.call(x, dp_cache, radial) + out_pt = pt_mod(_to_pt(x), pt_cache, _to_pt(radial_valid)) + _assert_parity(out_dp, out_pt) + + +@pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern +def test_so2_node_wise_s2_parity(masked) -> None: + # edge-local S2 cross product between source and destination node features + pt_mod, dp_mod, kwargs = _build_conv_pair( + node_wise_s2=True, lebedev_quadrature=True + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs, masked=masked) + + +@pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern +@pytest.mark.parametrize("lmax,mmax", [(2, 2), (3, 1)]) # full + truncated SO3 +def test_so2_node_wise_so3_parity(masked, lmax, mmax) -> None: + # edge-local SO(3) Wigner-D cross product; (3, 1) is the example truncation + pt_mod, dp_mod, kwargs = _build_conv_pair( + node_wise_so3=True, lmax=lmax, mmax=mmax, lebedev_quadrature=False + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs, masked=masked) + + +@pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern +def test_so2_message_node_s2_parity(masked) -> None: + # post-aggregation packed-layout S2 cross product (message vs node) + pt_mod, dp_mod, kwargs = _build_conv_pair( + message_node_s2=True, lebedev_quadrature=True + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs, masked=masked) + + +@pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern +@pytest.mark.parametrize("lmax,mmax", [(2, 2), (3, 1)]) # full + truncated SO3 +def test_so2_message_node_so3_parity(masked, lmax, mmax) -> None: + # post-aggregation SO(3) cross product; (3, 1) mirrors the flagship example + pt_mod, dp_mod, kwargs = _build_conv_pair( + message_node_so3=True, lmax=lmax, mmax=mmax, lebedev_quadrature=False + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs, masked=masked) + + +def test_so2_both_so3_parity() -> None: + # node_wise_so3 + message_node_so3 together, example-config-like + # (lmax=3, mmax=1, n_focus=2, so2_layers=3, degree_channel radial). + pt_mod, dp_mod, kwargs = _build_conv_pair( + node_wise_so3=True, + message_node_so3=True, + lmax=3, + mmax=1, + n_focus=2, + so2_layers=3, + lebedev_quadrature=False, + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs) + + +def test_so2_grid_mlp_cross() -> None: + # op_type='mlp' grid product (message_node_grid_mlp=True selects the MLP op) + pt_mod, dp_mod, kwargs = _build_conv_pair( + message_node_so3=True, + message_node_grid_mlp=True, + lmax=3, + mmax=1, + lebedev_quadrature=False, + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs) + + +def test_so2_grid_branch_cross() -> None: + # op_type='branch' grid product (mirrors the example's grid_branch=[1,...]). + pt_mod, dp_mod, kwargs = _build_conv_pair( + message_node_so3=True, + message_node_grid_branch=2, + lmax=3, + mmax=1, + lebedev_quadrature=False, + ) + _assert_conv_parity(pt_mod, dp_mod, kwargs) + + +@pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern +def test_so2_no_grid_regression(masked) -> None: + # all grid flags False: the base SO2Convolution path must still match pt. + pt_mod, dp_mod, kwargs = _build_conv_pair() + assert dp_mod.node_wise_grid_product is None + assert dp_mod.message_node_grid_product is None + _assert_conv_parity(pt_mod, dp_mod, kwargs, masked=masked) diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index c3966fcf2b..751c6a9001 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -2378,10 +2378,6 @@ def test_so2_convolution_roundtrip(self) -> None: ("atten_v_proj", True), # value projection ("atten_o_proj", True), # output projection ("s2_activation", True), # S2-grid SwiGLU non-linearity - ("node_wise_s2", True), # edge-local S2 grid product - ("node_wise_so3", True), # edge-local SO(3) grid product - ("message_node_s2", True), # post-aggregation S2 grid product - ("message_node_so3", True), # post-aggregation SO(3) grid product ], ) def test_so2_convolution_guards(self, flag, value) -> None: @@ -3411,8 +3407,6 @@ def test_block_roundtrip(self) -> None: ("block_attn_res", "dependent"), # block-level DepthAttnRes ("layer_scale", True), # block-level FFN LayerScale ("so2_s2_activation", True), # delegated to SO2Convolution - ("node_wise_s2", True), # delegated to SO2Convolution - ("message_node_so3", True), # delegated to SO2Convolution ], ) def test_block_guards(self, flag, value) -> None: From 6fb89971cd4cf094a3ffec6fdcc5d28d3a3db3b0 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 23:26:51 +0800 Subject: [PATCH 12/18] test(dpmodel): add fp32 grid-path parity for DPA4 SO3 grid --- .../common/dpmodel/test_dpa4_so2_grid.py | 58 +++++++++++++++---- .../common/dpmodel/test_dpa4_so3_gridnet.py | 40 ++++++++++++- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/source/tests/common/dpmodel/test_dpa4_so2_grid.py b/source/tests/common/dpmodel/test_dpa4_so2_grid.py index 082c306503..d69befcd54 100644 --- a/source/tests/common/dpmodel/test_dpa4_so2_grid.py +++ b/source/tests/common/dpmodel/test_dpa4_so2_grid.py @@ -71,20 +71,20 @@ def _perturb(pt_mod, seed): rng = np.random.default_rng(seed) with torch.no_grad(): for p in pt_mod.parameters(): - p += _to_pt(0.1 * rng.normal(size=tuple(p.shape))) + p += _to_pt(0.1 * rng.normal(size=tuple(p.shape))).to(p.dtype) -def _build_conv_pair(seed=17, perturb_seed=2060, **overrides): +def _build_conv_pair(seed=17, perturb_seed=2060, dtype=None, **overrides): import torch from deepmd.pt.model.descriptor.sezm_nn.so2 import ( SO2Convolution as PTSO2Conv, ) + if dtype is None: + dtype = torch.float64 kwargs = _base_kwargs(**overrides) - pt_mod = PTSO2Conv(**kwargs, dtype=torch.float64, seed=seed, trainable=True).to( - "cpu" - ) + pt_mod = PTSO2Conv(**kwargs, dtype=dtype, seed=seed, trainable=True).to("cpu") # post_focus_mix is zero-initialized; perturb so the output is nonzero and # the (residual-scaled) grid-product contribution is observable. _perturb(pt_mod, perturb_seed) @@ -92,7 +92,9 @@ def _build_conv_pair(seed=17, perturb_seed=2060, **overrides): return pt_mod, dp_mod, kwargs -def _build_edge_data(rng, *, nloc, nnei, lmax, channels, masked="slots"): +def _build_edge_data( + rng, *, nloc, nnei, lmax, channels, masked="slots", np_dtype=np.float64 +): """Build matching pt (sparse) and dp (padded) edge caches on CPU.""" from deepmd.dpmodel.descriptor.dpa4_nn.edge_cache import ( EdgeCache, @@ -135,11 +137,25 @@ def _build_edge_data(rng, *, nloc, nnei, lmax, channels, masked="slots"): radial = rng.normal(size=(n_edge, lmax + 1, channels)) x = rng.normal(size=(nloc, dim_full, channels)) + # Cast all float caches/inputs to the requested precision; both pt and dp + # then consume bit-identical inputs (the only divergence is the in-module + # accumulation order/precision). + edge_vec = edge_vec.astype(np_dtype) + edge_rbf = edge_rbf.astype(np_dtype) + edge_env = edge_env.astype(np_dtype) + deg = deg.astype(np_dtype) + inv_sqrt_deg = inv_sqrt_deg.astype(np_dtype) + D_full = D_full.astype(np_dtype) + Dt_full = Dt_full.astype(np_dtype) + radial = radial.astype(np_dtype) + x = x.astype(np_dtype) + edge_type_feat_np = np.zeros((n_edge, channels), dtype=np_dtype) + t = _to_pt pt_cache = EdgeFeatureCache( src=t(src[valid]), dst=t(dst[valid]), - edge_type_feat=t(np.zeros((n_valid, channels))), + edge_type_feat=t(edge_type_feat_np[valid]), edge_vec=t(edge_vec[valid]), edge_rbf=t(edge_rbf[valid]), edge_env=t(edge_env[valid]), @@ -152,7 +168,7 @@ def _build_edge_data(rng, *, nloc, nnei, lmax, channels, masked="slots"): dp_cache = EdgeCache( src=src, dst=dst, - edge_type_feat=np.zeros((n_edge, channels)), + edge_type_feat=edge_type_feat_np, edge_vec=edge_vec, edge_rbf=edge_rbf, edge_env=edge_env, @@ -166,7 +182,9 @@ def _build_edge_data(rng, *, nloc, nnei, lmax, channels, masked="slots"): return pt_cache, dp_cache, radial, radial[valid], x -def _assert_conv_parity(pt_mod, dp_mod, kwargs, *, masked="slots"): +def _assert_conv_parity( + pt_mod, dp_mod, kwargs, *, masked="slots", np_dtype=np.float64, rtol=RTOL, atol=ATOL +): rng = np.random.default_rng(2061) pt_cache, dp_cache, radial, radial_valid, x = _build_edge_data( rng, @@ -175,10 +193,11 @@ def _assert_conv_parity(pt_mod, dp_mod, kwargs, *, masked="slots"): lmax=kwargs["lmax"], channels=kwargs["channels"], masked=masked, + np_dtype=np_dtype, ) out_dp = dp_mod.call(x, dp_cache, radial) out_pt = pt_mod(_to_pt(x), pt_cache, _to_pt(radial_valid)) - _assert_parity(out_dp, out_pt) + _assert_parity(out_dp, out_pt, rtol=rtol, atol=atol) @pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern @@ -258,6 +277,25 @@ def test_so2_grid_branch_cross() -> None: _assert_conv_parity(pt_mod, dp_mod, kwargs) +def test_so2_message_node_so3_fp32_parity() -> None: + # fp32 parity for the example-config-shaped (lmax=3, mmax=1) message_node_so3 + # cross product. The flagship examples/water/dpa4/input.json runs + # precision: float32; the grid path reduces over many Lebedev points, so the + # right budget is the "computation-in-fp32" one (~1e-4), not fp64 bit-parity. + import torch + + pt_mod, dp_mod, kwargs = _build_conv_pair( + message_node_so3=True, + lmax=3, + mmax=1, + lebedev_quadrature=False, + dtype=torch.float32, + ) + _assert_conv_parity( + pt_mod, dp_mod, kwargs, np_dtype=np.float32, rtol=1e-4, atol=1e-4 + ) + + @pytest.mark.parametrize("masked", ["none", "slots"]) # padded-slot pattern def test_so2_no_grid_regression(masked) -> None: # all grid flags False: the base SO2Convolution path must still match pt. diff --git a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py index 078b7662f4..d83aa89a1e 100644 --- a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py +++ b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py @@ -47,6 +47,7 @@ def _build_so3_nets( mlp_bias=False, lebedev_precision=None, residual_scale_init=None, + precision="float64", seed=7, ): """Build a pt + dp ``SO3GridNet`` with identical (perturbed) weights.""" @@ -56,6 +57,7 @@ def _build_so3_nets( SO3GridNet as PTSO3GridNet, ) + pt_dtype = {"float64": torch.float64, "float32": torch.float32}[precision] common = { "lmax": lmax, "mmax": mmax, @@ -73,13 +75,13 @@ def _build_so3_nets( "trainable": True, "seed": seed, } - pt_net = PTSO3GridNet(dtype=torch.float64, **common).to("cpu") + pt_net = PTSO3GridNet(dtype=pt_dtype, **common).to("cpu") rng = np.random.default_rng(2100) with torch.no_grad(): for p in pt_net.parameters(): - p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))) + p += torch.from_numpy(0.1 * rng.normal(size=tuple(p.shape))).to(p.dtype) - dp_net = DPSO3GridNet(precision="float64", **common) + dp_net = DPSO3GridNet(precision=precision, **common) state = {k: v.detach().cpu().numpy() for k, v in pt_net.state_dict().items()} expected = {"scalar_gate.weight"} @@ -275,6 +277,38 @@ def test_so3_cross_flat_parity(op_type) -> None: np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +@pytest.mark.parametrize("op_type", ["glu", "mlp"]) # grid operation +def test_so3_fp32_parity(mode, op_type) -> None: + """fp32 weight-copied SO3GridNet matches pt at ~1e-4. + + The flagship ``examples/water/dpa4/input.json`` runs ``precision: + float32``. The grid path reduces over many Lebedev quadrature points, so + fp32 accumulation error is far above the 1-2 ulp budget; the right budget + is the "computation-in-fp32" one (rtol/atol ~1e-4), not bit-parity. + """ + lmax, n_focus, n_batch, kmax = 2, 2, 5, 2 + pt_net, dp_net = _build_so3_nets( + mode=mode, + op_type=op_type, + layout="ndfc", + lmax=lmax, + kmax=kmax, + n_focus=n_focus, + precision="float32", + ) + rng = np.random.default_rng(123) + query, context = _make_so3_inputs( + dp_net=dp_net, mode=mode, layout="ndfc", n_batch=n_batch, rng=rng + ) + query = query.astype(np.float32) + if context is not None: + context = context.astype(np.float32) + dp_out = _run(dp_net, query, context, "dp") + pt_out = _run(pt_net, query, context, "pt") + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-4, atol=1e-4) + + # === serialize ====================================================== From 50d9df3f719860c45506e053c2b00429ea46b9b3 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 18 Jun 2026 23:34:02 +0800 Subject: [PATCH 13/18] test(dpmodel): DPA4 SO3-grid consistency rows + descriptor interop + invariance --- .../dpmodel/test_dpa4_grid_descriptor.py | 188 ++++++++++++++++++ .../tests/consistent/descriptor/test_dpa4.py | 34 ++-- 2 files changed, 210 insertions(+), 12 deletions(-) create mode 100644 source/tests/common/dpmodel/test_dpa4_grid_descriptor.py diff --git a/source/tests/common/dpmodel/test_dpa4_grid_descriptor.py b/source/tests/common/dpmodel/test_dpa4_grid_descriptor.py new file mode 100644 index 0000000000..92b46dbae6 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_grid_descriptor.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Full-descriptor pt->dpmodel interop and invariance tests for the DPA4 SO(3) grid. + +These exercise the flagship ``examples/water/dpa4/input.json`` grid flags +(``ffn_so3_grid`` + ``message_node_so3`` + ``grid_branch``/``grid_mlp``) at the +descriptor level: + +* Part B (convert-backend interop): build the pt ``DescrptSeZM`` (the class + behind ``type: DPA4``) and the dpmodel ``DescrptDPA4`` with a matching, + example-config-like descriptor block, weight-copy via + ``DescrptDPA4.deserialize(pt.serialize())`` -- which is exactly the schema + path ``dp convert-backend`` uses -- and assert descriptor-output parity at + fp64 1e-10 on CPU. pt weights are deterministically perturbed so the + comparison is non-trivial. +* Part C (invariance): permuting the neighbor order within the nlist, and + appending an extra empty (``-1``) neighbor slot, both leave the per-atom + descriptor output unchanged. + +pt imports live inside the test functions because ruff TID253 bans +module-level ``deepmd.pt`` imports under ``source/tests/common``. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4 import ( + DescrptDPA4, +) + + +def build_neighbor_list_np(coord, rcut, nnei): + """Build a padded, distance-sorted gas-phase neighbor list. + + Parameters + ---------- + coord + Coordinates with shape (nf, nloc, 3); no PBC. + rcut + Cutoff radius. + nnei + Number of neighbor slots; pads with -1. + + Returns + ------- + np.ndarray + Neighbor list with shape (nf, nloc, nnei). + """ + nf, nloc, _ = coord.shape + nlist = -np.ones((nf, nloc, nnei), dtype=np.int64) + for f in range(nf): + dist = np.linalg.norm(coord[f][:, None, :] - coord[f][None, :, :], axis=-1) + for i in range(nloc): + neighbors = [ + (dist[i, j], j) for j in range(nloc) if j != i and dist[i, j] < rcut + ] + neighbors.sort() + for slot, (_, j) in enumerate(neighbors[:nnei]): + nlist[f, i, slot] = j + return nlist + + +def make_inputs(seed=5, nf=2, nloc=6, rcut=6.0, nnei=20, ntypes=2): + rng = np.random.default_rng(seed) + coord = rng.uniform(0.0, 3.5, size=(nf, nloc, 3)) + atype = rng.integers(0, ntypes, size=(nf, nloc)) + nlist = build_neighbor_list_np(coord, rcut, nnei) + return coord, atype, nlist + + +def example_descriptor_kwargs(**overrides) -> dict: + """Small example-config-like (examples/water/dpa4/input.json) descriptor block. + + Sizes are shrunk (channels=8, sel=20, so2_layers=2) for fast fp64 parity, + but the grid-relevant structure (lmax=3, mmax=1, n_focus=2, n_blocks=2, + grid_branch=[1,1,1], ffn_so3_grid + message_node_so3) mirrors the flagship + config. + """ + kwargs = { + "ntypes": 2, + "sel": 20, + "rcut": 6.0, + "channels": 8, + "n_radial": 8, + "lmax": 3, + "mmax": 1, + "n_blocks": 2, + "so2_layers": 2, + "n_focus": 2, + "focus_dim": 0, + "ffn_so3_grid": True, + "grid_mlp": [False, False, False], + "grid_branch": [1, 1, 1], + "message_node_so3": True, + "precision": "float64", + "seed": 42, + } + kwargs.update(overrides) + return kwargs + + +# Part B: pt -> dpmodel convert-backend interop, one case per grid path so a +# bug in either the FFN SO(3) grid or the post-aggregation SO(3) message is +# isolated. +@pytest.mark.parametrize( + "grid_flags", + [ + {"ffn_so3_grid": True, "message_node_so3": False}, # FFN SO(3) grid only + {"ffn_so3_grid": False, "message_node_so3": True}, # post-agg SO(3) msg only + {"ffn_so3_grid": True, "message_node_so3": True}, # both (example-config-like) + ], +) +def test_pt_to_dpmodel_interop(grid_flags) -> None: + """``DescrptDPA4.deserialize(pt.serialize())`` reproduces the pt descriptor. + + This proves the ``dp convert-backend`` schema interop for the SO(3) grid + paths at fp64 1e-10 (CPU). + """ + import torch + + from deepmd.pt.model.descriptor.sezm import ( + DescrptSeZM, + ) + + kwargs = example_descriptor_kwargs(**grid_flags) + + # pin to CPU: torch.from_numpy fp64 inputs and the module must agree under + # the CUDA-default-device CI configuration + pt_dd = DescrptSeZM(**kwargs).to("cpu") + # random init can give near-zero output; perturb deterministically so the + # comparison is non-trivial (asserted below via the magnitude check) + rng = np.random.default_rng(1234) + with torch.no_grad(): + for p in pt_dd.parameters(): + p += torch.from_numpy(0.1 * rng.standard_normal(size=tuple(p.shape))) + + # the convert-backend path: dpmodel reconstructs purely from pt's schema + dp_dd = DescrptDPA4.deserialize(pt_dd.serialize()) + + coord, atype, nlist = make_inputs() + nf = atype.shape[0] + coord_ext = coord.reshape(nf, -1) + + pt_out = ( + pt_dd( + torch.from_numpy(coord_ext), + torch.from_numpy(atype), + torch.from_numpy(nlist), + )[0] + .detach() + .cpu() + .numpy() + ) + dp_out = np.asarray(dp_dd.call(coord_ext, atype, nlist)[0]) + + # non-trivial output (perturbation took effect) + assert np.abs(dp_out).max() > 1e-6 + np.testing.assert_allclose(dp_out, pt_out, rtol=1e-10, atol=1e-10) + + +# Part C: invariance of the example-config descriptor (both grid flags on). +def test_permutation_invariance() -> None: + """Permuting the neighbor order within the nlist leaves the output unchanged.""" + dd = DescrptDPA4(**example_descriptor_kwargs()) + coord, atype, nlist = make_inputs() + nf = atype.shape[0] + coord_ext = coord.reshape(nf, -1) + out = np.asarray(dd.call(coord_ext, atype, nlist)[0]) + + rng = np.random.default_rng(7) + perm = rng.permutation(nlist.shape[-1]) + nlist2 = nlist[:, :, perm] + out2 = np.asarray(dd.call(coord_ext, atype, nlist2)[0]) + assert np.abs(out).max() > 1e-6 + np.testing.assert_allclose(out2, out, rtol=1e-10, atol=1e-12) + + +def test_masked_edge_noop() -> None: + """An extra all-(-1) neighbor slot must not change the descriptor.""" + dd = DescrptDPA4(**example_descriptor_kwargs()) + coord, atype, nlist = make_inputs() + nf, nloc = atype.shape + coord_ext = coord.reshape(nf, -1) + out = np.asarray(dd.call(coord_ext, atype, nlist)[0]) + + pad = -np.ones((nf, nloc, 1), dtype=nlist.dtype) + nlist2 = np.concatenate([nlist, pad], axis=-1) + out2 = np.asarray(dd.call(coord_ext, atype, nlist2)[0]) + np.testing.assert_allclose(out2, out, rtol=1e-10, atol=1e-12) diff --git a/source/tests/consistent/descriptor/test_dpa4.py b/source/tests/consistent/descriptor/test_dpa4.py index 7cfb642ac2..18b8ac1ee5 100644 --- a/source/tests/consistent/descriptor/test_dpa4.py +++ b/source/tests/consistent/descriptor/test_dpa4.py @@ -46,6 +46,9 @@ "grid_branch", "s2_activation", "basis_type", + "ffn_so3_grid", + "message_node_so3", + "grid_mlp", ) DPA4_BASELINE_CASE = { @@ -53,6 +56,9 @@ "grid_branch": [1, 1, 1], "s2_activation": [False, True], "basis_type": "bessel", + "ffn_so3_grid": False, + "message_node_so3": False, + "grid_mlp": False, } @@ -85,6 +91,14 @@ def dpa4_case(**overrides: Any) -> tuple: s2_activation=[False, False], basis_type="gaussian", ), + # SO(3) Wigner-D FFN grid (example-config flag) + dpa4_case(ffn_so3_grid=True), + # post-aggregation SO(3) cross-grid message (example-config flag) + dpa4_case(message_node_so3=True), + # both SO(3) grid paths on (mirrors examples/water/dpa4/input.json) + dpa4_case(ffn_so3_grid=True, message_node_so3=True), + # polynomial grid MLP op (grid_branch=0 so grid_mlp takes effect) + dpa4_case(grid_mlp=True, grid_branch=[0, 0, 0]), ) @@ -97,6 +111,9 @@ def data(self) -> dict: grid_branch, s2_activation, basis_type, + ffn_so3_grid, + message_node_so3, + grid_mlp, ) = self.param return { "ntypes": self.ntypes, @@ -110,6 +127,9 @@ def data(self) -> dict: "n_blocks": 2, "grid_branch": grid_branch, "s2_activation": s2_activation, + "ffn_so3_grid": ffn_so3_grid, + "message_node_so3": message_node_so3, + "grid_mlp": grid_mlp, "random_gamma": False, "precision": precision, "trainable": False, @@ -212,12 +232,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: @property def rtol(self) -> float: """Relative tolerance for comparing the return value.""" - ( - precision, - _grid_branch, - _s2_activation, - _basis_type, - ) = self.param + precision = self.param[0] if precision == "float64": return 1e-10 elif precision == "float32": @@ -228,12 +243,7 @@ def rtol(self) -> float: @property def atol(self) -> float: """Absolute tolerance for comparing the return value.""" - ( - precision, - _grid_branch, - _s2_activation, - _basis_type, - ) = self.param + precision = self.param[0] if precision == "float64": return 1e-10 elif precision == "float32": From 4a7be0e2f6f0e8e2a3f8430bcbcff1f0e651ab55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jun 2026 15:39:44 +0000 Subject: [PATCH 14/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 3 +-- .../dpmodel/test_dpa4_basegridnet_cross.py | 8 ++----- .../tests/common/dpmodel/test_dpa4_ffn_so3.py | 12 +++------- .../common/dpmodel/test_dpa4_frame_mixers.py | 8 +++---- .../dpmodel/test_dpa4_gridbranch_frames.py | 19 ++++----------- .../dpmodel/test_dpa4_gridmlp_frames.py | 19 ++++----------- .../dpmodel/test_dpa4_project_frames.py | 23 +++++-------------- .../common/dpmodel/test_dpa4_so2_grid.py | 8 ++----- .../common/dpmodel/test_dpa4_so3_gridnet.py | 16 ++++--------- 9 files changed, 31 insertions(+), 85 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index a5fee5a86f..1e3a6e3990 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -40,13 +40,12 @@ annotations, ) +import math from typing import ( TYPE_CHECKING, Any, ) -import math - import array_api_compat import numpy as np diff --git a/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py b/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py index c097b4293f..03d4e62987 100644 --- a/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py +++ b/source/tests/common/dpmodel/test_dpa4_basegridnet_cross.py @@ -15,9 +15,7 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - S2GridNet as DPS2GridNet, -) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet def _grid_op_param_names(op_type): @@ -45,9 +43,7 @@ def _build_nets( """Build a pt + dp ``S2GridNet`` with identical (perturbed) weights.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - S2GridNet as PTS2GridNet, - ) + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet common = { "lmax": lmax, diff --git a/source/tests/common/dpmodel/test_dpa4_ffn_so3.py b/source/tests/common/dpmodel/test_dpa4_ffn_so3.py index a726c59c80..65a2748f85 100644 --- a/source/tests/common/dpmodel/test_dpa4_ffn_so3.py +++ b/source/tests/common/dpmodel/test_dpa4_ffn_so3.py @@ -19,9 +19,7 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.ffn import ( - EquivariantFFN as DPFFN, -) +from deepmd.dpmodel.descriptor.dpa4_nn.ffn import EquivariantFFN as DPFFN def _build_ffn_pair(*, lmax, channels, hidden_channels, ffn_config, seed=7): @@ -34,9 +32,7 @@ def _build_ffn_pair(*, lmax, channels, hidden_channels, ffn_config, seed=7): """ import torch - from deepmd.pt.model.descriptor.sezm_nn.ffn import ( - EquivariantFFN as PTFFN, - ) + from deepmd.pt.model.descriptor.sezm_nn.ffn import EquivariantFFN as PTFFN pt_ffn = PTFFN( lmax=lmax, @@ -102,9 +98,7 @@ def test_ffn_so3_grid_parity(grid_mlp, grid_branch) -> None: def test_ffn_so3_grid_constructs() -> None: """The dp FFN with ffn_so3_grid=True constructs and runs (no NotImplementedError).""" - from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - SO3GridNet as DPSO3GridNet, - ) + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import SO3GridNet as DPSO3GridNet lmax, channels, hidden_channels, kmax = 3, 8, 8, 1 ffn = DPFFN( diff --git a/source/tests/common/dpmodel/test_dpa4_frame_mixers.py b/source/tests/common/dpmodel/test_dpa4_frame_mixers.py index 9f37c0bd8b..7432353a71 100644 --- a/source/tests/common/dpmodel/test_dpa4_frame_mixers.py +++ b/source/tests/common/dpmodel/test_dpa4_frame_mixers.py @@ -16,9 +16,9 @@ import numpy as np import pytest +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import FrameContract as DPFrameContract +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import FrameExpand as DPFrameExpand from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - FrameContract as DPFrameContract, - FrameExpand as DPFrameExpand, _build_frame_degree_index, ) @@ -83,9 +83,7 @@ def test_frame_expand_parity(lmax, channels, kmax) -> None: """The dpmodel ``FrameExpand`` matches pt with weight-copied fp64 weights.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - FrameExpand as PTFrameExpand, - ) + from deepmd.pt.model.descriptor.sezm_nn.grid_net import FrameExpand as PTFrameExpand n_frames = 2 * kmax + 1 coeff_dim = (lmax + 1) ** 2 diff --git a/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py b/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py index 04c6640ce1..56d73a6162 100644 --- a/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py +++ b/source/tests/common/dpmodel/test_dpa4_gridbranch_frames.py @@ -21,9 +21,7 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - GridBranch as DPGridBranch, -) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import GridBranch as DPGridBranch def _make_grid_fns(to_mat, from_mat, n_frames): @@ -94,9 +92,7 @@ def test_gridbranch_parity(n_frames, n_branches) -> None: """The dpmodel ``GridBranch`` matches pt over identical frame-aware grid fns.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - GridBranch as PTGridBranch, - ) + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridBranch as PTGridBranch channels, n_batch, coeff_dim, n_focus, grid_size = 4, 5, 9, 2, 7 rng = np.random.default_rng(2026) @@ -164,14 +160,9 @@ def test_gridbranch_s2_regression(n_branches) -> None: """ import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - GridBranch as PTGridBranch, - S2GridNet as PTS2GridNet, - ) - - from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - S2GridNet as DPS2GridNet, - ) + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridBranch as PTGridBranch + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet lmax, channels, n_focus = 2, 4, 1 # op_type='glu' makes grid_op a GridProduct; we reuse the nets only for diff --git a/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py b/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py index 6af6a71016..19a4f979e2 100644 --- a/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py +++ b/source/tests/common/dpmodel/test_dpa4_gridmlp_frames.py @@ -19,9 +19,7 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - GridMLP as DPGridMLP, -) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import GridMLP as DPGridMLP def _make_grid_fns(to_mat, from_mat, n_frames): @@ -90,9 +88,7 @@ def test_gridmlp_parity(n_frames, mode) -> None: """The dpmodel ``GridMLP`` matches pt over identical frame-aware grid fns.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - GridMLP as PTGridMLP, - ) + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridMLP as PTGridMLP channels, n_batch, coeff_dim, n_focus, grid_size = 4, 5, 9, 2, 7 rng = np.random.default_rng(2026) @@ -154,14 +150,9 @@ def test_gridmlp_s2_regression(mode) -> None: """ import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - GridMLP as PTGridMLP, - S2GridNet as PTS2GridNet, - ) - - from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - S2GridNet as DPS2GridNet, - ) + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridMLP as PTGridMLP + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet lmax, channels, n_focus = 2, 4, 1 # op_type='glu' makes grid_op a GridProduct; we reuse the nets only for diff --git a/source/tests/common/dpmodel/test_dpa4_project_frames.py b/source/tests/common/dpmodel/test_dpa4_project_frames.py index 485621156e..77687fb580 100644 --- a/source/tests/common/dpmodel/test_dpa4_project_frames.py +++ b/source/tests/common/dpmodel/test_dpa4_project_frames.py @@ -10,15 +10,11 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - GridProduct as DPGridProduct, -) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import GridProduct as DPGridProduct from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( _project_frames, ) -from deepmd.dpmodel.descriptor.dpa4_nn.so3 import ( - ChannelLinear as DPChannelLinear, -) +from deepmd.dpmodel.descriptor.dpa4_nn.so3 import ChannelLinear as DPChannelLinear @pytest.mark.parametrize("n_frames", [1, 2, 3]) # number of Wigner-D frames @@ -29,9 +25,7 @@ def test_project_frames_parity(n_frames) -> None: from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( _project_frames as pt_project_frames, ) - from deepmd.pt.model.descriptor.sezm_nn.so3 import ( - ChannelLinear as PTChannelLinear, - ) + from deepmd.pt.model.descriptor.sezm_nn.so3 import ChannelLinear as PTChannelLinear c_in, c_out = 4, 6 # pin to CPU so torch.from_numpy fp64 inputs and the module agree under the @@ -108,14 +102,9 @@ def test_grid_product_parity() -> None: """The dpmodel ``GridProduct`` matches pt over a real S2 projector's grid fns.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - GridProduct as PTGridProduct, - S2GridNet as PTS2GridNet, - ) - - from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - S2GridNet as DPS2GridNet, - ) + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridProduct as PTGridProduct + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet lmax, channels, n_focus = 2, 4, 1 # op_type='glu' makes grid_op a GridProduct; we reuse the nets only for diff --git a/source/tests/common/dpmodel/test_dpa4_so2_grid.py b/source/tests/common/dpmodel/test_dpa4_so2_grid.py index d69befcd54..af1cba2f4b 100644 --- a/source/tests/common/dpmodel/test_dpa4_so2_grid.py +++ b/source/tests/common/dpmodel/test_dpa4_so2_grid.py @@ -17,9 +17,7 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.so2 import ( - SO2Convolution as DPSO2Conv, -) +from deepmd.dpmodel.descriptor.dpa4_nn.so2 import SO2Convolution as DPSO2Conv # fp64 weight-copied parity is near-bit on CPU. RTOL, ATOL = 1e-12, 1e-14 @@ -77,9 +75,7 @@ def _perturb(pt_mod, seed): def _build_conv_pair(seed=17, perturb_seed=2060, dtype=None, **overrides): import torch - from deepmd.pt.model.descriptor.sezm_nn.so2 import ( - SO2Convolution as PTSO2Conv, - ) + from deepmd.pt.model.descriptor.sezm_nn.so2 import SO2Convolution as PTSO2Conv if dtype is None: dtype = torch.float64 diff --git a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py index d83aa89a1e..49ef28818b 100644 --- a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py +++ b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py @@ -17,12 +17,8 @@ import numpy as np import pytest -from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - S2GridNet as DPS2GridNet, -) -from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( - SO3GridNet as DPSO3GridNet, -) +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import SO3GridNet as DPSO3GridNet def _grid_op_param_names(op_type): @@ -53,9 +49,7 @@ def _build_so3_nets( """Build a pt + dp ``SO3GridNet`` with identical (perturbed) weights.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - SO3GridNet as PTSO3GridNet, - ) + from deepmd.pt.model.descriptor.sezm_nn.grid_net import SO3GridNet as PTSO3GridNet pt_dtype = {"float64": torch.float64, "float32": torch.float32}[precision] common = { @@ -389,9 +383,7 @@ def _build_s2_nets(*, mode, op_type, layout, lmax=2, channels=4, n_focus=1): """Build a pt + dp ``S2GridNet`` with identical (perturbed) weights.""" import torch - from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( - S2GridNet as PTS2GridNet, - ) + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet common = { "lmax": lmax, From c56cdc26e5f8499cb486726a265e535bed5418c6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 19 Jun 2026 00:33:46 +0800 Subject: [PATCH 15/18] fix(dpmodel): address AI-review findings on DPA4 SO3 grid (#5555) - BaseGridNet cross-mode: lift query/context to compute_dtype before FrameExpand so the frame expansion runs in the net's precision (mirrors pt's fp64-weight-forced FrameExpand); _FrameMixer otherwise casts weights down to the operand dtype, expanding fp32 inputs in fp32. (CodeRabbit) - SO3GridNet.deserialize: validate the nested projector @class/@version instead of blindly reading config["projector"]["config"]. (CodeRabbit) - SO2Convolution.deserialize: reject schema-drift keys under the node_wise/message_node grid-product prefixes (loaded @variables key set must match the fresh template). (CodeRabbit) - drop unused n_valid local in test_dpa4_so2_grid.py. (CodeQL) Tests: mixed-precision cross run, deserialize rejects bad projector @class/@version, deserialize rejects grid-product drift key. --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 18 ++++++++- deepmd/dpmodel/descriptor/dpa4_nn/so2.py | 21 +++++++++- .../common/dpmodel/test_dpa4_so2_grid.py | 17 +++++++- .../common/dpmodel/test_dpa4_so3_gridnet.py | 39 +++++++++++++++++++ 4 files changed, 91 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 1e3a6e3990..2a927d5fc5 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -951,6 +951,15 @@ def _prepare_cross_pair( scalar_pair = xp.concat([query[:, 0, :, :], context_ndfc[:, 0, :, :]], axis=-1) if scalar_pair.dtype != compute_dtype: scalar_pair = xp.astype(scalar_pair, compute_dtype) + # Lift operands to compute_dtype BEFORE frame expansion so FrameExpand + # runs in the net's precision. ``_FrameMixer.call`` casts its weights to + # the operand dtype, so without this an fp32 input through an fp64 grid + # net would expand in fp32 and only upcast afterward; pt's fp64 + # FrameExpand weights force fp64 expansion, so this matches pt. + if query.dtype != compute_dtype: + query = xp.astype(query, compute_dtype) + if context_ndfc.dtype != compute_dtype: + context_ndfc = xp.astype(context_ndfc, compute_dtype) return ( self.frame_expand(query), self.frame_expand(context_ndfc), @@ -1529,7 +1538,14 @@ def deserialize(cls, data: dict[str, Any]) -> SO3GridNet: check_version_compatibility(version, 1, 1) config = data.pop("config") variables = data.pop("@variables") - projector_config = config["projector"]["config"] + projector_data = config["projector"] + projector_cls = projector_data.get("@class") + if projector_cls != "SO3GridProjector": + raise ValueError( + f"Invalid nested projector class for SO3GridNet: {projector_cls}" + ) + check_version_compatibility(int(projector_data.get("@version", 1)), 1, 1) + projector_config = projector_data["config"] obj = cls( lmax=int(projector_config["lmax"]), mmax=int(projector_config["mmax"]), diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index a74c30ce7e..89d7da8098 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -1827,15 +1827,32 @@ def sub_vars(prefix: str) -> dict[str, Any]: # Grid products have no ``_load_variables``; reuse their config (from a # fresh ``serialize()``) plus the loaded @variables and re-deserialize # in place. This exercises the full grid-net serialize round-trip. + def _grid_product_vars(prefix: str, template: dict) -> dict: + # Reject schema drift: the loaded @variables keys must exactly match + # the fresh template's, so an unexpected ``.*`` key fails the + # conversion loudly instead of being silently dropped. + loaded = sub_vars(prefix) + expected = set(template["@variables"]) + if set(loaded) != expected: + raise ValueError( + f"{prefix} @variables keys {sorted(loaded)} do not match " + f"the expected keys {sorted(expected)}" + ) + return loaded + if self.node_wise_grid_product is not None: template = self.node_wise_grid_product.serialize() - template["@variables"] = sub_vars("node_wise_grid_product") + template["@variables"] = _grid_product_vars( + "node_wise_grid_product", template + ) self.node_wise_grid_product = type(self.node_wise_grid_product).deserialize( template ) if self.message_node_grid_product is not None: template = self.message_node_grid_product.serialize() - template["@variables"] = sub_vars("message_node_grid_product") + template["@variables"] = _grid_product_vars( + "message_node_grid_product", template + ) self.message_node_grid_product = type( self.message_node_grid_product ).deserialize(template) diff --git a/source/tests/common/dpmodel/test_dpa4_so2_grid.py b/source/tests/common/dpmodel/test_dpa4_so2_grid.py index af1cba2f4b..4bf13567fe 100644 --- a/source/tests/common/dpmodel/test_dpa4_so2_grid.py +++ b/source/tests/common/dpmodel/test_dpa4_so2_grid.py @@ -118,7 +118,6 @@ def _build_edge_data( elif masked != "none": raise ValueError(f"unknown masked mode {masked}") valid = mask > 0.5 - n_valid = int(valid.sum()) edge_vec = rng.normal(size=(n_edge, 3)) edge_vec /= np.linalg.norm(edge_vec, axis=-1, keepdims=True) @@ -299,3 +298,19 @@ def test_so2_no_grid_regression(masked) -> None: assert dp_mod.node_wise_grid_product is None assert dp_mod.message_node_grid_product is None _assert_conv_parity(pt_mod, dp_mod, kwargs, masked=masked) + + +def test_so2_deserialize_rejects_drift_key() -> None: + """A drift key under a grid-product prefix fails deserialization loudly.""" + pt_mod, _dp_mod, _kwargs = _build_conv_pair( + message_node_so3=True, lmax=3, mmax=1, lebedev_quadrature=False + ) + data = pt_mod.serialize() + var_key = next( + k for k in data["@variables"] if k.startswith("message_node_grid_product.") + ) + data["@variables"]["message_node_grid_product.__bogus__"] = data["@variables"][ + var_key + ] + with pytest.raises(ValueError, match="message_node_grid_product"): + DPSO2Conv.deserialize(data) diff --git a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py index 49ef28818b..9c248ab6b5 100644 --- a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py +++ b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py @@ -13,6 +13,8 @@ pinned to CPU (``.to("cpu")``) under the CUDA-default-device CI. """ +import copy + import array_api_compat import numpy as np import pytest @@ -432,3 +434,40 @@ def test_s2_regression(mode) -> None: dp_out = _run(dp_net, query, context, "dp") pt_out = _run(pt_net, query, context, "pt") np.testing.assert_allclose(dp_out, pt_out, rtol=1e-12, atol=1e-12) + + +def test_so3_cross_mixed_precision_runs() -> None: + """fp32 inputs through an fp64 SO3GridNet cross net run cleanly. + + ``_FrameMixer`` casts its weights to the operand dtype, so operands are + lifted to compute precision before frame expansion (matching pt's fp64 + FrameExpand); the mixed-precision path must run and stay close to the + fp64-input result. + """ + _pt, dp_net = _build_so3_nets( + mode="cross", op_type="glu", layout="ndfc", precision="float64" + ) + rng = np.random.default_rng(909) + query, context = _make_so3_inputs( + dp_net=dp_net, mode="cross", layout="ndfc", n_batch=3, rng=rng + ) + out32 = np.asarray( + dp_net.call(query.astype(np.float32), context.astype(np.float32)) + ) + out64 = np.asarray(dp_net.call(query, context)) + assert np.all(np.isfinite(out32)) + np.testing.assert_allclose(out32, out64, rtol=1e-4, atol=1e-4) + + +def test_so3_deserialize_rejects_bad_projector() -> None: + """SO3GridNet.deserialize validates the nested projector @class/@version.""" + _pt, dp_net = _build_so3_nets(mode="self", op_type="glu", layout="ndfc") + data = dp_net.serialize() + bad_class = copy.deepcopy(data) + bad_class["config"]["projector"]["@class"] = "S2GridProjector" + with pytest.raises(ValueError, match="projector"): + DPSO3GridNet.deserialize(bad_class) + bad_ver = copy.deepcopy(data) + bad_ver["config"]["projector"]["@version"] = 99 + with pytest.raises(Exception): + DPSO3GridNet.deserialize(bad_ver) From 83e0b79daf3b909b6d6633bff05a632bd80aceb2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:35:06 +0000 Subject: [PATCH 16/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/descriptor/dpa4_nn/so2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index 89d7da8098..97774155cf 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -1824,6 +1824,7 @@ def sub_vars(prefix: str) -> dict[str, Any]: ) if self.radial_degree_mixer is not None: self.radial_degree_mixer._load_variables(sub_vars("radial_degree_mixer")) + # Grid products have no ``_load_variables``; reuse their config (from a # fresh ``serialize()``) plus the loaded @variables and re-deserialize # in place. This exercises the full grid-net serialize round-trip. From 8b8608b9c9fa79c9a3fdcd29adfa0c08e6e79454 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 19 Jun 2026 09:55:05 +0800 Subject: [PATCH 17/18] fix(dpmodel): require nested projector @version; narrow version test (#5555) Follow-up to CodeRabbit re-review: - SO3GridNet.deserialize now requires the nested projector @version key (was silently defaulting a missing version to 1). - the version test asserts ValueError(match="version") instead of a blind Exception (ruff B017), and adds a missing-@version case. --- deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py | 4 +++- source/tests/common/dpmodel/test_dpa4_so3_gridnet.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 2a927d5fc5..407cf75b3f 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -1544,7 +1544,9 @@ def deserialize(cls, data: dict[str, Any]) -> SO3GridNet: raise ValueError( f"Invalid nested projector class for SO3GridNet: {projector_cls}" ) - check_version_compatibility(int(projector_data.get("@version", 1)), 1, 1) + if "@version" not in projector_data: + raise ValueError("nested SO3GridProjector payload is missing '@version'") + check_version_compatibility(int(projector_data["@version"]), 1, 1) projector_config = projector_data["config"] obj = cls( lmax=int(projector_config["lmax"]), diff --git a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py index 9c248ab6b5..2abbb80dbe 100644 --- a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py +++ b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py @@ -469,5 +469,9 @@ def test_so3_deserialize_rejects_bad_projector() -> None: DPSO3GridNet.deserialize(bad_class) bad_ver = copy.deepcopy(data) bad_ver["config"]["projector"]["@version"] = 99 - with pytest.raises(Exception): + with pytest.raises(ValueError, match="version"): DPSO3GridNet.deserialize(bad_ver) + missing_ver = copy.deepcopy(data) + del missing_ver["config"]["projector"]["@version"] + with pytest.raises(ValueError, match="@version"): + DPSO3GridNet.deserialize(missing_ver) From be4e13cfd3b7e8748ec6672b2609dfec9ef0b85f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 19 Jun 2026 12:09:53 +0800 Subject: [PATCH 18/18] ci: re-trigger readthedocs/CI (suspected transient post_install network failure)