Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions deepmd/pt/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from typing import (
Optional,
)

import torch

from deepmd.common import (
j_loader,
)
from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.update_sel import (
UpdateSel,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
get_data,
)

log = logging.getLogger(__name__)


def enable_compression(
Expand All @@ -14,12 +35,44 @@ def enable_compression(
stride: float = 0.01,
extrapolate: int = 5,
check_frequency: int = -1,
training_script: Optional[str] = None,
):
saved_model = torch.jit.load(input_file, map_location="cpu")
model_def_script = json.loads(saved_model.model_def_script)
model = get_model(model_def_script)
model.load_state_dict(saved_model.state_dict())

if model.get_min_nbor_dist() is None:
log.info(
"Minimal neighbor distance is not saved in the model, compute it from the training data."
)
if training_script is None:
raise ValueError(
"The model does not have a minimum neighbor distance, "
"so the training script and data must be provided "
"(via -t,--training-script)."
)

jdata = j_loader(training_script)
jdata = update_deepmd_input(jdata)

type_map = jdata["model"].get("type_map", None)
train_data = get_data(
jdata["training"]["training_data"],
0, # not used
type_map,
None,
)
update_sel = UpdateSel()
t_min_nbor_dist = update_sel.get_min_nbor_dist(
train_data,
)
model.min_nbor_dist = torch.tensor(
t_min_nbor_dist,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
)
Comment thread
njzjz marked this conversation as resolved.

model.enable_compression(
extrapolate,
stride,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
stride=FLAGS.step,
extrapolate=FLAGS.extrapolate,
check_frequency=FLAGS.frequency,
training_script=FLAGS.training_script,
)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")
Expand Down
159 changes: 158 additions & 1 deletion source/tests/pt/test_model_compression_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,48 @@ def _init_models_exclude_types():
return INPUT, frozen_model, compressed_model


def _init_models_skip_neighbor_stat():
suffix = "-skip-neighbor-stat"
data_file = str(tests_path / os.path.join("model_compression", "data"))
frozen_model = str(tests_path / f"dp-original{suffix}.pth")
compressed_model = str(tests_path / f"dp-compressed{suffix}.pth")
INPUT = str(tests_path / "input.json")
jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json")))
jdata["training"]["training_data"]["systems"] = data_file
with open(INPUT, "w") as fp:
json.dump(jdata, fp, indent=4)

ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat")
np.testing.assert_equal(ret, 0, "DP train failed!")
ret = run_dp("dp --pt freeze -o " + frozen_model)
np.testing.assert_equal(ret, 0, "DP freeze failed!")
ret = run_dp(
"dp --pt compress "
+ " -i "
+ frozen_model
+ " -o "
+ compressed_model
+ " -t "
+ INPUT
)
np.testing.assert_equal(ret, 0, "DP model compression failed!")
return INPUT, frozen_model, compressed_model


def setUpModule():
global \
INPUT, \
FROZEN_MODEL, \
COMPRESSED_MODEL, \
INPUT_ET, \
FROZEN_MODEL_ET, \
COMPRESSED_MODEL_ET
COMPRESSED_MODEL_ET, \
FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \
COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT
INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models()
_, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = (
_init_models_skip_neighbor_stat()
)
INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types()


Expand Down Expand Up @@ -572,5 +605,129 @@ def test_2frame_atm(self):
np.testing.assert_almost_equal(vv0, vv1, default_places)


class TestSkipNeighborStat(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT)
cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT)
cls.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
]
)
cls.atype = [0, 1, 1, 0, 1, 1]
cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0])

def test_attrs(self):
self.assertEqual(self.dp_original.get_ntypes(), 2)
self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places)
self.assertEqual(self.dp_original.get_type_map(), ["O", "H"])
self.assertEqual(self.dp_original.get_dim_fparam(), 0)
self.assertEqual(self.dp_original.get_dim_aparam(), 0)

self.assertEqual(self.dp_compressed.get_ntypes(), 2)
self.assertAlmostEqual(
self.dp_compressed.get_rcut(), 6.0, places=default_places
)
self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"])
self.assertEqual(self.dp_compressed.get_dim_fparam(), 0)
self.assertEqual(self.dp_compressed.get_dim_aparam(), 0)

def test_1frame(self):
ee0, ff0, vv0 = self.dp_original.eval(
self.coords, self.box, self.atype, atomic=False
)
ee1, ff1, vv1 = self.dp_compressed.eval(
self.coords, self.box, self.atype, atomic=False
)
# check shape of the returns
nframes = 1
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)

def test_1frame_atm(self):
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
self.coords, self.box, self.atype, atomic=True
)
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
self.coords, self.box, self.atype, atomic=True
)
# check shape of the returns
nframes = 1
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ae0.shape, (nframes, natoms, 1))
self.assertEqual(av0.shape, (nframes, natoms, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
self.assertEqual(ae1.shape, (nframes, natoms, 1))
self.assertEqual(av1.shape, (nframes, natoms, 9))
# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ae0, ae1, default_places)
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)

def test_2frame_atm(self):
coords2 = np.concatenate((self.coords, self.coords))
box2 = np.concatenate((self.box, self.box))
ee0, ff0, vv0, ae0, av0 = self.dp_original.eval(
coords2, box2, self.atype, atomic=True
)
ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval(
coords2, box2, self.atype, atomic=True
)
# check shape of the returns
nframes = 2
natoms = len(self.atype)
self.assertEqual(ee0.shape, (nframes, 1))
self.assertEqual(ff0.shape, (nframes, natoms, 3))
self.assertEqual(vv0.shape, (nframes, 9))
self.assertEqual(ae0.shape, (nframes, natoms, 1))
self.assertEqual(av0.shape, (nframes, natoms, 9))
self.assertEqual(ee1.shape, (nframes, 1))
self.assertEqual(ff1.shape, (nframes, natoms, 3))
self.assertEqual(vv1.shape, (nframes, 9))
self.assertEqual(ae1.shape, (nframes, natoms, 1))
self.assertEqual(av1.shape, (nframes, natoms, 9))

# check values
np.testing.assert_almost_equal(ff0, ff1, default_places)
np.testing.assert_almost_equal(ae0, ae1, default_places)
np.testing.assert_almost_equal(av0, av1, default_places)
np.testing.assert_almost_equal(ee0, ee1, default_places)
np.testing.assert_almost_equal(vv0, vv1, default_places)


if __name__ == "__main__":
unittest.main()