From 02097f1bb654c539366cf75723ba040802899140 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 12:49:09 +0800 Subject: [PATCH 1/4] fix(jax): restore jax2tf savedmodel export --- deepmd/jax/jax2tf/format_nlist.py | 70 +++++-- deepmd/jax/jax2tf/make_model.py | 113 +++++----- deepmd/jax/jax2tf/nlist.py | 178 +++++++++++++--- deepmd/jax/jax2tf/region.py | 77 ++++--- deepmd/jax/jax2tf/serialization.py | 320 ++++++++++++++++++++++++++++- deepmd/jax/utils/serialization.py | 8 +- 6 files changed, 642 insertions(+), 124 deletions(-) diff --git a/deepmd/jax/jax2tf/format_nlist.py b/deepmd/jax/jax2tf/format_nlist.py index f9b216fb27..75e0329ac6 100644 --- a/deepmd/jax/jax2tf/format_nlist.py +++ b/deepmd/jax/jax2tf/format_nlist.py @@ -1,24 +1,66 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Compatibility wrappers for TensorFlow neighbor-list formatting.""" +"""TensorFlow graph helpers for JAX/jax2tf SavedModel export. -from typing import ( - Any, -) +This module is not a generic TF2 compatibility wrapper. The functions here are +traced while saving the JAX ``.savedmodel`` artifact, before control reaches +``jax2tf.convert``. Keep the implementation in plain TensorFlow ops so +AutoGraph can see the tensor-dependent branches and emit graph control flow. +Routing through ndtensorflow/dpmodel helpers can leave symbolic shape +comparisons as Python ``if`` statements during SavedModel tracing. +""" import tensorflow as tf -from deepmd.tf2.common import ( - to_tf_tensor, -) -from deepmd.tf2.utils._dpmodel import format_nlist as tf2_format_nlist - -__all__ = ["format_nlist"] - +@tf.function(autograph=True) def format_nlist( - extended_coord: Any, - nlist: Any, + extended_coord: tf.Tensor, + nlist: tf.Tensor, nsel: int, rcut: float, ) -> tf.Tensor: - return to_tf_tensor(tf2_format_nlist(extended_coord, nlist, nsel, rcut)) + """Format neighbor list. + + If nnei == nsel, do nothing; + If nnei < nsel, pad -1; + If nnei > nsel, sort by distance and truncate. + """ + nlist_shape = tf.shape(nlist) + n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2] + extended_coord = tf.reshape(extended_coord, [n_nf, -1, 3]) + + if n_nsel < nsel: + ret = tf.concat( + [ + nlist, + tf.fill([n_nf, n_nloc, nsel - n_nsel], tf.cast(-1, nlist.dtype)), + ], + axis=-1, + ) + elif n_nsel > nsel: + m_real_nei = nlist >= 0 + ret = tf.where(m_real_nei, nlist, tf.zeros_like(nlist)) + coord0 = extended_coord[:, :n_nloc, :] + index = tf.reshape(ret, [n_nf, n_nloc * n_nsel]) + coord1 = tf.gather(extended_coord, index, batch_dims=1) + coord1 = tf.reshape(coord1, [n_nf, n_nloc, n_nsel, 3]) + rr2 = tf.reduce_sum(tf.square(coord0[:, :, None, :] - coord1), axis=-1) + rr2 = tf.where( + m_real_nei, + rr2, + tf.fill(tf.shape(rr2), tf.constant(float("inf"), rr2.dtype)), + ) + ret_mapping = tf.argsort(rr2, axis=-1) + rr2 = tf.sort(rr2, axis=-1) + ret = tf.gather(ret, ret_mapping, batch_dims=2) + ret = tf.where( + rr2 > rcut * rcut, + tf.fill(tf.shape(ret), tf.cast(-1, ret.dtype)), + ret, + ) + ret = ret[..., :nsel] + else: + ret = nlist + # Reshape anyway; this tells XLA the shape without dynamic shape. + ret = tf.reshape(ret, [n_nf, n_nloc, nsel]) + return ret diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index ebb93492a9..dba6d6946b 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -1,54 +1,47 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Compatibility wrappers for TensorFlow model-call helpers.""" +"""Outer TensorFlow call wrapper for the JAX/jax2tf SavedModel. + +The wrapper builds PBC ghosts, neighbor lists, and output communication around +the lower JAX model. It deliberately uses the graph-safe helpers in this +package instead of the TF2 eager helpers, because this code is traced by +``tf.saved_model.save`` and must keep tensor-shape branches convertible by +AutoGraph before it invokes the jax2tf-converted model body. +""" from collections.abc import ( Callable, ) -from typing import ( - Any, -) import tensorflow as tf from deepmd.dpmodel.output_def import ( ModelOutputDef, ) -from deepmd.tf2.common import ( - to_tf_tensor, - unwrap_value, - wrap_value, +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, ) -from deepmd.tf2.make_model import ( - model_call_from_call_lower as tf2_model_call_from_call_lower, +from deepmd.jax.jax2tf.region import ( + normalize_coord, +) +from deepmd.jax.jax2tf.transform_output import ( + communicate_extended_output, ) - -__all__ = ["model_call_from_call_lower"] - - -def _wrap_call_lower(call_lower: Callable[..., dict[str, Any]]) -> Callable: - def wrapped_call_lower( - extended_coord: Any, - extended_atype: Any, - nlist: Any, - mapping: Any, - **kwargs: Any, - ) -> dict[str, Any]: - return wrap_value( - call_lower( - to_tf_tensor(extended_coord), - to_tf_tensor(extended_atype), - to_tf_tensor(nlist), - to_tf_tensor(mapping), - **{kk: to_tf_tensor(vv) for kk, vv in kwargs.items()}, - ) - ) - - return wrapped_call_lower def model_call_from_call_lower( *, # enforce keyword-only arguments - call_lower: Callable[..., dict[str, Any]], + call_lower: Callable[ + [ + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, + bool, + ], + dict[str, tf.Tensor], + ], rcut: float, sel: list[int], mixed_types: bool, @@ -60,18 +53,44 @@ def model_call_from_call_lower( aparam: tf.Tensor, do_atomic_virial: bool = False, ) -> dict[str, tf.Tensor]: - return unwrap_value( - tf2_model_call_from_call_lower( - call_lower=_wrap_call_lower(call_lower), - rcut=rcut, - sel=sel, - mixed_types=mixed_types, - model_output_def=model_output_def, - coord=coord, - atype=atype, - box=box, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, + """Return model prediction from lower interface.""" + atype_shape = tf.shape(atype) + nframes, nloc = atype_shape[0], atype_shape[1] + cc, bb, fp, ap = coord, box, fparam, aparam + del coord, box, fparam, aparam + if tf.shape(bb)[-1] != 0: + coord_normalized = normalize_coord( + tf.reshape(cc, [nframes, nloc, 3]), + tf.reshape(bb, [nframes, 3, 3]), ) + else: + coord_normalized = cc + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, bb, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + # types will be distinguished in the lower interface, so it doesn't + # need to be distinguished here + distinguish_types=False, + ) + extended_coord = tf.reshape(extended_coord, [nframes, -1, 3]) + model_predict_lower = call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + ) + model_predict = communicate_extended_output( + model_predict_lower, + model_output_def, + mapping, + do_atomic_virial=do_atomic_virial, ) + return model_predict diff --git a/deepmd/jax/jax2tf/nlist.py b/deepmd/jax/jax2tf/nlist.py index 9ba8fa78cc..db975ae19a 100644 --- a/deepmd/jax/jax2tf/nlist.py +++ b/deepmd/jax/jax2tf/nlist.py @@ -1,53 +1,171 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Compatibility wrappers for TensorFlow neighbor-list helpers.""" +"""Neighbor-list helpers for the JAX/jax2tf SavedModel wrapper. -from typing import ( - Any, -) +These routines mirror the backend-independent neighbor-list logic, but keep it +expressed directly as TensorFlow graph code. During SavedModel export the input +sizes are symbolic tensors; using ndtensorflow array-api helpers here can route +dynamic ``shape``/``size`` checks through Python control flow before AutoGraph +has converted them. That breaks tracing and also makes it easy to bypass the +jax2tf/XlaCallModule export path. +""" import tensorflow as tf -from deepmd.tf2.common import ( - to_tf_tensor, -) -from deepmd.tf2.utils._dpmodel import build_neighbor_list as tf2_build_neighbor_list -from deepmd.tf2.utils._dpmodel import ( - extend_coord_with_ghosts as tf2_extend_coord_with_ghosts, +from .region import ( + to_face_distance, ) -__all__ = [ - "build_neighbor_list", - "extend_coord_with_ghosts", -] + +def tf_take_along_axis(params: tf.Tensor, indices: tf.Tensor, axis: int) -> tf.Tensor: + return tf.gather(params, indices, batch_dims=axis) def build_neighbor_list( - coord: Any, - atype: Any, + coord: tf.Tensor, + atype: tf.Tensor, nloc: int, rcut: float, sel: int | list[int], distinguish_types: bool = True, ) -> tf.Tensor: - return to_tf_tensor( - tf2_build_neighbor_list( - coord, - atype, - nloc, - rcut, - sel, - distinguish_types=distinguish_types, + """Build neighbor list for a single frame. Keeps nsel neighbors.""" + batch_size = tf.shape(coord)[0] + coord = tf.reshape(coord, (batch_size, -1)) + nall = tf.shape(coord)[1] // 3 + # Fill virtual atoms with large coords so they are not neighbors of any + # real atom. + if tf.size(coord) > 0: + xmax = tf.reduce_max(coord) + tf.cast(2.0 * rcut, coord.dtype) + else: + xmax = tf.cast(2.0 * rcut, coord.dtype) + is_vir = atype < 0 + coord1 = tf.where( + is_vir[:, :, None], xmax, tf.reshape(coord, (batch_size, nall, 3)) + ) + coord1 = tf.reshape(coord1, (batch_size, nall * 3)) + if isinstance(sel, int): + sel = [sel] + nsel = sum(sel) + coord0 = coord1[:, : nloc * 3] + diff = ( + tf.reshape(coord1, [batch_size, -1, 3])[:, None, :, :] + - tf.reshape(coord0, [batch_size, -1, 3])[:, :, None, :] + ) + rr = tf.linalg.norm(diff, axis=-1) + # If central atom has two zero distances, sorting sometimes can not exclude + # itself. + rr -= tf.eye(nloc, nall, dtype=diff.dtype)[None, :, :] + nlist = tf.cast(tf.argsort(rr, axis=-1), tf.int64) + rr = tf.sort(rr, axis=-1) + rr = rr[:, :, 1:] + nlist = nlist[:, :, 1:] + nnei = tf.shape(rr)[2] + if nsel <= nnei: + rr = rr[:, :, :nsel] + nlist = nlist[:, :, :nsel] + else: + rr = tf.concat( + [ + rr, + tf.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + + tf.cast(rcut, rr.dtype), + ], + axis=-1, ) + nlist = tf.concat( + [nlist, tf.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)], + axis=-1, + ) + nlist = tf.where( + tf.logical_or((rr > tf.cast(rcut, rr.dtype)), is_vir[:, :nloc, None]), + tf.fill(tf.shape(nlist), tf.cast(-1, nlist.dtype)), + nlist, ) + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) + return nlist + + +def nlist_distinguish_types( + nlist: tf.Tensor, + atype: tf.Tensor, + sel: list[int], +) -> tf.Tensor: + """Given a nlist that does not distinguish atom types, return one that does.""" + nloc = tf.shape(nlist)[1] + ret_nlist = [] + tmp_atype = tf.tile(atype[:, None, :], [1, nloc, 1]) + mask = nlist == -1 + tnlist_0 = tf.where(mask, tf.zeros_like(nlist), nlist) + tnlist = tf_take_along_axis(tmp_atype, tnlist_0, axis=2) + tnlist = tf.where( + mask, tf.fill(tf.shape(tnlist), tf.cast(-1, tnlist.dtype)), tnlist + ) + for ii, ss in enumerate(sel): + pick_mask = tf.cast(tnlist == ii, tf.int32) + sorted_indices = tf.argsort(-pick_mask, stable=True, axis=-1) + pick_mask_sorted = -tf.sort(-pick_mask, axis=-1) + inlist = tf_take_along_axis(nlist, sorted_indices, axis=2) + inlist = tf.where( + ~tf.cast(pick_mask_sorted, tf.bool), + tf.fill(tf.shape(inlist), tf.cast(-1, inlist.dtype)), + inlist, + ) + ret_nlist.append(inlist[..., :ss]) + ret = tf.concat(ret_nlist, axis=-1) + return ret + + +def tf_outer(a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: + return tf.einsum("i,j->ij", a, b) + def extend_coord_with_ghosts( - coord: Any, - atype: Any, - cell: Any | None, + coord: tf.Tensor, + atype: tf.Tensor, + cell: tf.Tensor, rcut: float, ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - return tuple( - to_tf_tensor(value) - for value in tf2_extend_coord_with_ghosts(coord, atype, cell, rcut) + """Extend atom coordinates by appending periodic images.""" + atype_shape = tf.shape(atype) + nf, nloc = atype_shape[0], atype_shape[1] + # int64 for index + aidx = tf.range(nloc, dtype=tf.int64) + aidx = tf.tile(aidx[None, :], [nf, 1]) + if tf.shape(cell)[-1] == 0: + nall = nloc + extend_coord = coord + extend_atype = atype + extend_aidx = aidx + else: + coord = tf.reshape(coord, (nf, nloc, 3)) + cell = tf.reshape(cell, (nf, 3, 3)) + to_face = to_face_distance(cell) + nbuff = tf.cast(tf.math.ceil(tf.cast(rcut, to_face.dtype) / to_face), tf.int64) + nbuff = tf.reduce_max(nbuff, axis=0) + xi = tf.range(-nbuff[0], nbuff[0] + 1, 1, dtype=tf.int64) + yi = tf.range(-nbuff[1], nbuff[1] + 1, 1, dtype=tf.int64) + zi = tf.range(-nbuff[2], nbuff[2] + 1, 1, dtype=tf.int64) + xyz = tf_outer(xi, tf.constant([1, 0, 0], dtype=tf.int64))[:, None, None, :] + xyz = ( + xyz + tf_outer(yi, tf.constant([0, 1, 0], dtype=tf.int64))[None, :, None, :] + ) + xyz = ( + xyz + tf_outer(zi, tf.constant([0, 0, 1], dtype=tf.int64))[None, None, :, :] + ) + xyz = tf.reshape(xyz, (-1, 3)) + xyz = tf.cast(xyz, coord.dtype) + shift_idx = tf.gather(xyz, tf.argsort(tf.linalg.norm(xyz, axis=1)), axis=0) + ns = tf.shape(shift_idx)[0] + nall = ns * nloc + shift_vec = tf.einsum("sd,fdk->fsk", shift_idx, cell) + extend_coord = coord[:, None, :, :] + shift_vec[:, :, None, :] + extend_atype = tf.tile(atype[:, :, None], [1, ns, 1]) + extend_aidx = tf.tile(aidx[:, :, None], [1, ns, 1]) + + return ( + tf.reshape(extend_coord, (nf, nall * 3)), + tf.reshape(extend_atype, (nf, nall)), + tf.reshape(extend_aidx, (nf, nall)), ) diff --git a/deepmd/jax/jax2tf/region.py b/deepmd/jax/jax2tf/region.py index 3c80277f2e..c197588f87 100644 --- a/deepmd/jax/jax2tf/region.py +++ b/deepmd/jax/jax2tf/region.py @@ -1,33 +1,56 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Compatibility wrappers for TensorFlow region helpers.""" +"""TensorFlow geometry helpers used while exporting JAX models through jax2tf. -from typing import ( - Any, -) +Keep these helpers free of TensorFlow eager array wrappers. They run inside the +SavedModel tracing path for ``.savedmodel`` and should stay as small, plain TF +graph functions that AutoGraph and the TensorFlow serializer can inspect. +""" import tensorflow as tf -from deepmd.tf2.common import ( - to_tf_tensor, -) -from deepmd.tf2.utils._dpmodel import inter2phys as tf2_inter2phys -from deepmd.tf2.utils._dpmodel import normalize_coord as tf2_normalize_coord -from deepmd.tf2.utils._dpmodel import to_face_distance as tf2_to_face_distance -__all__ = [ - "inter2phys", - "normalize_coord", - "to_face_distance", -] - - -def inter2phys(coord: Any, cell: Any) -> tf.Tensor: - return to_tf_tensor(tf2_inter2phys(coord, cell)) - - -def normalize_coord(coord: Any, cell: Any) -> tf.Tensor: - return to_tf_tensor(tf2_normalize_coord(coord, cell)) - - -def to_face_distance(cell: Any) -> tf.Tensor: - return to_tf_tensor(tf2_to_face_distance(cell)) +def phys2inter( + coord: tf.Tensor, + cell: tf.Tensor, +) -> tf.Tensor: + """Convert physical coordinates to internal coordinates.""" + rec_cell = tf.linalg.inv(cell) + return tf.matmul(coord, rec_cell) + + +def inter2phys( + coord: tf.Tensor, + cell: tf.Tensor, +) -> tf.Tensor: + """Convert internal coordinates to physical coordinates.""" + return tf.matmul(coord, cell) + + +def normalize_coord( + coord: tf.Tensor, + cell: tf.Tensor, +) -> tf.Tensor: + """Apply PBC according to the atomic coordinates.""" + icoord = phys2inter(coord, cell) + icoord = tf.math.floormod(icoord, tf.cast(1.0, icoord.dtype)) + return inter2phys(icoord, cell) + + +def to_face_distance( + cell: tf.Tensor, +) -> tf.Tensor: + """Compute the to-face-distance of the simulation cell.""" + cshape = tf.shape(cell) + dist = b_to_face_distance(tf.reshape(cell, [-1, 3, 3])) + return tf.reshape(dist, tf.concat([cshape[:-2], [3]], axis=0)) + + +def b_to_face_distance(cell: tf.Tensor) -> tf.Tensor: + volume = tf.linalg.det(cell) + c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...]) + h2yz = volume / tf.linalg.norm(c_yz, axis=-1) + c_zx = tf.linalg.cross(cell[:, 2, ...], cell[:, 0, ...]) + h2zx = volume / tf.linalg.norm(c_zx, axis=-1) + c_xy = tf.linalg.cross(cell[:, 0, ...], cell[:, 1, ...]) + h2xy = volume / tf.linalg.norm(c_xy, axis=-1) + return tf.stack([h2yz, h2zx, h2xy], axis=1) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index b26aeb2c40..d86c341c10 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -1,8 +1,320 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -"""Compatibility wrapper for the TF2 SavedModel exporter.""" +"""JAX/jax2tf SavedModel export. -from deepmd.tf2.utils.serialization import ( - deserialize_to_file, +The ``.savedmodel`` suffix is the JAX SavedModel artifact used by the JAX C++ +inference path. It is intentionally different from the TF2 eager +``.savedmodeltf`` artifact: the model body below must pass through +``jax2tf.convert`` so TensorFlow stores XlaCallModule nodes. Do not replace this +module with the TF2 SavedModel exporter unless the file suffix and C++ loader +contract are changed together. +""" + +import json +from collections.abc import ( + Callable, +) + +import tensorflow as tf +from jax.experimental import ( + jax2tf, +) + +from deepmd.jax.jax2tf.format_nlist import ( + format_nlist, +) +from deepmd.jax.jax2tf.make_model import ( + model_call_from_call_lower, +) +from deepmd.jax.model.base_model import ( + BaseModel, ) -__all__ = ["deserialize_to_file"] + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize the dictionary to a JAX/jax2tf SavedModel.""" + if model_file.endswith(".savedmodel"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_common_lower + + tf_model = tf.Module() + + def exported_whether_do_atomic_virial( + do_atomic_virial: bool, has_ghost_atoms: bool + ) -> Callable: + def call_lower_with_fixed_do_atomic_virial( + coord: tf.Tensor, + atype: tf.Tensor, + nlist: tf.Tensor, + mapping: tf.Tensor, + fparam: tf.Tensor, + aparam: tf.Tensor, + ) -> dict[str, tf.Tensor]: + return call_lower( + coord, + atype, + nlist, + mapping, + fparam, + aparam, + do_atomic_virial=do_atomic_virial, + ) + + # nghost >= 1 is assumed if there is ghost atoms. Other workarounds + # do not work, such as nall; nloc + nghost - 1. + if has_ghost_atoms: + nghost = "nghost" + else: + nghost = "0" + # The converted function is the part that carries the JAX model + # semantics into TensorFlow. Its SavedModel graph is expected to + # contain XlaCallModule ops; a graph made only of ordinary TF ops + # means this path has accidentally fallen back to the TF2 exporter. + return jax2tf.convert( + call_lower_with_fixed_do_atomic_virial, + polymorphic_shapes=[ + f"(nf, nloc + {nghost}, 3)", + f"(nf, nloc + {nghost})", + f"(nf, nloc, {model.get_nnei()})", + f"(nf, nloc + {nghost})", + f"(nf, {model.get_dim_fparam()})", + f"(nf, nloc, {model.get_dim_aparam()})", + ], + with_gradient=True, + ) + + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_lower_without_atomic_virial( + coord: tf.Tensor, + atype: tf.Tensor, + nlist: tf.Tensor, + mapping: tf.Tensor, + fparam: tf.Tensor, + aparam: tf.Tensor, + ) -> dict[str, tf.Tensor]: + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) + return tf.cond( + tf.shape(coord)[1] == tf.shape(nlist)[1], + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=False + )(coord, atype, nlist, mapping, fparam, aparam), + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=True + )(coord, atype, nlist, mapping, fparam, aparam), + ) + + tf_model.call_lower = call_lower_without_atomic_virial + + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_lower_with_atomic_virial( + coord: tf.Tensor, + atype: tf.Tensor, + nlist: tf.Tensor, + mapping: tf.Tensor, + fparam: tf.Tensor, + aparam: tf.Tensor, + ) -> dict[str, tf.Tensor]: + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) + return tf.cond( + tf.shape(coord)[1] == tf.shape(nlist)[1], + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=False + )(coord, atype, nlist, mapping, fparam, aparam), + lambda: exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=True + )(coord, atype, nlist, mapping, fparam, aparam), + ) + + tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial + + def make_call_whether_do_atomic_virial(do_atomic_virial: bool) -> Callable: + if do_atomic_virial: + call_lower = call_lower_with_atomic_virial + else: + call_lower = call_lower_without_atomic_virial + + def call( + coord: tf.Tensor, + atype: tf.Tensor, + box: tf.Tensor | None = None, + fparam: tf.Tensor | None = None, + aparam: tf.Tensor | None = None, + ) -> dict[str, tf.Tensor]: + return model_call_from_call_lower( + call_lower=call_lower, + rcut=model.get_rcut(), + sel=model.get_sel(), + mixed_types=model.mixed_types(), + model_output_def=model.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return call + + @tf.function( + autograph=True, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.float64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_with_atomic_virial( + coord: tf.Tensor, + atype: tf.Tensor, + box: tf.Tensor, + fparam: tf.Tensor, + aparam: tf.Tensor, + ) -> dict[str, tf.Tensor]: + return make_call_whether_do_atomic_virial(do_atomic_virial=True)( + coord, atype, box, fparam, aparam + ) + + tf_model.call_atomic_virial = call_with_atomic_virial + + @tf.function( + autograph=True, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], tf.float64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_without_atomic_virial( + coord: tf.Tensor, + atype: tf.Tensor, + box: tf.Tensor, + fparam: tf.Tensor, + aparam: tf.Tensor, + ) -> dict[str, tf.Tensor]: + return make_call_whether_do_atomic_virial(do_atomic_virial=False)( + coord, atype, box, fparam, aparam + ) + + tf_model.call = call_without_atomic_virial + + @tf.function + def get_type_map() -> tf.Tensor: + return tf.constant(model.get_type_map(), dtype=tf.string) + + tf_model.get_type_map = get_type_map + + @tf.function + def get_rcut() -> tf.Tensor: + return tf.constant(model.get_rcut(), dtype=tf.double) + + tf_model.get_rcut = get_rcut + + @tf.function + def get_dim_fparam() -> tf.Tensor: + return tf.constant(model.get_dim_fparam(), dtype=tf.int64) + + tf_model.get_dim_fparam = get_dim_fparam + + @tf.function + def get_dim_aparam() -> tf.Tensor: + return tf.constant(model.get_dim_aparam(), dtype=tf.int64) + + tf_model.get_dim_aparam = get_dim_aparam + + @tf.function + def get_sel_type() -> tf.Tensor: + return tf.constant(model.get_sel_type(), dtype=tf.int64) + + tf_model.get_sel_type = get_sel_type + + @tf.function + def is_aparam_nall() -> tf.Tensor: + return tf.constant(model.is_aparam_nall(), dtype=tf.bool) + + tf_model.is_aparam_nall = is_aparam_nall + + @tf.function + def model_output_type() -> tf.Tensor: + return tf.constant(model.model_output_type(), dtype=tf.string) + + tf_model.model_output_type = model_output_type + + @tf.function + def mixed_types() -> tf.Tensor: + return tf.constant(model.mixed_types(), dtype=tf.bool) + + tf_model.mixed_types = mixed_types + + if model.get_min_nbor_dist() is not None: + + @tf.function + def get_min_nbor_dist() -> tf.Tensor: + return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + + tf_model.get_min_nbor_dist = get_min_nbor_dist + + @tf.function + def get_sel() -> tf.Tensor: + return tf.constant(model.get_sel(), dtype=tf.int64) + + tf_model.get_sel = get_sel + + @tf.function + def get_model_def_script() -> tf.Tensor: + return tf.constant( + json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string + ) + + tf_model.get_model_def_script = get_model_def_script + + @tf.function + def has_message_passing() -> tf.Tensor: + return tf.constant(model.has_message_passing(), dtype=tf.bool) + + tf_model.has_message_passing = has_message_passing + + @tf.function + def has_default_fparam() -> tf.Tensor: + return tf.constant(model.has_default_fparam(), dtype=tf.bool) + + tf_model.has_default_fparam = has_default_fparam + + @tf.function + def get_default_fparam() -> tf.Tensor: + default_fparam = model.get_default_fparam() + if default_fparam is None: + return tf.constant([], dtype=tf.double) + return tf.constant(default_fparam, dtype=tf.double) + + tf_model.get_default_fparam = get_default_fparam + + tf.saved_model.save( + tf_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index e484475235..e82bd683f5 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -285,8 +285,12 @@ def call_lower_with_fixed_do_atomic_virial( } save_dp_model(filename=model_file, model_dict=data) elif model_file.endswith(".savedmodel"): - from deepmd.tf2.utils.serialization import ( - deserialize_to_savedmodel, + # Keep the historical JAX/JAX2TF meaning of ".savedmodel": this + # exporter must lower the JAX model through jax2tf and preserve + # XlaCallModule ops in the SavedModel. The TF2 eager SavedModel + # exporter owns the ".savedmodeltf" suffix. + from deepmd.jax.jax2tf.serialization import ( + deserialize_to_file as deserialize_to_savedmodel, ) return deserialize_to_savedmodel(model_file, data) From ecc6d204d73cacc6b7dd7a2596892d0a2b3345a5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 13:18:44 +0800 Subject: [PATCH 2/4] fix(jax): address savedmodel review comments --- deepmd/jax/jax2tf/region.py | 2 +- deepmd/jax/jax2tf/serialization.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/jax2tf/region.py b/deepmd/jax/jax2tf/region.py index c197588f87..eb94e473cc 100644 --- a/deepmd/jax/jax2tf/region.py +++ b/deepmd/jax/jax2tf/region.py @@ -46,7 +46,7 @@ def to_face_distance( def b_to_face_distance(cell: tf.Tensor) -> tf.Tensor: - volume = tf.linalg.det(cell) + volume = tf.abs(tf.linalg.det(cell)) c_yz = tf.linalg.cross(cell[:, 1, ...], cell[:, 2, ...]) h2yz = volume / tf.linalg.norm(c_yz, axis=-1) c_zx = tf.linalg.cross(cell[:, 2, ...], cell[:, 0, ...]) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index d86c341c10..dd496dedf0 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -28,12 +28,16 @@ from deepmd.jax.model.base_model import ( BaseModel, ) +from deepmd.jax.utils.serialization import ( + _set_model_min_nbor_dist_from_data, +) def deserialize_to_file(model_file: str, data: dict) -> None: """Deserialize the dictionary to a JAX/jax2tf SavedModel.""" if model_file.endswith(".savedmodel"): model = BaseModel.deserialize(data["model"]) + _set_model_min_nbor_dist_from_data(model, data) model_def_script = data["model_def_script"] call_lower = model.call_common_lower @@ -297,6 +301,7 @@ def has_message_passing() -> tf.Tensor: return tf.constant(model.has_message_passing(), dtype=tf.bool) tf_model.has_message_passing = has_message_passing + tf_model.do_message_passing = has_message_passing @tf.function def has_default_fparam() -> tf.Tensor: From 976fe89f88b2a514a3d33154aeeaf30599d6e16a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 14:58:16 +0800 Subject: [PATCH 3/4] fix(jax): update jax2tf savedmodel tests --- deepmd/jax/jax2tf/format_nlist.py | 24 +++++++++- source/jax2tf_tests/test_format_nlist.py | 20 ++++---- source/jax2tf_tests/test_nlist.py | 60 ++++++++++++++---------- source/jax2tf_tests/test_region.py | 22 ++++----- 4 files changed, 79 insertions(+), 47 deletions(-) diff --git a/deepmd/jax/jax2tf/format_nlist.py b/deepmd/jax/jax2tf/format_nlist.py index 75e0329ac6..169e5fdeaa 100644 --- a/deepmd/jax/jax2tf/format_nlist.py +++ b/deepmd/jax/jax2tf/format_nlist.py @@ -12,6 +12,27 @@ import tensorflow as tf +def _mask_out_of_cutoff( + extended_coord: tf.Tensor, + nlist: tf.Tensor, + rcut: float, +) -> tf.Tensor: + nlist_shape = tf.shape(nlist) + n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2] + m_real_nei = nlist >= 0 + real_nlist = tf.where(m_real_nei, nlist, tf.zeros_like(nlist)) + coord0 = extended_coord[:, :n_nloc, :] + index = tf.reshape(real_nlist, [n_nf, n_nloc * n_nsel]) + coord1 = tf.gather(extended_coord, index, batch_dims=1) + coord1 = tf.reshape(coord1, [n_nf, n_nloc, n_nsel, 3]) + rr2 = tf.reduce_sum(tf.square(coord0[:, :, None, :] - coord1), axis=-1) + return tf.where( + tf.logical_and(m_real_nei, rr2 > tf.cast(rcut * rcut, rr2.dtype)), + tf.fill(tf.shape(nlist), tf.cast(-1, nlist.dtype)), + nlist, + ) + + @tf.function(autograph=True) def format_nlist( extended_coord: tf.Tensor, @@ -37,6 +58,7 @@ def format_nlist( ], axis=-1, ) + ret = _mask_out_of_cutoff(extended_coord, ret, rcut) elif n_nsel > nsel: m_real_nei = nlist >= 0 ret = tf.where(m_real_nei, nlist, tf.zeros_like(nlist)) @@ -60,7 +82,7 @@ def format_nlist( ) ret = ret[..., :nsel] else: - ret = nlist + ret = _mask_out_of_cutoff(extended_coord, nlist, rcut) # Reshape anyway; this tells XLA the shape without dynamic shape. ret = tf.reshape(ret, [n_nf, n_nloc, nsel]) return ret diff --git a/source/jax2tf_tests/test_format_nlist.py b/source/jax2tf_tests/test_format_nlist.py index 20201147ad..f1e6e26b87 100644 --- a/source/jax2tf_tests/test_format_nlist.py +++ b/source/jax2tf_tests/test_format_nlist.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf -import tensorflow.experimental.numpy as tnp from deepmd.jax.jax2tf.format_nlist import ( format_nlist, @@ -11,6 +10,7 @@ ) GLOBAL_SEED = 20241110 +DTYPE = tf.float64 class TestFormatNlist(tf.test.TestCase): @@ -19,14 +19,14 @@ def setUp(self) -> None: self.nloc = 3 self.ns = 5 * 5 * 3 self.nall = self.ns * self.nloc - self.cell = tnp.array( - [[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64 + self.cell = tf.constant( + [[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=DTYPE ) - self.icoord = tnp.array( + self.icoord = tf.constant( [[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]], - dtype=tnp.float64, + dtype=DTYPE, ) - self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32) + self.atype = tf.constant([[1, 0, 1]], dtype=tf.int32) self.nsel = [10, 10] self.rcut = 1.01 @@ -69,10 +69,10 @@ def test_format_nlist_large(self) -> None: ) # random shuffle shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2])) - nlist = tnp.take(nlist, shuffle_idx, axis=2) + nlist = tf.gather(nlist, shuffle_idx, axis=2) nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) # we only need to ensure the result is correct, no need to check the order - self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1)) + self.assertAllEqual(tf.sort(nlist, axis=-1), tf.sort(self.nlist, axis=-1)) def test_format_nlist_larger_rcut(self) -> None: nlist = build_neighbor_list( @@ -85,10 +85,10 @@ def test_format_nlist_larger_rcut(self) -> None: ) # random shuffle shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2])) - nlist = tnp.take(nlist, shuffle_idx, axis=2) + nlist = tf.gather(nlist, shuffle_idx, axis=2) nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) # we only need to ensure the result is correct, no need to check the order - self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1)) + self.assertAllEqual(tf.sort(nlist, axis=-1), tf.sort(self.nlist, axis=-1)) def test_format_nlist_dynamic_nnei_graph(self) -> None: @tf.function( diff --git a/source/jax2tf_tests/test_nlist.py b/source/jax2tf_tests/test_nlist.py index 8ac9b8daa5..363ba60ff3 100644 --- a/source/jax2tf_tests/test_nlist.py +++ b/source/jax2tf_tests/test_nlist.py @@ -2,7 +2,6 @@ import tensorflow as tf -import tensorflow.experimental.numpy as tnp from deepmd.jax.jax2tf.nlist import ( build_neighbor_list, @@ -12,7 +11,7 @@ inter2phys, ) -dtype = tnp.float64 +DTYPE = tf.float64 class TestNeighList(tf.test.TestCase): @@ -21,26 +20,29 @@ def setUp(self) -> None: self.nloc = 3 self.ns = 5 * 5 * 3 self.nall = self.ns * self.nloc - self.cell = tnp.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) - self.icoord = tnp.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) - self.atype = tnp.array([-1, 0, 1], dtype=tnp.int32) + self.cell = tf.constant( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=DTYPE + ) + self.icoord = tf.constant([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=DTYPE) + self.atype = tf.constant([-1, 0, 1], dtype=tf.int32) [self.cell, self.icoord, self.atype] = [ - tnp.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] + tf.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] ] self.coord = tf.reshape(inter2phys(self.icoord, self.cell), [-1, self.nloc * 3]) self.cell = tf.reshape(self.cell, [-1, 9]) [self.cell, self.coord, self.atype] = [ - tnp.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] + tf.tile(ii, [self.nf, 1]) for ii in [self.cell, self.coord, self.atype] ] self.rcut = 1.01 self.prec = 1e-10 self.nsel = [10, 10] - self.ref_nlist = tnp.array( + self.ref_nlist = tf.constant( [ [-1] * sum(self.nsel), [1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], - ] + ], + dtype=tf.int64, ) def test_build_notype(self) -> None: @@ -58,10 +60,14 @@ def test_build_notype(self) -> None: self.assertAllClose(nlist[0], nlist[1]) nlist_mask = nlist[0] == -1 nlist_loc = tf.gather(mapping[0], tf.where(nlist_mask, 0, nlist[0])) - nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + nlist_loc = tf.where( + nlist_mask, + tf.fill(tf.shape(nlist_loc), tf.cast(-1, nlist_loc.dtype)), + nlist_loc, + ) self.assertAllClose( - tnp.sort(nlist_loc, axis=-1), - tnp.sort(self.ref_nlist, axis=-1), + tf.sort(nlist_loc, axis=-1), + tf.sort(self.ref_nlist, axis=-1), ) def test_build_type(self) -> None: @@ -79,11 +85,15 @@ def test_build_type(self) -> None: self.assertAllClose(nlist[0], nlist[1]) nlist_mask = nlist[0] == -1 nlist_loc = tf.gather(mapping[0], tf.where(nlist_mask, 0, nlist[0])) - nlist_loc = tnp.where(nlist_mask, tnp.full_like(nlist_loc, -1), nlist_loc) + nlist_loc = tf.where( + nlist_mask, + tf.fill(tf.shape(nlist_loc), tf.cast(-1, nlist_loc.dtype)), + nlist_loc, + ) for ii in range(2): self.assertAllClose( - tnp.sort(tnp.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), - tnp.sort(tnp.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), + tf.sort(tf.split(nlist_loc, self.nsel, axis=-1)[ii], axis=-1), + tf.sort(tf.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), ) def test_extend_coord(self) -> None: @@ -105,7 +115,7 @@ def test_extend_coord(self) -> None: ) shift_vec = tf.reshape(shift_vec, [-1, self.nall, 3]) # hack!!! assumes identical cell across frames - shift_vec = tnp.matmul( + shift_vec = tf.matmul( shift_vec, tf.linalg.inv(tf.reshape(self.cell, [self.nf, 3, 3])[0]) ) # nf x nall x 3 @@ -115,40 +125,40 @@ def test_extend_coord(self) -> None: # check: shift idx aligned with grid mm, _, cc = tf.unique_with_counts(shift_vec[0][:, 0]) self.assertAllClose( - tnp.sort(mm), - tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + tf.sort(mm), + tf.constant([-2, -1, 0, 1, 2], dtype=DTYPE), rtol=self.prec, atol=self.prec, ) self.assertAllClose( cc, - tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + tf.constant([self.ns * self.nloc // 5] * 5, dtype=tf.int32), rtol=self.prec, atol=self.prec, ) mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 1]) self.assertAllClose( - tnp.sort(mm), - tnp.array([-2, -1, 0, 1, 2], dtype=dtype), + tf.sort(mm), + tf.constant([-2, -1, 0, 1, 2], dtype=DTYPE), rtol=self.prec, atol=self.prec, ) self.assertAllClose( cc, - tnp.array([self.ns * self.nloc // 5] * 5, dtype=tnp.int32), + tf.constant([self.ns * self.nloc // 5] * 5, dtype=tf.int32), rtol=self.prec, atol=self.prec, ) mm, _, cc = tf.unique_with_counts(shift_vec[1][:, 2]) self.assertAllClose( - tnp.sort(mm), - tnp.array([-1, 0, 1], dtype=dtype), + tf.sort(mm), + tf.constant([-1, 0, 1], dtype=DTYPE), rtol=self.prec, atol=self.prec, ) self.assertAllClose( cc, - tnp.array([self.ns * self.nloc // 3] * 3, dtype=tnp.int32), + tf.constant([self.ns * self.nloc // 3] * 3, dtype=tf.int32), rtol=self.prec, atol=self.prec, ) diff --git a/source/jax2tf_tests/test_region.py b/source/jax2tf_tests/test_region.py index c7bd182889..9ed03fba53 100644 --- a/source/jax2tf_tests/test_region.py +++ b/source/jax2tf_tests/test_region.py @@ -2,7 +2,6 @@ import tensorflow as tf -import tensorflow.experimental.numpy as tnp from deepmd.jax.jax2tf.region import ( inter2phys, @@ -10,24 +9,25 @@ ) GLOBAL_SEED = 20241109 +DTYPE = tf.float64 class TestRegion(tf.test.TestCase): def setUp(self) -> None: - self.cell = tnp.array( - [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], + self.cell = tf.constant( + [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=DTYPE ) - self.cell = tnp.reshape(self.cell, [1, 1, -1, 3]) - self.cell = tnp.tile(self.cell, [4, 5, 1, 1]) + self.cell = tf.reshape(self.cell, [1, 1, -1, 3]) + self.cell = tf.tile(self.cell, [4, 5, 1, 1]) self.prec = 1e-8 def test_inter_to_phys(self) -> None: rng = tf.random.Generator.from_seed(GLOBAL_SEED) - inter = rng.normal(shape=[4, 5, 3, 3]) + inter = rng.normal(shape=[4, 5, 3, 3], dtype=DTYPE) phys = inter2phys(inter, self.cell) for ii in range(4): for jj in range(5): - expected_phys = tnp.matmul(inter[ii, jj], self.cell[ii, jj]) + expected_phys = tf.matmul(inter[ii, jj], self.cell[ii, jj]) self.assertAllClose( phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec ) @@ -36,14 +36,14 @@ def test_to_face_dist(self) -> None: cell0 = self.cell[0][0] vol = tf.linalg.det(cell0) # area of surfaces xy, xz, yz - sxy = tf.linalg.norm(tnp.cross(cell0[0], cell0[1])) - sxz = tf.linalg.norm(tnp.cross(cell0[0], cell0[2])) - syz = tf.linalg.norm(tnp.cross(cell0[1], cell0[2])) + sxy = tf.linalg.norm(tf.linalg.cross(cell0[0], cell0[1])) + sxz = tf.linalg.norm(tf.linalg.cross(cell0[0], cell0[2])) + syz = tf.linalg.norm(tf.linalg.cross(cell0[1], cell0[2])) # vol / area gives distance dz = vol / sxy dy = vol / sxz dx = vol / syz - expected = tnp.array([dx, dy, dz]) + expected = tf.stack([dx, dy, dz]) dists = to_face_distance(self.cell) for ii in range(4): for jj in range(5): From 0b3809c166a033ccc88ac122e616d240b0cc87af Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 30 Jun 2026 16:23:21 +0800 Subject: [PATCH 4/4] test(jax): cover jax2tf savedmodel paths --- source/jax2tf_tests/test_make_model.py | 78 +++++++++++++++ source/jax2tf_tests/test_nlist.py | 29 ++++++ source/jax2tf_tests/test_region.py | 22 +++++ source/jax2tf_tests/test_serialization.py | 114 ++++++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 source/jax2tf_tests/test_make_model.py create mode 100644 source/jax2tf_tests/test_serialization.py diff --git a/source/jax2tf_tests/test_make_model.py b/source/jax2tf_tests/test_make_model.py new file mode 100644 index 0000000000..18e830c162 --- /dev/null +++ b/source/jax2tf_tests/test_make_model.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) +from deepmd.jax.jax2tf.make_model import ( + model_call_from_call_lower, +) + +DTYPE = tf.float64 + + +class TestMakeModel(tf.test.TestCase): + def setUp(self) -> None: + self.output_def = ModelOutputDef( + FittingOutputDef([OutputVariableDef("coord_x", [1])]) + ) + + @tf.function( + input_signature=[ + tf.TensorSpec([None, None, 3], DTYPE), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, None], DTYPE), + ] + ) + def call_model( + self, + coord: tf.Tensor, + atype: tf.Tensor, + box: tf.Tensor, + ) -> tf.Tensor: + def call_lower( + extended_coord: tf.Tensor, + extended_atype: tf.Tensor, + nlist: tf.Tensor, + mapping: tf.Tensor, + fparam: tf.Tensor, + aparam: tf.Tensor, + ) -> dict[str, tf.Tensor]: + del extended_atype, nlist, mapping, fparam, aparam + return {"coord_x": extended_coord[..., :1]} + + nframes = tf.shape(coord)[0] + nloc = tf.shape(atype)[1] + ret = model_call_from_call_lower( + call_lower=call_lower, + rcut=0.4, + sel=[1], + mixed_types=True, + model_output_def=self.output_def, + coord=coord, + atype=atype, + box=box, + fparam=tf.zeros([nframes, 0], dtype=DTYPE), + aparam=tf.zeros([nframes, nloc, 0], dtype=DTYPE), + ) + return ret["coord_x"] + + def test_model_call_without_box(self) -> None: + coord = tf.constant([[[0.2, 0.0, 0.0], [0.8, 0.0, 0.0]]], dtype=DTYPE) + atype = tf.constant([[0, 1]], dtype=tf.int32) + box = tf.zeros([1, 0, 0], dtype=DTYPE) + + coord_x = self.call_model(coord, atype, box) + + self.assertAllClose(coord_x, coord[..., :1]) + + def test_model_call_with_box_normalizes_coord(self) -> None: + coord = tf.constant([[[1.2, 0.0, 0.0], [-0.2, 0.0, 0.0]]], dtype=DTYPE) + atype = tf.constant([[0, 1]], dtype=tf.int32) + box = tf.eye(3, batch_shape=[1], dtype=DTYPE) + + coord_x = self.call_model(coord, atype, box) + + self.assertAllClose(coord_x[:, :2], [[[0.2], [0.8]]]) diff --git a/source/jax2tf_tests/test_nlist.py b/source/jax2tf_tests/test_nlist.py index 363ba60ff3..9e5c035c08 100644 --- a/source/jax2tf_tests/test_nlist.py +++ b/source/jax2tf_tests/test_nlist.py @@ -96,6 +96,35 @@ def test_build_type(self) -> None: tf.sort(tf.split(self.ref_nlist, self.nsel, axis=-1)[ii], axis=-1), ) + def test_build_pad(self) -> None: + coord = tf.constant([[[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]]], dtype=DTYPE) + atype = tf.constant([[0, 0]], dtype=tf.int32) + + nlist = build_neighbor_list( + coord, + atype, + nloc=2, + rcut=1.0, + sel=3, + distinguish_types=False, + ) + + expected = tf.constant([[[1, -1, -1], [0, -1, -1]]], dtype=tf.int64) + self.assertAllEqual(nlist, expected) + + def test_extend_coord_empty_cell(self) -> None: + coord = tf.constant([[[0.0, 0.0, 0.0], [0.5, 0.0, 0.0]]], dtype=DTYPE) + atype = tf.constant([[1, 0]], dtype=tf.int32) + empty_cell = tf.zeros([1, 0], dtype=DTYPE) + + ecoord, eatype, mapping = extend_coord_with_ghosts( + coord, atype, empty_cell, self.rcut + ) + + self.assertAllClose(ecoord, tf.reshape(coord, [1, 6])) + self.assertAllEqual(eatype, atype) + self.assertAllEqual(mapping, tf.constant([[0, 1]], dtype=tf.int64)) + def test_extend_coord(self) -> None: ecoord, eatype, mapping = extend_coord_with_ghosts( self.coord, self.atype, self.cell, self.rcut diff --git a/source/jax2tf_tests/test_region.py b/source/jax2tf_tests/test_region.py index 9ed03fba53..df042d9068 100644 --- a/source/jax2tf_tests/test_region.py +++ b/source/jax2tf_tests/test_region.py @@ -5,6 +5,8 @@ from deepmd.jax.jax2tf.region import ( inter2phys, + normalize_coord, + phys2inter, to_face_distance, ) @@ -32,6 +34,26 @@ def test_inter_to_phys(self) -> None: phys[ii, jj], expected_phys, rtol=self.prec, atol=self.prec ) + def test_phys_to_inter(self) -> None: + rng = tf.random.Generator.from_seed(GLOBAL_SEED) + inter = rng.normal(shape=[4, 5, 3, 3], dtype=DTYPE) + phys = inter2phys(inter, self.cell) + actual_inter = phys2inter(phys, self.cell) + self.assertAllClose(actual_inter, inter, rtol=self.prec, atol=self.prec) + + def test_normalize_coord(self) -> None: + inter = tf.constant( + [[[[1.2, -0.3, 0.5], [0.25, 1.75, -1.25], [-0.5, 0.5, 2.0]]]], + dtype=DTYPE, + ) + cell = self.cell[:1, :1] + coord = inter2phys(inter, cell) + expected = inter2phys(tf.math.floormod(inter, tf.constant(1.0, DTYPE)), cell) + + actual = normalize_coord(coord, cell) + + self.assertAllClose(actual, expected, rtol=self.prec, atol=self.prec) + def test_to_face_dist(self) -> None: cell0 = self.cell[0][0] vol = tf.linalg.det(cell0) diff --git a/source/jax2tf_tests/test_serialization.py b/source/jax2tf_tests/test_serialization.py new file mode 100644 index 0000000000..cc352f33df --- /dev/null +++ b/source/jax2tf_tests/test_serialization.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) + +import pytest +from tensorflow.core.protobuf import ( + saved_model_pb2, +) + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) + + +def _saved_model_ops(model_dir: Path) -> set[str]: + saved_model = saved_model_pb2.SavedModel() + saved_model.ParseFromString((model_dir / "saved_model.pb").read_bytes()) + ops = set() + for meta_graph in saved_model.meta_graphs: + ops.update(node.op for node in meta_graph.graph_def.node) + for func in meta_graph.graph_def.library.function: + ops.update(node.op for node in func.node_def) + return ops + + +def test_savedmodel_export_contains_xla_call_module(tmp_path, monkeypatch) -> None: + pytest.importorskip("jax") + pytest.importorskip("flax") + pytest.importorskip("orbax.checkpoint") + + import jax.numpy as jnp + + from deepmd.jax.jax2tf import ( + serialization, + ) + + class DummyModel: + def call_common_lower( + self, + coord, + atype, + nlist, + mapping, + fparam, + aparam, + do_atomic_virial: bool = False, + ): + del nlist, mapping, fparam, aparam, do_atomic_virial + return { + "coord_x": coord[..., :1] + + jnp.asarray(atype[..., None], dtype=coord.dtype) * 0.0 + } + + def get_nnei(self) -> int: + return 1 + + def get_rcut(self) -> float: + return 1.0 + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + def get_sel(self) -> list[int]: + return [1] + + def mixed_types(self) -> bool: + return True + + def model_output_def(self) -> ModelOutputDef: + return ModelOutputDef(FittingOutputDef([OutputVariableDef("coord_x", [1])])) + + def get_type_map(self) -> list[str]: + return ["O"] + + def get_sel_type(self) -> list[int]: + return [] + + def is_aparam_nall(self) -> bool: + return False + + def model_output_type(self) -> list[str]: + return ["coord_x"] + + def get_min_nbor_dist(self) -> None: + return None + + def has_message_passing(self) -> bool: + return False + + def has_default_fparam(self) -> bool: + return False + + def get_default_fparam(self) -> None: + return None + + monkeypatch.setattr( + serialization.BaseModel, + "deserialize", + staticmethod(lambda data: DummyModel()), + ) + + model_dir = tmp_path / "dummy.savedmodel" + serialization.deserialize_to_file( + str(model_dir), + {"model": {"type": "dummy"}, "model_def_script": {"type": "dummy"}}, + ) + + assert "XlaCallModule" in _saved_model_ops(model_dir)