Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -372,3 +373,7 @@ def _get_output_shape(self, odef, nframes, natoms):
return [nframes, natoms, *odef.shape, 1]
else:
raise RuntimeError("unknown category")

def get_model_def_script(self) -> dict:
"""Get model defination script."""
return json.loads(self.model.get_model_def_script())
5 changes: 5 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from deepmd.entrypoints.neighbor_stat import (
neighbor_stat,
)
from deepmd.entrypoints.show import (
show,
)
from deepmd.entrypoints.test import (
test,
)
Expand Down Expand Up @@ -81,5 +84,7 @@ def main(args: argparse.Namespace):
start_dpgui(**dict_args)
elif args.command == "convert-backend":
convert_backend(**dict_args)
elif args.command == "show":
show(**dict_args)
else:
raise ValueError(f"Unknown command: {args.command}")
68 changes: 68 additions & 0 deletions deepmd/entrypoints/show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
)

from deepmd.infer.deep_eval import (
DeepEval,
)

log = logging.getLogger(__name__)


def show(
*,
INPUT: str,
ATTRIBUTES: List[str],
**kwargs,
):
model = DeepEval(INPUT, head=0)
model_params = model.get_model_def_script()
model_is_multi_task = "model_dict" in model_params
log.info("This is a multitask model") if model_is_multi_task else log.info(
"This is a singletask model"
)

if "model-branch" in ATTRIBUTES:
# The model must be multitask mode
if not model_is_multi_task:
raise RuntimeError(
"The 'model-branch' option requires a multitask model."
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
model_branches += ["RANDOM"]
log.info(
f"Available model branches are {model_branches}, "
f"where 'RANDOM' means using a randomly initialized fitting net."
)
if "type-map" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
if "descriptor" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
if "fitting-net" in ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
fitting_net = model_params["model_dict"][branch]["fitting_net"]
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")
8 changes: 8 additions & 0 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def get_has_spin(self):
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""

def get_model_def_script(self) -> dict:
"""Get model defination script."""
raise NotImplementedError("Not implemented in this backend.")


class DeepEval(ABC):
"""High-level Deep Evaluator interface.
Expand Down Expand Up @@ -546,3 +550,7 @@ def has_spin(self) -> bool:
def get_ntypes_spin(self) -> int:
"""Get the number of spin atom types of this model. Only used in old implement."""
return self.deep_eval.get_ntypes_spin()

def get_model_def_script(self) -> dict:
"""Get model defination script."""
return self.deep_eval.get_model_def_script()
14 changes: 10 additions & 4 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def main_parser() -> argparse.ArgumentParser:
parser_show = subparsers.add_parser(
"show",
parents=[parser_log],
help="(Supported backend: PyTorch) Show the information of a model",
help="Show the information of a model",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
Expand Down Expand Up @@ -880,15 +880,21 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
return parsed_args


def main():
def main(args: Optional[List[str]] = None):
"""DeePMD-kit new entry point.

Parameters
----------
args : List[str]
list of command line arguments, main purpose is testing default option None
takes arguments from sys.argv

Raises
------
RuntimeError
if no command was input
"""
args = parse_args()
args = parse_args(args=args)

if args.backend not in BACKEND_TABLE:
raise ValueError(f"Unknown backend {args.backend}")
Expand All @@ -900,6 +906,7 @@ def main():
"neighbor-stat",
"gui",
"convert-backend",
"show",
):
# common entrypoints
from deepmd.entrypoints.main import main as deepmd_main
Expand All @@ -910,7 +917,6 @@ def main():
"compress",
"convert-from",
"train-nvnmd",
"show",
"change-bias",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
Expand Down
67 changes: 0 additions & 67 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,71 +356,6 @@ def freeze(FLAGS):
)


def show(FLAGS):
if FLAGS.INPUT.split(".")[-1] == "pt":
state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
if "model" in state_dict:
state_dict = state_dict["model"]
model_params = state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.split(".")[-1] == "pth":
model_params_string = torch.jit.load(
FLAGS.INPUT, map_location=env.DEVICE
).model_def_script
model_params = json.loads(model_params_string)
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
model_is_multi_task = "model_dict" in model_params
log.info("This is a multitask model") if model_is_multi_task else log.info(
"This is a singletask model"
)

if "model-branch" in FLAGS.ATTRIBUTES:
# The model must be multitask mode
if not model_is_multi_task:
raise RuntimeError(
"The 'model-branch' option requires a multitask model."
" The provided model does not meet this criterion."
)
model_branches = list(model_params["model_dict"].keys())
model_branches += ["RANDOM"]
log.info(
f"Available model branches are {model_branches}, "
f"where 'RANDOM' means using a randomly initialized fitting net."
)
if "type-map" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
type_map = model_params["model_dict"][branch]["type_map"]
log.info(f"The type_map of branch {branch} is {type_map}")
else:
type_map = model_params["type_map"]
log.info(f"The type_map is {type_map}")
if "descriptor" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
descriptor = model_params["model_dict"][branch]["descriptor"]
log.info(f"The descriptor parameter of branch {branch} is {descriptor}")
else:
descriptor = model_params["descriptor"]
log.info(f"The descriptor parameter is {descriptor}")
if "fitting-net" in FLAGS.ATTRIBUTES:
if model_is_multi_task:
model_branches = list(model_params["model_dict"].keys())
for branch in model_branches:
fitting_net = model_params["model_dict"][branch]["fitting_net"]
log.info(
f"The fitting_net parameter of branch {branch} is {fitting_net}"
)
else:
fitting_net = model_params["fitting_net"]
log.info(f"The fitting_net parameter is {fitting_net}")


def change_bias(FLAGS):
if FLAGS.INPUT.endswith(".pt"):
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
Expand Down Expand Up @@ -574,8 +509,6 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
FLAGS.model = FLAGS.checkpoint_folder
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pth"))
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
else:
Expand Down
15 changes: 14 additions & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(
*args: Any,
auto_batch_size: Union[bool, int, AutoBatchSize] = True,
neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None,
head: Optional[str] = None,
head: Optional[Union[str, int]] = None,
**kwargs: Any,
):
self.output_def = output_def
Expand All @@ -110,9 +111,12 @@ def __init__(
if "model" in state_dict:
state_dict = state_dict["model"]
self.input_param = state_dict["_extra_state"]["model_params"]
self.model_def_script = self.input_param
self.multi_task = "model_dict" in self.input_param
if self.multi_task:
model_keys = list(self.input_param["model_dict"].keys())
if isinstance(head, int):
head = model_keys[0]
assert (
head is not None
), f"Head must be set for multitask model! Available heads are: {model_keys}"
Expand All @@ -134,6 +138,11 @@ def __init__(
elif str(self.model_path).endswith(".pth"):
model = torch.jit.load(model_file, map_location=env.DEVICE)
self.dp = ModelWrapper(model)
model_def_script = self.dp.model["Default"].get_model_def_script()
if model_def_script:
self.model_def_script = json.loads(model_def_script)
else:
self.model_def_script = {}
else:
raise ValueError("Unknown model file format!")
self.rcut = self.dp.model["Default"].get_rcut()
Expand Down Expand Up @@ -590,6 +599,10 @@ def eval_typeebd(self) -> np.ndarray:
typeebd = torch.cat(out, dim=1)
return to_numpy_array(typeebd)

def get_model_def_script(self) -> str:
"""Get model defination script."""
return self.model_def_script


# For tests only
def eval_model(
Expand Down
8 changes: 8 additions & 0 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
from functools import (
lru_cache,
)
Expand Down Expand Up @@ -1123,6 +1124,13 @@ def get_numb_dos(self) -> int:
def get_has_efield(self) -> bool:
return self.has_efield

def get_model_def_script(self) -> dict:
"""Get model defination script."""
t_script = self._get_tensor("train_attr/training_script:0")
[script] = run_sess(self.sess, [t_script], feed_dict={})
model_def_script = script.decode("utf-8")
return json.loads(model_def_script)["model"]


class DeepEvalOld:
# old class for DipoleChargeModifier only
Expand Down
1 change: 1 addition & 0 deletions source/tests/infer/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, filename: str):
self.type_map = config["type_map"]
self.dim_fparam = config["dim_fparam"]
self.dim_aparam = config["dim_aparam"]
self.model_def_script = config.get("model_def_script")

@lru_cache
def get_model(self, suffix: str, out_file: Optional[str] = None) -> str:
Expand Down
38 changes: 38 additions & 0 deletions source/tests/infer/deeppot-testcase.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,44 @@ rcut: 6.0
type_map: ["O", "H"]
dim_fparam: 0
dim_aparam: 0
model_def_script:
{
"data_bias_nsample": 10,
"data_stat_nbatch": 10,
"data_stat_protect": 0.01,
"descriptor":
{
"activation_function": "tanh",
"axis_neuron": 4,
"exclude_types": [],
"neuron": [2, 4, 8],
"precision": "default",
"rcut": 6.0,
"rcut_smth": 0.5,
"resnet_dt": False,
"seed": 1,
"sel": [46, 92],
"set_davg_zero": False,
"trainable": True,
"type": "se_e2_a",
"type_one_side": False,
},
"fitting_net":
{
"activation_function": "tanh",
"atom_ener": [],
"neuron": [6, 6, 6],
"numb_aparam": 0,
"numb_fparam": 0,
"precision": "default",
"rcond": 0.001,
"resnet_dt": True,
"seed": 1,
"trainable": True,
"type": "ener",
},
"type_map": ["O", "H"],
}
results:
- coord:
[
Expand Down
Loading