Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 76 additions & 12 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,88 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Compatibility wrappers for TensorFlow neighbor-list formatting."""
"""TensorFlow graph helpers for JAX/jax2tf SavedModel export.

from typing import (
Any,
)
This module is not a generic TF2 compatibility wrapper. The functions here are
traced while saving the JAX ``.savedmodel`` artifact, before control reaches
``jax2tf.convert``. Keep the implementation in plain TensorFlow ops so
AutoGraph can see the tensor-dependent branches and emit graph control flow.
Routing through ndtensorflow/dpmodel helpers can leave symbolic shape
comparisons as Python ``if`` statements during SavedModel tracing.
"""

import tensorflow as tf

from deepmd.tf2.common import (
to_tf_tensor,
)
from deepmd.tf2.utils._dpmodel import format_nlist as tf2_format_nlist

__all__ = ["format_nlist"]
def _mask_out_of_cutoff(
extended_coord: tf.Tensor,
nlist: tf.Tensor,
rcut: float,
) -> tf.Tensor:
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
m_real_nei = nlist >= 0
real_nlist = tf.where(m_real_nei, nlist, tf.zeros_like(nlist))
coord0 = extended_coord[:, :n_nloc, :]
index = tf.reshape(real_nlist, [n_nf, n_nloc * n_nsel])
coord1 = tf.gather(extended_coord, index, batch_dims=1)
coord1 = tf.reshape(coord1, [n_nf, n_nloc, n_nsel, 3])
rr2 = tf.reduce_sum(tf.square(coord0[:, :, None, :] - coord1), axis=-1)
return tf.where(
tf.logical_and(m_real_nei, rr2 > tf.cast(rcut * rcut, rr2.dtype)),
tf.fill(tf.shape(nlist), tf.cast(-1, nlist.dtype)),
nlist,
)


@tf.function(autograph=True)
def format_nlist(
extended_coord: Any,
nlist: Any,
extended_coord: tf.Tensor,
nlist: tf.Tensor,
nsel: int,
rcut: float,
) -> tf.Tensor:
return to_tf_tensor(tf2_format_nlist(extended_coord, nlist, nsel, rcut))
"""Format neighbor list.

If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = tf.reshape(extended_coord, [n_nf, -1, 3])

if n_nsel < nsel:
ret = tf.concat(
[
nlist,
tf.fill([n_nf, n_nloc, nsel - n_nsel], tf.cast(-1, nlist.dtype)),
],
axis=-1,
)
ret = _mask_out_of_cutoff(extended_coord, ret, rcut)
elif n_nsel > nsel:
m_real_nei = nlist >= 0
ret = tf.where(m_real_nei, nlist, tf.zeros_like(nlist))
coord0 = extended_coord[:, :n_nloc, :]
index = tf.reshape(ret, [n_nf, n_nloc * n_nsel])
coord1 = tf.gather(extended_coord, index, batch_dims=1)
coord1 = tf.reshape(coord1, [n_nf, n_nloc, n_nsel, 3])
rr2 = tf.reduce_sum(tf.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tf.where(
m_real_nei,
rr2,
tf.fill(tf.shape(rr2), tf.constant(float("inf"), rr2.dtype)),
)
ret_mapping = tf.argsort(rr2, axis=-1)
rr2 = tf.sort(rr2, axis=-1)
ret = tf.gather(ret, ret_mapping, batch_dims=2)
ret = tf.where(
rr2 > rcut * rcut,
tf.fill(tf.shape(ret), tf.cast(-1, ret.dtype)),
ret,
)
ret = ret[..., :nsel]
else:
ret = _mask_out_of_cutoff(extended_coord, nlist, rcut)
# Reshape anyway; this tells XLA the shape without dynamic shape.
ret = tf.reshape(ret, [n_nf, n_nloc, nsel])
return ret
113 changes: 66 additions & 47 deletions deepmd/jax/jax2tf/make_model.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,47 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Compatibility wrappers for TensorFlow model-call helpers."""
"""Outer TensorFlow call wrapper for the JAX/jax2tf SavedModel.

The wrapper builds PBC ghosts, neighbor lists, and output communication around
the lower JAX model. It deliberately uses the graph-safe helpers in this
package instead of the TF2 eager helpers, because this code is traced by
``tf.saved_model.save`` and must keep tensor-shape branches convertible by
AutoGraph before it invokes the jax2tf-converted model body.
"""

from collections.abc import (
Callable,
)
from typing import (
Any,
)

import tensorflow as tf

from deepmd.dpmodel.output_def import (
ModelOutputDef,
)
from deepmd.tf2.common import (
to_tf_tensor,
unwrap_value,
wrap_value,
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)
from deepmd.tf2.make_model import (
model_call_from_call_lower as tf2_model_call_from_call_lower,
from deepmd.jax.jax2tf.region import (
normalize_coord,
)
from deepmd.jax.jax2tf.transform_output import (
communicate_extended_output,
)

__all__ = ["model_call_from_call_lower"]


def _wrap_call_lower(call_lower: Callable[..., dict[str, Any]]) -> Callable:
def wrapped_call_lower(
extended_coord: Any,
extended_atype: Any,
nlist: Any,
mapping: Any,
**kwargs: Any,
) -> dict[str, Any]:
return wrap_value(
call_lower(
to_tf_tensor(extended_coord),
to_tf_tensor(extended_atype),
to_tf_tensor(nlist),
to_tf_tensor(mapping),
**{kk: to_tf_tensor(vv) for kk, vv in kwargs.items()},
)
)

return wrapped_call_lower


def model_call_from_call_lower(
*, # enforce keyword-only arguments
call_lower: Callable[..., dict[str, Any]],
call_lower: Callable[
[
tf.Tensor,
tf.Tensor,
tf.Tensor,
tf.Tensor,
tf.Tensor,
bool,
],
dict[str, tf.Tensor],
],
rcut: float,
sel: list[int],
mixed_types: bool,
Expand All @@ -60,18 +53,44 @@ def model_call_from_call_lower(
aparam: tf.Tensor,
do_atomic_virial: bool = False,
) -> dict[str, tf.Tensor]:
return unwrap_value(
tf2_model_call_from_call_lower(
call_lower=_wrap_call_lower(call_lower),
rcut=rcut,
sel=sel,
mixed_types=mixed_types,
model_output_def=model_output_def,
coord=coord,
atype=atype,
box=box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
"""Return model prediction from lower interface."""
atype_shape = tf.shape(atype)
nframes, nloc = atype_shape[0], atype_shape[1]
cc, bb, fp, ap = coord, box, fparam, aparam
del coord, box, fparam, aparam
if tf.shape(bb)[-1] != 0:
coord_normalized = normalize_coord(
tf.reshape(cc, [nframes, nloc, 3]),
tf.reshape(bb, [nframes, 3, 3]),
)
else:
coord_normalized = cc
extended_coord, extended_atype, mapping = extend_coord_with_ghosts(
coord_normalized, atype, bb, rcut
)
nlist = build_neighbor_list(
extended_coord,
extended_atype,
nloc,
rcut,
sel,
# types will be distinguished in the lower interface, so it doesn't
# need to be distinguished here
distinguish_types=False,
)
extended_coord = tf.reshape(extended_coord, [nframes, -1, 3])
model_predict_lower = call_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fp,
aparam=ap,
)
model_predict = communicate_extended_output(
model_predict_lower,
model_output_def,
mapping,
do_atomic_virial=do_atomic_virial,
)
return model_predict
Loading
Loading