Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
200edb6
use tf.variable to store bias_atom_polar
Yi-FanLi Feb 6, 2025
ac4a4d5
read variable t_bias_atom_polar when shift_diag=True
Yi-FanLi Feb 6, 2025
bb4edf0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2025
489ac47
clean commented lines
Yi-FanLi Feb 6, 2025
dea0aba
Merge remote-tracking branch 'refs/remotes/origin/bias_atom_polar' in…
Yi-FanLi Feb 6, 2025
4296757
use self.t_bias_atom_polar for type_embedding
Yi-FanLi Feb 6, 2025
fc0f88f
fix syntax error
Yi-FanLi Feb 6, 2025
4390144
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2025
b794731
use warning when t_bias_atom_polar is not found
Yi-FanLi Feb 6, 2025
8259be7
Merge branch 'bias_atom_polar' of github.com:Yi-FanLi/deepmd-kit into…
Yi-FanLi Feb 6, 2025
22e1fd5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2025
e54c56c
remove unnecessary pass
Yi-FanLi Feb 6, 2025
e8bdc39
use stacklevel=2 for warning
Yi-FanLi Feb 6, 2025
36d4677
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2025
72c50f6
add bias_atom_polar in serialize and deserialize
Yi-FanLi Feb 6, 2025
059df07
Merge branch 'bias_atom_polar' of github.com:Yi-FanLi/deepmd-kit into…
Yi-FanLi Feb 6, 2025
4bcca99
rename constant_matrix to bias_atom_polar
Yi-FanLi Feb 6, 2025
5ca01b4
rename bias_atom_polar to constant_matrix to keep consistent with oth…
Yi-FanLi Feb 7, 2025
7312a43
pop bias_atom_e from PolarFitting of the dp backend
Yi-FanLi Feb 7, 2025
4dfb83a
remove remaining print
Yi-FanLi Feb 7, 2025
9b7ef69
pop bias_atom_e from variables for dpmodel in test_tf_consistent_with…
Yi-FanLi Feb 7, 2025
c9baf66
do not pop bias_atom_e from dpmodel
Yi-FanLi Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
from deepmd.tf.loss.tensor import (
TensorLoss,
)
from deepmd.tf.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.tf.utils.graph import (
get_fitting_net_variables_from_graph_def,
get_tensor_by_name_from_graph,
)
from deepmd.tf.utils.network import (
one_layer,
Expand Down Expand Up @@ -429,6 +433,16 @@ def build(
atype = input_dict.get("atype", None)
nframes = input_dict.get("nframes")
start_index = 0

with tf.variable_scope("fitting_attr" + suffix, reuse=reuse):
self.t_constant_matrix = tf.get_variable(
"t_constant_matrix",
self.constant_matrix.shape,
dtype=GLOBAL_TF_FLOAT_PRECISION,
trainable=False,
initializer=tf.constant_initializer(self.constant_matrix),
)

inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])
if nframes is None:
Expand All @@ -452,7 +466,9 @@ def build(
# nframes x nloc_masked
constant_matrix = tf.reshape(
tf.reshape(
tf.tile(tf.repeat(self.constant_matrix, natoms[2:]), [nframes]),
tf.tile(
tf.repeat(self.t_constant_matrix, natoms[2:]), [nframes]
),
[nframes, -1],
)[nloc_mask],
[nframes, -1],
Expand Down Expand Up @@ -504,7 +520,9 @@ def build(
# shift and scale
sel_type_idx = self.sel_type.index(type_i)
final_layer = final_layer * self.scale[sel_type_idx]
final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye(
final_layer = final_layer + tf.slice(
self.t_constant_matrix, [sel_type_idx], [1]
) * tf.eye(
3,
batch_shape=[tf.shape(inputs)[0], natoms[2 + type_i]],
dtype=GLOBAL_TF_FLOAT_PRECISION,
Expand Down Expand Up @@ -551,6 +569,16 @@ def init_variables(
self.fitting_net_variables = get_fitting_net_variables_from_graph_def(
graph_def, suffix=suffix
)
if self.shift_diag:
try:
Comment thread Fixed
self.constant_matrix = get_tensor_by_name_from_graph(
graph, f"fitting_attr{suffix}/t_constant_matrix"
)
except GraphWithoutTensorError:
warnings.warn(
"You are trying to read a model trained with shift_diag=True, but the mean of the diagonal terms of the polarizability is not stored in the graph. This will lead to wrong inference results. You may train your model with the latest DeePMD-kit to avoid this issue.",
stacklevel=2,
)

def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
"""Receive the mixed precision setting.
Expand Down Expand Up @@ -612,6 +640,15 @@ def serialize(self, suffix: str) -> dict:
variables=self.fitting_net_variables,
suffix=suffix,
),
"@variables": {
"fparam_avg": None,
"fparam_inv_std": None,
"aparam_avg": None,
"aparam_inv_std": None,
"case_embd": None,
"scale": self.scale.reshape(-1, 1),
"constant_matrix": self.constant_matrix.reshape(-1),
},
"type_map": self.type_map,
}
return data
Expand Down Expand Up @@ -639,6 +676,7 @@ def deserialize(cls, data: dict, suffix: str):
data["nets"],
suffix=suffix,
)
fitting.constant_matrix = data["@variables"]["constant_matrix"].ravel()
return fitting


Expand Down
3 changes: 3 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def test_tf_consistent_with_ref(self) -> None:
data1.pop("@version")
data2.pop("@version")

if tf_obj.__class__.__name__.startswith("Polar"):
data1["@variables"].pop("bias_atom_e")

np.testing.assert_equal(data1, data2)
for rr1, rr2 in zip(ret1, ret2):
np.testing.assert_allclose(
Expand Down