Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions deepmd/pt/model/descriptor/sezm.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,16 @@ class DescrptSeZM(BaseDescriptor, nn.Module):
message_node_so3
If True, use the corresponding post-aggregation SO(3) Wigner-D grid-net
branch. The message is the query and the node state is the context.
so3_readout
Read-out FFN mode for the final ``l=0`` descriptor. ``"none"`` applies a
degree-0 scalar FFN to the ``l=0`` slice only; ``l>0`` coefficients are
discarded before the read-out. ``"glu"`` and ``"mlp"`` apply a full
equivariant FFN whose degree equals the node degree of the last
interaction block, driven by the SO(3) Wigner-D grid, so ``l>0`` geometry
is folded into ``l=0`` before the scalar is extracted. The value selects
the quadratic grid product (``"glu"``) or the polynomial point-wise grid
MLP (``"mlp"``). The Wigner-D frame order follows ``kmax``. The residual
stays on the ``l=0`` channel.
lebedev_quadrature
Either one boolean applied to both S2 branches, or two booleans
``[so2_enabled, ffn_enabled]`` aligned with ``s2_activation``. If
Expand Down Expand Up @@ -425,6 +435,7 @@ def __init__(
node_wise_so3: bool = False,
message_node_s2: bool = False,
message_node_so3: bool = False,
so3_readout: str = "none",
lebedev_quadrature: bool | list[bool] | None = True,
activation_function: str = "silu",
glu_activation: bool = True,
Expand Down Expand Up @@ -512,6 +523,9 @@ def __init__(
self.node_wise_so3 = bool(node_wise_so3)
self.message_node_s2 = bool(message_node_s2)
self.message_node_so3 = bool(message_node_so3)
self.so3_readout = str(so3_readout).lower()
if self.so3_readout not in {"none", "glu", "mlp"}:
raise ValueError("`so3_readout` must be one of 'none', 'glu', or 'mlp'")
if lebedev_quadrature is None:
lebedev_quadrature = [True, True]
elif isinstance(lebedev_quadrature, bool):
Expand Down Expand Up @@ -932,13 +946,21 @@ def __init__(
)

# === Final FFN for l=0 output mixing ===
# ``so3_readout="none"`` runs a degree-0 scalar FFN on the l=0 slice.
# ``"glu"``/``"mlp"`` run a full FFN at the last block's node degree whose
# SO(3) Wigner-D grid folds l>0 geometry into l=0; the value selects the
# quadratic grid product or the point-wise grid MLP.
readout_lmax = self.node_l_schedule[-1]
self.output_ffn = EquivariantFFN(
lmax=0,
lmax=0 if self.so3_readout == "none" else readout_lmax,
channels=self.channels,
hidden_channels=self.out_ffn_neurons,
grid_mlp=False,
kmax=min(self.kmax, readout_lmax),
grid_mlp=self.so3_readout == "mlp",
grid_branch=0,
dtype=self.compute_dtype,
s2_activation=False,
ffn_so3_grid=self.so3_readout != "none",
activation_function=self.out_activation_function,
glu_activation=self.out_glu_activation,
mlp_bias=self.mlp_bias,
Expand Down Expand Up @@ -1205,15 +1227,20 @@ def forward(
x = self._forward_blocks(x, edge_cache, rad_feat_per_block)

# === Step 11. Final l=0 output mixing ===
# Extract l=0 scalar features and apply FFN in promoted dtype.
# Residual keeps the output close to identity with zero-initialized FFN output.
# ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full
# (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The
# residual is added on the full coefficient tensor before extracting
# l=0: slicing the summed tensor rather than the FFN output keeps the
# saved degree-axis stride static under torch.compile dynamic shapes.
with nvtx_range("output_ffn"):
x_scalar = (
ffn_in = (
x[:, 0:1, :, :]
.reshape(n_nodes, 1, 1, self.channels)
.to(dtype=self.compute_dtype)
) # (N, 1, 1, C)
x_scalar = x_scalar + self.output_ffn(x_scalar)
if self.so3_readout == "none"
else x.to(dtype=self.compute_dtype)
)
x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :]

# === Step 12. Reshape to (nf, nloc, channels) and return ===
descriptor = rearrange(
Expand Down Expand Up @@ -1380,13 +1407,20 @@ def forward_with_edges(
x = self._forward_blocks(x, edge_cache, rad_feat_per_block)

# === Step 10. Final l=0 output mixing ===
# ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full
# (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The
# residual is added on the full coefficient tensor before extracting
# l=0: slicing the summed tensor rather than the FFN output keeps the
# saved degree-axis stride static under torch.compile dynamic shapes.
with nvtx_range("output_ffn"):
x_scalar = (
ffn_in = (
x[:, 0:1, :, :]
.reshape(n_nodes, 1, 1, self.channels)
.to(dtype=self.compute_dtype)
) # (N, 1, 1, C)
x_scalar = x_scalar + self.output_ffn(x_scalar)
if self.so3_readout == "none"
else x.to(dtype=self.compute_dtype)
)
x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :]

# === Step 11. Reshape to (nf, nloc, channels) and return ===
descriptor = x_scalar.reshape(nf, nloc, self.channels) # (nf, nloc, C)
Expand Down Expand Up @@ -2043,6 +2077,7 @@ def serialize(self) -> dict[str, Any]:
"node_wise_so3": self.node_wise_so3,
"message_node_s2": self.message_node_s2,
"message_node_so3": self.message_node_so3,
"so3_readout": self.so3_readout,
"lebedev_quadrature": self.lebedev_quadrature,
"activation_function": self.activation_function,
"glu_activation": self.glu_activation,
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/model/sezm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def __init__(
f"DP_TF32_INFER must be one of 0/1/2, got {tf32_infer_env!r}"
)
self._tf32_infer_precision = _TF32_INFER_PRECISION_CHOICES[tf32_infer_env]
if self.use_compile or self._env_use_compile_infer is True:
if self._env_use_compile_infer is True:
check_compile_torch_version()

# === Bridging (optional short-range zone bridging) ===
Expand Down Expand Up @@ -1615,6 +1615,7 @@ def trace_and_compile(
compiled callable is stored outside the ``nn.Module`` tree so
FSDP/DDP cannot see or shard its duplicated parameters.
"""
check_compile_torch_version()
from torch._decomp import (
get_decompositions,
)
Expand Down Expand Up @@ -2046,6 +2047,7 @@ def compiled(*args: Any, _fn: Any = _compiled_flat) -> dict[str, Any]:

def compile_dens(self) -> None:
"""Compile the direct-force `dens` path."""
check_compile_torch_version()
from torch._inductor import config as inductor_config

log.info("SeZM: start compiling dens path")
Expand Down Expand Up @@ -2091,6 +2093,7 @@ def _trace_lower_exportable(
*sample_inputs: torch.Tensor | None,
) -> torch.nn.Module:
"""Trace a lower-interface closure into an exportable FX graph."""
check_compile_torch_version()
from torch._decomp import (
get_decompositions,
)
Expand Down
73 changes: 67 additions & 6 deletions deepmd/pt/utils/nv_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import contextlib
import logging
import os
import sys
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -36,6 +38,10 @@

NV_CELL_LIST_THRESHOLD = 1024
NV_NONPERIODIC_CELL_LIST_THRESHOLD = 4096
# CPU has far less parallelism than CUDA, so the O(N^2) ``batch_naive`` method
# is overtaken by the O(N) ``batch_cell_list`` at a much smaller atom count;
# switch over early regardless of periodicity.
NV_CPU_CELL_LIST_THRESHOLD = 128

log = logging.getLogger(__name__)

Expand All @@ -45,32 +51,76 @@
)


@contextlib.contextmanager
def _suppress_native_stderr() -> Iterator[None]:
"""Redirect the process ``stderr`` file descriptor to ``os.devnull``.

``nvalchemiops`` initializes NVIDIA Warp on first import, which probes for a
CUDA driver and prints a native ``Warp CUDA error 100`` line straight to the
``stderr`` fd on CPU-only hosts. That line bypasses Python logging, so the
only way to mute it is at the descriptor level around the triggering import.
"""
try:
stderr_fd = sys.stderr.fileno()
except (AttributeError, OSError, ValueError):
# stderr is not a real file descriptor (e.g. captured in tests); the
# native chatter cannot be redirected, so import without suppression.
yield
return
saved_fd = os.dup(stderr_fd)
with open(os.devnull, "w") as devnull:
os.dup2(devnull.fileno(), stderr_fd)
try:
yield
finally:
os.dup2(saved_fd, stderr_fd)
os.close(saved_fd)
Comment thread
OutisLi marked this conversation as resolved.


def is_nv_available() -> bool:
"""Whether the ``nvalchemiops`` Toolkit-Ops neighbor list is importable."""
# Warp's one-time CUDA probe prints to the native stderr on CPU-only hosts;
# mute it there without hiding diagnostics on machines that have a GPU.
import_ctx = (
_suppress_native_stderr()
if not torch.cuda.is_available()
else contextlib.nullcontext()
)
try:
import nvalchemiops.torch.neighbors # noqa: F401
with import_ctx:
import nvalchemiops.torch.neighbors # noqa: F401
except (ImportError, OSError, RuntimeError) as err:
log.debug("nvalchemiops Toolkit-Ops neighbor list is unavailable: %s", err)
return False
return True


def choose_nv_nlist_method(nloc: int, *, periodic: bool = True) -> str:
def choose_nv_nlist_method(
nloc: int, *, periodic: bool = True, device: torch.device | None = None
) -> str:
"""Choose the Toolkit-Ops neighbor method for a homogeneous batch.

Parameters
----------
nloc
Number of local atoms per frame.
periodic
Whether the batch is periodic.
device
Target device. CPU uses a lower cell-list threshold than CUDA because
the ``batch_naive`` method does not parallelize well there.

Returns
-------
str
Toolkit-Ops method name.
"""
threshold = (
NV_CELL_LIST_THRESHOLD if periodic else NV_NONPERIODIC_CELL_LIST_THRESHOLD
)
if device is not None and device.type == "cpu":
threshold = NV_CPU_CELL_LIST_THRESHOLD
elif periodic:
threshold = NV_CELL_LIST_THRESHOLD
else:
threshold = NV_NONPERIODIC_CELL_LIST_THRESHOLD
if nloc >= threshold:
return "batch_cell_list"
return "batch_naive"
Expand Down Expand Up @@ -133,7 +183,17 @@ def build(
nf, dtype=torch.int32, device=device
).repeat_interleave(nloc)
batch_ptr = torch.arange(nf + 1, dtype=torch.int32, device=device) * nloc
method = choose_nv_nlist_method(nloc, periodic=periodic)
method = choose_nv_nlist_method(nloc, periodic=periodic, device=device)

# ``batch_naive`` otherwise derives ``max_atoms_per_system`` from
# ``batch_ptr`` with a ``.max().item()`` device->host sync on every
# call. Our batches are homogeneous (``nloc`` atoms per frame), so the
# value is known on the host; passing it explicitly removes that
# per-call sync. ``batch_cell_list`` neither accepts the argument nor
# has a ``**kwargs`` catch-all, so the override is guarded on method.
extra_nl_kwargs: dict[str, Any] = {}
if method == "batch_naive":
extra_nl_kwargs["max_atoms_per_system"] = int(nloc)

# Grow the search capacity until all neighbors fit so the distance-sort
# below selects the true nearest ``sum(sel)``.
Expand All @@ -149,6 +209,7 @@ def build(
max_neighbors=int(search_capacity),
return_neighbor_list=False,
wrap_positions=False,
**extra_nl_kwargs,
)
if len(nlist_result) == 2:
neighbor_matrix, num_neighbors = nlist_result
Expand Down
19 changes: 19 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,16 @@ def descrpt_se_zm_args() -> list[Argument]:
"context. When enabled together with `message_node_s2`, the SO(3) "
"branch is used for this path."
)
doc_so3_readout = (
"Read-out FFN mode for the final l=0 descriptor. `none` applies a "
"degree-0 scalar FFN to the l=0 slice only; l>0 coefficients are "
"discarded before the read-out. `glu` and `mlp` apply a full equivariant "
"FFN on the SO(3) Wigner-D grid so l>0 geometry is folded into l=0 "
"before the scalar is extracted; the value selects the quadratic grid "
"product (`glu`) or the polynomial point-wise grid MLP (`mlp`). The "
"read-out degree equals the node degree of the last interaction block; "
"the Wigner-D frame order follows `kmax`."
)
doc_lebedev_quadrature = (
"Either one boolean applied to both S2 branches, or two booleans "
"`[so2_enabled, ffn_enabled]` aligned with `s2_activation`. If a branch "
Expand Down Expand Up @@ -881,6 +891,15 @@ def descrpt_se_zm_args() -> list[Argument]:
default=False,
doc=doc_only_pt_supported + doc_message_node_so3,
),
Argument(
"so3_readout",
str,
optional=True,
default="none",
extra_check=lambda x: x in ("none", "glu", "mlp"),
extra_check_errmsg="must be one of 'none', 'glu', or 'mlp'",
doc=doc_only_pt_supported + doc_so3_readout,
),
Argument(
"lebedev_quadrature",
[bool, list[bool]],
Expand Down
1 change: 1 addition & 0 deletions examples/water/dpa4/input.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"n_atten_head": 1,
"ffn_neurons": 0,
"ffn_so3_grid": true,
"so3_readout": "mlp",
"grid_mlp": [
false,
false,
Expand Down
Loading