diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index bf5c49a4d6..560ea5a1ba 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse -import copy import io import json import logging @@ -101,6 +100,38 @@ log = logging.getLogger(__name__) +def _update_changed_model_tensors( + target_state_dict: dict[str, Any], + source_state_dict: dict[str, Any], + key_prefix: str | None = None, +) -> None: + """Copy changed tensors into an existing state dict without breaking aliases.""" + for key, source_value in source_state_dict.items(): + if key == "_extra_state": + continue + if key_prefix is not None and not key.startswith(key_prefix): + continue + if key not in target_state_dict: + target_state_dict[key] = ( + source_value.detach().clone() + if torch.is_tensor(source_value) + else source_value + ) + continue + target_value = target_state_dict[key] + if torch.is_tensor(target_value) and torch.is_tensor(source_value): + if ( + target_value.shape == source_value.shape + and target_value.dtype == source_value.dtype + ): + if not torch.equal(target_value, source_value): + target_value.copy_(source_value) + else: + target_state_dict[key] = source_value.detach().clone() + elif target_value != source_value: + target_state_dict[key] = source_value + + def get_trainer( config: dict[str, Any], init_model: str | None = None, @@ -512,7 +543,7 @@ def change_bias( old_state_dict = torch.load( input_file, map_location=env.DEVICE, weights_only=True ) - model_state_dict = copy.deepcopy(old_state_dict.get("model", old_state_dict)) + model_state_dict = old_state_dict.get("model", old_state_dict) model_params = model_state_dict["_extra_state"]["model_params"] elif input_file.endswith(".pth"): old_model = torch.jit.load(input_file, map_location=env.DEVICE) @@ -545,7 +576,7 @@ def change_bias( model_to_change = model if not multi_task else model[model_branch] if input_file.endswith(".pt"): wrapper = ModelWrapper(model) - wrapper.load_state_dict(old_state_dict["model"]) + wrapper.load_state_dict(model_state_dict) else: # for .pth model.load_state_dict(old_state_dict) @@ -608,12 +639,12 @@ def change_bias( output if output is not None else input_file.replace(".pt", "_updated.pt") ) wrapper = ModelWrapper(model) - if "model" in old_state_dict: - old_state_dict["model"] = wrapper.state_dict() - old_state_dict["model"]["_extra_state"] = model_state_dict["_extra_state"] - else: - old_state_dict = wrapper.state_dict() - old_state_dict["_extra_state"] = model_state_dict["_extra_state"] + key_prefix = f"model.{model_branch}." if multi_task else None + _update_changed_model_tensors( + model_state_dict, + wrapper.state_dict(), + key_prefix=key_prefix, + ) torch.save(old_state_dict, output_path) else: # for .pth diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index 4dd6c81d58..da28229bf4 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -41,6 +41,37 @@ log = logging.getLogger(__name__) +def _update_changed_model_tensors( + target_state_dict: dict[str, Any], + source_state_dict: dict[str, Any], +) -> None: + """Copy changed tensors into an existing state dict without breaking aliases.""" + import torch + + for key, source_value in source_state_dict.items(): + if key == "_extra_state": + continue + if key not in target_state_dict: + target_state_dict[key] = ( + source_value.detach().clone() + if torch.is_tensor(source_value) + else source_value + ) + continue + target_value = target_state_dict[key] + if torch.is_tensor(target_value) and torch.is_tensor(source_value): + if ( + target_value.shape == source_value.shape + and target_value.dtype == source_value.dtype + ): + if not torch.equal(target_value, source_value): + target_value.copy_(source_value) + else: + target_state_dict[key] = source_value.detach().clone() + elif target_value != source_value: + target_state_dict[key] = source_value + + def _detect_lmdb_path(systems_raw: Any) -> str | None: """Return the LMDB path when ``systems_raw`` is a scalar LMDB string. @@ -596,12 +627,7 @@ def change_bias( output if output is not None else input_file.replace(".pt", "_updated.pt") ) wrapper = ModelWrapper(model_to_change) - if "model" in old_state_dict: - old_state_dict["model"] = wrapper.state_dict() - old_state_dict["model"]["_extra_state"] = extra_state - else: - old_state_dict = wrapper.state_dict() - old_state_dict["_extra_state"] = extra_state + _update_changed_model_tensors(model_state_dict, wrapper.state_dict()) torch.save(old_state_dict, output_path) elif input_file.endswith((".pte", ".pt2")): output_path = ( diff --git a/source/tests/pt/test_change_bias.py b/source/tests/pt/test_change_bias.py index 58fd953656..a883449d22 100644 --- a/source/tests/pt/test_change_bias.py +++ b/source/tests/pt/test_change_bias.py @@ -30,12 +30,21 @@ from deepmd.pt.utils.env import ( DEVICE, ) +from deepmd.pt.utils.multi_task import ( + preprocess_shared_params, +) from deepmd.pt.utils.stat import ( make_stat_input, ) from deepmd.pt.utils.utils import ( to_torch_tensor, ) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) from .common import ( run_dp, @@ -165,3 +174,99 @@ def tearDown(self) -> None: os.remove(f) if f in ["stat_files"]: shutil.rmtree(f) + + +class TestChangeBiasMultitask(unittest.TestCase): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/multitask.json") + with open(input_json) as f: + config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.stat_files = "change-bias-multitask-stat" + os.makedirs(self.stat_files, exist_ok=True) + config["model"]["shared_dict"]["my_descriptor"] = deepcopy( + model_se_e2_a["descriptor"] + ) + for model_key in config["training"]["data_dict"]: + config["training"]["data_dict"][model_key]["training_data"]["systems"] = ( + data_file + ) + config["training"]["data_dict"][model_key]["validation_data"]["systems"] = ( + data_file + ) + config["training"]["data_dict"][model_key]["stat_file"] = ( + f"{self.stat_files}/{model_key}" + ) + config["training"]["numb_steps"] = 0 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=True) + config = normalize(config, multi_task=True) + self.trainer = get_trainer(deepcopy(config), shared_links=shared_links) + self.model_path = Path(current_path) / "change-bias-multitask-model.pt" + self.model_path_user_bias = ( + Path(current_path) / "change-bias-multitask-model-user-bias.pt" + ) + torch.save({"model": self.trainer.wrapper.state_dict()}, self.model_path) + + @staticmethod + def _share_storage(lhs: torch.Tensor, rhs: torch.Tensor) -> bool: + return lhs.untyped_storage().data_ptr() == rhs.untyped_storage().data_ptr() + + def _find_shared_descriptor_pair( + self, state_dict: dict[str, torch.Tensor] + ) -> tuple[str, str]: + for key, value in state_dict.items(): + if not ( + key.startswith("model.model_1.") + and "descriptor" in key + and torch.is_tensor(value) + ): + continue + peer_key = key.replace("model.model_1.", "model.model_2.", 1) + if ( + peer_key in state_dict + and torch.is_tensor(state_dict[peer_key]) + and self._share_storage(value, state_dict[peer_key]) + ): + return key, peer_key + self.fail("No shared descriptor tensor pair found in multitask checkpoint.") + + def test_change_bias_preserves_shared_checkpoint_storage(self) -> None: + state_dict = torch.load( + str(self.model_path), map_location=DEVICE, weights_only=True + )["model"] + shared_key, peer_key = self._find_shared_descriptor_pair(state_dict) + + user_bias = [0.1, 0.2, 0.3] + run_dp( + f"dp --pt change-bias {self.model_path!s} --model-branch model_1 " + f"-b {' '.join([str(_) for _ in user_bias])} " + f"-o {self.model_path_user_bias!s}" + ) + updated_state_dict = torch.load( + str(self.model_path_user_bias), map_location=DEVICE, weights_only=True + )["model"] + + self.assertTrue( + self._share_storage( + updated_state_dict[shared_key], updated_state_dict[peer_key] + ) + ) + bias_keys = [ + key + for key in updated_state_dict + if key.startswith("model.model_1.") and key.endswith("out_bias") + ] + self.assertEqual(len(bias_keys), 1) + updated_bias = updated_state_dict[bias_keys[0]] + expected_bias = to_torch_tensor(np.array(user_bias)).view(updated_bias.shape) + torch.testing.assert_close(updated_bias, expected_bias) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("change-bias-multitask-model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in [self.stat_files]: + shutil.rmtree(f) diff --git a/source/tests/pt_expt/test_change_bias.py b/source/tests/pt_expt/test_change_bias.py index 2f441f12bd..2401aa59c5 100644 --- a/source/tests/pt_expt/test_change_bias.py +++ b/source/tests/pt_expt/test_change_bias.py @@ -279,6 +279,36 @@ def test_change_bias_with_user_defined(self) -> None: expected_bias = np.array(user_bias).reshape(updated_bias.shape) np.testing.assert_allclose(updated_bias, expected_bias) + def test_change_bias_preserves_checkpoint_storage_aliases(self) -> None: + alias_path = os.path.join(self.tmpdir, "model_alias.pt") + output_path = os.path.join(self.tmpdir, "model_alias_user_bias.pt") + state_dict = torch.load(self.model_path, map_location=DEVICE, weights_only=True) + model_state = state_dict["model"] + key = "model.Default.atomic_model.descriptor.davg" + peer_key = "model.Default.atomic_model.descriptor.dstd" + self.assertIn(key, model_state) + self.assertIn(peer_key, model_state) + model_state[peer_key] = model_state[key] + torch.save(state_dict, alias_path) + + user_bias = [0.1, 3.2] + run_dp( + f"dp --pt-expt change-bias {alias_path} " + f"-b {' '.join(str(v) for v in user_bias)} -o {output_path}" + ) + updated_state = torch.load(output_path, map_location=DEVICE, weights_only=True)[ + "model" + ] + self.assertEqual( + updated_state[key].untyped_storage().data_ptr(), + updated_state[peer_key].untyped_storage().data_ptr(), + ) + updated_bias = updated_state["model.Default.atomic_model.out_bias"] + expected_bias = torch.as_tensor( + user_bias, dtype=updated_bias.dtype, device=updated_bias.device + ).view(updated_bias.shape) + torch.testing.assert_close(updated_bias, expected_bias) + def test_change_bias_frozen_pte(self) -> None: from deepmd.pt_expt.model.model import ( BaseModel,