Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dpdata/formats/deepmd/comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def dump(folder, data, set_size=5000, comp_prec=np.float32, remove_sets=True):
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/npy format."
)
continue
if nframes > 0 and np.asarray(data[dtype.name]).size == 0:
# an optional frame property (e.g. forces/virials when
# cal_force/cal_stress is disabled) may be empty while the
# system still has frames. Skip it instead of writing a
# meaningless (nframes, 0) array that cannot be reshaped on load.
continue
Comment thread
coderabbitai[bot] marked this conversation as resolved.
ddata = np.reshape(data[dtype.name], [nframes, -1])
if np.issubdtype(ddata.dtype, np.floating):
ddata = ddata.astype(comp_prec)
Expand Down
6 changes: 6 additions & 0 deletions dpdata/formats/deepmd/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,11 @@ def dump(folder, data):
f"Shape of {dtype.name} is not (nframes, ...), but {dtype.shape}. This type of data will not converted to deepmd/raw format."
)
continue
if nframes > 0 and np.asarray(data[dtype.name]).size == 0:
# an optional frame property (e.g. forces/virials when
# cal_force/cal_stress is disabled) may be empty while the
# system still has frames. Skip it instead of writing a
# meaningless (nframes, 0) array that cannot be reshaped on load.
continue
Comment thread
coderabbitai[bot] marked this conversation as resolved.
ddata = np.reshape(data[dtype.name], [nframes, -1])
np.savetxt(os.path.join(folder, f"{dtype.deepmd_name}.raw"), ddata)
29 changes: 29 additions & 0 deletions tests/test_abacus_pw_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import shutil
import tempfile
import unittest

import numpy as np
Expand Down Expand Up @@ -163,6 +164,34 @@ def test_noforcestress_job(self):
# test append self
system_ch4.append(system_ch4)

def test_noforcestress_deepmd_roundtrip(self):
# a converged scf without force/stress should survive a
# round-trip through deepmd/npy without raising a reshape error
system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf")
tmp_dir = tempfile.mkdtemp()
try:
system_ch4.to("deepmd/npy", tmp_dir)
reloaded = dpdata.LabeledSystem(tmp_dir, fmt="deepmd/npy")
self.assertEqual(reloaded.get_nframes(), system_ch4.get_nframes())
# empty force/virial should not be written as bogus data
self.assertFalse(reloaded.data.get("forces", np.empty(0)).size)
self.assertTrue("virials" not in reloaded.data)
finally:
shutil.rmtree(tmp_dir)

def test_noforcestress_deepmd_raw_roundtrip(self):
# same as above but for the deepmd/raw format
system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf")
tmp_dir = tempfile.mkdtemp()
try:
system_ch4.to("deepmd/raw", tmp_dir)
reloaded = dpdata.LabeledSystem(tmp_dir, fmt="deepmd/raw")
self.assertEqual(reloaded.get_nframes(), system_ch4.get_nframes())
self.assertFalse(reloaded.data.get("forces", np.empty(0)).size)
self.assertTrue("virials" not in reloaded.data)
finally:
shutil.rmtree(tmp_dir)


if __name__ == "__main__":
unittest.main()
Loading