Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,72 @@ def main_parser() -> argparse.ArgumentParser:
help="treat all types as a single type. Used with se_atten descriptor.",
)

# change_bias
parser_change_bias = subparsers.add_parser(
"change-bias",
parents=[parser_log],
help="(Supported backend: PyTorch) Change model out bias according to the input data.",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
epilog=textwrap.dedent(
"""\
examples:
dp change-bias model.pt -s data -n 10 -m change
"""
),
)
parser_change_bias.add_argument(
"INPUT", help="The input checkpoint file or frozen model file"
)
parser_change_bias_source = parser_change_bias.add_mutually_exclusive_group()
parser_change_bias_source.add_argument(
"-s",
"--system",
default=".",
type=str,
help="The system dir. Recursively detect systems in this directory",
)
parser_change_bias_source.add_argument(
"-b",
"--bias-value",
default=None,
type=float,
nargs="+",
help="The user defined value for each type in the type_map of the model, split with spaces.\n"
"For example, '-93.57 -187.1' for energy bias of two elements. "
"Only supports energy bias changing.",
)
parser_change_bias.add_argument(
"-n",
"--numb-batch",
default=0,
type=int,
help="The number of frames for bias changing in one data system. 0 means all data.",
)
parser_change_bias.add_argument(
"-m",
"--mode",
type=str,
default="change",
choices=["change", "set"],
help="The mode for changing energy bias: \n"
"change (default) : perform predictions using input model on target dataset, "
"and do least square on the errors to obtain the target shift as bias.\n"
"set : directly use the statistic bias in the target dataset.",
)
parser_change_bias.add_argument(
"-o",
"--output",
default=None,
type=str,
help="The model after changing bias.",
)
parser_change_bias.add_argument(
"--model-branch",
type=str,
default=None,
help="Model branch chosen for changing bias if multi-task model.",
)

Comment thread
iProzd marked this conversation as resolved.
# --version
parser.add_argument(
"--version", action="version", version=f"DeePMD-kit v{__version__}"
Expand Down Expand Up @@ -831,6 +897,7 @@ def main():
"convert-from",
"train-nvnmd",
"show",
"change-bias",
):
deepmd_main = BACKENDS[args.backend]().entry_point_hook
elif args.command is None:
Expand Down
137 changes: 137 additions & 0 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
Comment thread
iProzd marked this conversation as resolved.
import json
import logging
import os
Expand All @@ -23,6 +24,9 @@
from deepmd import (
__version__,
)
from deepmd.common import (
expand_sys_str,
)
from deepmd.env import (
GLOBAL_CONFIG,
)
Expand All @@ -44,6 +48,9 @@
from deepmd.pt.train import (
training,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
from deepmd.pt.utils import (
env,
)
Expand All @@ -59,6 +66,12 @@
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.argcheck import (
normalize,
)
Expand Down Expand Up @@ -377,6 +390,128 @@ def show(FLAGS):
log.info(f"The fitting_net parameter is {fitting_net}")


def change_bias(FLAGS):
Comment thread
iProzd marked this conversation as resolved.
Comment thread
iProzd marked this conversation as resolved.
if FLAGS.INPUT.endswith(".pt"):
Comment thread
iProzd marked this conversation as resolved.
old_state_dict = torch.load(FLAGS.INPUT, map_location=env.DEVICE)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
Comment thread
iProzd marked this conversation as resolved.
Comment thread
iProzd marked this conversation as resolved.
model_params = model_state_dict["_extra_state"]["model_params"]
elif FLAGS.INPUT.endswith(".pth"):
old_model = torch.jit.load(FLAGS.INPUT, map_location=env.DEVICE)
model_params_string = old_model.get_model_def_script()
model_params = json.loads(model_params_string)
old_state_dict = old_model.state_dict()
model_state_dict = old_state_dict
else:
raise RuntimeError(
"The model provided must be a checkpoint file with a .pt extension "
"or a frozen model with a .pth extension"
)
multi_task = "model_dict" in model_params
Comment thread
iProzd marked this conversation as resolved.
model_branch = FLAGS.model_branch
bias_adjust_mode = (
Comment thread
iProzd marked this conversation as resolved.
"change-by-statistic" if FLAGS.mode == "change" else "set-by-statistic"
)
if multi_task:
assert (
model_branch is not None
), "For multitask model, the model branch must be set!"
assert model_branch in model_params["model_dict"], (
f"For multitask model, the model branch must be in the 'model_dict'! "
f"Available options are : {list(model_params['model_dict'].keys())}."
)
log.info(f"Changing out bias for model {model_branch}.")
Comment thread
iProzd marked this conversation as resolved.
model = training.get_model_for_wrapper(model_params)
type_map = (
model_params["type_map"]
if not multi_task
else model_params["model_dict"][model_branch]["type_map"]
)
model_to_change = model if not multi_task else model[model_branch]
if FLAGS.INPUT.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])
else:
# for .pth
model.load_state_dict(old_state_dict)

if FLAGS.bias_value is not None:
# use user-defined bias
assert model_to_change.model_type in [
Comment thread
iProzd marked this conversation as resolved.
"ener"
], "User-defined bias is only available for energy model!"
assert (
len(FLAGS.bias_value) == len(type_map)
), f"The number of elements in the bias should be the same as that in the type_map: {type_map}."
old_bias = model_to_change.get_out_bias()
bias_to_set = torch.tensor(
FLAGS.bias_value, dtype=old_bias.dtype, device=old_bias.device
).view(old_bias.shape)
model_to_change.set_out_bias(bias_to_set)
log.info(
f"Change output bias of {type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
f"to {to_numpy_array(bias_to_set).reshape(-1)!s}."
)
updated_model = model_to_change
else:
# calculate bias on given systems
data_systems = process_systems(expand_sys_str(FLAGS.system))
data_single = DpLoaderSet(
data_systems,
1,
type_map,
)
mock_loss = training.get_loss(
{"inference": True}, 1.0, len(type_map), model_to_change
)
data_requirement = mock_loss.label_requirement
data_requirement += training.get_additional_data_requirement(model_to_change)
data_single.add_data_requirement(data_requirement)
nbatches = FLAGS.numb_batch if FLAGS.numb_batch != 0 else float("inf")
sampled_data = make_stat_input(
data_single.systems,
data_single.dataloaders,
nbatches,
)
updated_model = training.model_change_out_bias(
model_to_change, sampled_data, _bias_adjust_mode=bias_adjust_mode
)

if not multi_task:
model = updated_model
else:
model[model_branch] = updated_model

if FLAGS.INPUT.endswith(".pt"):
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
torch.save(old_state_dict, output_path)
Comment thread
iProzd marked this conversation as resolved.
else:
# for .pth
output_path = (
FLAGS.output
if FLAGS.output is not None
else FLAGS.INPUT.replace(".pth", "_updated.pth")
)
model = torch.jit.script(model)
torch.jit.save(
model,
output_path,
{},
)
log.info(f"Saved model to {output_path}")


Comment thread
iProzd marked this conversation as resolved.
@record
def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
if not isinstance(args, argparse.Namespace):
Expand All @@ -401,6 +536,8 @@ def main(args: Optional[Union[List[str], argparse.Namespace]] = None):
freeze(FLAGS)
elif FLAGS.command == "show":
show(FLAGS)
elif FLAGS.command == "change-bias":
change_bias(FLAGS)
else:
raise RuntimeError(f"Invalid command {FLAGS.command}!")

Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.has_pf = (start_pref_pf != 0.0 and limit_pref_pf != 0.0) or inference
self.has_gf = (start_pref_gf != 0.0 and limit_pref_gf != 0.0) or inference
self.has_gf = start_pref_gf != 0.0 and limit_pref_gf != 0.0

self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def init_out_stat(self):
self.register_buffer("out_bias", out_bias_data)
self.register_buffer("out_std", out_std_data)

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.out_bias = out_bias

def __setitem__(self, key, value):
if key in ["out_bias"]:
self.out_bias = value
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def forward_common(
def get_out_bias(self) -> torch.Tensor:
return self.atomic_model.get_out_bias()

def set_out_bias(self, out_bias: torch.Tensor) -> None:
self.atomic_model.set_out_bias(out_bias)

def change_out_bias(
self,
merged,
Expand Down
Loading