From e214ee9e62a79195d9038c2c99239ba64af39437 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Mon, 6 Jan 2025 19:52:18 +0800 Subject: [PATCH] fix bug in tf tensor loss --- deepmd/tf/loss/tensor.py | 4 ++-- deepmd/tf/train/trainer.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/deepmd/tf/loss/tensor.py b/deepmd/tf/loss/tensor.py index d7f879b4b4..dbd7da99df 100644 --- a/deepmd/tf/loss/tensor.py +++ b/deepmd/tf/loss/tensor.py @@ -155,7 +155,7 @@ def label_requirement(self) -> list[DataRequirementItem]: # data required data_requirements.append( DataRequirementItem( - "atom_" + self.label_name, + "atomic_" + self.label_name, self.tensor_size, atomic=True, must=False, @@ -176,7 +176,7 @@ def label_requirement(self) -> list[DataRequirementItem]: if self.enable_atomic_weight: data_requirements.append( DataRequirementItem( - "atom_weight", + "atomic_weight", 1, atomic=True, must=False, diff --git a/deepmd/tf/train/trainer.py b/deepmd/tf/train/trainer.py index 1fe72bc5e3..f8010091d8 100644 --- a/deepmd/tf/train/trainer.py +++ b/deepmd/tf/train/trainer.py @@ -281,6 +281,13 @@ def _build_network(self, data, suffix="") -> None: tf.int32, [None], name="t_mesh" ) self.place_holders["is_training"] = tf.placeholder(tf.bool) + # update "atomic_" in self.place_holders.keys() with "atom_" + for kk in list(self.place_holders.keys()): + if "atomic_" in kk: + self.place_holders[kk.replace("atomic_", "atom_")] = ( + self.place_holders.pop(kk) + ) + self.model_pred = self.model.build( self.place_holders["coord"], self.place_holders["type"],