Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp


@tf.function(autograph=True)
def format_nlist(
extended_coord: tnp.ndarray,
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
"""Format neighbor list.

If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.

Parameters
----------
extended_coord
The extended coordinates of the atoms.
shape: nf x nall x 3
nlist
The neighbor list.
shape: nf x nloc x nnei
nsel
The number of selected neighbors.
rcut
The cutoff radius.

Returns
-------
nlist
The formatted neighbor list.
shape: nf x nloc x nsel
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = extended_coord.reshape([n_nf, -1, 3])
Comment thread
njzjz marked this conversation as resolved.

if n_nsel < nsel:
# make a copy before revise
ret = tnp.concatenate(
[
nlist,
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
],
axis=-1,
)

elif n_nsel > nsel:
# make a copy before revise
m_real_nei = nlist >= 0
ret = tnp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
index = tnp.repeat(index, 3, axis=2)
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nsel]
else: # n_nsel == nsel:
ret = nlist
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
return ret
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
jax2tf,
)

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.make_model import (
model_call_from_call_lower,
)
Expand Down Expand Up @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
Expand All @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand All @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand Down
91 changes: 91 additions & 0 deletions source/jax2tf_tests/test_format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)

GLOBAL_SEED = 20241110


class TestFormatNlist(tf.test.TestCase):
def setUp(self):
self.nf = 3
self.nloc = 3
self.ns = 5 * 5 * 3
self.nall = self.ns * self.nloc
self.cell = tnp.array(
[[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64
)
self.icoord = tnp.array(
[[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]],
dtype=tnp.float64,
)
self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32)
self.nsel = [10, 10]
self.rcut = 1.01

self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(

Check notice

Code scanning / CodeQL

Unused local variable

Variable mapping is not used.
self.icoord, self.atype, self.cell, self.rcut
)
self.nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=False,
)

def test_format_nlist_equal(self):
nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_less(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) - 5,
distinguish_types=False,
)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_large(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 5,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
Comment thread
njzjz marked this conversation as resolved.
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

def test_format_nlist_larger_rcut(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut * 2,
40,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))