diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 3091fd56ca..c8fd4e86e8 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -27,8 +27,12 @@ from deepmd.tf.loss.tensor import ( TensorLoss, ) +from deepmd.tf.utils.errors import ( + GraphWithoutTensorError, +) from deepmd.tf.utils.graph import ( get_fitting_net_variables_from_graph_def, + get_tensor_by_name_from_graph, ) from deepmd.tf.utils.network import ( one_layer, @@ -429,6 +433,16 @@ def build( atype = input_dict.get("atype", None) nframes = input_dict.get("nframes") start_index = 0 + + with tf.variable_scope("fitting_attr" + suffix, reuse=reuse): + self.t_constant_matrix = tf.get_variable( + "t_constant_matrix", + self.constant_matrix.shape, + dtype=GLOBAL_TF_FLOAT_PRECISION, + trainable=False, + initializer=tf.constant_initializer(self.constant_matrix), + ) + inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]) rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) if nframes is None: @@ -452,7 +466,9 @@ def build( # nframes x nloc_masked constant_matrix = tf.reshape( tf.reshape( - tf.tile(tf.repeat(self.constant_matrix, natoms[2:]), [nframes]), + tf.tile( + tf.repeat(self.t_constant_matrix, natoms[2:]), [nframes] + ), [nframes, -1], )[nloc_mask], [nframes, -1], @@ -504,7 +520,9 @@ def build( # shift and scale sel_type_idx = self.sel_type.index(type_i) final_layer = final_layer * self.scale[sel_type_idx] - final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye( + final_layer = final_layer + tf.slice( + self.t_constant_matrix, [sel_type_idx], [1] + ) * tf.eye( 3, batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]], dtype=GLOBAL_TF_FLOAT_PRECISION, @@ -551,6 +569,16 @@ def init_variables( self.fitting_net_variables = get_fitting_net_variables_from_graph_def( graph_def, suffix=suffix ) + if self.shift_diag: + try: + self.constant_matrix = get_tensor_by_name_from_graph( + graph, f"fitting_attr{suffix}/t_constant_matrix" + ) + except GraphWithoutTensorError: + warnings.warn( + "You are trying to read a model trained with shift_diag=True, but the mean of the diagonal terms of the polarizability is not stored in the graph. This will lead to wrong inference results. You may train your model with the latest DeePMD-kit to avoid this issue.", + stacklevel=2, + ) def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None: """Receive the mixed precision setting. @@ -612,6 +640,15 @@ def serialize(self, suffix: str) -> dict: variables=self.fitting_net_variables, suffix=suffix, ), + "@variables": { + "fparam_avg": None, + "fparam_inv_std": None, + "aparam_avg": None, + "aparam_inv_std": None, + "case_embd": None, + "scale": self.scale.reshape(-1, 1), + "constant_matrix": self.constant_matrix.reshape(-1), + }, "type_map": self.type_map, } return data @@ -639,6 +676,7 @@ def deserialize(cls, data: dict, suffix: str): data["nets"], suffix=suffix, ) + fitting.constant_matrix = data["@variables"]["constant_matrix"].ravel() return fitting diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 7ee8477a3b..35230e2140 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -354,6 +354,9 @@ def test_tf_consistent_with_ref(self) -> None: data1.pop("@version") data2.pop("@version") + if tf_obj.__class__.__name__.startswith("Polar"): + data1["@variables"].pop("bias_atom_e") + np.testing.assert_equal(data1, data2) for rr1, rr2 in zip(ret1, ret2): np.testing.assert_allclose(