diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index d65a3a984b..53c3b1d468 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -549,13 +549,14 @@ def build( aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) atype_nall = tf.reshape(atype, [-1, natoms[1]]) - atype_filter = tf.cast(atype_nall >= 0, GLOBAL_TF_FLOAT_PRECISION) + self.atype_nloc = tf.slice( + atype_nall, [0, 0], [-1, natoms[0]] + ) ## lammps will make error + atype_filter = tf.cast(self.atype_nloc >= 0, GLOBAL_TF_FLOAT_PRECISION) + self.atype_nloc = tf.reshape(self.atype_nloc, [-1]) # prevent embedding_lookup error, # but the filter will be applied anyway - atype_nall = tf.clip_by_value(atype_nall, 0, self.ntypes - 1) - self.atype_nloc = tf.reshape( - tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1] - ) ## lammps will make error + self.atype_nloc = tf.clip_by_value(self.atype_nloc, 0, self.ntypes - 1) if type_embedding is not None: atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc) else: