diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py b/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py index a54cf7c4eb..18da3a0a76 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/ffn.py @@ -6,15 +6,11 @@ It defines the full SO(3)-equivariant feed-forward network used inside SeZM interaction blocks. -Branches guarded with ``NotImplementedError`` (flags unused by the core DPA4 -config): - -- ``ffn_so3_grid=True`` — the pt path instantiates ``SO3GridNet`` - (pt ffn.py:209), which is not ported to dpmodel. -- ``grid_mlp=True`` together with the grid path active selects - ``op_type='mlp'`` for ``S2GridNet`` (pt ffn.py:206); the delegate - ``S2GridNet`` constructor raises for that op type (``GridMLP`` is not - ported), so no duplicate guard is added here. +When the grid path is active the equivariant nonlinearity is a grid net +(``S2GridNet`` for the S2 path, ``SO3GridNet`` for ``ffn_so3_grid=True``); the +``grid_mlp``/``grid_branch`` flags pick the point-wise grid op type +(``branch`` if ``grid_branch>0`` else ``mlp`` if ``grid_mlp`` else ``glu``), +matching the pt reference (pt ffn.py:201-239). """ from __future__ import ( @@ -44,6 +40,7 @@ ) from .grid_net import ( S2GridNet, + SO3GridNet, ) from .projection import ( resolve_s2_grid_resolution, @@ -86,8 +83,7 @@ class EquivariantFFN(NativeOP): Maximum Wigner-D frame order (|k|) used by the SO3 Wigner-D FFN grid. grid_mlp If True, select the polynomial grid MLP operation when the - block-internal FFN grid path is enabled. Not ported: the delegate - ``S2GridNet`` raises ``NotImplementedError`` for ``op_type='mlp'``. + block-internal FFN grid path is enabled. grid_branch Number of scalar-routed polynomial product branches used when the block-internal FFN grid path is enabled. ``0`` disables this branch @@ -95,7 +91,7 @@ class EquivariantFFN(NativeOP): s2_activation If True, enable the S2 FFN grid path. ffn_so3_grid - If True, enable the SO3 Wigner-D FFN grid path (not ported). + If True, enable the SO3 Wigner-D FFN grid path (``SO3GridNet``). lebedev_quadrature If True, use Lebedev quadrature for the S2 projector in this FFN. activation_function @@ -145,10 +141,6 @@ def __init__( self.use_grid_branch = self.grid_branch > 0 self.s2_activation = bool(s2_activation) self.ffn_so3_grid = bool(ffn_so3_grid) - if self.ffn_so3_grid: - raise NotImplementedError( - "ffn_so3_grid=True (SO3GridNet) is not ported to dpmodel" - ) self.lebedev_quadrature = bool(lebedev_quadrature) self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" base_grid = resolve_s2_grid_resolution( @@ -167,8 +159,8 @@ def __init__( self.precision = precision self.compute_precision = _compute_precision(precision) self.trainable = bool(trainable) - # pt: grid_n_frames = 2 * kmax + 1 only when ffn_so3_grid (guarded above) - self.grid_n_frames = 1 + # pt: grid_n_frames = 2 * kmax + 1 only when ffn_so3_grid + self.grid_n_frames = 2 * self.kmax + 1 if self.ffn_so3_grid else 1 # === Step 0. Split deterministic seeds at the module top-level === seed_so3_in = child_seed(seed, 0) @@ -203,23 +195,40 @@ def __init__( if self.use_grid_branch else ("mlp" if self.use_grid_mlp else "glu") ) - # op_type='mlp' raises NotImplementedError inside S2GridNet - self.act: NativeOP = S2GridNet( - lmax=self.lmax, - channels=self.hidden_channels, - n_focus=1, - mode="self", - op_type=grid_op, - precision=self.compute_precision, - layout="ndfc", - grid_resolution_list=self.s2_grid_resolution, - coefficient_layout="packed", - grid_method=self.s2_grid_method, - grid_branches=max(1, self.grid_branch), - mlp_bias=self.mlp_bias, - trainable=self.trainable, - seed=seed_act, - ) + self.act: NativeOP + if self.ffn_so3_grid: + self.act = SO3GridNet( + lmax=self.lmax, + kmax=self.kmax, + channels=self.hidden_channels, + n_focus=1, + mode="self", + op_type=grid_op, + precision=self.compute_precision, + layout="ndfc", + coefficient_layout="packed", + grid_branches=max(1, self.grid_branch), + mlp_bias=self.mlp_bias, + trainable=self.trainable, + seed=seed_act, + ) + else: + self.act = S2GridNet( + lmax=self.lmax, + channels=self.hidden_channels, + n_focus=1, + mode="self", + op_type=grid_op, + precision=self.compute_precision, + layout="ndfc", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="packed", + grid_method=self.s2_grid_method, + grid_branches=max(1, self.grid_branch), + mlp_bias=self.mlp_bias, + trainable=self.trainable, + seed=seed_act, + ) else: self.act = GatedActivation( lmax=self.lmax, diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py index 70db705bb9..1e31981538 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/grid_net.py @@ -3,34 +3,23 @@ Grid-space nonlinearities for DPA4/SeZM coefficient tensors. This module is the dpmodel port of -``deepmd.pt.model.descriptor.sezm_nn.grid_net``, restricted to the S2/Lebedev -path used by the core DPA4 configuration. A grid net receives coefficient -tensors, converts them to quadrature values, applies one point-wise grid -operation, and projects the result back to coefficients. The public shapes -are: +``deepmd.pt.model.descriptor.sezm_nn.grid_net``, covering the S2/Lebedev and +SO(3)/Wigner-D quadrature paths. A grid net receives coefficient tensors, +converts them to quadrature values, applies one point-wise grid operation, and +projects the result back to coefficients. The public shapes are: * ``mode='self'``: one input ``(N, D, F, 2*C)`` or ``(N, F, D, 2*C)``. -* grid values: ``(N, G, F, C)`` after S2 projection. +* grid values: ``(N, G, F, C)`` after S2 or SO(3) projection. -Ported names: ``BaseGridNet`` (``mode='self'``; ``op_type`` 'glu'/'branch'), -``S2GridNet``, ``GridBranch``. +Ported names: ``BaseGridNet`` (``mode='self'``/``'cross'``; ``op_type`` +'glu'/'mlp'/'branch'), ``S2GridNet``, ``SO3GridNet``, ``GridBranch``, +``GridMLP``, ``FrameExpand``, ``FrameContract``. -Skipped names, with consumer evidence from the pt sources: - -- ``SO3GridNet``: only constructed by ``so2.py`` (``node_wise_so3``, - ``message_node_so3``) and ``ffn.py`` (``ffn_so3_grid``) — all disabled in - the core DPA4 config. -- ``FrameContract``, ``FrameExpand``, ``_build_frame_degree_index``: only - constructed by ``SO3GridNet`` (``mode='cross'``); the S2 projector always - has ``n_frames == 1``, so the frame machinery is unreachable here. -- ``GridMLP``: only selected via ``op_type='mlp'`` (``grid_mlp=True`` paths); - the core config has ``grid_mlp=[False, False, False]``. ``BaseGridNet`` - raises ``NotImplementedError`` for ``op_type='mlp'``. - -Guarded (routable from the shared ``S2GridNet`` entry point but only used by -the disabled ``node_wise_s2``/``message_node_s2`` grid products in -``so2.py``): ``mode='cross'`` (and with it ``layout='flat'``) and -``residual_scale_init is not None`` raise ``NotImplementedError``. +``mode='cross'`` (with ``layout='flat'`` and ``residual_scale_init``) is +supported for both projectors. For the SO(3) projector (``n_frames > 1``) the +frame axis is created by ``FrameExpand`` (``channels -> n_frames*channels``) +and collapsed by ``FrameContract``; for the S2 projector (``n_frames == 1``) +there is no frame machinery and query/context stay separate. Serialization contract: the pt ``S2GridNet`` and ``GridBranch`` define no ``serialize()`` (they only appear nested inside larger modules' @@ -46,6 +35,7 @@ annotations, ) +import math from typing import ( Any, ) @@ -76,9 +66,15 @@ from .activation import ( SwiGLU, ) +from .indexing import ( + build_l_major_index, + build_m_major_l_index, + map_degree_idx, +) from .projection import ( BaseGridProjector, S2GridProjector, + SO3GridProjector, ) from .so3 import ( ChannelLinear, @@ -93,6 +89,185 @@ def _softmax_last_axis(x: Any) -> Any: return e_x / xp.sum(e_x, axis=-1, keepdims=True) +def _build_frame_degree_index( + *, + lmax: int, + mmax: int, + coefficient_layout: str, +) -> np.ndarray: + """Build the per-coefficient degree index used by frame channel mixers. + + The torch version's ``device`` parameter is dropped: the output is a static + ``np.int64`` table mapping each coefficient row to its degree ``l``. + """ + coefficient_layout = str(coefficient_layout).lower() + if coefficient_layout == "m_major": + return build_m_major_l_index(lmax, mmax) + if coefficient_layout == "packed": + degree_index = map_degree_idx(lmax) + if int(mmax) == int(lmax): + return degree_index + coeff_index = build_l_major_index(lmax, mmax) + return degree_index[coeff_index] + raise ValueError("`coefficient_layout` must be either 'packed' or 'm_major'") + + +class GridMLP(NativeOP): + """ + Polynomial point-wise MLP applied independently at every grid point. + + The op is a pure quadratic channel product with no nonlinearity: two + channel-linear projections of the grid fields are multiplied and projected + back. In ``self`` mode both projections see ``concat(query, context)`` and + can form self and cross quadratic channel terms; in ``cross`` mode the + query and context roles stay separate (``(W_q query) * (W_c context)``). + + Parameters + ---------- + channels : int + Number of channels per grid point. + mode : str + Pairing mode; ``"self"`` or ``"cross"``. + precision : str + Parameter precision. + trainable : bool + Whether parameters are trainable. + seed : int | list[int] | None + Random seed for weight initialization. + + Notes + ----- + Like the pt ``GridMLP``, the three channel-linear projections are always + bias-free (the net-level ``mlp_bias`` flag only affects the scalar gate, + not the grid op). + """ + + def __init__( + self, + *, + channels: int, + mode: str, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + self.channels = int(channels) + self.mode = str(mode).lower() + if self.mode not in {"self", "cross"}: + raise ValueError("`mode` must be either 'self' or 'cross'") + self.precision = precision + self.trainable = bool(trainable) + self.input_channels = ( + 2 * self.channels if self.mode == "self" else self.channels + ) + self.hidden_channels = 2 * self.channels + self.left_proj = ChannelLinear( + in_channels=self.input_channels, + out_channels=self.hidden_channels, + precision=precision, + bias=False, + trainable=trainable, + seed=child_seed(seed, 0), + ) + self.right_proj = ChannelLinear( + in_channels=self.input_channels, + out_channels=self.hidden_channels, + precision=precision, + bias=False, + trainable=trainable, + seed=child_seed(seed, 1), + ) + self.out_proj = ChannelLinear( + in_channels=self.hidden_channels, + out_channels=self.channels, + precision=precision, + bias=False, + trainable=trainable, + seed=child_seed(seed, 2), + ) + + def call(self, query_grid: Any, context_grid: Any) -> Any: + """ + Apply the point-wise polynomial MLP to ``(N, G, F, C)`` grid fields. + + Parameters + ---------- + query_grid + First grid source with shape ``(N, G, F, C)``. + context_grid + Second grid source with shape ``(N, G, F, C)``. + """ + xp = array_api_compat.array_namespace(query_grid) + if self.mode == "self": + grid = xp.concat([query_grid, context_grid], axis=-1) + left = self.left_proj(grid) + right = self.right_proj(grid) + else: + left = self.left_proj(query_grid) + right = self.right_proj(context_grid) + return self.out_proj(left * right) # (N, G, F, C) + + def serialize(self) -> dict[str, Any]: + """Serialize the GridMLP to a dict. + + The pt ``GridMLP`` has no ``serialize()``; the ``@variables`` keys + here match the pt ``state_dict`` key names. + """ + variables = { + "left_proj.weight": to_numpy_array(self.left_proj.weight), + "right_proj.weight": to_numpy_array(self.right_proj.weight), + "out_proj.weight": to_numpy_array(self.out_proj.weight), + } + return { + "@class": "GridMLP", + "@version": 1, + "config": { + "channels": self.channels, + "mode": self.mode, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "trainable": self.trainable, + "seed": None, + }, + "@variables": variables, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> GridMLP: + """Deserialize a GridMLP from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "GridMLP": + raise ValueError(f"Invalid class for GridMLP: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls( + channels=int(config["channels"]), + mode=str(config["mode"]), + precision=str(config["precision"]), + trainable=bool(config["trainable"]), + seed=config.get("seed"), + ) + obj._load_variables(variables) + return obj + + def _load_variables(self, variables: dict[str, Any]) -> None: + prec = PRECISION_DICT[self.precision.lower()] + for name, proj in ( + ("left_proj", self.left_proj), + ("right_proj", self.right_proj), + ("out_proj", self.out_proj), + ): + weight = np.asarray(variables[f"{name}.weight"], dtype=prec) + if weight.shape != proj.weight.shape: + raise ValueError( + f"{name}.weight shape {weight.shape} does not match " + f"the expected shape {proj.weight.shape}" + ) + proj.weight = weight + + class GridBranch(NativeOP): """ Scalar-routed polynomial mixer over grid product branches. @@ -256,17 +431,233 @@ def _load_variables(self, variables: dict[str, Any]) -> None: proj.weight = weight +class _FrameMixer(NativeOP): + """Shared base for the per-degree frame channel mixers. + + The pt ``FrameContract``/``FrameExpand`` are unparameterised ``nn.Module`` + wrappers around a per-degree weight of shape ``(lmax + 1, in_ch, out_ch)`` + selected by a static degree-index buffer. ``mode='self'`` S2 grid nets + have ``n_frames == 1`` and never construct these; they back the SO(3) + cross-mode grid products only. Subclasses set ``in_channels`` / + ``out_channels`` and the init ``bound`` (the pt weight init). + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + in_channels: int, + out_channels: int, + init_bound: float, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + self.lmax = int(lmax) + self.mmax = int(mmax) + self.coefficient_layout = str(coefficient_layout).lower() + self.n_frames = int(n_frames) + self.channels = int(channels) + self.precision = precision + self.trainable = bool(trainable) + self.degree_index = _build_frame_degree_index( + lmax=self.lmax, + mmax=self.mmax, + coefficient_layout=self.coefficient_layout, + ) + prec = PRECISION_DICT[self.precision.lower()] + rng = np.random.default_rng(seed) + shape = (self.lmax + 1, int(in_channels), int(out_channels)) + self.weight = rng.uniform(-init_bound, init_bound, size=shape).astype(prec) + + def call(self, coeff: Any) -> Any: + """Apply the per-degree frame/channel map preserving the order index. + + ``einsum("ndfi,dio->ndfo", coeff, weight[degree_index])`` is realised as + a broadcast batched matmul (the gathered weight ``(D, i, o)`` broadcasts + over the leading frame batch dim of ``coeff``). + """ + xp = array_api_compat.array_namespace(coeff) + device = array_api_compat.device(coeff) + weight = xp_asarray_nodetach(xp, self.weight[...], device=device) + if weight.dtype != coeff.dtype: + weight = xp.astype(weight, coeff.dtype) + degree_index = xp_asarray_nodetach(xp, self.degree_index, device=device) + weight = xp.take(weight, degree_index, axis=0) # (D, i, o) + # (N, D, F, i) @ (1, D, i, o) -> (N, D, F, o) + return xp.matmul(coeff, weight[None, ...]) + + def _serialize_config(self) -> dict[str, Any]: + return { + "lmax": self.lmax, + "mmax": self.mmax, + "coefficient_layout": self.coefficient_layout, + "n_frames": self.n_frames, + "channels": self.channels, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "trainable": self.trainable, + "seed": None, + } + + @classmethod + def _deserialize(cls, data: dict[str, Any]) -> Any: + data = data.copy() + data_cls = data.pop("@class") + if data_cls != cls.__name__: + raise ValueError(f"Invalid class for {cls.__name__}: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls( + lmax=int(config["lmax"]), + mmax=int(config["mmax"]), + coefficient_layout=str(config["coefficient_layout"]), + n_frames=int(config["n_frames"]), + channels=int(config["channels"]), + precision=str(config["precision"]), + trainable=bool(config["trainable"]), + seed=config.get("seed"), + ) + prec = PRECISION_DICT[obj.precision.lower()] + weight = np.asarray(variables["weight"], dtype=prec) + if weight.shape != obj.weight.shape: + raise ValueError( + f"weight shape {weight.shape} does not match " + f"the expected shape {obj.weight.shape}" + ) + obj.weight = weight + return obj + + +class FrameContract(_FrameMixer): + """Per-degree frame/channel contraction that preserves the order index. + + Maps ``(N, D, F, K*C) -> (N, D, F, C)`` with a per-degree weight of shape + ``(lmax + 1, K*C, C)`` where ``K`` is ``n_frames``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + n_frames = int(n_frames) + channels = int(channels) + super().__init__( + lmax=lmax, + mmax=mmax, + coefficient_layout=coefficient_layout, + n_frames=n_frames, + channels=channels, + in_channels=n_frames * channels, + out_channels=channels, + init_bound=1.0 / math.sqrt(n_frames * channels), + precision=precision, + trainable=trainable, + seed=seed, + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the FrameContract to a dict. + + The pt ``FrameContract`` has no ``serialize()``; the ``@variables`` + key (``weight``) matches the pt ``state_dict`` key name. ``degree_index`` + is a non-persistent buffer in pt and is rebuilt from the config. + """ + return { + "@class": "FrameContract", + "@version": 1, + "config": self._serialize_config(), + "@variables": {"weight": to_numpy_array(self.weight)}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> FrameContract: + """Deserialize a FrameContract from a dict.""" + return cls._deserialize(data) + + +class FrameExpand(_FrameMixer): + """Per-degree frame/channel expansion that preserves the order index. + + Maps ``(N, D, F, C) -> (N, D, F, K*C)`` with a per-degree weight of shape + ``(lmax + 1, C, K*C)`` where ``K`` is ``n_frames``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int, + coefficient_layout: str, + n_frames: int, + channels: int, + precision: str = DEFAULT_PRECISION, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + n_frames = int(n_frames) + channels = int(channels) + super().__init__( + lmax=lmax, + mmax=mmax, + coefficient_layout=coefficient_layout, + n_frames=n_frames, + channels=channels, + in_channels=channels, + out_channels=n_frames * channels, + init_bound=1.0 / math.sqrt(channels), + precision=precision, + trainable=trainable, + seed=seed, + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the FrameExpand to a dict. + + The pt ``FrameExpand`` has no ``serialize()``; the ``@variables`` + key (``weight``) matches the pt ``state_dict`` key name. ``degree_index`` + is a non-persistent buffer in pt and is rebuilt from the config. + """ + return { + "@class": "FrameExpand", + "@version": 1, + "config": self._serialize_config(), + "@variables": {"weight": to_numpy_array(self.weight)}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> FrameExpand: + """Deserialize a FrameExpand from a dict.""" + return cls._deserialize(data) + + class BaseGridNet(NativeOP): """ - Shared implementation for S2 grid nets (``mode='self'`` only). + Shared implementation for S2 (``n_frames == 1``) and SO(3) grid nets. ``mode='self'`` expects one input whose last channel axis contains two branches; the first half supplies the SwiGLU gates of the scalar path. + ``mode='cross'`` expects separate query and context inputs and supports + ``layout='flat'`` and ``residual_scale_init``. - The pt ``mode='cross'`` path (with ``layout='flat'``, - ``residual_scale_init``, and the SO(3) frame machinery) backs the - ``node_wise_s2``/``message_node_s2`` grid products only, which are - disabled in the core DPA4 config; it is not ported. + The SO(3) frame machinery (FrameExpand/FrameContract, ``n_frames > 1``) + that backs ``SO3GridNet`` cross-mode is wired through the optional + ``frame_expand``/``frame_contract`` modules. When both are ``None`` (the + S2 path), the query/context/output widths collapse to + ``expanded_channels`` and the ``n_frames == 1`` fast paths are used. """ def __init__( @@ -282,6 +673,8 @@ def __init__( mlp_bias: bool, trainable: bool = True, grid_branches: int = 1, + frame_expand: _FrameMixer | None = None, + frame_contract: _FrameMixer | None = None, residual_scale_init: float | None = None, seed: int | list[int] | None = None, ) -> None: @@ -290,25 +683,12 @@ def __init__( self.channels = int(channels) self.n_focus = int(n_focus) self.n_frames = int(projector.n_frames) - if self.n_frames != 1: - raise ValueError( - "dpmodel BaseGridNet only supports S2 projectors (n_frames == 1)" - ) self.mode = str(mode).lower() if self.mode not in {"self", "cross"}: raise ValueError("`mode` must be either 'self' or 'cross'") - if self.mode == "cross": - raise NotImplementedError( - "mode='cross' (node_wise_s2/message_node_s2 grid products) " - "is not ported to dpmodel" - ) self.op_type = str(op_type).lower() if self.op_type not in {"glu", "mlp", "branch"}: raise ValueError("`op_type` must be one of 'glu', 'mlp', or 'branch'") - if self.op_type == "mlp": - raise NotImplementedError( - "op_type='mlp' (grid_mlp=True paths) is not ported to dpmodel" - ) self.precision = precision self.layout = str(layout).lower() if self.layout not in {"ndfc", "nfdc", "flat"}: @@ -318,16 +698,40 @@ def __init__( self.mlp_bias = bool(mlp_bias) self.trainable = bool(trainable) self.expanded_channels = self.n_frames * self.channels - self.query_channels = 2 * self.expanded_channels - self.output_channels = self.expanded_channels - self.frame_zero_index = 0 - if residual_scale_init is not None: - raise NotImplementedError( - "`residual_scale_init` is only used by the cross-mode " - "node_wise_s2/message_node_s2 grid products, which are not " - "ported to dpmodel" + self.frame_expand = frame_expand + self.frame_contract = frame_contract + # With frame_expand present (SO(3) cross), the external query/context + # widths are ``channels`` and the frame axis is created internally; with + # frame_contract present, the output collapses back to ``channels``. + # When both are None (the S2 path / SO(3) self path) all widths are + # ``expanded_channels``. + self.query_channels = ( + 2 * self.expanded_channels + if self.mode == "self" + else ( + self.channels + if self.frame_expand is not None + else self.expanded_channels + ) + ) + self.context_channels = ( + self.channels if self.frame_expand is not None else self.expanded_channels + ) + self.output_channels = ( + self.channels if self.frame_contract is not None else self.expanded_channels + ) + self.frame_zero_index = int(getattr(self.projector, "frame_zero_index", 0)) + self.residual_scale_init = ( + None if residual_scale_init is None else float(residual_scale_init) + ) + if self.residual_scale_init is None: + self.residual_scale = None + else: + prec = PRECISION_DICT[self.precision.lower()] + self.residual_scale = ( + np.ones((self.n_focus, self.output_channels), dtype=prec) + * self.residual_scale_init ) - self.residual_scale = None self.scalar_act = SwiGLU() self.scalar_gate = FocusLinear( @@ -340,8 +744,18 @@ def __init__( seed=child_seed(seed, 0), init_std=0.01, ) - if self.op_type == "branch": - self.grid_op: GridBranch | None = GridBranch( + if self.op_type == "mlp": + # GridMLP projections are bias-free (mirrors pt); the net-level + # mlp_bias only affects the scalar gate, not the grid op. + self.grid_op: GridMLP | GridBranch | None = GridMLP( + channels=self.channels, + mode=self.mode, + precision=self.precision, + trainable=trainable, + seed=child_seed(seed, 1), + ) + elif self.op_type == "branch": + self.grid_op = GridBranch( channels=self.channels, n_branches=grid_branches, precision=self.precision, @@ -358,15 +772,58 @@ def call(self, query: Any, context: Any = None) -> Any: input_dtype = query.dtype compute_dtype = get_xp_precision(xp, self.precision) query_ndfc = self._to_ndfc(query) - left, right = self._split_self_query(query_ndfc) - scalar_pair = self._make_scalar_pair(left, right, compute_dtype) + left, right, scalar_pair = self._prepare_pair( + query_ndfc, context, compute_dtype + ) grid_out = self._apply_grid_op(left, right, scalar_pair, compute_dtype) coeff_out = self._from_grid(grid_out) coeff_out = self._apply_scalar_path(coeff_out, scalar_pair) + coeff_out = self._contract_frames(coeff_out) + coeff_out = self._apply_residual_scale(coeff_out) if coeff_out.dtype != input_dtype: coeff_out = xp.astype(coeff_out, input_dtype) return self._restore_layout(coeff_out) + def _prepare_pair( + self, + query: Any, + context: Any, + compute_dtype: Any, + ) -> tuple[Any, Any, Any]: + if self.mode == "self": + left, right = self._split_self_query(query) + scalar_pair = self._make_scalar_pair(left, right, compute_dtype) + return left, right, scalar_pair + return self._prepare_cross_pair(query, context, compute_dtype) + + def _prepare_cross_pair( + self, + query: Any, + context: Any, + compute_dtype: Any, + ) -> tuple[Any, Any, Any]: + if context is None: + raise ValueError("`context` is required when `mode='cross'`") + xp = array_api_compat.array_namespace(query) + context_ndfc = self._to_ndfc(context) + self._check_last_dim(query, self.context_channels, "query") + self._check_last_dim(context_ndfc, self.context_channels, "context") + if self.frame_expand is None: + # S2 path: query and context keep their incoming (N, D, F, C) shape. + scalar_pair = self._make_scalar_pair(query, context_ndfc, compute_dtype) + return query, context_ndfc, scalar_pair + # SO(3) frame_expand path: the external query/context width is + # ``channels``, so the (l=0) scalar slice is the full leading row; the + # frame axis is created by frame_expand (channels -> n_frames*channels). + scalar_pair = xp.concat([query[:, 0, :, :], context_ndfc[:, 0, :, :]], axis=-1) + if scalar_pair.dtype != compute_dtype: + scalar_pair = xp.astype(scalar_pair, compute_dtype) + return ( + self.frame_expand(query), + self.frame_expand(context_ndfc), + scalar_pair, + ) + def _apply_grid_op( self, left: Any, @@ -383,17 +840,61 @@ def _apply_grid_op( right_grid = self._to_grid(right) if self.op_type == "glu": return left_grid * right_grid + if self.op_type == "mlp": + return self.grid_op(left_grid, right_grid) return self.grid_op(left_grid, right_grid, scalar_pair) + def _contract_frames(self, coeff: Any) -> Any: + # SO(3) cross-mode collapses the per-degree frame axis back to + # ``channels``; the S2 path leaves the coefficient unchanged. + if self.frame_contract is None: + return coeff + return self.frame_contract(coeff) + + def _apply_residual_scale(self, coeff: Any) -> Any: + if self.residual_scale is None: + return coeff + xp = array_api_compat.array_namespace(coeff) + scale = xp_asarray_nodetach( + xp, + self.residual_scale[...], + device=array_api_compat.device(coeff), + ) + if scale.dtype != coeff.dtype: + scale = xp.astype(scale, coeff.dtype) + # broadcast (n_focus, output_channels) over (N, D, F, C) + return coeff * xp.reshape(scale, (1, 1, self.n_focus, self.output_channels)) + def _apply_scalar_path(self, coeff: Any, scalar_pair: Any) -> Any: xp = array_api_compat.array_namespace(coeff) scalar_out = self.scalar_act(scalar_pair) # (N, F, C) scalar_gate = xp_sigmoid(self.scalar_gate(scalar_pair)) # (N, F, C) - coeff = coeff * scalar_gate[:, None, :, :] + if self.n_frames == 1: + coeff = coeff * scalar_gate[:, None, :, :] + # gradient-safe equivalent of the pt in-place + # ``coeff_view[:, 0, :, 0, :].add_(scalar_out)`` (n_frames == 1) + head = coeff[:, :1, :, :] + scalar_out[:, None, :, :] + return xp.concat([head, coeff[:, 1:, :, :]], axis=1) + # n_frames > 1: reshape the channel axis into (n_frames, channels), + # gate every frame, then add scalar_out only at (d=0, k=frame_zero). + n_batch, coeff_dim, n_focus, _ = coeff.shape + coeff_view = xp.reshape( + coeff, (n_batch, coeff_dim, n_focus, self.n_frames, self.channels) + ) + coeff_view = coeff_view * scalar_gate[:, None, :, None, :] # gradient-safe equivalent of the pt in-place - # ``coeff_view[:, 0, :, 0, :].add_(scalar_out)`` (n_frames == 1) - head = coeff[:, :1, :, :] + scalar_out[:, None, :, :] - return xp.concat([head, coeff[:, 1:, :, :]], axis=1) + # ``coeff_view[:, 0, :, frame_zero_index, :].add_(scalar_out)``: build a + # constant (D, K) one-hot placement and broadcast scalar_out onto it. + place = np.zeros((coeff_dim, self.n_frames), dtype=np.float64) + place[0, self.frame_zero_index] = 1.0 + place_xp = xp_asarray_nodetach(xp, place, device=array_api_compat.device(coeff)) + if place_xp.dtype != coeff_view.dtype: + place_xp = xp.astype(place_xp, coeff_view.dtype) + add_term = scalar_out[:, None, :, None, :] * place_xp[None, :, None, :, None] + coeff_view = coeff_view + add_term + return xp.reshape( + coeff_view, (n_batch, coeff_dim, n_focus, self.expanded_channels) + ) def _split_self_query(self, query: Any) -> tuple[Any, Any]: self._check_last_dim(query, self.query_channels, "query") @@ -418,10 +919,18 @@ def _make_scalar_pair(self, left: Any, right: Any, compute_dtype: Any) -> Any: def _extract_scalar(self, coeff: Any) -> Any: # (N, D, F, C) -> the (l=0, m=0) scalar slice (N, F, C); n_frames == 1 - return coeff[:, 0, :, :] + if self.n_frames == 1: + return coeff[:, 0, :, :] + # n_frames > 1: split the channel axis into (n_frames, channels) and + # pick the zero-frame of the (l=0) row. + xp = array_api_compat.array_namespace(coeff) + n_batch, coeff_dim, n_focus, _ = coeff.shape + coeff_view = xp.reshape( + coeff, (n_batch, coeff_dim, n_focus, self.n_frames, self.channels) + ) + return coeff_view[:, 0, :, self.frame_zero_index, :] def _to_grid(self, coeff: Any) -> Any: - # einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul xp = array_api_compat.array_namespace(coeff) n_batch, coeff_dim, n_focus, _ = coeff.shape to_grid_mat = xp_asarray_nodetach( @@ -429,38 +938,77 @@ def _to_grid(self, coeff: Any) -> Any: ) if to_grid_mat.dtype != coeff.dtype: to_grid_mat = xp.astype(to_grid_mat, coeff.dtype) - flat = xp.reshape(coeff, (n_batch, coeff_dim, n_focus * self.channels)) + if self.n_frames == 1: + # einsum "gd,ndfc->ngfc" (n_frames == 1) as a broadcast batched matmul + flat = xp.reshape(coeff, (n_batch, coeff_dim, n_focus * self.channels)) + out = xp.matmul(to_grid_mat[None, ...], flat) # (N, G, F*C) + return xp.reshape( + out, (n_batch, self.projector.grid_size, n_focus, self.channels) + ) + # einsum "gdk,ndfkc->ngfc": flatten (d, k) into the projector's J axis + # (J = d*K + k matches the projector's flat_idx ordering) and matmul + # against to_grid_mat (G, J). + coeff_view = xp.reshape( + coeff, (n_batch, coeff_dim, n_focus, self.n_frames, self.channels) + ) + # (N, D, F, K, C) -> (N, D, K, F, C) -> (N, D*K, F*C) + coeff_perm = xp.permute_dims(coeff_view, (0, 1, 3, 2, 4)) + flat = xp.reshape( + coeff_perm, + (n_batch, coeff_dim * self.n_frames, n_focus * self.channels), + ) out = xp.matmul(to_grid_mat[None, ...], flat) # (N, G, F*C) return xp.reshape( out, (n_batch, self.projector.grid_size, n_focus, self.channels) ) def _from_grid(self, grid: Any) -> Any: - # einsum "dg,ngfc->ndfc" (n_frames == 1) as a broadcast batched matmul xp = array_api_compat.array_namespace(grid) n_batch, n_grid, n_focus, _ = grid.shape - coeff_dim = self.projector.coeff_dim from_grid_mat = xp_asarray_nodetach( xp, self.projector.from_grid_mat[...], device=array_api_compat.device(grid) ) if from_grid_mat.dtype != grid.dtype: from_grid_mat = xp.astype(from_grid_mat, grid.dtype) flat = xp.reshape(grid, (n_batch, n_grid, n_focus * self.channels)) - out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D, F*C) + if self.n_frames == 1: + # einsum "dg,ngfc->ndfc" (n_frames == 1) as a broadcast batched matmul + out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D, F*C) + return xp.reshape( + out, + (n_batch, self.projector.coeff_dim, n_focus, self.expanded_channels), + ) + # einsum "dkg,ngfc->ndfkc": matmul against from_grid_mat (J=D*K, G) then + # split J back into (D, K) and move the frame axis next to the channel. + coeff_dim = self.projector.coeff_dim // self.n_frames + out = xp.matmul(from_grid_mat[None, ...], flat) # (N, D*K, F*C) + out = xp.reshape( + out, (n_batch, coeff_dim, self.n_frames, n_focus, self.channels) + ) + # (N, D, K, F, C) -> (N, D, F, K, C) + out = xp.permute_dims(out, (0, 1, 3, 2, 4)) return xp.reshape(out, (n_batch, coeff_dim, n_focus, self.expanded_channels)) def _to_ndfc(self, value: Any) -> Any: if self.layout == "ndfc": return value - # "nfdc": (N, F, D, C) -> (N, D, F, C); "flat" is cross-only (blocked) xp = array_api_compat.array_namespace(value) - return xp.permute_dims(value, (0, 2, 1, 3)) + if self.layout == "nfdc": + # (N, F, D, C) -> (N, D, F, C) + return xp.permute_dims(value, (0, 2, 1, 3)) + # "flat" (cross-only): (N, D, F*C) -> (N, D, F, C) + n_batch, coeff_dim, _ = value.shape + return xp.reshape(value, (n_batch, coeff_dim, self.n_focus, -1)) def _restore_layout(self, value: Any) -> Any: if self.layout == "ndfc": return value xp = array_api_compat.array_namespace(value) - return xp.permute_dims(value, (0, 2, 1, 3)) + if self.layout == "nfdc": + return xp.permute_dims(value, (0, 2, 1, 3)) + # "flat" (cross-only): (N, D, F, C) -> (N, D, F*C) + n_batch, coeff_dim, _, _ = value.shape + return xp.reshape(value, (n_batch, coeff_dim, -1)) def _check_last_dim(self, value: Any, expected: int, name: str) -> None: if value.shape[-1] != expected: @@ -468,6 +1016,56 @@ def _check_last_dim(self, value: Any, expected: int, name: str) -> None: f"`{name}` last dimension must be {expected}, got {value.shape[-1]}" ) + def _load_variables(self, variables: dict[str, Any]) -> None: + """Load variables keyed by the pt ``state_dict`` key names. + + Handles ``scalar_gate``, the optional ``grid_op`` (branch/mlp), the + optional SO(3) ``frame_expand``/``frame_contract`` per-degree mixers, + and the optional ``residual_scale`` parameter. ``S2GridNet`` leaves the + frame mixers as ``None`` so the corresponding keys are absent. + """ + prec = PRECISION_DICT[self.precision.lower()] + weight = np.asarray(variables["scalar_gate.weight"], dtype=prec) + if weight.shape != self.scalar_gate.weight.shape: + raise ValueError( + f"scalar_gate.weight shape {weight.shape} does not match " + f"the expected shape {self.scalar_gate.weight.shape}" + ) + self.scalar_gate.weight = weight + if self.mlp_bias: + self.scalar_gate.bias = np.asarray( + variables["scalar_gate.bias"], dtype=prec + ).reshape(self.scalar_gate.bias.shape) + if self.op_type in {"branch", "mlp"}: + self.grid_op._load_variables( + { + key[len("grid_op.") :]: value + for key, value in variables.items() + if key.startswith("grid_op.") + } + ) + for name, mixer in ( + ("frame_expand", self.frame_expand), + ("frame_contract", self.frame_contract), + ): + if mixer is None: + continue + mixer_weight = np.asarray(variables[f"{name}.weight"], dtype=prec) + if mixer_weight.shape != mixer.weight.shape: + raise ValueError( + f"{name}.weight shape {mixer_weight.shape} does not match " + f"the expected shape {mixer.weight.shape}" + ) + mixer.weight = mixer_weight + if self.residual_scale is not None: + residual_scale = np.asarray(variables["residual_scale"], dtype=prec) + if residual_scale.shape != self.residual_scale.shape: + raise ValueError( + f"residual_scale shape {residual_scale.shape} does not match " + f"the expected shape {self.residual_scale.shape}" + ) + self.residual_scale = residual_scale + class S2GridNet(BaseGridNet): """Grid net using an S2 spherical-harmonic projector (Lebedev only). @@ -483,14 +1081,14 @@ class S2GridNet(BaseGridNet): n_focus : int Number of focus streams. mode : str - Pairing mode; only ``"self"`` is ported. + Pairing mode; ``"self"`` or ``"cross"``. op_type : str - Point-wise grid operation; ``"glu"`` or ``"branch"`` (``"mlp"`` is - not ported). + Point-wise grid operation; ``"glu"``, ``"mlp"``, or ``"branch"``. precision : str Parameter precision. layout : str - Tensor layout convention: ``"ndfc"`` or ``"nfdc"``. + Tensor layout convention: ``"ndfc"``, ``"nfdc"``, or ``"flat"`` + (``"flat"`` is cross-mode only). grid_resolution_list : list[int] | None Lebedev ``[precision, n_points]`` pair; resolved automatically if None. coefficient_layout : str @@ -500,7 +1098,8 @@ class S2GridNet(BaseGridNet): grid_branches : int Number of scalar-routed branches when ``op_type='branch'``. residual_scale_init : float | None - Not ported (cross-mode only); must be None. + If set, scales the grid-net output by a per-(focus, channel) parameter + initialised to this value (cross-mode use). ``None`` disables it. mlp_bias : bool Whether to use bias in the scalar gate projection. trainable : bool @@ -569,10 +1168,14 @@ def serialize(self) -> dict[str, Any]: variables = {"scalar_gate.weight": to_numpy_array(self.scalar_gate.weight)} if self.mlp_bias: variables["scalar_gate.bias"] = to_numpy_array(self.scalar_gate.bias) - if self.op_type == "branch": + if self.op_type in {"branch", "mlp"}: grid_op_data = self.grid_op.serialize()["@variables"] for key, value in grid_op_data.items(): variables[f"grid_op.{key}"] = value + # ``residual_scale`` is an nn.Parameter directly on the pt module, so the + # @variables key matches its pt state_dict key. + if self.residual_scale is not None: + variables["residual_scale"] = to_numpy_array(self.residual_scale) return { "@class": "S2GridNet", "@version": 1, @@ -589,6 +1192,7 @@ def serialize(self) -> dict[str, Any]: "coefficient_layout": self.projector.coefficient_layout, "grid_method": self.grid_method, "grid_branches": self.grid_branches, + "residual_scale_init": self.residual_scale_init, "mlp_bias": self.mlp_bias, "trainable": self.trainable, "seed": None, @@ -620,28 +1224,211 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridNet: coefficient_layout=str(config["coefficient_layout"]), grid_method=str(config["grid_method"]), grid_branches=int(config["grid_branches"]), + residual_scale_init=config.get("residual_scale_init"), mlp_bias=bool(config["mlp_bias"]), trainable=bool(config["trainable"]), seed=config.get("seed"), ) - prec = PRECISION_DICT[obj.precision.lower()] - weight = np.asarray(variables["scalar_gate.weight"], dtype=prec) - if weight.shape != obj.scalar_gate.weight.shape: - raise ValueError( - f"scalar_gate.weight shape {weight.shape} does not match " - f"the expected shape {obj.scalar_gate.weight.shape}" + obj._load_variables(variables) + return obj + + +class SO3GridNet(BaseGridNet): + """Grid net using a Wigner-D SO(3) projector with frame indices. + + The dpmodel port of the pt ``SO3GridNet``. ``mode='self'`` keeps the frame + axis inside the channel (width ``n_frames * channels``); ``mode='cross'`` + expands the external ``channels``-wide query/context to the frame axis via + ``FrameExpand`` and collapses the output back via ``FrameContract``. + + Parameters + ---------- + lmax : int + Maximum spherical harmonic degree. + mmax : int | None + Maximum order kept in the coefficient layout. If None, use ``lmax``. + kmax : int + Frame-index half-width; the frame set is ``{0, -1, 1, ..., -kmax, kmax}``. + channels : int + Number of channels per (l, m) coefficient. + n_focus : int + Number of focus streams. + mode : str + Pairing mode; ``"self"`` or ``"cross"``. + op_type : str + Point-wise grid operation; ``"glu"``, ``"mlp"``, or ``"branch"``. + precision : str + Parameter precision. + layout : str + Tensor layout convention: ``"ndfc"``, ``"nfdc"``, or ``"flat"`` + (``"flat"`` is cross-mode only). + lebedev_precision : int | None + Explicit Lebedev rule precision. If None, resolved automatically. + coefficient_layout : str + ``"packed"`` or ``"m_major"`` coefficient ordering. + grid_branches : int + Number of scalar-routed branches when ``op_type='branch'``. + residual_scale_init : float | None + If set, scales the grid-net output by a per-(focus, channel) parameter + initialised to this value (cross-mode use). ``None`` disables it. + mlp_bias : bool + Whether to use bias in the scalar gate projection. + trainable : bool + Whether parameters are trainable. + seed : int | list[int] | None + Random seed for weight initialization. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + kmax: int = 1, + channels: int, + n_focus: int = 1, + mode: str, + op_type: str, + precision: str = DEFAULT_PRECISION, + layout: str, + lebedev_precision: int | None = None, + coefficient_layout: str = "packed", + grid_branches: int = 1, + residual_scale_init: float | None = None, + mlp_bias: bool = False, + trainable: bool = True, + seed: int | list[int] | None = None, + ) -> None: + projector = SO3GridProjector( + lmax=lmax, + mmax=mmax, + kmax=kmax, + precision=precision, + lebedev_precision=lebedev_precision, + coefficient_layout=coefficient_layout, + ) + self.frames = projector.frame_set + self.kmax = projector.kmax + self.lebedev_precision = projector.lebedev_precision + self.n_gamma = projector.n_gamma + self.grid_branches = int(grid_branches) + frame_expand: FrameExpand | None = None + frame_contract: FrameContract | None = None + if str(mode).lower() == "cross": + frame_expand = FrameExpand( + lmax=lmax, + mmax=projector.mmax, + coefficient_layout=coefficient_layout, + n_frames=projector.n_frames, + channels=channels, + precision=precision, + trainable=trainable, + seed=child_seed(seed, 4), ) - obj.scalar_gate.weight = weight - if obj.mlp_bias: - obj.scalar_gate.bias = np.asarray( - variables["scalar_gate.bias"], dtype=prec - ).reshape(obj.scalar_gate.bias.shape) - if obj.op_type == "branch": - obj.grid_op._load_variables( - { - key[len("grid_op.") :]: value - for key, value in variables.items() - if key.startswith("grid_op.") - } + frame_contract = FrameContract( + lmax=lmax, + mmax=projector.mmax, + coefficient_layout=coefficient_layout, + n_frames=projector.n_frames, + channels=channels, + precision=precision, + trainable=trainable, + seed=child_seed(seed, 5), ) + super().__init__( + projector=projector, + channels=channels, + n_focus=n_focus, + mode=mode, + op_type=op_type, + precision=precision, + layout=layout, + mlp_bias=mlp_bias, + trainable=trainable, + grid_branches=grid_branches, + frame_expand=frame_expand, + frame_contract=frame_contract, + residual_scale_init=residual_scale_init, + seed=seed, + ) + + def serialize(self) -> dict[str, Any]: + """Serialize the SO3GridNet to a dict. + + The pt ``SO3GridNet`` has no ``serialize()``; the ``@variables`` keys + here match the pt ``state_dict`` key names (the projector matrices are + non-persistent buffers in pt and are rebuilt from the config). The + ``frame_expand``/``frame_contract`` per-degree weights are emitted as + ``frame_expand.weight``/``frame_contract.weight`` (their pt state_dict + keys); their ``degree_index`` buffers are non-persistent and rebuilt. + """ + variables = {"scalar_gate.weight": to_numpy_array(self.scalar_gate.weight)} + if self.mlp_bias: + variables["scalar_gate.bias"] = to_numpy_array(self.scalar_gate.bias) + if self.op_type in {"branch", "mlp"}: + grid_op_data = self.grid_op.serialize()["@variables"] + for key, value in grid_op_data.items(): + variables[f"grid_op.{key}"] = value + if self.frame_expand is not None: + variables["frame_expand.weight"] = to_numpy_array(self.frame_expand.weight) + if self.frame_contract is not None: + variables["frame_contract.weight"] = to_numpy_array( + self.frame_contract.weight + ) + if self.residual_scale is not None: + variables["residual_scale"] = to_numpy_array(self.residual_scale) + return { + "@class": "SO3GridNet", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.projector.mmax, + "kmax": self.kmax, + "channels": self.channels, + "n_focus": self.n_focus, + "mode": self.mode, + "op_type": self.op_type, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "layout": self.layout, + "lebedev_precision": self.lebedev_precision, + "coefficient_layout": self.projector.coefficient_layout, + "grid_branches": self.grid_branches, + "residual_scale_init": self.residual_scale_init, + "mlp_bias": self.mlp_bias, + "trainable": self.trainable, + "seed": None, + }, + "@variables": variables, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO3GridNet: + """Deserialize an SO3GridNet from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO3GridNet": + raise ValueError(f"Invalid class for SO3GridNet: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + variables = data.pop("@variables") + obj = cls( + lmax=int(config["lmax"]), + mmax=int(config["mmax"]), + kmax=int(config["kmax"]), + channels=int(config["channels"]), + n_focus=int(config["n_focus"]), + mode=str(config["mode"]), + op_type=str(config["op_type"]), + precision=str(config["precision"]), + layout=str(config["layout"]), + lebedev_precision=int(config["lebedev_precision"]), + coefficient_layout=str(config["coefficient_layout"]), + grid_branches=int(config["grid_branches"]), + residual_scale_init=config.get("residual_scale_init"), + mlp_bias=bool(config["mlp_bias"]), + trainable=bool(config["trainable"]), + seed=config.get("seed"), + ) + obj._load_variables(variables) return obj diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py index e5131b5d12..c487e16773 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/projection.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/projection.py @@ -13,11 +13,13 @@ ``resolve_s2_grid_resolution`` (as-is, both methods — pure arithmetic), and ``_normalize_s2_grid_resolution``. -Skipped names (SO(3) Wigner-D grid machinery; consumed only by -``SO3GridNet`` in pt ``grid_net.py``, which backs the ``node_wise_so3``, -``message_node_so3``, and ``ffn_so3_grid`` paths — all disabled in the core -DPA4 config): ``SO3GridProjector``, ``resolve_so3_grid``, -``_build_so3_frame_set``. +SO(3) Wigner-D grid machinery (consumed by ``SO3GridNet`` in +``grid_net.py``, which backs the ``node_wise_so3``, ``message_node_so3``, +and ``ffn_so3_grid`` paths): ``resolve_so3_grid``, ``_build_so3_frame_set``, +and ``SO3GridProjector`` are ported here. The SO(3) projection matrices are +assembled at init time with pure numpy via the Wigner-D quadrature +(``WignerDCalculator``) over a Lebedev x gamma rotation grid, matching the pt +float64 buffers to machine precision. Not ported (guarded): the e3nn product-grid branch of ``S2GridProjector`` (``grid_method="e3nn"``, i.e. ``lebedev_quadrature=False``) raises @@ -66,6 +68,13 @@ from .indexing import ( build_l_major_index, build_m_major_index, + so3_packed_index, +) +from .wignerd import ( + WignerDCalculator, + build_edge_quaternion, + quaternion_multiply, + quaternion_z_rotation, ) @@ -303,6 +312,155 @@ def deserialize(cls, data: dict[str, Any]) -> S2GridProjector: ) +class SO3GridProjector(BaseGridProjector): + """ + Project SO(3) coefficients to/from a Wigner-D grid with frame indices. + + The coefficient axis is packed as ``(l, m, k)`` with ordinary SeZM + ``(l, m)`` order outside and the configured frame set inside each row. A + frame index outside ``[-l, l]`` is kept as a zero column/row. This keeps the + tensor layout regular while preserving the exact per-degree frame support. + + Parameters + ---------- + lmax + Maximum spherical harmonic degree. + mmax + Maximum order kept in the coefficient layout. If None, use ``lmax``. + kmax + Frame-index half-width; the frame set is ``{0, -1, 1, ..., -kmax, kmax}``. + precision + Buffer precision used by the projection matrices. + lebedev_precision + Explicit Lebedev rule precision. If None, resolved automatically. + coefficient_layout + Coefficient ordering expected by the caller: + - ``"packed"``: packed ``(l, m)`` order, optionally truncated by ``mmax``. + - ``"m_major"``: reduced m-major order used inside ``SO2Convolution``. + """ + + def __init__( + self, + *, + lmax: int, + mmax: int | None = None, + kmax: int = 1, + precision: str = DEFAULT_PRECISION, + lebedev_precision: int | None = None, + coefficient_layout: str = "packed", + ) -> None: + lmax_i = int(lmax) + mmax_i = int(lmax_i if mmax is None else mmax) + self.kmax = int(kmax) + if self.kmax < 0: + raise ValueError("`kmax` must be non-negative") + self.frame_set = _build_so3_frame_set(self.kmax) + self.frame_zero_index = self.frame_set.index(0) + self.lebedev_precision, self.lebedev_npoints, self.n_gamma = resolve_so3_grid( + lmax_i, + kmax=self.kmax, + lebedev_precision=lebedev_precision, + ) + super().__init__( + lmax=lmax_i, + mmax=mmax_i, + precision=precision, + n_frames=len(self.frame_set), + coefficient_layout=coefficient_layout, + ) + # plain numpy int64 attribute (becomes a torch buffer in pt_expt later) + self.frame_values = np.asarray(self.frame_set, dtype=np.int64) + + def _build_projection_mats( + self, + coeff_index: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + points, weights = load_lebedev_rule(self.lebedev_precision) + points = np.asarray(points, dtype=np.float64) + weights = np.asarray(weights, dtype=np.float64) + gamma = np.arange(self.n_gamma, dtype=np.float64) * ( + 2.0 * math.pi / float(self.n_gamma) + ) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + # torch ``repeat_interleave(n_gamma, dim=0)`` -> numpy ``repeat`` on axis 0 + edge_quaternion = np.repeat(edge_quaternion, self.n_gamma, axis=0) + # torch ``repeat(n_points, 1)`` (tile) -> numpy ``tile`` + gamma_quaternion = np.tile(quaternion_z_rotation(gamma), (points.shape[0], 1)) + grid_quaternion = quaternion_multiply(gamma_quaternion, edge_quaternion) + # WignerDCalculator.__call__ returns ``(D_full, Dt_full)``; take D_full. + wigner_grid = WignerDCalculator(self.lmax, precision="float64")( + grid_quaternion + )[0] + # ``build_edge_quaternion`` follows SeZM's global-to-local convention. + # The transpose below stores the local m=0 column in the same layout as + # ``WignerDCalculator.forward_zonal`` and extends it to k != 0. + wigner_grid = np.ascontiguousarray(np.swapaxes(wigner_grid, -1, -2)) + haar_weight = np.repeat(weights, self.n_gamma) / float(self.n_gamma) + + grid_size = int(grid_quaternion.shape[0]) + n_frames = len(self.frame_set) + coeff_dim = int(coeff_index.shape[0]) * n_frames + to_grid_mat = np.zeros((grid_size, coeff_dim), dtype=np.float64) + from_grid_mat = np.zeros((coeff_dim, grid_size), dtype=np.float64) + + for degree in range(self.lmax + 1): + degree_factor = float(2 * degree + 1) + for m_order in range(-degree, degree + 1): + packed_idx = so3_packed_index(degree, m_order) + # init-time numpy control-flow index; replaces pt's + # ``(coeff_index == packed_idx).nonzero(as_tuple=False)`` + coeff_positions = np.argwhere(coeff_index == packed_idx) + if coeff_positions.size == 0: + continue + coeff_pos = int(coeff_positions[0, 0]) + for frame_pos, frame_order in enumerate(self.frame_set): + flat_idx = coeff_pos * n_frames + frame_pos + if abs(frame_order) > degree: + continue + row = so3_packed_index(degree, m_order) + col = so3_packed_index(degree, frame_order) + values = wigner_grid[:, row, col] + to_grid_mat[:, flat_idx] = values + from_grid_mat[flat_idx, :] = degree_factor * haar_weight * values + return to_grid_mat, from_grid_mat + + def serialize(self) -> dict[str, Any]: + """Serialize the SO3GridProjector to a dict (pt-compatible format).""" + return { + "@class": "SO3GridProjector", + "@version": 1, + "config": { + "lmax": self.lmax, + "mmax": self.mmax, + "kmax": self.kmax, + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "lebedev_precision": self.lebedev_precision, + "coefficient_layout": self.coefficient_layout, + }, + "@variables": {}, + } + + @classmethod + def deserialize(cls, data: dict[str, Any]) -> SO3GridProjector: + """Deserialize an SO3GridProjector from a dict.""" + data = data.copy() + data_cls = data.pop("@class") + if data_cls != "SO3GridProjector": + raise ValueError(f"Invalid class for SO3GridProjector: {data_cls}") + version = int(data.pop("@version")) + check_version_compatibility(version, 1, 1) + config = data.pop("config") + data.pop("@variables", None) + return cls( + lmax=int(config["lmax"]), + mmax=int(config["mmax"]), + kmax=int(config["kmax"]), + precision=str(config["precision"]), + lebedev_precision=int(config["lebedev_precision"]), + coefficient_layout=str(config["coefficient_layout"]), + ) + + def resolve_s2_grid_resolution( lmax: int, mmax: int, @@ -370,3 +528,55 @@ def _normalize_s2_grid_resolution( if resolution[0] < 1 or resolution[1] < 1: raise ValueError("grid resolutions must be positive") return resolution + + +def resolve_so3_grid( + lmax: int, + *, + kmax: int = 1, + lebedev_precision: int | None = None, +) -> tuple[int, int, int]: + """ + Resolve the default SO(3) quadrature as Lebedev sphere times gamma samples. + + The Lebedev precision follows the same conservative ``3*lmax`` rule used by + the S2 grid path. The gamma grid is chosen for the quadratic grid products + used by the SO(3) grid nets, whose third-angle frequency can reach + ``k1 + k2 - kout``. + """ + lmax_i = int(lmax) + kmax_i = int(kmax) + if kmax_i < 0: + raise ValueError("`kmax` must be non-negative") + if lebedev_precision is None: + required_precision = 3 * lmax_i + for precision, n_points in LEBEDEV_PRECISION_TO_NPOINTS.items(): + if precision >= required_precision: + lebedev_precision = precision + lebedev_npoints = n_points + break + else: + raise ValueError( + f"No packaged Lebedev rule has precision >= {required_precision}" + ) + else: + lebedev_precision = int(lebedev_precision) + lebedev_npoints = LEBEDEV_PRECISION_TO_NPOINTS.get(lebedev_precision) + if lebedev_npoints is None: + raise ValueError( + f"Lebedev rule with precision {lebedev_precision} is not packaged" + ) + + # A quadratic product followed by analysis can contain gamma frequencies up + # to ``3*kmax``. A uniform grid with more samples than that frequency + # resolves the integer Fourier modes exactly. + n_gamma = 1 if kmax_i == 0 else 3 * kmax_i + 1 + return int(lebedev_precision), int(lebedev_npoints), int(n_gamma) + + +def _build_so3_frame_set(kmax: int) -> list[int]: + """Build the symmetric frame-index set with zero first.""" + kmax_i = int(kmax) + if kmax_i < 0: + raise ValueError("`kmax` must be non-negative") + return [0, *[frame for kk in range(1, kmax_i + 1) for frame in (-kk, kk)]] diff --git a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py index 1414956b13..3753d2ee21 100644 --- a/deepmd/dpmodel/descriptor/dpa4_nn/so2.py +++ b/deepmd/dpmodel/descriptor/dpa4_nn/so2.py @@ -20,10 +20,15 @@ projections, and the radial modulation) is identical to pt, just evaluated over the padded edge axis. +The cross-mode grid products (``node_wise_s2``/``node_wise_so3`` and +``message_node_s2``/``message_node_so3``) are wired in: when the matching +flags are set, ``SO2Convolution`` builds an ``S2GridNet``/``SO3GridNet`` +sub-module (SO3 wins when both are set) and applies it as a residual update in +the forward pass. + Branches guarded with ``NotImplementedError`` (flags unused by the core DPA4 config): ``so2_attn_res != "none"``, ``layer_scale``, ``s2_activation``, -``atten_f_mix``, ``atten_v_proj``, ``atten_o_proj``, ``node_wise_s2``, -``node_wise_so3``, ``message_node_s2``, ``message_node_so3``. +``atten_f_mix``, ``atten_v_proj``, ``atten_o_proj``. """ from __future__ import ( @@ -64,6 +69,10 @@ from .attention import ( segment_envelope_gated_softmax, ) +from .grid_net import ( + S2GridNet, + SO3GridNet, +) from .indexing import ( build_m_major_index, build_m_major_l_index, @@ -899,15 +908,6 @@ def __init__( self.node_wise_so3 = bool(node_wise_so3) self.message_node_s2 = bool(message_node_s2) self.message_node_so3 = bool(message_node_so3) - if self.node_wise_s2 or self.node_wise_so3: - raise NotImplementedError( - "node_wise_s2/node_wise_so3 grid products are not ported to dpmodel" - ) - if self.message_node_s2 or self.message_node_so3: - raise NotImplementedError( - "message_node_s2/message_node_so3 grid products are not ported " - "to dpmodel" - ) self.lebedev_quadrature = bool(lebedev_quadrature) self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn" self.s2_grid_resolution = resolve_s2_grid_resolution( @@ -915,6 +915,19 @@ def __init__( self.mmax, method=self.s2_grid_method, ) + # Full-mmax resolution for the message-node S2 grid product (which + # uses ``mmax = lmax``); the node-wise S2 grid uses the m-major + # ``self.s2_grid_resolution`` above (pt so2.py:999-1008). + base_full_grid_resolution = resolve_s2_grid_resolution( + self.lmax, + self.lmax, + method=self.s2_grid_method, + ) + self.s2_full_grid_resolution = ( + [max(base_full_grid_resolution), max(base_full_grid_resolution)] + if self.s2_grid_method == "e3nn" + else base_full_grid_resolution + ) self.activation_function = str(activation_function) self.attn_n_focus = self.n_focus self.attn_focus_dim = self.so2_focus_dim @@ -964,6 +977,8 @@ def __init__( seed_gate = child_seed(seed, 4) seed_radial_hidden = child_seed(seed, 6) seed_radial_degree = child_seed(seed, 7) + seed_node_wise_s2 = child_seed(seed, 8) + seed_message_node_s2 = child_seed(seed, 9) # === Step 3. Multiple SO2Linear layers === # (s2_activation is guarded above, so out_channels == so2_focus_dim.) @@ -1127,6 +1142,101 @@ def __init__( trainable=trainable, ) + # === Step 8b. Optional cross-mode grid products === + # ``node_wise``: query=source-rotated, context=destination-rotated. + # ``message_node``: query=aggregated message, context=node features. + # When both ``*_s2`` and ``*_so3`` are set, the SO3 net wins (pt + # argcheck doc + pt so2.py:1424-1500). + node_wise_op = ( + "branch" + if self.node_wise_grid_branch > 0 + else ("mlp" if self.node_wise_grid_mlp else "glu") + ) + node_wise_branches = max(1, self.node_wise_grid_branch) + message_node_op = ( + "branch" + if self.message_node_grid_branch > 0 + else ("mlp" if self.message_node_grid_mlp else "glu") + ) + message_node_branches = max(1, self.message_node_grid_branch) + self.node_wise_grid_product: S2GridNet | SO3GridNet | None = None + if self.node_wise_s2 or self.node_wise_so3: + if self.node_wise_so3: + self.node_wise_grid_product = SO3GridNet( + lmax=self.lmax, + mmax=self.mmax, + kmax=self.kmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=node_wise_op, + precision=self.compute_precision, + layout="flat", + coefficient_layout="m_major", + grid_branches=node_wise_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_node_wise_s2, + ) + else: + self.node_wise_grid_product = S2GridNet( + lmax=self.lmax, + mmax=self.mmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=node_wise_op, + precision=self.compute_precision, + layout="flat", + grid_resolution_list=self.s2_grid_resolution, + coefficient_layout="m_major", + grid_method=self.s2_grid_method, + grid_branches=node_wise_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_node_wise_s2, + ) + self.message_node_grid_product: S2GridNet | SO3GridNet | None = None + if self.message_node_s2 or self.message_node_so3: + if self.message_node_so3: + self.message_node_grid_product = SO3GridNet( + lmax=self.lmax, + kmax=self.kmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=message_node_op, + precision=self.compute_precision, + layout="flat", + coefficient_layout="packed", + grid_branches=message_node_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + seed=seed_message_node_s2, + ) + else: + self.message_node_grid_product = S2GridNet( + lmax=self.lmax, + mmax=self.lmax, + channels=self.so2_focus_dim, + n_focus=self.n_focus, + mode="cross", + op_type=message_node_op, + precision=self.compute_precision, + layout="flat", + grid_resolution_list=self.s2_full_grid_resolution, + grid_method=self.s2_grid_method, + grid_branches=message_node_branches, + mlp_bias=self.mlp_bias, + residual_scale_init=1e-3, + trainable=trainable, + coefficient_layout="packed", + seed=seed_message_node_s2, + ) + # === Step 9. Pre-focus channel mixing === # This projects the full channel width before the SO(2) focus split. self.pre_focus_mix = SO3Linear( @@ -1220,6 +1330,13 @@ def call( src_idx = xp.astype(xp.reshape(src, (n_edge,)), xp.int64) x_src = xp.take(x, src_idx, axis=0) # (E, D, C_wide) x_local = xp.matmul(D_m_prime, x_src) # (E, D_m, C_wide) + # ``node_wise`` grid product needs the destination-rotated features + # (query=source-rotated, context=destination-rotated); pt so2.py:1584. + x_dst_local: Any = None + if self.node_wise_grid_product is not None: + dst_idx_nw = xp.astype(xp.reshape(dst, (n_edge,)), xp.int64) + x_dst = xp.take(x, dst_idx_nw, axis=0) # (E, D, C_wide) + x_dst_local = xp.matmul(D_m_prime, x_dst) # (E, D_m, C_wide) # === Step 3. Select radial/type features for reduced layout === degree_index_m = xp_asarray_nodetach(xp, self.degree_index_m, device=device) @@ -1230,6 +1347,10 @@ def call( x_local = x_local * rad_feat else: x_local = self.radial_degree_mixer(x_local, rad_feat) + # ``node_wise`` cross-mode grid product (pt so2.py:1597-1601): a residual + # update in the m-major reduced "flat" layout (E, D_m, C_wide). + if self.node_wise_grid_product is not None: + x_local = x_local + self.node_wise_grid_product(x_local, x_dst_local) rad_feat_l0_focus = xp.reshape( rad_feat[:, 0, :], (n_edge, self.n_focus, self.so2_focus_dim) ) # (E, F, Cf) @@ -1480,6 +1601,13 @@ def apply_bias_correction( x.dtype, ) # (N, D, C_wide) + # === Step 9. Optional message-node grid product === + # Cross-mode grid product in the full packed SO(3) "flat" node layout + # (N, D, C_wide): query=aggregated message, context=node features + # (pt so2.py:1881-1883). + if self.message_node_grid_product is not None: + out = out + self.message_node_grid_product(out, x) + # === Step 10. Final channel mixing === out = self.post_focus_mix(out[:, :, None, :])[:, :, 0, :] return out # (N, D, C) @@ -1530,6 +1658,13 @@ def _variables(self) -> dict[str, np.ndarray]: if self.radial_degree_mixer is not None: for key, value in self.radial_degree_mixer._variables().items(): variables[f"radial_degree_mixer.{key}"] = value + for name, grid_product in ( + ("node_wise_grid_product", self.node_wise_grid_product), + ("message_node_grid_product", self.message_node_grid_product), + ): + if grid_product is not None: + for key, value in grid_product.serialize()["@variables"].items(): + variables[f"{name}.{key}"] = value for name, mix in ( ("pre_focus_mix", self.pre_focus_mix), ("post_focus_mix", self.post_focus_mix), @@ -1689,6 +1824,12 @@ def sub_vars(prefix: str) -> dict[str, Any]: ) if self.radial_degree_mixer is not None: self.radial_degree_mixer._load_variables(sub_vars("radial_degree_mixer")) + for name, grid_product in ( + ("node_wise_grid_product", self.node_wise_grid_product), + ("message_node_grid_product", self.message_node_grid_product), + ): + if grid_product is not None: + grid_product._load_variables(sub_vars(name)) for name, mix in ( ("pre_focus_mix", self.pre_focus_mix), ("post_focus_mix", self.post_focus_mix), diff --git a/source/tests/common/dpmodel/test_descrpt_dpa4.py b/source/tests/common/dpmodel/test_descrpt_dpa4.py index 43c882864a..a4e41887d8 100644 --- a/source/tests/common/dpmodel/test_descrpt_dpa4.py +++ b/source/tests/common/dpmodel/test_descrpt_dpa4.py @@ -164,9 +164,6 @@ def test_masked_edge_inertness(self) -> None: ("block_attn_res", "dependent"), # DepthAttnRes ("so2_attn_res", "independent"), # SO(2) DepthAttnRes ("s2_activation", [True, True]), # so2-side S2 activation - ("node_wise_s2", True), # SO(2) cross-grid product - ("message_node_so3", True), # SO(2) cross-grid product - ("ffn_so3_grid", True), # SO(3) Wigner-D FFN grid ("atten_f_mix", True), # SO(2) attention focus mix ("atten_v_proj", True), # SO(2) attention value projection ("atten_o_proj", True), # SO(2) attention output projection diff --git a/source/tests/common/dpmodel/test_dpa4_frame_mixers.py b/source/tests/common/dpmodel/test_dpa4_frame_mixers.py new file mode 100644 index 0000000000..d01ef73eed --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_frame_mixers.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 per-degree frame channel mixers. + +Compares the dpmodel ports of ``FrameContract``, ``FrameExpand`` and +``_build_frame_degree_index`` against the reference pt implementations using +weight-copied fp64 parity. All pt imports are kept inside the test functions +(ruff TID253). +""" + +import numpy as np +import pytest +from numpy.testing import ( + assert_allclose, +) + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + FrameContract, + FrameExpand, + _build_frame_degree_index, +) + +# (lmax, channels, kmax) combinations; n_frames = 2 * kmax + 1 +PARAMS = [ + (2, 4, 1), + (3, 2, 2), +] + + +def _copy_weight(dp_obj, pt_obj) -> None: + """Copy the pt weight tensor into the dpmodel numpy attribute.""" + dp_obj.weight = pt_obj.weight.detach().cpu().numpy().astype(np.float64) + + +@pytest.mark.parametrize("lmax,channels,kmax", PARAMS) # lmax, channels, kmax +def test_frame_contract_parity(lmax, channels, kmax) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + FrameContract as PTFrameContract, + ) + + n_frames = 2 * kmax + 1 + pt_obj = PTFrameContract( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + dtype=torch.float64, + trainable=True, + seed=0, + ).to("cpu") + dp_obj = FrameContract( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=0, + ) + _copy_weight(dp_obj, pt_obj) + + n_batch, n_focus = 3, 2 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(123) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, n_frames * channels)).astype( + np.float64 + ) + + dp_out = dp_obj.call(x) + pt_out = pt_obj(torch.from_numpy(x)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, channels) + assert_allclose(np.asarray(dp_out), pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("lmax,channels,kmax", PARAMS) # lmax, channels, kmax +def test_frame_expand_parity(lmax, channels, kmax) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import FrameExpand as PTFrameExpand + + n_frames = 2 * kmax + 1 + pt_obj = PTFrameExpand( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + dtype=torch.float64, + trainable=True, + seed=0, + ).to("cpu") + dp_obj = FrameExpand( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=0, + ) + _copy_weight(dp_obj, pt_obj) + + n_batch, n_focus = 3, 2 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(321) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + + dp_out = dp_obj.call(x) + pt_out = pt_obj(torch.from_numpy(x)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, n_frames * channels) + assert_allclose(np.asarray(dp_out), pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("lmax,channels,kmax", PARAMS) # lmax, channels, kmax +def test_expand_then_contract_shapes(lmax, channels, kmax) -> None: + n_frames = 2 * kmax + 1 + expand = FrameExpand( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=0, + ) + contract = FrameContract( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=1, + ) + n_batch, n_focus = 3, 2 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(7) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + expanded = expand.call(x) + assert expanded.shape == (n_batch, coeff_dim, n_focus, n_frames * channels) + contracted = contract.call(expanded) + assert contracted.shape == (n_batch, coeff_dim, n_focus, channels) + + +@pytest.mark.parametrize("cls", [FrameContract, FrameExpand]) # mixer class +def test_serialize_roundtrip(cls) -> None: + lmax, channels, kmax = 2, 4, 1 + n_frames = 2 * kmax + 1 + obj = cls( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=5, + ) + data = obj.serialize() + assert data["@version"] == 1 + obj2 = cls.deserialize(data) + assert_allclose(obj2.weight, obj.weight, atol=1e-12, rtol=1e-12) + np.testing.assert_array_equal(obj2.degree_index, obj.degree_index) + + n_batch, n_focus = 2, 2 + coeff_dim = (lmax + 1) ** 2 + in_ch = channels if cls is FrameExpand else n_frames * channels + rng = np.random.default_rng(11) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, in_ch)).astype(np.float64) + assert_allclose( + np.asarray(obj2.call(x)), + np.asarray(obj.call(x)), + atol=1e-12, + rtol=1e-12, + ) + + +@pytest.mark.parametrize("lmax", [1, 2, 3]) # max angular momentum +def test_degree_index(lmax) -> None: + from deepmd.pt.model.descriptor.sezm_nn.grid_net import ( + _build_frame_degree_index as pt_build_frame_degree_index, + ) + + dp_idx = _build_frame_degree_index( + lmax=lmax, mmax=lmax, coefficient_layout="packed" + ) + pt_idx = ( + pt_build_frame_degree_index(lmax=lmax, mmax=lmax, coefficient_layout="packed") + .detach() + .cpu() + .numpy() + ) + assert dp_idx.shape == ((lmax + 1) ** 2,) + np.testing.assert_array_equal(dp_idx, pt_idx) + # each (l, m) row maps to degree l: row d has degree dp_idx[d] + expected = np.concatenate( + [np.full(2 * l + 1, l, dtype=np.int64) for l in range(lmax + 1)] + ) + np.testing.assert_array_equal(dp_idx, expected) + + +@pytest.mark.parametrize("cls", [FrameContract, FrameExpand]) # mixer class +def test_torch_namespace_smoke(cls) -> None: + import torch + + lmax, channels, kmax = 2, 4, 1 + n_frames = 2 * kmax + 1 + obj = cls( + lmax=lmax, + mmax=lmax, + coefficient_layout="packed", + n_frames=n_frames, + channels=channels, + precision="float64", + trainable=True, + seed=9, + ) + n_batch, n_focus = 2, 2 + coeff_dim = (lmax + 1) ** 2 + in_ch = channels if cls is FrameExpand else n_frames * channels + rng = np.random.default_rng(13) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, in_ch)).astype(np.float64) + np_out = np.asarray(obj.call(x)) + torch_out = obj.call(torch.from_numpy(x)).detach().cpu().numpy() + assert_allclose(torch_out, np_out, atol=1e-12, rtol=1e-12) diff --git a/source/tests/common/dpmodel/test_dpa4_grid_mlp.py b/source/tests/common/dpmodel/test_dpa4_grid_mlp.py new file mode 100644 index 0000000000..7a36297aaa --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_grid_mlp.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 ``GridMLP`` grid op and the ``op_type='mlp'`` path. + +Compares the dpmodel port of ``GridMLP`` (and a full ``S2GridNet`` with +``op_type='mlp'``) against the reference pt implementations using weight-copied +fp64 parity, plus an SO(3) equivariance check for the mlp S2 grid net. All pt +imports are kept inside the test functions (ruff TID253). +""" + +import numpy as np +import pytest +from numpy.testing import ( + assert_allclose, +) + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + GridMLP, + S2GridNet, +) + + +def _rotate_ndfc(x: np.ndarray, d_matrix: np.ndarray) -> np.ndarray: + """Rotate coefficient-layout tensors with shape ``(N, D, F, C)``.""" + return np.einsum("nij,njfc->nifc", d_matrix, x) + + +def _random_quaternion(n_batch: int, seed: int) -> np.ndarray: + """Sample normalized quaternions in ``(w, x, y, z)`` order.""" + rng = np.random.default_rng(seed) + q = rng.standard_normal((n_batch, 4)).astype(np.float64) + return q / np.sqrt(np.sum(q * q, axis=-1, keepdims=True)) + + +def _copy_grid_mlp(dp_obj: GridMLP, pt_obj) -> None: + """Copy pt ``GridMLP`` projection weights into the dpmodel attributes.""" + sd = pt_obj.state_dict() + for name in ("left_proj", "right_proj", "out_proj"): + getattr(dp_obj, name).weight = ( + sd[f"{name}.weight"].detach().cpu().numpy().astype(np.float64) + ) + + +def _copy_s2gridnet_mlp(dp_net: S2GridNet, pt_net) -> None: + """Copy a pt mlp ``S2GridNet`` scalar gate + grid op weights into dpmodel.""" + sd = pt_net.state_dict() + dp_net.scalar_gate.weight = ( + sd["scalar_gate.weight"].detach().cpu().numpy().astype(np.float64) + ) + for name in ("left_proj", "right_proj", "out_proj"): + getattr(dp_net.grid_op, name).weight = ( + sd[f"grid_op.{name}.weight"].detach().cpu().numpy().astype(np.float64) + ) + + +@pytest.mark.parametrize("channels", [2, 4]) # channels per grid point +def test_grid_mlp_parity_self(channels) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridMLP as PTGridMLP + + pt_obj = PTGridMLP( + channels=channels, + mode="self", + dtype=torch.float64, + trainable=True, + seed=0, + ).to("cpu") + dp_obj = GridMLP( + channels=channels, + mode="self", + precision="float64", + trainable=True, + seed=0, + ) + _copy_grid_mlp(dp_obj, pt_obj) + + n_batch, n_grid, n_focus = 3, 5, 2 + rng = np.random.default_rng(123) + q = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + + dp_out = dp_obj.call(q, c) + pt_out = pt_obj(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, n_grid, n_focus, channels) + assert_allclose(np.asarray(dp_out), pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("channels", [2, 4]) # channels per grid point +def test_grid_mlp_parity_cross(channels) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import GridMLP as PTGridMLP + + pt_obj = PTGridMLP( + channels=channels, + mode="cross", + dtype=torch.float64, + trainable=True, + seed=1, + ).to("cpu") + dp_obj = GridMLP( + channels=channels, + mode="cross", + precision="float64", + trainable=True, + seed=1, + ) + _copy_grid_mlp(dp_obj, pt_obj) + + n_batch, n_grid, n_focus = 3, 5, 2 + rng = np.random.default_rng(321) + q = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + + dp_out = dp_obj.call(q, c) + pt_out = pt_obj(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, n_grid, n_focus, channels) + assert_allclose(np.asarray(dp_out), pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +def test_grid_mlp_serialize_roundtrip(mode) -> None: + channels = 4 + obj = GridMLP( + channels=channels, + mode=mode, + precision="float64", + trainable=True, + seed=5, + ) + data = obj.serialize() + assert data["@version"] == 1 + obj2 = GridMLP.deserialize(data) + + n_batch, n_grid, n_focus = 2, 4, 2 + rng = np.random.default_rng(11) + q = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + assert_allclose( + np.asarray(obj2.call(q, c)), + np.asarray(obj.call(q, c)), + atol=1e-12, + rtol=1e-12, + ) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +def test_grid_mlp_torch_namespace(mode) -> None: + import torch + + channels = 4 + obj = GridMLP( + channels=channels, + mode=mode, + precision="float64", + trainable=True, + seed=9, + ) + n_batch, n_grid, n_focus = 2, 4, 2 + rng = np.random.default_rng(13) + q = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, n_grid, n_focus, channels)).astype(np.float64) + np_out = np.asarray(obj.call(q, c)) + torch_out = ( + obj.call(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + ) + assert_allclose(torch_out, np_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("lmax,channels", [(2, 2), (3, 2)]) # lmax, channels +def test_s2gridnet_op_type_mlp_parity(lmax, channels) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet + + n_focus = 1 + pt_net = PTS2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="mlp", + dtype=torch.float64, + layout="ndfc", + coefficient_layout="packed", + grid_method="lebedev", + trainable=False, + seed=17 + lmax, + ).to("cpu") + dp_net = S2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="mlp", + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_method="lebedev", + trainable=False, + seed=17 + lmax, + ) + _copy_s2gridnet_mlp(dp_net, pt_net) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(2024 + lmax) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, 2 * channels)).astype( + np.float64 + ) + + dp_out = np.asarray(dp_net.call(x)) + pt_out = pt_net(torch.from_numpy(x)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("lmax,channels", [(2, 2), (3, 2)]) # lmax, channels +def test_s2gridnet_op_type_mlp_equivariance(lmax, channels) -> None: + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + n_focus = 1 + dp_net = S2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type="mlp", + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_method="lebedev", + trainable=False, + seed=31 + lmax, + ) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(4096 + lmax) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, 2 * channels)).astype( + np.float64 + ) + + quat = _random_quaternion(n_batch, seed=77 + lmax) + d_matrix, _ = WignerDCalculator(lmax=lmax, precision="float64").call(quat) + d_matrix = np.asarray(d_matrix) + + y_rotated_input = np.asarray(dp_net.call(_rotate_ndfc(x, d_matrix))) + y_then_rotated = _rotate_ndfc(np.asarray(dp_net.call(x)), d_matrix) + max_error = float(np.max(np.abs(y_rotated_input - y_then_rotated))) + assert max_error <= 1e-10, f"equivariance error {max_error}" diff --git a/source/tests/common/dpmodel/test_dpa4_grid_wiring.py b/source/tests/common/dpmodel/test_dpa4_grid_wiring.py new file mode 100644 index 0000000000..44977a6922 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_grid_wiring.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for wiring the SO3/S2 grid nets into the DPA4 FFN + SO2Convolution. + +These exercise the previously-guarded grid paths: + +- ``ffn_so3_grid=True`` -> ``EquivariantFFN`` builds an ``SO3GridNet`` (self + mode) for the equivariant nonlinearity. +- ``node_wise_so3``/``message_node_so3`` -> ``SO2Convolution`` builds a + cross-mode ``SO3GridNet`` (SO3 wins over S2 when both are set). +- ``node_wise_s2``/``message_node_s2`` -> cross-mode ``S2GridNet``. + +The flagship config (``examples/water/dpa4/input.json``) uses ``lmax=3, +mmax=1`` with both ``ffn_so3_grid`` and ``message_node_so3`` on, so the +SO3 grid path is exercised with a truncated ``mmax``. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4 import ( + DescrptDPA4, +) + + +def build_neighbor_list_np(coord, rcut, nnei): + """Build a padded, distance-sorted gas-phase (local) neighbor list.""" + nf, nloc, _ = coord.shape + nlist = -np.ones((nf, nloc, nnei), dtype=np.int64) + for f in range(nf): + dist = np.linalg.norm(coord[f][:, None, :] - coord[f][None, :, :], axis=-1) + for i in range(nloc): + neighbors = [ + (dist[i, j], j) for j in range(nloc) if j != i and dist[i, j] < rcut + ] + neighbors.sort() + for slot, (_, j) in enumerate(neighbors[:nnei]): + nlist[f, i, slot] = j + return nlist + + +def make_inputs(seed=5, nf=2, nloc=6, rcut=6.0, nnei=8, ntypes=2): + rng = np.random.default_rng(seed) + coord = rng.uniform(0.0, 4.0, size=(nf, nloc, 3)) + atype = rng.integers(0, ntypes, size=(nf, nloc)) + nlist = build_neighbor_list_np(coord, rcut, nnei) + return coord, atype, nlist + + +# Small flagship-shaped config: keep lmax=3, mmax=1 so the truncated-mmax +# SO3 grid path is exercised; shrink channels/sel/blocks for speed. +def _base_kwargs(**overrides): + kwargs = { + "ntypes": 2, + "sel": 8, + "rcut": 6.0, + "channels": 16, + "n_radial": 8, + "lmax": 3, + "mmax": 1, + "n_blocks": 2, + "so2_layers": 2, + "n_focus": 2, + "radial_so2_mode": "degree_channel", + "radial_so2_rank": 1, + "n_atten_head": 1, + "grid_mlp": [False, False, False], + "grid_branch": [1, 1, 1], + "lebedev_quadrature": True, + "precision": "float64", + "seed": 42, + } + kwargs.update(overrides) + return kwargs + + +def make_descriptor(**overrides) -> DescrptDPA4: + return DescrptDPA4(**_base_kwargs(**overrides)) + + +class TestGridWiringConstructsAndRuns: + @pytest.mark.parametrize( + "flags", + [ + {"ffn_so3_grid": True}, # SO3 self-mode FFN grid only + {"message_node_so3": True}, # SO3 cross-mode message-node grid only + {"ffn_so3_grid": True, "message_node_so3": True}, # both + {"node_wise_so3": True}, # SO3 cross-mode node-wise grid only + ], + ) + def test_descriptor_constructs_and_runs(self, flags) -> None: + dd = make_descriptor(**flags) + coord, atype, nlist = make_inputs() + nf, nloc = atype.shape + out = dd.call(coord.reshape(nf, -1), atype, nlist, mapping=None) + assert out[0].shape == (nf, nloc, dd.get_dim_out()) + assert out[1:] == (None, None, None, None) + assert np.isfinite(np.asarray(out[0])).all() + + def test_serialize_roundtrip(self) -> None: + dd = make_descriptor(ffn_so3_grid=True, message_node_so3=True) + data = dd.serialize() + dd2 = DescrptDPA4.deserialize(data) + coord, atype, nlist = make_inputs() + nf = atype.shape[0] + out1 = np.asarray(dd.call(coord.reshape(nf, -1), atype, nlist)[0]) + out2 = np.asarray(dd2.call(coord.reshape(nf, -1), atype, nlist)[0]) + np.testing.assert_array_equal(out1, out2) + + +class TestS2CrossPath: + @pytest.mark.parametrize( + "flags", + [ + {"node_wise_s2": True}, # S2 cross-mode node-wise grid (m-major) + {"message_node_s2": True}, # S2 cross-mode message-node grid (packed) + {"node_wise_s2": True, "message_node_s2": True}, # both S2 paths + ], + ) + def test_s2_cross_constructs_and_runs(self, flags) -> None: + dd = make_descriptor(**flags) + coord, atype, nlist = make_inputs() + nf, nloc = atype.shape + out = dd.call(coord.reshape(nf, -1), atype, nlist, mapping=None) + assert out[0].shape == (nf, nloc, dd.get_dim_out()) + assert np.isfinite(np.asarray(out[0])).all() + + def test_so3_wins_over_s2(self) -> None: + # When both s2 and so3 are set for a path, the SO3 net must be built. + from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + SO3GridNet, + ) + + dd = make_descriptor(message_node_s2=True, message_node_so3=True) + block = dd.blocks[0] + assert isinstance(block.so2_conv.message_node_grid_product, SO3GridNet) + + +# all three grid paths at once: S2 cross-mode node-wise grid, SO3 cross-mode +# message-node grid, SO3 self-mode FFN grid. +_ALL_GRID_FLAGS = { + "node_wise_s2": True, + "message_node_so3": True, + "ffn_so3_grid": True, +} + + +class TestGridInvariance: + """Physical invariances of the descriptor with the grid paths enabled. + + Exercises all three grid paths at once so a missing edge mask or a + neighbor-order dependence in any of them surfaces here. + """ + + def test_permutation_invariance(self) -> None: + # permuting the neighbor (slot) order within each atom's nlist row must + # not change the per-atom descriptor: the grid products aggregate over + # neighbors symmetrically. + dd = make_descriptor(**_ALL_GRID_FLAGS) + coord, atype, nlist = make_inputs() + nf, nloc = atype.shape + out = np.asarray(dd.call(coord.reshape(nf, -1), atype, nlist)[0]) + rng = np.random.default_rng(17) + nnei = nlist.shape[-1] + perm = rng.permutation(nnei) + nlist_perm = nlist[:, :, perm] + out_perm = np.asarray(dd.call(coord.reshape(nf, -1), atype, nlist_perm)[0]) + np.testing.assert_allclose(out_perm, out, rtol=1e-10, atol=1e-12) + + def test_masked_edge_noop(self) -> None: + # an extra all-(-1) neighbor column must not change the descriptor with + # the grid paths on (the cross-mode grid products must respect the edge + # mask of the padded layout). + dd = make_descriptor(**_ALL_GRID_FLAGS) + coord, atype, nlist = make_inputs() + nf, nloc = atype.shape + out = np.asarray(dd.call(coord.reshape(nf, -1), atype, nlist)[0]) + pad = -np.ones((nf, nloc, 1), dtype=nlist.dtype) + nlist2 = np.concatenate([nlist, pad], axis=-1) + out2 = np.asarray(dd.call(coord.reshape(nf, -1), atype, nlist2)[0]) + np.testing.assert_allclose(out2, out, rtol=1e-10, atol=1e-12) + + +class TestDescriptorParityVsPt: + """Weight-copy parity: build pt ``DescrptSeZM``, copy into dpmodel via + ``DescrptDPA4.deserialize(pt.serialize())``, compare descriptor outputs. + """ + + def _build_pair(self, perturb_seed=2130, **overrides): + import torch + + from deepmd.pt.model.descriptor.sezm import ( + DescrptSeZM, + ) + from deepmd.pt.utils import env as pt_env + + kwargs = _base_kwargs(**overrides) + pt_mod = DescrptSeZM(**kwargs).double().eval() + rng = np.random.default_rng(perturb_seed) + with torch.no_grad(): + for p in pt_mod.parameters(): + p += torch.from_numpy(0.05 * rng.normal(size=tuple(p.shape))).to( + pt_env.DEVICE + ) + dp_mod = DescrptDPA4.deserialize(pt_mod.serialize()) + return pt_mod, dp_mod + + def _assert_parity(self, pt_mod, dp_mod) -> None: + import torch + + from deepmd.pt.utils import env as pt_env + + # CPU: pt fp64 == numpy fp64 to ~1 ulp -> rtol 1e-10; CUDA index_add_ + # atomics are nondeterministic -> still 1e-10 is well below any logic bug. + coord, atype, nlist = make_inputs() + nf, nloc = atype.shape + out_dp = dp_mod.call(coord.reshape(nf, -1), atype, nlist, mapping=None) + out_pt = pt_mod( + torch.from_numpy(coord).to(pt_env.DEVICE), + torch.from_numpy(atype).to(pt_env.DEVICE), + torch.from_numpy(nlist).to(pt_env.DEVICE), + mapping=None, + ) + assert tuple(out_dp[0].shape) == tuple(out_pt[0].shape) + np.testing.assert_allclose( + np.asarray(out_dp[0]), + out_pt[0].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-12, + ) + + def test_parity_ffn_so3_grid(self) -> None: + pt_mod, dp_mod = self._build_pair(ffn_so3_grid=True) + self._assert_parity(pt_mod, dp_mod) + + def test_parity_message_node_so3(self) -> None: + pt_mod, dp_mod = self._build_pair(message_node_so3=True) + self._assert_parity(pt_mod, dp_mod) + + def test_parity_both_so3(self) -> None: + pt_mod, dp_mod = self._build_pair(ffn_so3_grid=True, message_node_so3=True) + self._assert_parity(pt_mod, dp_mod) + + def test_parity_node_wise_s2(self) -> None: + pt_mod, dp_mod = self._build_pair(node_wise_s2=True) + self._assert_parity(pt_mod, dp_mod) diff --git a/source/tests/common/dpmodel/test_dpa4_gridnet_cross.py b/source/tests/common/dpmodel/test_dpa4_gridnet_cross.py new file mode 100644 index 0000000000..079245972a --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_gridnet_cross.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity / equivariance tests for the DPA4 ``BaseGridNet`` cross-mode path. + +Covers the dpmodel port of ``mode='cross'`` (with ``layout='flat'`` and +``residual_scale_init``) using the S2 projector (``n_frames == 1``). All pt +imports are kept inside the test functions (ruff TID253). + +Test menu: + +* ``test_s2_cross_parity`` -- weight-copied fp64 forward parity (glu/mlp/branch). +* ``test_s2_cross_equivariance`` -- rotate query & context, SO(3) equivariance. +* ``test_layout_flat_parity`` -- ``layout='flat'`` parity. +* ``test_residual_scale_parity`` -- ``residual_scale_init`` parity + serialize. +* ``test_self_mode_regression`` -- the existing self-mode path still matches pt. +* ``test_torch_namespace`` -- cross-mode ``.call`` on torch inputs matches numpy. +""" + +import numpy as np +import pytest +from numpy.testing import ( + assert_allclose, +) + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet, +) + + +def _rotate_ndfc(x: np.ndarray, d_matrix: np.ndarray) -> np.ndarray: + """Rotate coefficient-layout tensors with shape ``(N, D, F, C)``.""" + return np.einsum("nij,njfc->nifc", d_matrix, x) + + +def _random_quaternion(n_batch: int, seed: int) -> np.ndarray: + """Sample normalized quaternions in ``(w, x, y, z)`` order.""" + rng = np.random.default_rng(seed) + q = rng.standard_normal((n_batch, 4)).astype(np.float64) + return q / np.sqrt(np.sum(q * q, axis=-1, keepdims=True)) + + +def _copy_s2gridnet(dp_net: S2GridNet, pt_net) -> None: + """Copy a pt ``S2GridNet`` scalar gate + grid op (+ residual) into dpmodel.""" + sd = pt_net.state_dict() + + def _np(key): + return sd[key].detach().cpu().numpy().astype(np.float64) + + dp_net.scalar_gate.weight = _np("scalar_gate.weight") + if "scalar_gate.bias" in sd: + dp_net.scalar_gate.bias = _np("scalar_gate.bias").reshape( + dp_net.scalar_gate.bias.shape + ) + if dp_net.op_type == "mlp": + for name in ("left_proj", "right_proj", "out_proj"): + getattr(dp_net.grid_op, name).weight = _np(f"grid_op.{name}.weight") + elif dp_net.op_type == "branch": + for name in ("left_proj", "right_proj", "router", "out_proj"): + getattr(dp_net.grid_op, name).weight = _np(f"grid_op.{name}.weight") + if dp_net.residual_scale is not None: + dp_net.residual_scale = _np("residual_scale").reshape( + dp_net.residual_scale.shape + ) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +@pytest.mark.parametrize("lmax,channels", [(2, 2), (3, 2)]) # lmax, channels +def test_s2_cross_parity(op_type, lmax, channels) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet + + n_focus = 1 + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": "cross", + "op_type": op_type, + "layout": "ndfc", + "coefficient_layout": "packed", + "grid_method": "lebedev", + "grid_branches": 2, + "trainable": False, + "seed": 17 + lmax, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = S2GridNet(precision="float64", **common) + _copy_s2gridnet(dp_net, pt_net) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(2024 + lmax) + q = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + + dp_out = np.asarray(dp_net.call(q, c)) + pt_out = pt_net(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +@pytest.mark.parametrize("lmax,channels", [(2, 2), (3, 2)]) # lmax, channels +def test_s2_cross_equivariance(op_type, lmax, channels) -> None: + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + n_focus = 1 + dp_net = S2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="cross", + op_type=op_type, + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_method="lebedev", + grid_branches=2, + trainable=False, + seed=31 + lmax, + ) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(4096 + lmax) + q = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + + quat = _random_quaternion(n_batch, seed=77 + lmax) + d_matrix, _ = WignerDCalculator(lmax=lmax, precision="float64").call(quat) + d_matrix = np.asarray(d_matrix) + + y_rotated_input = np.asarray( + dp_net.call(_rotate_ndfc(q, d_matrix), _rotate_ndfc(c, d_matrix)) + ) + y_then_rotated = _rotate_ndfc(np.asarray(dp_net.call(q, c)), d_matrix) + max_error = float(np.max(np.abs(y_rotated_input - y_then_rotated))) + assert max_error <= 1e-10, f"equivariance error {max_error}" + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +def test_layout_flat_parity(op_type) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet + + lmax, channels, n_focus = 2, 2, 2 + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": "cross", + "op_type": op_type, + "layout": "flat", + "coefficient_layout": "packed", + "grid_method": "lebedev", + "grid_branches": 2, + "trainable": False, + "seed": 53, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = S2GridNet(precision="float64", **common) + _copy_s2gridnet(dp_net, pt_net) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(909) + # flat layout: (N, D, F * C) + q = rng.standard_normal((n_batch, coeff_dim, n_focus * channels)).astype(np.float64) + c = rng.standard_normal((n_batch, coeff_dim, n_focus * channels)).astype(np.float64) + + dp_out = np.asarray(dp_net.call(q, c)) + pt_out = pt_net(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus * channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize( + "residual_scale_init", [None, 0.5] +) # residual-scale initial value +def test_residual_scale_parity(residual_scale_init) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet + + lmax, channels, n_focus = 2, 2, 1 + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": "cross", + "op_type": "glu", + "layout": "ndfc", + "coefficient_layout": "packed", + "grid_method": "lebedev", + "residual_scale_init": residual_scale_init, + "trainable": False, + "seed": 71, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = S2GridNet(precision="float64", **common) + _copy_s2gridnet(dp_net, pt_net) + + if residual_scale_init is None: + assert dp_net.residual_scale is None + else: + assert dp_net.residual_scale is not None + assert dp_net.residual_scale.shape == (n_focus, channels) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(606) + q = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + + dp_out = np.asarray(dp_net.call(q, c)) + pt_out = pt_net(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + # residual_scale must survive serialize/deserialize. + data = dp_net.serialize() + if residual_scale_init is None: + assert "residual_scale" not in data["@variables"] + else: + assert "residual_scale" in data["@variables"] + dp_net2 = S2GridNet.deserialize(data) + if residual_scale_init is None: + assert dp_net2.residual_scale is None + else: + assert dp_net2.residual_scale is not None + assert_allclose( + np.asarray(dp_net2.residual_scale), + np.asarray(dp_net.residual_scale), + atol=1e-12, + rtol=1e-12, + ) + assert_allclose( + np.asarray(dp_net2.call(q, c)), + dp_out, + atol=1e-12, + rtol=1e-12, + ) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +def test_self_mode_regression(op_type) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet + + lmax, channels, n_focus = 2, 2, 1 + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": "self", + "op_type": op_type, + "layout": "ndfc", + "coefficient_layout": "packed", + "grid_method": "lebedev", + "grid_branches": 2, + "trainable": False, + "seed": 19, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = S2GridNet(precision="float64", **common) + _copy_s2gridnet(dp_net, pt_net) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(135) + x = rng.standard_normal((n_batch, coeff_dim, n_focus, 2 * channels)).astype( + np.float64 + ) + + dp_out = np.asarray(dp_net.call(x)) + pt_out = pt_net(torch.from_numpy(x)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +def test_torch_namespace(op_type) -> None: + import torch + + lmax, channels, n_focus = 2, 2, 1 + dp_net = S2GridNet( + lmax=lmax, + channels=channels, + n_focus=n_focus, + mode="cross", + op_type=op_type, + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_method="lebedev", + grid_branches=2, + residual_scale_init=0.5, + trainable=False, + seed=23, + ) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(246) + q = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + c = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype(np.float64) + + np_out = np.asarray(dp_net.call(q, c)) + torch_out = ( + dp_net.call(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + ) + assert_allclose(torch_out, np_out, atol=1e-12, rtol=1e-12) diff --git a/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py b/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py new file mode 100644 index 0000000000..1e29eeb4e2 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so3_grid_utils.py @@ -0,0 +1,81 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the DPA4 SO(3) grid utility functions. + +Compares the dpmodel ports of ``resolve_so3_grid`` and ``_build_so3_frame_set`` +against the reference pt implementations. +""" + +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.projection import ( + _build_so3_frame_set, + resolve_so3_grid, +) + + +@pytest.mark.parametrize("kmax", [0, 1, 2, 3]) # frame-index half-width +def test_build_so3_frame_set(kmax) -> None: + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + _build_so3_frame_set as pt_build_so3_frame_set, + ) + + assert _build_so3_frame_set(kmax) == pt_build_so3_frame_set(kmax) + if kmax == 2: + assert _build_so3_frame_set(kmax) == [0, -1, 1, -2, 2] + + +@pytest.mark.parametrize( + "lmax,kmax", # max angular momentum, frame-index half-width + [(1, 1), (2, 1), (2, 2), (3, 1), (3, 2)], +) +def test_resolve_so3_grid(lmax, kmax) -> None: + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + resolve_so3_grid as pt_resolve_so3_grid, + ) + + dp_result = resolve_so3_grid(lmax, kmax=kmax) + pt_result = pt_resolve_so3_grid(lmax, kmax=kmax) + assert dp_result == pt_result + n_gamma = dp_result[2] + assert n_gamma == (1 if kmax == 0 else 3 * kmax + 1) + + +def test_resolve_so3_grid_kmax_zero() -> None: + """kmax=0 collapses the gamma grid to a single sample (n_gamma=1).""" + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + resolve_so3_grid as pt_resolve_so3_grid, + ) + + dp_result = resolve_so3_grid(2, kmax=0) + assert dp_result == pt_resolve_so3_grid(2, kmax=0) + assert dp_result[2] == 1 + + +def test_resolve_so3_grid_explicit_precision() -> None: + """An explicitly supplied (packaged) Lebedev precision is honored.""" + from deepmd.dpmodel.utils.lebedev import ( + LEBEDEV_PRECISION_TO_NPOINTS, + ) + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + resolve_so3_grid as pt_resolve_so3_grid, + ) + + precision = sorted(LEBEDEV_PRECISION_TO_NPOINTS)[3] + dp_result = resolve_so3_grid(2, kmax=1, lebedev_precision=precision) + assert dp_result == pt_resolve_so3_grid(2, kmax=1, lebedev_precision=precision) + assert dp_result[0] == precision + + +def test_resolve_so3_grid_unpackaged_precision_raises() -> None: + """An unpackaged explicit precision raises ValueError.""" + with pytest.raises(ValueError, match="not packaged"): + resolve_so3_grid(2, kmax=1, lebedev_precision=999999) + + +@pytest.mark.parametrize("kmax", [-1, -2]) # negative half-width is invalid +def test_negative_kmax_raises(kmax) -> None: + """Both utilities reject negative kmax.""" + with pytest.raises(ValueError, match="non-negative"): + _build_so3_frame_set(kmax) + with pytest.raises(ValueError, match="non-negative"): + resolve_so3_grid(2, kmax=kmax) diff --git a/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py new file mode 100644 index 0000000000..8af5e7cb9e --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so3_gridnet.py @@ -0,0 +1,425 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity / equivariance tests for the DPA4 ``SO3GridNet`` (self + cross). + +Covers the dpmodel port of the SO(3) Wigner-D grid net, including the +``mode='cross'`` frame machinery (``FrameExpand``/``FrameContract``, +``n_frames > 1``) and the ``layout='flat'`` frame-width path. All pt imports +are kept inside the test functions (ruff TID253). + +Test menu: + +* ``test_so3_self_parity`` -- self-mode weight-copied fp64 parity (glu/mlp/branch, kmax 1/2). +* ``test_so3_self_equivariance`` -- rotate input, SO(3) equivariance. +* ``test_so3_cross_parity`` -- cross-mode weight-copied fp64 parity. +* ``test_so3_cross_equivariance`` -- rotate query & context, SO(3) equivariance. +* ``test_so3_cross_flat_parity`` -- ``layout='flat'`` frame-width parity. +* ``test_so3_serialize_roundtrip`` -- serialize/deserialize forward identical. +* ``test_torch_namespace`` -- ``.call`` on torch inputs matches numpy. +* ``test_s2_regression`` -- the existing S2GridNet self+cross still matches pt. +""" + +import numpy as np +import pytest +from numpy.testing import ( + assert_allclose, +) + +from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( + S2GridNet, + SO3GridNet, +) + + +def _rotate_ndfc(x: np.ndarray, d_matrix: np.ndarray) -> np.ndarray: + """Rotate coefficient-layout tensors with shape ``(N, D, F, C)``.""" + return np.einsum("nij,njfc->nifc", d_matrix, x) + + +def _random_quaternion(n_batch: int, seed: int) -> np.ndarray: + """Sample normalized quaternions in ``(w, x, y, z)`` order.""" + rng = np.random.default_rng(seed) + q = rng.standard_normal((n_batch, 4)).astype(np.float64) + return q / np.sqrt(np.sum(q * q, axis=-1, keepdims=True)) + + +def _copy_so3gridnet(dp_net: SO3GridNet, pt_net) -> None: + """Copy a pt ``SO3GridNet`` state-dict into the dpmodel net.""" + sd = pt_net.state_dict() + + def _np(key): + return sd[key].detach().cpu().numpy().astype(np.float64) + + dp_net.scalar_gate.weight = _np("scalar_gate.weight") + if "scalar_gate.bias" in sd: + dp_net.scalar_gate.bias = _np("scalar_gate.bias").reshape( + dp_net.scalar_gate.bias.shape + ) + if dp_net.op_type == "mlp": + for name in ("left_proj", "right_proj", "out_proj"): + getattr(dp_net.grid_op, name).weight = _np(f"grid_op.{name}.weight") + elif dp_net.op_type == "branch": + for name in ("left_proj", "right_proj", "router", "out_proj"): + getattr(dp_net.grid_op, name).weight = _np(f"grid_op.{name}.weight") + if dp_net.frame_expand is not None: + dp_net.frame_expand.weight = _np("frame_expand.weight") + if dp_net.frame_contract is not None: + dp_net.frame_contract.weight = _np("frame_contract.weight") + if dp_net.residual_scale is not None: + dp_net.residual_scale = _np("residual_scale").reshape( + dp_net.residual_scale.shape + ) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +@pytest.mark.parametrize("kmax", [1, 2]) # frame-index half-width +def test_so3_self_parity(op_type, kmax) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import SO3GridNet as PTSO3GridNet + + lmax, channels, n_focus = 2, 2, 1 + common = { + "lmax": lmax, + "kmax": kmax, + "channels": channels, + "n_focus": n_focus, + "mode": "self", + "op_type": op_type, + "layout": "ndfc", + "coefficient_layout": "packed", + "grid_branches": 2, + "trainable": False, + "seed": 17 + kmax, + } + pt_net = PTSO3GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = SO3GridNet(precision="float64", **common) + _copy_so3gridnet(dp_net, pt_net) + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(2024 + kmax) + x = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.query_channels) + ).astype(np.float64) + + dp_out = np.asarray(dp_net.call(x)) + pt_out = pt_net(torch.from_numpy(x)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, dp_net.output_channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp"]) # grid op +@pytest.mark.parametrize("kmax", [1, 2]) # frame-index half-width +def test_so3_self_equivariance(op_type, kmax) -> None: + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + lmax, channels, n_focus = 2, 2, 1 + dp_net = SO3GridNet( + lmax=lmax, + kmax=kmax, + channels=channels, + n_focus=n_focus, + mode="self", + op_type=op_type, + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_branches=2, + trainable=False, + seed=31 + kmax, + ) + + n_batch = 2 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(4096 + kmax) + x = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.query_channels) + ).astype(np.float64) + + quat = _random_quaternion(n_batch, seed=77 + kmax) + d_matrix, _ = WignerDCalculator(lmax=lmax, precision="float64").call(quat) + d_matrix = np.asarray(d_matrix) + + y_rotated_input = np.asarray(dp_net.call(_rotate_ndfc(x, d_matrix))) + y_then_rotated = _rotate_ndfc(np.asarray(dp_net.call(x)), d_matrix) + max_error = float(np.max(np.abs(y_rotated_input - y_then_rotated))) + assert max_error <= 1e-10, f"equivariance error {max_error}" + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +@pytest.mark.parametrize("kmax", [1, 2]) # frame-index half-width +def test_so3_cross_parity(op_type, kmax) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import SO3GridNet as PTSO3GridNet + + lmax, channels, n_focus = 2, 2, 1 + common = { + "lmax": lmax, + "kmax": kmax, + "channels": channels, + "n_focus": n_focus, + "mode": "cross", + "op_type": op_type, + "layout": "ndfc", + "coefficient_layout": "packed", + "grid_branches": 2, + "trainable": False, + "seed": 41 + kmax, + } + pt_net = PTSO3GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = SO3GridNet(precision="float64", **common) + _copy_so3gridnet(dp_net, pt_net) + + n_batch = 3 + coeff_dim = dp_net.projector.coeff_dim // dp_net.n_frames + rng = np.random.default_rng(606 + kmax) + q = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + c = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + + dp_out = np.asarray(dp_net.call(q, c)) + pt_out = pt_net(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus, dp_net.output_channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("op_type", ["glu", "mlp"]) # grid op +@pytest.mark.parametrize("kmax", [1, 2]) # frame-index half-width +def test_so3_cross_equivariance(op_type, kmax) -> None: + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + ) + + lmax, channels, n_focus = 2, 2, 1 + dp_net = SO3GridNet( + lmax=lmax, + kmax=kmax, + channels=channels, + n_focus=n_focus, + mode="cross", + op_type=op_type, + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_branches=2, + trainable=False, + seed=51 + kmax, + ) + + n_batch = 2 + coeff_dim = dp_net.projector.coeff_dim // dp_net.n_frames + rng = np.random.default_rng(8192 + kmax) + q = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + c = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + + quat = _random_quaternion(n_batch, seed=99 + kmax) + d_matrix, _ = WignerDCalculator(lmax=lmax, precision="float64").call(quat) + d_matrix = np.asarray(d_matrix) + + y_rotated_input = np.asarray( + dp_net.call(_rotate_ndfc(q, d_matrix), _rotate_ndfc(c, d_matrix)) + ) + y_then_rotated = _rotate_ndfc(np.asarray(dp_net.call(q, c)), d_matrix) + max_error = float(np.max(np.abs(y_rotated_input - y_then_rotated))) + assert max_error <= 1e-10, f"equivariance error {max_error}" + + +@pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid op +def test_so3_cross_flat_parity(op_type) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import SO3GridNet as PTSO3GridNet + + lmax, channels, n_focus, kmax = 2, 2, 2, 1 + common = { + "lmax": lmax, + "kmax": kmax, + "channels": channels, + "n_focus": n_focus, + "mode": "cross", + "op_type": op_type, + "layout": "flat", + "coefficient_layout": "packed", + "grid_branches": 2, + "trainable": False, + "seed": 67, + } + pt_net = PTSO3GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = SO3GridNet(precision="float64", **common) + _copy_so3gridnet(dp_net, pt_net) + + n_batch = 3 + coeff_dim = dp_net.projector.coeff_dim // dp_net.n_frames + rng = np.random.default_rng(909) + # flat layout: (N, D, F * context_channels) + q = rng.standard_normal( + (n_batch, coeff_dim, n_focus * dp_net.context_channels) + ).astype(np.float64) + c = rng.standard_normal( + (n_batch, coeff_dim, n_focus * dp_net.context_channels) + ).astype(np.float64) + + dp_out = np.asarray(dp_net.call(q, c)) + pt_out = pt_net(torch.from_numpy(q), torch.from_numpy(c)).detach().cpu().numpy() + assert dp_out.shape == (n_batch, coeff_dim, n_focus * dp_net.output_channels) + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +def test_so3_serialize_roundtrip(mode) -> None: + lmax, channels, n_focus, kmax = 2, 2, 1, 2 + dp_net = SO3GridNet( + lmax=lmax, + kmax=kmax, + channels=channels, + n_focus=n_focus, + mode=mode, + op_type="branch", + precision="float64", + layout="ndfc", + coefficient_layout="packed", + grid_branches=2, + residual_scale_init=0.5, + trainable=False, + seed=73, + ) + data = dp_net.serialize() + assert data["@version"] == 1 + if mode == "cross": + assert "frame_expand.weight" in data["@variables"] + assert "frame_contract.weight" in data["@variables"] + dp_net2 = SO3GridNet.deserialize(data) + + n_batch = 3 + if mode == "self": + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(135) + x = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.query_channels) + ).astype(np.float64) + args = (x,) + else: + coeff_dim = dp_net.projector.coeff_dim // dp_net.n_frames + rng = np.random.default_rng(246) + q = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + c = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + args = (q, c) + + assert_allclose( + np.asarray(dp_net2.call(*args)), + np.asarray(dp_net.call(*args)), + atol=1e-12, + rtol=1e-12, + ) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +def test_torch_namespace(mode) -> None: + import torch + + lmax, channels, n_focus, kmax = 2, 2, 1, 2 + dp_net = SO3GridNet( + lmax=lmax, + kmax=kmax, + channels=channels, + n_focus=n_focus, + mode=mode, + op_type="mlp", + precision="float64", + layout="ndfc", + coefficient_layout="packed", + trainable=False, + seed=23, + ) + + n_batch = 3 + if mode == "self": + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(246) + x = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.query_channels) + ).astype(np.float64) + args = (x,) + else: + coeff_dim = dp_net.projector.coeff_dim // dp_net.n_frames + rng = np.random.default_rng(357) + q = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + c = rng.standard_normal( + (n_batch, coeff_dim, n_focus, dp_net.context_channels) + ).astype(np.float64) + args = (q, c) + + np_out = np.asarray(dp_net.call(*args)) + torch_args = tuple(torch.from_numpy(a) for a in args) + torch_out = dp_net.call(*torch_args).detach().cpu().numpy() + assert_allclose(torch_out, np_out, atol=1e-12, rtol=1e-12) + + +@pytest.mark.parametrize("mode", ["self", "cross"]) # pairing mode +def test_s2_regression(mode) -> None: + import torch + + from deepmd.pt.model.descriptor.sezm_nn.grid_net import S2GridNet as PTS2GridNet + + lmax, channels, n_focus = 2, 2, 1 + common = { + "lmax": lmax, + "channels": channels, + "n_focus": n_focus, + "mode": mode, + "op_type": "branch", + "layout": "ndfc", + "coefficient_layout": "packed", + "grid_method": "lebedev", + "grid_branches": 2, + "trainable": False, + "seed": 19, + } + pt_net = PTS2GridNet(dtype=torch.float64, **common).to("cpu") + dp_net = S2GridNet(precision="float64", **common) + sd = pt_net.state_dict() + + def _np(key): + return sd[key].detach().cpu().numpy().astype(np.float64) + + dp_net.scalar_gate.weight = _np("scalar_gate.weight") + for name in ("left_proj", "right_proj", "router", "out_proj"): + getattr(dp_net.grid_op, name).weight = _np(f"grid_op.{name}.weight") + + n_batch = 3 + coeff_dim = (lmax + 1) ** 2 + rng = np.random.default_rng(135) + if mode == "self": + x = rng.standard_normal((n_batch, coeff_dim, n_focus, 2 * channels)).astype( + np.float64 + ) + args = (x,) + else: + q = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype( + np.float64 + ) + c = rng.standard_normal((n_batch, coeff_dim, n_focus, channels)).astype( + np.float64 + ) + args = (q, c) + + dp_out = np.asarray(dp_net.call(*args)) + torch_args = tuple(torch.from_numpy(a) for a in args) + pt_out = pt_net(*torch_args).detach().cpu().numpy() + assert_allclose(dp_out, pt_out, atol=1e-12, rtol=1e-12) diff --git a/source/tests/common/dpmodel/test_dpa4_so3_projector.py b/source/tests/common/dpmodel/test_dpa4_so3_projector.py new file mode 100644 index 0000000000..01274f134c --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa4_so3_projector.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for the dpmodel ``SO3GridProjector`` (Wigner-D grid quadrature). + +Compares the dpmodel port against the reference pt implementation +(``deepmd.pt.model.descriptor.sezm_nn.projection.SO3GridProjector``) and checks +the legal-frame round-trip identity, serialization round-trip, and the kmax=0 +zonal convention. pt imports live inside the test functions to satisfy the +``source/tests/common`` import-isolation rule (ruff TID253). +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa4_nn.projection import ( + SO3GridProjector, +) + + +def _legal_so3_frame_mask(projector: SO3GridProjector) -> np.ndarray: + """Build the boolean mask of legal ``(l, m, k)`` flat-coefficient slots.""" + mask = np.ones(projector.coeff_dim, dtype=np.bool_) + n_frames = projector.n_frames + for degree in range(projector.lmax + 1): + for m_order in range(-degree, degree + 1): + packed_idx = degree * degree + degree + m_order + for frame_pos, frame_order in enumerate(projector.frame_set): + flat_idx = packed_idx * n_frames + frame_pos + if flat_idx >= projector.coeff_dim: + continue + if abs(frame_order) > degree: + mask[flat_idx] = False + return mask + + +# (lmax, kmax, mmax): max degree, frame-index half-width, retained order +_CASES = [(1, 1, 1), (2, 1, 1), (2, 2, 2), (3, 1, 1)] + + +@pytest.mark.parametrize("lmax,kmax,mmax", _CASES) +def test_projection_matrices_match_pt(lmax, kmax, mmax) -> None: + """Dpmodel projection matrices match the pt buffers at fp64.""" + import torch + + from deepmd.pt.model.descriptor.sezm_nn.projection import ( + SO3GridProjector as PTSO3GridProjector, + ) + + dp = SO3GridProjector(lmax=lmax, mmax=mmax, kmax=kmax, precision="float64") + pt = PTSO3GridProjector(lmax=lmax, mmax=mmax, kmax=kmax, dtype=torch.float64) + + np.testing.assert_allclose( + dp.to_grid_mat, + pt.to_grid_mat.detach().cpu().numpy(), + atol=1e-12, + rtol=1e-12, + ) + np.testing.assert_allclose( + dp.from_grid_mat, + pt.from_grid_mat.detach().cpu().numpy(), + atol=1e-12, + rtol=1e-12, + ) + assert dp.n_frames == pt.n_frames + assert dp.coeff_dim == pt.coeff_dim + assert dp.grid_size == pt.grid_size + assert dp.frame_set == pt.frame_set + + +@pytest.mark.parametrize("lmax", [1, 2, 3, 4, 5, 6]) # max degree +def test_roundtrip_preserves_legal_frame_coeffs(lmax) -> None: + """Project legal-frame coefficients to grid and back; recovery to 1e-12.""" + rng = np.random.default_rng(8100 + lmax) + projector = SO3GridProjector(lmax=lmax, kmax=1, precision="float64") + x = rng.standard_normal((2, projector.coeff_dim, 2)).astype(np.float64) + mask = _legal_so3_frame_mask(projector) + x[:, ~mask, :] = 0.0 + y = projector.from_grid(projector.to_grid(x)) + np.testing.assert_allclose(y[:, mask, :], x[:, mask, :], atol=1e-12, rtol=1e-12) + assert float(np.max(np.abs(y[:, ~mask, :]))) < 1e-14 + + +@pytest.mark.parametrize("lmax,kmax,mmax", _CASES) +def test_serialize_roundtrip(lmax, kmax, mmax) -> None: + """Serialize -> deserialize reproduces the matrices and config keys.""" + projector = SO3GridProjector(lmax=lmax, mmax=mmax, kmax=kmax, precision="float64") + data = projector.serialize() + assert data["@class"] == "SO3GridProjector" + assert data["@version"] == 1 + config = data["config"] + for key in ( + "lmax", + "mmax", + "kmax", + "precision", + "lebedev_precision", + "coefficient_layout", + ): + assert key in config + # the matrices must NOT be serialized (rebuilt at deserialize) + assert "to_grid_mat" not in config + assert "from_grid_mat" not in config + + restored = SO3GridProjector.deserialize(data) + np.testing.assert_array_equal(restored.to_grid_mat, projector.to_grid_mat) + np.testing.assert_array_equal(restored.from_grid_mat, projector.from_grid_mat) + assert restored.frame_set == projector.frame_set + + +def test_kmax_zero_zonal() -> None: + """kmax=0 collapses to a single frame and matches the Wigner zonal column.""" + from deepmd.dpmodel.descriptor.dpa4_nn.wignerd import ( + WignerDCalculator, + build_edge_quaternion, + ) + from deepmd.dpmodel.utils.lebedev import ( + load_lebedev_rule, + ) + + lmax = 6 + projector = SO3GridProjector(lmax=lmax, kmax=0, precision="float64") + assert projector.n_frames == 1 + + points, _ = load_lebedev_rule(projector.lebedev_precision) + points = np.asarray(points, dtype=np.float64) + edge_quaternion = build_edge_quaternion(points, eps=1e-14) + zonal = WignerDCalculator(lmax, precision="float64").forward_zonal( + edge_quaternion, lmin=1 + ) + np.testing.assert_allclose( + projector.to_grid_mat[:, 0], np.ones_like(points[:, 0]), atol=1e-14, rtol=1e-14 + ) + np.testing.assert_allclose( + projector.to_grid_mat[:, 1:], zonal, atol=1e-14, rtol=1e-14 + ) + + # the single-frame projector still round-trips legal coefficients + rng = np.random.default_rng(99) + x = rng.standard_normal((2, projector.coeff_dim, 2)).astype(np.float64) + mask = _legal_so3_frame_mask(projector) + x[:, ~mask, :] = 0.0 + y = projector.from_grid(projector.to_grid(x)) + np.testing.assert_allclose(y[:, mask, :], x[:, mask, :], atol=1e-12, rtol=1e-12) diff --git a/source/tests/consistent/descriptor/test_dpa4.py b/source/tests/consistent/descriptor/test_dpa4.py index 7cfb642ac2..9faecbf330 100644 --- a/source/tests/consistent/descriptor/test_dpa4.py +++ b/source/tests/consistent/descriptor/test_dpa4.py @@ -46,6 +46,9 @@ "grid_branch", "s2_activation", "basis_type", + "ffn_so3_grid", + "message_node_so3", + "grid_mlp", ) DPA4_BASELINE_CASE = { @@ -53,6 +56,9 @@ "grid_branch": [1, 1, 1], "s2_activation": [False, True], "basis_type": "bessel", + "ffn_so3_grid": False, + "message_node_so3": False, + "grid_mlp": False, } @@ -85,6 +91,15 @@ def dpa4_case(**overrides: Any) -> tuple: s2_activation=[False, False], basis_type="gaussian", ), + # SO(3) Wigner-D grid in the block-internal FFN + dpa4_case(ffn_so3_grid=True), + # SO(3) Wigner-D grid in the post-aggregation message-node path + dpa4_case(message_node_so3=True), + # both SO(3) grid paths enabled together + dpa4_case(ffn_so3_grid=True, message_node_so3=True), + # grid MLP point-wise op (op_type='mlp'); needs grid_branch=0 on the path, + # since positive grid_branch entries take precedence over grid_mlp + dpa4_case(grid_branch=[0, 0, 0], grid_mlp=True), ) @@ -97,6 +112,9 @@ def data(self) -> dict: grid_branch, s2_activation, basis_type, + ffn_so3_grid, + message_node_so3, + grid_mlp, ) = self.param return { "ntypes": self.ntypes, @@ -110,6 +128,9 @@ def data(self) -> dict: "n_blocks": 2, "grid_branch": grid_branch, "s2_activation": s2_activation, + "ffn_so3_grid": ffn_so3_grid, + "message_node_so3": message_node_so3, + "grid_mlp": grid_mlp, "random_gamma": False, "precision": precision, "trainable": False, @@ -217,6 +238,9 @@ def rtol(self) -> float: _grid_branch, _s2_activation, _basis_type, + _ffn_so3_grid, + _message_node_so3, + _grid_mlp, ) = self.param if precision == "float64": return 1e-10 @@ -233,6 +257,9 @@ def atol(self) -> float: _grid_branch, _s2_activation, _basis_type, + _ffn_so3_grid, + _message_node_so3, + _grid_mlp, ) = self.param if precision == "float64": return 1e-10 diff --git a/source/tests/pt/model/test_dpa4_dpmodel_parity.py b/source/tests/pt/model/test_dpa4_dpmodel_parity.py index 08d78e7320..2922dd23fe 100644 --- a/source/tests/pt/model/test_dpa4_dpmodel_parity.py +++ b/source/tests/pt/model/test_dpa4_dpmodel_parity.py @@ -1430,11 +1430,13 @@ def _build_grid_nets( lmax, op_type, layout, + mode="self", mlp_bias=False, n_focus=1, mmax=None, coefficient_layout="packed", grid_branches=1, + residual_scale_init=None, seed=7, ): """Build a pt S2GridNet, perturb its params, and copy them into dp.""" @@ -1446,13 +1448,14 @@ def _build_grid_nets( mmax=mmax, channels=self.channels, n_focus=n_focus, - mode="self", + mode=mode, op_type=op_type, dtype=torch.float64, layout=layout, coefficient_layout=coefficient_layout, grid_method="lebedev", grid_branches=grid_branches, + residual_scale_init=residual_scale_init, mlp_bias=mlp_bias, trainable=True, seed=seed, @@ -1466,13 +1469,14 @@ def _build_grid_nets( mmax=mmax, channels=self.channels, n_focus=n_focus, - mode="self", + mode=mode, op_type=op_type, precision="float64", layout=layout, coefficient_layout=coefficient_layout, grid_method="lebedev", grid_branches=grid_branches, + residual_scale_init=residual_scale_init, mlp_bias=mlp_bias, trainable=True, seed=seed, @@ -1483,20 +1487,28 @@ def _build_grid_nets( expected_keys = {"scalar_gate.weight"} if mlp_bias: expected_keys.add("scalar_gate.bias") - if op_type == "branch": + if op_type == "mlp": + expected_keys |= { + "grid_op.left_proj.weight", + "grid_op.right_proj.weight", + "grid_op.out_proj.weight", + } + elif op_type == "branch": expected_keys |= { "grid_op.left_proj.weight", "grid_op.right_proj.weight", "grid_op.router.weight", "grid_op.out_proj.weight", } + if residual_scale_init is not None: + expected_keys.add("residual_scale") assert set(state) == expected_keys - dp_net.scalar_gate.weight = state["scalar_gate.weight"] - if mlp_bias: - dp_net.scalar_gate.bias = state["scalar_gate.bias"] - if op_type == "branch": - for name in ("left_proj", "right_proj", "router", "out_proj"): - getattr(dp_net.grid_op, name).weight = state[f"grid_op.{name}.weight"] + # load through the dp serialize() schema (the pt state_dict key names + # match the dp @variables keys exactly), which copies the scalar gate, + # the grid op weights and (when set) residual_scale in one shot. + ser = dp_net.serialize() + ser["@variables"] = state + dp_net = DPS2GridNet.deserialize(ser) return pt_net, dp_net # ------------------------------------------------- (a) projector constants @@ -1707,7 +1719,86 @@ def test_grid_branch_deserialize_wrong_class(self) -> None: with pytest.raises(ValueError): DPGridBranch.deserialize({"@class": "Nope", "@version": 1}) - # ------------------------------------------------ (d) not-ported guards + # --------------------------------------------- (d) cross-mode parity + def _cross_inputs(self, rng, dp_net, *, n_batch=11, dtype=np.float64): + """Build matching cross-mode (query, context) inputs for an S2GridNet. + + For the S2 path (no frame machinery) the query/context channel widths + both collapse to ``channels``; ``layout='flat'`` flattens the focus + axis into the last dim while ``ndfc`` keeps it explicit. + """ + n_coeff = dp_net.projector.coeff_dim + if dp_net.layout == "flat": + shape = (n_batch, n_coeff, dp_net.n_focus * self.channels) + else: + shape = (n_batch, n_coeff, dp_net.n_focus, self.channels) + query = rng.normal(size=shape).astype(dtype) + context = rng.normal(size=shape).astype(dtype) + return query, context + + @pytest.mark.parametrize("op_type", ["glu", "mlp", "branch"]) # grid operation + @pytest.mark.parametrize("layout", ["ndfc", "flat"]) # flat is cross-only + def test_s2_grid_net_cross(self, op_type, layout) -> None: + # cross mode backs node_wise_s2 / message_node_s2; op_type='mlp' is the + # GridMLP path and layout='flat' is the cross-only flattened layout. + # lmax=3, mmax=1 (mmax None: + # residual_scale_init scales the cross-mode output by a per-(focus, + # channel) parameter; copied through the dp @variables schema. + pt_net, dp_net = self._build_grid_nets( + lmax=2, + op_type="glu", + layout="ndfc", + mode="cross", + residual_scale_init=0.5, + ) + assert dp_net.residual_scale is not None + rng = np.random.default_rng(2088) + query, context = self._cross_inputs(rng, dp_net) + assert_parity(dp_net.call(query, context), pt_net(to_pt(query), to_pt(context))) + + def test_s2_grid_net_cross_fp32(self) -> None: + # fp32 input through a float64-precision cross-mode net: the S2 grid + # reduction sums over many Lebedev points, so fp32 accumulation can + # drift; a loose ~1e-4 budget (device-conditional) accounts for it + # rather than the fp64 near-bit gate. + pt_net, dp_net = self._build_grid_nets( + lmax=3, + mmax=1, + op_type="branch", + layout="ndfc", + mode="cross", + coefficient_layout="m_major", + grid_branches=2, + ) + rng = np.random.default_rng(2089) + query, context = self._cross_inputs(rng, dp_net, dtype=np.float32) + dp_out = dp_net.call(query, context) + pt_out = pt_net(to_pt(query), to_pt(context)) + assert dp_out.dtype == np.float32 + assert pt_out.dtype == torch.float32 + fp32_tol = 1e-4 if _ON_CPU else 1e-3 + np.testing.assert_allclose( + np.asarray(dp_out), + pt_out.detach().cpu().numpy(), + rtol=fp32_tol, + atol=fp32_tol, + ) + + # ------------------------------------------------ (e) not-ported guards def test_not_ported_guards(self) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import S2GridNet as DPS2GridNet from deepmd.dpmodel.descriptor.dpa4_nn.projection import ( @@ -1723,8 +1814,11 @@ def test_not_ported_guards(self) -> None: "layout": "ndfc", "grid_method": "lebedev", } + # The e3nn product grid (lebedev_quadrature=False) remains the only + # not-ported path; mode='cross', op_type='mlp', layout='flat' and + # residual_scale_init are all now supported (see the parity tests + # above). with pytest.raises(NotImplementedError, match="lebedev_quadrature"): - # e3nn product grid (lebedev_quadrature=False) is not ported DPS2GridProjector(lmax=2, precision="float64", grid_method="e3nn") with pytest.raises(NotImplementedError, match="lebedev_quadrature"): DPS2GridNet(**{**common, "grid_method": "e3nn"}) @@ -1732,14 +1826,6 @@ def test_not_ported_guards(self) -> None: # "e3nn" default, which dp rejects): default construction works net = DPS2GridNet(**{k: v for k, v in common.items() if k != "grid_method"}) assert net.grid_method == "lebedev" - with pytest.raises(NotImplementedError, match="grid_mlp"): - # GridMLP (grid_mlp=True) is not ported - DPS2GridNet(**{**common, "op_type": "mlp"}) - with pytest.raises(NotImplementedError, match="node_wise_s2"): - # cross mode backs node_wise_s2/message_node_s2 only - DPS2GridNet(**{**common, "mode": "cross"}) - with pytest.raises(NotImplementedError, match="residual_scale_init"): - DPS2GridNet(**common, residual_scale_init=1e-3) def test_value_errors(self) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.grid_net import ( @@ -2292,6 +2378,36 @@ def test_so2_convolution_roundtrip(self) -> None: out2 = np.asarray(dp_mod2.call(x, dp_cache, radial)) np.testing.assert_array_equal(out1, out2) + @pytest.mark.parametrize( + "flag,grid_mlp_flag", + [ + ("node_wise_s2", "node_wise_grid_mlp"), # edge-local S2 grid product + ("node_wise_so3", "node_wise_grid_mlp"), # edge-local SO(3) grid product + ("message_node_s2", "message_node_grid_mlp"), # post-agg S2 grid product + ("message_node_so3", "message_node_grid_mlp"), # post-agg SO(3) grid prod + ], + ) + @pytest.mark.parametrize("grid_mlp", [False, True]) # glu (False) vs GridMLP (True) + def test_so2_convolution_grid(self, flag, grid_mlp_flag, grid_mlp) -> None: + # cross-mode grid products wired into SO2Convolution. ``node_wise_*`` + # is the edge-local (query/context split) path; ``message_node_*`` is + # the post-aggregation path. The default conv config is lmax=3, mmax=1 + # (mmax None: + # kmax=2 widens the SO(3) frame set to {0,-1,1,-2,2}; default tests use + # kmax=1, so this closes the kmax>1 coverage gap for an SO3 grid path. + pt_mod, dp_mod, kwargs = self._build_conv_pair(**{flag: True}, kmax=2) + self._assert_conv_parity(pt_mod, dp_mod, kwargs) + @pytest.mark.parametrize( "flag,value", [ @@ -2303,10 +2419,6 @@ def test_so2_convolution_roundtrip(self) -> None: ("atten_v_proj", True), # value projection ("atten_o_proj", True), # output projection ("s2_activation", True), # S2-grid SwiGLU non-linearity - ("node_wise_s2", True), # edge-local S2 grid product - ("node_wise_so3", True), # edge-local SO(3) grid product - ("message_node_s2", True), # post-aggregation S2 grid product - ("message_node_so3", True), # post-aggregation SO(3) grid product ], ) def test_so2_convolution_guards(self, flag, value) -> None: @@ -3109,15 +3221,35 @@ def test_ffn_roundtrip(self, s2_activation) -> None: np.asarray(dp_mod.call(x)), np.asarray(dp_mod2.call(x)) ) + @pytest.mark.parametrize("grid_mlp", [False, True]) # SO3 glu vs GridMLP op + @pytest.mark.parametrize("kmax", [1, 2]) # SO(3) frame half-width (core gap=2) + def test_ffn_so3_grid(self, grid_mlp, kmax) -> None: + # ffn_so3_grid enables the SO3 Wigner-D FFN grid path (SO3GridNet, + # self-mode). grid_mlp=True selects the GridMLP point-wise op; kmax=2 + # widens the frame set, closing the kmax>1 coverage gap for SO3 grids. + pt_mod, dp_mod, kwargs = self._build_ffn_pair( + ffn_so3_grid=True, grid_mlp=grid_mlp, kmax=kmax + ) + self._assert_ffn_parity(pt_mod, dp_mod, kwargs) + + @pytest.mark.parametrize("grid_mlp", [False, True]) # S2 glu vs GridMLP op + def test_ffn_s2_grid_mlp(self, grid_mlp) -> None: + # the S2 grid path (s2_activation=True) with grid_mlp selecting the + # GridMLP point-wise op (grid_mlp=True) vs the default glu (False). + pt_mod, dp_mod, kwargs = self._build_ffn_pair( + s2_activation=True, grid_mlp=grid_mlp + ) + self._assert_ffn_parity(pt_mod, dp_mod, kwargs) + def test_ffn_guards(self) -> None: from deepmd.dpmodel.descriptor.dpa4_nn.ffn import EquivariantFFN as DPFFN - with pytest.raises(NotImplementedError, match="ffn_so3_grid"): - DPFFN(**self._ffn_kwargs(ffn_so3_grid=True), precision="float64") - # grid_mlp guard is delegated to S2GridNet's op_type='mlp' NIE - with pytest.raises(NotImplementedError, match="mlp"): + # The e3nn product grid (lebedev_quadrature=False) on the S2 grid path + # remains the only not-ported FFN path; ffn_so3_grid and grid_mlp are + # now supported (see test_ffn_so3_grid / the grid_mlp parity below). + with pytest.raises(NotImplementedError, match="lebedev_quadrature"): DPFFN( - **self._ffn_kwargs(s2_activation=True, grid_mlp=True), + **self._ffn_kwargs(s2_activation=True, lebedev_quadrature=False), precision="float64", ) @@ -3323,6 +3455,21 @@ def test_block_roundtrip(self) -> None: out2 = np.asarray(dp_mod2.call(x, dp_cache, radial)[0]) np.testing.assert_array_equal(out1, out2) + @pytest.mark.parametrize( + "flag", + [ + "node_wise_s2", # delegated to SO2Convolution (cross-mode S2 grid) + "message_node_so3", # delegated to SO2Convolution (cross-mode SO3 grid) + "ffn_so3_grid", # delegated to EquivariantFFN (SO3GridNet self-mode) + ], + ) + def test_block_grid(self, flag) -> None: + # block-level wiring of the grid flags through to SO2Convolution / + # EquivariantFFN. Default block config is lmax=3, mmax=1 (mmax None: ("block_attn_res", "dependent"), # block-level DepthAttnRes ("layer_scale", True), # block-level FFN LayerScale ("so2_s2_activation", True), # delegated to SO2Convolution - ("node_wise_s2", True), # delegated to SO2Convolution - ("message_node_so3", True), # delegated to SO2Convolution - ("ffn_so3_grid", True), # delegated to EquivariantFFN ], ) def test_block_guards(self, flag, value) -> None: