Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import argparse
import copy
import io
import json
import logging
Expand Down Expand Up @@ -101,6 +100,38 @@
log = logging.getLogger(__name__)


def _update_changed_model_tensors(
target_state_dict: dict[str, Any],
source_state_dict: dict[str, Any],
key_prefix: str | None = None,
) -> None:
"""Copy changed tensors into an existing state dict without breaking aliases."""
for key, source_value in source_state_dict.items():
if key == "_extra_state":
continue
if key_prefix is not None and not key.startswith(key_prefix):
continue
if key not in target_state_dict:
target_state_dict[key] = (
source_value.detach().clone()
if torch.is_tensor(source_value)
else source_value
)
continue
target_value = target_state_dict[key]
if torch.is_tensor(target_value) and torch.is_tensor(source_value):
if (
target_value.shape == source_value.shape
and target_value.dtype == source_value.dtype
):
if not torch.equal(target_value, source_value):
target_value.copy_(source_value)
else:
target_state_dict[key] = source_value.detach().clone()
elif target_value != source_value:
target_state_dict[key] = source_value


def get_trainer(
config: dict[str, Any],
init_model: str | None = None,
Expand Down Expand Up @@ -512,7 +543,7 @@ def change_bias(
old_state_dict = torch.load(
input_file, map_location=env.DEVICE, weights_only=True
)
model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict))
model_state_dict = old_state_dict.get("model", old_state_dict)
model_params = model_state_dict["_extra_state"]["model_params"]
elif input_file.endswith(".pth"):
old_model = torch.jit.load(input_file, map_location=env.DEVICE)
Expand Down Expand Up @@ -545,7 +576,7 @@ def change_bias(
model_to_change = model if not multi_task else model[model_branch]
if input_file.endswith(".pt"):
wrapper = ModelWrapper(model)
wrapper.load_state_dict(old_state_dict["model"])
wrapper.load_state_dict(model_state_dict)
else:
# for .pth
model.load_state_dict(old_state_dict)
Expand Down Expand Up @@ -608,12 +639,12 @@ def change_bias(
output if output is not None else input_file.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"]
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = model_state_dict["_extra_state"]
key_prefix = f"model.{model_branch}." if multi_task else None
_update_changed_model_tensors(
model_state_dict,
wrapper.state_dict(),
key_prefix=key_prefix,
)
torch.save(old_state_dict, output_path)
else:
# for .pth
Expand Down
38 changes: 32 additions & 6 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,37 @@
log = logging.getLogger(__name__)


def _update_changed_model_tensors(
target_state_dict: dict[str, Any],
source_state_dict: dict[str, Any],
) -> None:
"""Copy changed tensors into an existing state dict without breaking aliases."""
import torch

for key, source_value in source_state_dict.items():
if key == "_extra_state":
continue
if key not in target_state_dict:
target_state_dict[key] = (
source_value.detach().clone()
if torch.is_tensor(source_value)
else source_value
)
continue
target_value = target_state_dict[key]
if torch.is_tensor(target_value) and torch.is_tensor(source_value):
if (
target_value.shape == source_value.shape
and target_value.dtype == source_value.dtype
):
if not torch.equal(target_value, source_value):
target_value.copy_(source_value)
else:
target_state_dict[key] = source_value.detach().clone()
elif target_value != source_value:
target_state_dict[key] = source_value


def _detect_lmdb_path(systems_raw: Any) -> str | None:
"""Return the LMDB path when ``systems_raw`` is a scalar LMDB string.

Expand Down Expand Up @@ -596,12 +627,7 @@ def change_bias(
output if output is not None else input_file.replace(".pt", "_updated.pt")
)
wrapper = ModelWrapper(model_to_change)
if "model" in old_state_dict:
old_state_dict["model"] = wrapper.state_dict()
old_state_dict["model"]["_extra_state"] = extra_state
else:
old_state_dict = wrapper.state_dict()
old_state_dict["_extra_state"] = extra_state
_update_changed_model_tensors(model_state_dict, wrapper.state_dict())
torch.save(old_state_dict, output_path)
elif input_file.endswith((".pte", ".pt2")):
output_path = (
Expand Down
105 changes: 105 additions & 0 deletions source/tests/pt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,21 @@
from deepmd.pt.utils.env import (
DEVICE,
)
from deepmd.pt.utils.multi_task import (
preprocess_shared_params,
)
from deepmd.pt.utils.stat import (
make_stat_input,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)

from .common import (
run_dp,
Expand Down Expand Up @@ -165,3 +174,99 @@ def tearDown(self) -> None:
os.remove(f)
if f in ["stat_files"]:
shutil.rmtree(f)


class TestChangeBiasMultitask(unittest.TestCase):
def setUp(self) -> None:
input_json = str(Path(__file__).parent / "water/multitask.json")
with open(input_json) as f:
config = json.load(f)
data_file = [str(Path(__file__).parent / "water/data/data_0")]
self.stat_files = "change-bias-multitask-stat"
os.makedirs(self.stat_files, exist_ok=True)
config["model"]["shared_dict"]["my_descriptor"] = deepcopy(
model_se_e2_a["descriptor"]
)
for model_key in config["training"]["data_dict"]:
config["training"]["data_dict"][model_key]["training_data"]["systems"] = (
data_file
)
config["training"]["data_dict"][model_key]["validation_data"]["systems"] = (
data_file
)
config["training"]["data_dict"][model_key]["stat_file"] = (
f"{self.stat_files}/{model_key}"
)
config["training"]["numb_steps"] = 0
config["model"], shared_links = preprocess_shared_params(config["model"])
config = update_deepmd_input(config, warning=True)
config = normalize(config, multi_task=True)
self.trainer = get_trainer(deepcopy(config), shared_links=shared_links)
self.model_path = Path(current_path) / "change-bias-multitask-model.pt"
self.model_path_user_bias = (
Path(current_path) / "change-bias-multitask-model-user-bias.pt"
)
torch.save({"model": self.trainer.wrapper.state_dict()}, self.model_path)

@staticmethod
def _share_storage(lhs: torch.Tensor, rhs: torch.Tensor) -> bool:
return lhs.untyped_storage().data_ptr() == rhs.untyped_storage().data_ptr()

def _find_shared_descriptor_pair(
self, state_dict: dict[str, torch.Tensor]
) -> tuple[str, str]:
Comment thread
njzjz marked this conversation as resolved.
Dismissed
for key, value in state_dict.items():
if not (
key.startswith("model.model_1.")
and "descriptor" in key
and torch.is_tensor(value)
):
continue
peer_key = key.replace("model.model_1.", "model.model_2.", 1)
if (
peer_key in state_dict
and torch.is_tensor(state_dict[peer_key])
and self._share_storage(value, state_dict[peer_key])
):
return key, peer_key
self.fail("No shared descriptor tensor pair found in multitask checkpoint.")

def test_change_bias_preserves_shared_checkpoint_storage(self) -> None:
state_dict = torch.load(
str(self.model_path), map_location=DEVICE, weights_only=True
)["model"]
shared_key, peer_key = self._find_shared_descriptor_pair(state_dict)

user_bias = [0.1, 0.2, 0.3]
run_dp(
f"dp --pt change-bias {self.model_path!s} --model-branch model_1 "
f"-b {' '.join([str(_) for _ in user_bias])} "
f"-o {self.model_path_user_bias!s}"
)
updated_state_dict = torch.load(
str(self.model_path_user_bias), map_location=DEVICE, weights_only=True
)["model"]

self.assertTrue(
self._share_storage(
updated_state_dict[shared_key], updated_state_dict[peer_key]
)
)
bias_keys = [
key
for key in updated_state_dict
if key.startswith("model.model_1.") and key.endswith("out_bias")
]
self.assertEqual(len(bias_keys), 1)
updated_bias = updated_state_dict[bias_keys[0]]
expected_bias = to_torch_tensor(np.array(user_bias)).view(updated_bias.shape)
torch.testing.assert_close(updated_bias, expected_bias)

def tearDown(self) -> None:
for f in os.listdir("."):
if f.startswith("change-bias-multitask-model") and f.endswith(".pt"):
os.remove(f)
if f in ["lcurve.out"]:
os.remove(f)
if f in [self.stat_files]:
shutil.rmtree(f)
30 changes: 30 additions & 0 deletions source/tests/pt_expt/test_change_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,36 @@ def test_change_bias_with_user_defined(self) -> None:
expected_bias = np.array(user_bias).reshape(updated_bias.shape)
np.testing.assert_allclose(updated_bias, expected_bias)

def test_change_bias_preserves_checkpoint_storage_aliases(self) -> None:
alias_path = os.path.join(self.tmpdir, "model_alias.pt")
output_path = os.path.join(self.tmpdir, "model_alias_user_bias.pt")
state_dict = torch.load(self.model_path, map_location=DEVICE, weights_only=True)
model_state = state_dict["model"]
key = "model.Default.atomic_model.descriptor.davg"
peer_key = "model.Default.atomic_model.descriptor.dstd"
self.assertIn(key, model_state)
self.assertIn(peer_key, model_state)
model_state[peer_key] = model_state[key]
torch.save(state_dict, alias_path)

user_bias = [0.1, 3.2]
run_dp(
f"dp --pt-expt change-bias {alias_path} "
f"-b {' '.join(str(v) for v in user_bias)} -o {output_path}"
)
updated_state = torch.load(output_path, map_location=DEVICE, weights_only=True)[
"model"
]
self.assertEqual(
updated_state[key].untyped_storage().data_ptr(),
updated_state[peer_key].untyped_storage().data_ptr(),
)
updated_bias = updated_state["model.Default.atomic_model.out_bias"]
expected_bias = torch.as_tensor(
user_bias, dtype=updated_bias.dtype, device=updated_bias.device
).view(updated_bias.shape)
torch.testing.assert_close(updated_bias, expected_bias)

def test_change_bias_frozen_pte(self) -> None:
from deepmd.pt_expt.model.model import (
BaseModel,
Expand Down
Loading