From e68ed8f401cef890ebee9e3b9ae9d0e878846e5f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 May 2022 02:06:17 -0400 Subject: [PATCH] add Loss abstract class --- deepmd/loss/ener.py | 6 +++-- deepmd/loss/loss.py | 59 +++++++++++++++++++++++++++++++++++++++++++ deepmd/loss/tensor.py | 4 ++- 3 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 deepmd/loss/loss.py diff --git a/deepmd/loss/ener.py b/deepmd/loss/ener.py index 29e1fa4068..4c1b5622fd 100644 --- a/deepmd/loss/ener.py +++ b/deepmd/loss/ener.py @@ -5,8 +5,10 @@ from deepmd.env import global_cvt_2_tf_float from deepmd.env import global_cvt_2_ener_float from deepmd.utils.sess import run_sess +from .loss import Loss -class EnerStdLoss () : + +class EnerStdLoss (Loss) : """ Standard loss function for DP models """ @@ -221,7 +223,7 @@ def print_on_training(self, return print_str -class EnerDipoleLoss () : +class EnerDipoleLoss (Loss) : def __init__ (self, starter_learning_rate : float, start_pref_e : float = 0.1, diff --git a/deepmd/loss/loss.py b/deepmd/loss/loss.py new file mode 100644 index 0000000000..6ae9dc7399 --- /dev/null +++ b/deepmd/loss/loss.py @@ -0,0 +1,59 @@ +from abc import ABCMeta, abstractmethod +from typing import Tuple, Dict +from deepmd.env import tf + + +class Loss(metaclass=ABCMeta): + """The abstract class for the loss function.""" + @abstractmethod + def build(self, + learning_rate: tf.Tensor, + natoms: tf.Tensor, + model_dict: Dict[str, tf.Tensor], + label_dict: Dict[str, tf.Tensor], + suffix: str) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + """Build the loss function graph. + + Parameters + ---------- + learning_rate : tf.Tensor + learning rate + natoms : tf.Tensor + number of atoms + model_dict : dict[str, tf.Tensor] + A dictionary that maps model keys to tensors + label_dict : dict[str, tf.Tensor] + A dictionary that maps label keys to tensors + suffix : str + suffix + + Returns + ------- + tf.Tensor + the total squared loss + dict[str, tf.Tensor] + A dictionary that maps loss keys to more loss tensors + """ + + @abstractmethod + def eval(self, + sess: tf.Session, + feed_dict: Dict[tf.placeholder, tf.Tensor], + natoms: tf.Tensor) -> dict: + """Eval the loss function. + + Parameters + ---------- + sess : tf.Session + TensorFlow session + feed_dict : dict[tf.placeholder, tf.Tensor] + A dictionary that maps graph elements to values + natoms : tf.Tensor + number of atoms + + Returns + ------- + dict + A dictionary that maps keys to values. It + should contain key `natoms` + """ diff --git a/deepmd/loss/tensor.py b/deepmd/loss/tensor.py index de4dee6fa8..64763627a3 100644 --- a/deepmd/loss/tensor.py +++ b/deepmd/loss/tensor.py @@ -5,8 +5,10 @@ from deepmd.env import global_cvt_2_tf_float from deepmd.env import global_cvt_2_ener_float from deepmd.utils.sess import run_sess +from .loss import Loss -class TensorLoss () : + +class TensorLoss(Loss) : """ Loss function for tensorial properties. """