Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
1267b1f
feat(dpmodel): per-edge env_mat 4-vector (graph-native EnvMat)
Jun 25, 2026
3c75daf
feat(dpmodel): DescrptBlockSeAtten.call_graph (attn_layer=0, segment_…
Jun 25, 2026
836000d
feat(dpmodel): fail-fast on exclude_types in DescrptBlockSeAtten.call…
Jun 25, 2026
01beb47
feat(dpmodel): DescrptDPA1 dense call -> from-quartet->call_graph ada…
Jun 25, 2026
2ac1306
feat(dpmodel): model.call_lower_graph (energy/atom-energy via segment…
Jun 25, 2026
0a00978
test(dpmodel): remove PR-A import smoke test (no smoke tests in repo)
Jun 25, 2026
61668ef
refactor(dpmodel): single public DescrptDPA1.call_graph; private bloc…
Jun 25, 2026
c22bc13
feat(dpmodel): neighbor_graph_from_ijs + ASE carry-all builder (optio…
Jun 25, 2026
178c174
feat(dpmodel): opt-in carry-all graph energy forward via neighbor_gra…
Jun 25, 2026
09c8b33
refactor(dpmodel): explicit if/else for neighbor_graph_method routing…
Jun 25, 2026
cfebef9
feat(pt_expt): edge_energy_deriv (autograd grad(E,edge_vec) -> edge_f…
Jun 25, 2026
57202ae
feat(pt_expt): forward_common_lower_graph (force/virial via edge_ener…
Jun 25, 2026
6e97423
test(pt_expt): move test_edge_energy_deriv into model/ (mirrors deepm…
Jun 25, 2026
4e426af
test(pt_expt): dpa1(attn_layer=0) graph-path serialize round-trip + i…
Jun 25, 2026
ca77ac3
test(pt_expt): move dpa1 graph serialize test into descriptor/
Jun 25, 2026
48c0ea4
test(pt_expt): drop redundant dpa1 graph serialize test
Jun 25, 2026
33284bb
test: dpa1 graph attn_layer=0 make_fx + type_one_side + multi-frame c…
Jun 25, 2026
f85b0f6
fix(dpmodel): only route graph-ELIGIBLE configs through call_graph; f…
Jun 25, 2026
01f84a1
test(dpmodel): regression-lock graph-ineligible (strip/exclude/no-map…
Jun 25, 2026
5d88e3d
style: fix D209 docstring in dpa1 fallback regression test
Jun 25, 2026
912f054
feat: default eligible dpa1(attn_layer=0) to carry-all graph forward …
Jun 25, 2026
291c4b0
fix: guard descriptor-less atomic models in graph auto-resolve; pt_ex…
Jun 25, 2026
18e26c5
fix: gate carry-all default-flip to pt_expt only; dpmodel/jax keep de…
Jun 25, 2026
b59f5dd
feat(dpmodel): shape-static from_dense_quartet + static n_total -> de…
Jun 25, 2026
663cca6
refactor(dpmodel): explicit if/else for graph-vs-dense routing in Des…
Jun 25, 2026
37f9d4a
fix(dpmodel): address OutisLi #5581 review (spec refs, jax int-sum, A…
Jun 25, 2026
c2e0d96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 25, 2026
c9c8c21
test(pt_expt): parametrize periodic/do_av in test_dpa1_graph_lower (d…
Jun 25, 2026
f7e84fc
refactor(dpmodel): explicit if/else for compact vs shape-static in fr…
Jun 25, 2026
5a487c7
refactor(dpmodel): extract DescrptDPA1.call into thin dispatcher + _c…
Jun 25, 2026
7cfb2a8
feat: general output transform for the graph path (support all fittin…
Jun 25, 2026
58ef3fb
fix: address CodeRabbit review on #5583 (charge_spin/virtual-atom fal…
Jun 25, 2026
99c707a
fix(dpmodel): graph dense-bridge uses neighbor's actual extended type
Jun 25, 2026
53ec9a0
fix(tests): correct invalid permuted mapping; revert graph nei_type t…
Jun 26, 2026
48fe4b1
refactor(dpmodel): extract _finalize_atomic_ret from forward_common_a…
Jun 26, 2026
d136eea
feat(dpmodel): add DPAtomicModel.forward_atomic_graph
Jun 26, 2026
10af472
feat(dpmodel): add BaseAtomicModel.forward_common_atomic_graph
Jun 26, 2026
e182ed9
refactor(dpmodel): call_lower_graph reuses forward_common_atomic_graph
Jun 26, 2026
f597547
refactor(pt_expt): forward_common_lower_graph reuses forward_common_a…
Jun 26, 2026
15fa245
test(dpmodel): feature-flag graph-vs-dense parity matrix + protection…
Jun 26, 2026
52997c1
test(dpmodel): guard protection env-mat test against vacuous pass
Jun 26, 2026
5b0ec90
test(dpmodel): graph applies out-stat (out-bias) identically to dense
Jun 26, 2026
2c82daf
fix: gate model-level pair_exclude_types out of the graph path + tests
Jun 26, 2026
e04bdc2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2026
051f66b
docs(dpmodel): align graph atomic-wrapper docstrings with NumPy style
Jun 26, 2026
8d3a882
feat(dpmodel): frame_id_from_n_node node->frame map for flat-N graph …
Jun 26, 2026
deb524a
refactor(dpmodel): call_graph returns flat (N,...) node axis
Jun 26, 2026
117cfe5
feat(dpmodel): GeneralFitting.call_graph graph-native (flat-N) fittin…
Jun 26, 2026
5076c3a
refactor(dpmodel): forward_atomic_graph uses fitting.call_graph on th…
Jun 26, 2026
613756f
refactor(dpmodel): forward_common_atomic_graph + _finalize_atomic_ret…
Jun 26, 2026
1103264
refactor(dpmodel): flat-N graph output transform (segment_sum over fr…
Jun 26, 2026
1a91746
refactor(pt_expt): flat-N graph output transform + lower; reshape onl…
Jun 26, 2026
71fcb41
fix(dpmodel,pt_expt): graph I/O unravel skips _redu keys (nloc==1 N==…
Jun 26, 2026
1320ca1
test(dpmodel): ragged-native gate + rectangular free-view equivalence
Jun 26, 2026
32a9d4c
feat(dpmodel): graph-native pair-exclude edge mask; graph supports pa…
Jun 26, 2026
3d3d65c
style(dpmodel): hoist dataclasses import to module top (Task 9 review)
Jun 26, 2026
22ba509
style(dpmodel): drop blank line after forward_common_atomic_graph doc…
Jun 26, 2026
c247a74
fix(dpmodel): _finalize_atomic_ret zero-atom reshape (explicit traili…
Jun 26, 2026
e8e9885
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2026
fd9c158
test+refactor(dpmodel): fparam-through-graph test; stronger pair-excl…
Jun 26, 2026
a95e397
refactor(dpmodel): rename DescrptBlockSeAtten._call_graph -> call_gra…
Jun 27, 2026
ff70bac
refactor(dpmodel,pt_expt): graph output transform takes NeighborGraph…
Jun 27, 2026
be0fc97
refactor(dpmodel): symmetric public graph lower (casting + model wrap…
Jun 27, 2026
6987b50
fix(pt_expt): dpa1 varying-natoms compile test compares dense-vs-dense
Jun 27, 2026
8d21609
docs(dpmodel,pt_expt): conform graph-lower docstrings to NumPy conven…
Jun 27, 2026
5271edb
fix(dpmodel,pt_expt): address iProzd review on #5583
Jun 28, 2026
83583a3
test(dpmodel,pt_expt): pin dpa1 graph lower invariance to charge_spin
Jun 28, 2026
52e6c8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2026
3789dbd
test(pt_expt): drop unused N in test_dpa1_graph_lower (CodeQL)
Jun 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 104 additions & 10 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import dataclasses
import functools
import math
from collections.abc import (
Callable,
)
from typing import (
TYPE_CHECKING,
Any,
)

if TYPE_CHECKING:
from deepmd.dpmodel.utils.neighbor_graph import (
NeighborGraph,
)

import array_api_compat
import numpy as np

Expand Down Expand Up @@ -303,23 +310,110 @@ def forward_common_atomic(
comm_dict=comm_dict,
charge_spin=charge_spin,
)
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc)
return self._finalize_atomic_ret(ret_dict, atom_mask, atype)

def forward_common_atomic_graph(
self,
graph: "NeighborGraph",
atype: Array,
fparam: Array | None = None,
aparam: Array | None = None,
charge_spin: Array | None = None,
) -> dict:
"""Graph analogue of :meth:`forward_common_atomic` on the flat node axis.

The node axis is flat ``(N,)`` (``N = sum(graph.n_node)``); masking and
out-stat operate per node. Reuses :meth:`_finalize_atomic_ret`, so
virtual-atom masking, ``atom_excl`` and ``apply_out_stat`` match the dense
path. Model-level ``pair_exclude_types`` is graph-native: when
``self.pair_excl is not None``, an edge-keep mask is ANDed into
``graph.edge_mask`` before the descriptor forward, so excluded type-pairs
contribute zero to the segment_sum. Descriptor-level ``exclude_types`` is
gated by ``uses_graph_lower()==False``.

Parameters
----------
graph
neighbor graph for the local atoms (ghost-free)
atype
flat local atom types. N
fparam
frame parameter. nf x ndf
aparam
atomic parameter. N x nda
charge_spin
charge/spin conditioning. Unused by the dpa1 graph path; accepted so
the interface stays stable for charge/spin-conditioned descriptors.

Returns
-------
result_dict
the result dict on the flat node axis, defined by the `FittingOutputDef`.

"""
xp = array_api_compat.array_namespace(graph.edge_vec)
atype = xp.asarray(atype, device=array_api_compat.device(graph.edge_vec))
atom_mask = self.make_atom_mask(atype) # (N,) bool
atype_clamped = xp.where(atom_mask, atype, xp.zeros_like(atype))
if self.pair_excl is not None:
keep = self.pair_excl.build_edge_exclude_mask(
graph.edge_index, atype_clamped
)
graph = dataclasses.replace(
graph,
edge_mask=graph.edge_mask * xp.astype(keep, graph.edge_mask.dtype),
)
ret_dict = self.forward_atomic_graph(
graph,
atype_clamped,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
)
return self._finalize_atomic_ret(ret_dict, atom_mask, atype)

def _finalize_atomic_ret(
self, ret_dict: dict, atom_mask: Array, atype: Array
) -> dict:
"""Apply out-stat, atom exclusion and virtual-atom zeroing; set ``mask``.

Shared by the dense (:meth:`forward_common_atomic`, ``(nf, nloc)`` leading
dims) and graph (:meth:`forward_common_atomic_graph`, flat ``(N,)`` leading
dim) wrappers -- leading-dim-agnostic.

Parameters
----------
ret_dict
the raw per-atom result dict from ``forward_atomic``/``forward_atomic_graph``
atom_mask
the real-atom mask, True for real and False for virtual atoms. leading dims
atype
the local atom types, used for out-stat and ``atom_excl``. leading dims

Returns
-------
result_dict
``ret_dict`` with out-stat applied, virtual and excluded atoms zeroed,
and the integer ``mask`` key set.

"""
xp = array_api_compat.array_namespace(atype)
ret_dict = self.apply_out_stat(ret_dict, atype)
if self.atom_excl is not None:
atom_mask = xp.logical_and(
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
)

lead = atom_mask.shape # (nf, nloc) dense | (N,) graph
for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = math.prod(out_shape[2:])
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
out = ret_dict[kk]
# explicit trailing product (NOT -1): a zero-atom forward (nloc==0)
# has size 0, and numpy cannot infer -1 for a size-0 array.
trail = math.prod(out.shape[len(lead) :])
flat = xp.reshape(out, (*lead, trail))
flat = xp.where(atom_mask[..., None], flat, xp.zeros_like(flat))
ret_dict[kk] = xp.reshape(flat, out.shape)
ret_dict["mask"] = xp.astype(atom_mask, xp.int32)

return ret_dict

def call(
Expand Down
60 changes: 60 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
Callable,
)
from typing import (
TYPE_CHECKING,
Any,
)

if TYPE_CHECKING:
from deepmd.dpmodel.utils.neighbor_graph import (
NeighborGraph,
)

from deepmd.dpmodel.array_api import (
Array,
xp_take_first_n,
Expand Down Expand Up @@ -248,6 +254,60 @@ def forward_atomic(
)
return ret

def forward_atomic_graph(
self,
graph: "NeighborGraph",
atype: Array,
fparam: Array | None = None,
aparam: Array | None = None,
charge_spin: Array | None = None,
) -> dict[str, Array]:
Comment thread
wanghan-iapcm marked this conversation as resolved.
"""Graph analogue of :meth:`forward_atomic` on the flat node axis.

Runs the descriptor ``call_graph`` then the fitting ``call_graph`` PER NODE
and returns the raw fitting dict on the flat ``(N, *)`` axis (no reduction
or masking; the wrapper handles those). ``fparam`` is gathered to nodes by
``frame_id`` so each node sees its frame's parameter.

Parameters
----------
graph
neighbor graph for the local atoms (ghost-free)
atype
flat local atom types. N
fparam
frame parameter. nf x ndf
aparam
atomic parameter. N x nda
charge_spin
charge/spin conditioning. Unused by the dpa1 graph path; accepted so
the interface stays stable for charge/spin-conditioned descriptors.

Returns
-------
result_dict
the result dict on the flat node axis, defined by the `FittingOutputDef`.

"""
import array_api_compat

from deepmd.dpmodel.utils.neighbor_graph import (
frame_id_from_n_node,
)

xp = array_api_compat.array_namespace(graph.edge_vec)
type_embedding = self.descriptor.type_embedding.call()
gg, rot_mat = self.descriptor.call_graph(
graph, atype, type_embedding=type_embedding
)
fparam_node = None
if fparam is not None:
frame_id = frame_id_from_n_node(graph.n_node)
fparam_node = xp.take(fparam, frame_id, axis=0) # (N, ndf)
return self.fitting_net.call_graph(
gg, atype, gr=rot_mat, g2=None, h2=None, fparam=fparam_node, aparam=aparam
)

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict]],
Expand Down
10 changes: 5 additions & 5 deletions deepmd/dpmodel/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def apply_out_stat(
out_bias, out_std = self._fetch_out_stat(self.bias_keys)

if self.fitting_net.shift_diag:
nframes, nloc = atype.shape
dtype = out_bias[self.bias_keys[0]].dtype
device = array_api_compat.device(out_bias[self.bias_keys[0]])
for kk in self.bias_keys:
Expand All @@ -57,16 +56,17 @@ def apply_out_stat(
)
modified_bias = temp[atype]

# (nframes, nloc, 1)
# (..., 1) -- (nframes, nloc, 1) or (N, 1)
modified_bias = (
modified_bias[..., xp.newaxis] * (self.fitting_net.scale[atype])
)

eye = xp.eye(3, dtype=dtype, device=device)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
# leading-dim-agnostic: (nf, nloc) dense or (N,) flat graph path
eye = xp.tile(eye, (*atype.shape, 1, 1))
# (..., 3, 3)
modified_bias = modified_bias[..., xp.newaxis] * eye

# nf x nloc x odims, out_bias: ntypes x odims
# nf x nloc x odims (rect) or N x odims (flat), out_bias: ntypes x odims
ret[kk] = ret[kk] + modified_bias
return ret
Loading
Loading