Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Mapping,
)
from copy import (
deepcopy,
)
from itertools import (
pairwise,
)
from typing import (
Any,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.array_api import (
Array,
)
Expand All @@ -14,6 +23,7 @@
)
from deepmd.dpmodel.common import (
NativeOP,
to_numpy_array,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
Expand Down Expand Up @@ -92,6 +102,7 @@ def call(
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
mixed_batch: Mapping[str, Array] | None = None,
charge_spin: Array | None = None,
neighbor_list: NeighborList | None = None,
) -> dict[str, Array]:
Expand All @@ -110,6 +121,19 @@ def call(
injected to accelerate neighbor-list construction without changing
the model outputs.
"""
if mixed_batch is not None:
return self.call_flat(
coord=coord,
atype=atype,
box=box,
fparam=fparam,
aparam=aparam,
charge_spin=charge_spin,
do_atomic_virial=do_atomic_virial,
mixed_batch=mixed_batch,
neighbor_list=neighbor_list,
)

model_ret = self.call_common(
coord,
atype,
Expand All @@ -135,6 +159,112 @@ def call(
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3)
return model_predict

def call_flat(
self,
coord: Array,
atype: Array,
mixed_batch: Mapping[str, Array],
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
charge_spin: Array | None = None,
do_atomic_virial: bool = False,
neighbor_list: NeighborList | None = None,
) -> dict[str, Array]:
"""Evaluate a flattened mixed-nloc batch with the dpmodel backend.

The dpmodel backend reuses the regular one-frame call path for each
segment described by ``ptr`` and merges the translated outputs back into
the flat mixed-batch layout.
"""
batch = mixed_batch.get("batch")
ptr = mixed_batch.get("ptr")
if batch is None or ptr is None:
raise ValueError("mixed_batch must contain both batch and ptr.")
if self._enable_hessian:
raise NotImplementedError(
"Hessian is not implemented for dpmodel mixed-batch flat calls."
)

xp = array_api_compat.array_namespace(coord, atype)
ptr_np = to_numpy_array(ptr)
if ptr_np is None:
raise ValueError("ptr is required for mixed batches.")
ptr_np = np.asarray(ptr_np, dtype=np.int64)
if ptr_np.ndim != 1 or ptr_np.size < 2:
raise ValueError("ptr must be a 1D array with at least two entries.")

total_atoms = coord.shape[0]
if ptr_np[0] != 0 or ptr_np[-1] != total_atoms:
raise ValueError("ptr must start at 0 and end at the number of atoms.")
if batch.shape[0] != total_atoms:
raise ValueError("batch length must match the number of atoms.")

frame_outputs = []
for frame_idx, (start, end) in enumerate(pairwise(ptr_np)):
nloc = int(end - start)
frame_coord = xp.reshape(coord[start:end], (1, nloc * 3))
frame_atype = xp.reshape(atype[start:end], (1, nloc))
frame_box = box[frame_idx : frame_idx + 1] if box is not None else None
frame_fparam = (
fparam[frame_idx : frame_idx + 1] if fparam is not None else None
)
frame_aparam = (
xp.reshape(aparam[start:end], (1, nloc, *aparam.shape[1:]))
if aparam is not None
else None
)
frame_charge_spin = (
charge_spin[frame_idx : frame_idx + 1]
if charge_spin is not None
else None
)
frame_outputs.append(
self.call(
frame_coord,
frame_atype,
box=frame_box,
fparam=frame_fparam,
aparam=frame_aparam,
charge_spin=frame_charge_spin,
do_atomic_virial=do_atomic_virial,
neighbor_list=neighbor_list,
)
)

return self._merge_flat_frame_outputs(frame_outputs)

@staticmethod
def _merge_flat_frame_outputs(
frame_outputs: list[dict[str, Array]],
) -> dict[str, Array]:
if not frame_outputs:
raise ValueError("mixed-batch input must contain at least one frame.")

framewise_keys = {"energy", "virial"}
result: dict[str, Array] = {}
for key in frame_outputs[0]:
values = [frame_output[key] for frame_output in frame_outputs]
xp = array_api_compat.array_namespace(values[0])
if key in framewise_keys:
result[key] = xp.concat(values, axis=0)
elif key == "mask":
result[key] = xp.concat(
[xp.reshape(value, (-1,)) for value in values],
axis=0,
)
else:
result[key] = xp.concat(
[
xp.reshape(value, (-1, *value.shape[2:]))
if value.ndim >= 3
else xp.reshape(value, (-1,))
for value in values
],
axis=0,
)
return result

def call_lower(
self,
extended_coord: Array,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
PairExcludeMask,
)
from .lmdb_data import (
DistributedMixedBatchSampler,
DistributedSameNlocBatchSampler,
LmdbDataReader,
LmdbTestData,
LmdbTestDataNlocView,
MixedBatchSampler,
SameNlocBatchSampler,
is_lmdb,
make_neighbor_stat_data,
Expand Down Expand Up @@ -71,6 +73,7 @@
__all__ = [
"AtomExcludeMask",
"DefaultNeighborList",
"DistributedMixedBatchSampler",
"DistributedSameNlocBatchSampler",
"EmbeddingNet",
"EnvMat",
Expand All @@ -79,6 +82,7 @@
"LmdbDataReader",
"LmdbTestData",
"LmdbTestDataNlocView",
"MixedBatchSampler",
"NativeLayer",
"NativeNet",
"NeighborGraph",
Expand Down
Loading
Loading