diff --git a/deepmd/backend/dpmodel.py b/deepmd/backend/dpmodel.py index 31585aa7a6..0e6e0964f3 100644 --- a/deepmd/backend/dpmodel.py +++ b/deepmd/backend/dpmodel.py @@ -34,7 +34,10 @@ class DPModelBackend(Backend): name = "DPModel" """The formal name of the backend.""" features: ClassVar[Backend.Feature] = ( - Backend.Feature.DEEP_EVAL | Backend.Feature.NEIGHBOR_STAT | Backend.Feature.IO + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.IO ) """The features of the backend.""" suffixes: ClassVar[list[str]] = [".dp", ".yaml", ".yml"] @@ -59,7 +62,9 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]: Callable[[Namespace], None] The entry point hook of the backend. """ - raise NotImplementedError(f"Unsupported backend: {self.name}") + from deepmd.dpmodel.entrypoints.main import main as deepmd_main + + return deepmd_main @property def deep_eval(self) -> type["DeepEvalBackend"]: diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 27e2d68bfc..d8f7bd1a43 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import math +import warnings from collections.abc import ( Callable, ) @@ -32,6 +33,7 @@ EnvMat, NetworkCollection, PairExcludeMask, + tabulate_fusion, ) from deepmd.dpmodel.utils.env_mat_stat import ( EnvMatStatSe, @@ -65,6 +67,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.tabulate_math import ( + DPTabulate, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -346,6 +351,8 @@ def __init__( self.concat_output_tebd = concat_output_tebd self.trainable = trainable self.precision = precision + self.tebd_compress = False + self.geo_compress = False self.compress = False def get_rcut(self) -> float: @@ -763,6 +770,89 @@ def call_graph( grrg = xp.concat([grrg, atype_embd], axis=-1) return grrg, rot_mat + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Enable descriptor compression. + + For DPA-1, compression is available for stripped type embeddings. The + type embedding branch is always precomputed; the radial embedding table + is enabled only when there is no attention layer, matching the PT/TF + compression semantics. + """ + if self.compress: + raise ValueError("Compression is already enabled.") + if self.se_atten.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.se_atten.resnet_dt: + raise RuntimeError( + "Model compression error: descriptor resnet_dt must be false!" + ) + for tt in self.se_atten.exclude_types: + if (tt[0] not in range(self.se_atten.ntypes)) or ( + tt[1] not in range(self.se_atten.ntypes) + ): + raise RuntimeError( + "exclude types" + + str(tt) + + " must within the number of atomic types " + + str(self.se_atten.ntypes) + + "!" + ) + if ( + self.se_atten.ntypes * self.se_atten.ntypes + - len(self.se_atten.exclude_types) + == 0 + ): + raise RuntimeError( + "Empty embedding-nets are not supported in model compression!" + ) + + self.se_atten.type_embedding_compression(self.type_embedding) + self.type_embd_data = self.se_atten.type_embd_data + self.tebd_compress = True + self.compress = True + + if self.se_atten.attn_layer == 0: + table = DPTabulate( + self, + self.se_atten.neuron, + self.se_atten.type_one_side, + self.se_atten.exclude_types, + self.se_atten.activation_function, + ) + table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + lower, upper = table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.se_atten.enable_compression( + table.data, + table_config, + lower, + upper, + ) + self.compress_data = self.se_atten.compress_data + self.compress_info = self.se_atten.compress_info + self.geo_compress = True + else: + self.geo_compress = False + warnings.warn( + "Attention layer is not 0, only type embedding is compressed. " + "Geometric part is not compressed.", + UserWarning, + stacklevel=2, + ) + def serialize(self) -> dict: """Serialize the descriptor to dict.""" obj = self.se_atten @@ -815,18 +905,33 @@ def serialize(self) -> dict: if obj.tebd_input_mode in ["strip"]: data.update({"embeddings_strip": obj.embeddings_strip.serialize()}) if self.compress: + type_embd_data = ( + self.type_embd_data + if hasattr(self, "type_embd_data") + else obj.type_embd_data + ) compress_dict: dict = { "@variables": { - "type_embd_data": to_numpy_array(self.type_embd_data), + "type_embd_data": to_numpy_array(type_embd_data), }, "geo_compress": self.geo_compress, } if self.geo_compress: + compress_data = ( + self.compress_data + if hasattr(self, "compress_data") + else obj.compress_data + ) + compress_info = ( + self.compress_info + if hasattr(self, "compress_info") + else obj.compress_info + ) compress_dict["@variables"]["compress_data"] = [ - to_numpy_array(d) for d in self.compress_data + to_numpy_array(d) for d in compress_data ] compress_dict["@variables"]["compress_info"] = [ - to_numpy_array(i) for i in self.compress_info + to_numpy_array(i) for i in compress_info ] data["compress"] = compress_dict return data @@ -874,9 +979,15 @@ def _load_compress_data(self, compress: dict) -> None: variables = compress["@variables"] self.type_embd_data = variables["type_embd_data"] self.geo_compress = compress.get("geo_compress", False) + self.tebd_compress = True + self.se_atten.type_embd_data = self.type_embd_data + self.se_atten.tebd_compress = True + self.se_atten.geo_compress = self.geo_compress if self.geo_compress: self.compress_data = variables["compress_data"] self.compress_info = variables["compress_info"] + self.se_atten.compress_data = self.compress_data + self.se_atten.compress_info = self.compress_info self.compress = True @classmethod @@ -1115,6 +1226,12 @@ def __init__( self.mean = np.zeros(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.stddev = np.ones(wanted_shape, dtype=PRECISION_DICT[self.precision]) self.orig_sel = self.sel + self.tebd_compress = False + self.geo_compress = False + self.is_sorted = len(self.exclude_types) == 0 + self.compress_data = None + self.compress_info = None + self.type_embd_data = None def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -1253,6 +1370,7 @@ def reinit_exclude( exclude_types: list[tuple[int, int]] = [], ) -> None: self.exclude_types = exclude_types + self.is_sorted = len(self.exclude_types) == 0 self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) def cal_g( @@ -1278,6 +1396,82 @@ def cal_g_strip( gg = self.embeddings_strip[embedding_idx].call(ss) return gg + def enable_compression( + self, + table_data: dict[str, Array], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + """Store tabulated geometric embedding-net data.""" + net = "filter_net" + dtype = self.mean.dtype + self.compress_info = [ + np.asarray( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=dtype, + ) + ] + self.compress_data = [np.asarray(table_data[net], dtype=dtype)] + self.geo_compress = True + + def type_embedding_compression(self, type_embedding_net: TypeEmbedNet) -> None: + """Precompute stripped type embedding network outputs.""" + if self.tebd_input_mode != "strip": + raise RuntimeError("Type embedding compression only works in strip mode") + if self.embeddings_strip is None: + raise RuntimeError( + "embeddings_strip must be initialized for type embedding compression" + ) + + full_embd = type_embedding_net.call() + xp = array_api_compat.array_namespace(full_embd) + nt, t_dim = full_embd.shape + if self.type_one_side: + embd_tensor = self.embeddings_strip[0].call(full_embd) + else: + type_embedding_nei = xp.tile( + xp.reshape(full_embd, (1, nt, t_dim)), (nt, 1, 1) + ) + type_embedding_center = xp.tile( + xp.reshape(full_embd, (nt, 1, t_dim)), (1, nt, 1) + ) + two_side_type_embedding = xp.reshape( + xp.concat([type_embedding_nei, type_embedding_center], axis=-1), + (-1, t_dim * 2), + ) + embd_tensor = self.embeddings_strip[0].call(two_side_type_embedding) + self.type_embd_data = embd_tensor + self.tebd_compress = True + + def _tabulate_fusion_se_atten( + self, + table: Array, + table_info: Array, + em_x: Array, + em: Array, + two_embed: Array, + last_layer_size: int, + ) -> Array: + """Pure Array API implementation of tabulate_fusion_se_atten forward.""" + xp = array_api_compat.array_namespace(em_x, em, two_embed) + values = tabulate_fusion( + table, + table_info, + em_x, + last_layer_size, + reference=em, + ) + values = values * two_embed + values + return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) + def call( self, nlist: Array, @@ -1327,6 +1521,7 @@ def call( rr = rr * xp.astype(exclude_mask[:, :, None], rr.dtype) # nfnl x nnei x 1 ss = rr[..., 0:1] + geo_gr = None if self.tebd_input_mode in ["concat"]: # nfnl x tebd_dim atype_embd = xp.reshape( @@ -1353,8 +1548,7 @@ def call( # nfnl x nnei x ng gg = self.cal_g(ss, 0) elif self.tebd_input_mode in ["strip"]: - # nfnl x nnei x ng - gg_s = self.cal_g(ss, 0) + ss_scalar = ss assert self.embeddings_strip is not None assert type_embedding is not None ntypes_with_padding = type_embedding.shape[0] @@ -1364,7 +1558,14 @@ def call( # (nf x nl x nnei) x ng nei_type_index = xp.tile(xp.reshape(nei_type, (-1, 1)), (1, ng)) if self.type_one_side: - tt_full = self.cal_g_strip(type_embedding, 0) + if self.tebd_compress: + tt_full = xp.asarray( + self.type_embd_data[...], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + else: + tt_full = self.cal_g_strip(type_embedding, 0) # (nf x nl x nnei) x ng gg_t = xp_take_along_axis(tt_full, nei_type_index, axis=0) else: @@ -1379,46 +1580,70 @@ def call( idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) idx = xp.astype(idx, xp.int64) - # (ntypes) * ntypes * nt - type_embedding_nei = xp.tile( - xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), - (ntypes_with_padding, 1, 1), - ) - # ntypes * (ntypes) * nt - type_embedding_center = xp.tile( - xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)), - (1, ntypes_with_padding, 1), - ) - # (ntypes * ntypes) * (nt+nt) - two_side_type_embedding = xp.reshape( - xp.concat([type_embedding_nei, type_embedding_center], axis=-1), - (-1, nt * 2), - ) - tt_full = self.cal_g_strip(two_side_type_embedding, 0) + if self.tebd_compress: + tt_full = xp.asarray( + self.type_embd_data[...], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + else: + # (ntypes) * ntypes * nt + type_embedding_nei = xp.tile( + xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), + (ntypes_with_padding, 1, 1), + ) + # ntypes * (ntypes) * nt + type_embedding_center = xp.tile( + xp.reshape(type_embedding, (ntypes_with_padding, 1, nt)), + (1, ntypes_with_padding, 1), + ) + # (ntypes * ntypes) * (nt+nt) + two_side_type_embedding = xp.reshape( + xp.concat([type_embedding_nei, type_embedding_center], axis=-1), + (-1, nt * 2), + ) + tt_full = self.cal_g_strip(two_side_type_embedding, 0) # (nf x nl x nnei) x ng gg_t = xp_take_along_axis(tt_full, idx, axis=0) # (nf x nl) x nnei x ng gg_t = xp.reshape(gg_t, (nf * nloc, nnei, ng)) if self.smooth: gg_t = gg_t * xp.reshape(sw, (-1, self.nnei, 1)) - # nfnl x nnei x ng - gg = gg_s * gg_t + gg_s + if self.geo_compress: + geo_gr = self._tabulate_fusion_se_atten( + self.compress_data[0], + self.compress_info[0], + ss_scalar, + rr, + gg_t, + self.filter_neuron[-1], + ) + gg = None + else: + # nfnl x nnei x ng + gg_s = self.cal_g(ss_scalar, 0) + gg = gg_s * gg_t + gg_s else: raise NotImplementedError - normed = safe_for_vector_norm( - xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True - ) - input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( - normed, - xp.full_like(normed, 1e-12), - ) - gg = self.dpa1_attention( - gg, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - # nfnl x ng x 4 - # gr = xp.einsum("lni,lnj->lij", gg, rr) - gr = xp.sum(gg[:, :, :, None] * rr[:, :, None, :], axis=1) + if geo_gr is None: + normed = safe_for_vector_norm( + xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4], axis=-1, keepdims=True + ) + input_r = xp.reshape(rr, (-1, nnei, 4))[:, :, 1:4] / xp.maximum( + normed, + xp.full_like(normed, 1e-12), + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x ng x 4 + # gr = xp.einsum("lni,lnj->lij", gg, rr) + gr = xp.sum(gg[:, :, :, None] * rr[:, :, None, :], axis=1) + g2 = xp.reshape(gg, (nf, nloc, self.nnei, self.filter_neuron[-1])) + else: + gr = xp.permute_dims(geo_gr, (0, 2, 1)) + g2 = None gr /= self.nnei gr1 = gr[:, : self.axis_neuron, :] # nfnl x ng x ng1 @@ -1430,7 +1655,7 @@ def call( ) return ( xp.reshape(grrg, (nf, nloc, self.filter_neuron[-1] * self.axis_neuron)), - xp.reshape(gg, (nf, nloc, self.nnei, self.filter_neuron[-1])), + g2, xp.reshape(dmatrix, (nf, nloc, self.nnei, 4))[..., 1:], xp.reshape(gr[..., 1:], (nf, nloc, self.filter_neuron[-1], 3)), xp.reshape(sw, (nf, nloc, nnei, 1)), diff --git a/deepmd/dpmodel/descriptor/se_atten_v2.py b/deepmd/dpmodel/descriptor/se_atten_v2.py index 1cc383cfd9..6942a5bac3 100644 --- a/deepmd/dpmodel/descriptor/se_atten_v2.py +++ b/deepmd/dpmodel/descriptor/se_atten_v2.py @@ -247,18 +247,33 @@ def serialize(self) -> dict: "spin": None, } if self.compress: + type_embd_data = ( + self.type_embd_data + if hasattr(self, "type_embd_data") + else obj.type_embd_data + ) compress_dict: dict = { "@variables": { - "type_embd_data": to_numpy_array(self.type_embd_data), + "type_embd_data": to_numpy_array(type_embd_data), }, "geo_compress": self.geo_compress, } if self.geo_compress: + compress_data = ( + self.compress_data + if hasattr(self, "compress_data") + else obj.compress_data + ) + compress_info = ( + self.compress_info + if hasattr(self, "compress_info") + else obj.compress_info + ) compress_dict["@variables"]["compress_data"] = [ - to_numpy_array(d) for d in self.compress_data + to_numpy_array(d) for d in compress_data ] compress_dict["@variables"]["compress_info"] = [ - to_numpy_array(i) for i in self.compress_info + to_numpy_array(i) for i in compress_info ] data["compress"] = compress_dict return data @@ -299,7 +314,13 @@ def _load_compress_data(self, compress: dict) -> None: variables = compress["@variables"] self.type_embd_data = variables["type_embd_data"] self.geo_compress = compress.get("geo_compress", False) + self.tebd_compress = True + self.se_atten.type_embd_data = self.type_embd_data + self.se_atten.tebd_compress = True + self.se_atten.geo_compress = self.geo_compress if self.geo_compress: self.compress_data = variables["compress_data"] self.compress_info = variables["compress_info"] + self.se_atten.compress_data = self.compress_data + self.se_atten.compress_info = self.compress_info self.compress = True diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index bdac1e0cc0..c7e4bc8257 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -28,6 +28,7 @@ EnvMat, NetworkCollection, PairExcludeMask, + tabulate_fusion, ) from deepmd.dpmodel.utils.env_mat_stat import ( EnvMatStatSe, @@ -44,6 +45,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.tabulate_math import ( + DPTabulate, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -388,6 +392,106 @@ def cal_g( gg = self.embeddings[embedding_idx].call(ss) return gg + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Enable descriptor compression by tabulating embedding networks.""" + if self.compress: + raise ValueError("Compression is already enabled.") + table = DPTabulate( + self, + self.neuron, + self.type_one_side, + self.exclude_types, + self.activation_function, + ) + lower, upper = table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self._store_compress_data( + table.data, + [table_extrapolate, table_stride_1, table_stride_2, check_frequency], + lower, + upper, + ) + self.compress = True + + def _store_compress_data( + self, + table_data: dict[str, Array], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + """Store tabulated embedding-net data in the descriptor state.""" + compress_data = [] + compress_info = [] + dtype = self.davg.dtype + ndim = 1 if self.type_one_side else 2 + for embedding_idx in range(self.ntypes**ndim): + if self.type_one_side: + ii = embedding_idx + ti = -1 + else: + ii = embedding_idx // self.ntypes + ti = embedding_idx % self.ntypes + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + if net not in table_data: + compress_data.append(np.asarray([], dtype=dtype)) + compress_info.append(np.asarray([], dtype=dtype)) + continue + compress_data.append(np.asarray(table_data[net], dtype=dtype)) + compress_info.append( + np.asarray( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=dtype, + ) + ) + self.compress_data = compress_data + self.compress_info = compress_info + + def _tabulate_fusion_se_a( + self, + table: Array, + table_info: Array, + em_x: Array, + em: Array, + last_layer_size: int, + ) -> Array: + """Pure Array API implementation of tabulate_fusion_se_a forward.""" + xp = array_api_compat.array_namespace(em_x, em) + values = tabulate_fusion( + table, + table_info, + em_x, + last_layer_size, + reference=em, + ) + return xp.sum(em[:, :, :, None] * values[:, :, None, :], axis=1) + + def _add_to_slice(self, array: Array, start: int, end: int, value: Array) -> Array: + """Return ``array`` with ``value`` added to ``array[start:end]``.""" + xp = array_api_compat.array_namespace(array, value) + return xp.concat( + [array[:start], array[start:end] + value, array[end:]], + axis=0, + ) + def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], @@ -450,12 +554,22 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] + exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) + if self.compress: + gr = self._call_compressed(rr, atype_ext, nlist, exclude_mask) + gr = xp.astype(gr, input_dtype) + gr = xp.reshape(gr, (nf, nloc, ng, 4)) + gr /= self.nnei + gr1 = gr[:, :, : self.axis_neuron, :] + grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) + return grrg, gr[..., 1:], None, None, ww + gr = xp.zeros( [nf * nloc, ng, 4], dtype=input_dtype, device=array_api_compat.device(coord_ext), ) - exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) @@ -499,8 +613,11 @@ def call( tr = tr * xp.astype(mm[:, :, None], tr.dtype) ss = tr[..., 0:1] gg = self.cal_g(ss, (ti, tt)) - gr_s[s:e] = gr_s[s:e] + xp.sum( - gg[:, :, :, None] * tr[:, :, None, :], axis=1 + gr_s = self._add_to_slice( + gr_s, + s, + e, + xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1), ) gr = xp.take(gr_s, unsort_idx, axis=0) gr = xp.reshape(gr, (nf, nloc, ng, 4)) @@ -513,6 +630,88 @@ def call( grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww + def _call_compressed( + self, + rr: Array, + atype_ext: Array, + nlist: Array, + exclude_mask: Array, + ) -> Array: + """Compressed forward path for the SE-A descriptor.""" + xp = array_api_compat.array_namespace(rr, atype_ext, nlist) + nf, nloc, nnei, _ = rr.shape + sec = self.sel_cumsum + ng = self.neuron[-1] + nfnl = nf * nloc + gr = xp.zeros( + [nfnl, 4, ng], + dtype=rr.dtype, + device=array_api_compat.device(rr), + ) + exclude_mask = xp.reshape(exclude_mask, (nfnl, nnei)) + rr = xp.reshape(rr, (nfnl, nnei, 4)) + rr = xp.astype(rr, self.dstd.dtype) + + if self.type_one_side: + for embedding_idx, (compress_data_ii, compress_info_ii) in enumerate( + zip(self.compress_data, self.compress_info, strict=True) + ): + if array_api_compat.size(compress_data_ii) == 0: + continue + mm = exclude_mask[:, sec[embedding_idx] : sec[embedding_idx + 1]] + rr_i = rr[:, sec[embedding_idx] : sec[embedding_idx + 1], :] + rr_i = rr_i * xp.astype(mm[:, :, None], rr_i.dtype) + ss = rr_i[:, :, :1] + gr += self._tabulate_fusion_se_a( + compress_data_ii, + compress_info_ii, + ss, + rr_i, + ng, + ) + else: + atype_loc = xp.reshape(atype_ext[:, :nloc], (nfnl,)) + sort_idx = xp.argsort(atype_loc) + unsort_idx = xp.argsort(sort_idx) + rr_s = xp.take(rr, sort_idx, axis=0) + mask_s = xp.take(exclude_mask, sort_idx, axis=0) + dev = array_api_compat.device(rr) + gr_s = xp.zeros([nfnl, 4, ng], dtype=rr.dtype, device=dev) + type_ends = [] + offset = 0 + for ti in range(self.ntypes): + offset += int(xp.sum(xp.astype(atype_loc == ti, xp.int32))) + type_ends.append(offset) + type_starts = [0, *type_ends[:-1]] + for ti in range(self.ntypes): + s, e = type_starts[ti], type_ends[ti] + if s == e: + continue + for tt in range(self.ntypes): + embedding_idx = tt * self.ntypes + ti + compress_data_ii = self.compress_data[embedding_idx] + if array_api_compat.size(compress_data_ii) == 0: + continue + compress_info_ii = self.compress_info[embedding_idx] + mm = mask_s[s:e, sec[tt] : sec[tt + 1]] + rr_i = rr_s[s:e, sec[tt] : sec[tt + 1], :] + rr_i = rr_i * xp.astype(mm[:, :, None], rr_i.dtype) + ss = rr_i[:, :, :1] + gr_s = self._add_to_slice( + gr_s, + s, + e, + self._tabulate_fusion_se_a( + compress_data_ii, + compress_info_ii, + ss, + rr_i, + ng, + ), + ) + gr = xp.take(gr_s, unsort_idx, axis=0) + return xp.permute_dims(gr, (0, 2, 1)) + def serialize(self) -> dict: """Serialize the descriptor to dict.""" if not self.type_one_side and self.exclude_types: diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 624889d85f..e992409452 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -28,6 +28,7 @@ EnvMat, NetworkCollection, PairExcludeMask, + tabulate_fusion, ) from deepmd.dpmodel.utils.env_mat_stat import ( EnvMatStatSe, @@ -44,6 +45,9 @@ from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.tabulate_math import ( + DPTabulate, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -367,6 +371,79 @@ def cal_g( gg = self.embeddings[(ll,)].call(ss) return gg + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Enable descriptor compression by tabulating embedding networks.""" + if self.compress: + raise ValueError("Compression is already enabled.") + table = DPTabulate( + self, + self.neuron, + self.type_one_side, + self.exclude_types, + self.activation_function, + ) + lower, upper = table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self._store_compress_data( + table.data, + [table_extrapolate, table_stride_1, table_stride_2, check_frequency], + lower, + upper, + ) + self.compress = True + + def _store_compress_data( + self, + table_data: dict[str, Array], + table_config: list[int | float], + lower: dict[str, int], + upper: dict[str, int], + ) -> None: + """Store tabulated embedding-net data in the descriptor state.""" + compress_data = [] + compress_info = [] + dtype = self.davg.dtype + for embedding_idx in range(self.ntypes): + net = "filter_-1_net_" + str(embedding_idx) + if net not in table_data: + compress_data.append(np.asarray([], dtype=dtype)) + compress_info.append(np.asarray([], dtype=dtype)) + continue + compress_data.append(np.asarray(table_data[net], dtype=dtype)) + compress_info.append( + np.asarray( + [ + lower[net], + upper[net], + upper[net] * table_config[0], + table_config[1], + table_config[2], + table_config[3], + ], + dtype=dtype, + ) + ) + self.compress_data = compress_data + self.compress_info = compress_info + + def _tabulate_fusion_se_r( + self, + table: Array, + table_info: Array, + em_x: Array, + last_layer_size: int, + ) -> Array: + """Pure Array API implementation of tabulate_fusion_se_r forward.""" + return tabulate_fusion(table, table_info, em_x, last_layer_size) + @cast_precision def call( self, @@ -429,14 +506,34 @@ def call( ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) rr = xp.astype(rr, xyz_scatter.dtype) - for tt in range(self.ntypes): - mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] - tr = rr[:, :, sec[tt] : sec[tt + 1], :] - tr = tr * xp.astype(mm[:, :, :, None], tr.dtype) - gg = self.cal_g(tr, tt) - gg = xp.mean(gg, axis=2) - # nf x nloc x ng x 1 - xyz_scatter += gg * (self.sel[tt] / self.nnei) + if self.compress: + rr = xp.reshape(rr, (nf * nloc, nnei, 1)) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + for tt, (compress_data_ii, compress_info_ii) in enumerate( + zip(self.compress_data, self.compress_info, strict=True) + ): + if array_api_compat.size(compress_data_ii) == 0: + continue + mm = exclude_mask[:, sec[tt] : sec[tt + 1]] + tr = rr[:, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, None], tr.dtype) + gg = self._tabulate_fusion_se_r( + compress_data_ii, + compress_info_ii, + tr, + ng, + ) + gg = xp.reshape(xp.sum(gg, axis=1), (nf, nloc, ng)) + xyz_scatter += gg / self.nnei + else: + for tt in range(self.ntypes): + mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]] + tr = rr[:, :, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, :, None], tr.dtype) + gg = self.cal_g(tr, tt) + gg = xp.mean(gg, axis=2) + # nf x nloc x ng x 1 + xyz_scatter += gg * (self.sel[tt] / self.nnei) res_rescale = 1.0 / 5.0 res = xyz_scatter * res_rescale diff --git a/deepmd/dpmodel/entrypoints/__init__.py b/deepmd/dpmodel/entrypoints/__init__.py new file mode 100644 index 0000000000..f732469b75 --- /dev/null +++ b/deepmd/dpmodel/entrypoints/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Command-line entry points for the DPModel backend.""" diff --git a/deepmd/dpmodel/entrypoints/compress.py b/deepmd/dpmodel/entrypoints/compress.py new file mode 100644 index 0000000000..edc942dbb3 --- /dev/null +++ b/deepmd/dpmodel/entrypoints/compress.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Compress a native DPModel file by tabulating embedding networks.""" + +import logging + +from deepmd.dpmodel.entrypoints.compress_common import ( + enable_model_compression, + resolve_min_nbor_dist, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, +) + +log = logging.getLogger(__name__) + + +def enable_compression( + input_file: str, + output: str, + stride: float = 0.01, + extrapolate: int = 5, + check_frequency: int = -1, + training_script: str | None = None, +) -> None: + """Compress a native ``.dp``/``.yaml`` model.""" + model_dict = load_dp_model(input_file) + model = BaseModel.deserialize(model_dict["model"]) + + min_nbor_dist = resolve_min_nbor_dist( + model, + [model_dict], + training_script, + ) + enable_model_compression(model, min_nbor_dist, stride, extrapolate, check_frequency) + + compressed_model_dict = model_dict.copy() + compressed_model_dict["model"] = model.serialize() + compressed_model_dict["min_nbor_dist"] = float(min_nbor_dist) + save_dp_model(output, compressed_model_dict) + log.info("Compressed model saved to %s", output) diff --git a/deepmd/dpmodel/entrypoints/compress_common.py b/deepmd/dpmodel/entrypoints/compress_common.py new file mode 100644 index 0000000000..03d6c52897 --- /dev/null +++ b/deepmd/dpmodel/entrypoints/compress_common.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Shared helpers for native-model compression entrypoints.""" + +import logging +from collections.abc import ( + Iterable, +) +from typing import ( + Any, +) + +import numpy as np + +from deepmd.common import ( + j_loader, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + get_data, +) + +log = logging.getLogger(__name__) + + +def to_float(value: Any) -> float | None: + """Convert plain, NumPy, or framework-wrapped scalar values to float.""" + if value is None: + return None + value = getattr(value, "value", value) + return float(np.asarray(value)) + + +def get_saved_min_nbor_dist(data: dict) -> float | None: + """Read min_nbor_dist from known serialized-model metadata locations.""" + min_nbor_dist = to_float(data.get("min_nbor_dist")) + if min_nbor_dist is not None: + return min_nbor_dist + constants = data.get("constants", {}) + return to_float(constants.get("min_nbor_dist")) + + +def compute_min_nbor_dist( + training_script: str, + update_sel_cls: type[Any] | None = None, +) -> float: + """Compute min_nbor_dist from the training data.""" + if update_sel_cls is None: + from deepmd.dpmodel.utils.update_sel import ( + UpdateSel, + ) + + update_sel_cls = UpdateSel + + jdata = update_deepmd_input(j_loader(training_script)) + type_map = jdata["model"].get("type_map", None) + train_data = get_data( + jdata["training"]["training_data"], + 0, + type_map, + None, + ) + return float(update_sel_cls().get_min_nbor_dist(train_data)) + + +def resolve_min_nbor_dist( + model: Any, + metadata_sources: Iterable[dict], + training_script: str | None, + update_sel_cls: type[Any] | None = None, +) -> float: + """Resolve min_nbor_dist from model, saved metadata, or training data.""" + min_nbor_dist = to_float(model.get_min_nbor_dist()) + if min_nbor_dist is None: + for metadata in metadata_sources: + min_nbor_dist = get_saved_min_nbor_dist(metadata) + if min_nbor_dist is not None: + break + if min_nbor_dist is None: + log.info( + "Minimal neighbor distance is not saved in the model, " + "compute it from the training data." + ) + if training_script is None: + raise ValueError( + "The model does not have a minimum neighbor distance, " + "so the training script and data must be provided " + "(via -t,--training-script)." + ) + min_nbor_dist = compute_min_nbor_dist(training_script, update_sel_cls) + return float(min_nbor_dist) + + +def enable_model_compression( + model: Any, + min_nbor_dist: float, + stride: float, + extrapolate: int, + check_frequency: int, +) -> None: + """Set compression inputs and enable descriptor tabulation.""" + model.min_nbor_dist = float(min_nbor_dist) + model.enable_compression( + table_extrapolate=extrapolate, + table_stride_1=stride, + table_stride_2=stride * 10, + check_frequency=check_frequency, + ) diff --git a/deepmd/dpmodel/entrypoints/main.py b/deepmd/dpmodel/entrypoints/main.py new file mode 100644 index 0000000000..3ca3adf11b --- /dev/null +++ b/deepmd/dpmodel/entrypoints/main.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""DeePMD-kit entry point for the DPModel backend.""" + +import argparse +from pathlib import ( + Path, +) + +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.dpmodel.entrypoints.compress import ( + enable_compression, +) +from deepmd.loggers.loggers import ( + set_log_handles, +) +from deepmd.main import ( + parse_args, +) + +__all__ = ["main"] + + +def main(args: list[str] | argparse.Namespace | None = None) -> None: + """DPModel backend command dispatcher.""" + if not isinstance(args, argparse.Namespace): + args = parse_args(args=args) + + set_log_handles( + args.log_level, + Path(args.log_path) if args.log_path else None, + mpi_log=None, + ) + + if args.command == "compress": + enable_compression( + input_file=format_model_suffix( + args.input, + preferred_backend="dp", + strict_prefer=True, + ), + output=format_model_suffix( + args.output, + preferred_backend="dp", + strict_prefer=True, + ), + stride=args.step, + extrapolate=args.extrapolate, + check_frequency=args.frequency, + training_script=args.training_script, + ) + elif args.command is None: + pass + else: + raise RuntimeError( + f"Unsupported command '{args.command}' for the DPModel backend." + ) diff --git a/deepmd/dpmodel/utils/__init__.py b/deepmd/dpmodel/utils/__init__.py index 66dd8fe3c1..acf16fb8a2 100644 --- a/deepmd/dpmodel/utils/__init__.py +++ b/deepmd/dpmodel/utils/__init__.py @@ -65,6 +65,9 @@ save_dp_model, traverse_model_dict, ) +from .tabulate import ( + tabulate_fusion, +) from .training_utils import ( compute_total_numb_batch, resolve_model_prob, @@ -119,6 +122,7 @@ "save_dp_model", "segment_mean", "segment_sum", + "tabulate_fusion", "to_face_distance", "traverse_model_dict", ] diff --git a/deepmd/dpmodel/utils/tabulate.py b/deepmd/dpmodel/utils/tabulate.py new file mode 100644 index 0000000000..b5bf2f3322 --- /dev/null +++ b/deepmd/dpmodel/utils/tabulate.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat + +from deepmd.dpmodel.array_api import ( + Array, +) + + +def tabulate_fusion( + table: Array, + table_info: Array, + em_x: Array, + last_layer_size: int, + reference: Array | None = None, +) -> Array: + """Evaluate tabulated embedding-net values with C1 extrapolation.""" + if reference is None: + xp = array_api_compat.array_namespace(em_x) + reference = em_x + else: + xp = array_api_compat.array_namespace(em_x, reference) + device = array_api_compat.device(reference) + table = xp.asarray(table[...], dtype=reference.dtype, device=device) + table_info = xp.asarray(table_info[...], dtype=reference.dtype, device=device) + + nloc, nnei = em_x.shape[:2] + xx = xp.reshape(em_x, (nloc, nnei)) + lower = table_info[0] + upper = table_info[1] + table_max = table_info[2] + stride0 = table_info[3] + stride1 = table_info[4] + + zeros = xp.zeros(xx.shape, dtype=xp.int64, device=device) + nspline = table.shape[0] + last_idx = xp.full(xx.shape, nspline - 1, dtype=xp.int64, device=device) + first_stride = xp.astype(xp.floor((upper - lower) / stride0), xp.int64) + first_stride_value = xp.astype(first_stride, reference.dtype) + + first_idx = xp.astype(xp.floor((xx - lower) / stride0), xp.int64) + second_idx = first_stride + xp.astype(xp.floor((xx - upper) / stride1), xp.int64) + table_idx = xp.where( + xx < lower, + zeros, + xp.where( + xx < upper, + first_idx, + xp.where(xx < table_max, second_idx, last_idx), + ), + ) + table_idx = xp.minimum(xp.maximum(table_idx, zeros), last_idx) + + table_idx_value = xp.astype(table_idx, reference.dtype) + dx_first = xx - (table_idx_value * stride0 + lower) + dx_second = xx - ((table_idx_value - first_stride_value) * stride1 + upper) + dx_high = table_max - ( + (xp.astype(last_idx, reference.dtype) - first_stride_value) * stride1 + upper + ) + dx = xp.where( + xx < lower, + xp.zeros_like(xx), + xp.where( + xx < upper, + dx_first, + xp.where(xx < table_max, dx_second, dx_high), + ), + ) + extrapolate_delta = xp.where( + xx < lower, + xx - lower, + xp.where(xx >= table_max, xx - table_max, xp.zeros_like(xx)), + ) + + coeff = xp.take(table, xp.reshape(table_idx, (-1,)), axis=0) + coeff = xp.reshape(coeff, (nloc, nnei, last_layer_size, 6)) + dx = xp.reshape(dx, (nloc, nnei, 1)) + values = ( + coeff[..., 0] + + ( + coeff[..., 1] + + ( + coeff[..., 2] + + (coeff[..., 3] + (coeff[..., 4] + coeff[..., 5] * dx) * dx) * dx + ) + * dx + ) + * dx + ) + values_grad = ( + coeff[..., 1] + + ( + 2 * coeff[..., 2] + + (3 * coeff[..., 3] + (4 * coeff[..., 4] + 5 * coeff[..., 5] * dx) * dx) + * dx + ) + * dx + ) + extrapolate_delta = xp.reshape(extrapolate_delta, (nloc, nnei, 1)) + return values + values_grad * extrapolate_delta diff --git a/deepmd/jax/entrypoints/compress.py b/deepmd/jax/entrypoints/compress.py new file mode 100644 index 0000000000..a54addd9be --- /dev/null +++ b/deepmd/jax/entrypoints/compress.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Compress a JAX model by tabulating embedding networks.""" + +import logging +from pathlib import ( + Path, +) + +from deepmd.dpmodel.entrypoints.compress_common import ( + enable_model_compression, + resolve_min_nbor_dist, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) +from deepmd.jax.utils.serialization import ( + deserialize_to_file, + serialize_from_file, +) +from deepmd.jax.utils.update_sel import ( + UpdateSel, +) + +log = logging.getLogger(__name__) + + +def enable_compression( + input_file: str, + output: str, + stride: float = 0.01, + extrapolate: int = 5, + check_frequency: int = -1, + training_script: str | None = None, +) -> None: + """Compress a JAX ``.jax``/``.hlo`` model.""" + data = serialize_from_file(input_file) + model = BaseModel.deserialize(data["model"]) + + metadata_sources = [data] + if Path(input_file).suffix == ".hlo": + metadata_sources.append(load_dp_model(input_file)) + min_nbor_dist = resolve_min_nbor_dist( + model, + metadata_sources, + training_script, + UpdateSel, + ) + enable_model_compression(model, min_nbor_dist, stride, extrapolate, check_frequency) + + compressed_data = data.copy() + compressed_data["model"] = model.serialize() + compressed_data["min_nbor_dist"] = float(min_nbor_dist) + deserialize_to_file(output, compressed_data) + log.info("Compressed model saved to %s", output) diff --git a/deepmd/jax/entrypoints/main.py b/deepmd/jax/entrypoints/main.py index a365b1dea8..1d3b957c1f 100644 --- a/deepmd/jax/entrypoints/main.py +++ b/deepmd/jax/entrypoints/main.py @@ -6,6 +6,12 @@ Path, ) +from deepmd.backend.suffix import ( + format_model_suffix, +) +from deepmd.jax.entrypoints.compress import ( + enable_compression, +) from deepmd.jax.entrypoints.freeze import ( freeze, ) @@ -51,6 +57,23 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None: train(**dict_args) elif args.command == "freeze": freeze(**dict_args) + elif args.command == "compress": + enable_compression( + input_file=format_model_suffix( + args.input, + preferred_backend="jax", + strict_prefer=True, + ), + output=format_model_suffix( + args.output, + preferred_backend="jax", + strict_prefer=True, + ), + stride=args.step, + extrapolate=args.extrapolate, + check_frequency=args.frequency, + training_script=args.training_script, + ) elif args.command is None: pass else: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 14386d9f3d..b993acc7b8 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -2,6 +2,9 @@ from pathlib import ( Path, ) +from typing import ( + Any, +) import numpy as np import orbax.checkpoint as ocp @@ -22,6 +25,141 @@ ) +def _state_sequence_to_numpy_list(state_value: Any) -> list[np.ndarray]: + """Convert an Orbax-restored list/dict sequence to NumPy arrays.""" + if isinstance(state_value, dict): + values = [state_value[key] for key in sorted(state_value)] + else: + values = state_value + return [np.asarray(getattr(value, "value", value)) for value in values] + + +def _state_value_to_numpy(state_value: Any) -> np.ndarray: + """Convert an Orbax-restored state value to a NumPy array.""" + return np.asarray(getattr(state_value, "value", state_value)) + + +def _restore_compression_slots_from_state(obj: Any, state: Any) -> None: + """Create compression variable slots before replacing an NNX state. + + A compressed ``.jax`` checkpoint stores tabulation arrays in the NNX state, + while ``model_def_script`` still describes the original uncompressed model. + Build the corresponding descriptor attributes first so Flax can match the + restored state keys. + """ + if not isinstance(state, dict): + return + if ( + (hasattr(obj, "compress") or hasattr(obj, "geo_compress")) + and "compress_data" in state + and "compress_info" in state + ): + obj.compress_data = _state_sequence_to_numpy_list(state["compress_data"]) + obj.compress_info = _state_sequence_to_numpy_list(state["compress_info"]) + if hasattr(obj, "compress"): + obj.compress = True + if hasattr(obj, "geo_compress"): + obj.geo_compress = True + if hasattr(obj, "se_atten"): + obj.se_atten.compress_data = obj.compress_data + obj.se_atten.compress_info = obj.compress_info + if hasattr(obj.se_atten, "geo_compress"): + obj.se_atten.geo_compress = True + if ( + hasattr(obj, "compress") or hasattr(obj, "tebd_compress") + ) and "type_embd_data" in state: + obj.type_embd_data = _state_value_to_numpy(state["type_embd_data"]) + if hasattr(obj, "compress"): + obj.compress = True + if hasattr(obj, "tebd_compress"): + obj.tebd_compress = True + if hasattr(obj, "geo_compress"): + obj.geo_compress = "compress_data" in state and "compress_info" in state + if hasattr(obj, "se_atten"): + obj.se_atten.type_embd_data = obj.type_embd_data + obj.se_atten.tebd_compress = True + if hasattr(obj.se_atten, "geo_compress"): + obj.se_atten.geo_compress = getattr(obj, "geo_compress", False) + if getattr(obj, "geo_compress", False): + obj.se_atten.compress_data = obj.compress_data + obj.se_atten.compress_info = obj.compress_info + for name, child_state in state.items(): + if not isinstance(child_state, dict): + continue + if isinstance(name, int): + try: + child = obj[name] + except (IndexError, KeyError, TypeError): + continue + else: + if not hasattr(obj, name): + continue + child = getattr(obj, name) + _restore_compression_slots_from_state(child, child_state) + if name == "se_atten" and hasattr(obj, "compress"): + child_type_embd_data = getattr(child, "type_embd_data", None) + if child_type_embd_data is not None: + obj.type_embd_data = child_type_embd_data + obj.tebd_compress = getattr(child, "tebd_compress", True) + obj.compress = True + if getattr(child, "geo_compress", False): + obj.geo_compress = True + obj.compress_data = child.compress_data + obj.compress_info = child.compress_info + + +def _to_optional_float(value: Any) -> float | None: + if value is None: + return None + return float(np.asarray(getattr(value, "value", value))) + + +def _set_model_min_nbor_dist_from_data(model: BaseModel, data: dict) -> None: + if model.get_min_nbor_dist() is not None: + return + min_nbor_dist = _to_optional_float(data.get("min_nbor_dist")) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float( + data.get("constants", {}).get("min_nbor_dist") + ) + if min_nbor_dist is not None: + model.min_nbor_dist = min_nbor_dist + + +def _find_compressed_type_two_side_descriptors(data: Any) -> list[str]: + """Find compressed descriptors whose JAX HLO path is not exportable.""" + if not isinstance(data, dict): + return [] + matches = [] + descriptor_type = data.get("type") + if ( + descriptor_type in {"se_e2_a", "se_a", "dpa1", "se_atten"} + and "compress" in data + and data.get("type_one_side") is False + ): + matches.append(descriptor_type) + for value in data.values(): + if isinstance(value, dict): + matches.extend(_find_compressed_type_two_side_descriptors(value)) + elif isinstance(value, list): + for item in value: + matches.extend(_find_compressed_type_two_side_descriptors(item)) + return matches + + +def _check_compressed_hlo_exportable(data: dict) -> None: + """Reject compressed descriptors that cannot be traced to StableHLO.""" + descriptor_types = _find_compressed_type_two_side_descriptors(data.get("model", {})) + if descriptor_types: + names = ", ".join(sorted(set(descriptor_types))) + raise ValueError( + "Compressed JAX HLO export does not support type_one_side=False for " + f"{names} descriptors because the compressed path uses data-dependent " + "type slices that cannot be traced. Use type_one_side=True for HLO " + "export, or write a .jax checkpoint instead." + ) + + def deserialize_to_file(model_file: str, data: dict) -> None: """Deserialize the dictionary to a model file. @@ -34,7 +172,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None: """ if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) - model_def_script = data["model_def_script"] + model_def_script = data["model_def_script"].copy() + min_nbor_dist = _to_optional_float(data.get("min_nbor_dist")) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float( + data.get("constants", {}).get("min_nbor_dist") + ) + if min_nbor_dist is not None: + model_def_script["_min_nbor_dist"] = min_nbor_dist _, state = nnx.split(model) with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") @@ -47,7 +192,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ), ) elif model_file.endswith(".hlo"): + _check_compressed_hlo_exportable(data) model = BaseModel.deserialize(data["model"]) + _set_model_min_nbor_dist_from_data(model, data) model_def_script = data["model_def_script"] call_lower = model.call_common_lower @@ -185,10 +332,14 @@ def convert_str_to_int_key(item: dict) -> None: model_def_script = data.model_def_script abstract_model = get_model(model_def_script) + _restore_compression_slots_from_state(abstract_model, state) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(state) model = nnx.merge(graphdef, abstract_state) model_dict = model.serialize() + min_nbor_dist = _to_optional_float(model.get_min_nbor_dist()) + if min_nbor_dist is None: + min_nbor_dist = _to_optional_float(model_def_script.get("_min_nbor_dist")) data = { "backend": "JAX", "jax_version": jax.__version__, @@ -196,6 +347,8 @@ def convert_str_to_int_key(item: dict) -> None: "model_def_script": model_def_script, "@variables": {}, } + if min_nbor_dist is not None: + data["min_nbor_dist"] = min_nbor_dist return data elif model_file.endswith(".hlo"): data = load_dp_model(model_file) diff --git a/deepmd/main.py b/deepmd/main.py index 43f40dc214..56e1863725 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -585,6 +585,8 @@ def main_parser() -> argparse.ArgumentParser: dp compress dp --tf compress -i frozen_model.pb -o compressed_model.pb dp --pt compress -i frozen_model.pth -o compressed_model.pth + dp --dp compress -i frozen_model.dp -o compressed_model.dp + dp --jax compress -i frozen_model.hlo -o compressed_model.hlo """ ), ) @@ -593,14 +595,14 @@ def main_parser() -> argparse.ArgumentParser: "--input", default="frozen_model", type=str, - help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth", + help="The original frozen model, which will be compressed by the code. Filename (prefix) of the input model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", ) parser_compress.add_argument( "-o", "--output", default="frozen_model_compressed", type=str, - help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth", + help="The compressed model. Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; DPModel backend: suffix is .dp; JAX backend: suffix is .hlo or .jax", ) parser_compress.add_argument( "-s", diff --git a/source/tests/common/dpmodel/test_model_compression.py b/source/tests/common/dpmodel/test_model_compression.py new file mode 100644 index 0000000000..29f5c281a8 --- /dev/null +++ b/source/tests/common/dpmodel/test_model_compression.py @@ -0,0 +1,431 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import tempfile +import unittest +from pathlib import ( + Path, +) + +import numpy as np + +from deepmd.dpmodel.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.dpmodel.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.dpmodel.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.dpmodel.entrypoints.compress import ( + enable_compression, +) +from deepmd.dpmodel.model.base_model import ( + BaseModel, +) +from deepmd.dpmodel.model.model import ( + get_model, +) +from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, +) + + +class TestDPModelCompression(unittest.TestCase): + def setUp(self) -> None: + self.coord = np.array( + [ + [ + [0.0, 0.0, 0.0], + [1.2, 0.1, 0.0], + [0.1, 1.4, 0.0], + [1.5, 1.5, 0.1], + ] + ], + dtype=np.float64, + ) + self.atype = np.array([[0, 0, 1, 1]], dtype=np.int32) + self.nlist = np.array([[[1, 2], [0, 2], [0, 3], [0, 2]]], dtype=np.int64) + + def _make_descriptor( + self, + type_one_side: bool = True, + exclude_types: list[list[int]] | None = None, + ) -> DescrptSeA: + return DescrptSeA( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + axis_neuron=2, + resnet_dt=False, + type_one_side=type_one_side, + exclude_types=[] if exclude_types is None else exclude_types, + precision="float64", + seed=1234, + ) + + @staticmethod + def _poly5(coeff: np.ndarray, xx: np.ndarray | float) -> np.ndarray: + return ( + coeff[..., 0] + + ( + coeff[..., 1] + + ( + coeff[..., 2] + + (coeff[..., 3] + (coeff[..., 4] + coeff[..., 5] * xx) * xx) * xx + ) + * xx + ) + * xx + ) + + @staticmethod + def _poly5_grad(coeff: np.ndarray, xx: np.ndarray | float) -> np.ndarray: + return ( + coeff[..., 1] + + ( + 2 * coeff[..., 2] + + ( + 3 * coeff[..., 3] + + (4 * coeff[..., 4] + 5 * coeff[..., 5] * xx) * xx + ) + * xx + ) + * xx + ) + + def _make_c1_tabulation_case( + self, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + table_info = np.array([1.0, 3.0, 5.0, 1.0, 1.0, -1.0], dtype=np.float64) + coeff = np.array( + [ + [ + [1.0, 0.5, 0.25, -0.10, 0.03, 0.02], + [2.0, -0.4, 0.20, 0.05, -0.02, 0.01], + ], + [ + [3.0, -0.2, 0.40, 0.15, -0.03, 0.02], + [4.0, 0.3, -0.10, 0.07, 0.04, -0.01], + ], + [ + [5.0, 0.7, -0.20, 0.12, 0.01, -0.04], + [6.0, -0.6, 0.30, -0.05, 0.02, 0.03], + ], + [ + [7.0, 1.1, -0.35, 0.25, 0.08, -0.02], + [8.0, -0.9, 0.45, -0.15, 0.06, 0.01], + ], + ], + dtype=np.float64, + ) + xx = np.array([[0.5, 1.0, 1.5, 3.25, 5.0, 5.5]], dtype=np.float64) + expected = self._expected_c1_tabulation(coeff, table_info, xx[0]) + table = np.reshape(coeff, (coeff.shape[0], -1)) + return table, coeff, table_info, xx, expected + + def _expected_c1_tabulation( + self, + coeff: np.ndarray, + table_info: np.ndarray, + xx: np.ndarray, + ) -> np.ndarray: + lower, upper, table_max, stride0, stride1 = table_info[:5] + first_stride = int(np.floor((upper - lower) / stride0)) + values = [] + for value in xx: + delta = 0.0 + if value < lower: + table_idx = 0 + dx = 0.0 + delta = value - lower + elif value < upper: + table_idx = int(np.floor((value - lower) / stride0)) + dx = value - (table_idx * stride0 + lower) + elif value < table_max: + table_idx = first_stride + int(np.floor((value - upper) / stride1)) + dx = value - ((table_idx - first_stride) * stride1 + upper) + else: + table_idx = coeff.shape[0] - 1 + dx = table_max - ((table_idx - first_stride) * stride1 + upper) + delta = value - table_max + values.append( + self._poly5(coeff[table_idx], dx) + + self._poly5_grad(coeff[table_idx], dx) * delta + ) + return np.asarray(values, dtype=np.float64) + + def test_tabulate_fusion_se_r_c1_extrapolates_outside_table(self) -> None: + table, coeff, table_info, xx, expected = self._make_c1_tabulation_case() + descriptor = DescrptSeR( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 2], + resnet_dt=False, + precision="float64", + seed=1234, + ) + + actual = descriptor._tabulate_fusion_se_r( + table, + table_info, + xx[:, :, None], + expected.shape[-1], + ) + + np.testing.assert_allclose(actual[0], expected, rtol=0.0, atol=1e-12) + np.testing.assert_allclose( + actual[0, 1] - actual[0, 0], + 0.5 * self._poly5_grad(coeff[0], 0.0), + rtol=0.0, + atol=1e-12, + ) + np.testing.assert_allclose( + actual[0, 5] - actual[0, 4], + 0.5 * self._poly5_grad(coeff[-1], 1.0), + rtol=0.0, + atol=1e-12, + ) + + def test_tabulate_fusion_se_a_c1_extrapolates_outside_table(self) -> None: + table, _, table_info, xx, expected = self._make_c1_tabulation_case() + descriptor = self._make_descriptor() + em = np.zeros((1, xx.shape[1], 4), dtype=np.float64) + em[:, :, 0] = 1.0 + expected_out = np.zeros((1, 4, expected.shape[-1]), dtype=np.float64) + expected_out[:, 0, :] = np.sum(expected, axis=0) + + actual = descriptor._tabulate_fusion_se_a( + table, + table_info, + xx[:, :, None], + em, + expected.shape[-1], + ) + + np.testing.assert_allclose(actual, expected_out, rtol=0.0, atol=1e-12) + + def test_tabulate_fusion_se_atten_c1_extrapolates_outside_table(self) -> None: + table, _, table_info, xx, expected = self._make_c1_tabulation_case() + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 2], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + precision="float64", + seed=1234, + ) + em = np.zeros((1, xx.shape[1], 4), dtype=np.float64) + em[:, :, 0] = 1.0 + two_embed = np.full_like(expected[None, :, :], 0.5) + expected_out = np.zeros((1, 4, expected.shape[-1]), dtype=np.float64) + expected_out[:, 0, :] = np.sum(expected * (1.0 + two_embed[0]), axis=0) + + actual = descriptor.se_atten._tabulate_fusion_se_atten( + table, + table_info, + xx[:, :, None], + em, + two_embed, + expected.shape[-1], + ) + + np.testing.assert_allclose(actual, expected_out, rtol=0.0, atol=1e-12) + + def test_se_e2_a_enable_compression(self) -> None: + for type_one_side, exclude_types in ( + (True, []), + (False, []), + (True, [[0, 1]]), + (False, [[0, 1]]), + ): + with self.subTest(type_one_side=type_one_side, exclude_types=exclude_types): + descriptor = self._make_descriptor(type_one_side, exclude_types) + expected = descriptor.call(self.coord, self.atype, self.nlist) + + compressed = DescrptSeA.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeA.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + actual_item, expected_item, rtol=0.0, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, rtol=0.0, atol=1e-10 + ) + + def test_se_e2_r_enable_compression(self) -> None: + descriptor = DescrptSeR( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + resnet_dt=False, + precision="float64", + seed=1234, + ) + expected = descriptor.call(self.coord, self.atype, self.nlist) + + compressed = DescrptSeR.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeR.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + actual_item, expected_item, rtol=0.0, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, rtol=0.0, atol=1e-10 + ) + + def test_se_atten_enable_compression(self) -> None: + for type_one_side in (True, False): + with self.subTest(type_one_side=type_one_side): + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 8], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + type_one_side=type_one_side, + precision="float64", + seed=1234, + ) + expected = descriptor.call(self.coord, self.atype, self.nlist) + + compressed = DescrptDPA1.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(self.coord, self.atype, self.nlist) + + self.assertTrue(compressed.compress) + self.assertTrue(compressed.geo_compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(self.coord, self.atype, self.nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + actual_item, expected_item, rtol=0.0, atol=1e-10 + ) + np.testing.assert_allclose( + reloaded_item, expected_item, rtol=0.0, atol=1e-10 + ) + + def test_dpmodel_compress_entrypoint(self) -> None: + model_data = { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": [1, 1], + "neuron": [4, 8], + "axis_neuron": 2, + "resnet_dt": False, + "type_one_side": True, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + model = get_model(model_data) + model.min_nbor_dist = 1.0 + expected_output = model.call(self.coord, self.atype) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / "model.dp" + output_file = Path(tmpdir) / "model-compressed.dp" + save_dp_model( + str(input_file), + { + "model": model.serialize(), + "model_def_script": model_data, + "min_nbor_dist": 1.0, + }, + ) + + enable_compression( + str(input_file), + str(output_file), + stride=0.01, + extrapolate=5, + check_frequency=-1, + ) + + compressed = load_dp_model(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["min_nbor_dist"], 1.0) + + reloaded_model = BaseModel.deserialize(copy.deepcopy(compressed["model"])) + actual_output = reloaded_model.call(self.coord, self.atype) + self.assertEqual(actual_output.keys(), expected_output.keys()) + for key, expected_value in expected_output.items(): + np.testing.assert_allclose( + actual_output[key], expected_value, rtol=0.0, atol=1e-10 + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/jax/test_model_compression.py b/source/tests/jax/test_model_compression.py new file mode 100644 index 0000000000..32ac8c84b9 --- /dev/null +++ b/source/tests/jax/test_model_compression.py @@ -0,0 +1,435 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import tempfile +import unittest +from importlib.util import ( + find_spec, +) +from pathlib import ( + Path, +) + +import numpy as np + +INSTALLED_JAX = find_spec("jax") is not None and find_spec("orbax") is not None + +if INSTALLED_JAX: + from deepmd.dpmodel.utils.serialization import ( + load_dp_model, + save_dp_model, + ) + from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, + ) + from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, + ) + from deepmd.jax.descriptor.se_e2_r import ( + DescrptSeR, + ) + from deepmd.jax.env import ( + jax, + jnp, + ) + from deepmd.jax.model.model import ( + get_model, + ) + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + from deepmd.main import main as dp_main + + +class TestJAXModelCompression(unittest.TestCase): + def setUp(self) -> None: + self.coord = np.array( + [ + [ + [0.0, 0.0, 0.0], + [1.2, 0.1, 0.0], + [0.1, 1.4, 0.0], + [1.5, 1.5, 0.1], + ] + ], + dtype=np.float64, + ) + self.atype = np.array([[0, 0, 1, 1]], dtype=np.int32) + self.nlist = np.array([[[1, 2], [0, 2], [0, 3], [0, 2]]], dtype=np.int64) + + def _make_model_data(self, type_one_side: bool = True) -> dict: + return { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": [1, 1], + "neuron": [4, 8], + "axis_neuron": 2, + "resnet_dt": False, + "type_one_side": type_one_side, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + + def _make_se_r_model_data(self) -> dict: + return { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_r", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": [1, 1], + "neuron": [4, 8], + "resnet_dt": False, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + + def _make_dpa1_model_data(self, type_one_side: bool = False) -> dict: + return { + "type": "standard", + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa1", + "rcut": 4.0, + "rcut_smth": 3.5, + "sel": 2, + "neuron": [4, 8], + "axis_neuron": 2, + "tebd_dim": 4, + "tebd_input_mode": "strip", + "resnet_dt": False, + "attn_layer": 0, + "type_one_side": type_one_side, + "precision": "float64", + "seed": 1234, + }, + "fitting_net": { + "type": "ener", + "neuron": [8], + "resnet_dt": False, + "precision": "float64", + "seed": 5678, + }, + } + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_se_e2_a_enable_compression(self) -> None: + coord = jnp.array(self.coord) + atype = jnp.array(self.atype) + nlist = jnp.array(self.nlist) + for type_one_side in (True, False): + with self.subTest(type_one_side=type_one_side): + descriptor = DescrptSeA( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + axis_neuron=2, + resnet_dt=False, + type_one_side=type_one_side, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) + + compressed = DescrptSeA.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeA.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), + np.asarray(expected_item), + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), + np.asarray(expected_item), + atol=1e-10, + ) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_se_e2_r_enable_compression(self) -> None: + coord = jnp.array(self.coord) + atype = jnp.array(self.atype) + nlist = jnp.array(self.nlist) + descriptor = DescrptSeR( + rcut=4.0, + rcut_smth=3.5, + sel=[1, 1], + neuron=[4, 8], + resnet_dt=False, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) + + compressed = DescrptSeR.deserialize(copy.deepcopy(descriptor.serialize())) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) + + self.assertTrue(compressed.compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptSeR.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), np.asarray(expected_item), atol=1e-10 + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), np.asarray(expected_item), atol=1e-10 + ) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_se_atten_enable_compression(self) -> None: + coord = jnp.array(self.coord) + atype = jnp.array(self.atype) + nlist = jnp.array(self.nlist) + for type_one_side in (True, False): + with self.subTest(type_one_side=type_one_side): + descriptor = DescrptDPA1( + rcut=4.0, + rcut_smth=3.5, + sel=2, + ntypes=2, + neuron=[4, 8], + axis_neuron=2, + tebd_dim=4, + tebd_input_mode="strip", + resnet_dt=False, + attn_layer=0, + type_one_side=type_one_side, + precision="float64", + seed=1234, + ) + expected = descriptor.call(coord, atype, nlist) + + compressed = DescrptDPA1.deserialize( + copy.deepcopy(descriptor.serialize()) + ) + compressed.enable_compression(1.0, 5, 0.001, 0.01, -1) + actual = compressed.call(coord, atype, nlist) + + self.assertTrue(compressed.compress) + self.assertTrue(compressed.geo_compress) + serialized = compressed.serialize() + self.assertEqual(serialized["@version"], 3) + self.assertIn("compress", serialized) + reloaded = DescrptDPA1.deserialize(copy.deepcopy(serialized)) + reloaded_actual = reloaded.call(coord, atype, nlist) + + for expected_item, actual_item, reloaded_item in zip( + expected, actual, reloaded_actual, strict=True + ): + if expected_item is None: + self.assertIsNone(actual_item) + self.assertIsNone(reloaded_item) + else: + np.testing.assert_allclose( + np.asarray(actual_item), + np.asarray(expected_item), + atol=1e-10, + ) + np.testing.assert_allclose( + np.asarray(reloaded_item), + np.asarray(expected_item), + atol=1e-10, + ) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_compress_entrypoint(self) -> None: + for name, model_data in ( + ("se_e2_a", self._make_model_data()), + ("dpa1", self._make_dpa1_model_data()), + ): + with self.subTest(descriptor=name): + model = get_model(copy.deepcopy(model_data)) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / f"model-{name}.hlo" + output_file = Path(tmpdir) / f"model-{name}-compressed.jax" + save_dp_model( + str(input_file), + { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model.serialize(), + "model_def_script": model_data, + "@variables": { + "stablehlo": np.void(b"stablehlo"), + }, + "constants": { + "min_nbor_dist": 1.0, + }, + }, + ) + + dp_main( + [ + "--jax", + "compress", + "-i", + str(input_file), + "-o", + str(output_file), + "-s", + "0.01", + ] + ) + + compressed = serialize_from_file(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["min_nbor_dist"], 1.0) + if name == "dpa1": + compress = descriptor["compress"] + self.assertTrue(compress["geo_compress"]) + variables = compress["@variables"] + self.assertIn("type_embd_data", variables) + self.assertIn("compress_data", variables) + self.assertIn("compress_info", variables) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_compress_entrypoint_can_write_hlo(self) -> None: + for name, model_data in ( + ("se_e2_a", self._make_model_data(type_one_side=True)), + ("se_e2_r", self._make_se_r_model_data()), + ("dpa1", self._make_dpa1_model_data(type_one_side=True)), + ): + with self.subTest(descriptor=name): + model = get_model(copy.deepcopy(model_data)) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / f"model-{name}.hlo" + output_file = Path(tmpdir) / f"model-{name}-compressed.hlo" + save_dp_model( + str(input_file), + { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model.serialize(), + "model_def_script": model_data, + "@variables": { + "stablehlo": np.void(b"stablehlo"), + }, + "constants": { + "min_nbor_dist": 1.0, + }, + }, + ) + + dp_main( + [ + "--jax", + "compress", + "-i", + str(input_file), + "-o", + str(output_file), + "-s", + "0.01", + ] + ) + + compressed = load_dp_model(str(output_file)) + descriptor = compressed["model"]["descriptor"] + self.assertEqual(descriptor["@version"], 3) + self.assertIn("compress", descriptor) + self.assertEqual(compressed["constants"]["min_nbor_dist"], 1.0) + self.assertIn("stablehlo", compressed["@variables"]) + self.assertIn("stablehlo_no_ghost", compressed["@variables"]) + + @unittest.skipUnless(INSTALLED_JAX, "JAX is not installed") + def test_jax_compress_entrypoint_rejects_type_two_side_hlo(self) -> None: + for name, model_data in ( + ("se_e2_a", self._make_model_data(type_one_side=False)), + ("dpa1", self._make_dpa1_model_data(type_one_side=False)), + ): + with self.subTest(descriptor=name): + model = get_model(copy.deepcopy(model_data)) + + with tempfile.TemporaryDirectory() as tmpdir: + input_file = Path(tmpdir) / f"model-{name}.hlo" + output_file = Path(tmpdir) / f"model-{name}-compressed.hlo" + save_dp_model( + str(input_file), + { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model.serialize(), + "model_def_script": model_data, + "@variables": { + "stablehlo": np.void(b"stablehlo"), + }, + "constants": { + "min_nbor_dist": 1.0, + }, + }, + ) + + with self.assertRaisesRegex( + ValueError, + "Compressed JAX HLO export does not support " + "type_one_side=False", + ): + dp_main( + [ + "--jax", + "compress", + "-i", + str(input_file), + "-o", + str(output_file), + "-s", + "0.01", + ] + ) + + +if __name__ == "__main__": + unittest.main()