Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 225 additions & 10 deletions deepmd/pt/entrypoints/freeze_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
annotations,
)

import ctypes
import json
import logging
import os
import tempfile
import zipfile
from copy import (
deepcopy,
Expand Down Expand Up @@ -229,6 +232,7 @@ def _collect_metadata(
output_keys: list[str],
is_spin: bool | None = None,
do_atomic_virial: bool = False,
has_comm_artifact: bool = False,
) -> dict:
"""Assemble the flat metadata dict expected by :class:`DeepPotPTExpt`.

Expand Down Expand Up @@ -272,7 +276,7 @@ def _collect_metadata(
"dim_chg_spin": int(model.get_dim_chg_spin()),
"mixed_types": bool(model.mixed_types()),
"has_message_passing": _model_has_message_passing(model),
"has_comm_artifact": False,
"has_comm_artifact": bool(has_comm_artifact),
"do_atomic_virial": exports_atomic_virial,
"nnei": int(sum(model.get_sel())),
"has_default_fparam": bool(model.has_default_fparam()),
Expand All @@ -291,16 +295,25 @@ def _collect_metadata(
return metadata


def _make_sample_inputs(
# The trace-time sendlist for the with-comm artifact embeds the address of a
# numpy array (``int**`` contract of ``border_op``). The array must outlive the
# trace + export call; the exported graph never reads it at runtime (the op is
# opaque), so a module-level keepalive is sufficient.
_TRACE_SENDLIST_KEEPALIVE: list[np.ndarray] = []


def _build_sample_extended(
model: torch.nn.Module,
nframes: int,
nloc: int,
device: torch.device,
has_spin: bool = False,
has_spin: bool,
) -> tuple[torch.Tensor | None, ...]:
"""Build representative ``forward_common_lower`` inputs for tracing.
"""Build the extended-region sample tensors shared by the lower builders.

Tensors are float64 / int64 (matching the ``.pt2`` I/O contract).
Returns ``(ext_coord, ext_atype, nlist, mapping, ext_spin, fparam, aparam,
charge_spin)``; tensors are float64 / int64 (matching the ``.pt2`` I/O
contract). ``ext_spin`` is ``None`` unless ``has_spin``.
"""
rcut = float(model.get_rcut())
sel = list(model.get_sel())
Expand Down Expand Up @@ -353,6 +366,7 @@ def _make_sample_inputs(
ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=device)
nlist_t = torch.tensor(nlist, dtype=torch.int64, device=device)
mapping_t = torch.tensor(mapping, dtype=torch.int64, device=device)
ext_spin = None
if has_spin:
extended_spin = np.take_along_axis(spin_np, mapping[..., None], axis=1)
ext_spin = torch.tensor(extended_spin, dtype=torch.float64, device=device)
Expand All @@ -366,11 +380,45 @@ def _make_sample_inputs(
if dim_aparam > 0
else None
)
charge_spin = None
if dim_chg_spin > 0:
charge_spin = torch.zeros(
nframes, dim_chg_spin, dtype=torch.float64, device=device
)
charge_spin = (
torch.zeros(nframes, dim_chg_spin, dtype=torch.float64, device=device)
if dim_chg_spin > 0
else None
)
return (
ext_coord,
ext_atype,
nlist_t,
mapping_t,
ext_spin,
fparam,
aparam,
charge_spin,
)


def _make_sample_inputs(
model: torch.nn.Module,
nframes: int,
nloc: int,
device: torch.device,
has_spin: bool = False,
) -> tuple[torch.Tensor | None, ...]:
"""Build representative ``forward_common_lower`` inputs for tracing.

The spin path returns the nlist lower signature; the energy path returns the
single-domain edge schema (folded ``edge_index``, extended scatter indices).
"""
(
ext_coord,
ext_atype,
nlist_t,
mapping_t,
ext_spin,
fparam,
aparam,
charge_spin,
) = _build_sample_extended(model, nframes, nloc, device, has_spin)
if has_spin:
return (
ext_coord,
Expand Down Expand Up @@ -402,6 +450,83 @@ def _make_sample_inputs(
)


def _make_edge_comm_tensors(
mapping: torch.Tensor,
nloc: int,
device: torch.device,
) -> tuple[torch.Tensor, ...]:
"""Build a single self-send swap so the with-comm trace runs ``border_op``.

A LAMMPS run supplies the real per-swap communication plan at inference time;
the trace only needs valid in-range indices so the eager output-key probe can
execute the opaque op. Ghost slot ``k`` copies its owner's local index
``mapping[nloc + k]``.
"""
nall = int(mapping.shape[1])
nghost = nall - nloc
send_count = max(1, nghost)
owner = mapping[0, nloc:nall].to(dtype=torch.int32).cpu().numpy()
indices = np.ascontiguousarray(np.resize(owner, send_count).astype(np.int32))
_TRACE_SENDLIST_KEEPALIVE.append(indices)
addr = indices.ctypes.data_as(ctypes.c_void_p).value
return (
torch.tensor([addr], dtype=torch.int64, device=device), # send_list (int**)
torch.zeros(1, dtype=torch.int32, device=device), # send_proc (self)
torch.zeros(1, dtype=torch.int32, device=device), # recv_proc (self)
torch.tensor([send_count], dtype=torch.int32, device=device), # send_num
torch.tensor([send_count], dtype=torch.int32, device=device), # recv_num
torch.zeros(1, dtype=torch.int64, device=device), # communicator
torch.tensor(nloc, dtype=torch.int32, device=device), # nlocal
torch.tensor(nghost, dtype=torch.int32, device=device), # nghost
)


def _make_comm_sample_inputs(
model: torch.nn.Module,
nloc: int,
device: torch.device,
) -> tuple[torch.Tensor | None, ...]:
"""Build with-comm edge inputs for tracing the parallel ``.pt2`` artifact.

The parallel path indexes the extended node set directly, so ``edge_index``
coincides with ``edge_scatter_index`` (both extended) and ghost features are
refreshed via ``border_op`` rather than gathered through a folded mapping.
The frame axis is fixed at one, matching LAMMPS single-frame inference.
"""
(
ext_coord,
ext_atype,
nlist_t,
mapping_t,
_ext_spin,
fparam,
aparam,
charge_spin,
) = _build_sample_extended(
model, nframes=1, nloc=nloc, device=device, has_spin=False
)
formatted_nlist: torch.Tensor = model.format_nlist(ext_coord, ext_atype, nlist_t)
edge_schema = edge_schema_from_extended(
ext_coord,
ext_atype[:, :nloc],
formatted_nlist,
mapping_t,
)
return (
edge_schema.coord, # (1, nall, 3)
edge_schema.atype, # (1, nloc)
ext_atype, # (1, nall)
edge_schema.edge_scatter_index, # edge_index: extended (2, E)
edge_schema.edge_vec,
edge_schema.edge_scatter_index, # edge_scatter_index: extended (2, E)
edge_schema.edge_mask,
fparam,
aparam,
charge_spin,
*_make_edge_comm_tensors(mapping_t, nloc, device),
)


def _resolve_nframes(
model: torch.nn.Module,
nloc: int,
Expand Down Expand Up @@ -489,6 +614,79 @@ def _build_dynamic_shapes(
return shapes


def _build_with_comm_dynamic_shapes(
sample_inputs: tuple[torch.Tensor | None, ...],
) -> tuple:
"""Build dynamic-shape constraints for the parallel with-comm lower input.

The frame axis is fixed at one (LAMMPS single-frame inference), so only
``nall``, ``nloc`` and ``nedge`` vary. The eight communication tensors are
static: ``nswap`` is fixed at LAMMPS init and the graph carries no variation
across its value (``border_op`` is opaque to the exported program).
"""
nall_dim = torch.export.Dim("nall", min=1)
nloc_dim = torch.export.Dim("nloc", min=1)
nedge_dim = torch.export.Dim("nedge", min=2)
fparam = sample_inputs[7]
aparam = sample_inputs[8]
charge_spin = sample_inputs[9]
base = (
{1: nall_dim}, # coord: (1, nall, 3)
{1: nloc_dim}, # atype: (1, nloc)
{1: nall_dim}, # extended_atype: (1, nall)
{1: nedge_dim}, # edge_index: (2, nedge)
{0: nedge_dim}, # edge_vec: (nedge, 3)
{1: nedge_dim}, # edge_scatter_index: (2, nedge)
{0: nedge_dim}, # edge_mask: (nedge,)
None if fparam is None else {}, # fparam: (1, ndf) static
None if aparam is None else {1: nloc_dim}, # aparam: (1, nloc, nda)
None if charge_spin is None else {}, # charge_spin: (1, nchg) static
)
return (*base, *((None,) * 8))
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def _export_with_comm_artifact(
model: torch.nn.Module,
*,
target_device: torch.device,
compile_options: dict[str, Any],
) -> bytes:
"""Trace, export and compile the parallel with-comm ``.pt2`` artifact.

The artifact mirrors the regular edge graph but exchanges ghost node
features across ranks via ``border_op``. Returns the compiled package bytes
for nesting under ``model/extra/forward_lower_with_comm.pt2``; tracing runs
on CPU and the package is moved to ``target_device`` before compilation.
"""
from torch._inductor import (
aoti_compile_and_package,
)
from torch._inductor import config as inductor_config

sample_inputs = _make_comm_sample_inputs(model, nloc=7, device=torch.device("cpu"))
traced = model.forward_common_lower_exportable_with_comm(*sample_inputs)
exported = torch.export.export(
traced,
sample_inputs,
dynamic_shapes=_build_with_comm_dynamic_shapes(sample_inputs),
strict=False,
prefer_deferred_runtime_asserts_over_guards=True,
)
_strip_shape_assertions(exported.graph_module)
if target_device.type != "cpu":
from torch.export.passes import (
move_to_device_pass,
)

exported = move_to_device_pass(exported, target_device)
with tempfile.TemporaryDirectory() as td:
wc_path = os.path.join(td, "forward_lower_with_comm.pt2")
with inductor_config.patch({**compile_options, "triton.max_tiles": 1}):
aoti_compile_and_package(exported, package_path=wc_path)
with open(wc_path, "rb") as fh:
return fh.read()


def freeze_sezm_to_pt2(
ckpt_path: str,
out_path: str,
Expand Down Expand Up @@ -637,14 +835,31 @@ def freeze_sezm_to_pt2(
with inductor_config.patch({**compile_options, "triton.max_tiles": 1}):
aoti_compile_and_package(exported, package_path=out_path_str)

# Second artifact: the LAMMPS multi-rank with-comm graph. It threads the
# eight border_op communication tensors so cross-rank ghost features are
# exchanged between interaction blocks. Excluded for spin (nlist lower
# interface) and bridging models (Source Freeze Propagation is not
# rank-decomposable); those fall back to single-rank inference.
with_comm = (not is_spin) and model.supports_edge_parallel()
with_comm_bytes: bytes | None = None
if with_comm:
with_comm_bytes = _export_with_comm_artifact(
model,
target_device=target_device,
compile_options=compile_options,
)

metadata = _collect_metadata(
model,
output_keys=output_keys,
is_spin=is_spin,
do_atomic_virial=atomic_virial,
has_comm_artifact=with_comm,
)
with zipfile.ZipFile(out_path_str, "a") as zf:
zf.writestr("model/extra/metadata.json", json.dumps(metadata))
if with_comm_bytes is not None:
zf.writestr("model/extra/forward_lower_with_comm.pt2", with_comm_bytes)
# The raw training params are preserved so `dp change-bias` and
# other downstream tooling can recover the exact training config.
# ``default=str`` is a safety net for exotic nested values.
Expand Down
Loading
Loading