Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b87ac50
feat(pt_expt): forward_common_lower_graph_exportable trace target for…
Jun 29, 2026
ee8db1b
fix(pt_expt): B1.1 review — test forward_lower_graph_exportable (both…
Jun 29, 2026
7437d54
test(dpmodel): codify static edge_capacity contract for build_neighbo…
Jun 29, 2026
1051a0d
feat(pt_expt): graph .pt2 export branch + lower_input_kind metadata
Jun 29, 2026
148fa0e
fix(pt_expt): B1.3 review — persist static edge_capacity (E_max) in g…
Jun 29, 2026
ce2fd12
test(pt_expt): graph .pt2 DeepEval parity vs eager dense dpa1 (pbc+no…
Jun 29, 2026
e35fc38
fix(pt_expt): compiled training runs the graph lower (eager==compiled…
Jun 29, 2026
47fb700
docs(pt_expt): B1 final-review minors — document nloc==1 unravel-skip…
Jun 29, 2026
b046874
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2026
0d3860e
feat(pt_expt): graph .pt2 uses a dynamic edge axis (Dim(nedge)); drop…
Jun 29, 2026
40487c4
docs(pt_expt): B2.0 review — _strip_shape_assertions now documents th…
Jun 29, 2026
9a8727b
test(infer): extend gen_dpa1.py with graph-eligible dpa1(attn_layer=0…
Jun 29, 2026
f97129c
feat(api_cc): graph-schema .pt2 ingestion in DeepPotPTExpt (single-rank)
Jun 29, 2026
074b3ff
fix(api_cc): cache mapping vector as member to fix OOB on ago>0 graph…
Jun 29, 2026
26b2c9d
test(api_cc): dpa1 graph .pt2 single-rank parity + fix graph output e…
Jun 29, 2026
af92be1
docs(infer): B2 final-review — correct gen_dpa1 graph-reference docst…
Jun 29, 2026
b25fdfc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 29, 2026
10d82a6
fix(api_cc): guard graph remap single-rank-only + atomic-overload gte…
Jun 30, 2026
7d37319
feat(api_cc): non-MP multi-rank graph path (extended region + reverse…
Jun 30, 2026
92c35a6
test(lammps): dpa1 graph .pt2 single + multi-rank (mpirun -n 2, local)
Jun 30, 2026
e2e07f1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2026
942de1f
refactor(pt_expt): group node_capacity with graph-shape args, make ed…
Jun 30, 2026
7658091
fix(dpmodel): clamp graph edge-scatter indices in-bounds (CUDA device…
Jun 30, 2026
0f43731
fix(dpmodel): export-safe modulo clamp for graph edge-scatter indices
Jun 30, 2026
afda4c7
fix(pt_expt): address AI review (CodeQL + CodeRabbit) on #5604
Jun 30, 2026
0ea2c34
feat(pt_expt): dp freeze --lower-kind {nlist,graph} for graph .pt2 ex…
Jun 30, 2026
7a50e60
test(api_cc): add deeppot_dpa1_graph.pt2 to universal/variant battery
Jun 30, 2026
b4c0b49
fix(pt_expt): address iProzd review — graph freeze defaults to .pt2 +…
Jul 1, 2026
4a76f6d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2026
c095e12
docs(dpmodel): document the graph edge-scatter modulo as the permanen…
Jul 1, 2026
3348c80
refactor(pt_expt): consolidate graph trace/sample builders (OutisLi r…
Jul 1, 2026
282f641
perf(api_cc): cache graph edge topology across steps + guard empty ra…
Jul 1, 2026
70b02fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,13 @@ def call_graph(
)
# FLAT node axis (N, ...): no (nf, nloc) reshape -- ragged-native, spec.
if self.concat_output_tebd:
tebd = xp.asarray(type_embedding, device=dev)
# Use type_embedding directly (mirrors the dense path's
# ``xp.take(type_embedding, ...)``): ``xp.asarray(..., device=dev)``
# DETACHES under torch, silently severing the type-embedding weight
# gradient so the tebd net never trains; type_embedding already lives
# on the model device, so the device cast was redundant anyway.
atype_local = xp.asarray(atype, device=dev)
atype_embd = xp.take(tebd, atype_local, axis=0) # (N, tebd_dim)
atype_embd = xp.take(type_embedding, atype_local, axis=0) # (N, tebd_dim)
grrg = xp.concat([grrg, atype_embd], axis=-1)
return grrg, rot_mat

Expand Down Expand Up @@ -1523,7 +1527,10 @@ def call_graph(
ss = rr[:, 0:1] # (E, 1)
# neighbor / center type embeddings (concat mode); ghost type == owner type
# so gathering by the LOCAL owner (src) reproduces the dense neighbor tebd.
tebd = xp.asarray(type_embedding, device=dev)
# NB: do NOT wrap in ``xp.asarray(..., device=dev)`` -- that DETACHES under
# torch and severs the type-embedding weight gradient (the tebd net would
# never train); type_embedding already lives on the model device.
tebd = type_embedding
atype_embd_nlist = xp.take(tebd, nei_type, axis=0) # (E, tebd_dim)
if not self.type_one_side:
atype_embd_nnei = xp.take(tebd, center_type, axis=0) # (E, tebd_dim)
Expand Down
41 changes: 35 additions & 6 deletions deepmd/dpmodel/utils/neighbor_graph/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,42 @@ def edge_force_virial(
frame via the frame of their ``dst`` node.
"""
xp = array_api_compat.array_namespace(g_e)
# node-axis size; when a static ``node_capacity`` is supplied (the jax/export
# path) short-circuit so we never call int() on the traced ``sum(n_node)``.
n_out = int(node_capacity) if node_capacity is not None else int(xp.sum(n_node))
# node-axis size; when a ``node_capacity`` is supplied (the jax/export path)
# use it AS-IS so we never call int() on the traced ``sum(n_node)`` -- and,
# crucially, never on ``node_capacity`` itself: under symbolic make_fx /
# torch.export it is a SymInt (``atype.shape[0]``); ``int(SymInt)`` would
# SPECIALIZE the node axis to the trace-time sample size, baking a constant
# ``N`` into the scatter and breaking dynamic-``N`` inference.
n_out = node_capacity if node_capacity is not None else int(xp.sum(n_node))
nf = n_node.shape[0]
# zero padding/guard contributions; cast mask to g's dtype (array-API pure,
# CLAUDE.md mask-multiply guideline — avoids bool*float under array_api_strict)
g = g_e * xp.astype(edge_mask[:, None], g_e.dtype)
src = edge_index[0]
dst = edge_index[1]
# Wrap node indices into ``[0, n_out)`` so every scatter address is provably
# in-bounds. For a well-formed graph every real edge already has
# ``index < n_out`` (== ``atype.shape[0]``), so this modulo is the IDENTITY on
# real edges (pinned by test_modulo_clamp_leaves_real_edges_unchanged) -- a
# correctness-preserving guard, not a value fixup.
#
# Why it is needed (root cause, GPU-confirmed): under the dynamic-edge graph
# ``torch.export`` path the node count is traced as several equal-but-distinct
# symbols (``atype.shape[0]``, ``fit_ret.shape[0]``, ...), tied only by
# ``aten._assert_scalar(Eq(...))`` nodes. ``_strip_shape_assertions``
# (pt_expt/utils/serialization.py) neutralises ALL such asserts so export can
# trace -- which also drops those node-count equalities, so inductor can no
# longer prove the scatter index and its bound ``ks0 == n_out`` share a symbol
# and emits ``tl.device_assert(idx < ks0)`` (fatal on CUDA; unchecked on CPU,
# which is why all CPU dev/CI was green). ``% n_out`` discharges that guard
# unconditionally. This is the PERMANENT fix: the upstream alternative --
# making the SHARED, spin-export-critical ``_strip_shape_assertions``
# selective -- risks re-triggering the torch.export bugs it exists to bypass
# and the spin ``.pt2`` path, so it is deliberately NOT taken.
#
# Pure arithmetic => torch.export-safe, unlike ``xp.clip`` (SymInt bound
# breaks array_api_compat's clip) and unlike a mask-multiply (which misses the
# ``edge_mask == 1`` indices the stripped guard mis-bounds).
src = edge_index[0] % n_out
dst = edge_index[1] % n_out
# force (output sized to the node axis, incl. any padding tail)
force = segment_sum(g, dst, n_out) - segment_sum(g, src, n_out)
# per-edge virial w_e[k, j] = -g_e[k] * edge_vec[j] (broadcast, no einsum)
Expand All @@ -101,6 +128,8 @@ def edge_force_virial(
boundaries = xp.cumulative_sum(n_node) # (nf,) per-frame node upper bounds
edge_frame = xp.astype(
xp.searchsorted(boundaries, dst, side="right"), xp.int64
) # (E,) in [0, nf)
) # (E,) in [0, nf]
# wrap into [0, nf) for the same CUDA-bounds reason (export-safe modulo)
edge_frame = edge_frame % nf
virial = segment_sum(w_edge, edge_frame, nf) # (nf, 3, 3)
return force, atom_virial, virial
9 changes: 7 additions & 2 deletions deepmd/dpmodel/utils/neighbor_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,18 @@ def frame_id_from_n_node(n_node: Array, n_total: int | None = None) -> Array:
dev = array_api_compat.device(n_node)
if n_total is None:
n_total = int(xp.sum(n_node))
nf = n_node.shape[0]
idx = xp.arange(n_total, dtype=n_node.dtype, device=dev)
boundaries = xp.cumulative_sum(n_node) # (nf,) upper bounds, exclusive
frame_id = xp.astype(xp.searchsorted(boundaries, idx, side="right"), xp.int64)
# padding nodes (idx >= sum(n_node)) land at frame ``nf`` (OOB); clamp them to
# the last real frame so the per-frame scatter never indexes out of range.
return xp.minimum(frame_id, xp.asarray(nf - 1, dtype=xp.int64, device=dev))
# Derive ``nf - 1`` as a RUNTIME 0-d tensor (sum of ones over the frame axis)
# rather than ``xp.asarray(n_node.shape[0] - 1)``: under symbolic make_fx /
# torch.export, ``shape[0]`` is a SymInt and materializing it into a constant
# tensor SPECIALIZES the frame axis -- baking the trace-time frame count into
# every downstream per-frame reduction and breaking dynamic-``nf`` inference.
last_frame = xp.sum(xp.ones_like(n_node)) - 1 # 0-d int == nf - 1
return xp.minimum(frame_id, xp.astype(last_frame, xp.int64))


def node_validity_mask(n_node: Array, n_total: int) -> Array:
Expand Down
10 changes: 10 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,16 @@ def main_parser() -> argparse.ArgumentParser:
type=str,
help="(Supported backend: PyTorch) Task head (alias: model branch) to freeze if in multi-task mode.",
)
parser_frz.add_argument(
"--lower-kind",
default="nlist",
type=str,
choices=["nlist", "graph"],
help="(Supported backend: PyTorch Exportable) Lower-level export form of the "
"frozen .pt2: 'nlist' (default, dense neighbor-list lower) or 'graph' "
"(NeighborGraph edge-list lower; only for graph-eligible models, currently "
"dpa1 with attn_layer=0). 'graph' selects the C++ graph inference path.",
)

# * test script ********************************************************************
parser_tst = subparsers.add_parser(
Expand Down
46 changes: 43 additions & 3 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def freeze(
model: str,
output: str = "frozen_model.pte",
head: str | None = None,
lower_kind: str = "nlist",
) -> None:
"""Freeze a pt_expt checkpoint into a .pte exported model.

Expand All @@ -398,6 +399,13 @@ def freeze(
Path for the output .pte file.
head : str or None
Head to freeze in multi-task mode.
lower_kind : str
Lower-level export form: ``"nlist"`` (default, dense neighbor-list lower)
or ``"graph"`` (NeighborGraph edge-list lower). ``"graph"`` is only valid
for graph-eligible models (``mixed_types`` and ``uses_graph_lower``,
currently dpa1 with ``attn_layer == 0``) and selects the C++ graph
inference path; the per-atom virial is enabled for it (near-free in the
graph path: one extra scatter off the shared single backward).
"""
import torch

Expand Down Expand Up @@ -458,12 +466,34 @@ def freeze(
single_model_params = model_params

m.eval()

# The graph lower is opt-in and only valid for graph-eligible models (dpa1
# attn_layer==0 today). Fail fast with a clear message rather than emitting a
# broken .pt2. Enable the per-atom virial for the graph form -- it is
# near-free there (one extra scatter off the single shared backward).
do_atomic_virial = False
if lower_kind == "graph":
from deepmd.pt_expt.train.training import (
_model_uses_graph_lower,
)

if not _model_uses_graph_lower(m):
raise ValueError(
"lower_kind='graph' requires a graph-eligible model "
"(mixed_types and a descriptor exposing uses_graph_lower()==True, "
"currently dpa1 with attn_layer==0). Use lower_kind='nlist' for "
"this model."
)
do_atomic_virial = True

model_dict_serialized = m.serialize()
deserialize_to_file(
output,
{"model": model_dict_serialized, "model_def_script": single_model_params},
do_atomic_virial=do_atomic_virial,
lower_kind=lower_kind,
)
log.info("Saved frozen model to %s", output)
log.info("Saved frozen model to %s (lower_kind=%s)", output, lower_kind)


def change_bias(
Expand Down Expand Up @@ -701,9 +731,19 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
f"Checkpoint path '{model_path}' does not exist."
)
FLAGS.model = str(model_path)
_lower_kind = getattr(FLAGS, "lower_kind", "nlist")
if not FLAGS.output.endswith((".pte", ".pt2")):
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))
freeze(model=FLAGS.model, output=FLAGS.output, head=FLAGS.head)
# Default suffix: .pt2 for the graph export (an AOTI .pt2 archive is
# what the C++ graph path consumes), .pte otherwise. Explicit user
# .pte / .pt2 suffixes are preserved for both.
_default_suffix = ".pt2" if _lower_kind == "graph" else ".pte"
FLAGS.output = str(Path(FLAGS.output).with_suffix(_default_suffix))
freeze(
model=FLAGS.model,
output=FLAGS.output,
head=FLAGS.head,
lower_kind=_lower_kind,
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
elif FLAGS.command == "change-bias":
change_bias(
input_file=FLAGS.INPUT,
Expand Down
113 changes: 113 additions & 0 deletions deepmd/pt_expt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,20 @@
import ase.neighborlist


# Public output keys emitted by the graph-form AOTI forward
# (``forward_lower_graph_exportable``) keyed by the output-variable category that
# ``request_defs`` carries. The graph path is LOCAL-only (``N == sum(n_node)``
# nodes, no ghosts), so its outputs are already at local-atom resolution -- no
# ``communicate_extended_output`` fold-back is needed.
_GRAPH_CATEGORY_TO_KEY = {
OutputVariableCategory.OUT: "atom_energy",
OutputVariableCategory.REDU: "energy",
OutputVariableCategory.DERV_R: "force",
OutputVariableCategory.DERV_C_REDU: "virial",
OutputVariableCategory.DERV_C: "atom_virial",
}


def _reshape_charge_spin(
charge_spin: np.ndarray, nframes: int, dim_chg_spin: int
) -> np.ndarray:
Expand Down Expand Up @@ -1423,6 +1437,10 @@ def _eval_model(
request_defs: list[OutputVariableDef],
charge_spin: np.ndarray | None = None,
) -> tuple[np.ndarray, ...]:
if self.metadata.get("lower_input_kind") == "graph":
return self._eval_model_graph(
coords, cells, atom_types, fparam, aparam, request_defs, charge_spin
)
model_inputs, mapping_t, nframes, natoms = self._prepare_inputs(
coords, cells, atom_types, fparam, aparam, charge_spin
)
Expand Down Expand Up @@ -1621,6 +1639,101 @@ def _eval_model_spin(
)
return tuple(results)

def _eval_model_graph(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
fparam: np.ndarray | None,
aparam: np.ndarray | None,
request_defs: list[OutputVariableDef],
charge_spin: np.ndarray | None = None,
) -> tuple[np.ndarray, ...]:
"""Evaluate a graph-form ``.pt2`` (``lower_input_kind == "graph"``).

Builds a carry-all :class:`~deepmd.dpmodel.utils.neighbor_graph.NeighborGraph`
from the eval system at its exact (tight) edge count and feeds the
positional schema
``(atype, n_node, edge_index, edge_vec, edge_mask, fparam, aparam,
charge_spin)`` to the exported forward. The AOTI artifact's edge axis
is DYNAMIC (B2.0), so no ``edge_capacity`` padding is needed. The
forward returns the LOCAL public keys directly, so results are reshaped
without ``communicate_extended_output``.
"""
from deepmd.dpmodel.utils.neighbor_graph import (
build_neighbor_graph,
)
from deepmd.pt_expt.utils.env import (
DEVICE,
)

nframes = coords.shape[0]
if len(atom_types.shape) == 1:
natoms = len(atom_types)
atom_types = np.tile(atom_types, nframes).reshape(nframes, -1)
else:
natoms = len(atom_types[0])

coord_input = coords.reshape(nframes, natoms, 3)
box_input = cells.reshape(nframes, 9) if cells is not None else None
# Dynamic edge axis (B2.0): build the carry-all graph at its exact edge
# count (no static padding); the AOTI artifact accepts any E.
graph = build_neighbor_graph(
coord_input,
atom_types,
box_input,
self._rcut,
)

atype_t = torch.tensor(
np.asarray(atom_types).reshape(-1), dtype=torch.int64, device=DEVICE
)
n_node_t = torch.tensor(
np.asarray(graph.n_node), dtype=torch.int64, device=DEVICE
)
edge_index_t = torch.tensor(
np.asarray(graph.edge_index), dtype=torch.int64, device=DEVICE
)
edge_vec_t = torch.tensor(
np.asarray(graph.edge_vec), dtype=torch.float64, device=DEVICE
)
edge_mask_t = torch.tensor(
np.asarray(graph.edge_mask), dtype=torch.bool, device=DEVICE
)

fparam_t, aparam_t = self._prepare_optional_lower_inputs(
fparam, aparam, nframes, natoms, DEVICE
)
charge_spin_t = self._make_charge_spin_input(nframes, charge_spin)

model_inputs = (
atype_t,
n_node_t,
edge_index_t,
edge_vec_t,
edge_mask_t,
fparam_t,
aparam_t,
charge_spin_t,
)
if self._is_pt2:
model_ret = self._pt2_runner(*model_inputs)
else:
model_ret = self.exported_module(*model_inputs)

results = []
for odef in request_defs:
shape = self._get_output_shape(odef, nframes, natoms)
gkey = _GRAPH_CATEGORY_TO_KEY.get(odef.category)
val = model_ret.get(gkey) if gkey is not None else None
if val is not None:
results.append(val.detach().cpu().numpy().reshape(shape))
else:
results.append(
np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION)
)
return tuple(results)

def _get_output_shape(
self, odef: OutputVariableDef, nframes: int, natoms: int
) -> list[int]:
Expand Down
Loading
Loading