Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 44 additions & 35 deletions deepmd/dpmodel/descriptor/dpa4_nn/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@
It defines the full SO(3)-equivariant feed-forward network used inside SeZM
interaction blocks.

Branches guarded with ``NotImplementedError`` (flags unused by the core DPA4
config):

- ``ffn_so3_grid=True`` — the pt path instantiates ``SO3GridNet``
(pt ffn.py:209), which is not ported to dpmodel.
- ``grid_mlp=True`` together with the grid path active selects
``op_type='mlp'`` for ``S2GridNet`` (pt ffn.py:206); the delegate
``S2GridNet`` constructor raises for that op type (``GridMLP`` is not
ported), so no duplicate guard is added here.
When the grid path is active the equivariant nonlinearity is a grid net
(``S2GridNet`` for the S2 path, ``SO3GridNet`` for ``ffn_so3_grid=True``); the
``grid_mlp``/``grid_branch`` flags pick the point-wise grid op type
(``branch`` if ``grid_branch>0`` else ``mlp`` if ``grid_mlp`` else ``glu``),
matching the pt reference (pt ffn.py:201-239).
"""

from __future__ import (
Expand Down Expand Up @@ -44,6 +40,7 @@
)
from .grid_net import (
S2GridNet,
SO3GridNet,
)
from .projection import (
resolve_s2_grid_resolution,
Expand Down Expand Up @@ -86,16 +83,15 @@ class EquivariantFFN(NativeOP):
Maximum Wigner-D frame order (|k|) used by the SO3 Wigner-D FFN grid.
grid_mlp
If True, select the polynomial grid MLP operation when the
block-internal FFN grid path is enabled. Not ported: the delegate
``S2GridNet`` raises ``NotImplementedError`` for ``op_type='mlp'``.
block-internal FFN grid path is enabled.
grid_branch
Number of scalar-routed polynomial product branches used when the
block-internal FFN grid path is enabled. ``0`` disables this branch
mixer. Positive values take precedence over ``grid_mlp``.
s2_activation
If True, enable the S2 FFN grid path.
ffn_so3_grid
If True, enable the SO3 Wigner-D FFN grid path (not ported).
If True, enable the SO3 Wigner-D FFN grid path (``SO3GridNet``).
lebedev_quadrature
If True, use Lebedev quadrature for the S2 projector in this FFN.
activation_function
Expand Down Expand Up @@ -145,10 +141,6 @@ def __init__(
self.use_grid_branch = self.grid_branch > 0
self.s2_activation = bool(s2_activation)
self.ffn_so3_grid = bool(ffn_so3_grid)
if self.ffn_so3_grid:
raise NotImplementedError(
"ffn_so3_grid=True (SO3GridNet) is not ported to dpmodel"
)
self.lebedev_quadrature = bool(lebedev_quadrature)
self.s2_grid_method = "lebedev" if self.lebedev_quadrature else "e3nn"
base_grid = resolve_s2_grid_resolution(
Expand All @@ -167,8 +159,8 @@ def __init__(
self.precision = precision
self.compute_precision = _compute_precision(precision)
self.trainable = bool(trainable)
# pt: grid_n_frames = 2 * kmax + 1 only when ffn_so3_grid (guarded above)
self.grid_n_frames = 1
# pt: grid_n_frames = 2 * kmax + 1 only when ffn_so3_grid
self.grid_n_frames = 2 * self.kmax + 1 if self.ffn_so3_grid else 1

# === Step 0. Split deterministic seeds at the module top-level ===
seed_so3_in = child_seed(seed, 0)
Expand Down Expand Up @@ -203,23 +195,40 @@ def __init__(
if self.use_grid_branch
else ("mlp" if self.use_grid_mlp else "glu")
)
# op_type='mlp' raises NotImplementedError inside S2GridNet
self.act: NativeOP = S2GridNet(
lmax=self.lmax,
channels=self.hidden_channels,
n_focus=1,
mode="self",
op_type=grid_op,
precision=self.compute_precision,
layout="ndfc",
grid_resolution_list=self.s2_grid_resolution,
coefficient_layout="packed",
grid_method=self.s2_grid_method,
grid_branches=max(1, self.grid_branch),
mlp_bias=self.mlp_bias,
trainable=self.trainable,
seed=seed_act,
)
self.act: NativeOP
if self.ffn_so3_grid:
self.act = SO3GridNet(
lmax=self.lmax,
kmax=self.kmax,
channels=self.hidden_channels,
n_focus=1,
mode="self",
op_type=grid_op,
precision=self.compute_precision,
layout="ndfc",
coefficient_layout="packed",
grid_branches=max(1, self.grid_branch),
mlp_bias=self.mlp_bias,
trainable=self.trainable,
seed=seed_act,
)
else:
self.act = S2GridNet(
lmax=self.lmax,
channels=self.hidden_channels,
n_focus=1,
mode="self",
op_type=grid_op,
precision=self.compute_precision,
layout="ndfc",
grid_resolution_list=self.s2_grid_resolution,
coefficient_layout="packed",
grid_method=self.s2_grid_method,
grid_branches=max(1, self.grid_branch),
mlp_bias=self.mlp_bias,
trainable=self.trainable,
seed=seed_act,
)
else:
self.act = GatedActivation(
lmax=self.lmax,
Expand Down
Loading
Loading