diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index b9593a0c61..c7a7f4a8d6 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import warnings from collections.abc import ( Callable, ) @@ -196,6 +197,11 @@ class DescrptBlockRepflows(NativeOP, DescriptorBlock): Whether the block is trainable """ + # Internal-only switch used by backends that cannot export/trace arrays with + # runtime-sized edge/angle dimensions. It is deliberately not exposed in + # input JSON or serialization: users still only see ``use_dynamic_sel``. + _use_static_dynamic_sel: bool = False + def __init__( self, e_rcut: float, @@ -267,6 +273,17 @@ def __init__( self.edge_init_use_dist = edge_init_use_dist self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel + # Snapshot the class-level backend choice on the instance. This keeps + # descriptor/layer behavior stable even if a backend temporarily changes + # the class attribute while constructing a model. + self._use_static_dynamic_sel = type(self)._use_static_dynamic_sel + if self.use_dynamic_sel and self._use_static_dynamic_sel: + warnings.warn( + "The JAX exportable dynamic-selection layout materializes " + "fixed angle capacity nf * nloc * a_sel * a_sel. Keep a_sel " + "modest for exportable DPA-3 models.", + stacklevel=2, + ) self.use_loc_mapping = use_loc_mapping self.sel_reduce_factor = sel_reduce_factor self.sequential_update = sequential_update @@ -348,6 +365,10 @@ def __init__( ) ) self.layers = layers + # RepFlowLayer has the same internal switch; keep all layers in the same + # layout mode as the descriptor block that owns them. + for layer in self.layers: + layer._use_static_dynamic_sel = self._use_static_dynamic_sel wanted_shape = (self.ntypes, self.nnei, 4) self.env_mat_edge = EnvMat( @@ -632,28 +653,58 @@ def call( if self.use_dynamic_sel: # get graph index - edge_index, angle_index = get_graph_index( - nlist, - nlist_mask, - a_nlist_mask, - nall, - use_loc_mapping=self.use_loc_mapping, - ) - # flat all the tensors - # n_edge x 1 - edge_input = edge_input[nlist_mask] - # n_edge x 3 - h2 = h2[nlist_mask] - # n_edge x 1 - sw = sw[nlist_mask] + if self._use_static_dynamic_sel: + # Keep the dynamic-selection math but use fixed capacities: + # n_edge = nf * nloc * e_sel + # n_angle = nf * nloc * a_sel * a_sel + # Invalid padded slots have sw=0 (or pair a_sw=0), so for every + # owner i: + # sum_{j=1}^{e_sel} sw_ij m_ij + # = sum_{j in N(i)} sw_ij m_ij. + # Thus the output matches compact dynamic selection while the + # tensor shapes remain trace/export friendly. + edge_index, angle_index = _get_static_graph_index( + nlist, + a_nlist_mask, + nall, + use_loc_mapping=self.use_loc_mapping, + ) + # fixed-size flattened layout: (nf * nloc * sel) x ... + edge_input = xp.reshape(edge_input, (-1, edge_input.shape[-1])) + h2 = xp.reshape(h2, (-1, h2.shape[-1])) + sw = xp.reshape(sw, (-1,)) + else: + edge_index, angle_index = get_graph_index( + nlist, + nlist_mask, + a_nlist_mask, + nall, + use_loc_mapping=self.use_loc_mapping, + ) + # flat all the tensors + # n_edge x 1 + edge_input = edge_input[nlist_mask] + # n_edge x 3 + h2 = h2[nlist_mask] + # n_edge x 1 + sw = sw[nlist_mask] # nb x nloc x a_nnei x a_nnei a_nlist_mask = xp.logical_and( a_nlist_mask[:, :, :, None], a_nlist_mask[:, :, None, :] ) - # n_angle x 1 - angle_input = angle_input[a_nlist_mask] - # n_angle x 1 - a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] + if self._use_static_dynamic_sel: + # The angle graph keeps all (j, k) pairs in the a_sel square. + # Pairs involving padded angle neighbors are zeroed by + # a_sw_ij * a_sw_ik before reduction, preserving the compact + # dynamic result. + # fixed-size flattened layout: (nf * nloc * a_sel * a_sel) x ... + angle_input = xp.reshape(angle_input, (-1, angle_input.shape[-1])) + a_sw = xp.reshape(a_sw[:, :, :, None] * a_sw[:, :, None, :], (-1,)) + else: + # n_angle x 1 + angle_input = angle_input[a_nlist_mask] + # n_angle x 1 + a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] else: edge_index = xp.zeros( [2, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist) @@ -765,6 +816,11 @@ def deserialize(cls, data: dict) -> "DescrptBlockRepflows": obj.edge_embd = edge_embd obj.angle_embd = angle_embd obj.layers = layers + # ``_use_static_dynamic_sel`` is intentionally not serialized. After a + # backend-specific class deserializes the block, propagate the backend's + # class-level default to the restored layers. + for layer in obj.layers: + layer._use_static_dynamic_sel = obj._use_static_dynamic_sel obj.env_mat_edge = env_mat_edge obj.env_mat_angle = env_mat_angle obj.mean = davg @@ -821,6 +877,77 @@ def serialize(self) -> dict: } +def _get_static_graph_index( + nlist: Array, + a_nlist_mask: Array, + nall: int, + use_loc_mapping: bool = True, +) -> tuple[Array, Array]: + """Build graph indices with fixed edge/angle capacities. + + This mirrors ``get_graph_index`` but keeps all padded slots instead of + compacting with masks, so the first dimension is shape-static. + + Edge slot formula, for frame ``f``, local atom ``i``, edge slot ``j``: + + ``p = (f * nloc + i) * nnei + j`` + ``n2e[p] = f * nloc + i`` + ``n_ext2e[p] = f * nall + nlist[f, i, j]`` + + If ``use_loc_mapping`` is true, ``nlist`` has already been remapped to local + indices, so the frame stride is ``nloc`` instead of ``nall``. + + Angle slot formula, for angle pair ``(j, k)``: + + ``q = ((f * nloc + i) * a_sel + j) * a_sel + k`` + ``n2a[q] = f * nloc + i`` + ``eij2a[q] = p(f, i, j)`` + ``eik2a[q] = p(f, i, k)`` + + Padded edge/angle slots are intentionally included. Their messages vanish + because the corresponding switch weights are zero before owner reductions. + """ + xp = array_api_compat.array_namespace(nlist, a_nlist_mask) + + nf, nloc, nnei = nlist.shape + _, _, a_nnei = a_nlist_mask.shape + dev = array_api_compat.device(nlist) + + nlist_loc_index = xp.arange(nf * nloc, dtype=nlist.dtype, device=dev) + n2e_index = xp.broadcast_to( + xp.reshape(nlist_loc_index, (nf, nloc, 1)), (nf, nloc, nnei) + ) + n2e_index = xp.reshape(n2e_index, (-1,)) + + frame_shift = xp.arange(nf, dtype=nlist.dtype, device=dev) * ( + nall if not use_loc_mapping else nloc + ) + shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis] + n_ext2e_index = xp.reshape(shifted_nlist, (-1,)) + + n2a_index = xp.broadcast_to( + xp.reshape(nlist_loc_index, (nf, nloc, 1, 1)), (nf, nloc, a_nnei, a_nnei) + ) + n2a_index = xp.reshape(n2a_index, (-1,)) + + edge_id = xp.reshape( + xp.arange(nf * nloc * nnei, dtype=nlist.dtype, device=dev), + (nf, nloc, nnei), + )[:, :, :a_nnei] + eij2a_index = xp.broadcast_to( + edge_id[:, :, :, xp.newaxis], (nf, nloc, a_nnei, a_nnei) + ) + eij2a_index = xp.reshape(eij2a_index, (-1,)) + eik2a_index = xp.broadcast_to( + edge_id[:, :, xp.newaxis, :], (nf, nloc, a_nnei, a_nnei) + ) + eik2a_index = xp.reshape(eik2a_index, (-1,)) + + edge_index_result = xp.stack([n2e_index, n_ext2e_index], axis=0) + angle_index_result = xp.stack([n2a_index, eij2a_index, eik2a_index], axis=0) + return edge_index_result, angle_index_result + + def _cal_hg_dynamic( flat_edge_ebd: Array, flat_h2: Array, @@ -934,6 +1061,10 @@ def symmetrization_op_dynamic( class RepFlowLayer(NativeOP): + # Mirrors the descriptor-block internal switch. The owning block writes the + # instance value during construction/deserialization. + _use_static_dynamic_sel: bool = False + def __init__( self, e_rcut: float, @@ -1004,6 +1135,9 @@ def __init__( self.optim_update = optim_update self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel + # Default instance value for standalone RepFlowLayer construction. + # DescrptBlockRepflows overwrites it for regular DPA-3 use. + self._use_static_dynamic_sel = type(self)._use_static_dynamic_sel self.sel_reduce_factor = sel_reduce_factor self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor @@ -1476,9 +1610,16 @@ def call( ) nb, nloc, nnei = nlist.shape nall = node_ebd_ext.shape[1] + # In compact dynamic mode ``n_edge`` is the number of real edges. In + # static dynamic mode it is the fixed edge capacity; padded entries carry + # zero switch weights, so owner reductions are unchanged. # int cannot jit; do not run it when self.use_dynamic_sel == False n_edge = ( - int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 + h2.shape[0] + if self.use_dynamic_sel and self._use_static_dynamic_sel + else int(xp.sum(xp.astype(nlist_mask, xp.int32))) + if self.use_dynamic_sel + else 0 ) node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc) assert (nb, nloc) == node_ebd.shape[:2] diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index c7a12ba3e2..be98f2b67b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -1275,20 +1275,32 @@ def aggregate( # noqa: ANN201 output: [num_owner, feature_dim] """ xp = array_api_compat.array_namespace(data, owners) - bin_count = xp_bincount(owners) - bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count) - dev = array_api_compat.device(data) - if num_owner is not None and bin_count.shape[0] != num_owner: - difference = num_owner - bin_count.shape[0] - bin_count = xp.concat( - [bin_count, xp.ones(difference, dtype=bin_count.dtype, device=dev)] - ) + if num_owner is None or average: + # Averaging needs the owner population: + # avg[o] = sum_{r: owners[r] = o} data[r] / count[o]. + # If num_owner is omitted, bincount also determines the output length. + bin_count = xp_bincount(owners) + bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count) + if num_owner is not None and bin_count.shape[0] != num_owner: + difference = num_owner - bin_count.shape[0] + bin_count = xp.concat( + [bin_count, xp.ones(difference, dtype=bin_count.dtype, device=dev)] + ) + else: + num_owner = bin_count.shape[0] + else: + # Sum-reduction with a known owner count only needs: + # out[o] = sum_{r: owners[r] = o} data[r]. + # Skipping bincount here is mathematically identical and avoids JAX + # tracing failures where jnp.bincount requires concrete owner values. + bin_count = None - output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype, device=dev) + output = xp.zeros((num_owner, data.shape[1]), dtype=data.dtype, device=dev) output = xp_add_at(output, owners, data) if average: + assert bin_count is not None output = xp.transpose(xp.transpose(output) / bin_count) return output diff --git a/deepmd/jax/descriptor/repflows.py b/deepmd/jax/descriptor/repflows.py index 97db6c81a9..8ed8c34777 100644 --- a/deepmd/jax/descriptor/repflows.py +++ b/deepmd/jax/descriptor/repflows.py @@ -16,7 +16,14 @@ @flax_module class DescrptBlockRepflows(DescrptBlockRepflowsDP): - pass + # JAX/jax2tf export cannot represent the compact dynamic layout where + # boolean indexing creates runtime-sized ``n_edge``/``n_angle`` arrays. + # Use the fixed-capacity dynamic layout instead: + # edges = nf * nloc * e_sel + # angles = nf * nloc * a_sel * a_sel + # Invalid slots are still masked by switch weights, so DPA-3 outputs match + # the compact dynamic implementation. + _use_static_dynamic_sel = True @flax_module @@ -26,3 +33,6 @@ class RepFlowLayer(RepFlowLayerDP): "e_residual", "a_residual", } + # Keep the layer-level graph operations in the same fixed-capacity layout + # selected by the owning descriptor block. + _use_static_dynamic_sel = True diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 19fbe8cebd..bb763098d5 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2077,7 +2077,8 @@ def dpa3_repflow_args() -> list[Argument]: doc_a_rcut_smth = "Where to start smoothing for angle. For example the 1/r term is smoothed from `rcut` to `rcut_smth`." doc_a_sel = 'Maximally possible number of selected angle neighbors. It can be:\n\n\ - `int`. The maximum number of neighbor atoms to be considered. We recommend it to be less than 200. \n\n\ - - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors within the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally, the number is rounded up to a multiple of 4. The option "auto" is equivalent to "auto:1.1".' + - `str`. Can be "auto:factor" or "auto". "factor" is a float number larger than 1. This option will automatically determine the `sel`. In detail it counts the maximal number of neighbors within the cutoff radius for each type of neighbor, then multiply the maximum by the "factor". Finally, the number is rounded up to a multiple of 4. The option "auto" is equivalent to "auto:1.1". \n\n\ + For JAX export, DPA3 uses static shapes and materializes angle-pair work arrays with size proportional to `nf * nloc * a_sel^2`. Keep `a_sel` no larger than needed, including when `use_dynamic_sel` is enabled.' doc_a_compress_rate = ( "The compression rate for angular messages. The default value is 0, indicating no compression. " " If a non-zero integer c is provided, the node and edge dimensions will be compressed " @@ -2151,6 +2152,8 @@ def dpa3_repflow_args() -> list[Argument]: "without padding to a fixed selection numbers. " "When enabled, users can safely set larger values for `e_sel` or `a_sel` (e.g., 1200 or 300, respectively) " "to guarantee capturing all neighbors within the cutoff radius. " + "For JAX export, the static upper bound still controls memory use; in particular, angle-pair work arrays " + "scale as `nf * nloc * a_sel^2`. " "Note that when using dynamic selection, the `smooth_edge_update` must be True. " ) doc_sel_reduce_factor = ( diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 207355a7f9..915dc914fa 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -18,6 +18,7 @@ NativeLayer, NativeNet, NetworkCollection, + aggregate, load_dp_model, save_dp_model, ) @@ -478,6 +479,51 @@ def test_zero_dim(self) -> None: ) +class TestAggregate(unittest.TestCase): + def setUp(self) -> None: + self.data = np.array( + [ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + ], + dtype=np.float64, + ) + self.owners = np.array([0, 2, 2], dtype=np.int64) + + def test_average_with_explicit_num_owner(self) -> None: + output = aggregate(self.data, self.owners, average=True, num_owner=4) + + np.testing.assert_allclose( + output, + np.array( + [ + [1.0, 2.0], + [0.0, 0.0], + [4.0, 5.0], + [0.0, 0.0], + ], + dtype=np.float64, + ), + ) + + def test_sum_with_explicit_num_owner(self) -> None: + output = aggregate(self.data, self.owners, average=False, num_owner=4) + + np.testing.assert_allclose( + output, + np.array( + [ + [1.0, 2.0], + [0.0, 0.0], + [8.0, 10.0], + [0.0, 0.0], + ], + dtype=np.float64, + ), + ) + + class TestSaveLoadDPModel(unittest.TestCase): def setUp(self) -> None: self.w = np.full((3, 2), 3.0) diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 56ebc65cde..6999c0b779 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -7,6 +7,8 @@ Any, ) +import numpy as np + from deepmd.dpmodel.descriptor import ( DescrptDPA1, DescrptDPA2, @@ -24,6 +26,9 @@ from deepmd.dpmodel.descriptor.dpa3 import ( RepFlowArgs, ) +from deepmd.dpmodel.descriptor.repflows import ( + DescrptBlockRepflows, +) from ....consistent.common import ( parameterize_func, @@ -36,6 +41,9 @@ CI, TEST_DEVICE, ) +from ...common.cases.cases import ( + TestCaseSingleFrameWithNlist, +) from ...common.cases.descriptor.descriptor import ( DescriptorTest, ) @@ -898,6 +906,81 @@ def setUp(self) -> None: self.module = Descrpt(**self.input_dict) +class TestDPA3StaticDynamicSelDP(unittest.TestCase, TestCaseSingleFrameWithNlist): + """Check that the internal static dynamic layout is value-equivalent.""" + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def _make_dpa3( + self, + use_static_dynamic_sel: bool, + *, + use_loc_mapping: bool, + ) -> DescrptDPA3: + # The switch is intentionally class-level and internal, so tests toggle + # it only around construction and then restore the previous backend mode. + old_use_static_dynamic_sel = DescrptBlockRepflows._use_static_dynamic_sel + DescrptBlockRepflows._use_static_dynamic_sel = use_static_dynamic_sel + try: + return DescrptDPA3( + **DescriptorParamDPA3( + self.nt, + self.rcut, + self.rcut_smth, + self.sel, + ["O", "H"], + smooth_edge_update=True, + use_dynamic_sel=True, + use_loc_mapping=use_loc_mapping, + ) + ) + finally: + DescrptBlockRepflows._use_static_dynamic_sel = old_use_static_dynamic_sel + + def test_static_dynamic_sel_matches_packed_dynamic_sel(self) -> None: + for use_loc_mapping in (True, False): + packed = self._make_dpa3( + False, + use_loc_mapping=use_loc_mapping, + ) + static = self._make_dpa3( + True, + use_loc_mapping=use_loc_mapping, + ) + + packed_out = packed( + self.coord_ext, + self.atype_ext, + self.nlist, + mapping=self.mapping, + ) + static_out = static( + self.coord_ext, + self.atype_ext, + self.nlist, + mapping=self.mapping, + ) + + np.testing.assert_allclose(packed_out[0], static_out[0], atol=self.atol) + np.testing.assert_allclose(packed_out[1], static_out[1], atol=self.atol) + + # Static dynamic selection keeps all edge slots. Masking out padding + # should recover the compact dynamic edge/h2/sw tensors exactly: + # compact_edges == static_edges[nlist != -1]. + valid_edge_mask = np.reshape(self.nlist != -1, (-1,)) + assert static_out[2].shape[0] == self.nf * self.nloc * sum(self.sel) + np.testing.assert_allclose( + packed_out[2], static_out[2][valid_edge_mask], atol=self.atol + ) + np.testing.assert_allclose( + packed_out[3], static_out[3][valid_edge_mask], atol=self.atol + ) + np.testing.assert_allclose( + packed_out[4], static_out[4][valid_edge_mask], atol=self.atol + ) + + class TestHybridChgSpinDefaultDP(unittest.TestCase): def _make_dpa3(self, default_chg_spin: list[float] | None) -> DescrptDPA3: return DescrptDPA3(