Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 56 additions & 10 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,19 +848,65 @@ def test_polar(
if detail_file is not None:
detail_path = Path(detail_file)

pe = np.concatenate(
(
np.reshape(test_data["polarizability"][:numb_test], [-1, 9]),
np.reshape(polar, [-1, 9]),
),
axis=1,
)
if not atomic:
pe = np.concatenate(
(
np.reshape(test_data["polarizability"][:numb_test], [-1, 9]),
np.reshape(polar, [-1, 9]),
),
axis=1,
)
header_text = (
"data_pxx data_pxy data_pxz data_pyx data_pyy data_pyz data_pzx "
"data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy "
"pred_pyz pred_pzx pred_pzy pred_pzz"
)
else:
pe = np.concatenate(
(
np.reshape(
test_data["atomic_polarizability"][:numb_test],
[-1, 9 * sel_natoms],
),
np.reshape(polar, [-1, 9 * sel_natoms]),
),
axis=1,
)
header_text = [
f"{letter}{number}"
for number in range(1, sel_natoms + 1)
for letter in [
"data_pxx",
"data_pxy",
"data_pxz",
"data_pyx",
"data_pyy",
"data_pyz",
"data_pzx",
"data_pzy",
"data_pzz",
]
] + [
f"{letter}{number}"
for number in range(1, sel_natoms + 1)
for letter in [
"pred_pxx",
"pred_pxy",
"pred_pxz",
"pred_pyx",
"pred_pyy",
"pred_pyz",
"pred_pzx",
"pred_pzy",
"pred_pzz",
]
]
header_text = " ".join(header_text)

np.savetxt(
detail_path.with_suffix(".out"),
pe,
header="data_pxx data_pxy data_pxz data_pyx data_pyy data_pyz data_pzx "
"data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz "
"pred_pzx pred_pzy pred_pzz",
header=header_text,
)
return {"rmse": (rmse_f, polar.size)}

Expand Down
Loading