diff --git a/deepmd/pt/nvalchemi/__init__.py b/deepmd/pt/nvalchemi/__init__.py new file mode 100644 index 0000000000..914d7c00f5 --- /dev/null +++ b/deepmd/pt/nvalchemi/__init__.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Optional bridge between DeePMD-kit SeZM / DPA-4 models and NVIDIA's +``nvalchemi-toolkit`` molecular-dynamics framework. + +``nvalchemi-toolkit`` is an optional dependency. Importing this subpackage +without it installed raises a clear, actionable error instead of an opaque +``ModuleNotFoundError`` deep in the import chain. + +Example +------- +:: + + from deepmd.pt.nvalchemi import DPA4Wrapper + from nvalchemi.data import AtomicData, Batch + from nvalchemi.neighbors import compute_neighbors + + model = DPA4Wrapper.from_checkpoint("model.pt", device="cuda") + batch = Batch.from_data_list([data], device="cuda") + compute_neighbors(batch, config=model.model_config.neighbor_config) + out = model(batch) # {"energy": (B, 1), "forces": (N, 3), ...} +""" + +from __future__ import ( + annotations, +) + +try: + import nvalchemi # noqa: F401 +except ImportError as e: # pragma: no cover - exercised only without the dep + raise ImportError( + "deepmd.pt.nvalchemi requires the optional `nvalchemi-toolkit` package. " + "Install it with `pip install deepmd-kit[nvalchemi]` " + "(or `pip install nvalchemi-toolkit`)." + ) from e + +from .dpa4wrapper import ( + DPA4Wrapper, + SeZMWrapper, +) + +__all__ = ["DPA4Wrapper", "SeZMWrapper"] diff --git a/deepmd/pt/nvalchemi/dpa4wrapper.py b/deepmd/pt/nvalchemi/dpa4wrapper.py new file mode 100644 index 0000000000..61f8e9ea88 --- /dev/null +++ b/deepmd/pt/nvalchemi/dpa4wrapper.py @@ -0,0 +1,665 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""nvalchemi-toolkit wrapper for the SeZM / DPA-4 PyTorch energy model. + +:class:`DPA4Wrapper` adapts a trained DeePMD-kit SeZM / DPA-4 model to the +``nvalchemi-toolkit`` model interface (:class:`~nvalchemi.models.base.BaseModelMixin`) +so it can be driven by any ``nvalchemi`` dynamics engine (NVE, NVT, NPT, FIRE). +The underlying model is used unchanged; the wrapper only translates between the +``nvalchemi`` graph batch and SeZM's sparse edge-list interface, and maps the +outputs back to ``nvalchemi``'s ``energy`` / ``forces`` / ``stress``. + +Two backends are supported through :meth:`DPA4Wrapper.from_checkpoint`: + +* a ``.pt`` training checkpoint, run eagerly as a :class:`SeZMModel`. Set + ``DP_COMPILE_INFER=1`` (optionally ``DP_TRITON_INFER=1``) before loading to + enable SeZM's compiled-inference path. +* a frozen ``.pt2`` AOTInductor package, run through its precompiled callable + (float64 I/O, device-locked to the host it was frozen on). + +Neighbour-list and geometry conventions +--------------------------------------- +``nvalchemi`` supplies a COO neighbour list whose rows are ``[source, target]`` +(``source`` is the centre atom, ``target`` the neighbour), with an integer image +``neighbor_list_shifts`` belonging to ``target``. The per-edge displacement is +``r = positions[target] - positions[source] + shifts @ cell``. SeZM consumes +``edge_index = [src, dst]`` with ``edge_vec = r_src - r_dst`` and aggregates +messages onto ``dst``, so the wrapper maps ``dst = source`` and ``src = target``. + +The reported stress is the Cauchy stress (virial divided by the cell volume), +which matches ``nvalchemi``'s sign convention. A whole batch is presented to +SeZM as a single frame; per-graph energy and virial are recovered by +segment-summing the per-atom outputs with ``batch_idx``. +""" + +from __future__ import ( + annotations, +) + +from collections import ( + OrderedDict, +) +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +from nvalchemi.data import ( + AtomicData, + Batch, +) +from nvalchemi.models.base import ( + BaseModelMixin, + ModelConfig, + NeighborConfig, + NeighborListFormat, +) +from torch import ( + nn, +) + +from deepmd.pt.model.model.sezm_model import ( + ELEMENT_TO_Z, + SeZMModel, +) + +if TYPE_CHECKING: + from pathlib import ( + Path, + ) + + from nvalchemi._typing import ( + ModelOutputs, + ) + +__all__ = ["DPA4Wrapper", "SeZMWrapper"] + +# Location of the metadata written into a SeZM/DPA-4 ``.pt2`` archive. +_PT2_METADATA_ENTRY = "model/extra/metadata.json" + + +class DPA4Wrapper(nn.Module, BaseModelMixin): + """Wrap a trained SeZM / DPA-4 model as an ``nvalchemi`` model. + + Use :meth:`from_checkpoint` to load from a ``.pt`` checkpoint or a frozen + ``.pt2`` package, or construct directly from an in-memory :class:`SeZMModel`. + + Parameters + ---------- + model + An instantiated SeZM / DPA-4 energy model (:class:`SeZMModel`). It is put + into ``eval`` mode so the sparse-edge path is deterministic. + atomic_number_to_type + Optional explicit mapping ``{atomic_number: type_index}``. When ``None`` + the mapping is derived from the model ``type_map`` via the periodic + table, which requires the type map to contain element symbols. Provide + this override for non-element type maps. + compute_stress + If ``True``, ``"stress"`` is added to the active outputs so NPT / NPH + barostats receive the Cauchy stress. Stress requires a periodic ``cell``. + default_charge_spin + Optional global ``[charge, spin]`` forwarded to SeZM models built with + ``add_chg_spin_ebd=True``. It applies to the whole batch. + + Attributes + ---------- + model + The wrapped :class:`SeZMModel`, or ``None`` for a ``.pt2`` backend. + model_config + Mutable :class:`~nvalchemi.models.base.ModelConfig` controlling which + outputs are produced and describing the required neighbour list. + """ + + model: SeZMModel | None + + def __init__( + self, + model: SeZMModel, + *, + atomic_number_to_type: dict[int, int] | None = None, + compute_stress: bool = False, + default_charge_spin: list[float] | None = None, + ) -> None: + super().__init__() + self.model = model.eval() + self._aoti_runner: Any | None = None + self._aoti_dim_fparam = 0 + self._aoti_dim_aparam = 0 + self._dtype = next(model.parameters()).dtype + self._descriptor_dim: int | None = int( + model.atomic_model.descriptor.get_dim_out() + ) + self._configure( + rcut=float(model.get_rcut()), + type_map=list(model.get_type_map()), + device=next(model.parameters()).device, + atomic_number_to_type=atomic_number_to_type, + compute_stress=compute_stress, + default_charge_spin=default_charge_spin, + ) + + def _configure( + self, + *, + rcut: float, + type_map: list[str], + device: torch.device, + atomic_number_to_type: dict[int, int] | None, + compute_stress: bool, + default_charge_spin: list[float] | None, + ) -> None: + """Set up the shared configuration common to both backends.""" + self.default_charge_spin = default_charge_spin + self.rcut = float(rcut) + + # Memoized species mapping, keyed on the atomic-number tensor identity so + # a dynamics run validates types once instead of on every step. + self._atype_cache: tuple[tuple[int, int, torch.device], torch.Tensor] | None = ( + None + ) + # Pre-build the optional charge/spin condition once so the hot path does + # not rebuild it from a Python list (a host-to-device copy) every step. + charge_spin = ( + None + if default_charge_spin is None + else torch.tensor( + default_charge_spin, dtype=self._dtype, device=device + ).view(1, 2) + ) + self.register_buffer("_charge_spin_buf", charge_spin, persistent=False) + + z_to_type = self._build_z_to_type(type_map, atomic_number_to_type, device) + # persistent=False: derived from the type map, excluded from the state + # dict but kept in sync with ``.to()`` device moves. + self.register_buffer("_z_to_type", z_to_type, persistent=False) + + active: set[str] = {"energy", "forces"} + if compute_stress: + active.add("stress") + # Forces and stress are produced by the model itself (SeZM's internal + # ``edge_vec`` autograd), so they are returned directly rather than via + # an ``nvalchemi`` autograd pass through ``positions``. + self.model_config = ModelConfig( + outputs=frozenset({"energy", "forces", "stress"}), + active_outputs=active, + autograd_outputs=frozenset(), + autograd_inputs=frozenset({"positions"}), + required_inputs=frozenset(), + optional_inputs=frozenset({"cell", "neighbor_list_shifts"}), + supports_pbc=True, + needs_pbc=False, + neighbor_config=NeighborConfig( + cutoff=self.rcut, + format=NeighborListFormat.COO, + half_list=False, + ), + ) + + # ------------------------------------------------------------------ + # Construction helpers + # ------------------------------------------------------------------ + + @staticmethod + def _build_z_to_type( + type_map: list[str], + atomic_number_to_type: dict[int, int] | None, + device: torch.device, + ) -> torch.Tensor: + """Build a dense ``atomic_number -> type_index`` lookup tensor. + + Atomic numbers absent from the mapping map to ``-1`` so the forward pass + can raise a clear error instead of silently mislabelling atoms. + """ + if atomic_number_to_type is None: + atomic_number_to_type = {} + for type_index, symbol in enumerate(type_map): + z = ELEMENT_TO_Z.get(symbol) + if z is None: + raise ValueError( + f"Cannot map type map entry {symbol!r} to an atomic " + "number. Pass an explicit `atomic_number_to_type` " + "mapping for non-element type maps." + ) + atomic_number_to_type[z] = type_index + if not atomic_number_to_type: + raise ValueError("`atomic_number_to_type` resolved to an empty mapping.") + max_z = max(atomic_number_to_type) + table = torch.full((max_z + 1,), -1, dtype=torch.long, device=device) + for z, type_index in atomic_number_to_type.items(): + table[int(z)] = int(type_index) + return table + + # ------------------------------------------------------------------ + # BaseModelMixin required members + # ------------------------------------------------------------------ + + @property + def embedding_shapes(self) -> dict[str, tuple[int, ...]]: + """Per-atom and per-graph descriptor embedding widths.""" + if self._descriptor_dim is None: + raise NotImplementedError( + "Embeddings are only available for the `.pt` backend, not a " + "frozen `.pt2` package." + ) + return { + "node_embeddings": (self._descriptor_dim,), + "graph_embeddings": (self._descriptor_dim,), + } + + # ------------------------------------------------------------------ + # Input / output adaptation + # ------------------------------------------------------------------ + + def _atype(self, atomic_numbers: torch.Tensor) -> torch.Tensor: + """Map atomic numbers to SeZM type indices via the lookup table. + + The result is memoized on the identity of *atomic_numbers* (storage + pointer, length and device). A dynamics run reuses the same + ``atomic_numbers`` tensor for every step while only mutating positions, + so the species mapping -- and the two host-device synchronizations its + validation needs -- run once on the first step and are skipped + afterwards, keeping the MD hot path free of stream stalls. + """ + key = (atomic_numbers.data_ptr(), atomic_numbers.numel(), atomic_numbers.device) + cached = self._atype_cache + if cached is not None and cached[0] == key: + return cached[1] + + z = atomic_numbers.long() + if z.numel() and int(z.max()) >= self._z_to_type.shape[0]: + raise ValueError("Encountered an atomic number outside the model type map.") + atype = self._z_to_type.index_select(0, z.clamp_min(0)) + if bool((atype < 0).any()): + missing = sorted({int(v) for v in z[atype < 0].tolist()}) + raise ValueError( + f"Atomic numbers {missing} are not present in the model type map." + ) + self._atype_cache = (key, atype) + return atype + + def _edge_schema( + self, data: Batch + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Translate the COO neighbour list into SeZM's edge schema. + + Returns + ------- + edge_index + ``(2, E)`` with rows ``[src=neighbour, dst=centre]`` in flattened + local-atom space. + edge_vec + ``(E, 3)`` displacement ``r_src - r_dst`` (PBC images folded in). + edge_mask + ``(E,)`` all-true validity mask for the real edges. + """ + neighbor_list = getattr(data, "neighbor_list", None) + if neighbor_list is None: + raise KeyError( + "Batch has no `neighbor_list`. Run `compute_neighbors(batch, " + "config=model.model_config.neighbor_config)` or register a " + "COO NeighborListHook before calling the model." + ) + dtype = self._dtype + device = data.positions.device + positions = data.positions.to(dtype) + + neighbor_list = neighbor_list.long() + source = neighbor_list[:, 0] + target = neighbor_list[:, 1] + n_edge = neighbor_list.shape[0] + + cell = getattr(data, "cell", None) + neighbor_list_shifts = getattr(data, "neighbor_list_shifts", None) + if cell is not None and neighbor_list_shifts is None: + raise ValueError( + "A periodic `cell` was provided without `neighbor_list_shifts`; " + "PBC image shifts are required for correct edge vectors. Run " + "`compute_neighbors` with the cell so the shifts are attached." + ) + if neighbor_list_shifts is not None and cell is not None: + cell = cell.to(dtype) + shifts = neighbor_list_shifts.to(dtype) + if cell.shape[0] == 1: + # Single-frame fast path (the common MD case): every edge shares + # the one cell, so a single (E,3)x(3,3) matmul replaces gathering + # an (E,3,3) per-edge cell and the einsum over it. + shift_vec = shifts @ cell[0] + else: + graph_per_edge = data.batch_idx.long().index_select(0, source) + cell_per_edge = cell.index_select(0, graph_per_edge) + shift_vec = torch.einsum("eb,ebc->ec", shifts, cell_per_edge) + else: + shift_vec = torch.zeros(n_edge, 3, dtype=dtype, device=device) + + edge_vec = ( + positions.index_select(0, target) + - positions.index_select(0, source) + + shift_vec + ) + edge_index = torch.stack([target, source], dim=0) + edge_mask = torch.ones(n_edge, dtype=torch.bool, device=device) + return edge_index, edge_vec, edge_mask + + def adapt_input(self, data: AtomicData | Batch, **kwargs: Any) -> dict[str, Any]: + """Build the lower-interface inputs for the wrapped model. + + The batch is presented as a single frame with ``nloc`` equal to the + total number of atoms; the COO neighbour list already carries the global + node offsets, so heterogeneous multi-graph batches need no special + handling here. + """ + del kwargs + if isinstance(data, AtomicData): + data = Batch.from_data_list([data]) + dtype = self._dtype + n_node = data.num_nodes + + coord = data.positions.to(dtype).view(1, n_node, 3) + atype = self._atype(data.atomic_numbers).view(1, n_node) + edge_index, edge_vec, edge_mask = self._edge_schema(data) + return { + "coord": coord, + "atype": atype, + "edge_index": edge_index, + "edge_vec": edge_vec, + "edge_scatter_index": edge_index, + "edge_mask": edge_mask, + "charge_spin": self._charge_spin_buf, + } + + def adapt_output( + self, model_output: dict[str, torch.Tensor], data: Batch + ) -> ModelOutputs: + """Map the lower-interface outputs to the ``nvalchemi`` output dict. + + Per-atom energy and virial are segment-summed with ``batch_idx`` so each + graph gets its own total, even though the batch is run as a single frame. + ``model_output`` carries the normalized keys ``atom_energy``, + ``extended_force`` and ``extended_virial``. + + When ``stress`` is active the cell must be non-degenerate (positive + volume): the stress is ``virial / |det(cell)|``, so a singular cell would + yield non-finite values. + """ + batch_idx = data.batch_idx.long() + n_graph = data.num_graphs + n_node = data.num_nodes + out_dtype = data.positions.dtype + + atom_energy = model_output["atom_energy"].reshape(n_node) + energy = torch.zeros( + n_graph, dtype=atom_energy.dtype, device=atom_energy.device + ).index_add_(0, batch_idx, atom_energy) + + output: ModelOutputs = OrderedDict() + output["energy"] = energy.unsqueeze(-1).to(out_dtype) + + active = self.model_config.active_outputs + if "forces" in active: + output["forces"] = ( + model_output["extended_force"].reshape(n_node, 3).to(out_dtype) + ) + if "stress" in active: + cell = getattr(data, "cell", None) + if cell is None: + raise ValueError( + "stress output requires a periodic `cell` for the volume." + ) + atom_virial = model_output["extended_virial"].reshape(n_node, 9) + virial = torch.zeros( + n_graph, 9, dtype=atom_virial.dtype, device=atom_virial.device + ).index_add_(0, batch_idx, atom_virial) + volume = torch.det(cell.to(virial.dtype)).abs().view(n_graph, 1, 1) + output["stress"] = (virial.view(n_graph, 3, 3) / volume).to(out_dtype) + return output + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def forward(self, data: AtomicData | Batch, **kwargs: Any) -> ModelOutputs: + """Run the model on a batch and return the ``nvalchemi`` output dict.""" + if isinstance(data, AtomicData): + data = Batch.from_data_list([data]) + model_inputs = self.adapt_input(data, **kwargs) + if self._aoti_runner is not None: + model_ret = self._run_pt2(model_inputs) + else: + need_virial = "stress" in self.model_config.active_outputs + model_ret = self.model.forward_lower( + model_inputs["coord"], + model_inputs["atype"], + model_inputs["edge_index"], + model_inputs["edge_vec"], + model_inputs["edge_scatter_index"], + model_inputs["edge_mask"], + do_atomic_virial=need_virial, + charge_spin=model_inputs["charge_spin"], + ) + return self.adapt_output(model_ret, data) + + def _run_pt2(self, model_inputs: dict[str, Any]) -> dict[str, torch.Tensor]: + """Run the frozen ``.pt2`` callable and normalize its output keys. + + The AOTInductor package traces ``forward_common_lower`` and returns a + dict keyed by the raw DeePMD names; here they are renamed to the keys + :meth:`adapt_output` expects. ``None`` trailing arguments are filtered + by the loader, matching the export-time signature. + """ + if self._aoti_dim_fparam or self._aoti_dim_aparam: + raise NotImplementedError( + "The `.pt2` backend does not support models requiring frame or " + "atomic parameters (fparam / aparam) through nvalchemi." + ) + ret = self._aoti_runner( + model_inputs["coord"], + model_inputs["atype"], + model_inputs["edge_index"], + model_inputs["edge_vec"], + model_inputs["edge_scatter_index"], + model_inputs["edge_mask"], + None, + None, + model_inputs["charge_spin"], + ) + return { + "atom_energy": ret["energy"], + "extended_force": ret["energy_derv_r"], + "extended_virial": ret["energy_derv_c"], + } + + # ------------------------------------------------------------------ + # Embeddings + # ------------------------------------------------------------------ + + def compute_embeddings( + self, data: AtomicData | Batch, **kwargs: Any + ) -> AtomicData | Batch: + """Attach per-atom / per-graph descriptor embeddings to *data*. + + Writes ``node_embeddings`` (``[N, descriptor_dim]``) and + ``graph_embeddings`` (``[B, descriptor_dim]``, sum-pooled over atoms) in + place. Only supported for the ``.pt`` backend. + """ + del kwargs + if self.model is None: + raise NotImplementedError( + "Embeddings are only available for the `.pt` backend, not a " + "frozen `.pt2` package." + ) + if isinstance(data, AtomicData): + data = Batch.from_data_list([data]) + model_inputs = self.adapt_input(data) + ret = self.model.forward_common_lower( + model_inputs["coord"], + model_inputs["atype"], + model_inputs["edge_index"], + model_inputs["edge_vec"], + model_inputs["edge_scatter_index"], + model_inputs["edge_mask"], + charge_spin=model_inputs["charge_spin"], + embedding_only=True, + ) + n_node = data.num_nodes + node_embeddings = ret["descriptor"].reshape(n_node, self._descriptor_dim) + + atoms_group = data._atoms_group + if atoms_group is not None: + atoms_group["node_embeddings"] = node_embeddings + else: + data.node_embeddings = node_embeddings + + graph_embeddings = torch.zeros( + data.num_graphs, + self._descriptor_dim, + dtype=node_embeddings.dtype, + device=node_embeddings.device, + ) + graph_embeddings.index_add_(0, data.batch_idx.long(), node_embeddings) + data.graph_embeddings = graph_embeddings + return data + + # ------------------------------------------------------------------ + # Checkpoint loading + # ------------------------------------------------------------------ + + @classmethod + def from_checkpoint( + cls, + checkpoint_path: Path | str, + device: torch.device | str = "cpu", + *, + head: str | None = None, + **wrapper_kwargs: Any, + ) -> DPA4Wrapper: + """Load a DeePMD-kit SeZM / DPA-4 model into a wrapper. + + Parameters + ---------- + checkpoint_path + Either a ``.pt`` training checkpoint or a frozen ``.pt2`` + AOTInductor package. + device + Target device. For ``.pt2`` the package is device-locked to its + freeze host, so this must match. + head + Multi-task branch name; required for a multi-task ``.pt`` checkpoint. + **wrapper_kwargs + Forwarded to :class:`DPA4Wrapper` (e.g. ``compute_stress``, + ``atomic_number_to_type``). + """ + device = torch.device(device) if isinstance(device, str) else device + if str(checkpoint_path).endswith(".pt2"): + if head is not None: + raise NotImplementedError( + "Head selection is not supported for a frozen `.pt2` package; " + "freeze the desired head instead." + ) + return cls._from_pt2(checkpoint_path, device, **wrapper_kwargs) + return cls._from_pt(checkpoint_path, device, head=head, **wrapper_kwargs) + + @classmethod + def _from_pt( + cls, + checkpoint_path: Path | str, + device: torch.device, + *, + head: str | None, + **wrapper_kwargs: Any, + ) -> DPA4Wrapper: + """Load a ``.pt`` training checkpoint into a SeZM-backed wrapper.""" + from deepmd.pt.model.model import ( + get_model, + ) + from deepmd.pt.train.wrapper import ( + ModelWrapper, + ) + + state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True) + if "model" in state_dict: + state_dict = state_dict["model"] + model_params = state_dict["_extra_state"]["model_params"] + + if "model_dict" in model_params: + if head is None: + raise ValueError( + "`head` must be specified for a multi-task checkpoint. " + f"Available heads: {list(model_params['model_dict'])}." + ) + if head not in model_params["model_dict"]: + raise ValueError( + f"Unknown head {head!r} for this multi-task checkpoint. " + f"Available heads: {list(model_params['model_dict'])}." + ) + model_params = model_params["model_dict"][head] + head_state = {"_extra_state": state_dict["_extra_state"]} + for key, value in state_dict.items(): + if f"model.{head}." in key: + head_state[key.replace(f"model.{head}.", "model.Default.")] = ( + value.clone() + ) + state_dict = head_state + + model_params.pop("hessian_mode", None) + model = get_model(model_params).to(device) + model_wrapper = ModelWrapper(model) + model_wrapper.load_state_dict(state_dict) + sezm_model = model_wrapper.model["Default"] + if not isinstance(sezm_model, SeZMModel): + raise TypeError( + "Checkpoint does not contain a SeZM / DPA-4 model; got " + f"{type(sezm_model).__name__}." + ) + return cls(sezm_model.to(device), **wrapper_kwargs) + + @classmethod + def _from_pt2( + cls, + package_path: Path | str, + device: torch.device, + **wrapper_kwargs: Any, + ) -> DPA4Wrapper: + """Load a frozen ``.pt2`` AOTInductor package into a wrapper.""" + import json + import zipfile + + from torch._inductor import ( + aoti_load_package, + ) + + with zipfile.ZipFile(package_path, "r") as archive: + if _PT2_METADATA_ENTRY not in archive.namelist(): + raise ValueError( + f"{package_path!s} is missing {_PT2_METADATA_ENTRY!r}; it does " + "not look like a SeZM / DPA-4 `.pt2` archive." + ) + metadata = json.loads(archive.read(_PT2_METADATA_ENTRY).decode("utf-8")) + + self = cls.__new__(cls) + nn.Module.__init__(self) + self.model = None + self._aoti_runner = aoti_load_package(str(package_path)) + self._aoti_dim_fparam = int(metadata.get("dim_fparam", 0)) + self._aoti_dim_aparam = int(metadata.get("dim_aparam", 0)) + # `.pt2` packages are compiled with float64 I/O. + self._dtype = torch.float64 + self._descriptor_dim = None + self._configure( + rcut=float(metadata["rcut"]), + type_map=list(metadata["type_map"]), + device=device, + atomic_number_to_type=wrapper_kwargs.pop("atomic_number_to_type", None), + compute_stress=wrapper_kwargs.pop("compute_stress", False), + default_charge_spin=wrapper_kwargs.pop("default_charge_spin", None), + ) + if wrapper_kwargs: + raise TypeError( + f"Unexpected keyword arguments for a `.pt2` wrapper: " + f"{sorted(wrapper_kwargs)}." + ) + return self + + +# The descriptor is registered as both SeZM and DPA4; this alias keeps the +# SeZM name available. +SeZMWrapper = DPA4Wrapper diff --git a/doc/third-party/index.rst b/doc/third-party/index.rst index cd0726a4bb..fcf1b42160 100644 --- a/doc/third-party/index.rst +++ b/doc/third-party/index.rst @@ -11,4 +11,5 @@ Note that the model for inference is required to be compatible with the DeePMD-k lammps-command ipi gromacs + nvalchemi out-of-deepmd-kit diff --git a/doc/third-party/nvalchemi.md b/doc/third-party/nvalchemi.md new file mode 100644 index 0000000000..ac2af9846a --- /dev/null +++ b/doc/third-party/nvalchemi.md @@ -0,0 +1,212 @@ +# Molecular dynamics with nvalchemi-toolkit + +[`nvalchemi-toolkit`](https://github.com/NVIDIA/nvalchemi-toolkit) is NVIDIA's +GPU-accelerated framework for batched molecular dynamics and structure +optimization with machine-learning interatomic potentials. DeePMD-kit ships a +thin adapter, `DPA4Wrapper`, that exposes a trained DPA-4 / SeZM model to any +`nvalchemi` dynamics engine (NVE, NVT, NPT, FIRE, ...). The model itself runs +unmodified; the wrapper only translates between the `nvalchemi` graph batch and +the model's internal interface. + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }}, for DPA-4 / SeZM energy +models. `nvalchemi-toolkit` is an optional dependency and must be installed +separately. A CUDA device is recommended, since `nvalchemi`'s neighbour-list and +integrator kernels are GPU-accelerated. +::: + +## Installation + +Install the optional toolkit through the DeePMD-kit extra: + +```bash +pip install deepmd-kit[nvalchemi] +``` + +This pulls in the `nvalchemi-toolkit` package; equivalently, install it directly +with `pip install nvalchemi-toolkit`. Refer to the `nvalchemi-toolkit` +documentation for the build that matches your Python, platform, and CUDA +environment. + +The DeePMD-kit adapter lives in `deepmd.pt.nvalchemi`; importing it without +`nvalchemi-toolkit` present raises an actionable error. + +## Loading a model + +A trained DeePMD-kit checkpoint (`.pt`) is loaded and wrapped in one call: + +```python +import torch +from deepmd.pt.nvalchemi import DPA4Wrapper + +model = DPA4Wrapper.from_checkpoint( + "model.ckpt.pt", + device=torch.device("cuda"), + compute_stress=True, # enable the Cauchy stress output (needs a periodic cell) +) +``` + +For a multi-task checkpoint, pass the branch name with `head="..."`. An +already-instantiated model can be wrapped directly with `DPA4Wrapper(model)`. + +`from_checkpoint` also accepts a frozen `.pt2` (AOTInductor) package produced by +`dp --pt freeze`; it is loaded through its precompiled callable (float64 I/O, +and device-locked to the host it was frozen on). + +### Performance + +The model runs eagerly by default. To use DeePMD-kit's compiled inference path, +set the environment variables **before** loading the model: + +- `DP_COMPILE_INFER=1` — compile the model. The first call pays a one-time + compile cost (~1–2 min); subsequent steps are roughly 3x faster, and the + dynamic-shape graph handles the changing neighbour count during MD without + recompiling. +- `DP_TRITON_INFER=1` — additionally enable the Triton inference kernels for a + further speedup on larger cells. + +A frozen `.pt2` package bakes the compilation in — and, when you run +`dp --pt freeze` with `DP_TRITON_INFER=1` set, the Triton kernels too — so it +skips the warm-up at the cost of being device-locked. + +## Single-point evaluation + +Build an `AtomicData` object, batch it, compute a neighbour list, and call the +model. The wrapper returns a dictionary with `energy` (shape `(B, 1)`), +`forces` (shape `(N, 3)`), and, when enabled, `stress` (shape `(B, 3, 3)`): + +```python +from nvalchemi.data import AtomicData, Batch +from nvalchemi.neighbors import compute_neighbors + +data = AtomicData( + atomic_numbers=atomic_numbers, # (N,) integer atomic numbers + positions=positions, # (N, 3) in Angstrom + cell=cell, # (1, 3, 3) lattice vectors, or omit for a cluster + pbc=pbc, # (1, 3) booleans, or omit for a cluster +) +batch = Batch.from_data_list([data], device="cuda") +compute_neighbors(batch, config=model.model_config.neighbor_config) + +out = model(batch) +energy = out["energy"] # (B, 1) eV +forces = out["forces"] # (N, 3) eV/A +stress = out["stress"] # (B, 3, 3) eV/A^3 (Cauchy stress = virial / volume) +``` + +Forces and stress are computed conservatively inside the model and returned +directly, so no gradient bookkeeping is required on the caller side. + +## Molecular dynamics + +For dynamics, register a neighbour-list hook so the list is rebuilt before each +force evaluation, then drive the batch with an integrator. The following snippet +runs canonical (NVT) dynamics with a Langevin thermostat: + +```python +from nvalchemi.data import AtomicData, Batch +from nvalchemi.dynamics import initialize_velocities +from nvalchemi.dynamics.base import DynamicsStage +from nvalchemi.dynamics.integrators import NVTLangevin +from nvalchemi.hooks import NeighborListHook +from nvalchemi.neighbors import compute_neighbors + +# ``forces`` and ``energy`` are pre-allocated so the integrator can read forces +# and the engine can write results back into the batch in place. +data = AtomicData( + atomic_numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=pbc, + forces=torch.zeros_like(positions), + energy=torch.zeros((1, 1), dtype=positions.dtype, device=positions.device), +) +batch = Batch.from_data_list([data], device="cuda") + +# Seed Maxwell-Boltzmann velocities at the target temperature. +temperature = torch.full( + (batch.num_graphs,), 330.0, dtype=positions.dtype, device="cuda" +) +initialize_velocities( + batch.velocities, batch.atomic_masses, temperature, batch.batch_idx.int() +) + +nl_hook = NeighborListHook( + model.model_config.neighbor_config, stage=DynamicsStage.BEFORE_COMPUTE +) +nvt = NVTLangevin(model, dt=0.5, temperature=330.0, friction=0.01, hooks=[nl_hook]) + +# Prime the neighbour list and forces, then integrate. +compute_neighbors(batch, config=model.model_config.neighbor_config) +nvt.compute(batch) +batch = nvt.run(batch, n_steps=1000) +``` + +Switching ensemble is a one-line change: use `NVE(model, dt=...)` for the +microcanonical ensemble or `NPT(...)` for constant pressure (which consumes the +`stress` output). `nvalchemi` also provides logging and monitoring hooks (e.g. +`LoggingHook`, `EnergyDriftMonitorHook`) that attach to the same engine. + +## Geometry optimization + +The same model drives the FIRE optimizer. Convergence is controlled by a +maximum-force criterion: + +```python +from nvalchemi.dynamics.base import ConvergenceHook +from nvalchemi.dynamics.optimizers import FIRE2 + +opt = FIRE2( + model, + dt=1.0, + hooks=[nl_hook], + convergence_hook=ConvergenceHook.from_fmax(threshold=0.05), +) +compute_neighbors(batch, config=model.model_config.neighbor_config) +opt.compute(batch) +batch = opt.run(batch, n_steps=200) # stops early once fmax <= 0.05 eV/A +``` + +## Outputs and configuration + +The wrapper advertises its capabilities through `model.model_config` +(an `nvalchemi` `ModelConfig`): + +- `outputs` — `energy`, `forces`, and `stress`. +- `active_outputs` — the subset computed on each call. `energy` and `forces` are + active by default; `stress` is added when `compute_stress=True` (or via + `model.set_config("active_outputs", {"energy", "forces", "stress"})`). +- `neighbor_config` — the cutoff and neighbour-list format the model requires. + Pass it to `compute_neighbors` or `NeighborListHook` as shown above. + +## Heterogeneous batches + +A single `Batch` may contain several structures of different sizes and cells. +The wrapper evaluates the whole batch in one pass and returns per-structure +energy and stress (`(B, 1)` and `(B, 3, 3)`) together with the concatenated +per-atom forces (`(N, 3)`), making it straightforward to evaluate many systems +at once. + +## Units and conventions + +- Lengths are in Angstrom, energies in eV, masses in amu, and time in + femtoseconds. +- Atomic numbers are mapped to model types using the checkpoint `type_map`. Pass + `atomic_number_to_type={Z: type_index, ...}` to `DPA4Wrapper` to override this + for non-standard type maps. +- The reported `stress` is the Cauchy stress, equal to the virial divided by the + cell volume; it requires a periodic cell. + +## Limitations + +- Only DPA-4 / SeZM energy models are supported. +- Acceleration uses DeePMD-kit's own compiled inference (`DP_COMPILE_INFER` or a + frozen `.pt2`); `nvalchemi`'s `FusedStage` `torch.compile` is not used. +- Embeddings (`compute_embeddings`) require the `.pt` backend, not `.pt2`. +- Charge / spin conditioning is applied as a single global value per batch. + +## Examples + +Complete, runnable scripts for single-point evaluation, NVE, NVT, and geometry +optimization are provided in +[`examples/water/dpa4/nvalchemi/`](https://github.com/deepmodeling/deepmd-kit/tree/master/examples/water/dpa4/nvalchemi). diff --git a/examples/water/dpa4/nvalchemi/README.md b/examples/water/dpa4/nvalchemi/README.md new file mode 100644 index 0000000000..5c2b2dc3ac --- /dev/null +++ b/examples/water/dpa4/nvalchemi/README.md @@ -0,0 +1,76 @@ +# Running DPA-4 / SeZM with nvalchemi-toolkit + +This directory contains runnable examples for driving a trained DPA-4 / SeZM +model with NVIDIA's [`nvalchemi-toolkit`](https://github.com/NVIDIA/nvalchemi-toolkit) +molecular-dynamics framework. The model is loaded through +`deepmd.pt.nvalchemi.DPA4Wrapper`, a thin adapter that exposes a DeePMD-kit +PyTorch model to any `nvalchemi` dynamics engine. + +For a conceptual overview and the full API reference, see the user guide at +`doc/third-party/nvalchemi.md`. + +## Prerequisites + +- A DeePMD-kit installation with the PyTorch backend. +- The optional `nvalchemi-toolkit` package (`pip install deepmd-kit[nvalchemi]`, + or `pip install nvalchemi-toolkit`; see its documentation for the build matching + your Python, platform, and CUDA environment). A CUDA device is + recommended, since `nvalchemi`'s neighbour-list and integrator kernels are + GPU-accelerated. +- A trained DPA-4 / SeZM checkpoint (`.pt`, or a frozen `.pt2`). The examples + default to the smoke-test checkpoint shipped at `../lmp/pretrained.pt`; replace + it with your own model for production runs. + +For faster inference, export `DP_COMPILE_INFER=1` (optionally `DP_TRITON_INFER=1`) +before running any script to enable the compiled path, or pass a frozen `.pt2` +package as `--model`. + +The example structures are read from the bundled water dataset +(`../../data/data_0`), which provides a 192-atom periodic liquid-water cell. + +## Examples + +Each script is self-contained and documented; run any of them with `--help` to +see all options. + +| Script | Description | +| ----------------- | ----------------------------------------------------------------------- | +| `single_point.py` | Evaluate potential energy, atomic forces, and the Cauchy stress tensor. | +| `run_nve.py` | Microcanonical (NVE) MD; reports total-energy conservation. | +| `run_nvt.py` | Canonical (NVT) MD with a Langevin thermostat at a target temperature. | +| `relax.py` | Fixed-cell geometry optimization with the FIRE2 optimizer. | + +## Quick start + +```bash +cd examples/water/dpa4/nvalchemi + +# Single-point energy / forces / stress +python single_point.py + +# 200-step NVE trajectory seeded at 300 K +python run_nve.py --steps 200 --dt 0.5 --temperature 300 + +# NVT at 330 K with a Langevin thermostat +python run_nvt.py --steps 300 --temperature 330 --friction 0.01 + +# Relax to a maximum force of 0.05 eV/A +python relax.py --fmax 0.05 --max-steps 200 +``` + +To use your own model and structure: + +```bash +python run_nvt.py --model /path/to/model.ckpt.pt --data /path/to/deepmd/system +``` + +## Notes + +- **Units** follow the standard atomistic-MD convention: lengths in Angstrom, + energies in eV, masses in amu, and time in femtoseconds. +- **Element mapping** is derived automatically from the model `type_map`. Atoms + whose element is absent from the type map raise a clear error. +- **Stress** requires a periodic cell. The Cauchy stress equals the virial + divided by the cell volume. +- The shipped `pretrained.pt` is a 500-step smoke-test model, not a + production-quality water potential; use it only to verify the workflow. diff --git a/examples/water/dpa4/nvalchemi/relax.py b/examples/water/dpa4/nvalchemi/relax.py new file mode 100644 index 0000000000..bb2f0aadd3 --- /dev/null +++ b/examples/water/dpa4/nvalchemi/relax.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Geometry optimization (FIRE) with a DPA-4 / SeZM model. + +This example relaxes atomic positions toward a local energy minimum using the +FIRE2 optimizer in ``nvalchemi``, driven by a trained DPA-4 / SeZM potential. +The cell is held fixed; only atomic coordinates move. + +Convergence is controlled by a force criterion: the optimizer stops once the +maximum per-atom force norm (``fmax``) drops below the threshold. This is the +same criterion ASE's optimizers use, so the threshold is directly comparable. + +Usage +----- +:: + + python relax.py \ + --model ../lmp/pretrained.pt \ + --data ../../data/data_0 \ + --fmax 0.05 --max-steps 200 --dt 1.0 +""" + +from __future__ import ( + annotations, +) + +import argparse +from pathlib import ( + Path, +) + +import numpy as np +import torch +from nvalchemi.data import ( + AtomicData, + Batch, +) +from nvalchemi.dynamics.base import ( + ConvergenceHook, + DynamicsStage, +) +from nvalchemi.dynamics.optimizers import ( + FIRE2, +) +from nvalchemi.hooks import ( + NeighborListHook, +) +from nvalchemi.neighbors import ( + compute_neighbors, +) + +from deepmd.pt.model.model.sezm_model import ( + ELEMENT_TO_Z, +) +from deepmd.pt.nvalchemi import ( + DPA4Wrapper, +) + + +def load_frame( + data_dir: str | Path, + frame: int = 0, + dtype: torch.dtype = torch.float64, + device: torch.device | str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Load one periodic frame from a DeePMD-kit ``npy`` data system.""" + data_dir = Path(data_dir) + set_dir = sorted(data_dir.glob("set.*"))[0] + coord = np.load(set_dir / "coord.npy")[frame].reshape(-1, 3) + box = np.load(set_dir / "box.npy")[frame].reshape(3, 3) + type_index = np.loadtxt(data_dir / "type.raw", dtype=int).reshape(-1) + type_map = (data_dir / "type_map.raw").read_text().split() + z = np.array([ELEMENT_TO_Z[type_map[t]] for t in type_index], dtype=np.int64) + return ( + torch.tensor(z, dtype=torch.long, device=device), + torch.tensor(coord, dtype=dtype, device=device), + torch.tensor(box, dtype=dtype, device=device).reshape(1, 3, 3), + ) + + +def fmax(batch: Batch) -> float: + """Maximum per-atom force norm in eV/A.""" + return batch.forces.norm(dim=-1).max().item() + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="../lmp/pretrained.pt") + parser.add_argument("--data", default="../../data/data_0") + parser.add_argument("--frame", type=int, default=0) + parser.add_argument( + "--fmax", type=float, default=0.05, help="force convergence threshold (eV/A)" + ) + parser.add_argument( + "--max-steps", type=int, default=200, help="maximum optimizer steps" + ) + parser.add_argument("--dt", type=float, default=1.0, help="initial FIRE step (fs)") + parser.add_argument("--log-every", type=int, default=20) + parser.add_argument( + "--device", default="cuda" if torch.cuda.is_available() else "cpu" + ) + args = parser.parse_args() + if args.log_every <= 0: + parser.error("--log-every must be a positive integer") + device = torch.device(args.device) + + model = DPA4Wrapper.from_checkpoint(args.model, device=device) + model.eval() + + atomic_numbers, positions, cell = load_frame( + args.data, frame=args.frame, device=device + ) + n_atoms = atomic_numbers.shape[0] + # FIRE reuses the ``velocities`` field as its internal velocity (starting + # from rest); ``forces`` and ``energy`` are written back by ``compute()``. + data = AtomicData( + atomic_numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=torch.ones(1, 3, dtype=torch.bool, device=device), + forces=torch.zeros_like(positions), + energy=torch.zeros((1, 1), dtype=positions.dtype, device=device), + ) + batch = Batch.from_data_list([data], device=device) + + nl_hook = NeighborListHook( + model.model_config.neighbor_config, stage=DynamicsStage.BEFORE_COMPUTE + ) + opt = FIRE2( + model, + dt=args.dt, + hooks=[nl_hook], + convergence_hook=ConvergenceHook.from_fmax(threshold=args.fmax), + ) + + compute_neighbors(batch, config=model.model_config.neighbor_config) + opt.compute(batch) + + e0 = batch.energy.item() + print(f"model : {args.model} (rcut={model.rcut} A)") + print(f"system : {n_atoms} atoms, fmax target={args.fmax} eV/A") + print(f"{'step':>8} {'E_pot[eV]':>14} {'fmax[eV/A]':>12}") + print(f"{0:>8} {e0:>14.4f} {fmax(batch):>12.5f}") + + converged = False + step = 0 + while step < args.max_steps: + chunk = min(args.log_every, args.max_steps - step) + batch = opt.run(batch, n_steps=chunk) + step += chunk + print(f"{step:>8} {batch.energy.item():>14.4f} {fmax(batch):>12.5f}") + if fmax(batch) <= args.fmax: + converged = True + break + + e1 = batch.energy.item() + status = "converged" if converged else f"not converged in {args.max_steps} steps" + print(f"\n{status}: fmax={fmax(batch):.5f} eV/A") + print(f"energy change: {(e1 - e0) / n_atoms * 1e3:.4f} meV/atom") + + +if __name__ == "__main__": + main() diff --git a/examples/water/dpa4/nvalchemi/run_nve.py b/examples/water/dpa4/nvalchemi/run_nve.py new file mode 100644 index 0000000000..aab322c167 --- /dev/null +++ b/examples/water/dpa4/nvalchemi/run_nve.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Microcanonical (NVE) molecular dynamics with a DPA-4 / SeZM model. + +This example runs velocity-Verlet NVE dynamics through ``nvalchemi`` driven by a +trained DPA-4 / SeZM potential. It demonstrates the full molecular-dynamics +wiring: + +* wrap the model with :class:`deepmd.pt.nvalchemi.DPA4Wrapper`; +* build a periodic ``nvalchemi`` batch with allocated ``forces`` / ``velocities``; +* draw initial velocities from the Maxwell-Boltzmann distribution; +* register a COO :class:`~nvalchemi.hooks.NeighborListHook` so the neighbour + list is rebuilt before each force evaluation; +* integrate with :class:`~nvalchemi.dynamics.integrators.NVE` and monitor the + conserved total energy ``E_pot + E_kin``. + +NVE conserves the total energy, so the drift over the run is a direct measure of +the integration quality for the given timestep. + +Usage +----- +:: + + python run_nve.py \ + --model ../lmp/pretrained.pt \ + --data ../../data/data_0 \ + --steps 200 --dt 0.5 --temperature 300 +""" + +from __future__ import ( + annotations, +) + +import argparse +from pathlib import ( + Path, +) + +import numpy as np +import torch +from nvalchemi.data import ( + AtomicData, + Batch, +) +from nvalchemi.dynamics import ( + initialize_velocities, +) +from nvalchemi.dynamics.base import ( + DynamicsStage, +) +from nvalchemi.dynamics.integrators import ( + NVE, +) +from nvalchemi.hooks import ( + NeighborListHook, +) +from nvalchemi.neighbors import ( + compute_neighbors, +) + +from deepmd.pt.model.model.sezm_model import ( + ELEMENT_TO_Z, +) +from deepmd.pt.nvalchemi import ( + DPA4Wrapper, +) + +# Boltzmann constant in eV/K (positions in A, masses in amu, energy in eV). +_KB_EV = 8.617333262e-5 + + +def load_frame( + data_dir: str | Path, + frame: int = 0, + dtype: torch.dtype = torch.float64, + device: torch.device | str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Load one periodic frame from a DeePMD-kit ``npy`` data system.""" + data_dir = Path(data_dir) + set_dir = sorted(data_dir.glob("set.*"))[0] + coord = np.load(set_dir / "coord.npy")[frame].reshape(-1, 3) + box = np.load(set_dir / "box.npy")[frame].reshape(3, 3) + type_index = np.loadtxt(data_dir / "type.raw", dtype=int).reshape(-1) + type_map = (data_dir / "type_map.raw").read_text().split() + z = np.array([ELEMENT_TO_Z[type_map[t]] for t in type_index], dtype=np.int64) + return ( + torch.tensor(z, dtype=torch.long, device=device), + torch.tensor(coord, dtype=dtype, device=device), + torch.tensor(box, dtype=dtype, device=device).reshape(1, 3, 3), + ) + + +def thermo(batch: Batch, n_atoms: int) -> tuple[float, float, float, float]: + """Return (potential energy, kinetic energy, temperature, max force). + + Kinetic energy is ``0.5 * sum(m v^2)`` in eV (the integrator's internal + velocity unit makes this expression directly an energy), and the + temperature follows from equipartition with ``3N`` degrees of freedom. + """ + ke = (0.5 * batch.atomic_masses * (batch.velocities**2).sum(-1)).sum().item() + pe = batch.energy.item() + temperature = 2.0 * ke / (3.0 * n_atoms * _KB_EV) + fmax = batch.forces.norm(dim=-1).max().item() + return pe, ke, temperature, fmax + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="../lmp/pretrained.pt") + parser.add_argument("--data", default="../../data/data_0") + parser.add_argument("--frame", type=int, default=0) + parser.add_argument("--steps", type=int, default=200, help="number of MD steps") + parser.add_argument("--dt", type=float, default=0.5, help="timestep in fs") + parser.add_argument( + "--temperature", type=float, default=300.0, help="initial temperature in K" + ) + parser.add_argument("--log-every", type=int, default=20) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--device", default="cuda" if torch.cuda.is_available() else "cpu" + ) + args = parser.parse_args() + if args.log_every <= 0: + parser.error("--log-every must be a positive integer") + device = torch.device(args.device) + + model = DPA4Wrapper.from_checkpoint(args.model, device=device) + model.eval() + + atomic_numbers, positions, cell = load_frame( + args.data, frame=args.frame, device=device + ) + n_atoms = atomic_numbers.shape[0] + # ``forces`` and ``energy`` are pre-allocated so the integrator can read + # forces and ``compute()`` can write energy / forces back in place. + data = AtomicData( + atomic_numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=torch.ones(1, 3, dtype=torch.bool, device=device), + forces=torch.zeros_like(positions), + energy=torch.zeros((1, 1), dtype=positions.dtype, device=device), + ) + batch = Batch.from_data_list([data], device=device) + + # Draw Maxwell-Boltzmann velocities at the target temperature (in-place). + temperature = torch.full( + (batch.num_graphs,), args.temperature, dtype=positions.dtype, device=device + ) + initialize_velocities( + batch.velocities, + batch.atomic_masses, + temperature, + batch.batch_idx.int(), + random_seed=args.seed, + ) + + nl_hook = NeighborListHook( + model.model_config.neighbor_config, stage=DynamicsStage.BEFORE_COMPUTE + ) + nve = NVE(model, dt=args.dt, hooks=[nl_hook]) + + # Prime the neighbour list and forces so the first half-kick is exact. + compute_neighbors(batch, config=model.model_config.neighbor_config) + nve.compute(batch) + + pe0, ke0, t0, fmax0 = thermo(batch, n_atoms) + e_tot0 = pe0 + ke0 + print(f"model : {args.model} (rcut={model.rcut} A)") + print(f"system : {n_atoms} atoms, dt={args.dt} fs, T0={args.temperature} K") + print(f"{'step':>8} {'E_pot[eV]':>14} {'E_tot[eV]':>14} {'T[K]':>9} {'fmax':>9}") + print(f"{0:>8} {pe0:>14.4f} {e_tot0:>14.4f} {t0:>9.2f} {fmax0:>9.4f}") + + step = 0 + while step < args.steps: + chunk = min(args.log_every, args.steps - step) + batch = nve.run(batch, n_steps=chunk) + step += chunk + pe, ke, temperature_now, fmax = thermo(batch, n_atoms) + print( + f"{step:>8} {pe:>14.4f} {pe + ke:>14.4f} {temperature_now:>9.2f} " + f"{fmax:>9.4f}" + ) + + pe, ke, _, _ = thermo(batch, n_atoms) + drift = (pe + ke - e_tot0) / n_atoms + print(f"\ntotal-energy drift: {drift * 1e3:.4f} meV/atom over {args.steps} steps") + + +if __name__ == "__main__": + main() diff --git a/examples/water/dpa4/nvalchemi/run_nvt.py b/examples/water/dpa4/nvalchemi/run_nvt.py new file mode 100644 index 0000000000..213642ba06 --- /dev/null +++ b/examples/water/dpa4/nvalchemi/run_nvt.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Canonical (NVT) molecular dynamics with a DPA-4 / SeZM model. + +This example runs Langevin NVT dynamics through ``nvalchemi`` driven by a trained +DPA-4 / SeZM potential. A Langevin thermostat couples the system to a heat bath, +so the instantaneous temperature fluctuates around the target value rather than +being conserved (as in NVE). + +The wiring mirrors :mod:`run_nve` -- wrap the model, build a periodic batch, +seed Maxwell-Boltzmann velocities, register a neighbour-list hook -- but uses +:class:`~nvalchemi.dynamics.integrators.NVTLangevin`, which additionally takes a +target ``temperature`` and a ``friction`` coefficient. + +Usage +----- +:: + + python run_nvt.py \ + --model ../lmp/pretrained.pt \ + --data ../../data/data_0 \ + --steps 300 --dt 0.5 --temperature 330 --friction 0.01 +""" + +from __future__ import ( + annotations, +) + +import argparse +from pathlib import ( + Path, +) + +import numpy as np +import torch +from nvalchemi.data import ( + AtomicData, + Batch, +) +from nvalchemi.dynamics import ( + initialize_velocities, +) +from nvalchemi.dynamics.base import ( + DynamicsStage, +) +from nvalchemi.dynamics.integrators import ( + NVTLangevin, +) +from nvalchemi.hooks import ( + NeighborListHook, +) +from nvalchemi.neighbors import ( + compute_neighbors, +) + +from deepmd.pt.model.model.sezm_model import ( + ELEMENT_TO_Z, +) +from deepmd.pt.nvalchemi import ( + DPA4Wrapper, +) + +# Boltzmann constant in eV/K (positions in A, masses in amu, energy in eV). +_KB_EV = 8.617333262e-5 + + +def load_frame( + data_dir: str | Path, + frame: int = 0, + dtype: torch.dtype = torch.float64, + device: torch.device | str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Load one periodic frame from a DeePMD-kit ``npy`` data system.""" + data_dir = Path(data_dir) + set_dir = sorted(data_dir.glob("set.*"))[0] + coord = np.load(set_dir / "coord.npy")[frame].reshape(-1, 3) + box = np.load(set_dir / "box.npy")[frame].reshape(3, 3) + type_index = np.loadtxt(data_dir / "type.raw", dtype=int).reshape(-1) + type_map = (data_dir / "type_map.raw").read_text().split() + z = np.array([ELEMENT_TO_Z[type_map[t]] for t in type_index], dtype=np.int64) + return ( + torch.tensor(z, dtype=torch.long, device=device), + torch.tensor(coord, dtype=dtype, device=device), + torch.tensor(box, dtype=dtype, device=device).reshape(1, 3, 3), + ) + + +def temperature_kelvin(batch: Batch, n_atoms: int) -> float: + """Instantaneous kinetic temperature from ``T = 2 KE / (3 N k_B)``.""" + ke = (0.5 * batch.atomic_masses * (batch.velocities**2).sum(-1)).sum().item() + return 2.0 * ke / (3.0 * n_atoms * _KB_EV) + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="../lmp/pretrained.pt") + parser.add_argument("--data", default="../../data/data_0") + parser.add_argument("--frame", type=int, default=0) + parser.add_argument("--steps", type=int, default=300, help="number of MD steps") + parser.add_argument("--dt", type=float, default=0.5, help="timestep in fs") + parser.add_argument( + "--temperature", type=float, default=330.0, help="target temperature in K" + ) + parser.add_argument( + "--friction", type=float, default=0.01, help="Langevin friction in 1/fs" + ) + parser.add_argument("--log-every", type=int, default=50) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--device", default="cuda" if torch.cuda.is_available() else "cpu" + ) + args = parser.parse_args() + if args.log_every <= 0: + parser.error("--log-every must be a positive integer") + device = torch.device(args.device) + + model = DPA4Wrapper.from_checkpoint(args.model, device=device) + model.eval() + + atomic_numbers, positions, cell = load_frame( + args.data, frame=args.frame, device=device + ) + n_atoms = atomic_numbers.shape[0] + data = AtomicData( + atomic_numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=torch.ones(1, 3, dtype=torch.bool, device=device), + forces=torch.zeros_like(positions), + energy=torch.zeros((1, 1), dtype=positions.dtype, device=device), + ) + batch = Batch.from_data_list([data], device=device) + + temperature = torch.full( + (batch.num_graphs,), args.temperature, dtype=positions.dtype, device=device + ) + initialize_velocities( + batch.velocities, + batch.atomic_masses, + temperature, + batch.batch_idx.int(), + random_seed=args.seed, + ) + + nl_hook = NeighborListHook( + model.model_config.neighbor_config, stage=DynamicsStage.BEFORE_COMPUTE + ) + nvt = NVTLangevin( + model, + dt=args.dt, + temperature=args.temperature, + friction=args.friction, + random_seed=args.seed, + hooks=[nl_hook], + ) + + # Prime the neighbour list and forces before the first half-kick. + compute_neighbors(batch, config=model.model_config.neighbor_config) + nvt.compute(batch) + + print(f"model : {args.model} (rcut={model.rcut} A)") + print( + f"system : {n_atoms} atoms, dt={args.dt} fs, " + f"T_target={args.temperature} K, friction={args.friction}/fs" + ) + print(f"{'step':>8} {'E_pot[eV]':>14} {'T[K]':>9} {'fmax':>9}") + print( + f"{0:>8} {batch.energy.item():>14.4f} " + f"{temperature_kelvin(batch, n_atoms):>9.2f} " + f"{batch.forces.norm(dim=-1).max().item():>9.4f}" + ) + + step = 0 + while step < args.steps: + chunk = min(args.log_every, args.steps - step) + batch = nvt.run(batch, n_steps=chunk) + step += chunk + print( + f"{step:>8} {batch.energy.item():>14.4f} " + f"{temperature_kelvin(batch, n_atoms):>9.2f} " + f"{batch.forces.norm(dim=-1).max().item():>9.4f}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/water/dpa4/nvalchemi/single_point.py b/examples/water/dpa4/nvalchemi/single_point.py new file mode 100644 index 0000000000..19cdad5faa --- /dev/null +++ b/examples/water/dpa4/nvalchemi/single_point.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Single-point energy, forces, and stress with a DPA-4 / SeZM model. + +This example loads a trained DPA-4 / SeZM checkpoint, wraps it with +:class:`deepmd.pt.nvalchemi.DPA4Wrapper`, builds an ``nvalchemi`` batch from a +DeePMD-kit data frame, computes a neighbour list, and evaluates the potential +energy, atomic forces, and the Cauchy stress tensor for one configuration. + +It is the smallest complete example of the nvalchemi inference path and a good +starting point before running molecular dynamics. + +Usage +----- +:: + + python single_point.py \ + --model ../lmp/pretrained.pt \ + --data ../../data/data_0 +""" + +from __future__ import ( + annotations, +) + +import argparse +from pathlib import ( + Path, +) + +import numpy as np +import torch +from nvalchemi.data import ( + AtomicData, + Batch, +) +from nvalchemi.neighbors import ( + compute_neighbors, +) + +from deepmd.pt.model.model.sezm_model import ( + ELEMENT_TO_Z, +) +from deepmd.pt.nvalchemi import ( + DPA4Wrapper, +) + +# 1 eV/A^3 expressed in GPa, for reporting the pressure in familiar units. +_EV_PER_A3_TO_GPA = 160.21766208 + + +def load_frame( + data_dir: str | Path, + frame: int = 0, + dtype: torch.dtype = torch.float64, + device: torch.device | str = "cpu", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Load one frame from a DeePMD-kit ``npy`` data system. + + Parameters + ---------- + data_dir + Directory holding ``type.raw``, ``type_map.raw``, and a ``set.*`` + sub-directory with ``coord.npy`` and ``box.npy``. + frame + Frame index to load. + dtype + Floating-point dtype for positions and cell. + device + Target device. + + Returns + ------- + atomic_numbers + ``(N,)`` long tensor of atomic numbers. + positions + ``(N, 3)`` Cartesian coordinates in Angstrom. + cell + ``(1, 3, 3)`` lattice vectors (rows) in Angstrom. + """ + data_dir = Path(data_dir) + set_dir = sorted(data_dir.glob("set.*"))[0] + coord = np.load(set_dir / "coord.npy")[frame].reshape(-1, 3) + box = np.load(set_dir / "box.npy")[frame].reshape(3, 3) + type_index = np.loadtxt(data_dir / "type.raw", dtype=int).reshape(-1) + type_map = (data_dir / "type_map.raw").read_text().split() + z = np.array([ELEMENT_TO_Z[type_map[t]] for t in type_index], dtype=np.int64) + + atomic_numbers = torch.tensor(z, dtype=torch.long, device=device) + positions = torch.tensor(coord, dtype=dtype, device=device) + cell = torch.tensor(box, dtype=dtype, device=device).reshape(1, 3, 3) + return atomic_numbers, positions, cell + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--model", default="../lmp/pretrained.pt", help="DPA-4 / SeZM checkpoint (.pt)" + ) + parser.add_argument( + "--data", default="../../data/data_0", help="DeePMD-kit data system directory" + ) + parser.add_argument("--frame", type=int, default=0, help="frame index to evaluate") + parser.add_argument( + "--device", + default="cuda" if torch.cuda.is_available() else "cpu", + help="torch device", + ) + args = parser.parse_args() + device = torch.device(args.device) + + # Wrap the trained model. ``compute_stress=True`` adds the Cauchy stress to + # the active outputs (it requires a periodic cell). + model = DPA4Wrapper.from_checkpoint(args.model, device=device, compute_stress=True) + model.eval() + + atomic_numbers, positions, cell = load_frame( + args.data, frame=args.frame, device=device + ) + n_atoms = atomic_numbers.shape[0] + data = AtomicData( + atomic_numbers=atomic_numbers, + positions=positions, + cell=cell, + pbc=torch.ones(1, 3, dtype=torch.bool, device=device), + ) + batch = Batch.from_data_list([data], device=device) + + # Populate ``batch.neighbor_list`` / ``batch.neighbor_list_shifts`` with the + # cutoff the model declares in its ModelConfig. + compute_neighbors(batch, config=model.model_config.neighbor_config) + out = model(batch) + + energy = out["energy"].item() + forces = out["forces"] + stress = out["stress"][0] + pressure = -torch.diagonal(stress).mean().item() # -tr(sigma)/3 + + print(f"model : {args.model} (rcut={model.rcut} A)") + print(f"system : {n_atoms} atoms, {batch.num_edges} edges") + print(f"energy : {energy:.6f} eV ({energy / n_atoms:.6f} eV/atom)") + print(f"max |force| : {forces.norm(dim=-1).max().item():.6f} eV/A") + print( + f"rms |force| : {forces.norm(dim=-1).pow(2).mean().sqrt().item():.6f} eV/A" + ) + print( + f"pressure : {pressure:.6e} eV/A^3 " + f"({pressure * _EV_PER_A3_TO_GPA:.4f} GPa)" + ) + print("stress (eV/A^3) :") + for row in stress.tolist(): + print(" " + " ".join(f"{v: .6e}" for v in row)) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 35fc0fdb18..b5af916203 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ test = [ "pytest-split", "pytest-timeout", "dpgui", + 'nvalchemi-toolkit; python_version >= "3.11" and python_version < "3.14" and platform_system == "Linux"', # to support Array API 2024.12 'array-api-strict>=2.2;python_version>="3.9"', ] @@ -120,6 +121,9 @@ lmp = [ ipi = [ "ipi", ] +nvalchemi = [ + 'nvalchemi-toolkit; python_version >= "3.11" and python_version < "3.14" and platform_system == "Linux"', +] gui = [ "dpgui", ] @@ -464,6 +468,8 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "**/tests/**/test_*.py" = ["ANN"] "**/tests/**/*_test.py" = ["ANN"] "**/*.ipynb" = ["T20"] # printing in a nb file is expected +# Example / demo scripts are run directly: top-level imports and prints are fine. +"examples/**/*.py" = ["T20", "TID253"] [tool.pytest.ini_options] markers = "run" diff --git a/source/tests/pt/model/test_sezm_nvalchemi.py b/source/tests/pt/model/test_sezm_nvalchemi.py new file mode 100644 index 0000000000..f6c2c8c426 --- /dev/null +++ b/source/tests/pt/model/test_sezm_nvalchemi.py @@ -0,0 +1,430 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Unit tests for the SeZM / DPA-4 ``nvalchemi-toolkit`` wrapper. + +The wrapper (:class:`deepmd.pt.nvalchemi.DPA4Wrapper`) drives a SeZM model +through its sparse-edge lower interface from an ``nvalchemi`` graph batch. The +gold-standard correctness check is parity with the model's own neighbour-list +``forward``: feeding an identical structure through both paths must yield the +same energy, forces, and virial. These tests pin that parity (periodic, +non-periodic, and heterogeneous batched) plus the embedding and type-mapping +paths. + +The whole suite is skipped when ``nvalchemi-toolkit`` is not installed. +""" + +from __future__ import ( + annotations, +) + +import contextlib +import unittest +from typing import ( + TYPE_CHECKING, +) + +import torch + +from deepmd.pt.model.model import ( + get_sezm_model, +) +from deepmd.pt.model.model.sezm_model import ( + ELEMENT_TO_Z, +) +from deepmd.pt.utils import ( + env, +) + +if TYPE_CHECKING: + from collections.abc import ( + Iterator, + ) + + +@contextlib.contextmanager +def _clear_default_device() -> Iterator[None]: + """Disable the pt-test ``cuda:9999999`` sentinel default device. + + ``source/tests/pt/__init__.py`` sets an invalid default device so tests + that rely on implicit placement fail loudly. ``nvalchemi`` / ``tensordict`` + allocate unnamed tensors without an explicit device (both at import and at + runtime), so this guard temporarily restores the real default. Matches the + pattern in ``test_sezm_export.py``. + """ + saved = torch.get_default_device() + torch.set_default_device(None) + try: + yield + finally: + torch.set_default_device(saved) + + +try: + with _clear_default_device(): + from nvalchemi.data import ( + AtomicData, + Batch, + ) + from nvalchemi.neighbors import ( + compute_neighbors, + ) + + from deepmd.pt.nvalchemi import ( + DPA4Wrapper, + ) + + NVALCHEMI_AVAILABLE = True +except ImportError: + NVALCHEMI_AVAILABLE = False + +TYPE_MAP = ["O", "H"] +RCUT = 4.0 + + +class _ClearDefaultDeviceTestCase(unittest.TestCase): + """Run a test class while the pt default-device sentinel is disabled.""" + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls._default_device_ctx = _clear_default_device() + cls._default_device_ctx.__enter__() + + @classmethod + def tearDownClass(cls) -> None: + try: + super().tearDownClass() + finally: + ctx = getattr(cls, "_default_device_ctx", None) + if ctx is not None: + ctx.__exit__(None, None, None) + delattr(cls, "_default_device_ctx") + + +@unittest.skipUnless(NVALCHEMI_AVAILABLE, "nvalchemi-toolkit is not installed") +class TestSeZMNVAlchemiWrapper(_ClearDefaultDeviceTestCase): + """Parity of the nvalchemi wrapper against the native SeZM forward.""" + + def setUp(self) -> None: + self.device = env.DEVICE + self.model = self._build_model() + + # ------------------------------------------------------------------ + # Fixtures + # ------------------------------------------------------------------ + + def _build_model(self) -> torch.nn.Module: + """A tiny float64 SeZM model with randomized (non-trivial) weights.""" + params = { + "type": "SeZM", + "type_map": TYPE_MAP, + "descriptor": { + "type": "SeZM", + "sel": [80, 80], + "rcut": RCUT, + "channels": 8, + "n_focus": 1, + "n_radial": 4, + "radial_mlp": [8], + "use_env_seed": True, + "l_schedule": [2, 1], + "mmax": 1, + "so2_layers": 1, + "n_atten_head": 1, + "sandwich_norm": [True, False, True, False], + "ffn_neurons": 8, + "ffn_blocks": 1, + "s2_activation": [False, True], + "mlp_bias": False, + "layer_scale": False, + "use_amp": False, + "activation_function": "silu", + "glu_activation": True, + "precision": "float64", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float64", + "seed": 7, + }, + "use_compile": False, + } + model = get_sezm_model(params).to(self.device) + torch.manual_seed(1234) + with torch.no_grad(): + for p in model.parameters(): + p.copy_(torch.randn_like(p) * 0.1) + model.eval() + return model + + def _system(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """A periodic system with a tight cluster and a cross-boundary pair. + + Distances are kept either well below ``RCUT`` or well above it, so the + native ghost-atom neighbour list and the nvalchemi COO list select the + same edge set (the C3 envelope makes any near-cutoff edge negligible). + """ + coord = torch.tensor( + [ + [4.0, 4.0, 4.0], + [4.85, 4.20, 4.10], + [4.10, 4.90, 3.85], + [3.80, 4.15, 4.80], + [0.40, 1.20, 1.20], + [8.75, 1.25, 1.15], + ], + dtype=torch.float64, + device=self.device, + ) + atype = torch.tensor([0, 1, 1, 1, 0, 1], dtype=torch.int64, device=self.device) + box = torch.eye(3, dtype=torch.float64, device=self.device) * 9.0 + return coord, atype, box + + def _second_system(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """A differently sized periodic system with a different cell.""" + coord = torch.tensor( + [ + [2.0, 2.0, 2.0], + [2.9, 2.1, 1.9], + [2.1, 2.8, 2.2], + [1.85, 1.95, 2.85], + ], + dtype=torch.float64, + device=self.device, + ) + atype = torch.tensor([0, 1, 1, 0], dtype=torch.int64, device=self.device) + box = torch.eye(3, dtype=torch.float64, device=self.device) * 8.0 + return coord, atype, box + + def _atype_to_z(self, atype: torch.Tensor) -> torch.Tensor: + z_of_type = torch.tensor( + [ELEMENT_TO_Z[s] for s in TYPE_MAP], + dtype=torch.long, + device=atype.device, + ) + return z_of_type.index_select(0, atype.long()) + + def _native( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + """Energy / forces / virial from the native neighbour-list forward.""" + nloc = coord.shape[0] + out = self.model( + coord.view(1, nloc, 3), + atype.view(1, nloc), + box=None if box is None else box.reshape(1, 9), + do_atomic_virial=True, + ) + return { + "energy": out["energy"].reshape(1).detach(), + "forces": out["force"].reshape(nloc, 3).detach(), + "virial": out["virial"].reshape(3, 3).detach(), + } + + def _data( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None, + ) -> AtomicData: + fields = { + "atomic_numbers": self._atype_to_z(atype), + "positions": coord.clone(), + } + if box is not None: + fields["cell"] = box.reshape(1, 3, 3) + fields["pbc"] = torch.ones(1, 3, dtype=torch.bool, device=self.device) + return AtomicData(**fields) + + def _wrapper_batch_out( + self, + wrapper: DPA4Wrapper, + batch: Batch, + ) -> dict[str, torch.Tensor]: + compute_neighbors(batch, config=wrapper.model_config.neighbor_config) + return wrapper(batch) + + # ------------------------------------------------------------------ + # Tests + # ------------------------------------------------------------------ + + def test_parity_periodic(self) -> None: + """Energy / forces / virial match native forward for a periodic cell.""" + coord, atype, box = self._system() + wrapper = DPA4Wrapper(self.model, compute_stress=True) + ref = self._native(coord, atype, box) + + batch = Batch.from_data_list( + [self._data(coord, atype, box)], device=self.device + ) + out = self._wrapper_batch_out(wrapper, batch) + volume = torch.det(box).abs() + + torch.testing.assert_close( + out["energy"].reshape(1), ref["energy"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["forces"].reshape(-1, 3), ref["forces"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["stress"].reshape(3, 3) * volume, ref["virial"], atol=1e-7, rtol=1e-6 + ) + + def test_stress_matches_strain_finite_difference(self) -> None: + """Nvalchemi stress follows its ``W / V`` convention. + + This is intentionally non-circular: the finite difference uses only + wrapper energies under a small affine cell strain, not the native virial + output. nvalchemi expects ``stress = W / V`` with + ``W = -dE/d(strain)``; this differs by sign from ASE's tensile-stress + convention. + """ + coord, atype, box = self._system() + wrapper = DPA4Wrapper(self.model, compute_stress=True) + + batch = Batch.from_data_list( + [self._data(coord, atype, box)], device=self.device + ) + out = self._wrapper_batch_out(wrapper, batch) + volume = torch.det(box).abs() + virial_from_stress = out["stress"].reshape(3, 3) * volume + + torch.manual_seed(0) + sym = torch.randn(3, 3, dtype=torch.float64, device=self.device) + strain = 1.0e-4 * (sym + sym.transpose(0, 1)) + eye = torch.eye(3, dtype=torch.float64, device=self.device) + + def wrapper_energy(sign: float) -> torch.Tensor: + transform = (eye + sign * strain).transpose(0, 1) + coord_d = coord @ transform + box_d = box @ transform + batch_d = Batch.from_data_list( + [self._data(coord_d, atype, box_d)], device=self.device + ) + return self._wrapper_batch_out(wrapper, batch_d)["energy"].reshape(()) + + e_plus = wrapper_energy(1.0) + e_minus = wrapper_energy(-1.0) + lhs = (virial_from_stress * strain).sum() + rhs = -(e_plus - e_minus) / 2.0 + torch.testing.assert_close(lhs, rhs, atol=1.0e-8, rtol=1.0e-4) + + def test_parity_nonperiodic(self) -> None: + """Energy / forces match native forward for an open-boundary cluster.""" + coord, atype, _ = self._system() + wrapper = DPA4Wrapper(self.model, compute_stress=False) + ref = self._native(coord, atype, None) + + batch = Batch.from_data_list( + [self._data(coord, atype, None)], device=self.device + ) + out = self._wrapper_batch_out(wrapper, batch) + + torch.testing.assert_close( + out["energy"].reshape(1), ref["energy"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["forces"].reshape(-1, 3), ref["forces"], atol=1e-7, rtol=1e-6 + ) + self.assertNotIn("stress", out) + + def test_parity_batched_heterogeneous(self) -> None: + """Per-graph outputs match native runs for a two-graph batch. + + The two graphs differ in size *and* cell, exercising the ``batch_idx`` + segment reduction and the global ``neighbor_list`` node offsets. + """ + coord_a, atype_a, box_a = self._system() + coord_b, atype_b, box_b = self._second_system() + n_a = coord_a.shape[0] + wrapper = DPA4Wrapper(self.model, compute_stress=True) + + ref_a = self._native(coord_a, atype_a, box_a) + ref_b = self._native(coord_b, atype_b, box_b) + + batch = Batch.from_data_list( + [self._data(coord_a, atype_a, box_a), self._data(coord_b, atype_b, box_b)], + device=self.device, + ) + out = self._wrapper_batch_out(wrapper, batch) + vol_a = torch.det(box_a).abs() + vol_b = torch.det(box_b).abs() + + torch.testing.assert_close( + out["energy"][0], ref_a["energy"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["energy"][1], ref_b["energy"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["forces"][:n_a], ref_a["forces"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["forces"][n_a:], ref_b["forces"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["stress"][0] * vol_a, ref_a["virial"], atol=1e-7, rtol=1e-6 + ) + torch.testing.assert_close( + out["stress"][1] * vol_b, ref_b["virial"], atol=1e-7, rtol=1e-6 + ) + + def test_compute_embeddings_shapes(self) -> None: + """Embeddings have the advertised per-atom / per-graph descriptor width.""" + coord, atype, box = self._system() + wrapper = DPA4Wrapper(self.model) + dim = wrapper.embedding_shapes["node_embeddings"][0] + + batch = Batch.from_data_list( + [self._data(coord, atype, box)], device=self.device + ) + compute_neighbors(batch, config=wrapper.model_config.neighbor_config) + out = wrapper.compute_embeddings(batch) + + self.assertEqual(tuple(out.node_embeddings.shape), (coord.shape[0], dim)) + self.assertEqual(tuple(out.graph_embeddings.shape), (1, dim)) + self.assertTrue(torch.isfinite(out.node_embeddings).all()) + + def test_custom_type_mapping(self) -> None: + """An explicit atomic-number map reproduces the type-map default.""" + coord, atype, box = self._system() + ref = DPA4Wrapper(self.model) + override = DPA4Wrapper( + self.model, + atomic_number_to_type={ELEMENT_TO_Z["O"]: 0, ELEMENT_TO_Z["H"]: 1}, + ) + + batch_ref = Batch.from_data_list( + [self._data(coord, atype, box)], device=self.device + ) + batch_ovr = Batch.from_data_list( + [self._data(coord, atype, box)], device=self.device + ) + out_ref = self._wrapper_batch_out(ref, batch_ref) + out_ovr = self._wrapper_batch_out(override, batch_ovr) + torch.testing.assert_close(out_ref["energy"], out_ovr["energy"]) + + def test_unknown_atomic_number_raises(self) -> None: + """An atomic number outside the type map raises a clear error.""" + coord, atype, box = self._system() + wrapper = DPA4Wrapper(self.model) + data = AtomicData( + atomic_numbers=torch.full( + (coord.shape[0],), 6, dtype=torch.long, device=self.device + ), # carbon: not in the O/H type map + positions=coord.clone(), + cell=box.reshape(1, 3, 3), + pbc=torch.ones(1, 3, dtype=torch.bool, device=self.device), + ) + batch = Batch.from_data_list([data], device=self.device) + compute_neighbors(batch, config=wrapper.model_config.neighbor_config) + with self.assertRaises(ValueError): + wrapper(batch) + + +if __name__ == "__main__": + unittest.main()