Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 160 additions & 19 deletions deepmd/dpmodel/descriptor/repflows.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import warnings
from collections.abc import (
Callable,
)
Expand Down Expand Up @@ -196,6 +197,11 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock):
Whether the block is trainable
"""

# Internal-only switch used by backends that cannot export/trace arrays with
# runtime-sized edge/angle dimensions. It is deliberately not exposed in
# input JSON or serialization: users still only see ``use_dynamic_sel``.
_use_static_dynamic_sel: bool = False

def __init__(
self,
e_rcut: float,
Expand Down Expand Up @@ -267,6 +273,17 @@ def __init__(
self.edge_init_use_dist = edge_init_use_dist
self.use_exp_switch = use_exp_switch
self.use_dynamic_sel = use_dynamic_sel
# Snapshot the class-level backend choice on the instance. This keeps
# descriptor/layer behavior stable even if a backend temporarily changes
# the class attribute while constructing a model.
self._use_static_dynamic_sel = type(self)._use_static_dynamic_sel
if self.use_dynamic_sel and self._use_static_dynamic_sel:
warnings.warn(
"The JAX exportable dynamic-selection layout materializes "
"fixed angle capacity nf * nloc * a_sel * a_sel. Keep a_sel "
"modest for exportable DPA-3 models.",
stacklevel=2,
)
Comment thread
njzjz marked this conversation as resolved.
self.use_loc_mapping = use_loc_mapping
self.sel_reduce_factor = sel_reduce_factor
self.sequential_update = sequential_update
Expand Down Expand Up @@ -348,6 +365,10 @@ def __init__(
)
)
self.layers = layers
# RepFlowLayer has the same internal switch; keep all layers in the same
# layout mode as the descriptor block that owns them.
for layer in self.layers:
layer._use_static_dynamic_sel = self._use_static_dynamic_sel

wanted_shape = (self.ntypes, self.nnei, 4)
self.env_mat_edge = EnvMat(
Expand Down Expand Up @@ -632,28 +653,58 @@ def call(

if self.use_dynamic_sel:
# get graph index
edge_index, angle_index = get_graph_index(
nlist,
nlist_mask,
a_nlist_mask,
nall,
use_loc_mapping=self.use_loc_mapping,
)
# flat all the tensors
# n_edge x 1
edge_input = edge_input[nlist_mask]
# n_edge x 3
h2 = h2[nlist_mask]
# n_edge x 1
sw = sw[nlist_mask]
if self._use_static_dynamic_sel:
# Keep the dynamic-selection math but use fixed capacities:
# n_edge = nf * nloc * e_sel
# n_angle = nf * nloc * a_sel * a_sel
# Invalid padded slots have sw=0 (or pair a_sw=0), so for every
# owner i:
# sum_{j=1}^{e_sel} sw_ij m_ij
# = sum_{j in N(i)} sw_ij m_ij.
# Thus the output matches compact dynamic selection while the
# tensor shapes remain trace/export friendly.
edge_index, angle_index = _get_static_graph_index(
nlist,
a_nlist_mask,
nall,
use_loc_mapping=self.use_loc_mapping,
)
# fixed-size flattened layout: (nf * nloc * sel) x ...
edge_input = xp.reshape(edge_input, (-1, edge_input.shape[-1]))
h2 = xp.reshape(h2, (-1, h2.shape[-1]))
sw = xp.reshape(sw, (-1,))
else:
edge_index, angle_index = get_graph_index(
nlist,
nlist_mask,
a_nlist_mask,
nall,
use_loc_mapping=self.use_loc_mapping,
)
# flat all the tensors
# n_edge x 1
edge_input = edge_input[nlist_mask]
# n_edge x 3
h2 = h2[nlist_mask]
# n_edge x 1
sw = sw[nlist_mask]
# nb x nloc x a_nnei x a_nnei
a_nlist_mask = xp.logical_and(
a_nlist_mask[:, :, :, None], a_nlist_mask[:, :, None, :]
)
# n_angle x 1
angle_input = angle_input[a_nlist_mask]
# n_angle x 1
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
if self._use_static_dynamic_sel:
# The angle graph keeps all (j, k) pairs in the a_sel square.
# Pairs involving padded angle neighbors are zeroed by
# a_sw_ij * a_sw_ik before reduction, preserving the compact
# dynamic result.
# fixed-size flattened layout: (nf * nloc * a_sel * a_sel) x ...
angle_input = xp.reshape(angle_input, (-1, angle_input.shape[-1]))
a_sw = xp.reshape(a_sw[:, :, :, None] * a_sw[:, :, None, :], (-1,))
else:
# n_angle x 1
angle_input = angle_input[a_nlist_mask]
# n_angle x 1
a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask]
else:
edge_index = xp.zeros(
[2, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist)
Expand Down Expand Up @@ -765,6 +816,11 @@ def deserialize(cls, data: dict) -> "DescrptBlockRepflows":
obj.edge_embd = edge_embd
obj.angle_embd = angle_embd
obj.layers = layers
# ``_use_static_dynamic_sel`` is intentionally not serialized. After a
# backend-specific class deserializes the block, propagate the backend's
# class-level default to the restored layers.
for layer in obj.layers:
layer._use_static_dynamic_sel = obj._use_static_dynamic_sel
obj.env_mat_edge = env_mat_edge
obj.env_mat_angle = env_mat_angle
obj.mean = davg
Expand Down Expand Up @@ -821,6 +877,77 @@ def serialize(self) -> dict:
}


def _get_static_graph_index(
nlist: Array,
a_nlist_mask: Array,
nall: int,
use_loc_mapping: bool = True,
) -> tuple[Array, Array]:
"""Build graph indices with fixed edge/angle capacities.

This mirrors ``get_graph_index`` but keeps all padded slots instead of
compacting with masks, so the first dimension is shape-static.

Edge slot formula, for frame ``f``, local atom ``i``, edge slot ``j``:

``p = (f * nloc + i) * nnei + j``
``n2e[p] = f * nloc + i``
``n_ext2e[p] = f * nall + nlist[f, i, j]``

If ``use_loc_mapping`` is true, ``nlist`` has already been remapped to local
indices, so the frame stride is ``nloc`` instead of ``nall``.

Angle slot formula, for angle pair ``(j, k)``:

``q = ((f * nloc + i) * a_sel + j) * a_sel + k``
``n2a[q] = f * nloc + i``
``eij2a[q] = p(f, i, j)``
``eik2a[q] = p(f, i, k)``

Padded edge/angle slots are intentionally included. Their messages vanish
because the corresponding switch weights are zero before owner reductions.
"""
xp = array_api_compat.array_namespace(nlist, a_nlist_mask)

nf, nloc, nnei = nlist.shape
_, _, a_nnei = a_nlist_mask.shape
dev = array_api_compat.device(nlist)

nlist_loc_index = xp.arange(nf * nloc, dtype=nlist.dtype, device=dev)
n2e_index = xp.broadcast_to(
xp.reshape(nlist_loc_index, (nf, nloc, 1)), (nf, nloc, nnei)
)
n2e_index = xp.reshape(n2e_index, (-1,))

frame_shift = xp.arange(nf, dtype=nlist.dtype, device=dev) * (
nall if not use_loc_mapping else nloc
)
shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis]
n_ext2e_index = xp.reshape(shifted_nlist, (-1,))

n2a_index = xp.broadcast_to(
xp.reshape(nlist_loc_index, (nf, nloc, 1, 1)), (nf, nloc, a_nnei, a_nnei)
)
n2a_index = xp.reshape(n2a_index, (-1,))

edge_id = xp.reshape(
xp.arange(nf * nloc * nnei, dtype=nlist.dtype, device=dev),
(nf, nloc, nnei),
)[:, :, :a_nnei]
eij2a_index = xp.broadcast_to(
edge_id[:, :, :, xp.newaxis], (nf, nloc, a_nnei, a_nnei)
)
eij2a_index = xp.reshape(eij2a_index, (-1,))
eik2a_index = xp.broadcast_to(
edge_id[:, :, xp.newaxis, :], (nf, nloc, a_nnei, a_nnei)
)
eik2a_index = xp.reshape(eik2a_index, (-1,))

edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=0)
angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=0)
return edge_index_result, angle_index_result


def _cal_hg_dynamic(
flat_edge_ebd: Array,
flat_h2: Array,
Expand Down Expand Up @@ -934,6 +1061,10 @@ def symmetrization_op_dynamic(


class RepFlowLayer(NativeOP):
# Mirrors the descriptor-block internal switch. The owning block writes the
# instance value during construction/deserialization.
_use_static_dynamic_sel: bool = False

def __init__(
self,
e_rcut: float,
Expand Down Expand Up @@ -1004,6 +1135,9 @@ def __init__(
self.optim_update = optim_update
self.smooth_edge_update = smooth_edge_update
self.use_dynamic_sel = use_dynamic_sel
# Default instance value for standalone RepFlowLayer construction.
# DescrptBlockRepflows overwrites it for regular DPA-3 use.
self._use_static_dynamic_sel = type(self)._use_static_dynamic_sel
self.sel_reduce_factor = sel_reduce_factor
self.sequential_update = sequential_update
self.dynamic_e_sel = self.nnei / self.sel_reduce_factor
Expand Down Expand Up @@ -1476,9 +1610,16 @@ def call(
)
nb, nloc, nnei = nlist.shape
nall = node_ebd_ext.shape[1]
# In compact dynamic mode ``n_edge`` is the number of real edges. In
# static dynamic mode it is the fixed edge capacity; padded entries carry
# zero switch weights, so owner reductions are unchanged.
# int cannot jit; do not run it when self.use_dynamic_sel == False
n_edge = (
int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0
h2.shape[0]
if self.use_dynamic_sel and self._use_static_dynamic_sel
else int(xp.sum(xp.astype(nlist_mask, xp.int32)))
if self.use_dynamic_sel
else 0
)
node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc)
assert (nb, nloc) == node_ebd.shape[:2]
Expand Down
30 changes: 21 additions & 9 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,20 +1275,32 @@ def aggregate( # noqa: ANN201
output: [num_owner, feature_dim]
"""
xp = array_api_compat.array_namespace(data, owners)
bin_count = xp_bincount(owners)
bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count)

dev = array_api_compat.device(data)
if num_owner is not None and bin_count.shape[0] != num_owner:
difference = num_owner - bin_count.shape[0]
bin_count = xp.concat(
[bin_count, xp.ones(difference, dtype=bin_count.dtype, device=dev)]
)
if num_owner is None or average:
# Averaging needs the owner population:
# avg[o] = sum_{r: owners[r] = o} data[r] / count[o].
# If num_owner is omitted, bincount also determines the output length.
bin_count = xp_bincount(owners)
bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count)
if num_owner is not None and bin_count.shape[0] != num_owner:
difference = num_owner - bin_count.shape[0]
bin_count = xp.concat(
[bin_count, xp.ones(difference, dtype=bin_count.dtype, device=dev)]
)
else:
num_owner = bin_count.shape[0]
else:
# Sum-reduction with a known owner count only needs:
# out[o] = sum_{r: owners[r] = o} data[r].
# Skipping bincount here is mathematically identical and avoids JAX
# tracing failures where jnp.bincount requires concrete owner values.
bin_count = None

output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype, device=dev)
output = xp.zeros((num_owner, data.shape[1]), dtype=data.dtype, device=dev)
output = xp_add_at(output, owners, data)

if average:
assert bin_count is not None
output = xp.transpose(xp.transpose(output) / bin_count)

return output
Expand Down
12 changes: 11 additions & 1 deletion deepmd/jax/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,14 @@

@flax_module
class DescrptBlockRepflows(DescrptBlockRepflowsDP):
pass
# JAX/jax2tf export cannot represent the compact dynamic layout where
# boolean indexing creates runtime-sized ``n_edge``/``n_angle`` arrays.
# Use the fixed-capacity dynamic layout instead:
# edges = nf * nloc * e_sel
# angles = nf * nloc * a_sel * a_sel
# Invalid slots are still masked by switch weights, so DPA-3 outputs match
# the compact dynamic implementation.
_use_static_dynamic_sel = True
Comment thread
njzjz marked this conversation as resolved.


@flax_module
Expand All @@ -26,3 +33,6 @@ class RepFlowLayer(RepFlowLayerDP):
"e_residual",
"a_residual",
}
# Keep the layer-level graph operations in the same fixed-capacity layout
# selected by the owning descriptor block.
_use_static_dynamic_sel = True
5 changes: 4 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2077,7 +2077,8 @@ def dpa3_repflow_args() -> list[Argument]:
doc_a_rcut_smth = "Where to start smoothing for angle. For example the 1/r term is smoothed from `rcut` to `rcut_smth`."
doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\
- `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\
- `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors within the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally, the number is rounded up to a multiple of 4. The option "auto" is equivalent to "auto:1.1".'
- `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors within the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally, the number is rounded up to a multiple of 4. The option "auto" is equivalent to "auto:1.1". \n\n\
For JAX export, DPA3 uses static shapes and materializes angle-pair work arrays with size proportional to `nf * nloc * a_sel^2`. Keep `a_sel` no larger than needed, including when `use_dynamic_sel` is enabled.'
doc_a_compress_rate = (
"The compression rate for angular messages. The default value is 0, indicating no compression. "
" If a non-zero integer c is provided, the node and edge dimensions will be compressed "
Expand Down Expand Up @@ -2151,6 +2152,8 @@ def dpa3_repflow_args() -> list[Argument]:
"without padding to a fixed selection numbers. "
"When enabled, users can safely set larger values for `e_sel` or `a_sel` (e.g., 1200 or 300, respectively) "
"to guarantee capturing all neighbors within the cutoff radius. "
"For JAX export, the static upper bound still controls memory use; in particular, angle-pair work arrays "
"scale as `nf * nloc * a_sel^2`. "
"Note that when using dynamic selection, the `smooth_edge_update` must be True. "
)
doc_sel_reduce_factor = (
Expand Down
46 changes: 46 additions & 0 deletions source/tests/common/dpmodel/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
NativeLayer,
NativeNet,
NetworkCollection,
aggregate,
load_dp_model,
save_dp_model,
)
Expand Down Expand Up @@ -478,6 +479,51 @@ def test_zero_dim(self) -> None:
)


class TestAggregate(unittest.TestCase):
def setUp(self) -> None:
self.data = np.array(
[
[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0],
],
dtype=np.float64,
)
self.owners = np.array([0, 2, 2], dtype=np.int64)

def test_average_with_explicit_num_owner(self) -> None:
output = aggregate(self.data, self.owners, average=True, num_owner=4)

np.testing.assert_allclose(
output,
np.array(
[
[1.0, 2.0],
[0.0, 0.0],
[4.0, 5.0],
[0.0, 0.0],
],
dtype=np.float64,
),
)

def test_sum_with_explicit_num_owner(self) -> None:
output = aggregate(self.data, self.owners, average=False, num_owner=4)

np.testing.assert_allclose(
output,
np.array(
[
[1.0, 2.0],
[0.0, 0.0],
[8.0, 10.0],
[0.0, 0.0],
],
dtype=np.float64,
),
)


class TestSaveLoadDPModel(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
Expand Down
Loading
Loading