diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 9ab141bdc2..e0d9b39f57 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -93,7 +93,7 @@ def call( natoms: int, model_dict: dict[str, Array], label_dict: dict[str, Array], - ) -> dict[str, Array]: + ) -> tuple[Array, dict[str, Array]]: """Calculate loss from model results and labeled results.""" energy = model_dict["energy"] force = model_dict["force"] @@ -244,10 +244,10 @@ def call( if self.has_gf: find_drdq = label_dict["find_drdq"] drdq = label_dict["drdq"] - force_reshape_nframes = xp.reshape(force, (-1, natoms[0] * 3)) - force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms[0] * 3)) + force_reshape_nframes = xp.reshape(force, (-1, natoms * 3)) + force_hat_reshape_nframes = xp.reshape(force_hat, (-1, natoms * 3)) drdq_reshape = xp.reshape( - drdq, (-1, natoms[0] * 3, self.numb_generalized_coord) + drdq, (-1, natoms * 3, self.numb_generalized_coord) ) gen_force_hat = xp.einsum( "bij,bi->bj", drdq_reshape, force_hat_reshape_nframes diff --git a/deepmd/dpmodel/loss/loss.py b/deepmd/dpmodel/loss/loss.py index 4b9831c344..1037c1e469 100644 --- a/deepmd/dpmodel/loss/loss.py +++ b/deepmd/dpmodel/loss/loss.py @@ -28,8 +28,16 @@ def call( natoms: int, model_dict: dict[str, Array], label_dict: dict[str, Array], - ) -> dict[str, Array]: - """Calculate loss from model results and labeled results.""" + ) -> tuple[Array, dict[str, Array]]: + """Calculate loss from model results and labeled results. + + Returns + ------- + loss + The scalar loss to minimize. + more_loss + A dictionary of additional loss terms/metrics for logging. + """ @property @abstractmethod