diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py index ba3ef38e65..9a56d7b1c4 100644 --- a/deepmd/pt/model/descriptor/sezm.py +++ b/deepmd/pt/model/descriptor/sezm.py @@ -320,6 +320,16 @@ class DescrptSeZM(BaseDescriptor, nn.Module): message_node_so3 If True, use the corresponding post-aggregation SO(3) Wigner-D grid-net branch. The message is the query and the node state is the context. + so3_readout + Read-out FFN mode for the final ``l=0`` descriptor. ``"none"`` applies a + degree-0 scalar FFN to the ``l=0`` slice only; ``l>0`` coefficients are + discarded before the read-out. ``"glu"`` and ``"mlp"`` apply a full + equivariant FFN whose degree equals the node degree of the last + interaction block, driven by the SO(3) Wigner-D grid, so ``l>0`` geometry + is folded into ``l=0`` before the scalar is extracted. The value selects + the quadratic grid product (``"glu"``) or the polynomial point-wise grid + MLP (``"mlp"``). The Wigner-D frame order follows ``kmax``. The residual + stays on the ``l=0`` channel. lebedev_quadrature Either one boolean applied to both S2 branches, or two booleans ``[so2_enabled, ffn_enabled]`` aligned with ``s2_activation``. If @@ -425,6 +435,7 @@ def __init__( node_wise_so3: bool = False, message_node_s2: bool = False, message_node_so3: bool = False, + so3_readout: str = "none", lebedev_quadrature: bool | list[bool] | None = True, activation_function: str = "silu", glu_activation: bool = True, @@ -512,6 +523,9 @@ def __init__( self.node_wise_so3 = bool(node_wise_so3) self.message_node_s2 = bool(message_node_s2) self.message_node_so3 = bool(message_node_so3) + self.so3_readout = str(so3_readout).lower() + if self.so3_readout not in {"none", "glu", "mlp"}: + raise ValueError("`so3_readout` must be one of 'none', 'glu', or 'mlp'") if lebedev_quadrature is None: lebedev_quadrature = [True, True] elif isinstance(lebedev_quadrature, bool): @@ -932,13 +946,21 @@ def __init__( ) # === Final FFN for l=0 output mixing === + # ``so3_readout="none"`` runs a degree-0 scalar FFN on the l=0 slice. + # ``"glu"``/``"mlp"`` run a full FFN at the last block's node degree whose + # SO(3) Wigner-D grid folds l>0 geometry into l=0; the value selects the + # quadratic grid product or the point-wise grid MLP. + readout_lmax = self.node_l_schedule[-1] self.output_ffn = EquivariantFFN( - lmax=0, + lmax=0 if self.so3_readout == "none" else readout_lmax, channels=self.channels, hidden_channels=self.out_ffn_neurons, - grid_mlp=False, + kmax=min(self.kmax, readout_lmax), + grid_mlp=self.so3_readout == "mlp", + grid_branch=0, dtype=self.compute_dtype, s2_activation=False, + ffn_so3_grid=self.so3_readout != "none", activation_function=self.out_activation_function, glu_activation=self.out_glu_activation, mlp_bias=self.mlp_bias, @@ -1205,15 +1227,20 @@ def forward( x = self._forward_blocks(x, edge_cache, rad_feat_per_block) # === Step 11. Final l=0 output mixing === - # Extract l=0 scalar features and apply FFN in promoted dtype. - # Residual keeps the output close to identity with zero-initialized FFN output. + # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full + # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The + # residual is added on the full coefficient tensor before extracting + # l=0: slicing the summed tensor rather than the FFN output keeps the + # saved degree-axis stride static under torch.compile dynamic shapes. with nvtx_range("output_ffn"): - x_scalar = ( + ffn_in = ( x[:, 0:1, :, :] .reshape(n_nodes, 1, 1, self.channels) .to(dtype=self.compute_dtype) - ) # (N, 1, 1, C) - x_scalar = x_scalar + self.output_ffn(x_scalar) + if self.so3_readout == "none" + else x.to(dtype=self.compute_dtype) + ) + x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] # === Step 12. Reshape to (nf, nloc, channels) and return === descriptor = rearrange( @@ -1380,13 +1407,20 @@ def forward_with_edges( x = self._forward_blocks(x, edge_cache, rad_feat_per_block) # === Step 10. Final l=0 output mixing === + # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full + # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The + # residual is added on the full coefficient tensor before extracting + # l=0: slicing the summed tensor rather than the FFN output keeps the + # saved degree-axis stride static under torch.compile dynamic shapes. with nvtx_range("output_ffn"): - x_scalar = ( + ffn_in = ( x[:, 0:1, :, :] .reshape(n_nodes, 1, 1, self.channels) .to(dtype=self.compute_dtype) - ) # (N, 1, 1, C) - x_scalar = x_scalar + self.output_ffn(x_scalar) + if self.so3_readout == "none" + else x.to(dtype=self.compute_dtype) + ) + x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] # === Step 11. Reshape to (nf, nloc, channels) and return === descriptor = x_scalar.reshape(nf, nloc, self.channels) # (nf, nloc, C) @@ -2043,6 +2077,7 @@ def serialize(self) -> dict[str, Any]: "node_wise_so3": self.node_wise_so3, "message_node_s2": self.message_node_s2, "message_node_so3": self.message_node_so3, + "so3_readout": self.so3_readout, "lebedev_quadrature": self.lebedev_quadrature, "activation_function": self.activation_function, "glu_activation": self.glu_activation, diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index a6cb1f538d..2dae2c2912 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -715,7 +715,7 @@ def __init__( f"DP_TF32_INFER must be one of 0/1/2, got {tf32_infer_env!r}" ) self._tf32_infer_precision = _TF32_INFER_PRECISION_CHOICES[tf32_infer_env] - if self.use_compile or self._env_use_compile_infer is True: + if self._env_use_compile_infer is True: check_compile_torch_version() # === Bridging (optional short-range zone bridging) === @@ -1615,6 +1615,7 @@ def trace_and_compile( compiled callable is stored outside the ``nn.Module`` tree so FSDP/DDP cannot see or shard its duplicated parameters. """ + check_compile_torch_version() from torch._decomp import ( get_decompositions, ) @@ -2046,6 +2047,7 @@ def compiled(*args: Any, _fn: Any = _compiled_flat) -> dict[str, Any]: def compile_dens(self) -> None: """Compile the direct-force `dens` path.""" + check_compile_torch_version() from torch._inductor import config as inductor_config log.info("SeZM: start compiling dens path") @@ -2091,6 +2093,7 @@ def _trace_lower_exportable( *sample_inputs: torch.Tensor | None, ) -> torch.nn.Module: """Trace a lower-interface closure into an exportable FX graph.""" + check_compile_torch_version() from torch._decomp import ( get_decompositions, ) diff --git a/deepmd/pt/utils/nv_nlist.py b/deepmd/pt/utils/nv_nlist.py index 188d4f557d..5748618dbc 100644 --- a/deepmd/pt/utils/nv_nlist.py +++ b/deepmd/pt/utils/nv_nlist.py @@ -20,6 +20,8 @@ import contextlib import logging +import os +import sys from typing import ( TYPE_CHECKING, Any, @@ -36,6 +38,10 @@ NV_CELL_LIST_THRESHOLD = 1024 NV_NONPERIODIC_CELL_LIST_THRESHOLD = 4096 +# CPU has far less parallelism than CUDA, so the O(N^2) ``batch_naive`` method +# is overtaken by the O(N) ``batch_cell_list`` at a much smaller atom count; +# switch over early regardless of periodicity. +NV_CPU_CELL_LIST_THRESHOLD = 128 log = logging.getLogger(__name__) @@ -45,32 +51,76 @@ ) +@contextlib.contextmanager +def _suppress_native_stderr() -> Iterator[None]: + """Redirect the process ``stderr`` file descriptor to ``os.devnull``. + + ``nvalchemiops`` initializes NVIDIA Warp on first import, which probes for a + CUDA driver and prints a native ``Warp CUDA error 100`` line straight to the + ``stderr`` fd on CPU-only hosts. That line bypasses Python logging, so the + only way to mute it is at the descriptor level around the triggering import. + """ + try: + stderr_fd = sys.stderr.fileno() + except (AttributeError, OSError, ValueError): + # stderr is not a real file descriptor (e.g. captured in tests); the + # native chatter cannot be redirected, so import without suppression. + yield + return + saved_fd = os.dup(stderr_fd) + with open(os.devnull, "w") as devnull: + os.dup2(devnull.fileno(), stderr_fd) + try: + yield + finally: + os.dup2(saved_fd, stderr_fd) + os.close(saved_fd) + + def is_nv_available() -> bool: """Whether the ``nvalchemiops`` Toolkit-Ops neighbor list is importable.""" + # Warp's one-time CUDA probe prints to the native stderr on CPU-only hosts; + # mute it there without hiding diagnostics on machines that have a GPU. + import_ctx = ( + _suppress_native_stderr() + if not torch.cuda.is_available() + else contextlib.nullcontext() + ) try: - import nvalchemiops.torch.neighbors # noqa: F401 + with import_ctx: + import nvalchemiops.torch.neighbors # noqa: F401 except (ImportError, OSError, RuntimeError) as err: log.debug("nvalchemiops Toolkit-Ops neighbor list is unavailable: %s", err) return False return True -def choose_nv_nlist_method(nloc: int, *, periodic: bool = True) -> str: +def choose_nv_nlist_method( + nloc: int, *, periodic: bool = True, device: torch.device | None = None +) -> str: """Choose the Toolkit-Ops neighbor method for a homogeneous batch. Parameters ---------- nloc Number of local atoms per frame. + periodic + Whether the batch is periodic. + device + Target device. CPU uses a lower cell-list threshold than CUDA because + the ``batch_naive`` method does not parallelize well there. Returns ------- str Toolkit-Ops method name. """ - threshold = ( - NV_CELL_LIST_THRESHOLD if periodic else NV_NONPERIODIC_CELL_LIST_THRESHOLD - ) + if device is not None and device.type == "cpu": + threshold = NV_CPU_CELL_LIST_THRESHOLD + elif periodic: + threshold = NV_CELL_LIST_THRESHOLD + else: + threshold = NV_NONPERIODIC_CELL_LIST_THRESHOLD if nloc >= threshold: return "batch_cell_list" return "batch_naive" @@ -133,7 +183,17 @@ def build( nf, dtype=torch.int32, device=device ).repeat_interleave(nloc) batch_ptr = torch.arange(nf + 1, dtype=torch.int32, device=device) * nloc - method = choose_nv_nlist_method(nloc, periodic=periodic) + method = choose_nv_nlist_method(nloc, periodic=periodic, device=device) + + # ``batch_naive`` otherwise derives ``max_atoms_per_system`` from + # ``batch_ptr`` with a ``.max().item()`` device->host sync on every + # call. Our batches are homogeneous (``nloc`` atoms per frame), so the + # value is known on the host; passing it explicitly removes that + # per-call sync. ``batch_cell_list`` neither accepts the argument nor + # has a ``**kwargs`` catch-all, so the override is guarded on method. + extra_nl_kwargs: dict[str, Any] = {} + if method == "batch_naive": + extra_nl_kwargs["max_atoms_per_system"] = int(nloc) # Grow the search capacity until all neighbors fit so the distance-sort # below selects the true nearest ``sum(sel)``. @@ -149,6 +209,7 @@ def build( max_neighbors=int(search_capacity), return_neighbor_list=False, wrap_positions=False, + **extra_nl_kwargs, ) if len(nlist_result) == 2: neighbor_matrix, num_neighbors = nlist_result diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 19fbe8cebd..ca53bf0c06 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -551,6 +551,16 @@ def descrpt_se_zm_args() -> list[Argument]: "context. When enabled together with `message_node_s2`, the SO(3) " "branch is used for this path." ) + doc_so3_readout = ( + "Read-out FFN mode for the final l=0 descriptor. `none` applies a " + "degree-0 scalar FFN to the l=0 slice only; l>0 coefficients are " + "discarded before the read-out. `glu` and `mlp` apply a full equivariant " + "FFN on the SO(3) Wigner-D grid so l>0 geometry is folded into l=0 " + "before the scalar is extracted; the value selects the quadratic grid " + "product (`glu`) or the polynomial point-wise grid MLP (`mlp`). The " + "read-out degree equals the node degree of the last interaction block; " + "the Wigner-D frame order follows `kmax`." + ) doc_lebedev_quadrature = ( "Either one boolean applied to both S2 branches, or two booleans " "`[so2_enabled, ffn_enabled]` aligned with `s2_activation`. If a branch " @@ -881,6 +891,15 @@ def descrpt_se_zm_args() -> list[Argument]: default=False, doc=doc_only_pt_supported + doc_message_node_so3, ), + Argument( + "so3_readout", + str, + optional=True, + default="none", + extra_check=lambda x: x in ("none", "glu", "mlp"), + extra_check_errmsg="must be one of 'none', 'glu', or 'mlp'", + doc=doc_only_pt_supported + doc_so3_readout, + ), Argument( "lebedev_quadrature", [bool, list[bool]], diff --git a/examples/water/dpa4/input.json b/examples/water/dpa4/input.json index 34e316086a..415f6a5be0 100644 --- a/examples/water/dpa4/input.json +++ b/examples/water/dpa4/input.json @@ -23,6 +23,7 @@ "n_atten_head": 1, "ffn_neurons": 0, "ffn_so3_grid": true, + "so3_readout": "mlp", "grid_mlp": [ false, false,