From fea615fcedfcbfcbadb06ac6b55e0868081c3fbb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 01:01:05 +0800 Subject: [PATCH 1/6] feat(dpmodel): add descriptor compression --- deepmd/backend/dpmodel.py | 9 +- deepmd/dpmodel/descriptor/dpa1.py | 329 ++++++++++++++++-- deepmd/dpmodel/descriptor/se_atten_v2.py | 6 + deepmd/dpmodel/descriptor/se_e2_a.py | 234 ++++++++++++- deepmd/dpmodel/descriptor/se_r.py | 169 ++++++++- deepmd/dpmodel/entrypoints/__init__.py | 2 + deepmd/dpmodel/entrypoints/compress.py | 93 +++++ deepmd/dpmodel/entrypoints/main.py | 58 +++ deepmd/jax/entrypoints/compress.py | 118 +++++++ deepmd/jax/entrypoints/main.py | 23 ++ deepmd/jax/utils/serialization.py | 79 ++++- deepmd/main.py | 6 +- .../common/dpmodel/test_model_compression.py | 231 ++++++++++++ source/tests/jax/test_model_compression.py | 305 ++++++++++++++++ 14 files changed, 1613 insertions(+), 49 deletions(-) create mode 100644 deepmd/dpmodel/entrypoints/__init__.py create mode 100644 deepmd/dpmodel/entrypoints/compress.py create mode 100644 deepmd/dpmodel/entrypoints/main.py create mode 100644 deepmd/jax/entrypoints/compress.py create mode 100644 source/tests/common/dpmodel/test_model_compression.py create mode 100644 source/tests/jax/test_model_compression.py diff --git a/deepmd/backend/dpmodel.py b/deepmd/backend/dpmodel.py index 31585aa7a6..0e6e0964f3 100644 --- a/deepmd/backend/dpmodel.py +++ b/deepmd/backend/dpmodel.py @@ -34,7 +34,10 @@ class DPModelBackend(Backend): name = "DPModel" """The formal name of the backend.""" features: ClassVar[Backend.Feature] = ( - Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.IO ) """The features of the backend.""" suffixes: ClassVar[list[str]] = [".dp", ".yaml", ".yml"] @@ -59,7 +62,9 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]: Callable[[Namespace], None] The entry point hook of the backend. """ - raise NotImplementedError(f"Unsupported backend: {self.name}") + from deepmd.dpmodel.entrypoints.main import main as deepmd_main + + return deepmd_main @property def deep_eval(self) -> type["DeepEvalBackend"]: diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 2311858180..c813dc444a 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import math +import warnings from collections.abc import ( Callable, ) @@ -64,6 +65,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.tabulate_math import ( + DPTabulate, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -345,6 +349,8 @@ def __init__( self.concat_output_tebd = concat_output_tebd self.trainable = trainable self.precision = precision + self.tebd_compress = False + self.geo_compress = False self.compress = False def get_rcut(self) -> float: @@ -567,6 +573,89 @@ def call( ) return grrg, rot_mat, None, None, sw + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Enable descriptor compression. + + For DPA-1, compression is available for stripped type embeddings. The + type embedding branch is always precomputed; the radial embedding table + is enabled only when there is no attention layer, matching the PT/TF + compression semantics. + """ + if self.compress: + raise ValueError("Compression is already enabled.") + if self.se_atten.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.se_atten.resnet_dt: + raise RuntimeError( + "Model compression error: descriptor resnet_dt must be false!" + ) + for tt in self.se_atten.exclude_types: + if (tt[0] not in range(self.se_atten.ntypes)) or ( + tt[1] not in range(self.se_atten.ntypes) + ): + raise RuntimeError( + "exclude types" + + str(tt) + + " must within the number of atomic types " + + str(self.se_atten.ntypes) + + "!" + ) + if ( + self.se_atten.ntypes * self.se_atten.ntypes + - len(self.se_atten.exclude_types) + == 0 + ): + raise RuntimeError( + "Empty embedding-nets are not supported in model compression!" + ) + + self.se_atten.type_embedding_compression(self.type_embedding) + self.type_embd_data = self.se_atten.type_embd_data + self.tebd_compress = True + self.compress = True + + if self.se_atten.attn_layer == 0: + table = DPTabulate( + self, + self.se_atten.neuron, + self.se_atten.type_one_side, + self.se_atten.exclude_types, + self.se_atten.activation_function, + ) + table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + lower, upper = table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.se_atten.enable_compression( + table.data, + table_config, + lower, + upper, + ) + self.compress_data = self.se_atten.compress_data + self.compress_info = self.se_atten.compress_info + self.geo_compress = True + else: + self.geo_compress = False + warnings.warn( + "Attention layer is not 0, only type embedding is compressed. " + "Geometric part is not compressed.", + UserWarning, + stacklevel=2, + ) + def serialize(self) -> dict: """Serialize the descriptor to dict.""" obj = self.se_atten @@ -678,9 +767,15 @@ def _load_compress_data(self, compress: dict) -> None: variables = compress["@variables"] self.type_embd_data = variables["type_embd_data"] self.geo_compress = compress.get("geo_compress", False) + self.tebd_compress = True + self.se_atten.type_embd_data = self.type_embd_data + self.se_atten.tebd_compress = True + self.se_atten.geo_compress = self.geo_compress if self.geo_compress: self.compress_data = variables["compress_data"] self.compress_info = variables["compress_info"] + self.se_atten.compress_data = self.compress_data + self.se_atten.compress_info = self.compress_info self.compress = True @classmethod @@ -919,6 +1014,12 @@ def __init__( self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.orig_sel = self.sel + self.tebd_compress = False + self.geo_compress = False + self.is_sorted = len(self.exclude_types) == 0 + self.compress_data = [np.zeros(0, dtype=PRECISION_DICT[self.precision])] + self.compress_info = [np.zeros(0, dtype=PRECISION_DICT[self.precision])] + self.type_embd_data = np.zeros(0, dtype=PRECISION_DICT[self.precision]) def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -1057,6 +1158,7 @@ def reinit_exclude( exclude_types: list[tuple[int, int]] = [], ) -> None: self.exclude_types = exclude_types + self.is_sorted = len(self.exclude_types) == 0 self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) def cal_g( @@ -1082,6 +1184,132 @@ def cal_g_strip( gg = self.embeddings_strip[embedding_idx].call(ss) return gg + def enable_compression( + self, + table_data: dict[str, Array], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + """Store tabulated geometric embedding-net data.""" + net = "filter_net" + dtype = self.mean.dtype + self.compress_info = [ + np.asarray( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=dtype, + ) + ] + self.compress_data = [np.asarray(table_data[net], dtype=dtype)] + self.geo_compress = True + + def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: + """Precompute stripped type embedding network outputs.""" + if self.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.embeddings_strip is None: + raise RuntimeError( + "embeddings_strip must be initialized for type embedding compression" + ) + + full_embd = type_embedding_net.call() + xp = array_api_compat.array_namespace(full_embd) + nt, t_dim = full_embd.shape + if self.type_one_side: + embd_tensor = self.embeddings_strip[0].call(full_embd) + else: + type_embedding_nei = xp.tile( + xp.reshape(full_embd, (1, nt, t_dim)), (nt, 1, 1) + ) + type_embedding_center = xp.tile( + xp.reshape(full_embd, (nt, 1, t_dim)), (1, nt, 1) + ) + two_side_type_embedding = xp.reshape( + xp.concat([type_embedding_nei, type_embedding_center], axis=-1), + (-1, t_dim * 2), + ) + embd_tensor = self.embeddings_strip[0].call(two_side_type_embedding) + self.type_embd_data = embd_tensor + self.tebd_compress = True + + def _tabulate_fusion_se_atten( + self, + table: Array, + table_info: Array, + em_x: Array, + em: Array, + two_embed: Array, + last_layer_size: int, + ) -> Array: + """Pure Array API implementation of tabulate_fusion_se_atten forward.""" + xp = array_api_compat.array_namespace(em_x, em, two_embed) + device = array_api_compat.device(em) + table = xp.asarray(table[...], dtype=em.dtype, device=device) + table_info = xp.asarray(table_info[...], dtype=em.dtype, device=device) + + nloc, nnei = em.shape[:2] + xx = xp.reshape(em_x, (nloc, nnei)) + lower = table_info[0] + upper = table_info[1] + table_max = table_info[2] + stride0 = table_info[3] + stride1 = table_info[4] + + zeros = xp.zeros(xx.shape, dtype=xp.int64, device=device) + nspline = table.shape[0] + last_idx = xp.full(xx.shape, nspline - 1, dtype=xp.int64, device=device) + first_stride = xp.astype(xp.floor((upper - lower) / stride0), xp.int64) + first_stride_value = xp.astype(first_stride, em.dtype) + + first_idx = xp.astype(xp.floor((xx - lower) / stride0), xp.int64) + second_idx = first_stride + xp.astype( + xp.floor((xx - upper) / stride1), xp.int64 + ) + table_idx = xp.where( + xx < lower, + zeros, + xp.where( + xx < upper, + first_idx, + xp.where(xx < table_max, second_idx, last_idx), + ), + ) + table_idx = xp.minimum(xp.maximum(table_idx, zeros), last_idx) + + table_idx_value = xp.astype(table_idx, em.dtype) + dx_first = xx - (table_idx_value * stride0 + lower) + dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx = xp.where( + (xx >= lower) & (xx < upper), + dx_first, + xp.where((xx >= upper) & (xx < table_max), dx_second, xp.zeros_like(xx)), + ) + + coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) + coeff = xp.reshape(coeff, (nloc, nnei, last_layer_size, 6)) + dx = xp.reshape(dx, (nloc, nnei, 1)) + values = ( + coeff[..., 0] + + ( + coeff[..., 1] + + ( + coeff[..., 2] + + (coeff[..., 3] + (coeff[..., 4] + coeff[..., 5] * dx) * dx) * dx + ) + * dx + ) + * dx + ) + values = values * two_embed + values + return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) + def call( self, nlist: Array, @@ -1131,6 +1359,7 @@ def call( rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nnei x 1 ss = rr[..., 0:1] + geo_gr = None if self.tebd_input_mode in ["concat"]: # nfnl x tebd_dim atype_embd = xp.reshape( @@ -1157,8 +1386,7 @@ def call( # nfnl x nnei x ng gg = self.cal_g(ss, 0) elif self.tebd_input_mode in ["strip"]: - # nfnl x nnei x ng - gg_s = self.cal_g(ss, 0) + ss_scalar = ss assert self.embeddings_strip is not None assert type_embedding is not None ntypes_with_padding = type_embedding.shape[0] @@ -1168,7 +1396,14 @@ def call( # (nf x nl x nnei) x ng nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng)) if self.type_one_side: - tt_full = self.cal_g_strip(type_embedding, 0) + if self.tebd_compress: + tt_full = xp.asarray( + self.type_embd_data[...], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + else: + tt_full = self.cal_g_strip(type_embedding, 0) # (nf x nl x nnei) x ng gg_t = xp_take_along_axis(tt_full, nei_type_index, axis=0) else: @@ -1183,46 +1418,70 @@ def call( idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) idx = xp.astype(idx, xp.int64) - # (ntypes) * ntypes * nt - type_embedding_nei = xp.tile( - xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), - (ntypes_with_padding, 1, 1), - ) - # ntypes * (ntypes) * nt - type_embedding_center = xp.tile( - xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)), - (1, ntypes_with_padding, 1), - ) - # (ntypes * ntypes) * (nt+nt) - two_side_type_embedding = xp.reshape( - xp.concat([type_embedding_nei, type_embedding_center], axis=-1), - (-1, nt * 2), - ) - tt_full = self.cal_g_strip(two_side_type_embedding, 0) + if self.tebd_compress: + tt_full = xp.asarray( + self.type_embd_data[...], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + else: + # (ntypes) * ntypes * nt + type_embedding_nei = xp.tile( + xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), + (ntypes_with_padding, 1, 1), + ) + # ntypes * (ntypes) * nt + type_embedding_center = xp.tile( + xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)), + (1, ntypes_with_padding, 1), + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = xp.reshape( + xp.concat([type_embedding_nei, type_embedding_center], axis=-1), + (-1, nt * 2), + ) + tt_full = self.cal_g_strip(two_side_type_embedding, 0) # (nf x nl x nnei) x ng gg_t = xp_take_along_axis(tt_full, idx, axis=0) # (nf x nl) x nnei x ng gg_t = xp.reshape(gg_t, (nf * nloc, nnei, ng)) if self.smooth: gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1)) - # nfnl x nnei x ng - gg = gg_s * gg_t + gg_s + if self.geo_compress: + geo_gr = self._tabulate_fusion_se_atten( + self.compress_data[0], + self.compress_info[0], + ss_scalar, + rr, + gg_t, + self.filter_neuron[-1], + ) + gg = None + else: + # nfnl x nnei x ng + gg_s = self.cal_g(ss_scalar, 0) + gg = gg_s * gg_t + gg_s else: raise NotImplementedError - normed = safe_for_vector_norm( - xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True - ) - input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( - normed, - xp.full_like(normed, 1e-12), - ) - gg = self.dpa1_attention( - gg, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - # nfnl x ng x 4 - # gr = xp.einsum("lni,lnj->lij", gg, rr) - gr = xp.sum(gg[:, :, :, None] * rr[:, :, None, :], axis=1) + if geo_gr is None: + normed = safe_for_vector_norm( + xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True + ) + input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( + normed, + xp.full_like(normed, 1e-12), + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x ng x 4 + # gr = xp.einsum("lni,lnj->lij", gg, rr) + gr = xp.sum(gg[:, :, :, None] * rr[:, :, None, :], axis=1) + g2 = xp.reshape(gg, (nf, nloc, self.nnei, self.filter_neuron[-1])) + else: + gr = xp.permute_dims(geo_gr, (0, 2, 1)) + g2 = None gr /= self.nnei gr1 = gr[:, : self.axis_neuron, :] # nfnl x ng x ng1 @@ -1234,7 +1493,7 @@ def call( ) return ( xp.reshape(grrg, (nf, nloc, self.filter_neuron[-1] * self.axis_neuron)), - xp.reshape(gg, (nf, nloc, self.nnei, self.filter_neuron[-1])), + g2, xp.reshape(dmatrix, (nf, nloc, self.nnei, 4))[..., 1:], xp.reshape(gr[..., 1:], (nf, nloc, self.filter_neuron[-1], 3)), xp.reshape(sw, (nf, nloc, nnei, 1)), diff --git a/deepmd/dpmodel/descriptor/se_atten_v2.py b/deepmd/dpmodel/descriptor/se_atten_v2.py index 1cc383cfd9..c3b02c3d1e 100644 --- a/deepmd/dpmodel/descriptor/se_atten_v2.py +++ b/deepmd/dpmodel/descriptor/se_atten_v2.py @@ -299,7 +299,13 @@ def _load_compress_data(self, compress: dict) -> None: variables = compress["@variables"] self.type_embd_data = variables["type_embd_data"] self.geo_compress = compress.get("geo_compress", False) + self.tebd_compress = True + self.se_atten.type_embd_data = self.type_embd_data + self.se_atten.tebd_compress = True + self.se_atten.geo_compress = self.geo_compress if self.geo_compress: self.compress_data = variables["compress_data"] self.compress_info = variables["compress_info"] + self.se_atten.compress_data = self.compress_data + self.se_atten.compress_info = self.compress_info self.compress = True diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index bdac1e0cc0..48edbd836a 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -44,6 +44,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.tabulate_math import ( + DPTabulate, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -388,6 +391,148 @@ def cal_g( gg = self.embeddings[embedding_idx].call(ss) return gg + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Enable descriptor compression by tabulating embedding networks.""" + if self.compress: + raise ValueError("Compression is already enabled.") + table = DPTabulate( + self, + self.neuron, + self.type_one_side, + self.exclude_types, + self.activation_function, + ) + lower, upper = table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self._store_compress_data( + table.data, + [table_extrapolate, table_stride_1, table_stride_2, check_frequency], + lower, + upper, + ) + self.compress = True + + def _store_compress_data( + self, + table_data: dict[str, Array], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + """Store tabulated embedding-net data in the descriptor state.""" + compress_data = [] + compress_info = [] + dtype = self.davg.dtype + ndim = 1 if self.type_one_side else 2 + for embedding_idx in range(self.ntypes**ndim): + if self.type_one_side: + ii = embedding_idx + ti = -1 + else: + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + if net not in table_data: + compress_data.append(np.asarray([], dtype=dtype)) + compress_info.append(np.asarray([], dtype=dtype)) + continue + compress_data.append(np.asarray(table_data[net], dtype=dtype)) + compress_info.append( + np.asarray( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=dtype, + ) + ) + self.compress_data = compress_data + self.compress_info = compress_info + + def _tabulate_fusion_se_a( + self, + table: Array, + table_info: Array, + em_x: Array, + em: Array, + last_layer_size: int, + ) -> Array: + """Pure Array API implementation of tabulate_fusion_se_a forward.""" + xp = array_api_compat.array_namespace(em_x, em) + device = array_api_compat.device(em) + table = xp.asarray(table[...], dtype=em.dtype, device=device) + table_info = xp.asarray(table_info[...], dtype=em.dtype, device=device) + + nloc, nnei = em.shape[:2] + xx = xp.reshape(em_x, (nloc, nnei)) + lower = table_info[0] + upper = table_info[1] + table_max = table_info[2] + stride0 = table_info[3] + stride1 = table_info[4] + + zeros = xp.zeros(xx.shape, dtype=xp.int64, device=device) + nspline = table.shape[0] + last_idx = xp.full(xx.shape, nspline - 1, dtype=xp.int64, device=device) + first_stride = xp.astype(xp.floor((upper - lower) / stride0), xp.int64) + first_stride_value = xp.astype(first_stride, em.dtype) + + first_idx = xp.astype(xp.floor((xx - lower) / stride0), xp.int64) + second_idx = first_stride + xp.astype( + xp.floor((xx - upper) / stride1), xp.int64 + ) + table_idx = xp.where( + xx < lower, + zeros, + xp.where( + xx < upper, + first_idx, + xp.where(xx < table_max, second_idx, last_idx), + ), + ) + table_idx = xp.minimum(xp.maximum(table_idx, zeros), last_idx) + + table_idx_value = xp.astype(table_idx, em.dtype) + dx_first = xx - (table_idx_value * stride0 + lower) + dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx = xp.where( + (xx >= lower) & (xx < upper), + dx_first, + xp.where((xx >= upper) & (xx < table_max), dx_second, xp.zeros_like(xx)), + ) + + coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) + coeff = xp.reshape(coeff, (nloc, nnei, last_layer_size, 6)) + dx = xp.reshape(dx, (nloc, nnei, 1)) + values = ( + coeff[..., 0] + + ( + coeff[..., 1] + + ( + coeff[..., 2] + + (coeff[..., 3] + (coeff[..., 4] + coeff[..., 5] * dx) * dx) * dx + ) + * dx + ) + * dx + ) + return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) + def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], @@ -450,12 +595,22 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + if self.compress: + gr = self._call_compressed(rr, atype_ext, nlist, exclude_mask) + gr = xp.astype(gr, input_dtype) + gr = xp.reshape(gr, (nf, nloc, ng, 4)) + gr /= self.nnei + gr1 = gr[:, :, : self.axis_neuron, :] + grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) + return grrg, gr[..., 1:], None, None, ww + gr = xp.zeros( [nf * nloc, ng, 4], dtype=input_dtype, device=array_api_compat.device(coord_ext), ) - exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) @@ -513,6 +668,83 @@ def call( grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww + def _call_compressed( + self, + rr: Array, + atype_ext: Array, + nlist: Array, + exclude_mask: Array, + ) -> Array: + """Compressed forward path for the SE-A descriptor.""" + xp = array_api_compat.array_namespace(rr, atype_ext, nlist) + nf, nloc, nnei, _ = rr.shape + sec = self.sel_cumsum + ng = self.neuron[-1] + nfnl = nf * nloc + gr = xp.zeros( + [nfnl, 4, ng], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + exclude_mask = xp.reshape(exclude_mask, (nfnl, nnei)) + rr = xp.reshape(rr, (nfnl, nnei, 4)) + rr = xp.astype(rr, self.dstd.dtype) + + if self.type_one_side: + for embedding_idx, (compress_data_ii, compress_info_ii) in enumerate( + zip(self.compress_data, self.compress_info, strict=True) + ): + if array_api_compat.size(compress_data_ii) == 0: + continue + mm = exclude_mask[:, sec[embedding_idx] : sec[embedding_idx + 1]] + rr_i = rr[:, sec[embedding_idx] : sec[embedding_idx + 1], :] + rr_i = rr_i * xp.astype(mm[:, :, None], rr_i.dtype) + ss = rr_i[:, :, :1] + gr += self._tabulate_fusion_se_a( + compress_data_ii, + compress_info_ii, + ss, + rr_i, + ng, + ) + else: + atype_loc = xp.reshape(atype_ext[:, :nloc], (nfnl,)) + sort_idx = xp.argsort(atype_loc) + unsort_idx = xp.argsort(sort_idx) + rr_s = xp.take(rr, sort_idx, axis=0) + mask_s = xp.take(exclude_mask, sort_idx, axis=0) + dev = array_api_compat.device(rr) + gr_s = xp.zeros([nfnl, 4, ng], dtype=rr.dtype, device=dev) + type_ends = [] + offset = 0 + for ti in range(self.ntypes): + offset += int(xp.sum(xp.astype(atype_loc == ti, xp.int32))) + type_ends.append(offset) + type_starts = [0, *type_ends[:-1]] + for ti in range(self.ntypes): + s, e = type_starts[ti], type_ends[ti] + if s == e: + continue + for tt in range(self.ntypes): + embedding_idx = tt * self.ntypes + ti + compress_data_ii = self.compress_data[embedding_idx] + if array_api_compat.size(compress_data_ii) == 0: + continue + compress_info_ii = self.compress_info[embedding_idx] + mm = mask_s[s:e, sec[tt] : sec[tt + 1]] + rr_i = rr_s[s:e, sec[tt] : sec[tt + 1], :] + rr_i = rr_i * xp.astype(mm[:, :, None], rr_i.dtype) + ss = rr_i[:, :, :1] + gr_s[s:e] = gr_s[s:e] + self._tabulate_fusion_se_a( + compress_data_ii, + compress_info_ii, + ss, + rr_i, + ng, + ) + gr = xp.take(gr_s, unsort_idx, axis=0) + return xp.permute_dims(gr, (0, 2, 1)) + def serialize(self) -> dict: """Serialize the descriptor to dict.""" if not self.type_one_side and self.exclude_types: diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 624889d85f..891f234ced 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -44,6 +44,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.tabulate_math import ( + DPTabulate, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -367,6 +370,136 @@ def cal_g( gg = self.embeddings[(ll,)].call(ss) return gg + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Enable descriptor compression by tabulating embedding networks.""" + if self.compress: + raise ValueError("Compression is already enabled.") + table = DPTabulate( + self, + self.neuron, + self.type_one_side, + self.exclude_types, + self.activation_function, + ) + lower, upper = table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self._store_compress_data( + table.data, + [table_extrapolate, table_stride_1, table_stride_2, check_frequency], + lower, + upper, + ) + self.compress = True + + def _store_compress_data( + self, + table_data: dict[str, Array], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + """Store tabulated embedding-net data in the descriptor state.""" + compress_data = [] + compress_info = [] + dtype = self.davg.dtype + for embedding_idx in range(self.ntypes): + net = "filter_-1_net_" + str(embedding_idx) + if net not in table_data: + compress_data.append(np.asarray([], dtype=dtype)) + compress_info.append(np.asarray([], dtype=dtype)) + continue + compress_data.append(np.asarray(table_data[net], dtype=dtype)) + compress_info.append( + np.asarray( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=dtype, + ) + ) + self.compress_data = compress_data + self.compress_info = compress_info + + def _tabulate_fusion_se_r( + self, + table: Array, + table_info: Array, + em_x: Array, + last_layer_size: int, + ) -> Array: + """Pure Array API implementation of tabulate_fusion_se_r forward.""" + xp = array_api_compat.array_namespace(em_x) + device = array_api_compat.device(em_x) + table = xp.asarray(table[...], dtype=em_x.dtype, device=device) + table_info = xp.asarray(table_info[...], dtype=em_x.dtype, device=device) + + nloc, nnei = em_x.shape[:2] + xx = xp.reshape(em_x, (nloc, nnei)) + lower = table_info[0] + upper = table_info[1] + table_max = table_info[2] + stride0 = table_info[3] + stride1 = table_info[4] + + zeros = xp.zeros(xx.shape, dtype=xp.int64, device=device) + nspline = table.shape[0] + last_idx = xp.full(xx.shape, nspline - 1, dtype=xp.int64, device=device) + first_stride = xp.astype(xp.floor((upper - lower) / stride0), xp.int64) + first_stride_value = xp.astype(first_stride, em_x.dtype) + + first_idx = xp.astype(xp.floor((xx - lower) / stride0), xp.int64) + second_idx = first_stride + xp.astype( + xp.floor((xx - upper) / stride1), xp.int64 + ) + table_idx = xp.where( + xx < lower, + zeros, + xp.where( + xx < upper, + first_idx, + xp.where(xx < table_max, second_idx, last_idx), + ), + ) + table_idx = xp.minimum(xp.maximum(table_idx, zeros), last_idx) + + table_idx_value = xp.astype(table_idx, em_x.dtype) + dx_first = xx - (table_idx_value * stride0 + lower) + dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx = xp.where( + (xx >= lower) & (xx < upper), + dx_first, + xp.where((xx >= upper) & (xx < table_max), dx_second, xp.zeros_like(xx)), + ) + + coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) + coeff = xp.reshape(coeff, (nloc, nnei, last_layer_size, 6)) + dx = xp.reshape(dx, (nloc, nnei, 1)) + return ( + coeff[..., 0] + + ( + coeff[..., 1] + + ( + coeff[..., 2] + + (coeff[..., 3] + (coeff[..., 4] + coeff[..., 5] * dx) * dx) * dx + ) + * dx + ) + * dx + ) + @cast_precision def call( self, @@ -429,14 +562,34 @@ def call( ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) rr = xp.astype(rr, xyz_scatter.dtype) - for tt in range(self.ntypes): - mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] - tr = rr[:, :, sec[tt] : sec[tt + 1], :] - tr = tr * xp.astype(mm[:, :, :, None], tr.dtype) - gg = self.cal_g(tr, tt) - gg = xp.mean(gg, axis=2) - # nf x nloc x ng x 1 - xyz_scatter += gg * (self.sel[tt] / self.nnei) + if self.compress: + rr = xp.reshape(rr, (nf * nloc, nnei, 1)) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + for tt, (compress_data_ii, compress_info_ii) in enumerate( + zip(self.compress_data, self.compress_info, strict=True) + ): + if array_api_compat.size(compress_data_ii) == 0: + continue + mm = exclude_mask[:, sec[tt] : sec[tt + 1]] + tr = rr[:, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, None], tr.dtype) + gg = self._tabulate_fusion_se_r( + compress_data_ii, + compress_info_ii, + tr, + ng, + ) + gg = xp.reshape(xp.sum(gg, axis=1), (nf, nloc, ng)) + xyz_scatter += gg / self.nnei + else: + for tt in range(self.ntypes): + mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] + tr = rr[:, :, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, :, None], tr.dtype) + gg = self.cal_g(tr, tt) + gg = xp.mean(gg, axis=2) + # nf x nloc x ng x 1 + xyz_scatter += gg * (self.sel[tt] / self.nnei) res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale diff --git a/deepmd/dpmodel/entrypoints/__init__.py b/deepmd/dpmodel/entrypoints/__init__.py new file mode 100644 index 0000000000..f732469b75 --- /dev/null +++ b/deepmd/dpmodel/entrypoints/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Command-line entry points for the DPModel backend.""" diff --git a/deepmd/dpmodel/entrypoints/compress.py b/deepmd/dpmodel/entrypoints/compress.py new file mode 100644 index 0000000000..3848ade9b3 --- /dev/null +++ b/deepmd/dpmodel/entrypoints/compress.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Compress a native DPModel file by tabulating embedding networks.""" + +import logging + +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, +) + +log = logging.getLogger(__name__) + + +def _get_saved_min_nbor_dist(model_dict: dict) -> float | None: + """Read min_nbor_dist from known native-model metadata locations.""" + min_nbor_dist = model_dict.get("min_nbor_dist") + if min_nbor_dist is None: + constants = model_dict.get("constants", {}) + min_nbor_dist = constants.get("min_nbor_dist") + if min_nbor_dist is None: + return None + return float(min_nbor_dist) + + +def _compute_min_nbor_dist(training_script: str) -> float: + from deepmd.common import ( + j_loader, + ) + from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, + ) + from deepmd.utils.compat import ( + update_deepmd_input, + ) + from deepmd.utils.data_system import ( + get_data, + ) + + jdata = update_deepmd_input(j_loader(training_script)) + type_map = jdata["model"].get("type_map", None) + train_data = get_data( + jdata["training"]["training_data"], + 0, + type_map, + None, + ) + return float(UpdateSel().get_min_nbor_dist(train_data)) + + +def enable_compression( + input_file: str, + output: str, + stride: float = 0.01, + extrapolate: int = 5, + check_frequency: int = -1, + training_script: str | None = None, +) -> None: + """Compress a native ``.dp``/``.yaml`` model.""" + model_dict = load_dp_model(input_file) + model = BaseModel.deserialize(model_dict["model"]) + + min_nbor_dist = model.get_min_nbor_dist() + if min_nbor_dist is None: + min_nbor_dist = _get_saved_min_nbor_dist(model_dict) + if min_nbor_dist is None: + log.info( + "Minimal neighbor distance is not saved in the model, " + "compute it from the training data." + ) + if training_script is None: + raise ValueError( + "The model does not have a minimum neighbor distance, " + "so the training script and data must be provided " + "(via -t,--training-script)." + ) + min_nbor_dist = _compute_min_nbor_dist(training_script) + + model.min_nbor_dist = float(min_nbor_dist) + model.enable_compression( + extrapolate, + stride, + stride * 10, + check_frequency, + ) + + compressed_model_dict = model_dict.copy() + compressed_model_dict["model"] = model.serialize() + compressed_model_dict["min_nbor_dist"] = float(min_nbor_dist) + save_dp_model(output, compressed_model_dict) + log.info("Compressed model saved to %s", output) diff --git a/deepmd/dpmodel/entrypoints/main.py b/deepmd/dpmodel/entrypoints/main.py new file mode 100644 index 0000000000..3ca3adf11b --- /dev/null +++ b/deepmd/dpmodel/entrypoints/main.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD-kit entry point for the DPModel backend.""" + +import argparse +from pathlib import ( + Path, +) + +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.dpmodel.entrypoints.compress import ( + enable_compression, +) +from deepmd.loggers.loggers import ( + set_log_handles, +) +from deepmd.main import ( + parse_args, +) + +__all__ = ["main"] + + +def main(args: list[str] | argparse.Namespace | None = None) -> None: + """DPModel backend command dispatcher.""" + if not isinstance(args, argparse.Namespace): + args = parse_args(args=args) + + set_log_handles( + args.log_level, + Path(args.log_path) if args.log_path else None, + mpi_log=None, + ) + + if args.command == "compress": + enable_compression( + input_file=format_model_suffix( + args.input, + preferred_backend="dp", + strict_prefer=True, + ), + output=format_model_suffix( + args.output, + preferred_backend="dp", + strict_prefer=True, + ), + stride=args.step, + extrapolate=args.extrapolate, + check_frequency=args.frequency, + training_script=args.training_script, + ) + elif args.command is None: + pass + else: + raise RuntimeError( + f"Unsupported command '{args.command}' for the DPModel backend." + ) diff --git a/deepmd/jax/entrypoints/compress.py b/deepmd/jax/entrypoints/compress.py new file mode 100644 index 0000000000..337241d111 --- /dev/null +++ b/deepmd/jax/entrypoints/compress.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Compress a JAX model by tabulating embedding networks.""" + +import logging +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import numpy as np + +from deepmd.common import ( + j_loader, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.utils.serialization import ( + deserialize_to_file, + serialize_from_file, +) +from deepmd.jax.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + get_data, +) + +log = logging.getLogger(__name__) + + +def _to_float(value: Any) -> float | None: + if value is None: + return None + value = getattr(value, "value", value) + return float(np.asarray(value)) + + +def _get_saved_min_nbor_dist(data: dict) -> float | None: + """Read min_nbor_dist from known serialized-model metadata locations.""" + min_nbor_dist = _to_float(data.get("min_nbor_dist")) + if min_nbor_dist is not None: + return min_nbor_dist + constants = data.get("constants", {}) + return _to_float(constants.get("min_nbor_dist")) + + +def _get_input_min_nbor_dist(input_file: str, data: dict) -> float | None: + """Read min_nbor_dist from the serialized data or native HLO constants.""" + min_nbor_dist = _get_saved_min_nbor_dist(data) + if min_nbor_dist is not None: + return min_nbor_dist + if Path(input_file).suffix == ".hlo": + return _get_saved_min_nbor_dist(load_dp_model(input_file)) + return None + + +def _compute_min_nbor_dist(training_script: str) -> float: + jdata = update_deepmd_input(j_loader(training_script)) + type_map = jdata["model"].get("type_map", None) + train_data = get_data( + jdata["training"]["training_data"], + 0, + type_map, + None, + ) + return float(UpdateSel().get_min_nbor_dist(train_data)) + + +def enable_compression( + input_file: str, + output: str, + stride: float = 0.01, + extrapolate: int = 5, + check_frequency: int = -1, + training_script: str | None = None, +) -> None: + """Compress a JAX ``.jax``/``.hlo`` model.""" + data = serialize_from_file(input_file) + model = BaseModel.deserialize(data["model"]) + + min_nbor_dist = _to_float(model.get_min_nbor_dist()) + if min_nbor_dist is None: + min_nbor_dist = _get_input_min_nbor_dist(input_file, data) + if min_nbor_dist is None: + log.info( + "Minimal neighbor distance is not saved in the model, " + "compute it from the training data." + ) + if training_script is None: + raise ValueError( + "The model does not have a minimum neighbor distance, " + "so the training script and data must be provided " + "(via -t,--training-script)." + ) + min_nbor_dist = _compute_min_nbor_dist(training_script) + + model.min_nbor_dist = float(min_nbor_dist) + model.enable_compression( + extrapolate, + stride, + stride * 10, + check_frequency, + ) + + compressed_data = data.copy() + compressed_data["model"] = model.serialize() + compressed_data["min_nbor_dist"] = float(min_nbor_dist) + deserialize_to_file(output, compressed_data) + log.info("Compressed model saved to %s", output) diff --git a/deepmd/jax/entrypoints/main.py b/deepmd/jax/entrypoints/main.py index a365b1dea8..1d3b957c1f 100644 --- a/deepmd/jax/entrypoints/main.py +++ b/deepmd/jax/entrypoints/main.py @@ -6,6 +6,12 @@ Path, ) +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.jax.entrypoints.compress import ( + enable_compression, +) from deepmd.jax.entrypoints.freeze import ( freeze, ) @@ -51,6 +57,23 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None: train(**dict_args) elif args.command == "freeze": freeze(**dict_args) + elif args.command == "compress": + enable_compression( + input_file=format_model_suffix( + args.input, + preferred_backend="jax", + strict_prefer=True, + ), + output=format_model_suffix( + args.output, + preferred_backend="jax", + strict_prefer=True, + ), + stride=args.step, + extrapolate=args.extrapolate, + check_frequency=args.frequency, + training_script=args.training_script, + ) elif args.command is None: pass else: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 14386d9f3d..758de69384 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -2,6 +2,9 @@ from pathlib import ( Path, ) +from typing import ( + Any, +) import numpy as np import orbax.checkpoint as ocp @@ -22,6 +25,66 @@ ) +def _state_sequence_to_numpy_list(state_value: Any) -> list[np.ndarray]: + """Convert an Orbax-restored list/dict sequence to NumPy arrays.""" + if isinstance(state_value, dict): + values = [state_value[key] for key in sorted(state_value)] + else: + values = state_value + return [np.asarray(getattr(value, "value", value)) for value in values] + + +def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: + """Create compression variable slots before replacing an NNX state. + + A compressed ``.jax`` checkpoint stores tabulation arrays in the NNX state, + while ``model_def_script`` still describes the original uncompressed model. + Build the corresponding descriptor attributes first so Flax can match the + restored state keys. + """ + if not isinstance(state, dict): + return + if ( + hasattr(obj, "compress") + and "compress_data" in state + and "compress_info" in state + ): + obj.compress_data = _state_sequence_to_numpy_list(state["compress_data"]) + obj.compress_info = _state_sequence_to_numpy_list(state["compress_info"]) + obj.compress = True + for name, child_state in state.items(): + if not isinstance(child_state, dict): + continue + if isinstance(name, int): + try: + child = obj[name] + except (IndexError, KeyError, TypeError): + continue + else: + if not hasattr(obj, name): + continue + child = getattr(obj, name) + _restore_compression_slots_from_state(child, child_state) + + +def _to_optional_float(value: Any) -> float | None: + if value is None: + return None + return float(np.asarray(getattr(value, "value", value))) + + +def _set_model_min_nbor_dist_from_data(model: BaseModel, data: dict) -> None: + if model.get_min_nbor_dist() is not None: + return + min_nbor_dist = _to_optional_float(data.get("min_nbor_dist")) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float( + data.get("constants", {}).get("min_nbor_dist") + ) + if min_nbor_dist is not None: + model.min_nbor_dist = min_nbor_dist + + def deserialize_to_file(model_file: str, data: dict) -> None: """Deserialize the dictionary to a model file. @@ -34,7 +97,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None: """ if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) - model_def_script = data["model_def_script"] + model_def_script = data["model_def_script"].copy() + min_nbor_dist = _to_optional_float(data.get("min_nbor_dist")) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float( + data.get("constants", {}).get("min_nbor_dist") + ) + if min_nbor_dist is not None: + model_def_script["_min_nbor_dist"] = min_nbor_dist _, state = nnx.split(model) with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") @@ -48,6 +118,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ) elif model_file.endswith(".hlo"): model = BaseModel.deserialize(data["model"]) + _set_model_min_nbor_dist_from_data(model, data) model_def_script = data["model_def_script"] call_lower = model.call_common_lower @@ -185,10 +256,14 @@ def convert_str_to_int_key(item: dict) -> None: model_def_script = data.model_def_script abstract_model = get_model(model_def_script) + _restore_compression_slots_from_state(abstract_model, state) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(state) model = nnx.merge(graphdef, abstract_state) model_dict = model.serialize() + min_nbor_dist = _to_optional_float(model.get_min_nbor_dist()) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float(model_def_script.get("_min_nbor_dist")) data = { "backend": "JAX", "jax_version": jax.__version__, @@ -196,6 +271,8 @@ def convert_str_to_int_key(item: dict) -> None: "model_def_script": model_def_script, "@variables": {}, } + if min_nbor_dist is not None: + data["min_nbor_dist"] = min_nbor_dist return data elif model_file.endswith(".hlo"): data = load_dp_model(model_file) diff --git a/deepmd/main.py b/deepmd/main.py index 43f40dc214..56e1863725 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -585,6 +585,8 @@ def main_parser() -> argparse.ArgumentParser: dp compress dp --tf compress -i frozen_model.pb -o compressed_model.pb dp --pt compress -i frozen_model.pth -o compressed_model.pth + dp --dp compress -i frozen_model.dp -o compressed_model.dp + dp --jax compress -i frozen_model.hlo -o compressed_model.hlo """ ), ) @@ -593,14 +595,14 @@ def main_parser() -> argparse.ArgumentParser: "--input", default="frozen_model", type=str, - help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth", + help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", ) parser_compress.add_argument( "-o", "--output", default="frozen_model_compressed", type=str, - help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth", + help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", ) parser_compress.add_argument( "-s", diff --git a/source/tests/common/dpmodel/test_model_compression.py b/source/tests/common/dpmodel/test_model_compression.py new file mode 100644 index 0000000000..bf57e8b183 --- /dev/null +++ b/source/tests/common/dpmodel/test_model_compression.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import tempfile +import unittest +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.dpmodel.entrypoints.compress import ( + enable_compression, +) +from deepmd.dpmodel.model.model import ( + get_model, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, +) + + +class TestDPModelCompression(unittest.TestCase): + def setUp(self) -> None: + self.coord = np.array( + [ + [ + [0.0, 0.0, 0.0], + [1.2, 0.1, 0.0], + [0.1, 1.4, 0.0], + [1.5, 1.5, 0.1], + ] + ], + dtype=np.float64, + ) + self.atype = np.array([[0, 0, 1, 1]], dtype=np.int32) + self.nlist = np.array([[[1, 2], [0, 2], [0, 3], [0, 2]]], dtype=np.int64) + + def _make_descriptor( + self, + type_one_side: bool = True, + exclude_types: list[list[int]] | None = None, + ) -> DescrptSeA: + return DescrptSeA( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + axis_neuron=2, + resnet_dt=False, + type_one_side=type_one_side, + exclude_types=[] if exclude_types is None else exclude_types, + precision="float64", + seed=1234, + ) + + def test_se_e2_a_enable_compression(self) -> None: + for type_one_side, exclude_types in ( + (True, []), + (False, []), + (True, [[0, 1]]), + (False, [[0, 1]]), + ): + with self.subTest(type_one_side=type_one_side, exclude_types=exclude_types): + descriptor = self._make_descriptor(type_one_side, exclude_types) + expected = descriptor.call(self.coord, self.atype, self.nlist) + + compressed = DescrptSeA.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeA.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + actual_item, expected_item, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, atol=1e-10 + ) + + def test_se_e2_r_enable_compression(self) -> None: + descriptor = DescrptSeR( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + resnet_dt=False, + precision="float64", + seed=1234, + ) + expected = descriptor.call(self.coord, self.atype, self.nlist) + + compressed = DescrptSeR.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeR.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose(actual_item, expected_item, atol=1e-10) + np.testing.assert_allclose(reloaded_item, expected_item, atol=1e-10) + + def test_se_atten_enable_compression(self) -> None: + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 8], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + precision="float64", + seed=1234, + ) + expected = descriptor.call(self.coord, self.atype, self.nlist) + + compressed = DescrptDPA1.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) + + self.assertTrue(compressed.compress) + self.assertTrue(compressed.geo_compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose(actual_item, expected_item, atol=1e-10) + np.testing.assert_allclose(reloaded_item, expected_item, atol=1e-10) + + def test_dpmodel_compress_entrypoint(self) -> None: + model_data = { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": [1, 1], + "neuron": [4, 8], + "axis_neuron": 2, + "resnet_dt": False, + "type_one_side": True, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + model = get_model(model_data) + model.min_nbor_dist = 1.0 + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "model.dp" + output_file = Path(tmpdir) / "model-compressed.dp" + save_dp_model( + str(input_file), + { + "model": model.serialize(), + "model_def_script": model_data, + "min_nbor_dist": 1.0, + }, + ) + + enable_compression( + str(input_file), + str(output_file), + stride=0.01, + extrapolate=5, + check_frequency=-1, + ) + + compressed = load_dp_model(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["min_nbor_dist"], 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/jax/test_model_compression.py b/source/tests/jax/test_model_compression.py new file mode 100644 index 0000000000..00bb296a3c --- /dev/null +++ b/source/tests/jax/test_model_compression.py @@ -0,0 +1,305 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import tempfile +import unittest +from importlib.util import ( + find_spec, +) +from pathlib import ( + Path, +) + +import numpy as np + +INSTALLED_JAX = find_spec("jax") is not None and find_spec("orbax") is not None + +if INSTALLED_JAX: + from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, + ) + from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, + ) + from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, + ) + from deepmd.jax.descriptor.se_e2_r import ( + DescrptSeR, + ) + from deepmd.jax.env import ( + jax, + jnp, + ) + from deepmd.jax.model.model import ( + get_model, + ) + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + from deepmd.main import main as dp_main + + +class TestJAXModelCompression(unittest.TestCase): + def setUp(self) -> None: + self.coord = np.array( + [ + [ + [0.0, 0.0, 0.0], + [1.2, 0.1, 0.0], + [0.1, 1.4, 0.0], + [1.5, 1.5, 0.1], + ] + ], + dtype=np.float64, + ) + self.atype = np.array([[0, 0, 1, 1]], dtype=np.int32) + self.nlist = np.array([[[1, 2], [0, 2], [0, 3], [0, 2]]], dtype=np.int64) + + def _make_model_data(self) -> dict: + return { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": [1, 1], + "neuron": [4, 8], + "axis_neuron": 2, + "resnet_dt": False, + "type_one_side": True, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_se_e2_a_enable_compression(self) -> None: + coord = jnp.array(self.coord) + atype = jnp.array(self.atype) + nlist = jnp.array(self.nlist) + descriptor = DescrptSeA( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + axis_neuron=2, + resnet_dt=False, + type_one_side=True, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) + + compressed = DescrptSeA.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeA.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), np.asarray(expected_item), atol=1e-10 + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), np.asarray(expected_item), atol=1e-10 + ) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_se_e2_r_enable_compression(self) -> None: + coord = jnp.array(self.coord) + atype = jnp.array(self.atype) + nlist = jnp.array(self.nlist) + descriptor = DescrptSeR( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + resnet_dt=False, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) + + compressed = DescrptSeR.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeR.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), np.asarray(expected_item), atol=1e-10 + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), np.asarray(expected_item), atol=1e-10 + ) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_se_atten_enable_compression(self) -> None: + coord = jnp.array(self.coord) + atype = jnp.array(self.atype) + nlist = jnp.array(self.nlist) + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 8], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) + + compressed = DescrptDPA1.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) + + self.assertTrue(compressed.compress) + self.assertTrue(compressed.geo_compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), np.asarray(expected_item), atol=1e-10 + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), np.asarray(expected_item), atol=1e-10 + ) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_compress_entrypoint(self) -> None: + model_data = self._make_model_data() + model = get_model(copy.deepcopy(model_data)) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "model.hlo" + output_file = Path(tmpdir) / "model-compressed.jax" + save_dp_model( + str(input_file), + { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model.serialize(), + "model_def_script": model_data, + "@variables": { + "stablehlo": np.void(b"stablehlo"), + }, + "constants": { + "min_nbor_dist": 1.0, + }, + }, + ) + + dp_main( + [ + "--jax", + "compress", + "-i", + str(input_file), + "-o", + str(output_file), + "-s", + "0.01", + ] + ) + + compressed = serialize_from_file(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["min_nbor_dist"], 1.0) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_compress_entrypoint_can_write_hlo(self) -> None: + model_data = self._make_model_data() + model = get_model(copy.deepcopy(model_data)) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "model.hlo" + output_file = Path(tmpdir) / "model-compressed.hlo" + save_dp_model( + str(input_file), + { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model.serialize(), + "model_def_script": model_data, + "@variables": { + "stablehlo": np.void(b"stablehlo"), + }, + "constants": { + "min_nbor_dist": 1.0, + }, + }, + ) + + dp_main( + [ + "--jax", + "compress", + "-i", + str(input_file), + "-o", + str(output_file), + "-s", + "0.01", + ] + ) + + compressed = load_dp_model(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["constants"]["min_nbor_dist"], 1.0) + self.assertIn("stablehlo", compressed["@variables"]) + self.assertIn("stablehlo_no_ghost", compressed["@variables"]) + + +if __name__ == "__main__": + unittest.main() From ec72b6729877e8c85f10f853a0c79cb10fd9c1f9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 02:19:16 +0800 Subject: [PATCH 2/6] fix(jax): restore compressed DPA1 checkpoints --- deepmd/dpmodel/descriptor/dpa1.py | 21 +- deepmd/dpmodel/descriptor/se_atten_v2.py | 21 +- deepmd/dpmodel/descriptor/se_e2_a.py | 32 ++- deepmd/dpmodel/entrypoints/compress.py | 67 +------ deepmd/dpmodel/entrypoints/compress_common.py | 109 ++++++++++ deepmd/jax/entrypoints/compress.py | 85 ++------ deepmd/jax/utils/serialization.py | 35 ++++ source/tests/jax/test_model_compression.py | 188 +++++++++++------- 8 files changed, 342 insertions(+), 216 deletions(-) create mode 100644 deepmd/dpmodel/entrypoints/compress_common.py diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index c813dc444a..0960de1ac2 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -708,18 +708,33 @@ def serialize(self) -> dict: if obj.tebd_input_mode in ["strip"]: data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) if self.compress: + type_embd_data = ( + self.type_embd_data + if hasattr(self, "type_embd_data") + else obj.type_embd_data + ) compress_dict: dict = { "@variables": { - "type_embd_data": to_numpy_array(self.type_embd_data), + "type_embd_data": to_numpy_array(type_embd_data), }, "geo_compress": self.geo_compress, } if self.geo_compress: + compress_data = ( + self.compress_data + if hasattr(self, "compress_data") + else obj.compress_data + ) + compress_info = ( + self.compress_info + if hasattr(self, "compress_info") + else obj.compress_info + ) compress_dict["@variables"]["compress_data"] = [ - to_numpy_array(d) for d in self.compress_data + to_numpy_array(d) for d in compress_data ] compress_dict["@variables"]["compress_info"] = [ - to_numpy_array(i) for i in self.compress_info + to_numpy_array(i) for i in compress_info ] data["compress"] = compress_dict return data diff --git a/deepmd/dpmodel/descriptor/se_atten_v2.py b/deepmd/dpmodel/descriptor/se_atten_v2.py index c3b02c3d1e..6942a5bac3 100644 --- a/deepmd/dpmodel/descriptor/se_atten_v2.py +++ b/deepmd/dpmodel/descriptor/se_atten_v2.py @@ -247,18 +247,33 @@ def serialize(self) -> dict: "spin": None, } if self.compress: + type_embd_data = ( + self.type_embd_data + if hasattr(self, "type_embd_data") + else obj.type_embd_data + ) compress_dict: dict = { "@variables": { - "type_embd_data": to_numpy_array(self.type_embd_data), + "type_embd_data": to_numpy_array(type_embd_data), }, "geo_compress": self.geo_compress, } if self.geo_compress: + compress_data = ( + self.compress_data + if hasattr(self, "compress_data") + else obj.compress_data + ) + compress_info = ( + self.compress_info + if hasattr(self, "compress_info") + else obj.compress_info + ) compress_dict["@variables"]["compress_data"] = [ - to_numpy_array(d) for d in self.compress_data + to_numpy_array(d) for d in compress_data ] compress_dict["@variables"]["compress_info"] = [ - to_numpy_array(i) for i in self.compress_info + to_numpy_array(i) for i in compress_info ] data["compress"] = compress_dict return data diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 48edbd836a..1da0b44f03 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -533,6 +533,14 @@ def _tabulate_fusion_se_a( ) return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) + def _add_to_slice(self, array: Array, start: int, end: int, value: Array) -> Array: + """Return ``array`` with ``value`` added to ``array[start:end]``.""" + xp = array_api_compat.array_namespace(array, value) + return xp.concat( + [array[:start], array[start:end] + value, array[end:]], + axis=0, + ) + def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], @@ -654,8 +662,11 @@ def call( tr = tr * xp.astype(mm[:, :, None], tr.dtype) ss = tr[..., 0:1] gg = self.cal_g(ss, (ti, tt)) - gr_s[s:e] = gr_s[s:e] + xp.sum( - gg[:, :, :, None] * tr[:, :, None, :], axis=1 + gr_s = self._add_to_slice( + gr_s, + s, + e, + xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1), ) gr = xp.take(gr_s, unsort_idx, axis=0) gr = xp.reshape(gr, (nf, nloc, ng, 4)) @@ -735,12 +746,17 @@ def _call_compressed( rr_i = rr_s[s:e, sec[tt] : sec[tt + 1], :] rr_i = rr_i * xp.astype(mm[:, :, None], rr_i.dtype) ss = rr_i[:, :, :1] - gr_s[s:e] = gr_s[s:e] + self._tabulate_fusion_se_a( - compress_data_ii, - compress_info_ii, - ss, - rr_i, - ng, + gr_s = self._add_to_slice( + gr_s, + s, + e, + self._tabulate_fusion_se_a( + compress_data_ii, + compress_info_ii, + ss, + rr_i, + ng, + ), ) gr = xp.take(gr_s, unsort_idx, axis=0) return xp.permute_dims(gr, (0, 2, 1)) diff --git a/deepmd/dpmodel/entrypoints/compress.py b/deepmd/dpmodel/entrypoints/compress.py index 3848ade9b3..edc942dbb3 100644 --- a/deepmd/dpmodel/entrypoints/compress.py +++ b/deepmd/dpmodel/entrypoints/compress.py @@ -3,6 +3,10 @@ import logging +from deepmd.dpmodel.entrypoints.compress_common import ( + enable_model_compression, + resolve_min_nbor_dist, +) from deepmd.dpmodel.model.base_model import ( BaseModel, ) @@ -14,42 +18,6 @@ log = logging.getLogger(__name__) -def _get_saved_min_nbor_dist(model_dict: dict) -> float | None: - """Read min_nbor_dist from known native-model metadata locations.""" - min_nbor_dist = model_dict.get("min_nbor_dist") - if min_nbor_dist is None: - constants = model_dict.get("constants", {}) - min_nbor_dist = constants.get("min_nbor_dist") - if min_nbor_dist is None: - return None - return float(min_nbor_dist) - - -def _compute_min_nbor_dist(training_script: str) -> float: - from deepmd.common import ( - j_loader, - ) - from deepmd.dpmodel.utils.update_sel import ( - UpdateSel, - ) - from deepmd.utils.compat import ( - update_deepmd_input, - ) - from deepmd.utils.data_system import ( - get_data, - ) - - jdata = update_deepmd_input(j_loader(training_script)) - type_map = jdata["model"].get("type_map", None) - train_data = get_data( - jdata["training"]["training_data"], - 0, - type_map, - None, - ) - return float(UpdateSel().get_min_nbor_dist(train_data)) - - def enable_compression( input_file: str, output: str, @@ -62,29 +30,12 @@ def enable_compression( model_dict = load_dp_model(input_file) model = BaseModel.deserialize(model_dict["model"]) - min_nbor_dist = model.get_min_nbor_dist() - if min_nbor_dist is None: - min_nbor_dist = _get_saved_min_nbor_dist(model_dict) - if min_nbor_dist is None: - log.info( - "Minimal neighbor distance is not saved in the model, " - "compute it from the training data." - ) - if training_script is None: - raise ValueError( - "The model does not have a minimum neighbor distance, " - "so the training script and data must be provided " - "(via -t,--training-script)." - ) - min_nbor_dist = _compute_min_nbor_dist(training_script) - - model.min_nbor_dist = float(min_nbor_dist) - model.enable_compression( - extrapolate, - stride, - stride * 10, - check_frequency, + min_nbor_dist = resolve_min_nbor_dist( + model, + [model_dict], + training_script, ) + enable_model_compression(model, min_nbor_dist, stride, extrapolate, check_frequency) compressed_model_dict = model_dict.copy() compressed_model_dict["model"] = model.serialize() diff --git a/deepmd/dpmodel/entrypoints/compress_common.py b/deepmd/dpmodel/entrypoints/compress_common.py new file mode 100644 index 0000000000..03d6c52897 --- /dev/null +++ b/deepmd/dpmodel/entrypoints/compress_common.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Shared helpers for native-model compression entrypoints.""" + +import logging +from collections.abc import ( + Iterable, +) +from typing import ( + Any, +) + +import numpy as np + +from deepmd.common import ( + j_loader, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + get_data, +) + +log = logging.getLogger(__name__) + + +def to_float(value: Any) -> float | None: + """Convert plain, NumPy, or framework-wrapped scalar values to float.""" + if value is None: + return None + value = getattr(value, "value", value) + return float(np.asarray(value)) + + +def get_saved_min_nbor_dist(data: dict) -> float | None: + """Read min_nbor_dist from known serialized-model metadata locations.""" + min_nbor_dist = to_float(data.get("min_nbor_dist")) + if min_nbor_dist is not None: + return min_nbor_dist + constants = data.get("constants", {}) + return to_float(constants.get("min_nbor_dist")) + + +def compute_min_nbor_dist( + training_script: str, + update_sel_cls: type[Any] | None = None, +) -> float: + """Compute min_nbor_dist from the training data.""" + if update_sel_cls is None: + from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, + ) + + update_sel_cls = UpdateSel + + jdata = update_deepmd_input(j_loader(training_script)) + type_map = jdata["model"].get("type_map", None) + train_data = get_data( + jdata["training"]["training_data"], + 0, + type_map, + None, + ) + return float(update_sel_cls().get_min_nbor_dist(train_data)) + + +def resolve_min_nbor_dist( + model: Any, + metadata_sources: Iterable[dict], + training_script: str | None, + update_sel_cls: type[Any] | None = None, +) -> float: + """Resolve min_nbor_dist from model, saved metadata, or training data.""" + min_nbor_dist = to_float(model.get_min_nbor_dist()) + if min_nbor_dist is None: + for metadata in metadata_sources: + min_nbor_dist = get_saved_min_nbor_dist(metadata) + if min_nbor_dist is not None: + break + if min_nbor_dist is None: + log.info( + "Minimal neighbor distance is not saved in the model, " + "compute it from the training data." + ) + if training_script is None: + raise ValueError( + "The model does not have a minimum neighbor distance, " + "so the training script and data must be provided " + "(via -t,--training-script)." + ) + min_nbor_dist = compute_min_nbor_dist(training_script, update_sel_cls) + return float(min_nbor_dist) + + +def enable_model_compression( + model: Any, + min_nbor_dist: float, + stride: float, + extrapolate: int, + check_frequency: int, +) -> None: + """Set compression inputs and enable descriptor tabulation.""" + model.min_nbor_dist = float(min_nbor_dist) + model.enable_compression( + table_extrapolate=extrapolate, + table_stride_1=stride, + table_stride_2=stride * 10, + check_frequency=check_frequency, + ) diff --git a/deepmd/jax/entrypoints/compress.py b/deepmd/jax/entrypoints/compress.py index 337241d111..a54addd9be 100644 --- a/deepmd/jax/entrypoints/compress.py +++ b/deepmd/jax/entrypoints/compress.py @@ -5,14 +5,10 @@ from pathlib import ( Path, ) -from typing import ( - Any, -) - -import numpy as np -from deepmd.common import ( - j_loader, +from deepmd.dpmodel.entrypoints.compress_common import ( + enable_model_compression, + resolve_min_nbor_dist, ) from deepmd.dpmodel.utils.serialization import ( load_dp_model, @@ -27,54 +23,10 @@ from deepmd.jax.utils.update_sel import ( UpdateSel, ) -from deepmd.utils.compat import ( - update_deepmd_input, -) -from deepmd.utils.data_system import ( - get_data, -) log = logging.getLogger(__name__) -def _to_float(value: Any) -> float | None: - if value is None: - return None - value = getattr(value, "value", value) - return float(np.asarray(value)) - - -def _get_saved_min_nbor_dist(data: dict) -> float | None: - """Read min_nbor_dist from known serialized-model metadata locations.""" - min_nbor_dist = _to_float(data.get("min_nbor_dist")) - if min_nbor_dist is not None: - return min_nbor_dist - constants = data.get("constants", {}) - return _to_float(constants.get("min_nbor_dist")) - - -def _get_input_min_nbor_dist(input_file: str, data: dict) -> float | None: - """Read min_nbor_dist from the serialized data or native HLO constants.""" - min_nbor_dist = _get_saved_min_nbor_dist(data) - if min_nbor_dist is not None: - return min_nbor_dist - if Path(input_file).suffix == ".hlo": - return _get_saved_min_nbor_dist(load_dp_model(input_file)) - return None - - -def _compute_min_nbor_dist(training_script: str) -> float: - jdata = update_deepmd_input(j_loader(training_script)) - type_map = jdata["model"].get("type_map", None) - train_data = get_data( - jdata["training"]["training_data"], - 0, - type_map, - None, - ) - return float(UpdateSel().get_min_nbor_dist(train_data)) - - def enable_compression( input_file: str, output: str, @@ -87,29 +39,16 @@ def enable_compression( data = serialize_from_file(input_file) model = BaseModel.deserialize(data["model"]) - min_nbor_dist = _to_float(model.get_min_nbor_dist()) - if min_nbor_dist is None: - min_nbor_dist = _get_input_min_nbor_dist(input_file, data) - if min_nbor_dist is None: - log.info( - "Minimal neighbor distance is not saved in the model, " - "compute it from the training data." - ) - if training_script is None: - raise ValueError( - "The model does not have a minimum neighbor distance, " - "so the training script and data must be provided " - "(via -t,--training-script)." - ) - min_nbor_dist = _compute_min_nbor_dist(training_script) - - model.min_nbor_dist = float(min_nbor_dist) - model.enable_compression( - extrapolate, - stride, - stride * 10, - check_frequency, + metadata_sources = [data] + if Path(input_file).suffix == ".hlo": + metadata_sources.append(load_dp_model(input_file)) + min_nbor_dist = resolve_min_nbor_dist( + model, + metadata_sources, + training_script, + UpdateSel, ) + enable_model_compression(model, min_nbor_dist, stride, extrapolate, check_frequency) compressed_data = data.copy() compressed_data["model"] = model.serialize() diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 758de69384..3a33c01e25 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -34,6 +34,11 @@ def _state_sequence_to_numpy_list(state_value: Any) -> list[np.ndarray]: return [np.asarray(getattr(value, "value", value)) for value in values] +def _state_value_to_numpy(state_value: Any) -> np.ndarray: + """Convert an Orbax-restored state value to a NumPy array.""" + return np.asarray(getattr(state_value, "value", state_value)) + + def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: """Create compression variable slots before replacing an NNX state. @@ -52,6 +57,27 @@ def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: obj.compress_data = _state_sequence_to_numpy_list(state["compress_data"]) obj.compress_info = _state_sequence_to_numpy_list(state["compress_info"]) obj.compress = True + if hasattr(obj, "geo_compress"): + obj.geo_compress = True + if hasattr(obj, "se_atten"): + obj.se_atten.compress_data = obj.compress_data + obj.se_atten.compress_info = obj.compress_info + if hasattr(obj.se_atten, "geo_compress"): + obj.se_atten.geo_compress = True + if hasattr(obj, "compress") and "type_embd_data" in state: + obj.type_embd_data = _state_value_to_numpy(state["type_embd_data"]) + obj.compress = True + obj.tebd_compress = True + if hasattr(obj, "geo_compress"): + obj.geo_compress = "compress_data" in state and "compress_info" in state + if hasattr(obj, "se_atten"): + obj.se_atten.type_embd_data = obj.type_embd_data + obj.se_atten.tebd_compress = True + if hasattr(obj.se_atten, "geo_compress"): + obj.se_atten.geo_compress = getattr(obj, "geo_compress", False) + if getattr(obj, "geo_compress", False): + obj.se_atten.compress_data = obj.compress_data + obj.se_atten.compress_info = obj.compress_info for name, child_state in state.items(): if not isinstance(child_state, dict): continue @@ -65,6 +91,15 @@ def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: continue child = getattr(obj, name) _restore_compression_slots_from_state(child, child_state) + if name == "se_atten" and hasattr(obj, "compress"): + if hasattr(child, "type_embd_data"): + obj.type_embd_data = child.type_embd_data + obj.tebd_compress = getattr(child, "tebd_compress", True) + obj.compress = True + if getattr(child, "geo_compress", False): + obj.geo_compress = True + obj.compress_data = child.compress_data + obj.compress_info = child.compress_info def _to_optional_float(value: Any) -> float | None: diff --git a/source/tests/jax/test_model_compression.py b/source/tests/jax/test_model_compression.py index 00bb296a3c..de3243e91e 100644 --- a/source/tests/jax/test_model_compression.py +++ b/source/tests/jax/test_model_compression.py @@ -81,48 +81,83 @@ def _make_model_data(self) -> dict: }, } + def _make_dpa1_model_data(self) -> dict: + return { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa1", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": 2, + "neuron": [4, 8], + "axis_neuron": 2, + "tebd_dim": 4, + "tebd_input_mode": "strip", + "resnet_dt": False, + "attn_layer": 0, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_se_e2_a_enable_compression(self) -> None: coord = jnp.array(self.coord) atype = jnp.array(self.atype) nlist = jnp.array(self.nlist) - descriptor = DescrptSeA( - rcut=4.0, - rcut_smth=3.5, - sel=[1, 1], - neuron=[4, 8], - axis_neuron=2, - resnet_dt=False, - type_one_side=True, - precision="float64", - seed=1234, - ) - expected = descriptor.call(coord, atype, nlist) + for type_one_side in (True, False): + with self.subTest(type_one_side=type_one_side): + descriptor = DescrptSeA( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + axis_neuron=2, + resnet_dt=False, + type_one_side=type_one_side, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) - compressed = DescrptSeA.deserialize(copy.deepcopy(descriptor.serialize())) - compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) - actual = compressed.call(coord, atype, nlist) + compressed = DescrptSeA.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) - self.assertTrue(compressed.compress) - serialized = compressed.serialize() - self.assertEqual(serialized["@version"], 3) - self.assertIn("compress", serialized) - reloaded = DescrptSeA.deserialize(copy.deepcopy(serialized)) - reloaded_actual = reloaded.call(coord, atype, nlist) + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeA.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) - for expected_item, actual_item, reloaded_item in zip( - expected, actual, reloaded_actual, strict=True - ): - if expected_item is None: - self.assertIsNone(actual_item) - self.assertIsNone(reloaded_item) - else: - np.testing.assert_allclose( - np.asarray(actual_item), np.asarray(expected_item), atol=1e-10 - ) - np.testing.assert_allclose( - np.asarray(reloaded_item), np.asarray(expected_item), atol=1e-10 - ) + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), + np.asarray(expected_item), + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), + np.asarray(expected_item), + atol=1e-10, + ) @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_se_e2_r_enable_compression(self) -> None: @@ -214,46 +249,57 @@ def test_se_atten_enable_compression(self) -> None: @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_jax_compress_entrypoint(self) -> None: - model_data = self._make_model_data() - model = get_model(copy.deepcopy(model_data)) + for name, model_data in ( + ("se_e2_a", self._make_model_data()), + ("dpa1", self._make_dpa1_model_data()), + ): + with self.subTest(descriptor=name): + model = get_model(copy.deepcopy(model_data)) - with tempfile.TemporaryDirectory() as tmpdir: - input_file = Path(tmpdir) / "model.hlo" - output_file = Path(tmpdir) / "model-compressed.jax" - save_dp_model( - str(input_file), - { - "backend": "JAX", - "jax_version": jax.__version__, - "model": model.serialize(), - "model_def_script": model_data, - "@variables": { - "stablehlo": np.void(b"stablehlo"), - }, - "constants": { - "min_nbor_dist": 1.0, - }, - }, - ) + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / f"model-{name}.hlo" + output_file = Path(tmpdir) / f"model-{name}-compressed.jax" + save_dp_model( + str(input_file), + { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model.serialize(), + "model_def_script": model_data, + "@variables": { + "stablehlo": np.void(b"stablehlo"), + }, + "constants": { + "min_nbor_dist": 1.0, + }, + }, + ) - dp_main( - [ - "--jax", - "compress", - "-i", - str(input_file), - "-o", - str(output_file), - "-s", - "0.01", - ] - ) + dp_main( + [ + "--jax", + "compress", + "-i", + str(input_file), + "-o", + str(output_file), + "-s", + "0.01", + ] + ) - compressed = serialize_from_file(str(output_file)) - descriptor = compressed["model"]["descriptor"] - self.assertEqual(descriptor["@version"], 3) - self.assertIn("compress", descriptor) - self.assertEqual(compressed["min_nbor_dist"], 1.0) + compressed = serialize_from_file(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["min_nbor_dist"], 1.0) + if name == "dpa1": + compress = descriptor["compress"] + self.assertTrue(compress["geo_compress"]) + variables = compress["@variables"] + self.assertIn("type_embd_data", variables) + self.assertIn("compress_data", variables) + self.assertIn("compress_info", variables) @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_jax_compress_entrypoint_can_write_hlo(self) -> None: From 52f8c44530cc23e555748a3105752797c7160b1b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 27 Jun 2026 16:31:14 +0800 Subject: [PATCH 3/6] fix(dpmodel): avoid default DPA1 compression slots --- deepmd/dpmodel/descriptor/dpa1.py | 6 +++--- deepmd/jax/utils/serialization.py | 20 +++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 0960de1ac2..2ee43ca6d1 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1032,9 +1032,9 @@ def __init__( self.tebd_compress = False self.geo_compress = False self.is_sorted = len(self.exclude_types) == 0 - self.compress_data = [np.zeros(0, dtype=PRECISION_DICT[self.precision])] - self.compress_info = [np.zeros(0, dtype=PRECISION_DICT[self.precision])] - self.type_embd_data = np.zeros(0, dtype=PRECISION_DICT[self.precision]) + self.compress_data = None + self.compress_info = None + self.type_embd_data = None def get_rcut(self) -> float: """Returns the cut-off radius.""" diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 3a33c01e25..1c436e0ad1 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -50,13 +50,14 @@ def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: if not isinstance(state, dict): return if ( - hasattr(obj, "compress") + (hasattr(obj, "compress") or hasattr(obj, "geo_compress")) and "compress_data" in state and "compress_info" in state ): obj.compress_data = _state_sequence_to_numpy_list(state["compress_data"]) obj.compress_info = _state_sequence_to_numpy_list(state["compress_info"]) - obj.compress = True + if hasattr(obj, "compress"): + obj.compress = True if hasattr(obj, "geo_compress"): obj.geo_compress = True if hasattr(obj, "se_atten"): @@ -64,10 +65,14 @@ def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: obj.se_atten.compress_info = obj.compress_info if hasattr(obj.se_atten, "geo_compress"): obj.se_atten.geo_compress = True - if hasattr(obj, "compress") and "type_embd_data" in state: + if ( + hasattr(obj, "compress") or hasattr(obj, "tebd_compress") + ) and "type_embd_data" in state: obj.type_embd_data = _state_value_to_numpy(state["type_embd_data"]) - obj.compress = True - obj.tebd_compress = True + if hasattr(obj, "compress"): + obj.compress = True + if hasattr(obj, "tebd_compress"): + obj.tebd_compress = True if hasattr(obj, "geo_compress"): obj.geo_compress = "compress_data" in state and "compress_info" in state if hasattr(obj, "se_atten"): @@ -92,8 +97,9 @@ def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: child = getattr(obj, name) _restore_compression_slots_from_state(child, child_state) if name == "se_atten" and hasattr(obj, "compress"): - if hasattr(child, "type_embd_data"): - obj.type_embd_data = child.type_embd_data + child_type_embd_data = getattr(child, "type_embd_data", None) + if child_type_embd_data is not None: + obj.type_embd_data = child_type_embd_data obj.tebd_compress = getattr(child, "tebd_compress", True) obj.compress = True if getattr(child, "geo_compress", False): From 48e01a8d28f8d29d9dcb2ea6796b66dcefe890fb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 29 Jun 2026 00:18:43 +0800 Subject: [PATCH 4/6] fix(dpmodel): add c1 tabulate extrapolation --- deepmd/dpmodel/descriptor/dpa1.py | 32 +++- deepmd/dpmodel/descriptor/se_e2_a.py | 32 +++- deepmd/dpmodel/descriptor/se_r.py | 34 +++- .../common/dpmodel/test_model_compression.py | 173 ++++++++++++++++++ 4 files changed, 261 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 2ee43ca6d1..4cb38b9e3e 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1301,10 +1301,22 @@ def _tabulate_fusion_se_atten( table_idx_value = xp.astype(table_idx, em.dtype) dx_first = xx - (table_idx_value * stride0 + lower) dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx_high = table_max - ( + (xp.astype(last_idx, em.dtype) - first_stride_value) * stride1 + upper + ) dx = xp.where( - (xx >= lower) & (xx < upper), - dx_first, - xp.where((xx >= upper) & (xx < table_max), dx_second, xp.zeros_like(xx)), + xx < lower, + xp.zeros_like(xx), + xp.where( + xx < upper, + dx_first, + xp.where(xx < table_max, dx_second, dx_high), + ), + ) + extrapolate_delta = xp.where( + xx < lower, + xx - lower, + xp.where(xx >= table_max, xx - table_max, xp.zeros_like(xx)), ) coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) @@ -1322,6 +1334,20 @@ def _tabulate_fusion_se_atten( ) * dx ) + values_grad = ( + coeff[..., 1] + + ( + 2 * coeff[..., 2] + + ( + 3 * coeff[..., 3] + + (4 * coeff[..., 4] + 5 * coeff[..., 5] * dx) * dx + ) + * dx + ) + * dx + ) + extrapolate_delta = xp.reshape(extrapolate_delta, (nloc, nnei, 1)) + values = values + values_grad * extrapolate_delta values = values * two_embed + values return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 1da0b44f03..292e3d4fab 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -510,10 +510,22 @@ def _tabulate_fusion_se_a( table_idx_value = xp.astype(table_idx, em.dtype) dx_first = xx - (table_idx_value * stride0 + lower) dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx_high = table_max - ( + (xp.astype(last_idx, em.dtype) - first_stride_value) * stride1 + upper + ) dx = xp.where( - (xx >= lower) & (xx < upper), - dx_first, - xp.where((xx >= upper) & (xx < table_max), dx_second, xp.zeros_like(xx)), + xx < lower, + xp.zeros_like(xx), + xp.where( + xx < upper, + dx_first, + xp.where(xx < table_max, dx_second, dx_high), + ), + ) + extrapolate_delta = xp.where( + xx < lower, + xx - lower, + xp.where(xx >= table_max, xx - table_max, xp.zeros_like(xx)), ) coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) @@ -531,6 +543,20 @@ def _tabulate_fusion_se_a( ) * dx ) + values_grad = ( + coeff[..., 1] + + ( + 2 * coeff[..., 2] + + ( + 3 * coeff[..., 3] + + (4 * coeff[..., 4] + 5 * coeff[..., 5] * dx) * dx + ) + * dx + ) + * dx + ) + extrapolate_delta = xp.reshape(extrapolate_delta, (nloc, nnei, 1)) + values = values + values_grad * extrapolate_delta return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) def _add_to_slice(self, array: Array, start: int, end: int, value: Array) -> Array: diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 891f234ced..f9d91c25d8 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -478,16 +478,28 @@ def _tabulate_fusion_se_r( table_idx_value = xp.astype(table_idx, em_x.dtype) dx_first = xx - (table_idx_value * stride0 + lower) dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx_high = table_max - ( + (xp.astype(last_idx, em_x.dtype) - first_stride_value) * stride1 + upper + ) dx = xp.where( - (xx >= lower) & (xx < upper), - dx_first, - xp.where((xx >= upper) & (xx < table_max), dx_second, xp.zeros_like(xx)), + xx < lower, + xp.zeros_like(xx), + xp.where( + xx < upper, + dx_first, + xp.where(xx < table_max, dx_second, dx_high), + ), + ) + extrapolate_delta = xp.where( + xx < lower, + xx - lower, + xp.where(xx >= table_max, xx - table_max, xp.zeros_like(xx)), ) coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) coeff = xp.reshape(coeff, (nloc, nnei, last_layer_size, 6)) dx = xp.reshape(dx, (nloc, nnei, 1)) - return ( + values = ( coeff[..., 0] + ( coeff[..., 1] @@ -499,6 +511,20 @@ def _tabulate_fusion_se_r( ) * dx ) + values_grad = ( + coeff[..., 1] + + ( + 2 * coeff[..., 2] + + ( + 3 * coeff[..., 3] + + (4 * coeff[..., 4] + 5 * coeff[..., 5] * dx) * dx + ) + * dx + ) + * dx + ) + extrapolate_delta = xp.reshape(extrapolate_delta, (nloc, nnei, 1)) + return values + values_grad * extrapolate_delta @cast_precision def call( diff --git a/source/tests/common/dpmodel/test_model_compression.py b/source/tests/common/dpmodel/test_model_compression.py index bf57e8b183..f561a63097 100644 --- a/source/tests/common/dpmodel/test_model_compression.py +++ b/source/tests/common/dpmodel/test_model_compression.py @@ -63,6 +63,179 @@ def _make_descriptor( seed=1234, ) + @staticmethod + def _poly5(coeff: np.ndarray, xx: np.ndarray | float) -> np.ndarray: + return ( + coeff[..., 0] + + ( + coeff[..., 1] + + ( + coeff[..., 2] + + (coeff[..., 3] + (coeff[..., 4] + coeff[..., 5] * xx) * xx) * xx + ) + * xx + ) + * xx + ) + + @staticmethod + def _poly5_grad(coeff: np.ndarray, xx: np.ndarray | float) -> np.ndarray: + return ( + coeff[..., 1] + + ( + 2 * coeff[..., 2] + + ( + 3 * coeff[..., 3] + + (4 * coeff[..., 4] + 5 * coeff[..., 5] * xx) * xx + ) + * xx + ) + * xx + ) + + def _make_c1_tabulation_case( + self, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + table_info = np.array([1.0, 3.0, 5.0, 1.0, 1.0, -1.0], dtype=np.float64) + coeff = np.array( + [ + [ + [1.0, 0.5, 0.25, -0.10, 0.03, 0.02], + [2.0, -0.4, 0.20, 0.05, -0.02, 0.01], + ], + [ + [3.0, -0.2, 0.40, 0.15, -0.03, 0.02], + [4.0, 0.3, -0.10, 0.07, 0.04, -0.01], + ], + [ + [5.0, 0.7, -0.20, 0.12, 0.01, -0.04], + [6.0, -0.6, 0.30, -0.05, 0.02, 0.03], + ], + [ + [7.0, 1.1, -0.35, 0.25, 0.08, -0.02], + [8.0, -0.9, 0.45, -0.15, 0.06, 0.01], + ], + ], + dtype=np.float64, + ) + xx = np.array([[0.5, 1.0, 1.5, 3.25, 5.0, 5.5]], dtype=np.float64) + expected = self._expected_c1_tabulation(coeff, table_info, xx[0]) + table = np.reshape(coeff, (coeff.shape[0], -1)) + return table, coeff, table_info, xx, expected + + def _expected_c1_tabulation( + self, + coeff: np.ndarray, + table_info: np.ndarray, + xx: np.ndarray, + ) -> np.ndarray: + lower, upper, table_max, stride0, stride1 = table_info[:5] + first_stride = int(np.floor((upper - lower) / stride0)) + values = [] + for value in xx: + delta = 0.0 + if value < lower: + table_idx = 0 + dx = 0.0 + delta = value - lower + elif value < upper: + table_idx = int(np.floor((value - lower) / stride0)) + dx = value - (table_idx * stride0 + lower) + elif value < table_max: + table_idx = first_stride + int(np.floor((value - upper) / stride1)) + dx = value - ((table_idx - first_stride) * stride1 + upper) + else: + table_idx = coeff.shape[0] - 1 + dx = table_max - ((table_idx - first_stride) * stride1 + upper) + delta = value - table_max + values.append( + self._poly5(coeff[table_idx], dx) + + self._poly5_grad(coeff[table_idx], dx) * delta + ) + return np.asarray(values, dtype=np.float64) + + def test_tabulate_fusion_se_r_c1_extrapolates_outside_table(self) -> None: + table, coeff, table_info, xx, expected = self._make_c1_tabulation_case() + descriptor = DescrptSeR( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 2], + resnet_dt=False, + precision="float64", + seed=1234, + ) + + actual = descriptor._tabulate_fusion_se_r( + table, + table_info, + xx[:, :, None], + expected.shape[-1], + ) + + np.testing.assert_allclose(actual[0], expected, atol=1e-12) + np.testing.assert_allclose( + actual[0, 1] - actual[0, 0], + 0.5 * self._poly5_grad(coeff[0], 0.0), + atol=1e-12, + ) + np.testing.assert_allclose( + actual[0, 5] - actual[0, 4], + 0.5 * self._poly5_grad(coeff[-1], 1.0), + atol=1e-12, + ) + + def test_tabulate_fusion_se_a_c1_extrapolates_outside_table(self) -> None: + table, _, table_info, xx, expected = self._make_c1_tabulation_case() + descriptor = self._make_descriptor() + em = np.zeros((1, xx.shape[1], 4), dtype=np.float64) + em[:, :, 0] = 1.0 + expected_out = np.zeros((1, 4, expected.shape[-1]), dtype=np.float64) + expected_out[:, 0, :] = np.sum(expected, axis=0) + + actual = descriptor._tabulate_fusion_se_a( + table, + table_info, + xx[:, :, None], + em, + expected.shape[-1], + ) + + np.testing.assert_allclose(actual, expected_out, atol=1e-12) + + def test_tabulate_fusion_se_atten_c1_extrapolates_outside_table(self) -> None: + table, _, table_info, xx, expected = self._make_c1_tabulation_case() + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 2], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + precision="float64", + seed=1234, + ) + em = np.zeros((1, xx.shape[1], 4), dtype=np.float64) + em[:, :, 0] = 1.0 + two_embed = np.full_like(expected[None, :, :], 0.5) + expected_out = np.zeros((1, 4, expected.shape[-1]), dtype=np.float64) + expected_out[:, 0, :] = np.sum(expected * (1.0 + two_embed[0]), axis=0) + + actual = descriptor.se_atten._tabulate_fusion_se_atten( + table, + table_info, + xx[:, :, None], + em, + two_embed, + expected.shape[-1], + ) + + np.testing.assert_allclose(actual, expected_out, atol=1e-12) + def test_se_e2_a_enable_compression(self) -> None: for type_one_side, exclude_types in ( (True, []), From 97a4a31672535170769818947fc79335695777ec Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 29 Jun 2026 00:51:51 +0800 Subject: [PATCH 5/6] test(dpmodel): tighten compression entrypoint checks --- .../common/dpmodel/test_model_compression.py | 40 ++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/source/tests/common/dpmodel/test_model_compression.py b/source/tests/common/dpmodel/test_model_compression.py index f561a63097..446bccae68 100644 --- a/source/tests/common/dpmodel/test_model_compression.py +++ b/source/tests/common/dpmodel/test_model_compression.py @@ -20,6 +20,9 @@ from deepmd.dpmodel.entrypoints.compress import ( enable_compression, ) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) from deepmd.dpmodel.model.model import ( get_model, ) @@ -173,15 +176,17 @@ def test_tabulate_fusion_se_r_c1_extrapolates_outside_table(self) -> None: expected.shape[-1], ) - np.testing.assert_allclose(actual[0], expected, atol=1e-12) + np.testing.assert_allclose(actual[0], expected, rtol=0.0, atol=1e-12) np.testing.assert_allclose( actual[0, 1] - actual[0, 0], 0.5 * self._poly5_grad(coeff[0], 0.0), + rtol=0.0, atol=1e-12, ) np.testing.assert_allclose( actual[0, 5] - actual[0, 4], 0.5 * self._poly5_grad(coeff[-1], 1.0), + rtol=0.0, atol=1e-12, ) @@ -201,7 +206,7 @@ def test_tabulate_fusion_se_a_c1_extrapolates_outside_table(self) -> None: expected.shape[-1], ) - np.testing.assert_allclose(actual, expected_out, atol=1e-12) + np.testing.assert_allclose(actual, expected_out, rtol=0.0, atol=1e-12) def test_tabulate_fusion_se_atten_c1_extrapolates_outside_table(self) -> None: table, _, table_info, xx, expected = self._make_c1_tabulation_case() @@ -234,7 +239,7 @@ def test_tabulate_fusion_se_atten_c1_extrapolates_outside_table(self) -> None: expected.shape[-1], ) - np.testing.assert_allclose(actual, expected_out, atol=1e-12) + np.testing.assert_allclose(actual, expected_out, rtol=0.0, atol=1e-12) def test_se_e2_a_enable_compression(self) -> None: for type_one_side, exclude_types in ( @@ -268,10 +273,10 @@ def test_se_e2_a_enable_compression(self) -> None: self.assertIsNone(reloaded_item) else: np.testing.assert_allclose( - actual_item, expected_item, atol=1e-10 + actual_item, expected_item, rtol=0.0, atol=1e-10 ) np.testing.assert_allclose( - reloaded_item, expected_item, atol=1e-10 + reloaded_item, expected_item, rtol=0.0, atol=1e-10 ) def test_se_e2_r_enable_compression(self) -> None: @@ -304,8 +309,12 @@ def test_se_e2_r_enable_compression(self) -> None: self.assertIsNone(actual_item) self.assertIsNone(reloaded_item) else: - np.testing.assert_allclose(actual_item, expected_item, atol=1e-10) - np.testing.assert_allclose(reloaded_item, expected_item, atol=1e-10) + np.testing.assert_allclose( + actual_item, expected_item, rtol=0.0, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, rtol=0.0, atol=1e-10 + ) def test_se_atten_enable_compression(self) -> None: descriptor = DescrptDPA1( @@ -343,8 +352,12 @@ def test_se_atten_enable_compression(self) -> None: self.assertIsNone(actual_item) self.assertIsNone(reloaded_item) else: - np.testing.assert_allclose(actual_item, expected_item, atol=1e-10) - np.testing.assert_allclose(reloaded_item, expected_item, atol=1e-10) + np.testing.assert_allclose( + actual_item, expected_item, rtol=0.0, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, rtol=0.0, atol=1e-10 + ) def test_dpmodel_compress_entrypoint(self) -> None: model_data = { @@ -372,6 +385,7 @@ def test_dpmodel_compress_entrypoint(self) -> None: } model = get_model(model_data) model.min_nbor_dist = 1.0 + expected_output = model.call(self.coord, self.atype) with tempfile.TemporaryDirectory() as tmpdir: input_file = Path(tmpdir) / "model.dp" @@ -399,6 +413,14 @@ def test_dpmodel_compress_entrypoint(self) -> None: self.assertIn("compress", descriptor) self.assertEqual(compressed["min_nbor_dist"], 1.0) + reloaded_model = BaseModel.deserialize(copy.deepcopy(compressed["model"])) + actual_output = reloaded_model.call(self.coord, self.atype) + self.assertEqual(actual_output.keys(), expected_output.keys()) + for key, expected_value in expected_output.items(): + np.testing.assert_allclose( + actual_output[key], expected_value, rtol=0.0, atol=1e-10 + ) + if __name__ == "__main__": unittest.main() From b8a4c977a3beaa647c4768bdea98ec5ef5daf636 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 29 Jun 2026 19:09:07 +0800 Subject: [PATCH 6/6] test(dpmodel): cover one-sided DPA1 compression --- .../common/dpmodel/test_model_compression.py | 81 +++++++++--------- source/tests/jax/test_model_compression.py | 85 ++++++++++--------- 2 files changed, 90 insertions(+), 76 deletions(-) diff --git a/source/tests/common/dpmodel/test_model_compression.py b/source/tests/common/dpmodel/test_model_compression.py index 446bccae68..29f5c281a8 100644 --- a/source/tests/common/dpmodel/test_model_compression.py +++ b/source/tests/common/dpmodel/test_model_compression.py @@ -317,47 +317,52 @@ def test_se_e2_r_enable_compression(self) -> None: ) def test_se_atten_enable_compression(self) -> None: - descriptor = DescrptDPA1( - rcut=4.0, - rcut_smth=3.5, - sel=2, - ntypes=2, - neuron=[4, 8], - axis_neuron=2, - tebd_dim=4, - tebd_input_mode="strip", - resnet_dt=False, - attn_layer=0, - precision="float64", - seed=1234, - ) - expected = descriptor.call(self.coord, self.atype, self.nlist) + for type_one_side in (True, False): + with self.subTest(type_one_side=type_one_side): + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 8], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + type_one_side=type_one_side, + precision="float64", + seed=1234, + ) + expected = descriptor.call(self.coord, self.atype, self.nlist) - compressed = DescrptDPA1.deserialize(copy.deepcopy(descriptor.serialize())) - compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) - actual = compressed.call(self.coord, self.atype, self.nlist) + compressed = DescrptDPA1.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) - self.assertTrue(compressed.compress) - self.assertTrue(compressed.geo_compress) - serialized = compressed.serialize() - self.assertEqual(serialized["@version"], 3) - self.assertIn("compress", serialized) - reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) - reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + self.assertTrue(compressed.compress) + self.assertTrue(compressed.geo_compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) - for expected_item, actual_item, reloaded_item in zip( - expected, actual, reloaded_actual, strict=True - ): - if expected_item is None: - self.assertIsNone(actual_item) - self.assertIsNone(reloaded_item) - else: - np.testing.assert_allclose( - actual_item, expected_item, rtol=0.0, atol=1e-10 - ) - np.testing.assert_allclose( - reloaded_item, expected_item, rtol=0.0, atol=1e-10 - ) + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + actual_item, expected_item, rtol=0.0, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, rtol=0.0, atol=1e-10 + ) def test_dpmodel_compress_entrypoint(self) -> None: model_data = { diff --git a/source/tests/jax/test_model_compression.py b/source/tests/jax/test_model_compression.py index de3243e91e..3cf0b796f6 100644 --- a/source/tests/jax/test_model_compression.py +++ b/source/tests/jax/test_model_compression.py @@ -205,47 +205,56 @@ def test_se_atten_enable_compression(self) -> None: coord = jnp.array(self.coord) atype = jnp.array(self.atype) nlist = jnp.array(self.nlist) - descriptor = DescrptDPA1( - rcut=4.0, - rcut_smth=3.5, - sel=2, - ntypes=2, - neuron=[4, 8], - axis_neuron=2, - tebd_dim=4, - tebd_input_mode="strip", - resnet_dt=False, - attn_layer=0, - precision="float64", - seed=1234, - ) - expected = descriptor.call(coord, atype, nlist) + for type_one_side in (True, False): + with self.subTest(type_one_side=type_one_side): + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 8], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + type_one_side=type_one_side, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) - compressed = DescrptDPA1.deserialize(copy.deepcopy(descriptor.serialize())) - compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) - actual = compressed.call(coord, atype, nlist) + compressed = DescrptDPA1.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) - self.assertTrue(compressed.compress) - self.assertTrue(compressed.geo_compress) - serialized = compressed.serialize() - self.assertEqual(serialized["@version"], 3) - self.assertIn("compress", serialized) - reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) - reloaded_actual = reloaded.call(coord, atype, nlist) + self.assertTrue(compressed.compress) + self.assertTrue(compressed.geo_compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) - for expected_item, actual_item, reloaded_item in zip( - expected, actual, reloaded_actual, strict=True - ): - if expected_item is None: - self.assertIsNone(actual_item) - self.assertIsNone(reloaded_item) - else: - np.testing.assert_allclose( - np.asarray(actual_item), np.asarray(expected_item), atol=1e-10 - ) - np.testing.assert_allclose( - np.asarray(reloaded_item), np.asarray(expected_item), atol=1e-10 - ) + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), + np.asarray(expected_item), + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), + np.asarray(expected_item), + atol=1e-10, + ) @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") def test_jax_compress_entrypoint(self) -> None: