Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deepmd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from deepmd.env import global_cvt_2_tf_float
from deepmd.env import global_cvt_2_ener_float
from deepmd.utils.sess import run_sess
from .loss import Loss

class EnerStdLoss () :

class EnerStdLoss (Loss) :
"""
Standard loss function for DP models
"""
Expand Down Expand Up @@ -221,7 +223,7 @@ def print_on_training(self,
return print_str


class EnerDipoleLoss () :
class EnerDipoleLoss (Loss) :
def __init__ (self,
starter_learning_rate : float,
start_pref_e : float = 0.1,
Expand Down
59 changes: 59 additions & 0 deletions deepmd/loss/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from abc import ABCMeta, abstractmethod
from typing import Tuple, Dict
from deepmd.env import tf


class Loss(metaclass=ABCMeta):
"""The abstract class for the loss function."""
@abstractmethod
def build(self,
learning_rate: tf.Tensor,
natoms: tf.Tensor,
model_dict: Dict[str, tf.Tensor],
label_dict: Dict[str, tf.Tensor],
suffix: str) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
"""Build the loss function graph.

Parameters
----------
learning_rate : tf.Tensor
learning rate
natoms : tf.Tensor
number of atoms
model_dict : dict[str, tf.Tensor]
A dictionary that maps model keys to tensors
label_dict : dict[str, tf.Tensor]
A dictionary that maps label keys to tensors
suffix : str
suffix

Returns
-------
tf.Tensor
the total squared loss
dict[str, tf.Tensor]
A dictionary that maps loss keys to more loss tensors
"""

@abstractmethod
def eval(self,
sess: tf.Session,
feed_dict: Dict[tf.placeholder, tf.Tensor],
natoms: tf.Tensor) -> dict:
"""Eval the loss function.

Parameters
----------
sess : tf.Session
TensorFlow session
feed_dict : dict[tf.placeholder, tf.Tensor]
A dictionary that maps graph elements to values
natoms : tf.Tensor
number of atoms

Returns
-------
dict
A dictionary that maps keys to values. It
should contain key `natoms`
"""
4 changes: 3 additions & 1 deletion deepmd/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from deepmd.env import global_cvt_2_tf_float
from deepmd.env import global_cvt_2_ener_float
from deepmd.utils.sess import run_sess
from .loss import Loss

class TensorLoss () :

class TensorLoss(Loss) :
"""
Loss function for tensorial properties.
"""
Expand Down