diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index cf59af94db..bf41735f89 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,13 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import dataclasses import functools import math from collections.abc import ( Callable, ) from typing import ( + TYPE_CHECKING, Any, ) +if TYPE_CHECKING: + from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + ) + import array_api_compat import numpy as np @@ -303,23 +310,110 @@ def forward_common_atomic( comm_dict=comm_dict, charge_spin=charge_spin, ) - ret_dict = self.apply_out_stat(ret_dict, atype) - - # nf x nloc atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc) + return self._finalize_atomic_ret(ret_dict, atom_mask, atype) + + def forward_common_atomic_graph( + self, + graph: "NeighborGraph", + atype: Array, + fparam: Array | None = None, + aparam: Array | None = None, + charge_spin: Array | None = None, + ) -> dict: + """Graph analogue of :meth:`forward_common_atomic` on the flat node axis. + + The node axis is flat ``(N,)`` (``N = sum(graph.n_node)``); masking and + out-stat operate per node. Reuses :meth:`_finalize_atomic_ret`, so + virtual-atom masking, ``atom_excl`` and ``apply_out_stat`` match the dense + path. Model-level ``pair_exclude_types`` is graph-native: when + ``self.pair_excl is not None``, an edge-keep mask is ANDed into + ``graph.edge_mask`` before the descriptor forward, so excluded type-pairs + contribute zero to the segment_sum. Descriptor-level ``exclude_types`` is + gated by ``uses_graph_lower()==False``. + + Parameters + ---------- + graph + neighbor graph for the local atoms (ghost-free) + atype + flat local atom types. N + fparam + frame parameter. nf x ndf + aparam + atomic parameter. N x nda + charge_spin + charge/spin conditioning. Unused by the dpa1 graph path; accepted so + the interface stays stable for charge/spin-conditioned descriptors. + + Returns + ------- + result_dict + the result dict on the flat node axis, defined by the `FittingOutputDef`. + + """ + xp = array_api_compat.array_namespace(graph.edge_vec) + atype = xp.asarray(atype, device=array_api_compat.device(graph.edge_vec)) + atom_mask = self.make_atom_mask(atype) # (N,) bool + atype_clamped = xp.where(atom_mask, atype, xp.zeros_like(atype)) + if self.pair_excl is not None: + keep = self.pair_excl.build_edge_exclude_mask( + graph.edge_index, atype_clamped + ) + graph = dataclasses.replace( + graph, + edge_mask=graph.edge_mask * xp.astype(keep, graph.edge_mask.dtype), + ) + ret_dict = self.forward_atomic_graph( + graph, + atype_clamped, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + return self._finalize_atomic_ret(ret_dict, atom_mask, atype) + + def _finalize_atomic_ret( + self, ret_dict: dict, atom_mask: Array, atype: Array + ) -> dict: + """Apply out-stat, atom exclusion and virtual-atom zeroing; set ``mask``. + + Shared by the dense (:meth:`forward_common_atomic`, ``(nf, nloc)`` leading + dims) and graph (:meth:`forward_common_atomic_graph`, flat ``(N,)`` leading + dim) wrappers -- leading-dim-agnostic. + + Parameters + ---------- + ret_dict + the raw per-atom result dict from ``forward_atomic``/``forward_atomic_graph`` + atom_mask + the real-atom mask, True for real and False for virtual atoms. leading dims + atype + the local atom types, used for out-stat and ``atom_excl``. leading dims + + Returns + ------- + result_dict + ``ret_dict`` with out-stat applied, virtual and excluded atoms zeroed, + and the integer ``mask`` key set. + + """ + xp = array_api_compat.array_namespace(atype) + ret_dict = self.apply_out_stat(ret_dict, atype) if self.atom_excl is not None: atom_mask = xp.logical_and( atom_mask, self.atom_excl.build_type_exclude_mask(atype) ) - + lead = atom_mask.shape # (nf, nloc) dense | (N,) graph for kk in ret_dict.keys(): - out_shape = ret_dict[kk].shape - out_shape2 = math.prod(out_shape[2:]) - tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) - tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr)) - ret_dict[kk] = xp.reshape(tmp_arr, out_shape) + out = ret_dict[kk] + # explicit trailing product (NOT -1): a zero-atom forward (nloc==0) + # has size 0, and numpy cannot infer -1 for a size-0 array. + trail = math.prod(out.shape[len(lead) :]) + flat = xp.reshape(out, (*lead, trail)) + flat = xp.where(atom_mask[..., None], flat, xp.zeros_like(flat)) + ret_dict[kk] = xp.reshape(flat, out.shape) ret_dict["mask"] = xp.astype(atom_mask, xp.int32) - return ret_dict def call( diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index a2b49f47e3..440eb75284 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -3,9 +3,15 @@ Callable, ) from typing import ( + TYPE_CHECKING, Any, ) +if TYPE_CHECKING: + from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + ) + from deepmd.dpmodel.array_api import ( Array, xp_take_first_n, @@ -248,6 +254,60 @@ def forward_atomic( ) return ret + def forward_atomic_graph( + self, + graph: "NeighborGraph", + atype: Array, + fparam: Array | None = None, + aparam: Array | None = None, + charge_spin: Array | None = None, + ) -> dict[str, Array]: + """Graph analogue of :meth:`forward_atomic` on the flat node axis. + + Runs the descriptor ``call_graph`` then the fitting ``call_graph`` PER NODE + and returns the raw fitting dict on the flat ``(N, *)`` axis (no reduction + or masking; the wrapper handles those). ``fparam`` is gathered to nodes by + ``frame_id`` so each node sees its frame's parameter. + + Parameters + ---------- + graph + neighbor graph for the local atoms (ghost-free) + atype + flat local atom types. N + fparam + frame parameter. nf x ndf + aparam + atomic parameter. N x nda + charge_spin + charge/spin conditioning. Unused by the dpa1 graph path; accepted so + the interface stays stable for charge/spin-conditioned descriptors. + + Returns + ------- + result_dict + the result dict on the flat node axis, defined by the `FittingOutputDef`. + + """ + import array_api_compat + + from deepmd.dpmodel.utils.neighbor_graph import ( + frame_id_from_n_node, + ) + + xp = array_api_compat.array_namespace(graph.edge_vec) + type_embedding = self.descriptor.type_embedding.call() + gg, rot_mat = self.descriptor.call_graph( + graph, atype, type_embedding=type_embedding + ) + fparam_node = None + if fparam is not None: + frame_id = frame_id_from_n_node(graph.n_node) + fparam_node = xp.take(fparam, frame_id, axis=0) # (N, ndf) + return self.fitting_net.call_graph( + gg, atype, gr=rot_mat, g2=None, h2=None, fparam=fparam_node, aparam=aparam + ) + def compute_or_load_stat( self, sampled_func: Callable[[], list[dict]], diff --git a/deepmd/dpmodel/atomic_model/polar_atomic_model.py b/deepmd/dpmodel/atomic_model/polar_atomic_model.py index 76a221de46..293eae62c4 100644 --- a/deepmd/dpmodel/atomic_model/polar_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/polar_atomic_model.py @@ -46,7 +46,6 @@ def apply_out_stat( out_bias, out_std = self._fetch_out_stat(self.bias_keys) if self.fitting_net.shift_diag: - nframes, nloc = atype.shape dtype = out_bias[self.bias_keys[0]].dtype device = array_api_compat.device(out_bias[self.bias_keys[0]]) for kk in self.bias_keys: @@ -57,16 +56,17 @@ def apply_out_stat( ) modified_bias = temp[atype] - # (nframes, nloc, 1) + # (..., 1) -- (nframes, nloc, 1) or (N, 1) modified_bias = ( modified_bias[..., xp.newaxis] * (self.fitting_net.scale[atype]) ) eye = xp.eye(3, dtype=dtype, device=device) - eye = xp.tile(eye, (nframes, nloc, 1, 1)) - # (nframes, nloc, 3, 3) + # leading-dim-agnostic: (nf, nloc) dense or (N,) flat graph path + eye = xp.tile(eye, (*atype.shape, 1, 1)) + # (..., 3, 3) modified_bias = modified_bias[..., xp.newaxis] * eye - # nf x nloc x odims, out_bias: ntypes x odims + # nf x nloc x odims (rect) or N x odims (flat), out_bias: ntypes x odims ret[kk] = ret[kk] + modified_bias return ret diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 2311858180..27e2d68bfc 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -4,6 +4,7 @@ Callable, ) from typing import ( + Any, NoReturn, Optional, Union, @@ -417,6 +418,21 @@ def get_numb_attn_layer(self) -> int: """Returns the number of se_atten attention layers.""" return self.se_atten.attn_layer + def uses_graph_lower(self) -> bool: + """Returns whether this descriptor supports the graph-native lower. + + The graph-native energy lower (``call_graph``) currently covers only the + non-attention (``attn_layer == 0``) factorizable path with concat + type-embedding and no type exclusion. Any other config (attention, + ``tebd_input_mode == "strip"``, ``exclude_types``) falls back to the + legacy dense path, so those models keep working unchanged. + """ + return ( + self.se_atten.attn_layer == 0 + and self.se_atten.tebd_input_mode == "concat" + and not self.se_atten.exclude_types + ) + def share_params( self, base_class: "DescrptDPA1", shared_level: int, resume: bool = False ) -> NoReturn: @@ -540,10 +556,142 @@ def call( sw The smooth switch function. """ - del mapping xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) + nloc = nlist.shape[1] + nall = xp.reshape(coord_ext, (nlist.shape[0], -1)).shape[1] // 3 + # graph-eligible configs route through the graph-native adapter (decision + # #14: graph = single math source, dense call = thin adapter). Ineligible + # configs (attention, strip tebd, exclude_types) and the ghost case with + # no mapping fall back to the legacy dense body. The graph needs `mapping` + # to fold ghosts to local owners; without it only nall == nloc is valid. + if self.uses_graph_lower() and (mapping is not None or nall == nloc): + return self._call_graph_adapter(coord_ext, atype_ext, nlist, mapping) + else: + return self._call_dense(coord_ext, atype_ext, nlist) + + def _call_graph_adapter( + self, + coord_ext: Array, + atype_ext: Array, + nlist: Array, + mapping: Array | None, + ) -> Array: + """Regime-1 dense->graph adapter (the eligible ``call`` path). + + Builds a NeighborGraph from the dense quartet with the SHAPE-STATIC + converter (``compact=False``, so this is jit/export-traceable -- no + ``nonzero``), runs :meth:`call_graph`, and reconstructs the dense-shaped + ``sw``. Preserves the dense 5-tuple ABI exactly; masked invalid edges + contribute zero in ``call_graph``'s ``segment_sum`` so the output is + identical to the legacy dense body. + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nall x 3) + atype_ext + The extended atom types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + mapping + The index mapping from extended to local region. shape: nf x nall. + ``None`` is allowed only when nall == nloc (identity mapping). + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant single-particle representation. + shape: nf x nloc x ng x 3 + g2 + ``None`` for this descriptor. + h2 + ``None`` for this descriptor. + sw + The smooth switch function. shape: nf x nloc x nnei x 1 + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, + ) + + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) + dev = array_api_compat.device(coord_ext) nf, nloc, nnei = nlist.shape nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 + coord_ext_3 = xp.reshape(coord_ext, (nf, nall, 3)) + if mapping is None: + # default identity mapping (ext == loc, e.g. no-PBC nall == nloc) + mapping_g = xp.broadcast_to( + xp.arange(nall, dtype=xp.int64, device=dev)[None, :], (nf, nall) + ) + else: + mapping_g = xp.reshape(mapping, (nf, nall)) + graph = from_dense_quartet( + coord_ext_3, nlist, mapping_g, layout=None, compact=False + ) + # local atom types, flat (nf * nloc,) + atype_local = xp.reshape(xp_take_first_n(atype_ext, 1, nloc), (nf * nloc,)) + grrg_flat, rot_mat_flat = self.call_graph( + graph, + atype_local, + type_embedding=self.type_embedding.call(), + ) + # call_graph returns flat (N, ...) node axis; reshape to (nf, nloc, ...) + # for the dense 5-tuple ABI -- this reshape is LOCAL to the adapter shim. + grrg = xp.reshape(grrg_flat, (nf, nloc, *grrg_flat.shape[1:])) + rot_mat = xp.reshape(rot_mat_flat, (nf, nloc, *rot_mat_flat.shape[1:])) + # reconstruct the dense-shaped sw the dense way (env_mat switch masked + # where nlist == -1; the graph path forbids exclude_types, so nlist_mask + # == nlist != -1, matching DescrptBlockSeAtten.call). A dense-layout + # artifact tied to neighbor slots, which the graph does not carry. + _, _, sw = self.se_atten.env_mat.call( + coord_ext, + atype_ext, + nlist, + self.se_atten.mean[...], + self.se_atten.stddev[...], + ) + nlist_mask = (nlist != -1)[:, :, :, None] + sw = xp.where(nlist_mask, sw, xp.zeros_like(sw)) + sw = xp.reshape(sw, (nf, nloc, nnei, 1)) + return grrg, rot_mat, None, None, sw + + def _call_dense( + self, + coord_ext: Array, + atype_ext: Array, + nlist: Array, + ) -> Array: + """Legacy dense descriptor body (the ineligible ``call`` path: attention, + strip tebd, exclude_types, or the no-mapping ghost case). + + Parameters + ---------- + coord_ext + The extended coordinates of atoms. shape: nf x (nall x 3) + atype_ext + The extended atom types. shape: nf x nall + nlist + The neighbor list. shape: nf x nloc x nnei + + Returns + ------- + descriptor + The descriptor. shape: nf x nloc x (ng x axis_neuron) + gr + The rotationally equivariant single-particle representation. + shape: nf x nloc x ng x 3 + g2 + ``None`` for this descriptor. + h2 + ``None`` for this descriptor. + sw + The smooth switch function. shape: nf x nloc x nnei x 1 + """ + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) + nf, nloc = nlist.shape[:2] + nall = xp.reshape(coord_ext, (nf, -1)).shape[1] // 3 type_embedding = self.type_embedding.call() # nf x nall x tebd_dim atype_embd_ext = xp.reshape( @@ -567,6 +715,54 @@ def call( ) return grrg, rot_mat, None, None, sw + def call_graph( + self, + graph: Any, + atype: Array, + type_embedding: Array | None = None, + ) -> tuple[Array, Array]: + """Descriptor-level graph-native forward (``attn_layer == 0``). + + Wraps the block kernel + :meth:`DescrptBlockSeAtten.call_graph`, adds the descriptor-level + ``concat_output_tebd`` step, and returns the outputs on the flat ``(N, + ...)`` node axis (ragged-native; no rectangular ``(nf, nloc)`` + reshape). + + This method is graph-native: it takes no dense quartet inputs and does + not produce the dense ``sw`` (that lives in the dense :meth:`call` + adapter, which has the ``nlist``/``coord_ext`` needed to build it). + + Parameters + ---------- + graph + A :class:`~deepmd.dpmodel.utils.neighbor_graph.NeighborGraph`. + atype + (N,) flat LOCAL atom types where ``N = sum(n_node)``. + type_embedding + (ntypes_with_padding, tebd_dim) type-embedding table. + + Returns + ------- + grrg : Array + (N, ng * axis_neuron [+ tebd_dim]) descriptor, flat node axis. + rot_mat : Array + (N, ng, 3) equivariant single-particle representation, flat node + axis. + """ + xp = array_api_compat.array_namespace(graph.edge_vec) + dev = array_api_compat.device(graph.edge_vec) + grrg, rot_mat = self.se_atten.call_graph( + graph, atype, type_embedding=type_embedding + ) + # FLAT node axis (N, ...): no (nf, nloc) reshape -- ragged-native, spec. + if self.concat_output_tebd: + tebd = xp.asarray(type_embedding, device=dev) + atype_local = xp.asarray(atype, device=dev) + atype_embd = xp.take(tebd, atype_local, axis=0) # (N, tebd_dim) + grrg = xp.concat([grrg, atype_embd], axis=-1) + return grrg, rot_mat + def serialize(self) -> dict: """Serialize the descriptor to dict.""" obj = self.se_atten @@ -1240,6 +1436,122 @@ def call( xp.reshape(sw, (nf, nloc, nnei, 1)), ) + def call_graph( + self, + graph: Any, + atype: Array, + type_embedding: Array | None = None, + ) -> tuple[Array, Array]: + """Graph-native forward (``attn_layer=0`` only). + + Bit-exact analogue of :meth:`call` on the SAME neighbor list, with the + neighbor-axis reduction replaced by a ``segment_sum`` over edge centers + (``dst``). Geometry enters only through ``graph.edge_vec``. + + Parameters + ---------- + graph + A :class:`~deepmd.dpmodel.utils.neighbor_graph.NeighborGraph` whose + ``edge_index = [src, dst]`` (src = neighbor local owner, dst = center), + ``edge_vec = r_src - r_dst`` and ``edge_mask`` marks real edges. + atype + (N,) flat node atom types (``N = sum(graph.n_node)``). + type_embedding + (ntypes_with_padding, tebd_dim) type-embedding table. + + Returns + ------- + grrg : Array + (N, ng * axis_neuron) per-node descriptor, matching the first output + of :meth:`call` flattened over the (nf, nloc) axes. + rot_mat : Array + (N, ng, 3) per-node equivariant single-particle representation, + matching ``gr[..., 1:]`` of :meth:`call` flattened over (nf, nloc). + + Notes + ----- + Known limitations (NeighborGraph PR-A): + - ``attn_layer == 0`` only (attention lands in PR-D); + - ``tebd_input_mode == "concat"`` only (strip mode lands later); + - ``exclude_types`` is not yet supported and raises (lands in a later PR). + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + edge_env_mat, + segment_sum, + ) + + if self.attn_layer != 0: + raise NotImplementedError( + "graph path supports attn_layer=0 only (NeighborGraph PR-A); " + "attn_layer>0 lands in PR-D" + ) + if self.tebd_input_mode not in ["concat"]: + raise NotImplementedError( + "graph path supports tebd_input_mode='concat' only (NeighborGraph PR-A)" + ) + if self.exclude_types: + raise NotImplementedError( + "graph path does not yet apply exclude_types (NeighborGraph PR-A); " + "type exclusion lands in a later PR" + ) + if type_embedding is None: + raise ValueError("type_embedding is required for the graph path") + xp = array_api_compat.array_namespace(graph.edge_vec) + dev = array_api_compat.device(graph.edge_vec) + # N == sum(graph.n_node) by contract (atype is (N,)); use the static shape + # value so the kernel stays jit/export-traceable (no concretize of n_node). + n_total = atype.shape[0] + src = graph.edge_index[0, :] + dst = graph.edge_index[1, :] + atype = xp.asarray(atype, device=dev) + center_type = xp.take(atype, dst, axis=0) # (E,) + nei_type = xp.take(atype, src, axis=0) # (E,) + # per-edge env-mat 4-vector, normalized by the center (dst) atom type. + # self.mean/self.stddev are slot-independent (ntypes, nnei, 4); slot 0 is + # the canonical per-type vector. + rr = edge_env_mat( + graph.edge_vec, + center_type, + self.mean[:, 0, :], + self.stddev[:, 0, :], + self.rcut, + self.rcut_smth, + protection=self.env_protection, + edge_mask=graph.edge_mask, + ) # (E, 4) + # radial channel + ss = rr[:, 0:1] # (E, 1) + # neighbor / center type embeddings (concat mode); ghost type == owner type + # so gathering by the LOCAL owner (src) reproduces the dense neighbor tebd. + tebd = xp.asarray(type_embedding, device=dev) + atype_embd_nlist = xp.take(tebd, nei_type, axis=0) # (E, tebd_dim) + if not self.type_one_side: + atype_embd_nnei = xp.take(tebd, center_type, axis=0) # (E, tebd_dim) + ss = xp.concat([ss, atype_embd_nlist, atype_embd_nnei], axis=-1) + else: + ss = xp.concat([ss, atype_embd_nlist], axis=-1) + # embedding net (same weights as the dense path); applies on the last axis + gg = self.embeddings[0].call(ss) # (E, ng) + # zero padding/guard edges BEFORE the segment sum + gg = gg * xp.astype(graph.edge_mask[:, None], gg.dtype) + # outer product (replaces the dense gg[:,:,:,None] * rr[:,:,None,:]) + outer = gg[:, :, None] * rr[:, None, :] # (E, ng, 4) + # neighbor-axis reduction -> segment_sum over centers; divide by nnei + gr = segment_sum(outer, dst, n_total) / self.nnei # (N, ng, 4) + gr1 = gr[:, : self.axis_neuron, :] + # nf x nloc x (ng x ng1) + grrg = xp.sum(gr[:, :, None, :] * gr1[:, None, :, :], axis=3) # (N, ng, ng1) + ng = self.neuron[-1] + grrg = xp.astype( + xp.reshape(grrg, (n_total, ng * self.axis_neuron)), + graph.edge_vec.dtype, + ) + # equivariant single-particle representation, dense-ABI slice gr[..., 1:] + # (N, ng, 3); not cast, mirroring the dense block which leaves rot_mat in + # the working precision before the descriptor-level @cast_precision. + rot_mat = gr[:, :, 1:] + return grrg, rot_mat + def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index b9129a4364..4734a6be66 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -790,3 +790,57 @@ def _call_common( if self.eval_return_middle_output and len(self.neuron) > 0: results["middle_output"] = middle_outs return results + + def call_graph( + self, + descriptor: Array, + atype: Array, + gr: Array | None = None, + g2: Array | None = None, + h2: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + ) -> dict[str, Array]: + """Graph-native (flat node axis) fitting forward. + + The node axis is flat ``(N,)``. This reuses the dense forward by treating + the node axis as ``nf'=N`` single-atom frames (``nloc'=1``) -- an internal, + encapsulated workaround, verified bit-identical to the dense call. + + Parameters + ---------- + descriptor + input descriptor. N x nd + atype + the atom type. N + gr + equivariant single-particle representation. N x ng x 3 + g2 + the rotationally invariant pair-partical representation. + unused by this fitting; passed through to the dense call. + h2 + the rotationally equivariant pair-partical representation. + unused by this fitting; passed through to the dense call. + fparam + NODE-level frame parameter (already gathered by frame_id). N x nfp + aparam + atomic parameter. N x nap + + Returns + ------- + result_dict + the fitting result on the flat node axis. each value N x *shape + + """ + import array_api_compat + + xp = array_api_compat.array_namespace(descriptor, atype) + n, nd = descriptor.shape + d1 = xp.reshape(descriptor, (n, 1, nd)) + a1 = xp.reshape(atype, (n, 1)) + g1 = None if gr is None else xp.reshape(gr, (n, 1, gr.shape[-2], 3)) + ap1 = None if aparam is None else xp.reshape(aparam, (n, 1, aparam.shape[-1])) + # fparam: dense API expects (nf, nfp); here nf'=N single-atom frames, so the + # node-level (N, nfp) IS the per-(pseudo)frame param -- tiled over nloc'=1. + ret = self.__call__(d1, a1, gr=g1, g2=g2, h2=h2, fparam=fparam, aparam=ap1) + return {kk: xp.reshape(vv, (n, *vv.shape[2:])) for kk, vv in ret.items()} diff --git a/deepmd/dpmodel/model/edge_transform_output.py b/deepmd/dpmodel/model/edge_transform_output.py new file mode 100644 index 0000000000..f9cc49f874 --- /dev/null +++ b/deepmd/dpmodel/model/edge_transform_output.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Flat-N (ragged-native) graph output transform for the dpmodel backend. + +The graph lower produces per-node outputs on the flat ``(N,)`` node axis +(``N = sum(graph.n_node)``); this reduces every reducible fitting output per +frame via ``segment_sum`` over ``frame_id``. dpmodel is energy-only (no +autograd force on the graph path), so derivative name-holders are ``None`` -- +the pt_expt backend (:mod:`deepmd.pt_expt.model.edge_transform_output`) assembles +force/virial from the same ``NeighborGraph`` via ``edge_energy_deriv``. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, +) + +import array_api_compat + +from deepmd.dpmodel.common import ( + GLOBAL_ENER_FLOAT_PRECISION, +) +from deepmd.dpmodel.output_def import ( + get_deriv_name, + get_reduce_name, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.array_api import ( + Array, + ) + from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ) + from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + ) + + +def fit_output_to_model_output_graph( + fit_ret: dict[str, Array], + fit_output_def: FittingOutputDef, + graph: NeighborGraph, + mask: Array | None = None, +) -> dict[str, Array]: + """Flat-N analogue of :func:`~deepmd.dpmodel.model.transform_output.fit_output_to_model_output`. + + Parameters + ---------- + fit_ret + the raw per-node fitting dict, each value ``(N, *shape)``. + fit_output_def + the fitting output def (drives the per-key reduction). + graph + the :class:`~deepmd.dpmodel.utils.neighbor_graph.NeighborGraph`; only + ``graph.n_node`` is used (the node->frame map for the reduction). + mask + the ``(N,)`` real-node mask for the intensive-output denominator. + + Returns + ------- + model_ret + ``fit_ret`` plus, for each reducible key, ``_redu (nf, *shape)`` via + ``segment_sum`` over ``frame_id`` (intensive ⇒ divide by the per-frame + real-node count); derivative name-holders are ``None``. + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + frame_id_from_n_node, + segment_sum, + ) + + n_node = graph.n_node + xp = array_api_compat.get_namespace(n_node) + nf = n_node.shape[0] + frame_id = frame_id_from_n_node(n_node) + model_ret = dict(fit_ret.items()) + for kk, vv in fit_ret.items(): + vdef = fit_output_def[kk] + if not vdef.reducible: + continue + kk_redu = get_reduce_name(kk) + vv_e = xp.astype(vv, GLOBAL_ENER_FLOAT_PRECISION) + redu = segment_sum(vv_e, frame_id, nf) # (nf, *shape) + if vdef.intensive: + if mask is not None: + cnt = segment_sum( + xp.astype(mask, GLOBAL_ENER_FLOAT_PRECISION), frame_id, nf + ) + else: + cnt = xp.astype(n_node, GLOBAL_ENER_FLOAT_PRECISION) + redu = redu / xp.reshape(cnt, (nf, *([1] * (redu.ndim - 1)))) + model_ret[kk_redu] = redu + if vdef.r_differentiable: + kk_derv_r, _ = get_deriv_name(kk) + model_ret[kk_derv_r] = None + if vdef.c_differentiable: + _, kk_derv_c = get_deriv_name(kk) + model_ret[kk_derv_c] = None + return model_ret diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index aba6b9fd48..24cba63558 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -42,10 +42,18 @@ NeighborList, nlist_distinguish_types, ) +from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + build_neighbor_graph, + build_neighbor_graph_ase, +) from deepmd.utils.path import ( DPPath, ) +from .edge_transform_output import ( + fit_output_to_model_output_graph, +) from .transform_output import ( communicate_extended_output, fit_output_to_model_output, @@ -259,6 +267,7 @@ def call_common( coord_corr_for_virial: Array | None = None, charge_spin: Array | None = None, neighbor_list: NeighborList | None = None, + neighbor_graph_method: str | None = None, ) -> dict[str, Array]: """Return model prediction. @@ -281,10 +290,34 @@ def call_common( The coordinates correction for virial. shape: nf x (nloc x 3) neighbor_list - The neighbor-list construction strategy. ``None`` uses the - default all-pairs builder; an alternative strategy (e.g. an O(N) - cell list) may be injected to speed up neighbor-list construction - without changing model outputs. + Neighbor-list construction strategy for the DENSE-nlist path + only. ``None`` uses the default all-pairs builder; an + alternative strategy (e.g. an O(N) cell list) may be injected to + speed up nlist construction without changing model outputs. It + is consumed by the dense lower; supplying it forces the dense + route (see below) and it is rejected together with an explicit + ``neighbor_graph_method``. + neighbor_graph_method + Selects the lower the model routes through. The option strings + refer to the neighbor-GRAPH builder, NOT the legacy dense nlist: + + - ``None`` -- default. dpmodel/jax keep the dense nlist path; + pt_expt default-flips graph-eligible mixed_types descriptors to + the carry-all graph (decision #17). + - ``"legacy"`` -- force the dense nlist path (opt out of the + default-flip). + - ``"dense"`` -- build a carry-all :class:`NeighborGraph` with the + in-tree O(N^2) ALL-PAIRS search (this is NOT the dense nlist + lower; "dense" = the all-pairs graph builder). + - ``"ase"`` -- build the carry-all graph with the O(N) ASE cell + list. + + The graph routes (``"dense"``/``"ase"``, and the pt_expt + default-flip) require a ``mixed_types`` descriptor with a graph + lower (dpa1 ``attn_layer == 0``). At non-binding ``sel`` the + graph matches the dense path exactly; at binding ``sel`` the + carry-all graph keeps neighbors the dense path truncates, so the + energy intentionally differs. Returns ------- @@ -297,23 +330,169 @@ def call_common( coord, box=box, fparam=fparam, aparam=aparam, charge_spin=charge_spin ) del coord, box, fparam, aparam, charge_spin - model_predict = model_call_from_call_lower( - call_lower=self.call_common_lower, - rcut=self.get_rcut(), - sel=self.get_sel(), - mixed_types=self.mixed_types(), - model_output_def=self.model_output_def(), - coord=cc, - atype=atype, - box=bb, + graph_method = self._resolve_graph_method(neighbor_graph_method) + # ``neighbor_list`` is a DENSE-nlist strategy; the graph path cannot + # consume it. Reject an explicit graph+nlist combination, and + # otherwise honor the supplied nlist by taking the dense route + # (don't let the pt_expt default-flip silently ignore it). + if neighbor_list is not None: + if neighbor_graph_method not in (None, "legacy"): + raise ValueError( + "neighbor_list is a dense-nlist strategy and cannot be " + f"combined with neighbor_graph_method={neighbor_graph_method!r}; " + "pass one or the other" + ) + graph_method = None + # the graph lower does not consume charge_spin yet -> keep those + # models on dense (a None check, so it stays jit/export-safe) + if cs is not None: + graph_method = None + if graph_method is not None: + # carry-all NeighborGraph energy forward (Option B / decision #17) + model_predict = self._call_common_graph( + cc, + atype, + bb, + fp, + ap, + graph_method, + do_atomic_virial, + ) + else: + # legacy dense-nlist path (builds the extended quartet) + model_predict = model_call_from_call_lower( + call_lower=self.call_common_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=cc, + atype=atype, + box=bb, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + coord_corr_for_virial=coord_corr_for_virial, + charge_spin=cs, + neighbor_list=neighbor_list, + ) + model_predict = self._output_type_cast(model_predict, input_prec) + return model_predict + + def _resolve_graph_method( + self, neighbor_graph_method: str | None + ) -> str | None: + """Resolve the neighbor-graph method. + + Base (dpmodel/jax): ``None`` => the dense path. These backends compute + force/virial ANALYTICALLY inside ``call_common`` (``energy_derv_r`` in + the output); the carry-all graph lower here is ENERGY-only, so it is + NOT used by default (it would drop force). ``"legacy"`` => dense; + explicit ``"dense"``/``"ase"`` => opt into the (energy-only) graph. + + pt_expt OVERRIDES this so ``None`` defaults graph-eligible mixed_types + descriptors to the carry-all graph (decision #17) -- pt_expt has the + autograd ``forward_common_lower_graph`` that produces force/virial. + + Parameters + ---------- + neighbor_graph_method + The user-requested method: ``None`` (default), ``"legacy"`` + (force dense), or ``"dense"``/``"ase"`` (force the graph builder). + + Returns + ------- + method + The resolved method passed to :meth:`_call_common_graph`, or + ``None`` to take the dense path. + """ + if neighbor_graph_method == "legacy": + return None + return neighbor_graph_method + + def _call_common_graph( + self, + cc: Array, + atype: Array, + bb: Array | None, + fp: Array | None, + ap: Array | None, + method: str, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + """Carry-all graph forward (opt-in, Option B). + + Builds a carry-all :class:`NeighborGraph` from ``cc``/``atype``/``bb`` + and routes the forward through the OUTPUT-AGNOSTIC + :meth:`call_lower_graph`. Input/output type-casting is done by the + caller. + + Parameters + ---------- + cc + coordinates. nf x nloc x 3 (or nf x (nloc x 3)) + atype + the atom types. nf x nloc + bb + the simulation cell. nf x 3 x 3, or ``None`` for non-periodic. + fp + the frame parameter. nf x ndf + ap + the atomic parameter. nf x nloc x nda + method + the carry-all builder, ``"dense"`` or ``"ase"``. + do_atomic_virial + whether to calculate the atomic virial. + + Returns + ------- + model_predict + the standard model dict mirroring the dense ``call_common`` keys + (```` per-atom, ``_redu`` reduced, derivative + name-holders ``None``, plus the int ``mask``). + """ + descriptor = getattr(self.atomic_model, "descriptor", None) + uses_graph_lower = getattr(descriptor, "uses_graph_lower", lambda: False) + if not (self.mixed_types() and uses_graph_lower()): + raise NotImplementedError( + "neighbor_graph_method requires a mixed_types descriptor with a " + "graph lower (e.g. dpa1 attn_layer=0)" + ) + if method == "dense": + ng = build_neighbor_graph(cc, atype, bb, self.get_rcut()) + elif method == "ase": + ng = build_neighbor_graph_ase(cc, atype, bb, self.get_rcut()) + else: + raise ValueError( + f"unknown neighbor_graph_method {method!r}; use 'dense' or 'ase'" + ) + xp = array_api_compat.array_namespace(atype) + nf, nloc = atype.shape[:2] + # OUTPUT-AGNOSTIC standard model dict (````, ``_redu``, + # derivative name-holders ``None``, plus int ``mask``), like the + # dense ``call_common``. ``call_lower_graph`` masks virtual atoms + # (atype<0) and sets the real int mask. + model_predict = self.call_lower_graph( + atype=xp.reshape(atype, (nf * nloc,)), + n_node=ng.n_node, + edge_index=ng.edge_index, + edge_vec=ng.edge_vec, + edge_mask=ng.edge_mask, fparam=fp, aparam=ap, - do_atomic_virial=do_atomic_virial, - coord_corr_for_virial=coord_corr_for_virial, - charge_spin=cs, - neighbor_list=neighbor_list, ) - model_predict = self._output_type_cast(model_predict, input_prec) + # Public ABI is rectangular (nf, nloc, *); the lower is flat + # (N=nf*nloc, *). Unravel per-atom keys here at the boundary. + # public call_common always passes rectangular (nf,nloc) coord/atype (N == nf*nloc), so this unravel always applies; ragged graphs reach call_lower_graph/forward_common_lower_graph directly (no unravel) and stay flat (N,*). + for k in list(model_predict.keys()): + v = model_predict[k] + # per-frame reduced keys (..._redu) keep their (nf, *) shape; only node-level (N,*) keys unravel — guards the nloc==1 case where N == nf. + if ( + v is not None + and not k.endswith("_redu") + and v.shape[:1] == (nf * nloc,) + ): + model_predict[k] = xp.reshape(v, (nf, nloc, *v.shape[1:])) return model_predict def call_common_lower( @@ -423,6 +602,155 @@ def forward_common_atomic( mask=atomic_ret["mask"] if "mask" in atomic_ret else None, ) + def forward_common_atomic_graph( + self, + atype: Array, + n_node: Array, + edge_index: Array, + edge_vec: Array, + edge_mask: Array, + n_local: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + comm_dict: dict | None = None, + charge_spin: Array | None = None, + ) -> dict[str, Array]: + """Model-level graph forward (no type cast). Analogue of the dense + :meth:`forward_common_atomic`. + + Builds a :class:`NeighborGraph` from the flat edge fields, runs the + atomic model's :meth:`forward_common_atomic_graph` (flat ``(N, *)`` + per-node output), then the flat-N output transform (per-frame + ``segment_sum`` reduction; derivative name-holders ``None`` -- + force/virial come from the pt_expt autograd lower). The + ``(nf, nloc)`` unravel for the public ABI happens in the caller + (:meth:`_call_common_graph`). + + Parameters + ---------- + atype + (N,) flat LOCAL atom types, ``N == sum(n_node)``. + n_node + (nf,) per-frame local atom counts. + edge_index + (2, E) ``[src, dst]`` edge endpoints (flat local indices). + edge_vec + (E, 3) neighbor-minus-center edge vectors. + edge_mask + (E,) boolean/0-1 valid-edge mask. + n_local + Per-rank local atom counts for multi-rank inference. Ignored in + PR-A (single-rank); accepted for ABI stability. + fparam + Frame parameter, ``(nf, ndf)``. + aparam + Atomic parameter, ``(N, nda)``. + comm_dict + MPI communication metadata. Ignored in PR-A; accepted for ABI + stability. + charge_spin + charge/spin conditioning. Ignored in PR-A; accepted for ABI + stability with charge/spin-conditioned descriptors. + + Returns + ------- + dict + The standard model dict (```` per-node, ``_redu`` + reduced, derivative name-holders ``None``), matching + :func:`fit_output_to_model_output_graph`. + """ + graph = NeighborGraph( + n_node=n_node, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) + atomic_ret = self.atomic_model.forward_common_atomic_graph( + graph, atype, fparam=fparam, aparam=aparam, charge_spin=charge_spin + ) + return fit_output_to_model_output_graph( + atomic_ret, + self.atomic_output_def(), + graph, + mask=atomic_ret["mask"] if "mask" in atomic_ret else None, + ) + + def call_common_lower_graph( + self, + atype: Array, + n_node: Array, + edge_index: Array, + edge_vec: Array, + edge_mask: Array, + n_local: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + comm_dict: dict | None = None, + charge_spin: Array | None = None, + ) -> dict[str, Array]: + """Graph-native PUBLIC lower (PR-A: dpa1 ``attn_layer == 0``). + + The PRIMARY directly-callable graph interface (spec decision #14). + Casts inputs/outputs to/from the model precision exactly like the + dense :meth:`call_common_lower` (``edge_vec`` is the geometry, in + place of ``coord``), then runs :meth:`forward_common_atomic_graph`. + OUTPUT-AGNOSTIC: every fitting (energy/dos/dipole/polar/property/...) + flows through with no change on the fitting side; force/virial are + produced by the pt_expt autograd lower. Must match the dense + :meth:`call_common_lower` reduction on the SAME neighbor set. + + Parameters + ---------- + atype + (N,) flat LOCAL atom types, ``N == sum(n_node)``. + n_node + (nf,) per-frame local atom counts. + edge_index + (2, E) ``[src, dst]`` edge endpoints (flat local indices). + edge_vec + (E, 3) neighbor-minus-center edge vectors. + edge_mask + (E,) boolean/0-1 valid-edge mask. + n_local + Per-rank local atom counts for multi-rank inference. Ignored in + PR-A (single-rank); accepted for ABI stability. + fparam + Frame parameter, ``(nf, ndf)``. + aparam + Atomic parameter, ``(N, nda)``. + comm_dict + MPI communication metadata. Ignored in PR-A; accepted for ABI + stability. + charge_spin + charge/spin conditioning. Ignored in PR-A; accepted for ABI + stability with charge/spin-conditioned descriptors. + + Returns + ------- + dict + The standard model dict in the INPUT precision. + """ + edge_vec, _, fparam, aparam, cs, input_prec = self._input_type_cast( + edge_vec, fparam=fparam, aparam=aparam, charge_spin=charge_spin + ) + model_predict = self.forward_common_atomic_graph( + atype, + n_node, + edge_index, + edge_vec, + edge_mask, + n_local=n_local, + fparam=fparam, + aparam=aparam, + comm_dict=comm_dict, + charge_spin=cs, + ) + model_predict = self._output_type_cast(model_predict, input_prec) + return model_predict + + # backward-compat alias (mirrors ``call_lower = call_common_lower``) + call_lower_graph = call_common_lower_graph + call = call_common call_lower = call_common_lower diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index 0179543dd4..66dd8fe3c1 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -22,8 +22,11 @@ GraphLayout, NeighborGraph, build_neighbor_graph, + build_neighbor_graph_ase, + edge_env_mat, edge_force_virial, from_dense_quartet, + neighbor_graph_from_ijs, node_validity_mask, pad_and_guard_edges, segment_mean, @@ -89,8 +92,10 @@ "aggregate", "build_multiple_neighbor_list", "build_neighbor_graph", + "build_neighbor_graph_ase", "build_neighbor_list", "compute_total_numb_batch", + "edge_env_mat", "edge_force_virial", "extend_coord_with_ghosts", "from_dense_quartet", @@ -103,6 +108,7 @@ "make_fitting_network", "make_multilayer_network", "make_neighbor_stat_data", + "neighbor_graph_from_ijs", "nlist_distinguish_types", "node_validity_mask", "normalize_coord", diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index 80153129cc..1cc4ea7479 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -42,25 +42,24 @@ def build_type_exclude_mask( Parameters ---------- atype - The extended aotm types. shape: nf x natom + The atom types. shape: nf x natom (dense) or N (graph / flat) Returns ------- mask - The type exclusion mask for atoms. shape: nf x natom - Element [ff,ii] being 0 if type(ii) is excluded, - otherwise being 1. + The type exclusion mask for atoms, same shape as ``atype``. + Element being 0 if the type is excluded, otherwise being 1. """ xp = array_api_compat.array_namespace(atype) - nf, natom = atype.shape + lead = atype.shape # (nf, natom) dense | (N,) graph return xp.reshape( xp.take( xp.asarray(self.type_mask[...], device=array_api_compat.device(atype)), xp.reshape(atype, (-1,)), axis=0, ), - (nf, natom), + lead, ) @@ -159,5 +158,37 @@ def build_type_exclude_mask( ) return mask + def build_edge_exclude_mask(self, edge_index: Array, atype: Array) -> Array: + """Graph-native pair exclusion: per-edge keep mask (1 keep, 0 exclude). + + Parameters + ---------- + edge_index + (2, E) [src, dst]; src = neighbor, dst = center; into [0, N). + atype + (N,) flat local node types (clamped >= 0). + + Returns + ------- + mask + (E,) int. ``type_mask[atype[dst]*(ntypes+1) + atype[src]]``. + + """ + xp = array_api_compat.array_namespace(atype) + if len(self.exclude_types) == 0: + return xp.ones( + (edge_index.shape[1],), + dtype=xp.int32, + device=array_api_compat.device(atype), + ) + src_t = xp.take(atype, edge_index[0, :], axis=0) + dst_t = xp.take(atype, edge_index[1, :], axis=0) + type_ij = dst_t * (self.ntypes + 1) + src_t + return xp.take( + xp.asarray(self.type_mask[...], device=array_api_compat.device(atype)), + type_ij, + axis=0, + ) + def __contains__(self, item: tuple[int, int]) -> bool: return item in self.exclude_types diff --git a/deepmd/dpmodel/utils/neighbor_graph/__init__.py b/deepmd/dpmodel/utils/neighbor_graph/__init__.py index 08b165f861..6e041805b2 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/__init__.py +++ b/deepmd/dpmodel/utils/neighbor_graph/__init__.py @@ -6,9 +6,12 @@ + edge padding), ``builder`` (the carry-all ``build_neighbor_graph`` dispatcher + the ``from_dense_quartet`` legacy converter), ``segment`` (mask-aware segment-reduction toolkit), and ``derivatives`` (edge force/virial assembly). -See memory/spec_unified_edge_nlist.md. +See the design discussion wanghan-iapcm/deepmd-kit#4. """ +from .ase_builder import ( + build_neighbor_graph_ase, +) from .builder import ( build_neighbor_graph, from_dense_quartet, @@ -16,9 +19,16 @@ from .derivatives import ( edge_force_virial, ) +from .env import ( + edge_env_mat, +) +from .from_ijs import ( + neighbor_graph_from_ijs, +) from .graph import ( GraphLayout, NeighborGraph, + frame_id_from_n_node, node_validity_mask, pad_and_guard_edges, ) @@ -31,8 +41,12 @@ "GraphLayout", "NeighborGraph", "build_neighbor_graph", + "build_neighbor_graph_ase", + "edge_env_mat", "edge_force_virial", + "frame_id_from_n_node", "from_dense_quartet", + "neighbor_graph_from_ijs", "node_validity_mask", "pad_and_guard_edges", "segment_mean", diff --git a/deepmd/dpmodel/utils/neighbor_graph/ase_builder.py b/deepmd/dpmodel/utils/neighbor_graph/ase_builder.py new file mode 100644 index 0000000000..bc7312fcab --- /dev/null +++ b/deepmd/dpmodel/utils/neighbor_graph/ase_builder.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Carry-all NeighborGraph builder backed by ASE's O(N) cell list (optional dep). + +``build_neighbor_graph_ase`` is a carry-all search backend: it uses ASE's +``neighbor_list("ijS", ...)`` to enumerate EVERY neighbor within ``rcut`` (no +``sel`` cutoff), then routes the resulting sparse ``(i, j, S)`` edge list through +:func:`neighbor_graph_from_ijs` so ``edge_vec`` is recomputed differentiably from +``coord``/``box`` -- ASE's own distance vectors are intentionally NOT used, to +keep the geometry convention and autograd leaf consistent with every other +builder. ASE is an OPTIONAL dependency, imported lazily inside the function. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, + Any, +) + +import numpy as np + +from .from_ijs import ( + neighbor_graph_from_ijs, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.array_api import ( + Array, + ) + + from .graph import ( + GraphLayout, + NeighborGraph, + ) + + +def build_neighbor_graph_ase( + coord: Array, + atype: Array, + box: Array | None, + rcut: float, + layout: GraphLayout | None = None, +) -> NeighborGraph: + """Build a CARRY-ALL NeighborGraph using ASE's O(N) cell-list search. + + Per frame, ASE ``neighbor_list("ijS", atoms, rcut)`` returns center ``i``, + neighbor ``j`` and periodic shift ``S`` such that the neighbor image sits at + ``positions[j] + S @ cell``. These map directly to the graph convention + (src=neighbor=j, dst=center=i), and the edge list is fed to + :func:`neighbor_graph_from_ijs` which recomputes ``edge_vec`` from + ``coord``/``box`` (ASE's distance vectors are discarded for convention + + differentiability consistency). + + Parameters + ---------- + coord + (nf, nloc, 3) local coordinates. + atype + (nf, nloc) local atom types (unused for the search; carried for API parity). + box + (nf, 3, 3) simulation cell, or ``None`` for non-periodic. + rcut + cutoff radius. + layout + edge-axis length policy; ``None`` => dynamic (torch) with ``min_edges`` guards. + + Returns + ------- + graph + The carry-all :class:`NeighborGraph` over the LOCAL atoms + (``n_node = nloc`` per frame), with ``edge_vec`` recomputed + differentiably from ``coord``/``box``. + + Raises + ------ + ImportError + if the optional ``ase`` package is not installed. + """ + try: + from ase import ( + Atoms, + ) + from ase.neighborlist import ( + neighbor_list, + ) + except ImportError as e: + raise ImportError( + "build_neighbor_graph_ase requires the optional 'ase' package; " + "install ase or use neighbor-graph method 'dense'." + ) from e + + # The ASE topology search runs on the CPU in numpy; convert safely from a + # CUDA / grad-requiring torch tensor (the original coord/box are still + # passed to neighbor_graph_from_ijs below, which recomputes edge_vec + # differentiably on the native backend/device). + def _to_cpu_numpy(x: Any) -> np.ndarray: + return np.asarray(x.detach().cpu()) if hasattr(x, "detach") else np.asarray(x) + + coord_np = _to_cpu_numpy(coord) + nf, nloc = coord_np.shape[:2] + coord_np = coord_np.reshape(nf, nloc, 3) + box_np = _to_cpu_numpy(box).reshape(nf, 3, 3) if box is not None else None + periodic = box is not None + + i_parts = [] + j_parts = [] + S_parts = [] + nframe_parts = [] + for f in range(nf): + atoms = Atoms( + positions=coord_np[f], + cell=(box_np[f] if periodic else None), + pbc=periodic, + ) + ii, jj, SS = neighbor_list("ijS", atoms, rcut) + i_parts.append(np.asarray(ii, dtype=np.int64)) + j_parts.append(np.asarray(jj, dtype=np.int64)) + S_parts.append(np.asarray(SS, dtype=np.int64).reshape(-1, 3)) + nframe_parts.append(np.full((len(ii),), f, dtype=np.int64)) + + i_all = np.concatenate(i_parts) if i_parts else np.zeros((0,), dtype=np.int64) + j_all = np.concatenate(j_parts) if j_parts else np.zeros((0,), dtype=np.int64) + S_all = np.concatenate(S_parts) if S_parts else np.zeros((0, 3), dtype=np.int64) + nframe_all = ( + np.concatenate(nframe_parts) if nframe_parts else np.zeros((0,), dtype=np.int64) + ) + + return neighbor_graph_from_ijs( + i_all, j_all, S_all, coord, box, nframe_all, nloc, layout=layout + ) diff --git a/deepmd/dpmodel/utils/neighbor_graph/builder.py b/deepmd/dpmodel/utils/neighbor_graph/builder.py index 9a10d3f805..71ca699e1b 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/builder.py +++ b/deepmd/dpmodel/utils/neighbor_graph/builder.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Builders/converters that produce a :class:`NeighborGraph`. -Two distinct groups (see memory/spec_unified_edge_nlist.md decision #17), kept +Two distinct groups (see the design discussion wanghan-iapcm/deepmd-kit#4 decision #17), kept separate so a consumer can never assume completeness while a function silently truncated: @@ -56,13 +56,17 @@ def from_dense_quartet( nlist: Array, mapping: Array, layout: GraphLayout | None = None, + compact: bool = True, ) -> NeighborGraph: """Convert a legacy extended quartet into a ghost-free NeighborGraph (CONVERTER). This is a backward-compat CONVERTER (World 1 -> graph): it performs NO neighbor search and INHERITS the ``sel`` truncation already baked into ``nlist``. Use it only when a caller (an MD code, or the legacy dense path) already holds a - built quartet; for the carry-all graph use :func:`build_neighbor_graph`. + built quartet. In contrast, the carry-all graph builders search from RAW + coordinates and apply NO ``sel`` truncation: :func:`build_neighbor_graph` + (the ``neighbor_graph_method="dense"`` all-pairs route) and + :func:`build_neighbor_graph_ase` (the ``"ase"`` O(N) cell-list route). For each valid neighbor slot it emits one edge with ``src = mapping[neighbor]`` (the neighbor's LOCAL owner -> ghost-free index), ``dst = center`` (local), and @@ -85,6 +89,22 @@ def from_dense_quartet( (nf, nall) extended -> local-owner index (local atoms map to themselves). layout edge-axis length policy; ``None`` => dynamic (torch) with ``min_edges`` guards. + compact + If True (default), COMPACT real edges with ``nonzero`` and pad/guard via + :func:`pad_and_guard_edges` -- the data-dependent output shape breaks + jax.jit / torch.export. If False, emit a SHAPE-STATIC graph: every nlist + slot becomes an edge (``E = nf * nloc * nsel``, a static shape), invalid + slots (``nlist == -1``) get ``edge_mask=False``, zero ``edge_vec`` and a + ``src`` pointing at the center (in-range, masked) -- so no ``nonzero`` is + used and the converter is jit/export-traceable. The masked edges contribute + zero in a downstream ``segment_sum``, so the descriptor output is unchanged. + + Returns + ------- + graph + The :class:`NeighborGraph` over the LOCAL atoms (``n_node = nloc`` per + frame): ``edge_index`` ``[src, dst]`` in local indices, ``edge_vec`` the + neighbor-minus-center displacement, and ``edge_mask`` flagging real edges. """ if layout is None: layout = GraphLayout() @@ -92,43 +112,89 @@ def from_dense_quartet( dev = array_api_compat.device(extended_coord) nf, nloc, nsel = nlist.shape nall = extended_coord.shape[1] - # per-slot (nf, nloc, nsel) index grids, flattened frame-major - ff_grid = xp.broadcast_to( - xp.reshape(xp.arange(nf, dtype=xp.int64, device=dev), (nf, 1, 1)), - (nf, nloc, nsel), - ) - center_grid = xp.broadcast_to( - xp.reshape(xp.arange(nloc, dtype=xp.int64, device=dev), (1, nloc, 1)), - (nf, nloc, nsel), - ) - ff_flat = xp.reshape(ff_grid, (-1,)) - center_flat = xp.reshape(center_grid, (-1,)) - nl_flat = xp.reshape(nlist, (-1,)) - keep = xp.reshape(xp.nonzero(nl_flat >= 0)[0], (-1,)) - ff_k = xp.take(ff_flat, keep, axis=0) - dst_local = xp.take(center_flat, keep, axis=0) # center index in [0, nloc) - j_ext = xp.take(nl_flat, keep, axis=0) # neighbor index in [0, nall) - # cross-frame gathers via flat (frame * nall + idx) indices; centers are the - # first nloc extended atoms (local atoms precede ghosts). - ec_flat = xp.reshape(extended_coord, (nf * nall, 3)) - map_flat = xp.reshape(mapping, (nf * nall,)) - g_nei = ff_k * nall + j_ext - g_cen = ff_k * nall + dst_local - src_local = xp.take(map_flat, g_nei, axis=0) # local owner of the neighbor - edge_vec = xp.take(ec_flat, g_nei, axis=0) - xp.take(ec_flat, g_cen, axis=0) - edge_index = xp.astype( - xp.stack([ff_k * nloc + src_local, ff_k * nloc + dst_local], axis=0), xp.int64 - ) - edge_index, edge_vec, edge_mask = pad_and_guard_edges( - edge_index, edge_vec, layout.edge_capacity, layout.min_edges - ) - n_node = xp.full((nf,), nloc, dtype=xp.int64, device=dev) - return NeighborGraph( - n_node=n_node, - edge_index=edge_index, - edge_vec=edge_vec, - edge_mask=edge_mask, - ) + if not compact: + if layout.edge_capacity is not None: + raise NotImplementedError( + "shape-static from_dense_quartet pads to E=nf*nloc*nsel; " + "edge_capacity unsupported here" + ) + # (E,) flat grids, E = nf*nloc*nsel, row-major (frame, center, slot) + ff = xp.reshape( + xp.broadcast_to( + xp.reshape(xp.arange(nf, dtype=xp.int64, device=dev), (nf, 1, 1)), + (nf, nloc, nsel), + ), + (-1,), + ) + center = xp.reshape( + xp.broadcast_to( + xp.reshape(xp.arange(nloc, dtype=xp.int64, device=dev), (1, nloc, 1)), + (nf, nloc, nsel), + ), + (-1,), + ) + nl = xp.reshape(nlist, (-1,)) # neighbor ext idx or -1 + valid = nl >= 0 # (E,) bool <-- the mask + j_safe = xp.where(valid, nl, xp.zeros_like(nl)) # clamp -1 -> 0 (avoid OOB) + ec_flat = xp.reshape(extended_coord, (nf * nall, 3)) + map_flat = xp.reshape(mapping, (nf * nall,)) + g_nei = ff * nall + j_safe + g_cen = ff * nall + center + src_local = xp.take(map_flat, g_nei, axis=0) + edge_vec = xp.take(ec_flat, g_nei, axis=0) - xp.take(ec_flat, g_cen, axis=0) + edge_vec = edge_vec * xp.astype(valid[:, None], edge_vec.dtype) # zero invalid + src = xp.where(valid, ff * nloc + src_local, ff * nloc + center) # -> center + dst = ff * nloc + center + edge_index = xp.astype(xp.stack([src, dst], axis=0), xp.int64) + edge_mask = valid + n_node = xp.full((nf,), nloc, dtype=xp.int64, device=dev) + return NeighborGraph( + n_node=n_node, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) + else: + # COMPACT: drop invalid slots via nonzero (dynamic shape -> eager only, + # NOT jit/export-traceable) then pad/guard. + # per-slot (nf, nloc, nsel) index grids, flattened frame-major + ff_grid = xp.broadcast_to( + xp.reshape(xp.arange(nf, dtype=xp.int64, device=dev), (nf, 1, 1)), + (nf, nloc, nsel), + ) + center_grid = xp.broadcast_to( + xp.reshape(xp.arange(nloc, dtype=xp.int64, device=dev), (1, nloc, 1)), + (nf, nloc, nsel), + ) + ff_flat = xp.reshape(ff_grid, (-1,)) + center_flat = xp.reshape(center_grid, (-1,)) + nl_flat = xp.reshape(nlist, (-1,)) + keep = xp.reshape(xp.nonzero(nl_flat >= 0)[0], (-1,)) + ff_k = xp.take(ff_flat, keep, axis=0) + dst_local = xp.take(center_flat, keep, axis=0) # center index in [0, nloc) + j_ext = xp.take(nl_flat, keep, axis=0) # neighbor index in [0, nall) + # cross-frame gathers via flat (frame * nall + idx) indices; centers are + # the first nloc extended atoms (local atoms precede ghosts). + ec_flat = xp.reshape(extended_coord, (nf * nall, 3)) + map_flat = xp.reshape(mapping, (nf * nall,)) + g_nei = ff_k * nall + j_ext + g_cen = ff_k * nall + dst_local + src_local = xp.take(map_flat, g_nei, axis=0) # local owner of the neighbor + edge_vec = xp.take(ec_flat, g_nei, axis=0) - xp.take(ec_flat, g_cen, axis=0) + edge_index = xp.astype( + xp.stack([ff_k * nloc + src_local, ff_k * nloc + dst_local], axis=0), + xp.int64, + ) + edge_index, edge_vec, edge_mask = pad_and_guard_edges( + edge_index, edge_vec, layout.edge_capacity, layout.min_edges + ) + n_node = xp.full((nf,), nloc, dtype=xp.int64, device=dev) + return NeighborGraph( + n_node=n_node, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) def build_neighbor_graph( diff --git a/deepmd/dpmodel/utils/neighbor_graph/derivatives.py b/deepmd/dpmodel/utils/neighbor_graph/derivatives.py index 1c0bafc234..494e97a0c9 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/derivatives.py +++ b/deepmd/dpmodel/utils/neighbor_graph/derivatives.py @@ -4,7 +4,7 @@ The autograd that produces g_e (grad(E, edge_vec)) is wired in the torch/jax backend later; this pure-array-API assembly is shared by all backends. -Conventions (see memory/spec_unified_edge_nlist.md): +Conventions (see the unified edge-nlist design discussion, wanghan-iapcm/deepmd-kit#4): edge_vec_e = r_src - r_dst ; F_k = sum_{dst=k} g - sum_{src=k} g per-edge virial w_e = -g_e (x) edge_vec_e atom virial attributed FULL-TO-src (canonical TF==pt-legacy convention) @@ -13,16 +13,25 @@ Padding/guard edges (edge_mask == 0) are zeroed before any scatter. """ -import array_api_compat +from __future__ import ( + annotations, +) -from deepmd.dpmodel.array_api import ( - Array, +from typing import ( + TYPE_CHECKING, ) +import array_api_compat + from .segment import ( segment_sum, ) +if TYPE_CHECKING: + from deepmd.dpmodel.array_api import ( + Array, + ) + def edge_force_virial( g_e: Array, @@ -40,6 +49,15 @@ def edge_force_virial( Parameters ---------- + g_e + (E, 3) per-edge gradient ``dE/d(edge_vec)``. + edge_vec + (E, 3) per-edge displacement ``r_src - r_dst``; padding edges are zero. + edge_index + (2, E) ``[src, dst]`` node endpoints of each edge. + edge_mask + (E,) boolean valid-edge mask; padding/guard edges (``False``) are zeroed + before any scatter. n_node (nf,) per-frame REAL node counts. Real nodes occupy the compact prefix ``[0, sum(n_node))`` frame-major; ``nf = n_node.shape[0]``. @@ -62,8 +80,9 @@ def edge_force_virial( frame via the frame of their ``dst`` node. """ xp = array_api_compat.array_namespace(g_e) - n_real = int(xp.sum(n_node)) # real node count - n_out = n_real if node_capacity is None else int(node_capacity) # node-axis size + # node-axis size; when a static ``node_capacity`` is supplied (the jax/export + # path) short-circuit so we never call int() on the traced ``sum(n_node)``. + n_out = int(node_capacity) if node_capacity is not None else int(xp.sum(n_node)) nf = n_node.shape[0] # zero padding/guard contributions; cast mask to g's dtype (array-API pure, # CLAUDE.md mask-multiply guideline — avoids bool*float under array_api_strict) diff --git a/deepmd/dpmodel/utils/neighbor_graph/env.py b/deepmd/dpmodel/utils/neighbor_graph/env.py new file mode 100644 index 0000000000..55bbe1b02f --- /dev/null +++ b/deepmd/dpmodel/utils/neighbor_graph/env.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Per-edge environment-matrix 4-vector, the graph-native analogue of +EnvMat.call (deepmd/dpmodel/utils/env_mat.py). + +Computes, per edge, [1/r, dx/r^2, dy/r^2, dz/r^2] * smooth_weight, then +normalizes by (davg, dstd) indexed by the edge's CENTER (dst) atom type. +Stats are (ntypes, 4) — slot-independent — which is valid because +EnvMatStatSe tiles a single per-type vector across all nnei slots +(``np.tile(davgunit, [nsel, 1])``), so the slot axis carries no information. +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, +) + +import array_api_compat + +from deepmd.dpmodel.utils.env_mat import ( + compute_smooth_weight, +) +from deepmd.dpmodel.utils.safe_gradient import ( + safe_for_vector_norm, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.array_api import ( + Array, + ) + + +def edge_env_mat( + edge_vec: Array, + center_type: Array, + davg: Array, + dstd: Array, + rcut: float, + rcut_smth: float, + protection: float = 0.0, + edge_mask: Array | None = None, +) -> Array: + """Compute the per-edge environment-matrix 4-vector. + + Mirrors the math in ``_make_env_mat`` / ``EnvMat.call`` (env_mat.py) + for a single edge batch instead of a dense (nf, nloc, nnei) tensor. + + Parameters + ---------- + edge_vec + (E, 3) displacement vectors r_src - r_dst (neighbor minus center); + padding edges must have ``edge_vec = 0``. + center_type + (E,) int — atom type of the center (dst) atom for each edge. + davg + (ntypes, 4) per-center-type mean (slot-independent). + dstd + (ntypes, 4) per-center-type inverse-std (slot-independent). + rcut + Outer cutoff radius. + rcut_smth + Inner radius where the smooth switch begins. + protection + Small additive offset to avoid exact division-by-zero on + atoms that are numerically at the same position (default 0). + edge_mask + (E,) boolean valid-edge mask. When provided, the length of INVALID + (padding) edges has 1 added to it before adding ``protection`` --- + matching the dense ``_make_env_mat`` (``length = length + ~mask``), + which guards padding by mask rather than by a length threshold. + When ``None``, fall back to the ``length < 1e-10`` zero-guard + (back-compat for callers without a mask). + + Returns + ------- + Array + (E, 4) normalized environment-matrix vectors. + Padding edges (``edge_vec = 0``) produce nonzero values but are + masked by ``NeighborGraph.edge_mask`` downstream. + """ + xp = array_api_compat.array_namespace(edge_vec) + dev = array_api_compat.device(edge_vec) + + # ── geometry ─────────────────────────────────────────────────────────── + # (E, 1) lengths; safe_for_vector_norm returns 0 for zero vectors + length = safe_for_vector_norm(edge_vec, axis=-1, keepdims=True) + + # Guard against 1/0 on padding edges. When an edge_mask is provided, + # match the dense _make_env_mat exactly: add 1 to the length of INVALID + # (padding) edges by mask (not by a length threshold), so a real edge and + # a padding edge never share the same protection arithmetic. Otherwise + # fall back to the length<1e-10 zero-guard (back-compat). + if edge_mask is not None: + length = length + xp.astype(xp.logical_not(edge_mask)[:, None], length.dtype) + else: + length = xp.where(length < 1e-10, xp.ones_like(length), length) + + denom = length + protection # (E, 1) + t0 = 1.0 / denom # (E, 1) — radial component + t1 = edge_vec / (denom**2) # (E, 3) — angular components + + # ── smooth switch (same polynomial as compute_smooth_weight) ─────────── + # length has shape (E, 1); compute_smooth_weight broadcasts over any shape + sw = compute_smooth_weight(length, rcut_smth, rcut) # (E, 1) + + # ── raw (unnormalized) env-mat ───────────────────────────────────────── + em = xp.concat([t0, t1], axis=-1) * sw # (E, 4) + + # ── per-type normalization (indexed by center-atom type) ─────────────── + # davg/dstd must be asarray'd to ensure device placement when called with + # numpy stats on a torch/jax edge_vec. + avg = xp.take(xp.asarray(davg, device=dev), center_type, axis=0) # (E, 4) + std = xp.take(xp.asarray(dstd, device=dev), center_type, axis=0) # (E, 4) + + return (em - avg) / std diff --git a/deepmd/dpmodel/utils/neighbor_graph/from_ijs.py b/deepmd/dpmodel/utils/neighbor_graph/from_ijs.py new file mode 100644 index 0000000000..0136dd4bf4 --- /dev/null +++ b/deepmd/dpmodel/utils/neighbor_graph/from_ijs.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Sparse ``(i, j, S)`` edge-list converter to :class:`NeighborGraph`. + +``neighbor_graph_from_ijs`` is the canonical sparse converter: it takes an +already-built sparse edge list -- per-edge center ``i``, neighbor ``j`` (both +per-frame LOCAL indices in ``[0, nloc)``) and integer periodic-image shift ``S`` +-- and emits a :class:`NeighborGraph` whose ``edge_vec`` is recomputed +DIFFERENTIABLY from ``coord``/``box`` (it never trusts the builder's distance +vectors). It is the format-conversion step shared by every O(N) search backend +(ASE/vesin/LAMMPS): a backend searches, then hands its ``(i, j, S)`` here. + +Convention (matching :mod:`...graph`): ``edge_index = [src, dst]`` with +``src = j`` (neighbor's local owner), ``dst = i`` (center), and +``edge_vec = r_j + S @ box - r_i`` (neighbor image minus center). +""" + +from __future__ import ( + annotations, +) + +from typing import ( + TYPE_CHECKING, +) + +import array_api_compat + +from .graph import ( + GraphLayout, + NeighborGraph, + pad_and_guard_edges, +) + +if TYPE_CHECKING: + from deepmd.dpmodel.array_api import ( + Array, + ) + + +def neighbor_graph_from_ijs( + i: Array, + j: Array, + S: Array, + coord: Array, + box: Array | None, + nframe_id: Array, + nloc: int, + layout: GraphLayout | None = None, +) -> NeighborGraph: + """Convert a sparse ``(i, j, S)`` edge list into a :class:`NeighborGraph`. + + ``edge_vec`` is recomputed from ``coord``/``box`` (NOT from any distance vector + the search backend may carry), so it is a differentiable function of the input + coordinates and follows the graph convention exactly. + + Parameters + ---------- + i + (E,) int per-edge center, per-frame LOCAL index in ``[0, nloc)``. + j + (E,) int per-edge neighbor, per-frame LOCAL index in ``[0, nloc)``. + S + (E, 3) int periodic-image shift: the neighbor sits at ``coord[j] + S @ box``. + coord + (nf, nloc, 3) local coordinates. + box + (nf, 3, 3) simulation cell, or ``None`` for non-periodic (``S`` ignored). + nframe_id + (E,) int frame index of each edge. + nloc + number of local atoms per frame (used for the frame-major node offset). + layout + edge-axis length policy; ``None`` => dynamic (torch) with ``min_edges`` guards. + + Returns + ------- + NeighborGraph + ``edge_index = [j + nframe_id*nloc, i + nframe_id*nloc]`` (src=neighbor, + dst=center); ``edge_vec = coord[j] + S@box - coord[i]``; ``n_node`` is + ``nloc`` per frame. + """ + if layout is None: + layout = GraphLayout() + xp = array_api_compat.array_namespace(coord) + dev = array_api_compat.device(coord) + nf = coord.shape[0] + coord = xp.reshape(coord, (nf, nloc, 3)) + i = xp.astype(xp.asarray(i, device=dev), xp.int64) + j = xp.astype(xp.asarray(j, device=dev), xp.int64) + nframe_id = xp.astype(xp.asarray(nframe_id, device=dev), xp.int64) + # flat frame-major node indices + i_flat = i + nframe_id * nloc + j_flat = j + nframe_id * nloc + coord_flat = xp.reshape(coord, (nf * nloc, 3)) + r_i = xp.take(coord_flat, i_flat, axis=0) + r_j = xp.take(coord_flat, j_flat, axis=0) + edge_vec = r_j - r_i + if box is not None: + box = xp.asarray(box, device=dev) + box = xp.reshape(box, (nf, 3, 3)) + box_per_edge = xp.take(box, nframe_id, axis=0) # (E, 3, 3) + S = xp.astype(xp.asarray(S, device=dev), box.dtype) + # S @ box per edge via broadcast sum (NEVER np.einsum, which breaks on torch): + # shift[e, b] = sum_a S[e, a] * box[e, a, b] + shift = xp.sum(S[:, :, None] * box_per_edge, axis=1) # (E, 3) + edge_vec = edge_vec + shift + edge_index = xp.stack([j_flat, i_flat], axis=0) + edge_index, edge_vec, edge_mask = pad_and_guard_edges( + edge_index, edge_vec, layout.edge_capacity, layout.min_edges + ) + n_node = xp.full((nf,), nloc, dtype=xp.int64, device=dev) + return NeighborGraph( + n_node=n_node, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) diff --git a/deepmd/dpmodel/utils/neighbor_graph/graph.py b/deepmd/dpmodel/utils/neighbor_graph/graph.py index 232145bda0..e527a84bf0 100644 --- a/deepmd/dpmodel/utils/neighbor_graph/graph.py +++ b/deepmd/dpmodel/utils/neighbor_graph/graph.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Backend-agnostic edge-graph neighbor-list contract (NeighborGraph) and its -length policy (GraphLayout). See memory/spec_unified_edge_nlist.md. +length policy (GraphLayout). See the design discussion wanghan-iapcm/deepmd-kit#4. Node validity (real vs padding) is NOT a stored field: it is derived as ``arange(N) < sum(n_node)`` because ``n_node`` already encodes the real-node @@ -75,11 +75,32 @@ def pad_and_guard_edges( """Append padding/guard edges as a contiguous suffix and build edge_mask. Real edges (``edge_index``/``edge_vec``) stay at the front (compact layout). - - ``capacity is None`` (torch dynamic): append exactly ``min_edges`` masked - dummy edges so the edge axis has a known lower bound and shape-stable - guards for export. - - ``capacity`` set (jax static): pad to ``E_max = capacity``; raise on overflow. Dummy edges point at node ``pad_value`` (in-range) with zero ``edge_vec``. + + Parameters + ---------- + edge_index + (2, E_real) ``[src, dst]`` node endpoints of the real edges. + edge_vec + (E_real, 3) per-edge displacement of the real edges. + capacity + Target edge-axis length ``E_max``. ``None`` (torch dynamic) appends + exactly ``min_edges`` masked dummy edges so the axis has a known lower + bound and shape-stable guards for export; an int (jax static) pads to + ``E_max = capacity`` and raises ``ValueError`` on overflow. + min_edges + Number of dummy edges appended when ``capacity is None``. + pad_value + Node index the dummy edges point at (must be in range). + + Returns + ------- + edge_index + (2, target) padded edge endpoints. + edge_vec + (target, 3) padded edge displacements (dummy rows zero). + edge_mask + (target,) boolean mask, ``True`` for the real-edge prefix. """ xp = array_api_compat.array_namespace(edge_index) dev = array_api_compat.device(edge_index) @@ -102,11 +123,62 @@ def pad_and_guard_edges( return ei, ev, edge_mask +def frame_id_from_n_node(n_node: Array, n_total: int | None = None) -> Array: + """Node->frame map for a flat node axis: ``repeat(arange(nf), n_node)``. + + Implemented via ``searchsorted(cumulative_sum(n_node), arange(N), side="right")`` + -- the same primitives used in ``edge_force_virial`` for per-frame virial. + + Parameters + ---------- + n_node + Per-frame node counts. Shape ``(nf,)``. + n_total + Size of the (possibly padded) flat node axis ``N``. ``None`` (the + numpy/eager default) falls back to ``int(sum(n_node))``; pass a STATIC + value to keep the function trace-friendly under jax.jit / export, where + ``int()`` on the traced sum is not allowed (mirrors + :func:`node_validity_mask`). Padding nodes ``[sum(n_node), n_total)`` + are CLAMPED to the last frame (``nf - 1``) so a downstream + ``segment_sum(..., num_segments=nf)`` stays in range; they carry no real + edge, so this assignment is unused downstream. + + Returns + ------- + frame_id + Frame index of each flat node, compact-prefix frame-major. + Shape ``(n_total,)`` int64 (``n_total = sum(n_node)`` when not padded). + """ + xp = array_api_compat.array_namespace(n_node) + dev = array_api_compat.device(n_node) + if n_total is None: + n_total = int(xp.sum(n_node)) + nf = n_node.shape[0] + idx = xp.arange(n_total, dtype=n_node.dtype, device=dev) + boundaries = xp.cumulative_sum(n_node) # (nf,) upper bounds, exclusive + frame_id = xp.astype(xp.searchsorted(boundaries, idx, side="right"), xp.int64) + # padding nodes (idx >= sum(n_node)) land at frame ``nf`` (OOB); clamp them to + # the last real frame so the per-frame scatter never indexes out of range. + return xp.minimum(frame_id, xp.asarray(nf - 1, dtype=xp.int64, device=dev)) + + def node_validity_mask(n_node: Array, n_total: int) -> Array: """Derive the (n_total,) real-vs-padding node mask from per-frame counts. Compact-prefix layout: the first ``sum(n_node)`` nodes are real, the rest are padding. jit-safe (no Python ``int`` cast on the traced sum). + + Parameters + ---------- + n_node + (nf,) per-frame REAL node counts. + n_total + Size of the (possibly padded) flat node axis ``N``. + + Returns + ------- + mask + (n_total,) boolean mask, ``True`` for the real-node compact prefix. """ xp = array_api_compat.array_namespace(n_node) idx = xp.arange(n_total, dtype=n_node.dtype, device=array_api_compat.device(n_node)) diff --git a/deepmd/pt_expt/model/edge_transform_output.py b/deepmd/pt_expt/model/edge_transform_output.py new file mode 100644 index 0000000000..565e155157 --- /dev/null +++ b/deepmd/pt_expt/model/edge_transform_output.py @@ -0,0 +1,197 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Autograd assembly: graph energy -> force/virial/atom_virial via grad(E, edge_vec). + +torch-only. The pure-array scatter (edge_force_virial) is shared with dpmodel; +this module supplies the single backward pass that produces g_e = dE/d(edge_vec). +""" + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + get_deriv_name, + get_reduce_name, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + edge_force_virial, + frame_id_from_n_node, + segment_sum, +) +from deepmd.pt.utils import ( + env, +) + + +def edge_energy_deriv( + energy: torch.Tensor, + edge_vec: torch.Tensor, + edge_index: torch.Tensor, + edge_mask: torch.Tensor, + n_node: torch.Tensor, + do_atomic_virial: bool = False, + create_graph: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: + """Return (force, atom_virial_or_None, virial) from a graph energy. + + g_e = dE/d(edge_vec) via one torch.autograd.grad, then the shared + edge_force_virial scatter. + + Parameters + ---------- + energy + the reduced per-frame energy to differentiate. ``(nf,)`` (or scalar). + edge_vec + (E, 3) per-edge displacement; the autograd leaf of ``energy``. + edge_index + (2, E) ``[src, dst]`` edge endpoints. + edge_mask + (E,) valid-edge mask. + n_node + (nf,) per-frame node counts. + do_atomic_virial + whether to materialize the per-atom virial (else ``None`` is returned). + create_graph + whether the backward retains a graph (training, for second-order grad). + + Returns + ------- + force + (N, 3) per-node force. + atom_virial + (N, 3, 3) per-node virial when ``do_atomic_virial`` else ``None``. + virial + (nf, 3, 3) per-frame virial (always computed). + """ + (g_e,) = torch.autograd.grad( + energy.sum() if energy.dim() else energy, + edge_vec, + create_graph=create_graph, + retain_graph=True, + ) + force, atom_virial, virial = edge_force_virial( + g_e, edge_vec, edge_index, edge_mask, n_node + ) + return force, (atom_virial if do_atomic_virial else None), virial + + +def fit_output_to_model_output_graph( + fit_ret: dict[str, torch.Tensor], + fit_output_def: FittingOutputDef, + graph: NeighborGraph, + do_atomic_virial: bool = False, + create_graph: bool = True, + mask: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + """Graph analogue of the dense pt_expt ``fit_output_to_model_output``. + + OUTPUT-AGNOSTIC: reduces EVERY reducible fitting output (cast to energy + precision, summed/averaged per frame via ``segment_sum`` over ``frame_id``) + and, for every reducible + ``r_differentiable`` output, assembles + per-component force / virial / (optional) atom-virial from + :func:`edge_energy_deriv` (one ``grad`` w.r.t. ``edge_vec`` per scalar + component, then the shared full-to-``src`` scatter). + + All per-atom outputs stay FLAT with leading dimension ``N = sum(n_node)``: + ```` is ``(N, *shape)``, ``_derv_r`` is ``(N, *shape, 3)``, + ``_derv_c`` is ``(N, *shape, 9)``. Per-frame reductions have leading + dimension ``nf``: ``_redu`` is ``(nf, *shape)``, + ``_derv_c_redu`` is ``(nf, *shape, 9)``. + + Parameters + ---------- + fit_ret + Raw flat fitting output, ``(N, *shape)`` per key (``N = sum(n_node)``). + fit_output_def + The fitting output definition. + graph + the :class:`~deepmd.dpmodel.utils.neighbor_graph.NeighborGraph`. Its + ``edge_vec`` MUST be the autograd leaf for ``fit_ret`` (the force backward + differentiates the reduced energy w.r.t. it); ``edge_index``/``edge_mask`` + define the scatter, ``n_node`` the node->frame map. + do_atomic_virial + Whether to also assemble the per-atom virial ``_derv_c``. + create_graph + Whether the backward retains a graph (training). + mask + (N,) flat realness mask; used only for intensive-output reduction. + + Returns + ------- + model_ret + ``fit_ret`` plus, for each reducible key, the per-frame reduction + ``_redu`` ``(nf, *shape)`` and -- for ``r_differentiable`` keys -- + the FLAT per-atom force ``_derv_r`` ``(N, *shape, 3)``, the + per-frame virial ``_derv_c_redu`` ``(nf, *shape, 9)``, and (when + ``do_atomic_virial``) the per-atom virial ``_derv_c`` + ``(N, *shape, 9)``. + """ + edge_vec = graph.edge_vec + edge_index = graph.edge_index + edge_mask = graph.edge_mask + n_node = graph.n_node + redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION + nf = int(n_node.shape[0]) + N = int(n_node.sum()) + frame_id = frame_id_from_n_node(n_node) # (N,) int64 frame index per atom + model_ret: dict[str, torch.Tensor] = dict(fit_ret.items()) + for kk, vv in fit_ret.items(): + vdef = fit_output_def[kk] + shap = vdef.shape + if not vdef.reducible: + continue + kk_redu = get_reduce_name(kk) + # segment_sum reduces axis 0 (the flat atom axis) per frame + vv_e = vv.to(redu_prec) # (N, *shape) + redu = segment_sum(vv_e, frame_id, nf) # (nf, *shape) + if vdef.intensive: + if mask is not None: + # real-atom count per frame: segment_sum of the mask + cnt = segment_sum(mask.to(redu_prec), frame_id, nf) # (nf,) + # broadcast cnt to (nf, 1, ..., 1) to match redu shape + cnt = cnt.reshape(nf, *([1] * (redu.ndim - 1))) + else: + cnt = n_node.to(redu_prec).reshape(nf, *([1] * (redu.ndim - 1))) + redu = redu / cnt + model_ret[kk_redu] = redu + if not vdef.r_differentiable: + continue + kk_derv_r, kk_derv_c = get_deriv_name(kk) + size = 1 + for ii in shap: + size *= ii + # split the reduced output into ``size`` per-frame scalar components. + svv = model_ret[kk_redu].reshape(nf, size) + ff_list: list[torch.Tensor] = [] + av_list: list[torch.Tensor] = [] + vir_list: list[torch.Tensor] = [] + for c in range(size): + force, atom_vir, vir = edge_energy_deriv( + svv[:, c], + edge_vec, + edge_index, + edge_mask, + n_node, + do_atomic_virial=(vdef.c_differentiable and do_atomic_virial), + create_graph=create_graph, + ) + # force (N, 3) -> (N, 1, 3) [flat; caller unravels at I/O boundary] + ff_list.append(force.reshape(N, 1, 3)) + if vdef.c_differentiable: + # virial (nf, 3, 3) -> (nf, 1, 9) + vir_list.append(vir.reshape(nf, 1, 9)) + if do_atomic_virial: + assert atom_vir is not None + # atom_virial (N, 3, 3) -> (N, 1, 9) [flat] + av_list.append(atom_vir.reshape(N, 1, 9)) + # (N, size, 3) -> (N, *shape, 3) + model_ret[kk_derv_r] = torch.cat(ff_list, dim=-2).reshape([N, *shap, 3]) + if vdef.c_differentiable: + # (nf, size, 9) -> (nf, *shape, 9) + model_ret[kk_derv_c + "_redu"] = torch.cat(vir_list, dim=-2).reshape( + [nf, *shap, 9] + ) + if do_atomic_virial: + # (N, size, 9) -> (N, *shape, 9) + model_ret[kk_derv_c] = torch.cat(av_list, dim=-2).reshape([N, *shap, 9]) + return model_ret diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 878ed21a38..50ede240e4 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -24,6 +24,9 @@ torch_module, ) +from .edge_transform_output import ( + fit_output_to_model_output_graph, +) from .transform_output import ( fit_output_to_model_output, ) @@ -277,6 +280,215 @@ def forward_common_lower( """Forward common lower delegates to call_common_lower().""" return self.call_common_lower(*args, **kwargs) + def forward_common_lower_graph( + self, + atype: torch.Tensor, + n_node: torch.Tensor, + edge_index: torch.Tensor, + edge_vec: torch.Tensor, + edge_mask: torch.Tensor, + do_atomic_virial: bool = False, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Graph-native lower with autograd force/virial (PR-A: dpa1 ``attn_layer==0``). + + OUTPUT-AGNOSTIC: runs the graph descriptor + fitting forward with + ``edge_vec`` as the autograd leaf (via the inherited + :meth:`forward_common_atomic_graph`), then routes the raw flat + ``atomic_ret`` through :func:`fit_output_to_model_output_graph`, which + reduces EVERY reducible output via ``segment_sum`` and assembles + force / per-frame virial / (optional) atom-virial for every + ``r_differentiable`` output from a backward pass w.r.t. ``edge_vec`` + (the shared full-to-``src`` scatter). This makes any fitting + (energy/dos/dipole/polar/property/...) flow through the graph path + with no change on the fitting side. + + All per-atom outputs stay FLAT with leading dimension + ``N = sum(n_node)``; per-frame reductions have leading dimension + ``nf``. Callers that need rectangular ``(nf, nloc, *)`` output + (e.g. :meth:`_call_common_graph` where ``atype`` is rectangular) + unravel at the public I/O boundary. + + Parameters + ---------- + atype + (N,) flat LOCAL atom types, ``N == sum(n_node)``. + n_node + (nf,) per-frame local atom counts. + edge_index + (2, E) ``[src, dst]`` edge endpoints (flat local indices). + edge_vec + (E, 3) neighbor-minus-center edge vectors. + edge_mask + (E,) valid-edge mask. + do_atomic_virial + Whether to also return the per-atom virial ``_derv_c``. + fparam + Frame parameter, ``(nf, ndf)``. + aparam + Atomic parameter, ``(nf, nloc, nda)``. + charge_spin + charge/spin conditioning. Ignored in PR-A; accepted for ABI + stability with charge/spin-conditioned descriptors. + + Returns + ------- + dict + Flat model dict: ```` (N, *shape), ``_redu`` + (nf, *shape), and -- for ``r_differentiable`` outputs -- + ``_derv_r`` (N, *shape, 3), ``_derv_c_redu`` + (nf, *shape, 9), and -- when ``do_atomic_virial`` -- + ``_derv_c`` (N, *shape, 9). + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + ) + + # make edge_vec the autograd leaf for the energy backward + edge_vec = edge_vec.detach().requires_grad_(True) + graph = NeighborGraph( + n_node=n_node, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) + atomic_ret = self.atomic_model.forward_common_atomic_graph( + graph, + atype, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + ) + # ``forward_common_atomic_graph`` returns flat ``(N, *)`` output. + # Pass directly to the flat-N transform; no rectangular reshape needed. + return fit_output_to_model_output_graph( + atomic_ret, + self.atomic_output_def(), + graph, + do_atomic_virial=do_atomic_virial, + create_graph=self.training, + mask=atomic_ret["mask"] if "mask" in atomic_ret else None, + ) + + def _resolve_graph_method( + self, neighbor_graph_method: str | None + ) -> str | None: + """pt_expt default-flip (decision #17): ``None`` => carry-all graph for + graph-eligible mixed_types descriptors, else dense. Unlike dpmodel/jax, + pt_expt has the autograd ``forward_common_lower_graph`` that produces + force/virial on the graph, so the graph can be the DEFAULT here. + ``"legacy"`` forces dense; explicit ``"dense"``/``"ase"`` force the graph. + + Parameters + ---------- + neighbor_graph_method + The user-requested method: ``None`` (default-flip), ``"legacy"`` + (force dense), or ``"dense"``/``"ase"`` (force the graph builder). + + Returns + ------- + method + The resolved method passed to :meth:`_call_common_graph`, or + ``None`` to take the dense path. + """ + if neighbor_graph_method == "legacy": + return None + if neighbor_graph_method is not None: + return neighbor_graph_method + # Linear/ZBL atomic models have no single ``descriptor`` -> dense. + descriptor = getattr(self.atomic_model, "descriptor", None) + uses_graph_lower = getattr(descriptor, "uses_graph_lower", lambda: False) + if self.mixed_types() and uses_graph_lower(): + return "dense" + return None + + def _call_common_graph( + self, + cc: torch.Tensor, + atype: torch.Tensor, + bb: torch.Tensor | None, + fp: torch.Tensor | None, + ap: torch.Tensor | None, + method: str, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Carry-all graph forward with autograd force/virial (pt_expt override). + + Builds the carry-all :class:`NeighborGraph` in TORCH (the array-API + builder runs natively and yields a differentiable ``edge_vec``), then + routes through :meth:`forward_common_lower_graph` so force / virial / + (optional) atom-virial are produced via autograd. + + Parameters + ---------- + cc + coordinates. nf x nloc x 3 (or nf x (nloc x 3)) + atype + the atom types. nf x nloc + bb + the simulation cell. nf x 3 x 3, or ``None`` for non-periodic. + fp + the frame parameter. nf x ndf + ap + the atomic parameter. nf x nloc x nda + method + the carry-all builder, ``"dense"`` or ``"ase"``. + do_atomic_virial + whether to calculate the atomic virial. + + Returns + ------- + model_predict + the standard model dict using the SAME internal key names as the + legacy dense :meth:`call_common` output (``energy``, + ``energy_redu``, ``energy_derv_r``, ``energy_derv_c_redu``, and + ``energy_derv_c`` when ``do_atomic_virial``). + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + build_neighbor_graph, + build_neighbor_graph_ase, + ) + + rcut = self.get_rcut() + if method == "dense": + ng = build_neighbor_graph(cc, atype, bb, rcut) + elif method == "ase": + ng = build_neighbor_graph_ase(cc, atype, bb, rcut) + else: + raise ValueError( + f"unknown neighbor_graph_method {method!r}; use 'dense' or 'ase'" + ) + nf, nloc = atype.shape[:2] + atype_flat = atype.reshape(nf * nloc) + model_predict = self.forward_common_lower_graph( + atype_flat, + ng.n_node, + ng.edge_index, + ng.edge_vec, + ng.edge_mask, + do_atomic_virial=do_atomic_virial, + fparam=fp, + aparam=ap, + ) + # ``forward_common_lower_graph`` returns flat ``(N, *)`` per-atom + # outputs (N = nf * nloc for a carry-all rectangular graph). + # Unravel to rectangular ``(nf, nloc, *)`` at the public I/O boundary + # so that callers receive the same shape as the dense ``call_common``. + N = nf * nloc + # public call_common always passes rectangular (nf,nloc) coord/atype (N == nf*nloc), so this unravel always applies; ragged graphs reach call_lower_graph/forward_common_lower_graph directly (no unravel) and stay flat (N,*). + for k in list(model_predict.keys()): + v = model_predict[k] + # per-frame reduced keys (..._redu) keep their (nf, *) shape; only node-level (N,*) keys unravel — guards the nloc==1 case where N == nf. + if ( + v is not None + and not k.endswith("_redu") + and v.shape[:1] == torch.Size([N]) + ): + model_predict[k] = v.reshape(nf, nloc, *v.shape[1:]) + return model_predict + def forward_common_atomic( self, extended_coord: torch.Tensor, diff --git a/source/tests/common/dpmodel/case_single_frame_with_nlist.py b/source/tests/common/dpmodel/case_single_frame_with_nlist.py index 7ec92a1de1..3995bd20f4 100644 --- a/source/tests/common/dpmodel/case_single_frame_with_nlist.py +++ b/source/tests/common/dpmodel/case_single_frame_with_nlist.py @@ -69,7 +69,7 @@ def setUp(self) -> None: [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) self.mapping = np.concatenate( - [self.mapping, self.mapping[:, self.perm]], axis=0 + [self.mapping, inv_perm[self.mapping[:, self.perm]]], axis=0 ) # permute the nlist nlist1 = self.nlist[:, self.perm[: self.nloc], :] diff --git a/source/tests/common/dpmodel/test_call_lower_graph.py b/source/tests/common/dpmodel/test_call_lower_graph.py new file mode 100644 index 0000000000..29dd87c55a --- /dev/null +++ b/source/tests/common/dpmodel/test_call_lower_graph.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Energy-level parity between the graph-native model lower +(``CM.call_lower_graph``) and the dense ``EnergyModel.call_lower`` on the SAME +neighbor list (regime-1: ``from_dense_quartet`` reproduces the nlist neighbors). + +PR-A is dpa1(attn_layer=0) energy-only; force/virial come from pt_expt autograd +in a later task, so this only checks ``energy`` (reduced per-frame) and +``atom_energy`` (per-atom). +""" + +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, +) +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +class TestCallLowerGraph(unittest.TestCase): + def _make_model(self): + ds = DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=[30], + ntypes=2, + attn_layer=0, + axis_neuron=2, + neuron=[6, 12], + ) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + return EnergyModel(ds, ft, type_map=["foo", "bar"]) + + def setUp(self) -> None: + rng = np.random.default_rng(2) + self.nloc = 4 + self.coord = rng.normal(size=(1, self.nloc, 3)) * 1.5 + self.atype = np.array([[0, 1, 0, 1]], dtype=np.int64) + + def test_graph_lower_matches_dense_lower(self) -> None: + """Graph model lower energy/atom_energy match the dense lower on the same nlist.""" + model = self._make_model() + ( + ext_coord, + ext_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + model.get_rcut(), + model.get_sel(), + mixed_types=model.mixed_types(), + box=None, + ) + + # dense ``call_common_lower`` returns the INTERNAL model_output_def keys + # (``energy`` per-atom, ``energy_redu`` reduced), matching the + # OUTPUT-AGNOSTIC graph lower. + dense = model.call_common_lower(ext_coord, ext_atype, nlist, mapping) + + ng = from_dense_quartet(ext_coord, nlist, mapping) + nloc = nlist.shape[1] + out = model.call_lower_graph( + atype=ext_atype.reshape(-1)[:nloc], + n_node=ng.n_node, + edge_index=ng.edge_index, + edge_vec=ng.edge_vec, + edge_mask=ng.edge_mask, + ) + + # reduced per-frame energy + np.testing.assert_allclose( + out["energy_redu"], dense["energy_redu"], rtol=1e-12, atol=1e-12 + ) + # per-atom energy + np.testing.assert_allclose( + out["energy"].reshape(dense["energy"].shape), + dense["energy"], + rtol=1e-12, + atol=1e-12, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/common/dpmodel/test_dpa1_call_graph_block.py b/source/tests/common/dpmodel/test_dpa1_call_graph_block.py new file mode 100644 index 0000000000..e8930101dd --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa1_call_graph_block.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Bit-exact parity between the graph-native ``DescrptBlockSeAtten.call_graph`` +(attn_layer=0) and the legacy dense ``DescrptBlockSeAtten.call`` on the SAME +neighbor list, for binding AND non-binding ``sel``. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, +) +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +class TestDpa1BlockCallGraph: + def _make(self, sel, type_one_side=False): + return DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=sel, + ntypes=2, + attn_layer=0, + axis_neuron=2, + neuron=[6, 12], + type_one_side=type_one_side, + ) + + def setup_method(self) -> None: + rng = np.random.default_rng(1) + self.nloc = 4 + self.coord = rng.normal(size=(1, self.nloc, 3)) * 1.5 + self.atype = np.array([[0, 1, 0, 1]], dtype=np.int64) + + @pytest.mark.parametrize("type_one_side", [False, True]) # tebd concat branch + @pytest.mark.parametrize("sel", [[20], [3]]) # non-binding AND binding + def test_block_graph_equals_dense_any_sel(self, sel, type_one_side) -> None: + """Graph block output is bit-exact with the dense block on the same nlist. + + ``type_one_side`` toggles the concat branch in the block: when True the + per-edge feature concatenates only the NEIGHBOR tebd (no center tebd), + so both the graph and dense paths must agree for either branch. + """ + dd = self._make(sel, type_one_side=type_one_side) + blk = dd.se_atten + # build the dense nlist exactly as the descriptor would + ( + ext_coord, + ext_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=None, + ) + # type embedding as both paths use it + tebd = dd.type_embedding.call() + nf, nall = ext_atype.shape + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (nf, nall, dd.tebd_dim), + ) + dense_g, *_ = blk.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + ng = from_dense_quartet(ext_coord, nlist, mapping) + graph_g, _rot_mat = blk.call_graph( + ng, + np.reshape(ext_atype, (-1,)), + type_embedding=tebd, + ) + np.testing.assert_allclose( + graph_g.reshape(dense_g.shape), + dense_g, + rtol=1e-12, + atol=1e-12, + ) + + def test_attn_layer_gt0_raises(self) -> None: + """The graph block kernel fail-fasts for attn_layer > 0 (unsupported).""" + dd = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[20], ntypes=2, attn_layer=2) + with pytest.raises(NotImplementedError): + dd.se_atten.call_graph(None, np.array([0], dtype=np.int64)) + + def test_exclude_types_raises(self) -> None: + """The graph block kernel fail-fasts for exclude_types (not yet applied).""" + # the graph path does not yet apply type exclusion; it must fail-fast + # rather than silently diverge from the dense path (which masks edges). + dd = DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=[20], + ntypes=2, + attn_layer=0, + exclude_types=[(0, 1)], + ) + ng = from_dense_quartet( + self.coord, + -np.ones((1, self.nloc, 1), dtype=np.int64), # any graph; guard fires first + np.arange(self.nloc, dtype=np.int64)[None], + ) + with pytest.raises(NotImplementedError): + dd.se_atten.call_graph( + ng, self.atype.reshape(-1), type_embedding=dd.type_embedding.call() + ) diff --git a/source/tests/common/dpmodel/test_dpa1_call_graph_descriptor.py b/source/tests/common/dpmodel/test_dpa1_call_graph_descriptor.py new file mode 100644 index 0000000000..dc1d51da91 --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa1_call_graph_descriptor.py @@ -0,0 +1,202 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Full 5-tuple ABI parity between the graph-routed ``DescrptDPA1.call`` +(attn_layer=0, which now goes ``from_dense_quartet -> call_graph``) and the +legacy dense descriptor output captured BEFORE the swap, for binding AND +non-binding ``sel``. + +The dense reference is reconstructed by calling the BLOCK directly +(``dd.se_atten.call``) and applying the descriptor-level ``concat_output_tebd`` +step by hand (mirroring dpa1.py), because ``dd.call`` itself now routes through +the graph for ``attn_layer == 0``. +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +class TestDpa1DescriptorCallGraph: + def _make(self, sel): + return DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=sel, + ntypes=2, + attn_layer=0, + axis_neuron=2, + neuron=[6, 12], + ) + + def setup_method(self) -> None: + rng = np.random.default_rng(2) + self.nloc = 4 + self.coord = rng.normal(size=(1, self.nloc, 3)) * 1.5 + self.atype = np.array([[0, 1, 0, 1]], dtype=np.int64) + + def _dense_reference(self, dd, ext_coord, ext_atype, nlist): + """Reconstruct the original dense descriptor 5-tuple (pre-swap).""" + tebd = dd.type_embedding.call() + nf, nall = ext_atype.shape + atype_embd_ext = np.reshape( + np.take(tebd, np.reshape(ext_atype, (-1,)), axis=0), + (nf, nall, dd.tebd_dim), + ) + grrg, g2, h2, rot_mat, sw = dd.se_atten.call( + nlist, + ext_coord, + ext_atype, + atype_embd_ext=atype_embd_ext, + mapping=None, + type_embedding=tebd, + ) + nloc = nlist.shape[1] + # descriptor-level concat_output_tebd (mirror dpa1.py) + atype_embd = atype_embd_ext[:, :nloc, :] + if dd.concat_output_tebd: + grrg = np.concatenate( + [grrg, np.reshape(atype_embd, (nf, nloc, dd.tebd_dim))], axis=-1 + ) + return grrg, rot_mat, None, None, sw + + @pytest.mark.parametrize("sel", [[30], [4]]) # non-binding AND binding + def test_descriptor_graph_equals_dense_full_tuple(self, sel) -> None: + """Graph-routed dd.call() returns the identical dense 5-tuple ABI.""" + dd = self._make(sel) + ( + ext_coord, + ext_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=None, + ) + # dense reference captured via the block (pre-swap behaviour) + ref = self._dense_reference(dd, ext_coord, ext_atype, nlist) + # the swapped public ABI: routes through the graph + out = dd.call(ext_coord, ext_atype, nlist, mapping=mapping) + assert len(out) == 5 + # grrg + np.testing.assert_allclose(out[0], ref[0], rtol=1e-12, atol=1e-12) + # rot_mat + np.testing.assert_allclose(out[1], ref[1], rtol=1e-12, atol=1e-12) + # positions [2], [3] are always None for this descriptor + assert out[2] is None + assert out[3] is None + # sw + np.testing.assert_allclose(out[4], ref[4], rtol=1e-12, atol=1e-12) + + @pytest.mark.parametrize( + "kwargs", + [ + {"tebd_input_mode": "strip"}, # strip tebd: graph unsupported -> dense + {"exclude_types": [(0, 1)]}, # type exclusion: graph unsupported -> dense + ], + ) + def test_ineligible_config_falls_back_to_dense(self, kwargs) -> None: + """attn_layer=0 configs the graph can't handle (strip tebd, exclude_types) + must report uses_graph_lower()=False and run the dense body without + raising (regression: Task-3 routing previously raised NotImplementedError). + """ + dd = DescrptDPA1( + rcut=4.0, rcut_smth=0.5, sel=[30], ntypes=2, attn_layer=0, **kwargs + ) + assert dd.uses_graph_lower() is False + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=None, + ) + out = dd.call(ext_coord, ext_atype, nlist, mapping=mapping) # must not raise + assert len(out) == 5 + + def test_eligible_no_mapping_with_ghosts_falls_back(self) -> None: + """An eligible (concat) attn_layer=0 descriptor called with mapping=None + on a PERIODIC system (nall > nloc ghosts) must fall back to the dense + body and match it (regression: the graph needs mapping for ghosts, the + identity-mapping default previously indexed out of range). + """ + dd = self._make([30]) + box = np.eye(3, dtype=np.float64)[None] * 6.0 + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=box, + ) + assert ext_atype.shape[1] > self.nloc # ghosts present + ref = self._dense_reference(dd, ext_coord, ext_atype, nlist) + out = dd.call(ext_coord, ext_atype, nlist, mapping=None) # must not IndexError + np.testing.assert_allclose(out[0], ref[0], rtol=1e-12, atol=1e-12) + + def test_single_rank_extension_keeps_type_invariant(self) -> None: + """The ghost-free graph types a neighbor as ``atype[mapping[neighbor]]`` + (its local owner). This is correct because a real single-rank extension + is type-consistent: ``extend_coord_with_ghosts`` tiles the local atype, so + ``atype_ext[k] == atype[mapping[k]]`` for every extended atom -- a ghost is + a periodic image of its owner and shares its type. This test pins that + invariant (an inconsistent ``mapping`` like the universal fixture's old + buggy permutation is NOT a valid single-rank extension) and confirms the + graph-routed ``call`` matches dense on the resulting quartet. + """ + dd = self._make([30]) + box = np.eye(3, dtype=np.float64)[None] * 6.0 + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=box, + ) + assert ext_atype.shape[1] > self.nloc # ghosts present + # the single-rank type invariant the ghost-free graph relies on + nf, nall = ext_atype.shape + for f in range(nf): + np.testing.assert_array_equal( + ext_atype[f], ext_atype[f][mapping[f]] + ) # atype_ext[k] == atype[mapping[k]] + ref = self._dense_reference(dd, ext_coord, ext_atype, nlist) + out = dd.call(ext_coord, ext_atype, nlist, mapping=mapping) + np.testing.assert_allclose(out[0], ref[0], rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(out[1], ref[1], rtol=1e-12, atol=1e-12) + np.testing.assert_allclose(out[4], ref[4], rtol=1e-12, atol=1e-12) + + def test_call_graph_returns_flat_node_axis(self) -> None: + """call_graph output lives on the flat (N,) node axis, not (nf, nloc).""" + from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, + ) + + dd = self._make([30]) + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + self.coord, + self.atype, + dd.get_rcut(), + dd.get_sel(), + mixed_types=dd.mixed_types(), + box=None, + ) + graph = from_dense_quartet(ext_coord, nlist, mapping, compact=True) + atype_local = self.atype.reshape(-1) + grrg, rot_mat = dd.call_graph( + graph, atype_local, type_embedding=dd.type_embedding.call() + ) + n = atype_local.shape[0] + assert grrg.shape[0] == n and grrg.ndim == 2 + assert rot_mat.shape[0] == n and rot_mat.ndim == 3 diff --git a/source/tests/common/dpmodel/test_dpa1_graph_model_energy.py b/source/tests/common/dpmodel/test_dpa1_graph_model_energy.py new file mode 100644 index 0000000000..37cb18808e --- /dev/null +++ b/source/tests/common/dpmodel/test_dpa1_graph_model_energy.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Carry-all graph energy forward via ``neighbor_graph_method`` (Option B). + +``CM.call_common`` routes a graph-eligible dpa1(``attn_layer == 0``) ENERGY +forward through the carry-all graph builder + ``call_lower_graph``. Per the +default-flip (decision #17) this is now the DEFAULT for eligible models; +``neighbor_graph_method="legacy"`` opts out to the truncating dense nlist path, +and ``"dense"``/``"ase"`` force the carry-all graph with that builder. + +Option-B behavior (decision #17 / spec_unified_edge_nlist): + +* non-binding ``sel`` -- the carry-all graph and the legacy dense path see the + SAME neighbors, so ``energy``/``atom_energy`` are EXACTLY equal; +* binding ``sel`` -- the carry-all graph keeps neighbors the legacy dense path + truncates, so energy DIFFERS (intended). +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, +) + + +def _make_model(sel): + ds = DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=sel, + ntypes=2, + attn_layer=0, + axis_neuron=2, + neuron=[6, 12], + ) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + return EnergyModel(ds, ft, type_map=["foo", "bar"]) + + +@pytest.mark.parametrize("method", ["dense", "ase"]) # in-tree carry-all AND ase +@pytest.mark.parametrize("periodic", [True, False]) # PBC and non-PBC +def test_energy_parity_non_binding_sel(method, periodic) -> None: + """At non-binding sel the carry-all graph and the dense path see the SAME + neighbors, so model energy is exactly equal. + """ + if method == "ase": + pytest.importorskip("ase") + rng = np.random.default_rng(0) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + box = None + if periodic: + # large box so the cell is essentially non-periodic for rcut=4.0 + box = np.eye(3).reshape(1, 9) * 20.0 + # LARGE sel -> non-binding (no truncation) + model = _make_model([200]) + + dense = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + graph = model.call_common(coord, atype, box, neighbor_graph_method=method) + + # dense energy keys: ``energy_redu`` (reduced, nf x 1) and ``energy`` + # (per-atom, nf x nloc x 1). Compare matching keys. + np.testing.assert_allclose( + graph["energy_redu"], dense["energy_redu"], rtol=1e-12, atol=1e-12 + ) + np.testing.assert_allclose(graph["energy"], dense["energy"], rtol=1e-12, atol=1e-12) + # mask must match the dense all-ones (nf, nloc) int mask + np.testing.assert_array_equal(graph["mask"], dense["mask"]) + + +@pytest.mark.parametrize("method", ["dense", "ase"]) # in-tree carry-all AND ase +def test_energy_parity_multiframe_periodic(method) -> None: + """Multi-frame (nf=2) PERIODIC energy parity at non-binding sel. + + Exercises the nf>1 graph reductions (``frame_id = repeat(arange(nf), + n_node)`` energy segment-sum and the ``frame * nloc`` node offsetting in + ``from_dense_quartet``) with DIFFERENT per-frame coordinates and a box. + At non-binding sel the carry-all graph and the dense path see the SAME + neighbors, so ``energy_redu``/``energy`` are EXACTLY equal per frame. + """ + if method == "ase": + pytest.importorskip("ase") + rng = np.random.default_rng(3) + nf, nloc = 2, 6 + # distinct coordinates per frame (not a broadcast of one frame) + coord = rng.normal(size=(nf, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]] * nf, dtype=np.int64) + # large box so the cell is essentially non-periodic for rcut=4.0 + box = np.tile(np.eye(3).reshape(1, 9) * 20.0, (nf, 1)) + # LARGE sel -> non-binding (no truncation) + model = _make_model([200]) + + dense = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + graph = model.call_common(coord, atype, box, neighbor_graph_method=method) + + np.testing.assert_allclose( + graph["energy_redu"], dense["energy_redu"], rtol=1e-12, atol=1e-12 + ) + np.testing.assert_allclose(graph["energy"], dense["energy"], rtol=1e-12, atol=1e-12) + np.testing.assert_array_equal(graph["mask"], dense["mask"]) + # the two frames must produce DIFFERENT energies (genuine nf>1 test, not a + # broadcast of one frame); they differ here by ~1e-5. + assert not np.array_equal(dense["energy_redu"][0], dense["energy_redu"][1]) + + +def test_virtual_atom_masked() -> None: + """A virtual atom (``atype == -1``) must contribute ZERO energy and have a + ZERO mask in the carry-all graph path, matching the dense path exactly. + + Regression for the leak where the graph path fed the raw (negative) atype + to the descriptor/fitting and stamped an all-ones mask, so virtual atoms + picked up a type-embedding + bias energy that the dense path masks out. + + Uses the in-tree ``"dense"`` builder, which shares the EXACT same quartet + neighbor list as the ``"legacy"`` dense path, so the parity is bit-tight + (the ``"ase"`` builder has its own near-cutoff boundary quirks, covered by + the other tests). + """ + method = "dense" + rng = np.random.default_rng(7) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + # one local virtual atom (atype == -1); the rest are real + atype = np.array([[0, 1, -1, 1, 0, 1]], dtype=np.int64) + box = None + # LARGE sel -> non-binding (no truncation) so dense == graph on real atoms + model = _make_model([200]) + + dense = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + graph = model.call_common(coord, atype, box, neighbor_graph_method=method) + + # graph energy (reduced + per-atom) must match the dense path exactly + np.testing.assert_allclose( + graph["energy_redu"], dense["energy_redu"], rtol=1e-12, atol=1e-12 + ) + np.testing.assert_allclose(graph["energy"], dense["energy"], rtol=1e-12, atol=1e-12) + # the virtual atom (index 2) contributes ZERO per-atom energy + np.testing.assert_allclose(graph["energy"][0, 2], 0.0, rtol=0, atol=0) + # mask must be 0 at the virtual atom and match the dense int mask + assert int(graph["mask"][0, 2]) == 0 + np.testing.assert_array_equal(graph["mask"], dense["mask"]) + expected_mask = np.array([[1, 1, 0, 1, 1, 1]], dtype=np.int32) + np.testing.assert_array_equal(graph["mask"], expected_mask) + + +def test_binding_sel_carries_more_than_dense() -> None: + """At binding sel the carry-all graph includes neighbors the dense path + truncates, so energy DIFFERS (intended, decision #17 / Option B). + """ + rng = np.random.default_rng(1) + nloc = 14 + # a dense cluster: many atoms well within rcut=4.0 of each other + coord = rng.normal(size=(1, nloc, 3)) * 0.8 + atype = np.array([[0, 1] * 7], dtype=np.int64) + box = None + # binding sel -> dense path truncates to 4 neighbors per atom + model = _make_model([4]) + + dense = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + graph = model.call_common(coord, atype, box, neighbor_graph_method="dense") + + assert not np.allclose(graph["energy_redu"], dense["energy_redu"]) + + +def test_neighbor_list_conflicts_with_graph_method() -> None: + """An explicit ``neighbor_list`` (a dense-nlist strategy) cannot be combined + with an explicit graph ``neighbor_graph_method``; passing both raises. + """ + from deepmd.dpmodel.utils.default_neighbor_list import ( + DefaultNeighborList, + ) + + rng = np.random.default_rng(2) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + model = _make_model([200]) + + with pytest.raises(ValueError, match="cannot be combined"): + model.call_common( + coord, + atype, + None, + neighbor_list=DefaultNeighborList(), + neighbor_graph_method="dense", + ) + + +def test_neighbor_list_takes_dense_route() -> None: + """Supplying ``neighbor_list`` (without an explicit graph method) takes the + dense route -- it is NOT silently ignored by the graph path. With the + default builder the result matches the legacy dense path exactly. + """ + from deepmd.dpmodel.utils.default_neighbor_list import ( + DefaultNeighborList, + ) + + rng = np.random.default_rng(3) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + box = np.eye(3).reshape(1, 9) * 20.0 + model = _make_model([200]) + + legacy = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + with_nlist = model.call_common( + coord, atype, box, neighbor_list=DefaultNeighborList() + ) + np.testing.assert_allclose( + with_nlist["energy_redu"], legacy["energy_redu"], rtol=1e-12, atol=1e-12 + ) + + +def test_graph_lower_invariant_to_charge_spin() -> None: + """dpa1 does NOT consume charge_spin (``get_dim_chg_spin() == 0``); the dense + atomic model passes ``None`` to the dpa1 descriptor regardless. The graph + lower accepts ``charge_spin`` only for ABI stability with charge/spin + descriptors (dpa3/dpa4, PR-G), so its output must be INVARIANT to it. + + Combined with the graph==dense parity at non-binding sel + (:func:`test_energy_parity_non_binding_sel`), this gives the full claim: + ``graph(charge_spin) == graph(None) == dense``. + """ + from deepmd.dpmodel.utils.neighbor_graph import ( + build_neighbor_graph, + ) + + rng = np.random.default_rng(4) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + box = np.eye(3).reshape(1, 9) * 20.0 + model = _make_model([200]) + assert model.get_descriptor().get_dim_chg_spin() == 0 # dpa1: no chg/spin + + ng = build_neighbor_graph(coord, atype, box, model.get_rcut()) + atype_flat = atype.reshape(-1) + base = model.call_common_lower_graph( + atype_flat, ng.n_node, ng.edge_index, ng.edge_vec, ng.edge_mask + ) + # arbitrary non-None charge/spin -> must NOT change the dpa1 graph output + cs = np.array([[1.0, 2.0]], dtype=coord.dtype) + with_cs = model.call_common_lower_graph( + atype_flat, + ng.n_node, + ng.edge_index, + ng.edge_vec, + ng.edge_mask, + charge_spin=cs, + ) + assert set(base) == set(with_cs) + for k, v in base.items(): + if v is None: + assert with_cs[k] is None + else: + np.testing.assert_array_equal(with_cs[k], v) diff --git a/source/tests/common/dpmodel/test_edge_env_mat.py b/source/tests/common/dpmodel/test_edge_env_mat.py new file mode 100644 index 0000000000..fd7a51deeb --- /dev/null +++ b/source/tests/common/dpmodel/test_edge_env_mat.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import pytest + +from deepmd.dpmodel.utils.env_mat import ( + EnvMat, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + edge_env_mat, + from_dense_quartet, +) + + +class TestEdgeEnvMat(unittest.TestCase): + def setUp(self) -> None: + rng = np.random.default_rng(0) + self.rcut, self.rcut_smth = 4.0, 0.5 + self.nf, self.nloc, self.nnei = 1, 4, 6 + self.ext_coord = rng.normal(size=(self.nf, self.nloc, 3)) * 1.5 + self.atype = np.array([[0, 1, 0, 1]], dtype=np.int64) + nlist = -np.ones((self.nf, self.nloc, self.nnei), dtype=np.int64) + for i in range(self.nloc): + ns = [j for j in range(self.nloc) if j != i][: self.nnei] + nlist[0, i, : len(ns)] = ns + self.nlist = nlist + self.mapping = np.arange(self.nloc, dtype=np.int64)[None] + self.nt = 2 + self.davg = rng.normal(size=(self.nt, 4)) + self.dstd = np.abs(rng.normal(size=(self.nt, 4))) + 0.5 + + def test_matches_envmat_slice(self) -> None: + davg_dense = np.broadcast_to( + self.davg[:, None, :], (self.nt, self.nnei, 4) + ).copy() + dstd_dense = np.broadcast_to( + self.dstd[:, None, :], (self.nt, self.nnei, 4) + ).copy() + dmat, _, _ = EnvMat(self.rcut, self.rcut_smth).call( + self.ext_coord, self.atype, self.nlist, davg_dense, dstd_dense + ) + + ng = from_dense_quartet(self.ext_coord, self.nlist, self.mapping) + center_type = self.atype.reshape(-1)[ng.edge_index[1]] + em = edge_env_mat( + ng.edge_vec, center_type, self.davg, self.dstd, self.rcut, self.rcut_smth + ) + + ei = ng.edge_index[:, ng.edge_mask] + for k in range(ei.shape[1]): + src, dst = int(ei[0, k]), int(ei[1, k]) + slot = list(self.nlist[0, dst]).index(src) + np.testing.assert_allclose( + em[k], dmat[0, dst, slot], rtol=1e-12, atol=1e-12 + ) + + def test_slot_broadcast_stats(self) -> None: + """After compute_input_stats, DescrptBlockSeAtten stats must be + slot-uniform: mean[:, k, :] == mean[:, 0, :] for all slots k. + This property is what allows edge_env_mat to use (ntypes, 4) stats + instead of (ntypes, nnei, 4) stats. + """ + from deepmd.dpmodel.descriptor import ( + DescrptDPA1, + ) + + rng = np.random.default_rng(42) + nloc = 6 + nf = 3 + rcut = 4.0 + rcut_smth = 0.5 + ntypes = 2 + sel = [6, 6] + + coord = rng.normal(size=(nf, nloc, 3)).astype(np.float64) + # scale so atoms are within rcut of each other + coord = coord * 1.2 + atype = np.array([[0, 1, 0, 1, 0, 1]] * nf, dtype=np.int64) + # non-periodic: box=None + data = [ + { + "coord": coord, + "atype": atype, + "box": None, + } + ] + + dpa1 = DescrptDPA1(rcut, rcut_smth, sel, ntypes=ntypes) + dpa1.compute_input_stats(data) + block = dpa1.se_atten + + nnei = block.nnei + for k in range(1, nnei): + np.testing.assert_allclose( + block.mean[:, 0, :], + block.mean[:, k, :], + rtol=0, + atol=0, + err_msg=f"mean slot {k} != slot 0", + ) + np.testing.assert_allclose( + block.stddev[:, 0, :], + block.stddev[:, k, :], + rtol=0, + atol=0, + err_msg=f"stddev slot {k} != slot 0", + ) + + +# ── Protection parity (Task 6) ──────────────────────────────────────────────── + + +@pytest.mark.parametrize("protection", [0.0, 1e-2]) # env-mat protection offset +def test_edge_env_mat_protection_parity(protection): + """edge_env_mat(protection=p, edge_mask=...) must match EnvMat(protection=p).call slice.""" + rng = np.random.default_rng(7) + rcut, rcut_smth = 4.0, 0.5 + nf, nloc, nnei = 1, 4, 6 + nt = 2 + + ext_coord = rng.normal(size=(nf, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1]], dtype=np.int64) + + # Build nlist with at most 3 valid neighbors per atom; slots 3-5 are padding (-1). + nlist = -np.ones((nf, nloc, nnei), dtype=np.int64) + for i in range(nloc): + ns = [j for j in range(nloc) if j != i][:nnei] + nlist[0, i, : len(ns)] = ns + mapping = np.arange(nloc, dtype=np.int64)[None] + + davg = rng.normal(size=(nt, 4)) + dstd = np.abs(rng.normal(size=(nt, 4))) + 0.5 + + # ── dense reference (EnvMat.call) ────────────────────────────────────── + davg_dense = np.broadcast_to(davg[:, None, :], (nt, nnei, 4)).copy() + dstd_dense = np.broadcast_to(dstd[:, None, :], (nt, nnei, 4)).copy() + dmat, _, _ = EnvMat(rcut, rcut_smth, protection=protection).call( + ext_coord, atype, nlist, davg_dense, dstd_dense + ) + + # ── graph path (edge_env_mat with edge_mask) ─────────────────────────── + ng = from_dense_quartet(ext_coord, nlist, mapping) + center_type = atype.reshape(-1)[ng.edge_index[1]] + em = edge_env_mat( + ng.edge_vec, + center_type, + davg, + dstd, + rcut, + rcut_smth, + protection=protection, + edge_mask=ng.edge_mask, + ) + + # Compare valid edges only, matched to their dense (frame, dst, slot) position. + ei = ng.edge_index[:, ng.edge_mask] + for k in range(ei.shape[1]): + src, dst = int(ei[0, k]), int(ei[1, k]) + slot = list(nlist[0, dst]).index(src) + np.testing.assert_allclose( + em[ng.edge_mask][k], + dmat[0, dst, slot], + rtol=1e-12, + atol=1e-12, + err_msg=f"protection={protection}, edge {k} (src={src}, dst={dst}, slot={slot})", + ) + + +def test_protection_actually_changes_env_mat() -> None: + """Guard against vacuous pass: verify that changing protection parameter + actually modifies the edge_env_mat output. If this test fails (outputs are + identical for protection=0 and protection=1e-2), it means protection is + silently ignored and the parity test cannot validate the protection path. + """ + rng = np.random.default_rng(7) + rcut, rcut_smth = 4.0, 0.5 + nf, nloc, nnei = 1, 4, 6 + nt = 2 + + ext_coord = rng.normal(size=(nf, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1]], dtype=np.int64) + + # Build nlist with at most 3 valid neighbors per atom; slots 3-5 are padding (-1). + nlist = -np.ones((nf, nloc, nnei), dtype=np.int64) + for i in range(nloc): + ns = [j for j in range(nloc) if j != i][:nnei] + nlist[0, i, : len(ns)] = ns + mapping = np.arange(nloc, dtype=np.int64)[None] + + davg = rng.normal(size=(nt, 4)) + dstd = np.abs(rng.normal(size=(nt, 4))) + 0.5 + + # Build the graph once + ng = from_dense_quartet(ext_coord, nlist, mapping) + center_type = atype.reshape(-1)[ng.edge_index[1]] + + # Evaluate edge_env_mat with two different protection values + em_p0 = edge_env_mat( + ng.edge_vec, + center_type, + davg, + dstd, + rcut, + rcut_smth, + protection=0.0, + edge_mask=ng.edge_mask, + ) + em_p1 = edge_env_mat( + ng.edge_vec, + center_type, + davg, + dstd, + rcut, + rcut_smth, + protection=1e-2, + edge_mask=ng.edge_mask, + ) + + # Assert they differ: protection must affect the output + assert not np.allclose(em_p0, em_p1), ( + "protection parameter has no effect on edge_env_mat output; " + "parity test cannot validate protection path" + ) diff --git a/source/tests/common/dpmodel/test_fitting_call_graph.py b/source/tests/common/dpmodel/test_fitting_call_graph.py new file mode 100644 index 0000000000..2e143046eb --- /dev/null +++ b/source/tests/common/dpmodel/test_fitting_call_graph.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""GeneralFitting.call_graph is the graph-native (flat-N) fitting API. Its result +must be bit-identical to the dense __call__ raveled over (nf, nloc) -- it reuses +the dense net via the (N,1,nd) single-atom-frame workaround. fparam is node-level +(N, ndf) (the caller gathers per-frame fparam by frame_id). +""" + +import numpy as np +import pytest + +from deepmd.dpmodel.fitting import ( + InvarFitting, +) + + +@pytest.mark.parametrize("ndf", [0, 3]) # numb_fparam: no-fparam AND fparam +def test_call_graph_matches_dense_raveled(ndf): + rng = np.random.default_rng(0) + nf, nloc, nd, ntypes, ng = 2, 4, 8, 2, 5 + ft = InvarFitting("energy", ntypes, nd, 1, mixed_types=True, numb_fparam=ndf) + desc = rng.normal(size=(nf, nloc, nd)) + atype = rng.integers(0, ntypes, size=(nf, nloc)) + gr = rng.normal(size=(nf, nloc, ng, 3)) + fparam = rng.normal(size=(nf, ndf)) if ndf else None + dense = ft(desc, atype, gr=gr, fparam=fparam)["energy"] # (nf, nloc, 1) + N = nf * nloc + frame_id = np.repeat(np.arange(nf), nloc) + fparam_node = fparam[frame_id] if ndf else None # (N, ndf) + flat = ft.call_graph( + desc.reshape(N, nd), + atype.reshape(N), + gr=gr.reshape(N, ng, 3), + fparam=fparam_node, + )["energy"] # (N, 1) + assert flat.shape == (N, 1) + np.testing.assert_allclose(flat, dense.reshape(N, 1), rtol=1e-12, atol=1e-12) diff --git a/source/tests/common/dpmodel/test_from_ijs.py b/source/tests/common/dpmodel/test_from_ijs.py new file mode 100644 index 0000000000..bab616e452 --- /dev/null +++ b/source/tests/common/dpmodel/test_from_ijs.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import pytest + +from deepmd.dpmodel.utils.neighbor_graph import ( + neighbor_graph_from_ijs, +) + + +class TestFromIjs(unittest.TestCase): + def test_edge_vec_and_index(self) -> None: + """src=j, dst=i, edge_vec = coord[j] + S@box - coord[i] (single frame, S=0).""" + coord = np.array([[[0.0, 0, 0], [1.0, 0, 0], [0, 2.0, 0]]]) # (1,3,3) + box = np.eye(3)[None] * 6.0 + i = np.array([0, 1]) # center + j = np.array([1, 0]) # neighbor + S = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int64) + ng = neighbor_graph_from_ijs( + i, j, S, coord, box, nframe_id=np.zeros(2, np.int64), nloc=3 + ) + np.testing.assert_array_equal(ng.edge_index[0][ng.edge_mask], j) # src + np.testing.assert_array_equal(ng.edge_index[1][ng.edge_mask], i) # dst + np.testing.assert_allclose( + ng.edge_vec[ng.edge_mask][0], coord[0, 1] - coord[0, 0] + ) + + def test_periodic_shift_in_edge_vec(self) -> None: + """A nonzero S contributes S@box to edge_vec (image neighbor).""" + coord = np.array([[[0.5, 0, 0], [5.5, 0, 0]]]) # (1,2,3) + box = np.eye(3)[None] * 6.0 + i = np.array([0]) + j = np.array([1]) + S = np.array([[-1, 0, 0]], dtype=np.int64) + ng = neighbor_graph_from_ijs( + i, j, S, coord, box, nframe_id=np.zeros(1, np.int64), nloc=2 + ) + # coord[1] + (-1,0,0)@box - coord[0] = 5.5 - 6 - 0.5 = -1.0 + np.testing.assert_allclose( + ng.edge_vec[ng.edge_mask][0], np.array([-1.0, 0.0, 0.0]) + ) + + +class TestAseCarryAll(unittest.TestCase): + def _sets(self, ng, nloc): + # per-center set of (src, rounded edge_vec); real edges only + ei = ng.edge_index[:, ng.edge_mask] + ev = ng.edge_vec[ng.edge_mask] + s = [set() for _ in range(nloc)] + for k in range(ei.shape[1]): + s[int(ei[1, k])].add((int(ei[0, k]), tuple(np.round(ev[k], 6)))) + return s + + def test_ase_matches_intree_carry_all(self) -> None: + """ASE carry-all builder yields the SAME neighbor set as the in-tree + carry-all build_neighbor_graph (both carry ALL neighbors in rcut). + """ + pytest.importorskip("ase") + from deepmd.dpmodel.utils.neighbor_graph import ( + build_neighbor_graph, + build_neighbor_graph_ase, + ) + + rng = np.random.default_rng(3) + coord = rng.normal(size=(1, 8, 3)) * 2.0 + atype = np.array([[0, 1] * 4], dtype=np.int64) + box = np.eye(3)[None] * 8.0 + ng_ase = build_neighbor_graph_ase(coord, atype, box, rcut=4.0) + ng_ref = build_neighbor_graph(coord, atype, box, rcut=4.0) + self.assertEqual(self._sets(ng_ase, 8), self._sets(ng_ref, 8)) + + def test_ase_matches_intree_carry_all_nonperiodic(self) -> None: + """Non-periodic (box=None): ASE carry-all == in-tree carry-all.""" + pytest.importorskip("ase") + from deepmd.dpmodel.utils.neighbor_graph import ( + build_neighbor_graph, + build_neighbor_graph_ase, + ) + + rng = np.random.default_rng(7) + coord = rng.normal(size=(1, 6, 3)) * 2.0 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + ng_ase = build_neighbor_graph_ase(coord, atype, None, rcut=4.0) + ng_ref = build_neighbor_graph(coord, atype, None, rcut=4.0) + self.assertEqual(self._sets(ng_ase, 6), self._sets(ng_ref, 6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/common/dpmodel/test_graph_atomic_parity.py b/source/tests/common/dpmodel/test_graph_atomic_parity.py new file mode 100644 index 0000000000..7de084a25f --- /dev/null +++ b/source/tests/common/dpmodel/test_graph_atomic_parity.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import pytest + +from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, +) +from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +def _atomic_model(sel=(30,), **kw): + ds = DescrptDPA1( + rcut=4.0, rcut_smth=0.5, sel=list(sel), ntypes=2, attn_layer=0, **kw + ) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + return DPAtomicModel(ds, ft, type_map=["a", "b"]) + + +def test_forward_atomic_graph_matches_dense(): + rng = np.random.default_rng(0) + coord = rng.normal(size=(1, 5, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0]], dtype=np.int64) + am = _atomic_model() + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + coord, atype, 4.0, [30], mixed_types=True, box=None + ) + dense = am.forward_atomic(ext_coord, ext_atype, nlist, mapping=mapping) + ng = from_dense_quartet(ext_coord, nlist, mapping) + graph = am.forward_atomic_graph(ng, atype.reshape(-1)) + np.testing.assert_allclose( + graph["energy"], dense["energy"].reshape(-1, 1), rtol=1e-12, atol=1e-12 + ) + + +def test_forward_atomic_graph_flat_shape_and_parity(): + """Flat (N, *) output, matching dense forward_atomic raveled over (nf, nloc).""" + rng = np.random.default_rng(0) + coord = rng.normal(size=(1, 5, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0]], dtype=np.int64) + am = _atomic_model() + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + coord, atype, 4.0, [30], mixed_types=True, box=None + ) + dense = am.forward_atomic(ext_coord, ext_atype, nlist, mapping=mapping) + ng = from_dense_quartet(ext_coord, nlist, mapping) + graph = am.forward_atomic_graph(ng, atype.reshape(-1)) + assert graph["energy"].shape == (5, 1) # FLAT (N, 1) + np.testing.assert_allclose( + graph["energy"], dense["energy"].reshape(5, 1), rtol=1e-12, atol=1e-12 + ) + + +def test_forward_common_atomic_graph_matches_dense(): + rng = np.random.default_rng(1) + coord = rng.normal(size=(1, 5, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0]], dtype=np.int64) + am = _atomic_model() + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + coord, atype, 4.0, [30], mixed_types=True, box=None + ) + dense = am.forward_common_atomic(ext_coord, ext_atype, nlist, mapping=mapping) + ng = from_dense_quartet(ext_coord, nlist, mapping) + graph = am.forward_common_atomic_graph(ng, atype.reshape(-1)) + # graph returns flat (N,*); reshape dense (nf,nloc,*) -> flat for comparison + for k in ("energy", "mask"): + g_arr = np.asarray(graph[k]) + d_arr = np.asarray(dense[k]).reshape(g_arr.shape) + np.testing.assert_allclose(g_arr, d_arr, rtol=1e-12, atol=1e-12) + + +# ── Feature-flag parity matrix (Task 6) ────────────────────────────────────── + + +def _ener_model(sel, type_one_side=False, exclude_types=None): + ds = DescrptDPA1( + rcut=4.0, + rcut_smth=0.5, + sel=list(sel), + ntypes=2, + attn_layer=0, + type_one_side=type_one_side, + exclude_types=exclude_types or [], + ) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + return EnergyModel(ds, ft, type_map=["a", "b"]) + + +@pytest.mark.parametrize("virtual", [False, True]) # one local atype == -1 +@pytest.mark.parametrize("type_one_side", [False, True]) # tebd concat content +@pytest.mark.parametrize("nf", [1, 2]) # single- and multi-frame +def test_graph_matches_dense_over_flags(virtual, type_one_side, nf): + rng = np.random.default_rng(2) + nloc = 6 + coord = rng.normal(size=(nf, nloc, 3)) * 1.5 + atype = np.tile(np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64), (nf, 1)) + if virtual: + atype[:, -1] = -1 # mark one local atom virtual + box = np.tile(np.eye(3).reshape(1, 9) * 20.0, (nf, 1)) + model = _ener_model([200], type_one_side=type_one_side) # non-binding sel + g = model.call_common(coord, atype, box, neighbor_graph_method="dense") + d = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + for k in ("energy", "energy_redu", "mask"): + np.testing.assert_allclose( + np.asarray(g[k]), np.asarray(d[k]), rtol=1e-12, atol=1e-12 + ) + if virtual: + assert int(np.asarray(g["mask"])[0, -1]) == 0 # virtual atom masked + + +def test_pair_exclude_types_falls_back_to_dense(): + """Pair exclude_types is unsupported on the graph -> uses_graph_lower False.""" + m = _ener_model([30], exclude_types=[(0, 1)]) + assert m.atomic_model.descriptor.uses_graph_lower() is False + + +def test_model_pair_exclude_types_graph_matches_dense(): + """Model-level pair_exclude_types is now graph-native (edge mask): graph == + dense at 1e-12 (was: gated to dense / raises NotImplementedError). + """ + rng = np.random.default_rng(4) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + box = np.eye(3).reshape(1, 9) * 20.0 + ds = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[200], ntypes=2, attn_layer=0) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + model = EnergyModel(ds, ft, type_map=["a", "b"], pair_exclude_types=[(0, 1)]) + assert model.atomic_model.pair_excl is not None + g = model.call_common(coord, atype, box, neighbor_graph_method="dense") + d = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + for k in ("energy", "energy_redu", "mask"): + np.testing.assert_allclose( + np.asarray(g[k]), np.asarray(d[k]), rtol=1e-12, atol=1e-12 + ) + # non-vacuous: toggle pair exclusion OFF on the SAME model (same weights), + # so any energy difference is due solely to the exclusion (not weights). + g_excl = model.call_common(coord, atype, box, neighbor_graph_method="dense") + model.atomic_model.reinit_pair_exclude([]) # clear pair exclusion + assert model.atomic_model.pair_excl is None + g_noexcl = model.call_common(coord, atype, box, neighbor_graph_method="dense") + # tight tolerance: the excluded (0,1) pairs contribute a small but real + # amount; default rtol=1e-5 is too loose to register it. + assert not np.allclose( + np.asarray(g_excl["energy_redu"]), + np.asarray(g_noexcl["energy_redu"]), + rtol=1e-9, + atol=1e-9, + ), "pair exclusion must change the graph energy (same weights)" + + +def test_graph_matches_dense_with_fparam(): + """Frame parameter is gathered to nodes by frame_id in forward_atomic_graph + and fed to the fitting's call_graph; the graph path must match dense at 1e-12 + with a non-zero fparam (exercises the frame_id gather + xp.take dispatch). + """ + rng = np.random.default_rng(7) + nf, nloc, ndf = 2, 5, 3 + coord = rng.normal(size=(nf, nloc, 3)) * 1.5 + atype = np.tile(np.array([[0, 1, 0, 1, 0]], dtype=np.int64), (nf, 1)) + box = np.tile(np.eye(3).reshape(1, 9) * 20.0, (nf, 1)) + fparam = rng.normal(size=(nf, ndf)) # per-frame, differs across frames + ds = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[200], ntypes=2, attn_layer=0) + ft = InvarFitting( + "energy", 2, ds.get_dim_out(), 1, mixed_types=True, numb_fparam=ndf + ) + model = EnergyModel(ds, ft, type_map=["a", "b"]) + g = model.call_common( + coord, atype, box, fparam=fparam, neighbor_graph_method="dense" + ) + d = model.call_common( + coord, atype, box, fparam=fparam, neighbor_graph_method="legacy" + ) + for k in ("energy", "energy_redu"): + np.testing.assert_allclose( + np.asarray(g[k]), np.asarray(d[k]), rtol=1e-12, atol=1e-12 + ) + # non-vacuous: each frame's fparam differs, so a mis-gathered fparam (e.g. + # every node given frame 0's fparam) would make the two frames' energies equal. + assert not np.allclose( + np.asarray(g["energy_redu"][0]), np.asarray(g["energy_redu"][1]) + ) + + +def test_graph_matches_dense_with_atom_exclude(): + """Model-level atom_exclude_types IS supported on the graph path (applied + via _finalize_atomic_ret's atom_excl). Graph == dense at rtol/atol 1e-12. + Also proves atom-level exclusion is correctly inherited and non-vacuous. + """ + rng = np.random.default_rng(11) + nloc = 6 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0, 1]], dtype=np.int64) + box = np.eye(3).reshape(1, 9) * 20.0 + ds = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[200], ntypes=2, attn_layer=0) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + am = DPAtomicModel(ds, ft, type_map=["a", "b"], atom_exclude_types=[0]) + model = EnergyModel(atomic_model_=am) + g = model.call_common(coord, atype, box, neighbor_graph_method="dense") + d = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + for k in ("energy", "energy_redu", "mask"): + g_arr = np.asarray(g[k]) + d_arr = np.asarray(d[k]) + max_diff = float(np.max(np.abs(g_arr - d_arr))) + np.testing.assert_allclose( + g_arr, + d_arr, + rtol=1e-12, + atol=1e-12, + err_msg=f"graph vs dense mismatch for '{k}': max_diff={max_diff}", + ) + # non-vacuous: type-0 atoms have zero energy (excluded), type-1 have nonzero + g_energy = np.asarray(g["energy"]) + g_mask = np.asarray(g["mask"]) + type0_indices = atype[0] == 0 + assert np.allclose(g_energy[0, type0_indices], 0.0), ( + "excluded type-0 atoms must have zero energy" + ) + assert not np.allclose(g_energy[0, ~type0_indices], 0.0), ( + "non-excluded type-1 atoms must have nonzero energy" + ) + # also check mask: excluded type-0 atoms should have mask==0 + assert np.all(g_mask[0, type0_indices] == 0), ( + "excluded type-0 atoms must have mask==0" + ) + + +def test_forward_common_atomic_graph_flat_shape(): + rng = np.random.default_rng(1) + coord = rng.normal(size=(1, 5, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0]], dtype=np.int64) + am = _atomic_model() + ext_coord, ext_atype, mapping, nlist = extend_input_and_build_neighbor_list( + coord, atype, 4.0, [30], mixed_types=True, box=None + ) + ng = from_dense_quartet(ext_coord, nlist, mapping) + out = am.forward_common_atomic_graph(ng, atype.reshape(-1)) + assert out["energy"].shape == (5, 1) # flat (N, 1) + assert out["mask"].shape == (5,) # flat (N,) + + +def test_graph_nloc1_unravel_shapes(): + """Regression: when nloc==1, N==nf so per-frame _redu keys must NOT be + reshaped to (nf,1,*). Before the fix, energy_redu came out (nf,1,1) instead + of (nf,1). Checks both shapes and value parity against the dense (legacy) path. + """ + nf = 2 + rng = np.random.default_rng(42) + coord = rng.normal(size=(nf, 1, 3)) * 1.5 + atype = np.zeros((nf, 1), dtype=np.int64) + box = np.tile(np.eye(3).reshape(1, 9) * 20.0, (nf, 1)) + model = _ener_model([200]) # non-binding sel + g = model.call_common(coord, atype, box, neighbor_graph_method="dense") + d = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + # shape assertions — the critical regression check + assert g["energy"].shape == (nf, 1, 1), f"energy shape {g['energy'].shape}" + assert g["energy_redu"].shape == (nf, 1), ( + f"energy_redu shape {g['energy_redu'].shape}" + ) + assert g["mask"].shape == (nf, 1), f"mask shape {g['mask'].shape}" + # value parity with the dense path + for k in ("energy", "energy_redu", "mask"): + np.testing.assert_allclose( + np.asarray(g[k]), + np.asarray(d[k]), + rtol=1e-12, + atol=1e-12, + err_msg=f"graph vs legacy mismatch for key '{k}'", + ) + + +def test_graph_matches_dense_with_out_bias(): + """The graph path applies apply_out_stat (per-type out-bias) identically + to the dense path. With a non-zero bias, graph == dense at 1e-12, and the + bias actually shifts the graph energy (non-vacuous). + """ + rng = np.random.default_rng(3) + nloc = 5 + coord = rng.normal(size=(1, nloc, 3)) * 1.5 + atype = np.array([[0, 1, 0, 1, 0]], dtype=np.int64) + box = np.eye(3).reshape(1, 9) * 20.0 + model = _ener_model([200]) + # energy BEFORE setting bias (zero out-bias), graph path + g_zero = model.call_common(coord, atype, box, neighbor_graph_method="dense") + # set a non-zero per-type energy out-bias + model.atomic_model.out_bias[0, :, 0] = np.array([0.3, -0.7]) + g = model.call_common(coord, atype, box, neighbor_graph_method="dense") + d = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + # graph applies out-stat exactly like dense + for k in ("energy", "energy_redu"): + np.testing.assert_allclose( + np.asarray(g[k]), np.asarray(d[k]), rtol=1e-12, atol=1e-12 + ) + # non-vacuous: the bias actually shifted the graph energy + assert not np.allclose(np.asarray(g["energy"]), np.asarray(g_zero["energy"])) diff --git a/source/tests/common/dpmodel/test_graph_ragged.py b/source/tests/common/dpmodel/test_graph_ragged.py new file mode 100644 index 0000000000..a651d245ea --- /dev/null +++ b/source/tests/common/dpmodel/test_graph_ragged.py @@ -0,0 +1,255 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Flat-N ragged-native graph path: nodes on a flat (N,) axis, N = sum(n_node); +per-frame reductions use segment_sum over frame_id. UNEQUAL per-frame node counts +(ragged) -- the case the old rectangular (nf,nloc) path could not represent. +""" + +import numpy as np + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model.ener_model import ( + EnergyModel, +) +from deepmd.dpmodel.utils.neighbor_graph import ( + frame_id_from_n_node, +) + + +def test_frame_id_ragged(): + fid = frame_id_from_n_node(np.array([3, 5, 2], dtype=np.int64)) # N=10 + np.testing.assert_array_equal( + fid, np.array([0, 0, 0, 1, 1, 1, 1, 1, 2, 2], dtype=fid.dtype) + ) + + +def test_frame_id_static_n_total(): + """A static ``n_total`` (jax/export trace-friendly path) matches the default + ``int(sum(n_node))`` path exactly when ``n_total == sum(n_node)``; a PADDED + ``n_total`` assigns the padding tail to the last frame. + """ + n_node = np.array([3, 5, 2], dtype=np.int64) # sum = 10 + # exact n_total reproduces the default (None) path + np.testing.assert_array_equal( + frame_id_from_n_node(n_node, n_total=10), + frame_id_from_n_node(n_node), + ) + # padded n_total -> padding nodes [10, 12) map to the last frame (nf-1 == 2) + np.testing.assert_array_equal( + frame_id_from_n_node(n_node, n_total=12), + np.array([0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2], dtype=np.int64), + ) + + +def test_forward_common_atomic_graph_ragged(): + """Two frames with DIFFERENT node counts (3 and 2) share one flat node axis. + + The old rectangular path (nloc = N // nf) could not represent this. + """ + import numpy as np + + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, + ) + from deepmd.dpmodel.fitting import ( + InvarFitting, + ) + from deepmd.dpmodel.utils.neighbor_graph import ( + NeighborGraph, + ) + + ds = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[30], ntypes=2, attn_layer=0) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + am = DPAtomicModel(ds, ft, type_map=["a", "b"]) + n_node = np.array([3, 2], dtype=np.int64) # RAGGED, N=5 + atype = np.array([0, 1, 0, 1, 0], dtype=np.int64) + edge_index = np.array([[1, 0, 4], [0, 1, 3]], dtype=np.int64) # within-frame + edge_vec = np.array([[1.0, 0, 0], [-1.0, 0, 0], [0.5, 0, 0]], dtype=np.float64) + edge_mask = np.array([True, True, True]) + g = NeighborGraph( + n_node=n_node, edge_index=edge_index, edge_vec=edge_vec, edge_mask=edge_mask + ) + out = am.forward_common_atomic_graph(g, atype) + assert out["energy"].shape == (5, 1) and out["mask"].shape == (5,) + assert np.all(np.isfinite(out["energy"])) + + +def test_frame_id_rectangular(): + fid = frame_id_from_n_node(np.array([4, 4], dtype=np.int64)) + np.testing.assert_array_equal( + fid, np.array([0, 0, 0, 0, 1, 1, 1, 1], dtype=fid.dtype) + ) + + +def test_call_lower_graph_ragged_energy_reduction(): + """Per-frame energy_redu = segment_sum of the frame's atom energies; ragged.""" + import numpy as np + + from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, + ) + from deepmd.dpmodel.fitting import ( + InvarFitting, + ) + from deepmd.dpmodel.model.ener_model import ( + EnergyModel, + ) + + ds = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=[30], ntypes=2, attn_layer=0) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + m = EnergyModel(ds, ft, type_map=["a", "b"]) + n_node = np.array([3, 2], dtype=np.int64) + atype = np.array([0, 1, 0, 1, 0], dtype=np.int64) + edge_index = np.array([[1, 0, 4], [0, 1, 3]], dtype=np.int64) + edge_vec = np.array([[1.0, 0, 0], [-1.0, 0, 0], [0.5, 0, 0]], dtype=np.float64) + edge_mask = np.array([True, True, True]) + out = m.call_lower_graph( + atype=atype, + n_node=n_node, + edge_index=edge_index, + edge_vec=edge_vec, + edge_mask=edge_mask, + ) + assert out["energy"].shape == (5, 1) # flat node energy + assert out["energy_redu"].shape == (2, 1) # per-FRAME reduced + np.testing.assert_allclose( + out["energy_redu"][0, 0], + out["energy"][0:3, 0].sum(), + rtol=1e-12, + atol=1e-12, + ) + np.testing.assert_allclose( + out["energy_redu"][1, 0], + out["energy"][3:5, 0].sum(), + rtol=1e-12, + atol=1e-12, + ) + + +def _ener_model_ragged(sel=(200,)): + """Build a dpa1(attn_layer=0) EnergyModel for gate tests.""" + ds = DescrptDPA1(rcut=4.0, rcut_smth=0.5, sel=list(sel), ntypes=2, attn_layer=0) + ft = InvarFitting("energy", 2, ds.get_dim_out(), 1, mixed_types=True) + return EnergyModel(ds, ft, type_map=["a", "b"]) + + +def test_rectangular_free_view_equivalence(): + """GATE: rectangular nf=2, nloc=5 graph path == legacy dense path bit-identical. + + Proves the flat-N rewrite does not perturb the rectangular special case. + public call_common with neighbor_graph_method='dense' must match 'legacy' + on energy / energy_redu / mask at rtol/atol 1e-12 (non-binding sel=[200]). + """ + nf, nloc = 2, 5 + rng = np.random.default_rng(7) + coord = rng.normal(size=(nf, nloc, 3)) * 1.5 + atype = np.tile(np.array([[0, 1, 0, 1, 0]], dtype=np.int64), (nf, 1)) + box = np.tile(np.eye(3).reshape(1, 9) * 20.0, (nf, 1)) # large PBC box + model = _ener_model_ragged(sel=[200]) # non-binding sel + g = model.call_common(coord, atype, box, neighbor_graph_method="dense") + d = model.call_common(coord, atype, box, neighbor_graph_method="legacy") + for k in ("energy", "energy_redu", "mask"): + np.testing.assert_allclose( + np.asarray(g[k]), + np.asarray(d[k]), + rtol=1e-12, + atol=1e-12, + err_msg=f"graph vs legacy mismatch for key '{k}'", + ) + + +def test_ragged_frames_independent(): + """GATE: ragged n_node=[3,2] per-frame energies equal two single-frame runs. + + Proves frames do not leak through segment_sum on the flat axis: the ragged + energy_redu[i] must match running the i-th frame's atoms+edges in isolation + through call_lower_graph. The SAME model weights are used for all three + calls so the comparison is meaningful. + + Frame 0: nodes 0-2 (atype [0,1,0]), edges 0<->1, 1<->2. + Frame 1: nodes 3-4 (atype [1,0]), edges 3<->4 (global) = 0<->1 (local). + """ + model = _ener_model_ragged() + + # ── Ragged graph (both frames in one flat call) ──────────────────────── + atype5 = np.array([0, 1, 0, 1, 0], dtype=np.int64) + # frame-0 edges (global indices 0,1,2): 0↔1, 1↔2 + # frame-1 edges (global indices 3,4): 3↔4 + edge_index_rag = np.array([[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]], dtype=np.int64) + edge_vec_rag = np.array( + [ + [1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0], + [1.5, 0.0, 0.0], + [-1.5, 0.0, 0.0], + [0.5, 0.0, 0.0], + [-0.5, 0.0, 0.0], + ], + dtype=np.float64, + ) + edge_mask_rag = np.ones(6, dtype=bool) + ragged = model.call_lower_graph( + atype=atype5, + n_node=np.array([3, 2], dtype=np.int64), + edge_index=edge_index_rag, + edge_vec=edge_vec_rag, + edge_mask=edge_mask_rag, + ) + + # ── Single-frame 0 (nodes 0-2) ───────────────────────────────────────── + atype_f0 = atype5[:3] # [0, 1, 0] + edge_index_f0 = np.array([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=np.int64) + edge_vec_f0 = np.array( + [ + [1.0, 0.0, 0.0], + [-1.0, 0.0, 0.0], + [1.5, 0.0, 0.0], + [-1.5, 0.0, 0.0], + ], + dtype=np.float64, + ) + edge_mask_f0 = np.ones(4, dtype=bool) + f0 = model.call_lower_graph( + atype=atype_f0, + n_node=np.array([3], dtype=np.int64), + edge_index=edge_index_f0, + edge_vec=edge_vec_f0, + edge_mask=edge_mask_f0, + ) + + # ── Single-frame 1 (nodes 3-4, remapped to local indices 0-1) ────────── + atype_f1 = atype5[3:] # [1, 0] (atype of global nodes 3,4) + # global edge 3→4 becomes local 0→1; global 4→3 becomes local 1→0 + edge_index_f1 = np.array([[0, 1], [1, 0]], dtype=np.int64) + edge_vec_f1 = np.array([[0.5, 0.0, 0.0], [-0.5, 0.0, 0.0]], dtype=np.float64) + edge_mask_f1 = np.ones(2, dtype=bool) + f1 = model.call_lower_graph( + atype=atype_f1, + n_node=np.array([2], dtype=np.int64), + edge_index=edge_index_f1, + edge_vec=edge_vec_f1, + edge_mask=edge_mask_f1, + ) + + # ── Gate assertions ──────────────────────────────────────────────────── + np.testing.assert_allclose( + np.asarray(ragged["energy_redu"][0]), + np.asarray(f0["energy_redu"][0]), + rtol=1e-12, + atol=1e-12, + err_msg="ragged frame-0 energy_redu must equal single-frame-0 energy_redu", + ) + np.testing.assert_allclose( + np.asarray(ragged["energy_redu"][1]), + np.asarray(f1["energy_redu"][0]), + rtol=1e-12, + atol=1e-12, + err_msg="ragged frame-1 energy_redu must equal single-frame-1 energy_redu", + ) diff --git a/source/tests/common/test_mixins.py b/source/tests/common/test_mixins.py index e311baf5cf..5dd907ded4 100644 --- a/source/tests/common/test_mixins.py +++ b/source/tests/common/test_mixins.py @@ -54,7 +54,7 @@ def setUp(self) -> None: [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) self.mapping = np.concatenate( - [self.mapping, self.mapping[:, self.perm]], axis=0 + [self.mapping, inv_perm[self.mapping[:, self.perm]]], axis=0 ) # permute the nlist diff --git a/source/tests/pd/model/test_env_mat.py b/source/tests/pd/model/test_env_mat.py index bbdb7c75a3..b5b9e0bee6 100644 --- a/source/tests/pd/model/test_env_mat.py +++ b/source/tests/pd/model/test_env_mat.py @@ -63,7 +63,7 @@ def setUp(self) -> None: [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) self.mapping = np.concatenate( - [self.mapping, self.mapping[:, self.perm]], axis=0 + [self.mapping, inv_perm[self.mapping[:, self.perm]]], axis=0 ) # permute the nlist diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index c5a2ed57a6..d7d2718e67 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -252,6 +252,65 @@ def fn(coord_ext, atype_ext, nlist): atol=atol, ) + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx_graph(self, prec) -> None: + """make_fx (export-readiness) of the attn_layer=0 GRAPH forward. + + For ``attn_layer == 0`` the dense ``forward`` routes through the + graph-native path (``from_dense_quartet -> call_graph``). This proves + that graph forward + ``autograd.grad`` is fx-traceable (full .pt2 + export is PR-B). + """ + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=0, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + # the attn_layer=0 graph adapter (from_dense_quartet) maps every ghost + # neighbor to its LOCAL owner via ``mapping``; the mixin's nall(4) > nloc(3) + # so a real mapping is required (identity mapping would index out of range). + mapping = torch.tensor(self.mapping, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist, mapping): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist, mapping)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist, mapping) + traced = make_fx(fn)(coord_ext, atype_ext, nlist, mapping) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist, mapping) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level def test_share_params(self, shared_level) -> None: """share_params level 0: share all; level 1: share type_embedding only.""" diff --git a/source/tests/pt_expt/model/test_dos_graph.py b/source/tests/pt_expt/model/test_dos_graph.py new file mode 100644 index 0000000000..dc8cfc0685 --- /dev/null +++ b/source/tests/pt_expt/model/test_dos_graph.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""The OUTPUT-AGNOSTIC graph lower supports ANY fitting, not just energy. + +A non-energy model with a graph-eligible descriptor (dpa1 ``attn_layer==0``) +routes into the graph path by default. Before the general output transform this +KeyError'd on ``"energy"``; now every fitting (dos/dipole/polar/property/...) +flows through :func:`fit_output_to_model_output_graph` with no change on the +fitting side. Each model's graph forward (default ``neighbor_graph_method``) +must match the dense path (``neighbor_graph_method="legacy"``) on every shared +key (carry-all graph at non-binding ``sel`` reproduces the dense neighbor set). +""" + +import pytest +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt_expt.fitting import ( + DipoleFitting, + DOSFittingNet, + PolarFitting, + PropertyFittingNet, +) +from deepmd.pt_expt.model import ( + DipoleModel, + DOSModel, + PolarModel, + PropertyModel, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +def _make_descriptor() -> DescrptDPA1: + return DescrptDPA1( + 4.0, + 0.5, + 20, # non-binding mixed-type single-int sel -> graph == dense neighbors + 2, + attn_layer=0, # graph lower only supports attn_layer == 0 + precision="float64", + seed=GLOBAL_SEED, + ).to(env.DEVICE) + + +def _make_dos(ds: DescrptDPA1): + return DOSModel( + ds, + DOSFittingNet( + 2, ds.get_dim_out(), 5, mixed_types=ds.mixed_types(), seed=GLOBAL_SEED + ).to(env.DEVICE), + type_map=["a", "b"], + ).to(env.DEVICE) + + +def _make_dipole(ds: DescrptDPA1): + return DipoleModel( + ds, + DipoleFitting( + 2, + ds.get_dim_out(), + embedding_width=ds.get_dim_emb(), + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE), + type_map=["a", "b"], + ).to(env.DEVICE) + + +def _make_polar(ds: DescrptDPA1): + return PolarModel( + ds, + PolarFitting( + 2, + ds.get_dim_out(), + embedding_width=ds.get_dim_emb(), + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE), + type_map=["a", "b"], + ).to(env.DEVICE) + + +def _make_property(ds: DescrptDPA1): + return PropertyModel( + ds, + PropertyFittingNet( + 2, + ds.get_dim_out(), + task_dim=3, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE), + type_map=["a", "b"], + ).to(env.DEVICE) + + +class TestNonEnergyGraph: + def setup_method(self) -> None: + generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) + self.coord = torch.rand( + 1, 5, 3, dtype=torch.float64, device=env.DEVICE, generator=generator + ) + self.atype = torch.tensor([[0, 1, 0, 1, 0]], device=env.DEVICE) + + def test_dos_repro(self) -> None: + """The exact bug repro: a DOS model's default forward used to KeyError + on ``"energy"`` in the graph path; now it succeeds. + """ + ds = _make_descriptor() + ft = DOSFittingNet(2, ds.get_dim_out(), 5, mixed_types=ds.mixed_types()).to( + env.DEVICE + ) + m = DOSModel(ds, ft, type_map=["a", "b"]).to(env.DEVICE) + out = m(self.coord, self.atype, box=None) + # standard DOS model keys (no KeyError) + assert set(out.keys()) >= {"atom_dos", "dos", "mask"} + assert out["atom_dos"].shape == (1, 5, 5) + assert out["dos"].shape == (1, 5) + + @pytest.mark.parametrize( + "make_model", + [_make_dos, _make_dipole, _make_polar, _make_property], + ) # one builder per fitting kind + def test_graph_matches_dense(self, make_model) -> None: + """Graph (default) output matches the dense (``legacy``) path on every + shared key, including derivatives for r/c-differentiable fittings. + """ + tol = ( + {"rtol": 1e-11, "atol": 1e-11} + if env.DEVICE.type == "cpu" + else {"rtol": 1e-9, "atol": 1e-9} + ) + ds = _make_descriptor() + m = make_model(ds) + graph = m.call_common(self.coord, self.atype, None, do_atomic_virial=True) + # the dense path differentiates w.r.t. coord -> needs a coord leaf. + dense = m.call_common( + self.coord.detach().requires_grad_(True), + self.atype, + None, + do_atomic_virial=True, + neighbor_graph_method="legacy", + ) + shared = [ + k + for k in graph + if k in dense and graph[k] is not None and dense[k] is not None + ] + # at least the reduced + per-atom output must be present and shared + assert len(shared) >= 2 + for k in shared: + torch.testing.assert_close( + graph[k].to(torch.float64), dense[k].to(torch.float64), **tol + ) diff --git a/source/tests/pt_expt/model/test_dpa1_graph_lower.py b/source/tests/pt_expt/model/test_dpa1_graph_lower.py new file mode 100644 index 0000000000..e274a1bcec --- /dev/null +++ b/source/tests/pt_expt/model/test_dpa1_graph_lower.py @@ -0,0 +1,232 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Parity: graph lower (forward_common_lower_graph) vs legacy dense lower. + +Builds a same-weights pt_expt dpa1(attn_layer=0) EnergyModel and a small +extended system, then compares the graph-native lower (energy/force/virial/ +atom_virial assembled from ``edge_energy_deriv``) against the legacy dense +``forward_common_lower`` on the SAME neighbor set (the graph is built REGIME-1 +from the same extended quartet via ``from_dense_quartet``). + +The graph lower is inherently LOCAL (ghost-free): its force/atom_virial live on +``nloc`` nodes, while the legacy lower returns EXTENDED (``nall``) force/ +atom_virial. The two are reconciled by folding the legacy extended force/ +atom_virial onto local atoms via ``mapping`` (a scatter-add on the atom axis, +identical to ``communicate_extended_output``). Energy, reduced energy and the +reduced (per-frame) virial are frame/local quantities and compare directly. +""" + +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.utils.neighbor_graph import ( + from_dense_quartet, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +def _fold_extended_to_local( + ext: torch.Tensor, mapping: torch.Tensor, nloc: int +) -> torch.Tensor: + """Scatter-add an extended (nf, nall, 1, K) tensor onto local atoms. + + Mirrors ``communicate_extended_output``: ``local[mapping[j]] += ext[j]`` + along the atom axis (dim 1). + """ + nf, nall = mapping.shape + K = ext.shape[-1] + out = torch.zeros(nf, nloc, 1, K, dtype=ext.dtype, device=ext.device) + idx = mapping.view(nf, nall, 1, 1).expand(nf, nall, 1, K) + out.scatter_add_(1, idx, ext) + return out + + +class TestDpa1GraphLower: + def setup_method(self) -> None: + self.device = env.DEVICE + self.natoms = 5 + self.rcut = 4.0 + self.rcut_smth = 0.5 + self.sel = 20 # mixed-type single int sel + self.nt = 2 + self.type_map = ["foo", "bar"] + + generator = torch.Generator(device=self.device).manual_seed(GLOBAL_SEED) + cell = torch.rand( + [3, 3], dtype=torch.float64, device=self.device, generator=generator + ) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device=self.device) + self.cell = cell.unsqueeze(0) # [1, 3, 3] + coord = torch.rand( + [self.natoms, 3], + dtype=torch.float64, + device=self.device, + generator=generator, + ) + coord = torch.matmul(coord, cell) + self.coord = coord.unsqueeze(0).to(self.device) # [1, natoms, 3] + self.atype = torch.tensor( + [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device + ) + + def _make_model(self) -> EnergyModel: + ds = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + neuron=[3, 6], + axis_neuron=2, + attn=4, + attn_layer=0, # graph lower only supports attn_layer == 0 + attn_dotr=True, + attn_mask=False, + activation_function="tanh", + set_davg_zero=False, + type_one_side=True, + precision="float64", + seed=GLOBAL_SEED, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + precision="float64", + seed=GLOBAL_SEED, + ).to(self.device) + return EnergyModel(ds, ft, type_map=self.type_map).to(self.device) + + def _prepare_lower_inputs(self, periodic: bool): + """Build extended coords, atype, nlist, mapping as torch tensors.""" + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + if periodic: + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, self.natoms, 3), + cell_np.reshape(1, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atype_np, + cell_np, + self.rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + self.natoms, + self.rcut, + [self.sel], + distinguish_types=False, + ) + extended_coord = extended_coord.reshape(1, -1, 3) + else: + extended_coord = coord_np.reshape(1, self.natoms, 3) + extended_atype = atype_np.reshape(1, self.natoms) + mapping = np.arange(self.natoms, dtype=np.int64).reshape(1, self.natoms) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + self.natoms, + self.rcut, + [self.sel], + distinguish_types=False, + ) + ext_coord = torch.tensor( + extended_coord, dtype=torch.float64, device=self.device + ) + ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=self.device) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=self.device) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) + return ext_coord, ext_atype, nlist_t, mapping_t + + @pytest.mark.parametrize("periodic", [True, False]) # PBC vs non-PBC + @pytest.mark.parametrize("do_av", [False, True]) # atom-virial off / on + def test_force_virial_parity_vs_legacy(self, periodic, do_av) -> None: + """Graph lower energy/force/virial/atom_virial == legacy dense lower on + the SAME neighbor set (regime-1 graph from from_dense_quartet). + """ + model = self._make_model() + model.eval() + tol = ( + {"rtol": 1e-12, "atol": 1e-12} + if self.device.type == "cpu" + else {"rtol": 1e-10, "atol": 1e-10} + ) + ext_coord, ext_atype, nlist, mapping = self._prepare_lower_inputs(periodic) + nf = ext_coord.shape[0] + nloc = self.natoms + + legacy = model.forward_common_lower( + ext_coord.clone().requires_grad_(True), + ext_atype, + nlist, + mapping, + do_atomic_virial=do_av, + ) + + # build the regime-1 graph from the SAME extended quartet. + # from_dense_quartet is array-API; feed torch tensors so the + # returned edge_vec is already a torch tensor on env.DEVICE. + ng = from_dense_quartet(ext_coord, nlist, mapping) + atype_local = ext_atype[:, :nloc].reshape(nf * nloc) + graph = model.forward_common_lower_graph( + atype_local, + ng.n_node, + ng.edge_index, + ng.edge_vec, + ng.edge_mask, + do_atomic_virial=do_av, + ) + + # forward_common_lower_graph returns flat (N = nf * nloc, *) per-atom + # outputs. Reshape to (nf, nloc, *) to compare against the dense lower. + + # per-atom energy: flat (N, 1) -> (nf, nloc, 1) + graph_energy = graph["energy"].reshape(nf, nloc, 1) + torch.testing.assert_close(graph_energy, legacy["energy"], **tol) + + # reduced energy and virial: already per-frame (nf, *) + torch.testing.assert_close(graph["energy_redu"], legacy["energy_redu"], **tol) + torch.testing.assert_close( + graph["energy_derv_c_redu"], legacy["energy_derv_c_redu"], **tol + ) + + # force: graph is flat (N, 1, 3); fold legacy extended (nall) -> local (nloc) + legacy_force_local = _fold_extended_to_local( + legacy["energy_derv_r"], mapping, nloc + ) + graph_force = graph["energy_derv_r"].reshape(nf, nloc, 1, 3) + torch.testing.assert_close(graph_force, legacy_force_local, **tol) + + if do_av: + legacy_av_local = _fold_extended_to_local( + legacy["energy_derv_c"], mapping, nloc + ) + graph_av = graph["energy_derv_c"].reshape(nf, nloc, 1, 9) + torch.testing.assert_close(graph_av, legacy_av_local, **tol) diff --git a/source/tests/pt_expt/model/test_edge_energy_deriv.py b/source/tests/pt_expt/model/test_edge_energy_deriv.py new file mode 100644 index 0000000000..fafc8ac180 --- /dev/null +++ b/source/tests/pt_expt/model/test_edge_energy_deriv.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt_expt.model.edge_transform_output import ( + edge_energy_deriv, +) + + +class TestEdgeEnergyDeriv(unittest.TestCase): + def test_force_matches_autograd_wrt_node_coords(self) -> None: + """The graph force equals -dE/d(node coord): build edge_vec from node + coords, so force from edge_energy_deriv == -autograd.grad(E, coords). + """ + torch.manual_seed(0) + N, nf = 5, 1 + n_node = torch.tensor([N], dtype=torch.int64, device=env.DEVICE) + coord = torch.randn( + N, 3, dtype=torch.float64, device=env.DEVICE, requires_grad=True + ) + # a connected edge set (both directions), all real + src = torch.tensor([0, 1, 1, 2, 3, 4], device=env.DEVICE) + dst = torch.tensor([1, 0, 2, 1, 4, 3], device=env.DEVICE) + edge_index = torch.stack([src, dst], 0) + edge_mask = torch.ones(src.shape[0], dtype=torch.bool, device=env.DEVICE) + edge_vec = coord[src] - coord[dst] # differentiable wrt coord + energy = (torch.sin(edge_vec).sum(-1) ** 2).sum() # toy scalar energy + force, av, gv = edge_energy_deriv( + energy, edge_vec, edge_index, edge_mask, n_node, do_atomic_virial=True + ) + # reference physical force = -dE/d(coord) + f_ref = -torch.autograd.grad(energy, coord, retain_graph=True)[0] + torch.testing.assert_close(force, f_ref, rtol=1e-10, atol=1e-10) + # atom-virial sums (per frame) to the global virial + torch.testing.assert_close(av.sum(0), gv[0], rtol=1e-10, atol=1e-10) + self.assertEqual(gv.shape, (nf, 3, 3)) + + def test_padding_edges_contribute_nothing(self) -> None: + """A masked guard edge with a huge edge_vec must not change force/virial.""" + torch.manual_seed(1) + N = 4 + n_node = torch.tensor([N], dtype=torch.int64, device=env.DEVICE) + coord = torch.randn( + N, 3, dtype=torch.float64, device=env.DEVICE, requires_grad=True + ) + src = torch.tensor([0, 1, 2], device=env.DEVICE) + dst = torch.tensor([1, 2, 3], device=env.DEVICE) + ev = coord[src] - coord[dst] + # append a masked guard edge with a huge vec + guard = torch.tensor( + [[99.0, 99.0, 99.0]], dtype=torch.float64, device=env.DEVICE + ) + edge_vec = torch.cat([ev, guard], 0).detach().requires_grad_(True) + edge_index = torch.tensor([[0, 1, 2, 0], [1, 2, 3, 0]], device=env.DEVICE) + edge_mask = torch.tensor([True, True, True, False], device=env.DEVICE) + energy = (edge_vec**2).sum() + force, av, gv = edge_energy_deriv( + energy, edge_vec, edge_index, edge_mask, n_node, do_atomic_virial=True + ) + # run again with ONLY the real edges; results must match + ev2 = edge_vec[:3].detach().requires_grad_(True) + e2 = (ev2**2).sum() + f2, av2, gv2 = edge_energy_deriv( + e2, ev2, edge_index[:, :3], edge_mask[:3], n_node, do_atomic_virial=True + ) + torch.testing.assert_close(force, f2, rtol=1e-12, atol=1e-12) + torch.testing.assert_close(gv, gv2, rtol=1e-12, atol=1e-12) + + def test_atom_virial_optional(self) -> None: + """do_atomic_virial=False returns None for atom_virial; force+virial still computed.""" + N = 3 + n_node = torch.tensor([N], dtype=torch.int64, device=env.DEVICE) + coord = torch.randn( + N, 3, dtype=torch.float64, device=env.DEVICE, requires_grad=True + ) + edge_index = torch.tensor([[0, 1], [1, 0]], device=env.DEVICE) + edge_mask = torch.ones(2, dtype=torch.bool, device=env.DEVICE) + edge_vec = coord[edge_index[0]] - coord[edge_index[1]] + energy = (edge_vec**2).sum() + force, av, gv = edge_energy_deriv( + energy, edge_vec, edge_index, edge_mask, n_node, do_atomic_virial=False + ) + self.assertIsNone(av) + self.assertEqual(force.shape, (N, 3)) + self.assertEqual(gv.shape, (1, 3, 3)) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/model/test_graph_ragged.py b/source/tests/pt_expt/model/test_graph_ragged.py new file mode 100644 index 0000000000..efe2ffeaec --- /dev/null +++ b/source/tests/pt_expt/model/test_graph_ragged.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Ragged n_node test for forward_common_lower_graph. + +Verifies that the flat-N graph transform correctly handles ragged frames +(n_node=[3,2], N=5): energy shape (5,1), energy_redu shape (2,1), +energy_derv_r leading dim 5. All entries must be finite. +""" + +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) + +from ...seed import ( + GLOBAL_SEED, +) + +_RCUT = 3.0 +_NT = 2 + + +def _make_model() -> EnergyModel: + ds = DescrptDPA1( + _RCUT, + 0.5, + 10, + _NT, + neuron=[3, 6], + axis_neuron=2, + attn=4, + attn_layer=0, + attn_dotr=True, + attn_mask=False, + activation_function="tanh", + set_davg_zero=True, + type_one_side=True, + precision="float64", + seed=GLOBAL_SEED, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + _NT, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + precision="float64", + seed=GLOBAL_SEED, + ).to(env.DEVICE) + return EnergyModel(ds, ft, type_map=["A", "B"]).to(env.DEVICE) + + +def _make_ragged_graph(device: torch.device) -> tuple: + """Build a ragged graph with n_node=[3,2] (N=5). + + Frame 0: atoms 0,1,2 — fully connected (6 directed edges within rcut). + Frame 1: atoms 3,4 — fully connected (2 directed edges within rcut). + Edge vectors are chosen to be small enough to fall within _RCUT. + """ + rng = torch.Generator(device=device).manual_seed(GLOBAL_SEED) + # flat atom types (N=5) + atype = torch.tensor([0, 1, 0, 1, 0], dtype=torch.int64, device=device) + # n_node per frame + n_node = torch.tensor([3, 2], dtype=torch.int64, device=device) + # edge_index: all pairs within each frame (flat indices into [0,4]) + # frame 0: 0↔1, 0↔2, 1↔2 (both directions = 6 edges) + # frame 1: 3↔4 (both directions = 2 edges) + src = torch.tensor([0, 1, 0, 2, 1, 2, 3, 4], dtype=torch.int64, device=device) + dst = torch.tensor([1, 0, 2, 0, 2, 1, 4, 3], dtype=torch.int64, device=device) + edge_index = torch.stack([src, dst], dim=0) # (2, 8) + # edge_vec: random small vectors well within rcut + edge_vec = ( + torch.rand(8, 3, dtype=torch.float64, device=device, generator=rng) * 0.5 + ).detach() + edge_mask = torch.ones(8, dtype=torch.bool, device=device) + return atype, n_node, edge_index, edge_vec, edge_mask + + +class TestGraphRagged: + def setup_method(self) -> None: + self.model = _make_model() + self.model.eval() + self.device = env.DEVICE + self.atype, self.n_node, self.edge_index, self.edge_vec, self.edge_mask = ( + _make_ragged_graph(self.device) + ) + + def test_flat_energy_shapes(self) -> None: + """forward_common_lower_graph returns flat (N,1) energy, (nf,1) energy_redu.""" + ret = self.model.forward_common_lower_graph( + self.atype, + self.n_node, + self.edge_index, + self.edge_vec, + self.edge_mask, + do_atomic_virial=False, + ) + N = int(self.n_node.sum()) # 5 + nf = int(self.n_node.shape[0]) # 2 + # per-atom energy: flat (N, *shap) = (5, 1) + assert ret["energy"].shape == (N, 1), ( + f"expected (5,1) got {ret['energy'].shape}" + ) + # reduced energy: per-frame (nf, *shap) = (2, 1) + assert ret["energy_redu"].shape == (nf, 1), ( + f"expected (2,1) got {ret['energy_redu'].shape}" + ) + # force: flat leading dim N + assert ret["energy_derv_r"].shape[0] == N, ( + f"expected leading dim 5 got {ret['energy_derv_r'].shape}" + ) + # all finite + assert torch.isfinite(ret["energy"]).all() + assert torch.isfinite(ret["energy_redu"]).all() + assert torch.isfinite(ret["energy_derv_r"]).all() + + def test_flat_atom_virial_shapes(self) -> None: + """With do_atomic_virial=True, atom_virial is also flat (N,1,9).""" + ret = self.model.forward_common_lower_graph( + self.atype, + self.n_node, + self.edge_index, + self.edge_vec, + self.edge_mask, + do_atomic_virial=True, + ) + N = int(self.n_node.sum()) # 5 + nf = int(self.n_node.shape[0]) # 2 + assert ret["energy"].shape == (N, 1) + assert ret["energy_redu"].shape == (nf, 1) + assert ret["energy_derv_r"].shape[0] == N + assert ret["energy_derv_c"].shape[0] == N + assert ret["energy_derv_c_redu"].shape[0] == nf + assert torch.isfinite(ret["energy_derv_c"]).all() + assert torch.isfinite(ret["energy_derv_c_redu"]).all() + + def test_invariant_to_charge_spin(self) -> None: + """dpa1 does NOT consume charge_spin (``get_dim_chg_spin() == 0``); + forward_common_lower_graph accepts it only for ABI stability with + charge/spin descriptors (dpa3/dpa4, PR-G), so energy / force / virial / + atom-virial must be INVARIANT to it. + """ + assert self.model.get_descriptor().get_dim_chg_spin() == 0 # dpa1 + args = ( + self.atype, + self.n_node, + self.edge_index, + self.edge_vec, + self.edge_mask, + ) + base = self.model.forward_common_lower_graph(*args, do_atomic_virial=True) + nf = int(self.n_node.shape[0]) + # arbitrary non-None charge/spin -> must NOT change any dpa1 graph output + cs = torch.tensor([[1.0, 2.0]] * nf, dtype=torch.float64, device=self.device) + with_cs = self.model.forward_common_lower_graph( + *args, do_atomic_virial=True, charge_spin=cs + ) + assert set(base) == set(with_cs) + for k, v in base.items(): + if v is None: + assert with_cs[k] is None + else: + torch.testing.assert_close(with_cs[k], v, rtol=1e-12, atol=1e-12) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 6e3f0b97a7..45061c084a 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -1352,7 +1352,9 @@ def _make_varying_config( config = normalize(config) return config - def _check_varying_natoms(self, descriptor: dict | None = None) -> None: + def _check_varying_natoms( + self, descriptor: dict | None = None, force_legacy_descriptor: bool = False + ) -> None: """Per-step compiled-vs-uncompiled comparison for the given descriptor. The loss config has ``start_pref_f=1000`` and ``start_pref_v=1.0``, @@ -1367,6 +1369,18 @@ def _check_varying_natoms(self, descriptor: dict | None = None) -> None: ``atol=rtol=1e-10`` tolerance; if a descriptor's compiled path cannot meet that on float64 the descriptor has a real numerical problem (see the DPA1 limitation note where this happened). + + ``force_legacy_descriptor`` makes a graph-eligible descriptor (dpa1 + ``attn_layer==0``) take the legacy *dense* (env-mat) path on BOTH the + compiled and uncompiled sides, so this stays a true compile-correctness + check (same computation, compiled vs eager). The pt_expt eager default + for such a descriptor is the carry-all GRAPH forward while the compiled + ``forward_lower`` is the sel-capped DENSE forward; those are two + *different* force computations whose parameter gradients agree only to + fp64 accumulation (~1e-12), which the optimizer then amplifies into a + diverging training trajectory. Making the compiled GRAPH lower (so + eager==compiled) is tracked for PR-B; until then this test exercises the + dense path it actually compiles. """ from deepmd.pt_expt.train.training import ( _CompiledModel, @@ -1386,6 +1400,16 @@ def _check_varying_natoms(self, descriptor: dict | None = None) -> None: compiled_model = trainer_c.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) + if force_legacy_descriptor: + # Pin BOTH sides to the legacy dense (env-mat) path so the + # uncompiled reference matches the dense ``forward_lower`` + # that gets compiled (must happen before the first forward, + # i.e. before the lazy compile trace). See the docstring / + # PR-B note: the graph forward vs dense forward differ in the + # backward at fp64 precision, which the optimizer amplifies. + for _m in (trainer_uc.model, compiled_model.original_model): + _m.get_descriptor().uses_graph_lower = lambda: False + # Sync weights so predictions can be compared exactly compiled_model.original_model.load_state_dict( trainer_uc.model.state_dict() @@ -1458,14 +1482,25 @@ def test_compiled_matches_uncompiled_varying_natoms_dpa3(self) -> None: self._check_varying_natoms(_DESCRIPTOR_DPA3) def test_compiled_matches_uncompiled_varying_natoms_dpa1_no_attn(self) -> None: - """DPA1 (attn_layer=0): compiled vs uncompiled match. + """DPA1 (attn_layer=0): compiled vs uncompiled match (dense path). + + ``force_legacy_descriptor=True`` pins both sides to the legacy dense + (env-mat) forward -- the path the compiled ``forward_lower`` actually + uses. The pt_expt eager default for dpa1(attn_layer=0) is the carry-all + GRAPH forward, a *different* force computation from the compiled dense + forward; their backward gradients agree only to fp64 accumulation, which + the optimizer amplifies, so comparing graph-vs-dense through training is + ill-posed. Making the compiled path the GRAPH lower (eager==compiled) + is tracked for PR-B (graph .pt2/export). DPA1 with attention layers is intentionally not covered: the compiled se_atten path is hardware-sensitive on multi-threaded CPUs (parallel reduction order diverges from eager above the 1e-10 tolerance). ``_compile_model`` warns the user instead. """ - self._check_varying_natoms(_DESCRIPTOR_DPA1_NO_ATTN) + self._check_varying_natoms( + _DESCRIPTOR_DPA1_NO_ATTN, force_legacy_descriptor=True + ) def test_compile_warns_dpa1_with_attention(self) -> None: """DPA1 (attn_layer>0) under compile must emit a warning. diff --git a/source/tests/universal/common/cases/cases.py b/source/tests/universal/common/cases/cases.py index d625ef0d35..421cbf4556 100644 --- a/source/tests/universal/common/cases/cases.py +++ b/source/tests/universal/common/cases/cases.py @@ -51,7 +51,7 @@ def setUp(self) -> None: [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 ) self.mapping = np.concatenate( - [self.mapping, self.mapping[:, self.perm]], axis=0 + [self.mapping, inv_perm[self.mapping[:, self.perm]]], axis=0 ) self.mock_descriptor = np.concatenate( [self.mock_descriptor, self.mock_descriptor[:, self.perm[: self.nloc], :]],