From 8ca67714475057b2e2d9a66ac59036d422a7a898 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 14 Mar 2023 18:38:19 -0400 Subject: [PATCH 1/3] =?UTF-8?q?atype=5Ffilter=20should=20have=20the=20shap?= =?UTF-8?q?e=20of=20nloc=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepmd/fit/ener.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index d65a3a984b..81e296dff0 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -549,13 +549,13 @@ def build( aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) atype_nall = tf.reshape(atype, [-1, natoms[1]]) - atype_filter = tf.cast(atype_nall >= 0, GLOBAL_TF_FLOAT_PRECISION) - # prevent embedding_lookup error, - # but the filter will be applied anyway - atype_nall = tf.clip_by_value(atype_nall, 0, self.ntypes - 1) self.atype_nloc = tf.reshape( tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1] ) ## lammps will make error + atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION) + # prevent embedding_lookup error, + # but the filter will be applied anyway + self.atype_nloc = tf.clip_by_value(self.atype_nloc, 0, self.ntypes - 1) if type_embedding is not None: atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc) else: From 82ba54f1f10718cf04788a009ec4165ec17e96dd Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 14 Mar 2023 18:46:54 -0400 Subject: [PATCH 2/3] make filter before reshape Signed-off-by: Jinzhe Zeng --- deepmd/fit/ener.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 81e296dff0..7754a95cb7 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -549,10 +549,9 @@ def build( aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) atype_nall = tf.reshape(atype, [-1, natoms[1]]) - self.atype_nloc = tf.reshape( - tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1] - ) ## lammps will make error + self.atype_nloc = tf.slice(atype_nall, [0, 0], [-1, natoms[0]]) ## lammps will make error atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION) + self.atype_nloc = tf.reshape(self.atype_nloc, [-1]) # prevent embedding_lookup error, # but the filter will be applied anyway self.atype_nloc = tf.clip_by_value(self.atype_nloc, 0, self.ntypes - 1) From d9878d93670346a3a0d3b791dbed5b8f098a5ad7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Mar 2023 22:47:22 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/fit/ener.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 7754a95cb7..53c3b1d468 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -549,7 +549,9 @@ def build( aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) atype_nall = tf.reshape(atype, [-1, natoms[1]]) - self.atype_nloc = tf.slice(atype_nall, [0, 0], [-1, natoms[0]]) ## lammps will make error + self.atype_nloc = tf.slice( + atype_nall, [0, 0], [-1, natoms[0]] + ) ## lammps will make error atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION) self.atype_nloc = tf.reshape(self.atype_nloc, [-1]) # prevent embedding_lookup error,