diff --git a/deepmd/pt/entrypoints/freeze_pt2.py b/deepmd/pt/entrypoints/freeze_pt2.py index c97c4a07ad..78daf3ef65 100644 --- a/deepmd/pt/entrypoints/freeze_pt2.py +++ b/deepmd/pt/entrypoints/freeze_pt2.py @@ -23,8 +23,11 @@ annotations, ) +import ctypes import json import logging +import os +import tempfile import zipfile from copy import ( deepcopy, @@ -229,6 +232,7 @@ def _collect_metadata( output_keys: list[str], is_spin: bool | None = None, do_atomic_virial: bool = False, + has_comm_artifact: bool = False, ) -> dict: """Assemble the flat metadata dict expected by :class:`DeepPotPTExpt`. @@ -272,7 +276,7 @@ def _collect_metadata( "dim_chg_spin": int(model.get_dim_chg_spin()), "mixed_types": bool(model.mixed_types()), "has_message_passing": _model_has_message_passing(model), - "has_comm_artifact": False, + "has_comm_artifact": bool(has_comm_artifact), "do_atomic_virial": exports_atomic_virial, "nnei": int(sum(model.get_sel())), "has_default_fparam": bool(model.has_default_fparam()), @@ -291,16 +295,25 @@ def _collect_metadata( return metadata -def _make_sample_inputs( +# The trace-time sendlist for the with-comm artifact embeds the address of a +# numpy array (``int**`` contract of ``border_op``). The array must outlive the +# trace + export call; the exported graph never reads it at runtime (the op is +# opaque), so a module-level keepalive is sufficient. +_TRACE_SENDLIST_KEEPALIVE: list[np.ndarray] = [] + + +def _build_sample_extended( model: torch.nn.Module, nframes: int, nloc: int, device: torch.device, - has_spin: bool = False, + has_spin: bool, ) -> tuple[torch.Tensor | None, ...]: - """Build representative ``forward_common_lower`` inputs for tracing. + """Build the extended-region sample tensors shared by the lower builders. - Tensors are float64 / int64 (matching the ``.pt2`` I/O contract). + Returns ``(ext_coord, ext_atype, nlist, mapping, ext_spin, fparam, aparam, + charge_spin)``; tensors are float64 / int64 (matching the ``.pt2`` I/O + contract). ``ext_spin`` is ``None`` unless ``has_spin``. """ rcut = float(model.get_rcut()) sel = list(model.get_sel()) @@ -353,6 +366,7 @@ def _make_sample_inputs( ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=device) nlist_t = torch.tensor(nlist, dtype=torch.int64, device=device) mapping_t = torch.tensor(mapping, dtype=torch.int64, device=device) + ext_spin = None if has_spin: extended_spin = np.take_along_axis(spin_np, mapping[..., None], axis=1) ext_spin = torch.tensor(extended_spin, dtype=torch.float64, device=device) @@ -366,11 +380,45 @@ def _make_sample_inputs( if dim_aparam > 0 else None ) - charge_spin = None - if dim_chg_spin > 0: - charge_spin = torch.zeros( - nframes, dim_chg_spin, dtype=torch.float64, device=device - ) + charge_spin = ( + torch.zeros(nframes, dim_chg_spin, dtype=torch.float64, device=device) + if dim_chg_spin > 0 + else None + ) + return ( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + ext_spin, + fparam, + aparam, + charge_spin, + ) + + +def _make_sample_inputs( + model: torch.nn.Module, + nframes: int, + nloc: int, + device: torch.device, + has_spin: bool = False, +) -> tuple[torch.Tensor | None, ...]: + """Build representative ``forward_common_lower`` inputs for tracing. + + The spin path returns the nlist lower signature; the energy path returns the + single-domain edge schema (folded ``edge_index``, extended scatter indices). + """ + ( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + ext_spin, + fparam, + aparam, + charge_spin, + ) = _build_sample_extended(model, nframes, nloc, device, has_spin) if has_spin: return ( ext_coord, @@ -402,6 +450,83 @@ def _make_sample_inputs( ) +def _make_edge_comm_tensors( + mapping: torch.Tensor, + nloc: int, + device: torch.device, +) -> tuple[torch.Tensor, ...]: + """Build a single self-send swap so the with-comm trace runs ``border_op``. + + A LAMMPS run supplies the real per-swap communication plan at inference time; + the trace only needs valid in-range indices so the eager output-key probe can + execute the opaque op. Ghost slot ``k`` copies its owner's local index + ``mapping[nloc + k]``. + """ + nall = int(mapping.shape[1]) + nghost = nall - nloc + send_count = max(1, nghost) + owner = mapping[0, nloc:nall].to(dtype=torch.int32).cpu().numpy() + indices = np.ascontiguousarray(np.resize(owner, send_count).astype(np.int32)) + _TRACE_SENDLIST_KEEPALIVE.append(indices) + addr = indices.ctypes.data_as(ctypes.c_void_p).value + return ( + torch.tensor([addr], dtype=torch.int64, device=device), # send_list (int**) + torch.zeros(1, dtype=torch.int32, device=device), # send_proc (self) + torch.zeros(1, dtype=torch.int32, device=device), # recv_proc (self) + torch.tensor([send_count], dtype=torch.int32, device=device), # send_num + torch.tensor([send_count], dtype=torch.int32, device=device), # recv_num + torch.zeros(1, dtype=torch.int64, device=device), # communicator + torch.tensor(nloc, dtype=torch.int32, device=device), # nlocal + torch.tensor(nghost, dtype=torch.int32, device=device), # nghost + ) + + +def _make_comm_sample_inputs( + model: torch.nn.Module, + nloc: int, + device: torch.device, +) -> tuple[torch.Tensor | None, ...]: + """Build with-comm edge inputs for tracing the parallel ``.pt2`` artifact. + + The parallel path indexes the extended node set directly, so ``edge_index`` + coincides with ``edge_scatter_index`` (both extended) and ghost features are + refreshed via ``border_op`` rather than gathered through a folded mapping. + The frame axis is fixed at one, matching LAMMPS single-frame inference. + """ + ( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + _ext_spin, + fparam, + aparam, + charge_spin, + ) = _build_sample_extended( + model, nframes=1, nloc=nloc, device=device, has_spin=False + ) + formatted_nlist: torch.Tensor = model.format_nlist(ext_coord, ext_atype, nlist_t) + edge_schema = edge_schema_from_extended( + ext_coord, + ext_atype[:, :nloc], + formatted_nlist, + mapping_t, + ) + return ( + edge_schema.coord, # (1, nall, 3) + edge_schema.atype, # (1, nloc) + ext_atype, # (1, nall) + edge_schema.edge_scatter_index, # edge_index: extended (2, E) + edge_schema.edge_vec, + edge_schema.edge_scatter_index, # edge_scatter_index: extended (2, E) + edge_schema.edge_mask, + fparam, + aparam, + charge_spin, + *_make_edge_comm_tensors(mapping_t, nloc, device), + ) + + def _resolve_nframes( model: torch.nn.Module, nloc: int, @@ -489,6 +614,79 @@ def _build_dynamic_shapes( return shapes +def _build_with_comm_dynamic_shapes( + sample_inputs: tuple[torch.Tensor | None, ...], +) -> tuple: + """Build dynamic-shape constraints for the parallel with-comm lower input. + + The frame axis is fixed at one (LAMMPS single-frame inference), so only + ``nall``, ``nloc`` and ``nedge`` vary. The eight communication tensors are + static: ``nswap`` is fixed at LAMMPS init and the graph carries no variation + across its value (``border_op`` is opaque to the exported program). + """ + nall_dim = torch.export.Dim("nall", min=1) + nloc_dim = torch.export.Dim("nloc", min=1) + nedge_dim = torch.export.Dim("nedge", min=2) + fparam = sample_inputs[7] + aparam = sample_inputs[8] + charge_spin = sample_inputs[9] + base = ( + {1: nall_dim}, # coord: (1, nall, 3) + {1: nloc_dim}, # atype: (1, nloc) + {1: nall_dim}, # extended_atype: (1, nall) + {1: nedge_dim}, # edge_index: (2, nedge) + {0: nedge_dim}, # edge_vec: (nedge, 3) + {1: nedge_dim}, # edge_scatter_index: (2, nedge) + {0: nedge_dim}, # edge_mask: (nedge,) + None if fparam is None else {}, # fparam: (1, ndf) static + None if aparam is None else {1: nloc_dim}, # aparam: (1, nloc, nda) + None if charge_spin is None else {}, # charge_spin: (1, nchg) static + ) + return (*base, *((None,) * 8)) + + +def _export_with_comm_artifact( + model: torch.nn.Module, + *, + target_device: torch.device, + compile_options: dict[str, Any], +) -> bytes: + """Trace, export and compile the parallel with-comm ``.pt2`` artifact. + + The artifact mirrors the regular edge graph but exchanges ghost node + features across ranks via ``border_op``. Returns the compiled package bytes + for nesting under ``model/extra/forward_lower_with_comm.pt2``; tracing runs + on CPU and the package is moved to ``target_device`` before compilation. + """ + from torch._inductor import ( + aoti_compile_and_package, + ) + from torch._inductor import config as inductor_config + + sample_inputs = _make_comm_sample_inputs(model, nloc=7, device=torch.device("cpu")) + traced = model.forward_common_lower_exportable_with_comm(*sample_inputs) + exported = torch.export.export( + traced, + sample_inputs, + dynamic_shapes=_build_with_comm_dynamic_shapes(sample_inputs), + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + _strip_shape_assertions(exported.graph_module) + if target_device.type != "cpu": + from torch.export.passes import ( + move_to_device_pass, + ) + + exported = move_to_device_pass(exported, target_device) + with tempfile.TemporaryDirectory() as td: + wc_path = os.path.join(td, "forward_lower_with_comm.pt2") + with inductor_config.patch({**compile_options, "triton.max_tiles": 1}): + aoti_compile_and_package(exported, package_path=wc_path) + with open(wc_path, "rb") as fh: + return fh.read() + + def freeze_sezm_to_pt2( ckpt_path: str, out_path: str, @@ -637,14 +835,31 @@ def freeze_sezm_to_pt2( with inductor_config.patch({**compile_options, "triton.max_tiles": 1}): aoti_compile_and_package(exported, package_path=out_path_str) + # Second artifact: the LAMMPS multi-rank with-comm graph. It threads the + # eight border_op communication tensors so cross-rank ghost features are + # exchanged between interaction blocks. Excluded for spin (nlist lower + # interface) and bridging models (Source Freeze Propagation is not + # rank-decomposable); those fall back to single-rank inference. + with_comm = (not is_spin) and model.supports_edge_parallel() + with_comm_bytes: bytes | None = None + if with_comm: + with_comm_bytes = _export_with_comm_artifact( + model, + target_device=target_device, + compile_options=compile_options, + ) + metadata = _collect_metadata( model, output_keys=output_keys, is_spin=is_spin, do_atomic_virial=atomic_virial, + has_comm_artifact=with_comm, ) with zipfile.ZipFile(out_path_str, "a") as zf: zf.writestr("model/extra/metadata.json", json.dumps(metadata)) + if with_comm_bytes is not None: + zf.writestr("model/extra/forward_lower_with_comm.pt2", with_comm_bytes) # The raw training params are preserved so `dp change-bias` and # other downstream tooling can recover the exact training config. # ``default=str`` is a safety net for exotic nested values. diff --git a/deepmd/pt/model/descriptor/sezm.py b/deepmd/pt/model/descriptor/sezm.py index 165fef0aaa..eba6ae5b4e 100644 --- a/deepmd/pt/model/descriptor/sezm.py +++ b/deepmd/pt/model/descriptor/sezm.py @@ -1267,16 +1267,30 @@ def forward_with_edges( edge_mask: torch.Tensor, force_embedding: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + nloc: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute the descriptor from a sparse edge list. + Two node-set conventions share this path. In the single-domain path + (``comm_dict`` is ``None``) the nodes are exactly the local atoms and + ``edge_index`` source/destination both index ``[0, nf*nloc)``. In the + parallel (LAMMPS multi-rank) path the nodes span the extended region + (local owners followed by ghosts), ``edge_index`` indexes the extended + atoms directly, and each interaction block refreshes ghost-node features + from their owner ranks at the SO(2) convolution input (see + :func:`~deepmd.pt.model.descriptor.sezm_nn.block.exchange_ghost_features`). + Parameters ---------- extended_coord - Coordinates with shape (nf, nloc*3) or (nf, nloc, 3) in Å. + Coordinates with shape (nf, n*3) or (nf, n, 3) in Å, where ``n`` is + ``nloc`` in the single-domain path and ``nall`` in the parallel path. extended_atype - Atom types with shape (nf, nloc). + Atom types with shape (nf, n). In the parallel path this spans the + extended region so ghost type embeddings are available for the + edge-type and environment-seed features. edge_index Edge indices with shape (2, E). edge_vec @@ -1290,6 +1304,13 @@ def forward_with_edges( initial SO(3) backbone state before the interaction blocks. charge_spin Frame-level charge and spin conditions with shape (nf, 2). + comm_dict + Border-exchange tensors for parallel inference. When provided, the + node set spans the extended region and ghost features are exchanged + via ``deepmd_export::border_op`` between interaction blocks. + nloc + Number of owned (local) atoms per frame. Required when ``comm_dict`` + is provided; the final scalar read-out is restricted to these atoms. Returns ------- @@ -1298,13 +1319,32 @@ def forward_with_edges( final equivariant latent with shape ``(nf * nloc, D_final, 1, channels)``. """ # === Step 1. Setup dimensions === + # ``n_per_frame`` is the per-frame node count: ``nloc`` in the + # single-domain path and ``nall`` in the parallel path. ``out_nloc`` is + # the owned-atom count used for the final local read-out. extended_coord = extended_coord.to(self.compute_dtype) - nf, nloc = extended_atype.shape[:2] + nf, n_per_frame = extended_atype.shape[:2] + parallel = comm_dict is not None + if parallel: + # The border exchange and the owned-atom read-out assume one MPI + # rank's single-frame extended layout (LAMMPS, the with-comm export + # trace, and the parity tests all provide it). nf > 1 would silently + # mix frames into wrong forces, so it is rejected outright. + if nf != 1: + raise ValueError("parallel `comm_dict` inference requires nf == 1") + # Imported lazily so plain pt inference never pulls the custom-op + # registration module onto its import path. + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() + out_nloc = nloc if parallel else n_per_frame + atype_flat = extended_atype.reshape(-1) # (N,) # === Step 2. Type embedding (l=0) === with nvtx_range("type_embedding"): - atype_loc = extended_atype[:, :nloc] # (nf, nloc) - type_ebed = self.type_embedding(atype_loc).reshape( + type_ebed = self.type_embedding(extended_atype).reshape( -1, self.channels ) # (N, C) if self.charge_spin_embedding is not None: @@ -1312,7 +1352,7 @@ def forward_with_edges( type_ebed, charge_spin, nf=nf, - nloc=nloc, + nloc=n_per_frame, ) n_nodes = type_ebed.shape[0] @@ -1320,7 +1360,7 @@ def forward_with_edges( with nvtx_range("build_edge_cache"): edge_cache = build_edge_cache_from_edges( type_ebed=type_ebed, - atype_flat=atype_loc.reshape(-1), + atype_flat=atype_flat, edge_index=edge_index, edge_vec=edge_vec, edge_mask=edge_mask, @@ -1359,7 +1399,6 @@ def forward_with_edges( # === Step 5. Env FiLM conditioning (optional, fp32+) === with nvtx_range("env_film"): if self.use_env_seed: - atype_flat = atype_loc.reshape(-1) # (N,) film = self.env_seed_embedding( edge_cache=edge_cache, atype_flat=atype_flat, @@ -1407,9 +1446,20 @@ def forward_with_edges( x = x + force_embedding.to(dtype=self.dtype) edge_cache = edge_cache_to_dtype(edge_cache, self.dtype) with self._compute_mode_ctx(extended_coord.device): - x = self._forward_blocks(x, edge_cache, rad_feat_per_block) + x = self._forward_blocks( + x, edge_cache, rad_feat_per_block, comm_dict=comm_dict + ) + + # === Step 10. Keep the owned-atom rows for the read-out === + # ``n_out_nodes`` is the owned-node count in the flattened layout + # (``nf * nloc``). Single-domain: ``out_nloc == n_per_frame``, so this + # equals the whole node set and the slice is a no-op. Parallel + # (single-frame): it drops the trailing ghost rows that only fed message + # passing -- LAMMPS orders owned atoms before ghosts, so they lead. + n_out_nodes = nf * out_nloc + x = x[:n_out_nodes] - # === Step 10. Final l=0 output mixing === + # === Step 11. Final l=0 output mixing === # ``none`` feeds the l=0 slice only; ``glu``/``mlp`` feed the full # (N, D, 1, C) node tensor so the SO(3) grid folds l>0 into l=0. The # residual is added on the full coefficient tensor before extracting @@ -1418,7 +1468,7 @@ def forward_with_edges( with nvtx_range("output_ffn"): ffn_in = ( x[:, 0:1, :, :] - .reshape(n_nodes, 1, 1, self.channels) + .reshape(n_out_nodes, 1, 1, self.channels) .to(dtype=self.compute_dtype) if self.so3_readout == "none" # truncate to the final node degree: the empty-edge path @@ -1428,8 +1478,8 @@ def forward_with_edges( ) x_scalar = (ffn_in + self.output_ffn(ffn_in))[:, 0:1, :, :] - # === Step 11. Reshape to (nf, nloc, channels) and return === - descriptor = x_scalar.reshape(nf, nloc, self.channels) # (nf, nloc, C) + # === Step 12. Reshape to (nf, nloc, channels) and return === + descriptor = x_scalar.reshape(nf, out_nloc, self.channels) # (nf, nloc, C) return descriptor.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), x.contiguous() def _forward_blocks( @@ -1437,6 +1487,7 @@ def _forward_blocks( x: torch.Tensor, edge_cache: EdgeFeatureCache, radial_feat_per_block: list[torch.Tensor], + comm_dict: dict[str, torch.Tensor] | None = None, ) -> torch.Tensor: """ Run the interaction blocks with optional depth attention. @@ -1449,6 +1500,12 @@ def _forward_blocks( Per-edge cache. radial_feat_per_block List of per-block radial features already truncated to l_schedule[i]+1. + comm_dict + Border-exchange tensors for parallel inference, forwarded to each + block. The block refreshes ghost rows at the SO(2) convolution + input — the descriptor's only cross-node operation — so message + passing always reads up-to-date neighbours regardless of the + (per-node) attention-residual history. Returns ------- @@ -1461,7 +1518,12 @@ def _forward_blocks( x = x[:, : self.node_ebed_dims[i], :, :] blk_radial = radial_feat_per_block[i] with nvtx_range(f"block_{i}"): - x, _, _, _ = block(x, edge_cache, blk_radial) + x, _, _, _ = block( + x, + edge_cache, + blk_radial, + comm_dict=self._block_comm(i, comm_dict), + ) return x n_node = x.shape[0] @@ -1488,6 +1550,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: edge_cache, blk_radial, unit_history=truncated_unit_history, + comm_dict=self._block_comm(i, comm_dict), ) unit_history.append(so2_unit_output) unit_history.extend(ffn_unit_outputs) @@ -1520,6 +1583,7 @@ def node_l0_extractor(v: torch.Tensor) -> torch.Tensor: edge_cache, blk_radial, unit_history=truncated_block_history, + comm_dict=self._block_comm(i, comm_dict), ) block_history.append(block_summary) x = block_output @@ -1798,6 +1862,29 @@ def _canonicalize_charge_spin( raise ValueError("`charge_spin` first dimension must match nframes.") return charge_spin + def _block_comm( + self, + block_idx: int, + comm_dict: dict[str, torch.Tensor] | None, + ) -> dict[str, torch.Tensor] | None: + """Return the border-exchange tensors block ``block_idx`` actually needs. + + Only the SO(2) convolution reads neighbour features, so a block needs the + ghost exchange exactly when its neighbour rows cannot be rebuilt locally. + Block 0 reads the initial node state: a rank reproduces its ghost rows + from ``extended_atype`` (type embedding) unless env-seed / GIE folds + neighbour-environment information into them. Every later block reads a + previous block's output, which a rank cannot reproduce for ghosts (they + receive no messages locally). Returning ``None`` skips the exchange, so a + purely local model (``use_env_seed=False`` with a single block) runs + multi-rank with no communication at all. + """ + if comm_dict is None: + return None + if block_idx == 0 and not self.use_env_seed: + return None + return comm_dict + @contextmanager def _compute_mode_ctx(self, device: torch.device) -> Generator[None, None, None]: """ @@ -1884,7 +1971,22 @@ def mixed_types(self) -> bool: return True def has_message_passing(self) -> bool: - return bool(len(self.blocks) > 0 and self.lmax > 0) + # SeZM resolves ghost neighbours through the atom-map fold (single + # domain) or border_op exchange (parallel) instead of reading them + # directly, so its lower path always needs message-passing handling. + return True + + def has_message_passing_across_ranks(self) -> bool: + """Whether multi-rank inference needs cross-rank ghost-feature exchange. + + SeZM reads ghost-neighbour features at every interaction block, so a + domain-decomposed run must exchange them through ``border_op``. Source + Freeze Propagation bridging is excluded: its per-node gate folds a + node's entire outgoing-edge set, which a single rank cannot observe for + ghost owners, so the edge-based with-comm artifact is not exported for + bridging models and multi-rank inference fails fast instead. + """ + return self.bridging_switch is None def need_sorted_nlist_for_lower(self) -> bool: return False diff --git a/deepmd/pt/model/descriptor/sezm_nn/block.py b/deepmd/pt/model/descriptor/sezm_nn/block.py index 5857f173a9..c38fc78a5b 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/block.py +++ b/deepmd/pt/model/descriptor/sezm_nn/block.py @@ -63,6 +63,59 @@ ) +def exchange_ghost_features( + x: torch.Tensor, + comm_dict: dict[str, torch.Tensor], +) -> torch.Tensor: + """ + Refresh ghost-node features from their owner ranks via MPI border exchange. + + SeZM node features are SO(3) coefficients expressed in the shared global + frame, so a ghost atom and its owner carry identical features and the + per-row owner->ghost copy is exact and equivariance-preserving. The opaque + ``deepmd_export::border_op`` performs the exchange and carries a registered + backward (reverse communication of gradients), so a single + ``autograd.grad(energy, edge_vec)`` accumulates cross-rank force + contributions when every rank runs the exchange in lockstep. + + This is applied to the SO(2) convolution input — the descriptor's only + cross-node operation — so ghost rows are correct exactly where message + passing reads them, regardless of how the (per-node) attention-residual + history that produced the input populated its ghost rows. + + Parameters + ---------- + x + Extended node features with shape (nall, D, 1, channels). Owned-atom rows + hold up-to-date values; ghost rows are overwritten by this call. + comm_dict + Border-exchange tensors ``send_list``, ``send_proc``, ``recv_proc``, + ``send_num``, ``recv_num``, ``communicator``, ``nlocal``, ``nghost``. + + Returns + ------- + torch.Tensor + Node features with ghost rows filled, same shape as ``x``. + """ + n_nodes, ebed_dim, n_focus, channels = x.shape + # border_op exchanges whole rows by raw pointer arithmetic, so the buffer + # must be contiguous; a degree-truncated node tensor reshapes to a strided + # view that would otherwise corrupt the exchange. + g1 = x.reshape(n_nodes, ebed_dim * n_focus * channels).contiguous() + g1 = torch.ops.deepmd_export.border_op( + comm_dict["send_list"], + comm_dict["send_proc"], + comm_dict["recv_proc"], + comm_dict["send_num"], + comm_dict["recv_num"], + g1, + comm_dict["communicator"], + comm_dict["nlocal"], + comm_dict["nghost"], + ) + return g1.reshape(n_nodes, ebed_dim, n_focus, channels) + + class SeZMInteractionBlock(nn.Module): """ SeZM interaction block with SO(2) message passing and equivariant FFN stack. @@ -586,6 +639,7 @@ def forward( edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, unit_history: list[torch.Tensor] | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -606,6 +660,12 @@ def forward( `full_attn_res != "none"`, it is interpreted as completed unit history. When `block_attn_res != "none"`, it is interpreted as completed block history. + comm_dict + Border-exchange tensors for parallel (LAMMPS multi-rank) inference. + When provided, the SO(2) convolution input has its ghost rows + refreshed from owner ranks; the depth-attention history may carry + stale ghost rows because the exchange happens at the convolution + input, after the (per-node) aggregation that consumes it. Returns ------- @@ -619,7 +679,7 @@ def forward( - full AttnRes path returns `(block_output, None, so2_unit_output, ffn_unit_outputs)` - block AttnRes path returns `(block_output, block_summary, None, None)` """ - return self._forward_impl(x, edge_cache, radial_feat, unit_history) + return self._forward_impl(x, edge_cache, radial_feat, unit_history, comm_dict) def _extract_l0_from_canonical(self, value: torch.Tensor) -> torch.Tensor: """ @@ -657,6 +717,7 @@ def _run_so2_unit( x: torch.Tensor, edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> torch.Tensor: """ Run the SO(2) unit without an outer block-level residual shortcut. @@ -669,12 +730,19 @@ def _run_so2_unit( Edge cache. radial_feat Per-edge radial features with shape (E, lmax+1, C). + comm_dict + Border-exchange tensors for parallel inference. When provided, the + convolution input's ghost rows are refreshed from owner ranks + immediately before the only cross-node operation in the block, so + owned destinations gather up-to-date neighbour features. Returns ------- torch.Tensor SO(2) unit output with shape `(N, D, 1, C)`. """ + if comm_dict is not None: + x = exchange_ghost_features(x, comm_dict) if self._use_infer_activation_checkpoint(x, radial_feat): edge_cache_no_proj = edge_cache._replace( D_to_m_cache=None, @@ -759,6 +827,7 @@ def _forward_with_residual_shortcuts( edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, unit_history: list[torch.Tensor] | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -778,6 +847,10 @@ def _forward_with_residual_shortcuts( Per-edge radial features with shape (E, lmax+1, C). unit_history Unused in the residual-connected path. + comm_dict + Border-exchange tensors for parallel inference, forwarded to the + SO(2) unit. The owned-atom residual reads the original ``x``, which + is already correct on owned rows. Returns ------- @@ -785,7 +858,7 @@ def _forward_with_residual_shortcuts( Tuple `(block_output, None, None, None)`. """ with nvtx_range("so2_conv"): - so2_unit_output = self._run_so2_unit(x, edge_cache, radial_feat) + so2_unit_output = self._run_so2_unit(x, edge_cache, radial_feat, comm_dict) so2_state = x + so2_unit_output with nvtx_range("ffn"): @@ -803,6 +876,7 @@ def _forward_with_full_attn_res( edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, unit_history: list[torch.Tensor] | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -823,6 +897,11 @@ def _forward_with_full_attn_res( unit_history Truncated history in canonical node layout. Each source has shape `(N, D, 1, C)`. + comm_dict + Border-exchange tensors for parallel inference, forwarded to the + SO(2) unit. The attention-residual aggregation is per-node, so the + ghost exchange at the convolution input restores ghost correctness + even when the history sources carry stale ghost rows. Returns ------- @@ -836,7 +915,9 @@ def _forward_with_full_attn_res( scalar_extractor=self._extract_l0_from_canonical, current_x=x, ) - so2_unit_output = self._run_so2_unit(so2_input, edge_cache, radial_feat) + so2_unit_output = self._run_so2_unit( + so2_input, edge_cache, radial_feat, comm_dict + ) with nvtx_range("ffn"): completed_units = [*unit_history, so2_unit_output] @@ -863,6 +944,7 @@ def _forward_with_block_attn_res( edge_cache: EdgeFeatureCache, radial_feat: torch.Tensor, unit_history: list[torch.Tensor] | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -883,6 +965,11 @@ def _forward_with_block_attn_res( unit_history Truncated block history in canonical node layout. Each source has shape `(N, D, 1, C)`. + comm_dict + Border-exchange tensors for parallel inference, forwarded to the + SO(2) unit. The attention-residual aggregation is per-node, so the + ghost exchange at the convolution input restores ghost correctness + even when the history sources carry stale ghost rows. Returns ------- @@ -896,7 +983,9 @@ def _forward_with_block_attn_res( scalar_extractor=self._extract_l0_from_canonical, current_x=x, ) - so2_unit_output = self._run_so2_unit(so2_input, edge_cache, radial_feat) + so2_unit_output = self._run_so2_unit( + so2_input, edge_cache, radial_feat, comm_dict + ) with nvtx_range("ffn"): partial_block = so2_unit_output diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index 6537f9d84e..c1bc494eec 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -1048,6 +1048,7 @@ def forward_common_lower( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + extended_atype: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, input_prec: torch.dtype | None = None, @@ -1061,8 +1062,15 @@ def forward_common_lower( descriptor. ``edge_scatter_index`` defines the force/virial scatter domain, which may be local atoms for Python inference or local-plus-ghost slots for LAMMPS. + + ``comm_dict`` selects the parallel (LAMMPS multi-rank) path: the + descriptor runs on the extended node set and exchanges ghost features via + ``border_op`` between blocks. ``extended_atype`` (shape ``(nf, nall)``) + is then required so the descriptor can embed ghost atom types; ``atype`` + stays local ``(nf, nloc)`` for fitting and the energy read-out. The + parallel path always runs eager: its compiled artifact is produced + separately by the ``.pt2`` with-comm export, not by the runtime cache. """ - del comm_dict coord, _, fp, ap, inferred_input_prec = self._input_type_cast( coord, fparam=fparam, @@ -1073,6 +1081,8 @@ def forward_common_lower( if coord.ndim == 2: coord = coord.reshape(atype.shape[0], -1, 3) atype = atype.to(device=coord.device, dtype=torch.long) + if extended_atype is not None: + extended_atype = extended_atype.to(device=coord.device, dtype=torch.long) edge_index = edge_index.to(device=coord.device, dtype=torch.long) edge_vec = edge_vec.to(device=coord.device, dtype=coord.dtype) edge_scatter_index = edge_scatter_index.to( @@ -1089,6 +1099,13 @@ def forward_common_lower( should_compile = ( self.should_use_compile() if use_compile is None else use_compile ) + if comm_dict is not None: + if extended_atype is None: + raise ValueError( + "`extended_atype` (nf, nall) is required when `comm_dict` " + "is provided." + ) + should_compile = False charge_spin = self.convert_charge_spin( charge_spin, nf=nf, @@ -1199,6 +1216,8 @@ def forward_common_lower( fparam=fp, aparam=ap, charge_spin=charge_spin, + comm_dict=comm_dict, + extended_atype=extended_atype, extended_coord_corr=extended_coord_corr, embedding_only=embedding_only, ) @@ -1341,6 +1360,7 @@ def core_compute( aparam: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, + extended_atype: torch.Tensor | None = None, extended_coord_corr: torch.Tensor | None = None, embedding_only: bool = False, ) -> dict[str, torch.Tensor]: @@ -1373,7 +1393,13 @@ def core_compute( charge_spin Frame-level charge and spin conditions with shape `(nf, 2)`. comm_dict - Communication data for parallel inference. Currently unused. + Border-exchange tensors for parallel (LAMMPS multi-rank) inference. + When provided, the descriptor runs on the extended node set and + exchanges ghost features via ``border_op`` between blocks. + extended_atype + Extended atom types with shape ``(nf, nall)``. Required when + ``comm_dict`` is provided so the descriptor can embed ghost types; + unused otherwise. extended_coord_corr Coordinates correction for virial with shape ``(nf, nscatter, 3)`` or ``None``. @@ -1391,7 +1417,6 @@ def core_compute( ``descriptor`` (nf, nloc, d), ``atomic_feature`` (nf, nloc, h), and ``structural_feature`` (nf, h). """ - del comm_dict nf, nloc = atype.shape[:2] nscatter = coord.shape[1] descriptor_model = self.atomic_model.descriptor @@ -1408,14 +1433,22 @@ def core_compute( edge_vec = edge_vec.detach().requires_grad_(True) # === Step 2. Descriptor forward === + # ``extended_atype`` spans the extended region on the parallel path and + # reduces to ``atype`` (owned atoms) on the single-domain path; the + # descriptor returns the per-owner descriptor ``(nf, nloc, channels)`` + # either way. ``comm_dict`` (possibly ``None``) and ``nloc`` are + # forwarded unconditionally -- ``forward_with_edges`` ignores ``nloc`` + # without ``comm_dict``, and ``extended_coord`` only supplies the device. with nvtx_range("SeZM/descriptor"): descriptor, _ = descriptor_model.forward_with_edges( - extended_coord=coord[:, :nloc, :], - extended_atype=atype, + extended_coord=coord, + extended_atype=extended_atype if comm_dict is not None else atype, edge_index=edge_index, edge_vec=edge_vec, edge_mask=edge_mask, charge_spin=charge_spin, + comm_dict=comm_dict, + nloc=nloc, ) # === Atom mask === @@ -1632,6 +1665,7 @@ def forward_lower( do_atomic_virial: bool = False, comm_dict: dict[str, torch.Tensor] | None = None, charge_spin: torch.Tensor | None = None, + extended_atype: Int[Tensor, "nf nall"] | None = None, ) -> dict[str, torch.Tensor]: """ Lower-level public forward using the compact-edge contract. @@ -1658,9 +1692,13 @@ def forward_lower( do_atomic_virial Whether to compute atomic virial. comm_dict - Communication dict forwarded to `forward_common_lower()`. + Communication dict forwarded to `forward_common_lower()`. When + present, selects the parallel (LAMMPS multi-rank) path. charge_spin Frame-level charge and spin conditions with shape `(nf, 2)`. + extended_atype + Extended atom types with shape (nf, nall). Required on the parallel + path so ghost atom types can be embedded; ignored otherwise. Returns ------- @@ -1692,6 +1730,7 @@ def forward_lower( fparam=fparam, aparam=aparam, comm_dict=comm_dict, + extended_atype=extended_atype, charge_spin=charge_spin, ) if self.get_fitting_net() is not None: @@ -2375,6 +2414,130 @@ def fn( *trace_inputs, ) + def forward_common_lower_exportable_with_comm( + self, + coord: torch.Tensor, + atype: torch.Tensor, + extended_atype: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_scatter_index: torch.Tensor, + edge_mask: torch.Tensor, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + charge_spin: torch.Tensor | None, + send_list: torch.Tensor, + send_proc: torch.Tensor, + recv_proc: torch.Tensor, + send_num: torch.Tensor, + recv_num: torch.Tensor, + communicator: torch.Tensor, + nlocal: torch.Tensor, + nghost: torch.Tensor, + ) -> torch.nn.Module: + """Trace the parallel lower interface into an exportable FX ``GraphModule``. + + This mirrors :meth:`forward_common_lower_exportable` but threads the + extended atom types and the eight ``border_op`` communication tensors as + explicit positional inputs, fixing the C++ ABI that ``DeepPotPTExpt`` + uses for the multi-rank with-comm ``.pt2`` artifact. The opaque + ``deepmd_export::border_op`` is registered for tracing before + ``make_fx`` runs so its forward and reverse exchanges enter the graph as + external calls. + """ + if self.get_active_mode() == "dens": + raise NotImplementedError( + "SeZM export supports only the conservative `ener` path." + ) + # Imported lazily so plain pt inference never pulls the custom-op + # registration module onto its import path. + from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, + ) + + ensure_comm_registered() + + model = self + + def fn( + coord_: torch.Tensor, + atype_: torch.Tensor, + extended_atype_: torch.Tensor, + edge_index_: torch.Tensor, + edge_vec_: torch.Tensor, + edge_scatter_index_: torch.Tensor, + edge_mask_: torch.Tensor, + fparam_: torch.Tensor | None, + aparam_: torch.Tensor | None, + charge_spin_: torch.Tensor | None, + send_list_: torch.Tensor, + send_proc_: torch.Tensor, + recv_proc_: torch.Tensor, + send_num_: torch.Tensor, + recv_num_: torch.Tensor, + communicator_: torch.Tensor, + nlocal_: torch.Tensor, + nghost_: torch.Tensor, + ) -> dict[str, torch.Tensor]: + # Detach inside the closure so the exported graph roots at the + # per-edge ``edge_vec`` leaf created in ``core_compute``, never at + # the upstream LAMMPS coordinate tensor. + coord_ = coord_.detach() + edge_vec_ = edge_vec_.detach() + comm_dict = { + "send_list": send_list_, + "send_proc": send_proc_, + "recv_proc": recv_proc_, + "send_num": send_num_, + "recv_num": recv_num_, + "communicator": communicator_, + "nlocal": nlocal_, + "nghost": nghost_, + } + return model.forward_common_lower( + coord_, + atype_, + edge_index_, + edge_vec_, + edge_scatter_index_, + edge_mask_, + fparam=fparam_, + aparam=aparam_, + comm_dict=comm_dict, + extended_atype=extended_atype_, + charge_spin=charge_spin_, + use_compile=False, + ) + + if self.get_dim_chg_spin() > 0: + charge_spin = self.convert_charge_spin( + charge_spin, + nf=atype.shape[0], + dtype=coord.dtype, + device=coord.device, + ) + trace_inputs = ( + coord, + atype, + extended_atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fparam, + aparam, + charge_spin, + send_list, + send_proc, + recv_proc, + send_num, + recv_num, + communicator, + nlocal, + nghost, + ) + return self._trace_lower_exportable(fn, *trace_inputs) + # ========================================================================= # Neighbor List Construction # ========================================================================= @@ -2726,6 +2889,20 @@ def has_message_passing(self) -> bool: """Return whether the descriptor performs message passing.""" return self.atomic_model.has_message_passing() + def supports_edge_parallel(self) -> bool: + """Whether the edge-based LAMMPS multi-rank with-comm artifact applies. + + Cross-rank ghost-feature exchange is well-defined only for the + conservative non-bridging path: analytical ZBL bridging and its Source + Freeze Propagation gate fold each node's full outgoing-edge set, which a + single rank cannot observe for ghost owners. Spin models use the nlist + lower interface and are gated separately by the freeze entry point. + """ + if self.inter_potential is not None: + return False + descriptor = self.atomic_model.descriptor + return bool(descriptor.has_message_passing_across_ranks()) + # ========================================================================= # Mode Management # ========================================================================= diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index 6dc9949f42..2901b2ac2d 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -5,45 +5,55 @@ ::: DPA4 is the DeePMD-kit implementation of the SeZM (Smooth Equivariant -Zone-bridging Model) architecture. Use `model.type: "dpa4"` in new input -files. The aliases `DPA4`, `SeZM`, and `sezm` are accepted for the same -implementation. The DPA4 model scaffold uses the SeZM descriptor and the -`dpa4_ener` fitting network, so `descriptor.type` and `fitting_net.type` -may be omitted in ordinary DPA4 inputs. +Zone-bridging Model) architecture: an SO(3)-equivariant message-passing model +for conservative interatomic potentials. The aliases `DPA4`, `SeZM`, and +`sezm` all select the same implementation. + +`model.type: "dpa4"` is a convenience scaffold that fixes the SeZM descriptor +and the `dpa4_ener` energy fitting network, so `descriptor.type` and +`fitting_net.type` may be omitted. A new input then needs only the model type, +`type_map`, and a few descriptor options. Reference: [DPA4 paper](https://arxiv.org/abs/2606.02419). -Training example: `examples/water/dpa4/input.json`. +## Quick start -Quick start: +DPA4 is a PyTorch-only model. Train it with the standard `dp --pt` workflow: ```bash cd examples/water/dpa4 dp --pt train input.json ``` +`examples/water/dpa4/input.json` is a complete, compact training input you can +copy and adapt. See [training energy models](train-energy.md) for the general +training workflow shared by all energy models. + ## Overview -DPA4/SeZM is an SO(3)-equivariant message-passing model for conservative -interatomic potentials. It predicts atomic energies and obtains forces and -virials by differentiating the energy, following the same conservative -formulation used by standard DeePMD energy models: +DPA4/SeZM predicts atomic energies and obtains forces and virials by +differentiating the energy, the same conservative formulation used by standard +DeePMD energy models: ```math \mathbf{F}_i = -\frac{\partial E}{\partial \mathbf{r}_i}. ``` -The model keeps vector and higher-order angular information while building -the descriptor. Only the final descriptor sent to the fitting network is -scalar. This separates geometric representation from energy prediction: -equivariant layers encode local environments, and the fitting network maps -the resulting scalar features to atomic energies. +Internally the model keeps vector and higher-order angular (SO(3)-equivariant) +information while building the descriptor, and only the final descriptor sent +to the fitting network is a scalar. This separates geometric representation +(equivariant message-passing layers that encode local environments) from +energy prediction (the fitting network that maps scalar features to atomic +energies). The architecture targets a favorable accuracy–cost trade-off; if you +want the design details, see [Architecture details](#architecture-details) at +the end of this page. -## Model scaffold +## Configuration -The DPA4 model type is a convenience scaffold around the SeZM descriptor and -the `dpa4_ener` energy fitting network. A minimal input therefore only needs -the model type, `type_map`, and descriptor settings such as `sel` and `rcut`: +### Minimal input + +A minimal DPA4 model only needs the model type, `type_map`, and the neighbor +list (`sel`, `rcut`). Every other option uses its documented default. ```json { @@ -54,237 +64,59 @@ the model type, `type_map`, and descriptor settings such as `sel` and `rcut`: "H" ], "descriptor": { - "sel": 120, "rcut": 6.0 } } } ``` -Options that are not written in the input use their documented defaults. -The neighbor selection `sel` may be an integer total neighbor limit, a -per-type list, or `auto` / `auto:factor`. - -Internally, the PyTorch model builds a standard DeePMD neighbor list for the -public forward path. When `use_compile` is enabled, the model additionally -uses a compact sparse-edge path for compiled training. Both paths share the -same descriptor and fitting definitions. - -## Descriptor construction - -For each frame, DPA4/SeZM first builds a local neighbor graph within cutoff -radius `rcut`. Each edge stores the displacement vector, smooth cutoff -weights, radial basis features, and the rotation between the global frame and -an edge-aligned local frame. These edge features are built once per forward -call and reused by all interaction blocks. - -One DPA4/SeZM interaction block consists of the following operations: - -1. Gather source-atom equivariant features on each edge. -1. Rotate them into the edge-local frame. -1. Apply SO(2)-equivariant convolution on the retained angular orders. -1. Rotate messages back to the global frame. -1. Aggregate messages at destination atoms with smooth envelope weights or - attention weights. -1. Update atom features with an equivariant feed-forward block. - -After the last block, DPA4/SeZM keeps the `l = 0` scalar channels: - -```math -\mathcal{D}_i = \mathrm{Scalar}\left(\mathbf{h}_i^{(L)}\right), -``` - -where $\mathbf{h}_i^{(L)}$ is the final equivariant feature of atom `i`. - -## Angular representation - -DPA4/SeZM stores intermediate features as SO(3)-equivariant coefficients. A -feature block with maximum degree `lmax` contains all degrees -`l = 0, ..., lmax`, and each degree has `2l + 1` angular components. - -The model reduces angular cost by working in a local frame on each edge. In -that frame, rotations around the edge axis become SO(2) operations. The SO(2) -convolution retains orders `|m| <= mmax`, or the per-block value specified by -`m_schedule`, while preserving the required equivariant transformation -behavior. - -Two schedules control the angular width: - -- `l_schedule` sets the SO(3) degree used by each block. A schedule such as - `[3, 3, 2]` uses higher degrees in early blocks and truncates them in later - blocks. -- `mmax` or `m_schedule` sets how many SO(2) orders are retained in the - edge-local convolution. +DPA4/SeZM defaults to `float32` +({ref}`precision `); double precision is +unnecessary and not recommended (see [Hardware selection](#hardware-selection)). -The angular schedule is one of the primary accuracy-cost controls in -DPA4/SeZM. Larger angular spaces can represent more complex local chemistry, -but the cost grows quickly with `lmax`. For many systems, a non-increasing -`l_schedule` provides a practical compromise. - -## Radial basis and smooth cutoff - -Every edge uses a radial basis multiplied by a smooth envelope. The default -basis is Bessel-like, and a Gaussian basis is also available through -`basis_type`. The cutoff envelope is constructed so that its value and first -three derivatives vanish at `rcut`. This smoothness is important for -molecular dynamics because nonsmooth descriptor cutoffs would be inherited by -force derivatives. - -DPA4/SeZM uses two envelope exponents through `env_exp`: - -- the first exponent controls the radial basis envelope, -- the second exponent controls message-passing edge weights. - -Increasing an exponent keeps the corresponding envelope closer to one for -more of the cutoff range before it drops near `rcut`. - -## Attention and focus streams - -DPA4/SeZM can aggregate edge messages either by envelope-weighted scatter or -by attention. When attention is enabled with `n_atten_head > 0`, the cutoff -envelope also participates in the softmax normalization. Edges near the -cutoff are therefore smoothly suppressed in both the numerator and the -denominator, avoiding nonsmooth contributions from the normalization term. - -The SO(2) convolution can also use multiple focus streams through `n_focus`. -These streams process the same edge geometry in parallel and are then -combined through scalar weights. This design is not a sparse mixture of -experts: all focus streams are evaluated before soft reweighting. The -additional capacity helps the convolution distinguish different local -patterns while preserving equivariance. - -## Grid nonlinearities - -Several DPA4/SeZM branches can use sphere-grid or SO(3)-grid nonlinearities -inside the equivariant network. The most commonly used public switches are: - -- `s2_activation`, which enables S2-grid nonlinearities for the SO(2) branch - and/or the block-internal feed-forward branch. -- `ffn_so3_grid`, which uses an SO(3) Wigner-D grid in the block-internal - feed-forward path. -- `lebedev_quadrature`, which selects packaged Lebedev quadrature rules for - enabled S2-grid branches. -- `grid_mlp` and `grid_branch`, which select the polynomial point-wise MLP or - the scalar-routed polynomial branch mixer for each grid path. Each is either - a single value applied to every path or a list - `[node_wise, message_node, ffn]`. - -These options affect the expressiveness and cost of the equivariant -nonlinearity. The final `l = 0` output descriptor remains a scalar feature -tensor consumed by the fitting network. - -## Environment-seeded initial features - -When `use_env_seed` is enabled, DPA4/SeZM seeds the initial node state from -the local environment before the equivariant message-passing blocks. The -scalar seed uses a DeePMD-style local environment matrix with radial -information and normalized directions, then produces FiLM-like scale and -shift values for the first scalar features. When non-scalar degrees are -present, the same switch also enables the geometric initial embedding. - -When `use_env_seed` is disabled, the initial node state contains only -atom-local scalar features before message passing. This keeps a one-block -model closed over the one-hop neighbor shell. - -## Zone bridging and ZBL - -DPA4/SeZM includes an optional short-range bridge for analytical repulsion. -The typical use case is ZBL: - -```math -E_i = E_i^{\mathrm{DPA4/SeZM}} + E_i^{\mathrm{ZBL}}. -``` - -The purpose of zone bridging is to combine the analytical short-range -repulsion with the learned model while preventing uncontrolled learned forces -in the same protected region. - -Zone bridging has two pieces: - -1. Distances below `bridging_r_inner` are clamped before they enter the - descriptor. Between `bridging_r_inner` and `bridging_r_outer`, a smooth - polynomial transitions back to the true distance. -1. A source gate suppresses message propagation from atoms involved in frozen - short-range pairs. This blocks multi-hop leakage, where a third atom could - otherwise carry information about the frozen pair back into the learned - energy. - -This gives a controlled decomposition in the protected region: - -```math -E_\mathrm{total}(r) = E_\mathrm{ZBL}(r) + E_\mathrm{model}(\tilde r), -``` - -where $r$ is the true distance and $\tilde r$ is the clamped distance seen by -the descriptor. - -Enable zone bridging with: - -```json -{ - "model": { - "bridging_method": "zbl", - "bridging_r_inner": 0.5, - "bridging_r_outer": 0.8 - } -} -``` - -When ZBL bridging is enabled, set `training.training_data.min_pair_dist` to -the same value as `bridging_r_inner` so that frames with shorter atom pairs -are excluded from training. See `examples/water/dpa4/input-zbl.json` for a -complete ZBL input example. - -## Fitting network - -DPA4/SeZM uses the `dpa4_ener` energy fitting implementation. It is selected -automatically by the DPA4 model scaffold and maps scalar descriptors to atomic -energies. - -The fitting network uses the same common keys as DeePMD's standard energy -fitting network: - -- `neuron` -- `activation_function` -- `precision` -- `seed` -- `numb_fparam` -- `numb_aparam` - -The hidden layers use GLU-style transformations. If `neuron` is `[0]`, the -fitting network uses a direct projection from descriptor channels to atomic -energy. This compact setting is useful for small examples and quick -validation tests. - -For shared-fitting multitask training, DPA4/SeZM supports case embeddings. -With `case_film_embd: true`, the case vector modulates the fitting network -instead of being concatenated directly to the descriptor. This keeps the -descriptor case-independent while allowing the energy map to depend on the -task branch. - -## Configuration - -For a complete training input, see `examples/water/dpa4/input.json`. The -example uses a compact water setup with the DPA4 model type, SeZM descriptor -options, `dpa4_ener` fitting settings, and the standard conservative energy -loss. Its structure is closer to a DPA4-Neo-style compact configuration than -to the DPA4-Air pretrained configuration. - -Common descriptor controls include: - -- `sel` and `rcut` for the neighbor list. -- `channels`, `n_radial`, and `basis_type` for feature width and radial - resolution. -- `lmax`, `l_schedule`, `mmax`, and `m_schedule` for angular resolution. -- `n_blocks`, `so2_layers`, and `ffn_blocks` for network depth. -- `n_focus` and `n_atten_head` for focus streams and attention aggregation. -- `use_env_seed`, `s2_activation`, `ffn_so3_grid`, and `message_node_so3` for - the main geometric feature paths. -- `use_amp` and `precision` for training precision. - -## Training modes +:::{note} +{ref}`sel ` behaves differently from classic +descriptors. On the conservative **energy** path it is only an initial +neighbor-search capacity that grows on demand, so it never truncates the +neighbor list and you do not need to size it to the true maximum neighbor count. +Only the denoising (`dens`) and spin paths cap the list at `sum(sel)`. You can +also set `sel` to `auto` or `auto:factor` to size it from the training data. +::: -The recommended training objective is the standard conservative energy loss: +### Main options + +Every descriptor option, with its default and full description, is listed in +the {ref}`argument reference `. The options worth +tuning first group into four accuracy–cost levers: + +- **Angular width** — the primary control. {ref}`lmax ` + with the per-block pyramid {ref}`l_schedule ` + (which overrides `lmax` and `n_blocks`), and the SO(2) order + {ref}`mmax ` / + {ref}`m_schedule `. Cost grows + quickly with `lmax`; a non-increasing `l_schedule` is often a good compromise. +- **Depth** — {ref}`n_blocks `, + {ref}`so2_layers `, + {ref}`ffn_blocks `. +- **Width** — {ref}`channels `, + {ref}`n_radial `. +- **Aggregation** — {ref}`n_focus `, + {ref}`n_atten_head ` (`0` falls + back to a plain envelope-weighted scatter). + +The neighbor list is set by {ref}`rcut ` and +{ref}`sel `, the initial node features by +{ref}`use_env_seed `, and the energy +head by the fitting `neuron` list (`[0]` is a direct projection). The quickest +starting point is to copy `examples/water/dpa4/input.json` and adjust the +levers above. + +## Training + +### Energy training (default) + +The recommended objective is the standard conservative energy loss. The model +predicts energies and forces are obtained by autograd: ```json { @@ -294,11 +126,11 @@ The recommended training objective is the standard conservative energy loss: } ``` -In this mode, the model predicts energies, and forces are computed by -autograd. See [training energy models](train-energy.md) for the general -energy-training workflow. +See [training energy models](train-energy.md) for the general workflow. -DPA4/SeZM also has an experimental direct-force denoising mode selected by: +### Direct-force denoising (`dens`, experimental) + +DPA4/SeZM has an experimental direct-force denoising head: ```json { @@ -308,14 +140,13 @@ DPA4/SeZM also has an experimental direct-force denoising mode selected by: } ``` -Use `dens` only when the direct-force denoising head is required. It is not -the default training path. See `examples/water/dpa4/input_dens.json` for an -example input. +Use `dens` only when the direct-force denoising head is required; it is not the +default training path. See `examples/water/dpa4/input_dens.json` for an example. -## Spin +### Spin -DPA4/SeZM supports the DeePMD-kit spin convention in the PyTorch backend. -Keep the DPA4/SeZM type string and add the standard `model.spin` block: +DPA4/SeZM supports the DeePMD-kit spin convention. Keep the model type and add +the standard `model.spin` block: ```json { @@ -342,237 +173,212 @@ Keep the DPA4/SeZM type string and add the standard `model.spin` block: } ``` -The spin path supports the conservative `ener_spin` loss. The direct-force -denoising mode is not used together with spin. See -[training spin energy models](train-energy-spin.md) for the common spin -training settings, and `examples/water/dpa4/input-spin.json` for a DPA4-style -input example. +The spin path uses the conservative `ener_spin` loss and is not combined with +the `dens` mode. See [training spin energy models](train-energy-spin.md) and +`examples/water/dpa4/input-spin.json`. -## Performance and hardware recommendations +### Multi-task / shared fitting -### bfloat16 automatic mixed precision +DPA4/SeZM supports shared-fitting multitask training. With +`case_film_embd: true`, the case vector modulates the fitting network instead +of being concatenated to the descriptor, which keeps the descriptor +case-independent while letting the energy map depend on the task branch. See +[multi-task training](../train/multi-task-training.md) for the workflow and +`examples/water/dpa4/input_multitask.json` for an example. -DPA4/SeZM supports automatic mixed precision (AMP) during training through the -descriptor option `use_amp`, whose default value is `true`. This option uses -bfloat16 (bf16) autocast for eligible CUDA operations. In typical DPA4/SeZM -workloads, bf16 AMP reduces memory usage and may improve throughput while -preserving fitted accuracy; no visible accuracy degradation is expected in -normal DPA4/SeZM training. Numerically sensitive geometric operations are kept -in promoted precision. +### LoRA fine-tuning -When the GPU provides native bf16 support, enabling `use_amp` is recommended: +DPA4/SeZM supports LoRA adapters on its SO(3) and SO(2) linear layers, intended +for single-task fine-tuning: ```json { "model": { - "descriptor": { - "use_amp": true - } - } -} -``` - -On GPUs without native bf16 support, explicitly set `use_amp` to `false` to -avoid runtime errors or additional conversion overhead: - -```json -{ - "model": { - "descriptor": { - "use_amp": false + "type": "dpa4", + "lora": { + "rank": 16, + "alpha": 16.0 } } } ``` -On NVIDIA hardware, native bf16 support starts with the Ampere generation, -including A100-series accelerators and RTX 30-series GPUs, and continues on -newer architectures. - -### Experimental `torch.compile` path +Fine-tune from a checkpoint: -DPA4/SeZM can train through an experimental `torch.compile` path: - -```json -{ - "model": { - "use_compile": true - } -} +```bash +dp --pt train lora_ft.json --finetune pretrained.pt ``` -This path is useful for force-loss training, where differentiating the force -loss requires higher-order derivatives through the conservative -energy-gradient path. DPA4/SeZM traces this path before passing it to -Inductor. - -This path is experimental and may expose PyTorch compiler issues. It currently -requires `torch==2.11`; other PyTorch versions are not supported for this -compiled DPA4/SeZM training path. On NVIDIA GPUs, CUDA must be >= 12.6. Apple -Silicon Macs are also supported. It has been tested with Python 3.13. If the -compiled path fails or produces unexpected behavior, please report the issue -with the PyTorch version, CUDA version, GPU model, and a minimal input file. - -### Inference environment variables +Best checkpoints fold the LoRA deltas back into the base weights, producing +plain DPA4/SeZM checkpoints suitable for deployment. See +`examples/water/dpa4/lora_ft.json`. -DPA4/SeZM reads inference-related environment variables when the PyTorch model -is constructed. If these variables are already exported in the shell, they -take precedence over values written in the input file. Changing them after -model construction does not affect that model instance. +## Zone bridging (ZBL) -`DP_COMPILE_INFER` controls whether evaluation and inference forwards use the -DPA4/SeZM compile path: +DPA4/SeZM can add an analytical short-range repulsion (typically ZBL) to the +learned energy in a protected region: -```bash -export DP_COMPILE_INFER=1 +```math +E_i = E_i^{\mathrm{DPA4/SeZM}} + E_i^{\mathrm{ZBL}}. ``` -Accepted true values are `1`, `true`, `yes`, and `on`; accepted false values -are `0`, `false`, `no`, and `off`. Enabling this path has the same PyTorch -version requirements as `model.use_compile`. - -During training validation, the same setting can be requested in the input -file: +Below `bridging_r_inner` the distance seen by the descriptor is clamped, with a +smooth transition back to the true distance up to `bridging_r_outer`; a source +gate additionally blocks the learned model from leaking information about the +frozen short-range pairs. Enable it with: ```json { - "validating": { - "compiled_infer": true + "model": { + "bridging_method": "zbl", + "bridging_r_inner": 0.5, + "bridging_r_outer": 0.8 } } ``` -The trainer translates this option into `DP_COMPILE_INFER=1` before model -construction, unless the shell environment already defines `DP_COMPILE_INFER`. +When ZBL bridging is enabled, set `training.training_data.min_pair_dist` to the +same value as `bridging_r_inner` so frames with shorter atom pairs are excluded +from training. See `examples/water/dpa4/input-zbl.json` for a complete example. + +## Performance and precision + +### Training-time settings + +Three options control training precision and the compiled path: + +- {ref}`use_amp ` — bf16 automatic mixed + precision on CUDA. Reduces memory and often improves throughput with no + expected accuracy loss. Recommended on GPUs with native bf16 (NVIDIA Ampere + and newer, e.g. A100 / RTX 30-series); set it off on GPUs without native bf16 + to avoid runtime errors and conversion overhead. +- {ref}`enable_tf32 ` — TF32 matmul precision for CUDA + **training** forwards (independent of `use_compile`, and separate from the + inference TF32 control below). +- {ref}`use_compile ` — experimental `torch.compile` + training path. Useful for force-loss training (higher-order derivatives + through the energy gradient) and can speed training markedly on supported + setups. + +### Inference and deployment settings + +Inference behavior is controlled by environment variables, each with an +equivalent input-file option used during training validation: + +| Environment variable | Input-file option | Default | Effect | +| -------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `DP_COMPILE_INFER` | `validating.compiled_infer` | off | Use the compile path for evaluation/inference. Same `torch==2.11` / CUDA ≥ 12.6 requirements as `model.use_compile`. | +| `DP_TF32_INFER` | `validating.tf32_infer` | `0` (highest) | float32 matmul precision for inference: `0` highest, `1` high, `2` medium. Higher values improve throughput but make the potential energy surface less smooth. | +| `DP_TRITON_INFER` | — | off | Fused block-diagonal Triton kernels for the SO(2) Wigner-D rotation (CUDA eval only). Lower latency and peak memory, numerically equivalent to the dense path with full float32 accumulation. Compatible with `DP_COMPILE_INFER`. | + +Accepted boolean values are `1`/`true`/`yes`/`on` and `0`/`false`/`no`/`off`. +Shell exports take precedence over the input-file options and over values +written in the input; they are read when the model is constructed and changing +them afterward has no effect. + +For molecular dynamics and other workflows sensitive to the smoothness of the +potential energy surface, keep `DP_TF32_INFER=0`. `DP_TRITON_INFER=1` retains +full float32 accumulation regardless of `DP_TF32_INFER` and is therefore safe +for those workflows. + +:::{important} +Set these variables **before** running `dp --pt freeze`. The exported `.pt2` is +an AOTInductor artifact, so the SO(2) rotation branch (`DP_TRITON_INFER`) and +the matmul precision (`DP_TF32_INFER`) are captured into the graph at export +time and are **not** re-evaluated when the `.pt2` is later loaded by ASE or +LAMMPS. A frozen `.pt2` runs a forward-only package, so training-time +memory-saving switches do not apply to it. +::: -`DP_TF32_INFER` controls the float32 matmul precision used by evaluation and -inference forwards on CUDA: +### Hardware selection -- `0`: use PyTorch `highest` precision. This is the default. -- `1`: use PyTorch `high` precision. -- `2`: use PyTorch `medium` precision. +DPA4/SeZM is designed for fp32 training and inference, so prefer GPUs with high +fp32 throughput and native bf16 support rather than strong fp64 performance. +Because bf16 AMP substantially reduces the activation memory footprint, very +large device memory is usually less important than fp32 FLOPS and bf16 support +once the target system and batch size fit. -During training validation, the input option -`validating.tf32_infer: true` is translated into `DP_TF32_INFER=1` before -model construction, again without overriding an explicitly exported -environment variable. Training forwards are controlled separately by -`model.enable_tf32`, independently of whether `model.use_compile` selects the -compiled or eager training path. +## Export and running in LAMMPS -For molecular dynamics and other workflows that are sensitive to potential -energy surface smoothness, keep `DP_TF32_INFER=0`. Enabling TF32 inference may -leave energy and force MAE nearly unchanged while making the potential energy -surface less smooth. For less smoothness-sensitive evaluation or screening -workloads, `DP_TF32_INFER=1` or `2` may be useful for improving throughput. +### Freeze to `.pt2` -`DP_TRITON_INFER` enables fused block-diagonal Triton kernels for the SO(2) -Wigner-D rotation. It applies to evaluation and inference on CUDA in eval mode -only and is disabled by default: +DPA4/SeZM checkpoints use the PyTorch `.pt2` (AOTInductor) export path; the +ordinary TorchScript freeze path is not used. Run the standard freeze command: ```bash -export DP_TRITON_INFER=1 +dp --pt freeze -c model.ckpt.pt -o frozen_model ``` -The kernels operate on the block-diagonal (by degree `l`) structure of the -Wigner-D matrix and are numerically equivalent to the default dense rotation up -to floating-point rounding. They retain full float32 accumulation regardless of -`DP_TF32_INFER` and are therefore appropriate for smoothness-sensitive -workflows. They are compatible with the compile path (`DP_COMPILE_INFER=1`) and -reduce both latency and peak memory. - -When exporting DPA4/SeZM to `.pt2`, set inference environment variables before -running `dp --pt freeze`. The exported package is an AOTInductor artifact, so -graph-level choices and compiler precision settings are fixed during export and -are not re-evaluated when the `.pt2` file is later loaded by ASE or LAMMPS. -In particular, `DP_TRITON_INFER` selects the SO(2) rotation branch that is -captured into the exported graph, and `DP_TF32_INFER` should be set before -export if TF32 inference is desired. `DP_ACT_INFER` is not a runtime control for -`.pt2` inference: activation checkpointing is a Python/autograd memory-saving -strategy, while `.pt2` inference runs a forward-only AOTI package whose force -and virial computations have already been lowered into the exported graph. - -### Hardware selection - -DPA4/SeZM is designed for fp32 training and inference. Hardware selection -should therefore be based primarily on fp32 throughput rather than fp64 -throughput. In contrast to workloads dominated by double-precision linear -algebra, DPA4/SeZM does not require GPUs with especially strong fp64 -performance. - -For practical training, prefer GPUs that combine high fp32 FLOPS with native -bf16 support. Native bf16 enables the recommended AMP path, lowering memory -usage and often improving throughput. Because AMP can substantially reduce the -activation memory footprint, DPA4/SeZM training usually does not require -unusually large-memory GPUs once the target system and batch size fit. In that -regime, native bf16 support and fp32 FLOPS are usually more important -selection criteria than maximum device memory. - -## LoRA fine-tuning +The PyTorch backend detects DPA4/SeZM and writes `frozen_model.pt2`. -DPA4/SeZM supports LoRA adapters on its SO(3) and SO(2) linear layers. This -mode is intended for single-task fine-tuning. A typical input block is: +### Single GPU -```json -{ - "model": { - "type": "dpa4", - "lora": { - "rank": 16, - "alpha": 16.0 - } - } -} -``` - -Then fine-tune from a checkpoint: +Use the frozen `.pt2` with the `deepmd` pair style. A small example is in +`examples/water/dpa4/lmp/`. -```bash -dp --pt train lora_ft.json --finetune pretrained.pt +```lammps +pair_style deepmd frozen_model.pt2 +pair_coeff * * O H ``` -See `examples/water/dpa4/lora_ft.json` for a complete example. +### Multi-GPU (MPI) inference -## Export +The exported `.pt2` runs across multiple GPUs in LAMMPS using MPI domain +decomposition. Multi-GPU support is built into the package by `dp --pt freeze`, +so no extra freeze options are needed and the same `.pt2` file serves both +single- and multi-GPU runs. -DPA4/SeZM checkpoints use the PyTorch `.pt2` export path. Run the standard -freeze command: +Launch LAMMPS with one MPI rank per GPU and make the target devices visible: ```bash -dp --pt freeze -c model.ckpt.pt -o frozen_model +CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -np 4 lmp -in in.lammps ``` -The PyTorch backend detects DPA4/SeZM and writes `frozen_model.pt2`. Use this -`.pt2` file with LAMMPS: +Each MPI rank uses at most one GPU, so `CUDA_VISIBLE_DEVICES` must list every +GPU the run may use. If only one device is visible, all ranks share it: results +stay correct, but the GPU work is serialized and that device's memory grows +with the rank count. + +DPA4/SeZM exchanges neighbor information across the domain boundary, so the +LAMMPS atom map must be enabled: ```lammps +atom_modify map yes pair_style deepmd frozen_model.pt2 pair_coeff * * O H ``` -The ordinary TorchScript freeze path is not used for DPA4/SeZM checkpoints. -A small LAMMPS example is in `examples/water/dpa4/lmp/`. +Two settings improve multi-GPU runs: + +- For fast GPU-to-GPU exchange, build the C++ interface against a + [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library; + otherwise the cross-rank exchange falls back to a slower CPU path. +- Use a non-zero neighbor skin, e.g. `neighbor 2.0 bin`, to keep per-step GPU + memory stable. A zero skin rebuilds the neighbor list every step and can + substantially increase memory use. + +Multi-GPU inference applies to the plain energy model. ZBL zone bridging and +spin models run on a single MPI rank. ## Embedding extraction A trained DPA4/SeZM model can export learned representations for downstream -analysis with `dp embed`. A single forward pass (no force or virial -computation) produces three embeddings per system: - -- `descriptor`: the per-atom local-environment representation, with shape - (nframes, natoms, dim_descriptor). -- `atomic_feature`: the per-atom activation after the last fitting hidden layer, - with shape (nframes, natoms, dim_hidden). -- `structural_feature`: a whole-structure summary obtained by summing - `atomic_feature` over atoms, with shape (nframes, dim_hidden). +analysis with `dp embed`. A single forward pass (no force or virial) produces +three embeddings per system: -A typical invocation operates on the PyTorch checkpoint (`.pt`): +- `descriptor`: per-atom local-environment representation, shape + `(nframes, natoms, dim_descriptor)`. +- `atomic_feature`: per-atom activation after the last fitting hidden layer, + shape `(nframes, natoms, dim_hidden)`. +- `structural_feature`: whole-structure summary obtained by summing + `atomic_feature` over atoms, shape `(nframes, dim_hidden)`. ```bash dp embed -m model.ckpt.pt -s /path/to/system -o embedding.hdf5 ``` The results are written to a single HDF5 file in which each system is a group -holding the three float32 datasets above. They can be read back with `h5py`: +holding the three float32 datasets above: ```python import h5py @@ -585,23 +391,92 @@ with h5py.File("embedding.hdf5", "r") as f: structural_feature = group["structural_feature"][:] ``` -This command is available for DPA4/SeZM energy models in the PyTorch backend and -honors both `DP_COMPILE_INFER` and `DP_TRITON_INFER`. It operates on the training checkpoint (`.pt`); the -frozen `.pt2` package is not supported. See +This command operates on the training checkpoint (`.pt`), not the frozen +`.pt2`, and honors both `DP_COMPILE_INFER` and `DP_TRITON_INFER`. See [model embeddings](../inference/embedding.md) for the full description. ## Data format -DPA4/SeZM uses the [standard DeePMD-kit data format](../data/system.md). Keep -the `type_map` order consistent across the dataset, input file, and any -downstream `pair_coeff` mapping. +DPA4/SeZM uses the [standard DeePMD-kit data format](../data/system.md) and +also supports the [mixed-type data format](../data/system.md#mixed-type), which +is convenient for datasets that mix many element combinations (and is the usual +choice for multi-task training). Keep the `type_map` order consistent across the +dataset, the input file, and any downstream `pair_coeff` mapping. + +## Architecture details + +Optional background on how the descriptor works, linking each part to the +options that control it. Skip it unless you are tuning those options. + +### Equivariant representation and the l = 0 read-out + +DPA4/SeZM stores intermediate features as SO(3)-equivariant coefficients: a +feature block of maximum degree `lmax` holds all degrees `l = 0, …, lmax`, each +with `2l + 1` angular components (controlled by `lmax` / `l_schedule`). + +For each frame the model first builds a local neighbor graph within `rcut`. +Each edge stores the displacement vector, smooth cutoff weights, radial basis +features, and the rotation between the global frame and an edge-aligned local +frame; these are built once and reused by all blocks. One interaction block +then (1) gathers source-atom features on each edge, (2) rotates them into the +edge-local frame, (3) applies an SO(2)-equivariant convolution on the retained +angular orders, (4) rotates the messages back, (5) aggregates them at +destination atoms with envelope or attention weights, and (6) updates atom +features with an equivariant feed-forward block. + +Working in the edge-local frame turns rotations around the edge axis into SO(2) +operations, so the cost scales with `lmax` instead of cubically. The SO(2) +convolution retains orders `|m| ≤ mmax` (or `m_schedule`). After the last block, +only the `l = 0` scalar channels are read out and passed to the fitting network: + +```math +\mathcal{D}_i = \mathrm{Scalar}\left(\mathbf{h}_i^{(L)}\right). +``` + +### Radial basis and smooth cutoff + +Every edge uses a radial basis (`basis_type`, with `n_radial` functions) +multiplied by a smooth envelope whose value and first three derivatives vanish +at `rcut`. This smoothness matters for MD because nonsmooth descriptor cutoffs +would be inherited by the force derivatives. The two `env_exp` exponents control +the radial-basis envelope and the message-passing edge weights respectively; +larger values keep an envelope closer to one for more of the cutoff range. + +### Attention and focus streams + +Messages are aggregated either by envelope-weighted scatter or by attention +(`n_atten_head > 0`). With attention, the cutoff envelope participates in the +softmax normalization, so edges near `rcut` are smoothly suppressed in both the +numerator and the denominator. The SO(2) convolution can also use multiple +`n_focus` streams that process the same edge geometry in parallel and are +combined by scalar weights, adding capacity while preserving equivariance. + +### Grid nonlinearities + +Several branches can use sphere-grid (S2) or SO(3) Wigner-D grid +nonlinearities. The main switches are `s2_activation` (S2-grid nonlinearity for +the SO(2) and/or FFN branch), `ffn_so3_grid` (SO(3) grid in the block-internal +FFN), `lebedev_quadrature` (Lebedev rules for enabled S2 branches), and +`grid_mlp` / `grid_branch` (point-wise polynomial MLP or scalar-routed branch +mixer per grid path). These trade expressiveness for cost; the final `l = 0` +output remains a scalar. + +### Environment-seeded initial features + +With `use_env_seed` enabled, the initial node state is seeded from the local +environment: a DeePMD-style environment matrix produces FiLM-like scale and +shift values for the first scalar features, and the geometric initial embedding +is enabled when non-scalar degrees are present. With it disabled, the initial +state contains only atom-local scalar features, which keeps a one-block model +closed over the one-hop neighbor shell. ## Limitations -- DPA4/SeZM is currently implemented for the PyTorch backend. +- DPA4/SeZM is implemented for the PyTorch backend only. +- Export uses `.pt2` (AOTInductor); the TorchScript freeze path is not used. - Model compression is not supported. -- Export uses `.pt2`; the ordinary TorchScript freeze path is not used for - DPA4/SeZM checkpoints. +- Multi-GPU (MPI) LAMMPS inference is supported for the plain energy model; + ZBL zone bridging and spin models run on a single MPI rank. ## Citation diff --git a/source/api_cc/include/DeepPotPTExpt.h b/source/api_cc/include/DeepPotPTExpt.h index 669fa99e71..68a553e29c 100644 --- a/source/api_cc/include/DeepPotPTExpt.h +++ b/source/api_cc/include/DeepPotPTExpt.h @@ -418,6 +418,30 @@ class DeepPotPTExpt : public DeepPotBackend { const torch::Tensor& charge_spin, const std::vector& comm_tensors); + /** + * @brief Run the with-comm edge (SeZM) ``.pt2`` artifact with comm tensors. + * + * The edge schema indexes the extended node set, so ``edge_index`` and + * ``edge_scatter_index`` coincide. ``atype`` carries owned atoms (fitting, + * energy read-out) while ``extended_atype`` embeds ghost neighbours. + * + * @param[in] comm_tensors 8 comm tensors in canonical positional order: + * send_list, send_proc, recv_proc, send_num, recv_num, + * communicator, nlocal, nghost. + */ + std::vector run_model_edges_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& extended_atype, + const torch::Tensor& edge_index, + const torch::Tensor& edge_vec, + const torch::Tensor& edge_scatter_index, + const torch::Tensor& edge_mask, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const torch::Tensor& charge_spin, + const std::vector& comm_tensors); + /** * @brief Extract outputs from flat tensor list using output_keys. */ diff --git a/source/api_cc/include/commonPT.h b/source/api_cc/include/commonPT.h index 637ade70b4..643e53974a 100644 --- a/source/api_cc/include/commonPT.h +++ b/source/api_cc/include/commonPT.h @@ -173,6 +173,11 @@ struct EdgeTensorPack { * topology and compacts it on-device every step, so it passes ``false``. * The returned edge_vec and edge_mask are left undefined in that case. * @param row_centers Optional center atom index for each neighbor-list row. + * @param fold_to_local Whether edge_index folds ghost neighbours onto their + * local owners via ``mapping`` (single-domain message passing). When false, + * edge_index indexes the extended atoms directly and coincides with + * edge_index_ext; this is the multi-rank with-comm convention where ghost + * node features are exchanged across ranks rather than gathered locally. */ template inline EdgeTensorPack createEdgeTensors( @@ -183,7 +188,8 @@ inline EdgeTensorPack createEdgeTensors( const int nall, const torch::Device& device, const bool with_geometry = true, - const std::vector* row_centers = nullptr) { + const std::vector* row_centers = nullptr, + const bool fold_to_local = true) { std::vector src; std::vector dst; std::vector src_ext; @@ -217,9 +223,18 @@ inline EdgeTensorPack createEdgeTensors( if (jj < 0 || jj >= nall) { continue; } - const std::int64_t src_local = mapping[static_cast(jj)]; - if (src_local < 0 || src_local >= nloc) { - continue; + // edge_index source: the local owner (folded) for single-domain message + // passing, or the extended atom itself for the multi-rank with-comm + // convention where ghost features are exchanged across ranks. + std::int64_t src_node; + if (fold_to_local) { + const std::int64_t src_local = mapping[static_cast(jj)]; + if (src_local < 0 || src_local >= nloc) { + continue; + } + src_node = src_local; + } else { + src_node = jj; } const size_t neighbor_offset = static_cast(jj) * 3; const VALUETYPE dx = coord[neighbor_offset] - coord[center_offset]; @@ -231,7 +246,7 @@ inline EdgeTensorPack createEdgeTensors( if (rr <= static_cast(1e-10)) { continue; } - src.push_back(src_local); + src.push_back(src_node); dst.push_back(center); src_ext.push_back(jj); dst_ext.push_back(center); diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index 673468a172..96033fcab4 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -330,6 +330,52 @@ std::vector DeepPotPTExpt::run_model_with_comm( return with_comm_loader->run(inputs); } +std::vector DeepPotPTExpt::run_model_edges_with_comm( + const torch::Tensor& coord, + const torch::Tensor& atype, + const torch::Tensor& extended_atype, + const torch::Tensor& edge_index, + const torch::Tensor& edge_vec, + const torch::Tensor& edge_scatter_index, + const torch::Tensor& edge_mask, + const torch::Tensor& fparam, + const torch::Tensor& aparam, + const torch::Tensor& charge_spin, + const std::vector& comm_tensors) { + if (!with_comm_loader) { + throw deepmd::deepmd_exception( + "run_model_edges_with_comm called but the with-comm artifact is not " + "available. Either the .pt2 file has no with-comm artifact compiled " + "(programming error: the caller should check has_comm_artifact_ " + "before invoking this path), or the artifact was present in the " + ".pt2 metadata but failed to load at init time (see earlier stderr " + "log). Multi-rank LAMMPS requires a working with-comm artifact."); + } + if (comm_tensors.size() != 8) { + throw deepmd::deepmd_exception( + "run_model_edges_with_comm: comm_tensors must contain exactly 8 " + "tensors (send_list, send_proc, recv_proc, send_num, recv_num, " + "communicator, nlocal, nghost). Got " + + std::to_string(comm_tensors.size()) + "."); + } + std::vector inputs = {coord, atype, extended_atype, + edge_index, edge_vec, edge_scatter_index, + edge_mask}; + if (dfparam > 0) { + inputs.push_back(fparam); + } + if (daparam > 0) { + inputs.push_back(aparam); + } + if (dchgspin > 0) { + inputs.push_back(charge_spin); + } + for (const auto& t : comm_tensors) { + inputs.push_back(t); + } + return with_comm_loader->run(inputs); +} + void DeepPotPTExpt::extract_outputs( std::map& output_map, const std::vector& flat_outputs) { @@ -499,9 +545,15 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, // model-input dummy edges are rebuilt on-device from current coordinates // every step, so rcut-crossing skin atoms are handled without carrying // out-of-cutoff edges into the exported graph. - const auto edge_tensors = - createEdgeTensors(nlist_data.jlist, dcoord, mapping, nloc, nall_real, - device, /*with_geometry=*/false, &nlist_data.ilist); + // + // Multi-rank inference indexes the extended node set directly + // (``fold_to_local=false``): ghost neighbours stay as distinct nodes so + // their features can be exchanged across ranks via border_op, instead of + // being folded onto a local owner that this rank does not own. + const auto edge_tensors = createEdgeTensors( + nlist_data.jlist, dcoord, mapping, nloc, nall_real, device, + /*with_geometry=*/false, /*row_centers=*/&nlist_data.ilist, + /*fold_to_local=*/!use_with_comm); edge_index_tensor = edge_tensors.edge_index; edge_index_ext_tensor = edge_tensors.edge_index_ext; } else { @@ -601,16 +653,57 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, std::vector> remapped_sendlist; std::vector remapped_sendlist_ptrs; std::vector remapped_sendnum, remapped_recvnum; + // Empty-subdomain phantom padding (edge with-comm path only): a rank that + // owns zero local atoms would feed nloc==0 / nedge==0 into the with-comm + // artifact, which is traced with nloc_min=1 / nedge_min=2 and may be lowered + // by inductor under an even stricter nloc>=2 assumption -- so a 0-atom rank + // can SIGFPE or silently corrupt instead of throwing. Mirror the spin + // phantom-atom workaround (DeepSpinPTExpt): run the graph with two phantom + // local atoms and two masked self-edges, then report the empty rank's true + // (zero) contribution. Setting the comm ``nlocal`` to ``phantom_n`` makes + // border_op write received ghost features past the phantom slots, so + // collective communication stays in lockstep with non-empty ranks. + // Gated on ``use_with_comm`` so the strip-back below never touches outputs + // that the regular (non-comm) path produced without phantom padding. + int phantom_n = (use_with_comm && lower_input_is_edge_ && nloc == 0) ? 2 : 0; if (use_with_comm) { - if (lower_input_is_edge_) { - throw deepmd::deepmd_exception( - "SeZM edge-schema .pt2 inference requires the regular single-rank " - "AOTInductor artifact. Multi-rank inference must use an artifact " - "whose lower input schema includes explicit communication tensors."); - } bool has_null_atoms = (nall_real < nall); std::vector comm_tensors; - if (has_null_atoms) { + if (phantom_n > 0) { + // Empty rank: the phantom prefix shifts every node index by ``phantom_n`` + // (received ghost features land at [phantom_n, nall)), so the forwarded + // send indices shift with it. Build the send-list in the real-atom node + // space -- fwd_map remap when NULL-type atoms were filtered, raw LAMMPS + // indices otherwise -- then offset every entry by ``phantom_n``. Without + // the offset border_op forwards the zeroed phantom slots instead of the + // relayed ghost features, corrupting neighbour ranks under subdomains + // small enough for an empty rank to relay (sendnum > 0). + if (has_null_atoms) { + deepmd::remap_comm_sendlist(remapped_sendlist, remapped_sendnum, + remapped_recvnum, lmp_list, fwd_map); + } else { + remapped_sendlist.resize(lmp_list.nswap); + remapped_sendnum.assign(lmp_list.sendnum, + lmp_list.sendnum + lmp_list.nswap); + remapped_recvnum.assign(lmp_list.recvnum, + lmp_list.recvnum + lmp_list.nswap); + for (int iswap = 0; iswap < lmp_list.nswap; ++iswap) { + remapped_sendlist[iswap].assign( + lmp_list.sendlist[iswap], + lmp_list.sendlist[iswap] + lmp_list.sendnum[iswap]); + } + } + remapped_sendlist_ptrs.resize(lmp_list.nswap); + for (int iswap = 0; iswap < lmp_list.nswap; ++iswap) { + for (int& idx : remapped_sendlist[iswap]) { + idx += phantom_n; + } + remapped_sendlist_ptrs[iswap] = remapped_sendlist[iswap].data(); + } + comm_tensors = deepmd::ptexpt::build_comm_tensors_positional( + lmp_list, remapped_sendlist_ptrs.data(), remapped_sendnum.data(), + remapped_recvnum.data(), phantom_n, nghost_real); + } else if (has_null_atoms) { comm_tensors = deepmd::ptexpt::build_comm_tensors_positional_with_virtual_atoms( lmp_list, fwd_map, nloc, nghost_real, remapped_sendlist, @@ -620,9 +713,54 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, lmp_list, lmp_list.sendlist, lmp_list.sendnum, lmp_list.recvnum, nloc, nghost_real); } - flat_outputs = run_model_with_comm( - coord_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, - fparam_tensor, aparam_tensor, charge_spin_tensor, comm_tensors); + if (lower_input_is_edge_) { + if (phantom_n > 0) { + // Prepend ``phantom_n`` type-0 local atoms to the extended node set and + // supply two masked self-edges (edge_mask=false), so the graph runs at + // nloc>=2 / nedge>=2 with zero physical contribution. Real ghost + // features still arrive via border_op at slots [phantom_n, nall); the + // phantom prefix is stripped from the outputs below. + const auto bool_option = + torch::TensorOptions().device(torch::kCPU).dtype(torch::kBool); + at::Tensor ph_coord = torch::cat( + {torch::zeros({1, phantom_n, 3}, options).to(device), coord_Tensor}, + 1); + at::Tensor ph_ext_atype = torch::cat( + {torch::zeros({1, phantom_n}, int_option).to(device), atype_Tensor}, + 1); + at::Tensor ph_loc_atype = + torch::zeros({1, phantom_n}, int_option).to(device); + at::Tensor ph_edge_index = torch::zeros({2, 2}, int_option).to(device); + at::Tensor ph_edge_vec = torch::zeros({2, 3}, options).to(device); + at::Tensor ph_edge_mask = torch::zeros({2}, bool_option).to(device); + at::Tensor ph_aparam = + (daparam > 0) + ? torch::zeros({1, phantom_n, daparam}, options).to(device) + : aparam_tensor; + flat_outputs = run_model_edges_with_comm( + ph_coord, ph_loc_atype, ph_ext_atype, ph_edge_index, ph_edge_vec, + ph_edge_index, ph_edge_mask, fparam_tensor, ph_aparam, + charge_spin_tensor, comm_tensors); + } else { + // SeZM edge schema: edge_index already indexes the extended node set + // (fold_to_local=false above), so it doubles as the force-scatter + // index. Ghost node features are exchanged between blocks inside the + // with-comm graph via border_op; the local atom types feed fitting + // while the extended types embed ghost neighbours. + const auto edge_tensors = + compactEdgeTensors(edge_index_tensor, edge_index_ext_tensor, + coord_Tensor, static_cast(rcut)); + flat_outputs = run_model_edges_with_comm( + coord_Tensor, atype_Tensor.slice(1, 0, nloc), atype_Tensor, + edge_tensors.edge_index, edge_tensors.edge_vec, + edge_tensors.edge_index_ext, edge_tensors.edge_mask, fparam_tensor, + aparam_tensor, charge_spin_tensor, comm_tensors); + } + } else { + flat_outputs = run_model_with_comm( + coord_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, + fparam_tensor, aparam_tensor, charge_spin_tensor, comm_tensors); + } } else { if (lower_input_is_edge_) { const auto edge_tensors = @@ -644,6 +782,25 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, std::map output_map; extract_outputs(output_map, flat_outputs); + if (phantom_n > 0) { + // Strip the phantom local prefix and zero the empty rank's energy. The + // phantom atoms carry no edges, so their force / per-atom virial are + // already zero and the global virial reduces to zero; only the per-type + // bias in energy_redu must be cleared so it does not enter the MPI-reduced + // total. After stripping, the extended force / atom virial regain their + // original (nf, nall_real, ...) extent and the local atom energy becomes + // empty, matching the real rank's nloc == 0. + output_map["energy_redu"] = torch::zeros_like(output_map["energy_redu"]); + output_map["energy_derv_r"] = + output_map["energy_derv_r"].slice(1, phantom_n).contiguous(); + if (atomic) { + output_map["energy"] = + output_map["energy"].slice(1, phantom_n).contiguous(); + output_map["energy_derv_c"] = + output_map["energy_derv_c"].slice(1, phantom_n).contiguous(); + } + } + // Extract energy: energy_redu (nf, 1) torch::Tensor flat_energy_ = output_map["energy_redu"].view({-1}).to(torch::kCPU); diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index b2adbd63d1..516ccce964 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -593,23 +593,36 @@ inline std::vector build_comm_tensors_positional( auto int64_option = torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64); + // The with-comm AOTInductor artifact is compiled assuming 16-byte-aligned + // inputs (the freeze-time sample comm tensors are torch-allocated). LAMMPS' + // raw send/recv arrays and the MPI handle carry only their natural element + // alignment, so wrapping them with ``from_blob`` would force AOTInductor to + // copy each input to an aligned buffer on every step (a per-step warning and + // copy). ``clone`` materialises them in torch-allocated aligned storage; the + // pointer values inside ``sendlist`` are copied verbatim and still address + // the live LAMMPS swap buffers. The clones are tiny (``nswap`` elements), so + // the one-time copy is negligible. at::Tensor sendlist_tensor = - torch::from_blob(static_cast(sendlist), {nswap}, int64_option); + torch::from_blob(static_cast(sendlist), {nswap}, int64_option) + .clone(); at::Tensor sendproc_tensor = - torch::from_blob(lmp_list.sendproc, {nswap}, int32_option); + torch::from_blob(lmp_list.sendproc, {nswap}, int32_option).clone(); at::Tensor recvproc_tensor = - torch::from_blob(lmp_list.recvproc, {nswap}, int32_option); - at::Tensor sendnum_tensor = torch::from_blob(sendnum, {nswap}, int32_option); - at::Tensor recvnum_tensor = torch::from_blob(recvnum, {nswap}, int32_option); + torch::from_blob(lmp_list.recvproc, {nswap}, int32_option).clone(); + at::Tensor sendnum_tensor = + torch::from_blob(sendnum, {nswap}, int32_option).clone(); + at::Tensor recvnum_tensor = + torch::from_blob(recvnum, {nswap}, int32_option).clone(); - static std::int64_t null_communicator = 0; + std::int64_t null_communicator = 0; at::Tensor communicator_tensor; if (lmp_list.world == nullptr) { communicator_tensor = - torch::from_blob(&null_communicator, {1}, int64_option); + torch::from_blob(&null_communicator, {1}, int64_option).clone(); } else { communicator_tensor = - torch::from_blob(const_cast(lmp_list.world), {1}, int64_option); + torch::from_blob(const_cast(lmp_list.world), {1}, int64_option) + .clone(); } at::Tensor nlocal_tensor = torch::tensor(nlocal, int32_option); diff --git a/source/tests/pt/model/test_sezm_export.py b/source/tests/pt/model/test_sezm_export.py index a233ef8406..2082b85293 100644 --- a/source/tests/pt/model/test_sezm_export.py +++ b/source/tests/pt/model/test_sezm_export.py @@ -33,12 +33,65 @@ from deepmd.pt.entrypoints.freeze_pt2 import ( _build_dynamic_shapes, + _build_with_comm_dynamic_shapes, _collect_metadata, + _make_comm_sample_inputs, _make_sample_inputs, _resolve_nframes, freeze_sezm_to_pt2, is_sezm_checkpoint, ) +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) + +_COMM_KEYS = ( + "send_list", + "send_proc", + "recv_proc", + "send_num", + "recv_num", + "communicator", + "nlocal", + "nghost", +) + + +def _eager_parallel_forward( + model: torch.nn.Module, + comm_sample: tuple, +) -> dict[str, torch.Tensor]: + """Reference parallel lower forward driven by the with-comm sample tensors.""" + ( + coord, + atype, + extended_atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fparam, + aparam, + charge_spin, + ) = comm_sample[:10] + comm_dict = dict(zip(_COMM_KEYS, comm_sample[10:18], strict=True)) + eager_coord = coord.detach().clone().requires_grad_(True) + return model.forward_common_lower( + eager_coord, + atype, + edge_index, + edge_vec.detach(), + edge_scatter_index, + edge_mask, + fparam=fparam, + aparam=aparam, + comm_dict=comm_dict, + extended_atype=extended_atype, + charge_spin=charge_spin, + use_compile=False, + ) + + from deepmd.pt.model.model import ( get_model, ) @@ -305,11 +358,15 @@ def _assert_dict_allclose( *, context: str, ) -> None: - test_pairs = ( - list(test_dict.items()) - if hasattr(test_dict, "items") - else list(zip(ref.keys(), test_dict, strict=True)) - ) + if hasattr(test_dict, "items"): + self.assertEqual( + set(test_dict.keys()), + set(ref.keys()), + msg=f"{context}: exported output keys do not match the reference", + ) + test_pairs = list(test_dict.items()) + else: + test_pairs = list(zip(ref.keys(), test_dict, strict=True)) for key, test_val in test_pairs: self.assertIn(key, ref, msg=f"{context}: unexpected output key {key!r}") ref_val = ref[key] @@ -402,6 +459,102 @@ def test_loaded_pte_matches_dense(self) -> None: ) +@unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) +class TestSeZMWithCommExportPipeline(_ClearDefaultDeviceTestCase): + """Trace / export / ``.pte`` parity for the parallel with-comm lower graph. + + The with-comm artifact threads the eight ``border_op`` communication tensors + so cross-rank ghost exchange is captured as opaque external calls. A + single-process self-send plan reduces the exchange to an owner->ghost copy, + so the exported program must reproduce the eager parallel forward to fp64 + round-off, on both the trace shape and a different owned-atom count. + """ + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + try: + ensure_comm_registered() + cls.model = _build_tiny_sezm_model() + cls.sample_inputs = _make_comm_sample_inputs(cls.model, nloc=7, device=_CPU) + traced = cls.model.forward_common_lower_exportable_with_comm( + *cls.sample_inputs + ) + exported = torch.export.export( + traced, + cls.sample_inputs, + dynamic_shapes=_build_with_comm_dynamic_shapes(cls.sample_inputs), + strict=False, + prefer_deferred_runtime_asserts_over_guards=True, + ) + cls.traced = traced + cls._pte_tmp = tempfile.NamedTemporaryFile(suffix=".pte", delete=True) + torch.export.save(exported, cls._pte_tmp.name) + cls.loaded = torch.export.load(cls._pte_tmp.name).module() + except Exception: + super().tearDownClass() + raise + + @classmethod + def tearDownClass(cls) -> None: + try: + for attr in ("loaded", "traced", "model", "sample_inputs"): + if hasattr(cls, attr): + delattr(cls, attr) + if hasattr(cls, "_pte_tmp"): + cls._pte_tmp.close() + delattr(cls, "_pte_tmp") + finally: + super().tearDownClass() + + def _assert_dict_allclose( + self, + ref: dict[str, torch.Tensor], + test_dict: dict[str, torch.Tensor] | object, + *, + context: str, + ) -> None: + if hasattr(test_dict, "items"): + self.assertEqual( + set(test_dict.keys()), + set(ref.keys()), + msg=f"{context}: exported output keys do not match the reference", + ) + test_pairs = list(test_dict.items()) + else: + test_pairs = list(zip(ref.keys(), test_dict, strict=True)) + for key, test_val in test_pairs: + np.testing.assert_allclose( + ref[key].detach().cpu().numpy(), + test_val.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"{context}: {key}", + ) + + def test_traced_matches_eager(self) -> None: + eager = _eager_parallel_forward(self.model, self.sample_inputs) + traced_out = self.traced(*self.sample_inputs) + self._assert_dict_allclose( + eager, traced_out, context="with-comm traced vs eager" + ) + + def test_loaded_pte_matches_eager(self) -> None: + eager = _eager_parallel_forward(self.model, self.sample_inputs) + loaded_out = self.loaded(*self.sample_inputs) + self._assert_dict_allclose(eager, loaded_out, context="with-comm .pte vs eager") + + def test_loaded_pte_matches_eager_different_nloc(self) -> None: + # nloc=11 retargets the owned-atom symbol away from the trace value (7); + # nall/nedge follow from the geometry and nswap stays 1. + infer_inputs = _make_comm_sample_inputs(self.model, nloc=11, device=_CPU) + eager = _eager_parallel_forward(self.model, infer_inputs) + loaded_out = self.loaded(*infer_inputs) + self._assert_dict_allclose( + eager, loaded_out, context="with-comm .pte vs eager (infer shape)" + ) + + class _FrozenPt2Fixture(_ClearDefaultDeviceTestCase): """Shared setUp/tearDown: freeze a tiny SeZM checkpoint to ``.pt2`` once. @@ -837,6 +990,7 @@ def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): freeze_sezm_to_pt2(str(ckpt_path), str(out), device=_CPU) with zipfile.ZipFile(str(out), "r") as zf: + names = zf.namelist() metadata = json.loads( zf.read("model/extra/metadata.json").decode("utf-8") ) @@ -851,6 +1005,40 @@ def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): self.assertEqual(metadata["ntypes_spin"], 1) self.assertIn("energy_derv_r_mag", metadata["output_keys"]) self.assertIn("energy_derv_c_redu", metadata["output_keys"]) + # Spin uses the nlist lower interface; the edge-based with-comm artifact + # does not apply, so multi-rank inference fails fast in C++. + self.assertFalse(metadata["has_comm_artifact"]) + self.assertNotIn("model/extra/forward_lower_with_comm.pt2", names) + + @unittest.skipIf(_SKIP_OFF_COMPILE_TORCH, _SKIP_OFF_COMPILE_TORCH_REASON) + def test_freeze_embeds_with_comm_artifact(self) -> None: + """A plain SeZM checkpoint ships the nested multi-rank with-comm artifact.""" + + def fake_compile(_exported: torch.export.ExportedProgram, package_path: str): + with zipfile.ZipFile(package_path, "w") as zf: + zf.writestr("model/data.pkl", b"") + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + params = _tiny_sezm_model_params() + ckpt_path = _write_tiny_sezm_checkpoint(tmp_path, params) + out = tmp_path / "with_comm.pt2" + + with mock.patch( + "torch._inductor.aoti_compile_and_package", + side_effect=fake_compile, + ): + freeze_sezm_to_pt2(str(ckpt_path), str(out), device=_CPU) + + with zipfile.ZipFile(str(out), "r") as zf: + names = zf.namelist() + metadata = json.loads( + zf.read("model/extra/metadata.json").decode("utf-8") + ) + + self.assertTrue(metadata["has_comm_artifact"]) + self.assertTrue(metadata["has_message_passing"]) + self.assertIn("model/extra/forward_lower_with_comm.pt2", names) if __name__ == "__main__": diff --git a/source/tests/pt/model/test_sezm_parallel.py b/source/tests/pt/model/test_sezm_parallel.py new file mode 100644 index 0000000000..4b30b1a5e5 --- /dev/null +++ b/source/tests/pt/model/test_sezm_parallel.py @@ -0,0 +1,446 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity tests for SeZM LAMMPS multi-rank (edge-based) inference. + +The parallel path expands the descriptor node set to the extended region and +refreshes ghost-node features through ``deepmd_export::border_op`` between +interaction blocks. A single process can emulate one MPI rank by driving +``border_op`` with a self-send swap whose send-list maps each ghost slot to its +local owner; the exchange then reduces to an exact owner->ghost copy, so the +parallel path must reproduce the single-domain (folded) path bit-for-bit on the +owned atoms. These tests pin that equivalence end-to-end (descriptor and model) +and guard the export-capability predicate used to gate the with-comm artifact. +""" + +from __future__ import ( + annotations, +) + +import ctypes +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.pt.model.descriptor.sezm_nn import block as sezm_block +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt_expt.utils.comm import ( + ensure_comm_registered, +) + +# Self-send send-lists embed the address of a numpy array; the arrays must +# outlive the eager ``border_op`` calls that dereference them. +_SENDLIST_KEEPALIVE: list[np.ndarray] = [] + + +def _tiny_parallel_model_params(**overrides) -> dict: + """Minimal fp64 SeZM config exercising message passing and FiLM/GIE seeds.""" + descriptor = { + "type": "SeZM", + "sel": [8, 8], + "rcut": 4.0, + "channels": 8, + "n_focus": 1, + "n_radial": 4, + "radial_mlp": [8], + "use_env_seed": True, + "l_schedule": [1, 1, 0], + "mmax": 1, + "so2_layers": 2, + "n_atten_head": 1, + "ffn_neurons": 16, + "ffn_blocks": 1, + "mlp_bias": False, + "use_amp": False, + "random_gamma": False, + "precision": "float64", + "seed": 11, + } + descriptor.update(overrides.pop("descriptor", {})) + params = { + "type": "SeZM", + "type_map": ["A", "B"], + "descriptor": descriptor, + "fitting_net": { + "neuron": [16], + "activation_function": "silu", + "precision": "float64", + "seed": 11, + }, + "use_compile": False, + } + params.update(overrides) + return params + + +def _build_model(device: torch.device, **overrides) -> torch.nn.Module: + """Build a tiny SeZM model in eval mode on ``device``.""" + model = get_model(_tiny_parallel_model_params(**overrides)) + model.eval() + model.to(device) + return model + + +def _build_extended_system( + model: torch.nn.Module, + device: torch.device, + *, + nloc: int = 6, + seed: int = 3, +) -> dict[str, torch.Tensor]: + """Build one periodic frame with ghost atoms and its edge schema. + + Returns the folded edge schema (single-domain convention) together with the + extended atom types and the extended-to-local mapping needed to assemble the + self-send communication plan. + """ + rcut = float(model.get_rcut()) + sel = list(model.get_sel()) + ntypes = len(model.get_type_map()) + box_size = rcut * 2.5 + box = np.eye(3, dtype=np.float64) * box_size + + rng = np.random.default_rng(seed) + coord_np = rng.random((1, nloc, 3), dtype=np.float64) * box_size + atype_np = (np.arange(nloc, dtype=np.int32) % ntypes).reshape(1, nloc) + + coord_norm = normalize_coord(coord_np, np.tile(box.reshape(1, 3, 3), (1, 1, 1))) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype_np, box.reshape(1, 9), rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not model.mixed_types(), + ) + extended_coord = extended_coord.reshape(1, -1, 3) + + ext_coord = torch.tensor(extended_coord, dtype=torch.float64, device=device) + ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=device) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=device) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=device) + + formatted = model.format_nlist(ext_coord, ext_atype, nlist_t) + from deepmd.pt_expt.utils.edge_schema import ( + edge_schema_from_extended, + ) + + schema = edge_schema_from_extended( + ext_coord, ext_atype[:, :nloc], formatted, mapping_t + ) + return { + "coord": schema.coord, + "atype": schema.atype, + "extended_atype": ext_atype, + "edge_index": schema.edge_index, + "edge_vec": schema.edge_vec, + "edge_scatter_index": schema.edge_scatter_index, + "edge_mask": schema.edge_mask, + "mapping": mapping_t, + "nloc": nloc, + "nall": ext_coord.shape[1], + } + + +def _self_comm_dict( + mapping: torch.Tensor, + nloc: int, + nall: int, +) -> dict[str, torch.Tensor]: + """Build a single self-send swap that copies each owner into its ghost slot. + + Ghost slot ``k`` reads local index ``mapping[nloc + k]`` (its owner), so the + eager ``border_op`` self-send memcpy reproduces the folded gather exactly. + Control tensors live on CPU per the C++ host-side dereference contract. + """ + nghost = nall - nloc + send_count = max(1, nghost) + owner = mapping[0, nloc:nall].to(dtype=torch.int32).cpu().numpy() + indices = np.ascontiguousarray(np.resize(owner, send_count).astype(np.int32)) + _SENDLIST_KEEPALIVE.append(indices) + addr = indices.ctypes.data_as(ctypes.c_void_p).value + cpu = torch.device("cpu") + return { + "send_list": torch.tensor([addr], dtype=torch.int64, device=cpu), + "send_proc": torch.zeros(1, dtype=torch.int32, device=cpu), + "recv_proc": torch.zeros(1, dtype=torch.int32, device=cpu), + "send_num": torch.tensor([send_count], dtype=torch.int32, device=cpu), + "recv_num": torch.tensor([send_count], dtype=torch.int32, device=cpu), + "communicator": torch.zeros(1, dtype=torch.int64, device=cpu), + "nlocal": torch.tensor(nloc, dtype=torch.int32, device=cpu), + "nghost": torch.tensor(nghost, dtype=torch.int32, device=cpu), + } + + +def _perturb_descriptor(descriptor: torch.nn.Module, *, seed: int = 0) -> None: + """Push descriptor weights away from their near-identity initialization. + + SeZM initializes interaction blocks close to identity, so a freshly built + model has near-zero message-passing contributions. That masks ghost-exchange + bugs whose error is proportional to the convolution output. Perturbing every + parameter simulates a trained model and makes the parity test sensitive to + them. + """ + generator = torch.Generator(device="cpu").manual_seed(seed) + with torch.no_grad(): + for param in descriptor.parameters(): + noise = torch.randn( + param.shape, + generator=generator, + dtype=param.dtype, + device="cpu", + ).to(param.device) + param.add_(noise * 0.5) + + +class TestSeZMSelfCommParity(unittest.TestCase): + """Self-send ``border_op`` must reproduce the single-domain folded path. + + The attention-residual configurations are exercised with perturbed weights: + the depth-history feeds the SO(2) convolution, so a trained model (non-zero + message passing) is required to catch a stale-ghost regression there. + """ + + @classmethod + def setUpClass(cls) -> None: + ensure_comm_registered() + + def _assert_parity( + self, + device: torch.device, + rtol: float, + atol: float, + *, + descriptor_overrides: dict | None = None, + ) -> None: + model = _build_model(device, descriptor=descriptor_overrides or {}) + # An untrained SeZM model is geometry-independent (identically zero + # forces), for which the parity comparison below holds vacuously for any + # ghost-exchange implementation. Perturbing the descriptor (see + # ``_perturb_descriptor``) restores non-zero, ghost-feature-dependent + # forces so the comparison is load-bearing. + _perturb_descriptor(model.atomic_model.descriptor) + sysm = _build_extended_system(model, device) + comm = _self_comm_dict(sysm["mapping"], sysm["nloc"], sysm["nall"]) + + ref = model.forward_lower( + sysm["coord"], + sysm["atype"], + sysm["edge_index"], + sysm["edge_vec"], + sysm["edge_scatter_index"], + sysm["edge_mask"], + do_atomic_virial=True, + ) + par = model.forward_lower( + sysm["coord"], + sysm["atype"], + # The parallel path indexes the extended node set directly, so the + # extended scatter index doubles as the message-passing edge_index. + sysm["edge_scatter_index"], + sysm["edge_vec"], + sysm["edge_scatter_index"], + sysm["edge_mask"], + do_atomic_virial=True, + comm_dict=comm, + extended_atype=sysm["extended_atype"], + ) + + # Reject the degenerate zero-force regime so the parity assertion can + # never pass vacuously: the reference force field must dominate the + # comparison tolerance. + self.assertGreater( + ref["extended_force"].abs().max().item(), + atol * 1e3, + msg="reference forces are ~0; the parity check would be vacuous", + ) + for key in ("energy", "extended_force", "virial", "extended_virial"): + torch.testing.assert_close( + par[key], ref[key], rtol=rtol, atol=atol, msg=f"mismatch in {key}" + ) + + def test_parity_cpu(self) -> None: + self._assert_parity(torch.device("cpu"), rtol=1e-8, atol=1e-9) + + @unittest.skipUnless(torch.cuda.is_available(), "CUDA required") + def test_parity_cuda(self) -> None: + # CUDA atomic scatter reorders accumulation, so the tolerance is looser + # than the deterministic CPU path while still pinning correctness. + self._assert_parity(env.DEVICE, rtol=1e-6, atol=1e-7) + + def test_parity_full_attn_res_cpu(self) -> None: + self._assert_parity( + torch.device("cpu"), + rtol=1e-8, + atol=1e-9, + descriptor_overrides={"full_attn_res": "dependent", "so2_layers": 2}, + ) + + def test_parity_block_attn_res_cpu(self) -> None: + self._assert_parity( + torch.device("cpu"), + rtol=1e-8, + atol=1e-9, + descriptor_overrides={"block_attn_res": "dependent", "so2_layers": 2}, + ) + + def test_parity_no_env_seed_single_block_cpu(self) -> None: + # env_seed off + a single block: the only neighbour feature block 0 + # reads is the type embedding, which a rank can recompute from + # ``extended_atype`` -- the ghost exchange is then redundant but must + # stay exact (it copies identical type embeddings). + self._assert_parity( + torch.device("cpu"), + rtol=1e-8, + atol=1e-9, + descriptor_overrides={"l_schedule": [1], "use_env_seed": False}, + ) + + def test_parity_no_env_seed_cpu(self) -> None: + # Multi-block without env-seed: ghost features still carry block outputs + # that a rank cannot recompute, so the exchange is load-bearing. + self._assert_parity( + torch.device("cpu"), + rtol=1e-8, + atol=1e-9, + descriptor_overrides={"use_env_seed": False}, + ) + + +class TestSeZMDescriptorSelfCommParity(unittest.TestCase): + """Descriptor-level parity isolates the ghost-exchange from the force scatter.""" + + @classmethod + def setUpClass(cls) -> None: + ensure_comm_registered() + + def test_descriptor_parity_cpu(self) -> None: + device = torch.device("cpu") + model = _build_model(device, descriptor={"full_attn_res": "dependent"}) + descriptor = model.atomic_model.descriptor + _perturb_descriptor(descriptor) + sysm = _build_extended_system(model, device) + comm = _self_comm_dict(sysm["mapping"], sysm["nloc"], sysm["nall"]) + + ref, _ = descriptor.forward_with_edges( + extended_coord=sysm["coord"][:, : sysm["nloc"], :], + extended_atype=sysm["atype"], + edge_index=sysm["edge_index"], + edge_vec=sysm["edge_vec"], + edge_mask=sysm["edge_mask"], + ) + par, _ = descriptor.forward_with_edges( + extended_coord=sysm["coord"], + extended_atype=sysm["extended_atype"], + edge_index=sysm["edge_scatter_index"], + edge_vec=sysm["edge_vec"], + edge_mask=sysm["edge_mask"], + comm_dict=comm, + nloc=sysm["nloc"], + ) + torch.testing.assert_close(par, ref, rtol=1e-8, atol=1e-9) + + +class TestSeZMEdgeParallelCapability(unittest.TestCase): + """The with-comm export predicate gates bridging and spin out.""" + + def test_plain_model_supports_edge_parallel(self) -> None: + model = _build_model(torch.device("cpu")) + self.assertTrue(model.supports_edge_parallel()) + self.assertTrue( + model.atomic_model.descriptor.has_message_passing_across_ranks() + ) + + def test_bridging_model_fails_fast(self) -> None: + # ZBL needs real element symbols for its analytical pair potential. + model = _build_model( + torch.device("cpu"), + type_map=["O", "H"], + bridging_method="ZBL", + bridging_r_inner=0.5, + bridging_r_outer=1.0, + ) + self.assertFalse(model.supports_edge_parallel()) + + +class TestSeZMExchangeSchedule(unittest.TestCase): + """The ghost exchange is scheduled per block, not blanket-applied. + + A block exchanges only when its SO(2) convolution reads neighbour rows that + the local rank cannot rebuild: block 0 needs it only with env-seed/GIE (which + fold neighbour environment into the initial state), and later blocks always + need it (they read previous-block outputs). A purely local model + (``use_env_seed=False`` with one block) must therefore communicate zero + times, preserving its single-pass speed under domain decomposition. + """ + + @classmethod + def setUpClass(cls) -> None: + ensure_comm_registered() + + def _count_exchanges(self, descriptor_overrides: dict) -> int: + device = torch.device("cpu") + model = _build_model(device, descriptor=descriptor_overrides) + sysm = _build_extended_system(model, device) + comm = _self_comm_dict(sysm["mapping"], sysm["nloc"], sysm["nall"]) + real = sezm_block.exchange_ghost_features + count = 0 + + def counting(*args, **kwargs): + nonlocal count + count += 1 + return real(*args, **kwargs) + + sezm_block.exchange_ghost_features = counting + try: + model.forward_lower( + sysm["coord"], + sysm["atype"], + sysm["edge_scatter_index"], + sysm["edge_vec"], + sysm["edge_scatter_index"], + sysm["edge_mask"], + comm_dict=comm, + extended_atype=sysm["extended_atype"], + ) + finally: + sezm_block.exchange_ghost_features = real + return count + + def test_local_single_block_skips_all_comm(self) -> None: + self.assertEqual( + self._count_exchanges({"l_schedule": [1], "use_env_seed": False}), 0 + ) + + def test_env_seed_single_block_exchanges_once(self) -> None: + self.assertEqual( + self._count_exchanges({"l_schedule": [1], "use_env_seed": True}), 1 + ) + + def test_no_env_seed_multi_block_skips_first(self) -> None: + self.assertEqual( + self._count_exchanges({"l_schedule": [1, 1, 0], "use_env_seed": False}), 2 + ) + + def test_env_seed_multi_block_exchanges_every_block(self) -> None: + self.assertEqual( + self._count_exchanges({"l_schedule": [1, 1, 0], "use_env_seed": True}), 3 + ) + + +if __name__ == "__main__": + unittest.main()