From 8cb3ca18ae4a1e641bb8edccfe6163d62a5b6de2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 18 Feb 2022 01:06:42 -0500 Subject: [PATCH] recover input stats from frozen models Before, only NN parameters were recovered. --- deepmd/descriptor/loc_frame.py | 18 +++++++++++++++++- deepmd/descriptor/se.py | 4 +++- deepmd/train/trainer.py | 5 ++--- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/deepmd/descriptor/loc_frame.py b/deepmd/descriptor/loc_frame.py index 3a178ff494..b85e70823a 100644 --- a/deepmd/descriptor/loc_frame.py +++ b/deepmd/descriptor/loc_frame.py @@ -8,6 +8,7 @@ from deepmd.env import default_tf_session_config from deepmd.utils.sess import run_sess from .descriptor import Descriptor +from deepmd.utils.graph import get_tensor_by_name @Descriptor.register("loc_frame") class DescrptLocFrame (Descriptor) : @@ -367,4 +368,19 @@ def _compute_dstats_sys_nonsmth (self, def _compute_std (self,sumv2, sumv, sumn) : return np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn)) - + def init_variables(self, + model_file : str, + suffix : str = "", + ) -> None: + """ + Init the embedding net variables with the given frozen model + + Parameters + ---------- + model_file : str + The input frozen model file + suffix : str, optional + The suffix of the scope + """ + self.davg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_avg' % suffix) + self.tavg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_std' % suffix) diff --git a/deepmd/descriptor/se.py b/deepmd/descriptor/se.py index c42a7fb46d..832dcfcd58 100644 --- a/deepmd/descriptor/se.py +++ b/deepmd/descriptor/se.py @@ -1,7 +1,7 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.utils.graph import get_embedding_net_variables +from deepmd.utils.graph import get_embedding_net_variables, get_tensor_by_name from .descriptor import Descriptor @@ -106,6 +106,8 @@ def init_variables(self, The suffix of the scope """ self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix) + self.davg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_avg' % suffix) + self.tavg = get_tensor_by_name(model_file, 'descrpt_attr%s/t_std' % suffix) @property def precision(self) -> tf.DType: diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 2b3f8a249c..6009759781 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -282,10 +282,9 @@ def build (self, )) self.type_map = data.get_type_map() self.batch_size = data.get_batch_size() - if self.run_opt.init_mode not in ('init_from_model', 'restart'): + if self.run_opt.init_mode not in ('init_from_model', 'restart', 'init_from_frz_model'): # self.saver.restore (in self._init_session) will restore avg and std variables, so data_stat is useless - # currently init_from_frz_model does not restore data_stat variables - # TODO: restore avg and std in the init_from_frz_model mode + # init_from_frz_model will restore data_stat variables in `init_variables` method log.info("data stating... (this step may take long time)") self.model.data_stat(data)