From 512eeb60bacc2d88dcd9738e468ce26e0198d5be Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 00:51:40 +0800 Subject: [PATCH 01/56] feat(pt_expt): multi-task training support --- deepmd/dpmodel/descriptor/repformers.py | 8 + deepmd/dpmodel/fitting/general_fitting.py | 18 +- deepmd/dpmodel/utils/env_mat_stat.py | 69 + deepmd/pt/model/task/fitting.py | 10 - deepmd/pt_expt/descriptor/dpa1.py | 28 + deepmd/pt_expt/descriptor/dpa2.py | 44 + deepmd/pt_expt/descriptor/dpa3.py | 28 + deepmd/pt_expt/descriptor/hybrid.py | 28 +- deepmd/pt_expt/descriptor/se_atten_v2.py | 7 + deepmd/pt_expt/descriptor/se_e2_a.py | 27 + deepmd/pt_expt/descriptor/se_r.py | 27 + deepmd/pt_expt/descriptor/se_t.py | 27 + deepmd/pt_expt/descriptor/se_t_tebd.py | 28 + deepmd/pt_expt/entrypoints/main.py | 230 +- deepmd/pt_expt/fitting/ener_fitting.py | 10 +- deepmd/pt_expt/fitting/invar_fitting.py | 121 +- deepmd/pt_expt/train/training.py | 1003 +++++--- deepmd/pt_expt/train/wrapper.py | 155 +- deepmd/pt_expt/utils/finetune.py | 94 +- deepmd/pt_expt/utils/multi_task.py | 116 + source/tests/pt/test_fitting_stat.py | 4 +- .../descriptor/test_descrpt_stat_merge.py | 1328 +++++++++++ source/tests/pt_expt/descriptor/test_dpa1.py | 43 + source/tests/pt_expt/descriptor/test_dpa3.py | 51 + .../tests/pt_expt/descriptor/test_hybrid.py | 53 + .../pt_expt/descriptor/test_se_atten_v2.py | 43 + source/tests/pt_expt/descriptor/test_se_r.py | 40 + source/tests/pt_expt/descriptor/test_se_t.py | 40 + .../pt_expt/descriptor/test_se_t_tebd.py | 31 + .../pt_expt/fitting/test_fitting_stat.py | 541 +++++ source/tests/pt_expt/test_change_bias.py | 2 +- source/tests/pt_expt/test_finetune.py | 6 +- source/tests/pt_expt/test_multitask.py | 2099 +++++++++++++++++ source/tests/pt_expt/test_training.py | 84 +- source/tests/pt_expt/test_training_ddp.py | 1666 +++++++++++++ 35 files changed, 7705 insertions(+), 404 deletions(-) create mode 100644 deepmd/pt_expt/utils/multi_task.py create mode 100644 source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py create mode 100644 source/tests/pt_expt/test_multitask.py create mode 100644 source/tests/pt_expt/test_training_ddp.py diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 774cf33d72..5881b3a0b3 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -345,6 +345,14 @@ def get_rcut(self) -> float: """Returns the cut-off radius.""" return self.rcut + def get_rcut_smth(self) -> float: + """Returns the radius where the neighbor information starts to smoothly decay to 0.""" + return self.rcut_smth + + def get_env_protection(self) -> float: + """Returns the protection of building environment matrix.""" + return self.env_protection + def get_nsel(self) -> int: """Returns the number of selected atoms in the cut-off radius.""" return sum(self.sel) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 3a3012440c..810c62f927 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -255,6 +255,7 @@ def compute_input_stats( stat_file_path : Optional[DPPath] The path to the stat file. """ + self._param_stats: dict[str, list[StatItem]] = {} if self.numb_fparam == 0 and self.numb_aparam == 0: # skip data statistics return @@ -296,6 +297,7 @@ def compute_input_stats( self._save_param_stats_to_file( stat_file_path, "fparam", fparam_stats ) + self._param_stats["fparam"] = fparam_stats fparam_avg = np.array( [s.compute_avg() for s in fparam_stats], dtype=np.float64 ) @@ -362,6 +364,7 @@ def compute_input_stats( self._save_param_stats_to_file( stat_file_path, "aparam", aparam_stats ) + self._param_stats["aparam"] = aparam_stats aparam_avg = np.array( [s.compute_avg() for s in aparam_stats], dtype=np.float64 ) @@ -407,6 +410,10 @@ def _load_param_stats_from_file( for ii in range(numb) ] + def get_param_stats(self) -> dict[str, list[StatItem]]: + """Get the stored fparam/aparam statistics (populated by compute_input_stats).""" + return getattr(self, "_param_stats", {}) + @abstractmethod def _net_out_dim(self) -> int: """Set the FittingNet output dim.""" @@ -666,11 +673,7 @@ def _call_common( # check fparam dim, concate to input descriptor if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" - if fparam.shape[-1] != self.numb_fparam: - raise ValueError( - f"get an input fparam of dim {fparam.shape[-1]}, " - f"which is not consistent with {self.numb_fparam}." - ) + fparam = xp.reshape(fparam, (nf, self.numb_fparam)) fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1) @@ -687,11 +690,6 @@ def _call_common( # check aparam dim, concate to input descriptor if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" - if aparam.shape[-1] != self.numb_aparam: - raise ValueError( - f"get an input aparam of dim {aparam.shape[-1]}, " - f"which is not consistent with {self.numb_aparam}." - ) aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...] xx = xp.concat( diff --git a/deepmd/dpmodel/utils/env_mat_stat.py b/deepmd/dpmodel/utils/env_mat_stat.py index 721723821e..8d53602b18 100644 --- a/deepmd/dpmodel/utils/env_mat_stat.py +++ b/deepmd/dpmodel/utils/env_mat_stat.py @@ -40,6 +40,75 @@ ) +def merge_env_stat( + base_obj: Union["Descriptor", "DescriptorBlock"], + link_obj: Union["Descriptor", "DescriptorBlock"], + model_prob: float = 1.0, +) -> None: + """Merge descriptor env mat stats from link_obj into base_obj. + + Uses probability-weighted merging: merged = base_stats + link_stats * model_prob, + where model_prob = link_prob / base_prob. + Mutates base_obj.stats for chaining (3+ models). + + Parameters + ---------- + base_obj : Descriptor or DescriptorBlock + The base descriptor whose stats will be updated. + link_obj : Descriptor or DescriptorBlock + The linked descriptor whose stats will be merged in. + model_prob : float + The probability weight ratio (link_prob / base_prob). + """ + if ( + getattr(base_obj, "stats", None) is None + or getattr(link_obj, "stats", None) is None + ): + return + if getattr(base_obj, "set_stddev_constant", False) and getattr( + base_obj, "set_davg_zero", False + ): + return + + # Weighted merge of StatItem objects + base_stats = base_obj.stats + link_stats = link_obj.stats + merged_stats = {} + for kk in base_stats: + merged_stats[kk] = base_stats[kk] + link_stats[kk] * model_prob + + # Compute mean/stddev from merged stats + base_env = EnvMatStatSe(base_obj) + base_env.stats = merged_stats + mean, stddev = base_env() + + # Update base_obj stats for chaining + base_obj.stats = merged_stats + + # Update buffers in-place: davg/dstd (simple) or mean/stddev (blocks) + # mean/stddev are numpy arrays; convert to match the buffer's backend + if hasattr(base_obj, "davg"): + xp = array_api_compat.array_namespace(base_obj.dstd) + device = array_api_compat.device(base_obj.dstd) + if not getattr(base_obj, "set_davg_zero", False): + base_obj.davg[...] = xp.asarray( + mean, dtype=base_obj.davg.dtype, device=device + ) + base_obj.dstd[...] = xp.asarray( + stddev, dtype=base_obj.dstd.dtype, device=device + ) + elif hasattr(base_obj, "mean"): + xp = array_api_compat.array_namespace(base_obj.stddev) + device = array_api_compat.device(base_obj.stddev) + if not getattr(base_obj, "set_davg_zero", False): + base_obj.mean[...] = xp.asarray( + mean, dtype=base_obj.mean.dtype, device=device + ) + base_obj.stddev[...] = xp.asarray( + stddev, dtype=base_obj.stddev.dtype, device=device + ) + + class EnvMatStat(BaseEnvMatStat): def compute_stat(self, env_mat: dict[str, Array]) -> dict[str, StatItem]: """Compute the statistics of the environment matrix for a single system. diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 7aac7b9a29..a07b408d72 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -779,11 +779,6 @@ def _forward_common( assert fparam is not None, "fparam should not be None" assert self.fparam_avg is not None assert self.fparam_inv_std is not None - if fparam.shape[-1] != self.numb_fparam: - raise ValueError( - "get an input fparam of dim {fparam.shape[-1]}, ", - "which is not consistent with {self.numb_fparam}.", - ) fparam = fparam.view([nf, self.numb_fparam]) nb, _ = fparam.shape t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) @@ -804,11 +799,6 @@ def _forward_common( assert aparam is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None - if aparam.shape[-1] != self.numb_aparam: - raise ValueError( - f"get an input aparam of dim {aparam.shape[-1]}, ", - f"which is not consistent with {self.numb_aparam}.", - ) aparam = aparam.view([nf, -1, self.numb_aparam]) nb, nloc, _ = aparam.shape t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) diff --git a/deepmd/pt_expt/descriptor/dpa1.py b/deepmd/pt_expt/descriptor/dpa1.py index d72a12267a..01df91abd6 100644 --- a/deepmd/pt_expt/descriptor/dpa1.py +++ b/deepmd/pt_expt/descriptor/dpa1.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -26,6 +29,31 @@ class DescrptDPA1(DescrptDPA1DP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding and se_atten (all modules and buffers). + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.se_atten, self.se_atten, model_prob) + self._modules["se_atten"] = base_class._modules["se_atten"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py index 0d389af070..1723df5a30 100644 --- a/deepmd/pt_expt/descriptor/dpa2.py +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -14,6 +14,9 @@ build_multiple_neighbor_list, get_multiple_nlist_key, ) +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -30,6 +33,47 @@ class DescrptDPA2(DescrptDPA2DP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: "DescrptDPA2", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding, repinit, repinit_three_body, + g1_shape_tranform, and repformers. + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.repinit, self.repinit, model_prob) + if self.use_three_body and "repinit_three_body" in base_class._modules: + merge_env_stat( + base_class.repinit_three_body, + self.repinit_three_body, + model_prob, + ) + merge_env_stat(base_class.repformers, self.repformers, model_prob) + self._modules["repinit"] = base_class._modules["repinit"] + if self.use_three_body and "repinit_three_body" in base_class._modules: + self._modules["repinit_three_body"] = base_class._modules[ + "repinit_three_body" + ] + self._modules["g1_shape_tranform"] = base_class._modules[ + "g1_shape_tranform" + ] + self._modules["repformers"] = base_class._modules["repformers"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/dpa3.py b/deepmd/pt_expt/descriptor/dpa3.py index 7119f043bd..fec047e3fd 100644 --- a/deepmd/pt_expt/descriptor/dpa3.py +++ b/deepmd/pt_expt/descriptor/dpa3.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -16,3 +19,28 @@ @torch_module class DescrptDPA3(DescrptDPA3DP): _update_sel_cls = UpdateSel + + def share_params( + self, + base_class: "DescrptDPA3", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding and repflows. + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.repflows, self.repflows, model_prob) + self._modules["repflows"] = base_class._modules["repflows"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError diff --git a/deepmd/pt_expt/descriptor/hybrid.py b/deepmd/pt_expt/descriptor/hybrid.py index 9ec5570c7c..07eddd2e01 100644 --- a/deepmd/pt_expt/descriptor/hybrid.py +++ b/deepmd/pt_expt/descriptor/hybrid.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP from deepmd.pt_expt.common import ( @@ -12,4 +15,27 @@ @BaseDescriptor.register("hybrid") @torch_module class DescrptHybrid(DescrptHybridDP): - pass + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all sub-descriptors. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + for ii, des in enumerate(self.descrpt_list): + self.descrpt_list[ii].share_params( + base_class.descrpt_list[ii], + shared_level, + model_prob=model_prob, + resume=resume, + ) + else: + raise NotImplementedError diff --git a/deepmd/pt_expt/descriptor/se_atten_v2.py b/deepmd/pt_expt/descriptor/se_atten_v2.py index 2c4be7d3ae..e0eb3acac3 100644 --- a/deepmd/pt_expt/descriptor/se_atten_v2.py +++ b/deepmd/pt_expt/descriptor/se_atten_v2.py @@ -22,6 +22,13 @@ class DescrptSeAttenV2(DescrptSeAttenV2DP): _update_sel_cls = UpdateSel + def share_params(self, *args: Any, **kwargs: Any) -> None: + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + return DescrptDPA1.share_params(self, *args, **kwargs) + def enable_compression(self, *args: Any, **kwargs: Any) -> None: from deepmd.pt_expt.descriptor.dpa1 import ( DescrptDPA1, diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 38be83c46c..61d611036e 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -26,6 +29,30 @@ class DescrptSeA(DescrptSeADP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all modules and buffers. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, self, model_prob) + for item in self._modules: + self._modules[item] = base_class._modules[item] + for item in self._buffers: + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index c2fd34e6b5..22302f54e6 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -26,6 +29,30 @@ class DescrptSeR(DescrptSeRDP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all modules and buffers. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, self, model_prob) + for item in self._modules: + self._modules[item] = base_class._modules[item] + for item in self._buffers: + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 806d5eca7a..061306f281 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -27,6 +30,30 @@ class DescrptSeT(DescrptSeTDP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all modules and buffers. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + if not resume: + merge_env_stat(base_class, self, model_prob) + for item in self._modules: + self._modules[item] = base_class._modules[item] + for item in self._buffers: + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 385bf0dfb6..c0ae308971 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -9,6 +9,9 @@ cast_precision, ) from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.dpmodel.utils.env_mat_stat import ( + merge_env_stat, +) from deepmd.pt_expt.common import ( torch_module, ) @@ -25,6 +28,31 @@ class DescrptSeTTebd(DescrptSeTTebdDP): _update_sel_cls = UpdateSel + def share_params( + self, + base_class: "DescrptSeTTebd", + shared_level: int, + model_prob: float = 1.0, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share type_embedding and se_ttebd. + Level 1: share type_embedding only. + """ + assert self.__class__ == base_class.__class__, ( + "Only descriptors of the same type can share params!" + ) + if shared_level == 0: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + if not resume: + merge_env_stat(base_class.se_ttebd, self.se_ttebd, model_prob) + self._modules["se_ttebd"] = base_class._modules["se_ttebd"] + elif shared_level == 1: + self._modules["type_embedding"] = base_class._modules["type_embedding"] + else: + raise NotImplementedError + def enable_compression( self, min_nbor_dist: float, diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index 3c82ff13aa..08fc6682a7 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -4,6 +4,7 @@ import argparse import json import logging +import os from pathlib import ( Path, ) @@ -40,54 +41,112 @@ def get_trainer( restart_model: str | None = None, finetune_model: str | None = None, finetune_links: dict | None = None, + shared_links: dict | None = None, ) -> training.Trainer: """Build a :class:`training.Trainer` from a normalised config.""" model_params = config["model"] training_params = config["training"] - type_map = model_params["type_map"] + multi_task = "model_dict" in model_params - # ----- training data ------------------------------------------------ - training_dataset_params = training_params["training_data"] - training_systems = process_systems( - training_dataset_params["systems"], - patterns=training_dataset_params.get("rglob_patterns", None), - ) - train_data = DeepmdDataSystem( - systems=training_systems, - batch_size=training_dataset_params["batch_size"], - test_size=1, - type_map=type_map, - trn_all_set=True, - sys_probs=training_dataset_params.get("sys_probs", None), - auto_prob_style=training_dataset_params.get("auto_prob", "prob_sys_size"), - ) + if not multi_task: + type_map = model_params["type_map"] - # ----- validation data ---------------------------------------------- - validation_data = None - validation_dataset_params = training_params.get("validation_data", None) - if validation_dataset_params is not None: - val_systems = process_systems( - validation_dataset_params["systems"], - patterns=validation_dataset_params.get("rglob_patterns", None), + # ----- training data ------------------------------------------------ + training_dataset_params = training_params["training_data"] + training_systems = process_systems( + training_dataset_params["systems"], + patterns=training_dataset_params.get("rglob_patterns", None), ) - validation_data = DeepmdDataSystem( - systems=val_systems, - batch_size=validation_dataset_params["batch_size"], + train_data = DeepmdDataSystem( + systems=training_systems, + batch_size=training_dataset_params["batch_size"], test_size=1, type_map=type_map, trn_all_set=True, + sys_probs=training_dataset_params.get("sys_probs", None), + auto_prob_style=training_dataset_params.get("auto_prob", "prob_sys_size"), ) - # ----- stat file path ----------------------------------------------- - stat_file_path = training_params.get("stat_file", None) - if stat_file_path is not None: - if not Path(stat_file_path).exists(): - if stat_file_path.endswith((".h5", ".hdf5")): - with h5py.File(stat_file_path, "w"): - pass + # ----- validation data ---------------------------------------------- + validation_data = None + validation_dataset_params = training_params.get("validation_data", None) + if validation_dataset_params is not None: + val_systems = process_systems( + validation_dataset_params["systems"], + patterns=validation_dataset_params.get("rglob_patterns", None), + ) + validation_data = DeepmdDataSystem( + systems=val_systems, + batch_size=validation_dataset_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + ) + + # ----- stat file path ----------------------------------------------- + stat_file_path = training_params.get("stat_file", None) + if stat_file_path is not None: + if not Path(stat_file_path).exists(): + if stat_file_path.endswith((".h5", ".hdf5")): + with h5py.File(stat_file_path, "w"): + pass + else: + Path(stat_file_path).mkdir() + stat_file_path = DPPath(stat_file_path, "a") + else: + # Multi-task: build per-task data systems + train_data = {} + validation_data = {} + stat_file_path = {} + for model_key in model_params["model_dict"]: + type_map = model_params["model_dict"][model_key]["type_map"] + data_params = training_params["data_dict"][model_key] + + # training data + td_params = data_params["training_data"] + training_systems = process_systems( + td_params["systems"], + patterns=td_params.get("rglob_patterns", None), + ) + train_data[model_key] = DeepmdDataSystem( + systems=training_systems, + batch_size=td_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + sys_probs=td_params.get("sys_probs", None), + auto_prob_style=td_params.get("auto_prob", "prob_sys_size"), + ) + + # validation data + vd_params = data_params.get("validation_data", None) + if vd_params is not None: + val_systems = process_systems( + vd_params["systems"], + patterns=vd_params.get("rglob_patterns", None), + ) + validation_data[model_key] = DeepmdDataSystem( + systems=val_systems, + batch_size=vd_params["batch_size"], + test_size=1, + type_map=type_map, + trn_all_set=True, + ) + else: + validation_data[model_key] = None + + # stat file + _sf = data_params.get("stat_file", None) + if _sf is not None: + if not Path(_sf).exists(): + if _sf.endswith((".h5", ".hdf5")): + with h5py.File(_sf, "w"): + pass + else: + Path(_sf).mkdir(parents=True) + stat_file_path[model_key] = DPPath(_sf, "a") else: - Path(stat_file_path).mkdir() - stat_file_path = DPPath(stat_file_path, "a") + stat_file_path[model_key] = None trainer = training.Trainer( config, @@ -98,6 +157,7 @@ def get_trainer( restart_model=restart_model, finetune_model=finetune_model, finetune_links=finetune_links, + shared_links=shared_links, ) return trainer @@ -151,6 +211,19 @@ def train( if restart is not None and not restart.endswith(".pt"): restart += ".pt" + # Multi-task detection and shared params preprocessing + multi_task = "model_dict" in config.get("model", {}) + shared_links = None + if multi_task: + from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, + ) + + config["model"], shared_links = preprocess_shared_params(config["model"]) + assert "RANDOM" not in config["model"]["model_dict"], ( + "Model name can not be 'RANDOM' in multi-task mode!" + ) + # update fine-tuning config finetune_links = None if finetune is not None: @@ -174,7 +247,7 @@ def train( # argcheck config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") - config = normalize(config) + config = normalize(config, multi_task=multi_task) # neighbour stat if not skip_neighbor_stat: @@ -182,28 +255,56 @@ def train( "Calculate neighbor statistics... " "(add --skip-neighbor-stat to skip this step)" ) - type_map = config["model"].get("type_map") - train_data = get_data(config["training"]["training_data"], 0, type_map, None) from deepmd.pt_expt.model import ( BaseModel, ) - config["model"], _min_nbor_dist = BaseModel.update_sel( - train_data, type_map, config["model"] - ) + if not multi_task: + type_map = config["model"].get("type_map") + train_data = get_data( + config["training"]["training_data"], 0, type_map, None + ) + config["model"], _min_nbor_dist = BaseModel.update_sel( + train_data, type_map, config["model"] + ) + else: + for model_key in config["model"]["model_dict"]: + type_map = config["model"]["model_dict"][model_key]["type_map"] + train_data = get_data( + config["training"]["data_dict"][model_key]["training_data"], + 0, + type_map, + None, + ) + config["model"]["model_dict"][model_key], _min_nbor_dist = ( + BaseModel.update_sel( + train_data, + type_map, + config["model"]["model_dict"][model_key], + ) + ) with open(output, "w") as fp: json.dump(config, fp, indent=4) + import torch.distributed as dist + + if os.environ.get("LOCAL_RANK") is not None: + dist.init_process_group(backend="cuda:nccl,cpu:gloo") + trainer = get_trainer( config, init_model, restart, finetune_model=finetune, finetune_links=finetune_links, + shared_links=shared_links, ) trainer.run() + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() + def freeze( model: str, @@ -219,7 +320,7 @@ def freeze( output : str Path for the output .pte file. head : str or None - Head to freeze in multi-task mode (not yet supported). + Head to freeze in multi-task mode. """ import torch @@ -248,18 +349,43 @@ def freeze( ) model_params = extra_state["model_params"] - if head is not None and "model_dict" in model_params: - raise NotImplementedError( - "Multi-task freeze is not yet supported for the pt_expt backend." - ) + multi_task = "model_dict" in model_params + if multi_task: + if head is None: + raise ValueError( + "Multi-task model requires --head to specify which model to freeze. " + f"Available heads: {list(model_params['model_dict'].keys())}" + ) + if head not in model_params["model_dict"]: + raise ValueError( + f"Head '{head}' not found. " + f"Available: {list(model_params['model_dict'].keys())}" + ) + # Build full multi-task wrapper, load weights, extract single head + model_dict = {} + for key in model_params["model_dict"]: + from copy import ( + deepcopy, + ) - m = get_model(model_params) - wrapper = ModelWrapper(m) - wrapper.load_state_dict(state_dict) - m.eval() + model_dict[key] = get_model(deepcopy(model_params["model_dict"][key])) + wrapper = ModelWrapper(model_dict) + wrapper.load_state_dict(state_dict) - model_dict = m.serialize() - deserialize_to_file(output, {"model": model_dict, "model_def_script": model_params}) + m = wrapper.model[head] + single_model_params = model_params["model_dict"][head] + else: + m = get_model(model_params) + wrapper = ModelWrapper(m) + wrapper.load_state_dict(state_dict) + single_model_params = model_params + + m.eval() + model_dict_serialized = m.serialize() + deserialize_to_file( + output, + {"model": model_dict_serialized, "model_def_script": single_model_params}, + ) log.info("Saved frozen model to %s", output) diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index f778af8fec..e931c72b5c 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( @@ -13,4 +16,9 @@ @BaseFitting.register("ener") @torch_module class EnergyFittingNet(EnergyFittingNetDP): - pass + def share_params(self, *args: Any, **kwargs: Any) -> None: + from deepmd.pt_expt.fitting.invar_fitting import ( + InvarFitting, + ) + + return InvarFitting.share_params(self, *args, **kwargs) diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index f13fe2afbb..35eed09bc3 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -1,4 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import numpy as np +import torch from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( @@ -7,9 +13,122 @@ from deepmd.pt_expt.fitting.base_fitting import ( BaseFitting, ) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) @BaseFitting.register("invar") @torch_module class InvarFitting(InvarFittingDP): - pass + def share_params( + self, + base_class: Any, + shared_level: int, + model_prob: float = 1.0, + protection: float = 1e-2, + resume: bool = False, + ) -> None: + """Share parameters with base_class for multi-task training. + + Level 0: share all sub-modules and buffers except bias_atom_e + and case_embd. When not resuming, fparam/aparam statistics are + merged using probability-weighted averaging (matching PT). + """ + assert self.__class__ == base_class.__class__, ( + "Only fitting nets of the same type can share params!" + ) + if shared_level == 0: + # --- weighted fparam stat merging --- + if self.numb_fparam > 0: + if not resume: + base_stats = base_class.get_param_stats().get("fparam", []) + self_stats = self.get_param_stats().get("fparam", []) + if base_stats and self_stats: + assert len(base_stats) == self.numb_fparam + merged = [ + base_stats[ii] + self_stats[ii] * model_prob + for ii in range(self.numb_fparam) + ] + fparam_avg = np.array( + [s.compute_avg() for s in merged], dtype=np.float64 + ) + fparam_std = np.array( + [s.compute_std(protection=protection) for s in merged], + dtype=np.float64, + ) + fparam_inv_std = 1.0 / fparam_std + base_class.fparam_avg.copy_( + torch.tensor( + fparam_avg, + device=DEVICE, + dtype=base_class.fparam_avg.dtype, + ) + ) + base_class.fparam_inv_std.copy_( + torch.tensor( + fparam_inv_std, + device=DEVICE, + dtype=base_class.fparam_inv_std.dtype, + ) + ) + # update stored stats so chained share_params works + base_class._param_stats["fparam"] = merged + self._buffers["fparam_avg"] = base_class._buffers["fparam_avg"] + self._buffers["fparam_inv_std"] = base_class._buffers["fparam_inv_std"] + + # --- weighted aparam stat merging --- + if self.numb_aparam > 0: + if not resume: + base_stats = base_class.get_param_stats().get("aparam", []) + self_stats = self.get_param_stats().get("aparam", []) + if base_stats and self_stats: + assert len(base_stats) == self.numb_aparam + merged = [ + base_stats[ii] + self_stats[ii] * model_prob + for ii in range(self.numb_aparam) + ] + aparam_avg = np.array( + [s.compute_avg() for s in merged], dtype=np.float64 + ) + aparam_std = np.array( + [s.compute_std(protection=protection) for s in merged], + dtype=np.float64, + ) + aparam_inv_std = 1.0 / aparam_std + base_class.aparam_avg.copy_( + torch.tensor( + aparam_avg, + device=DEVICE, + dtype=base_class.aparam_avg.dtype, + ) + ) + base_class.aparam_inv_std.copy_( + torch.tensor( + aparam_inv_std, + device=DEVICE, + dtype=base_class.aparam_inv_std.dtype, + ) + ) + base_class._param_stats["aparam"] = merged + self._buffers["aparam_avg"] = base_class._buffers["aparam_avg"] + self._buffers["aparam_inv_std"] = base_class._buffers["aparam_inv_std"] + + # --- share modules and remaining buffers --- + for item in list(self._modules): + if item in ("bias_atom_e", "case_embd"): + continue + self._modules[item] = base_class._modules[item] + for item in list(self._buffers): + if item in ( + "bias_atom_e", + "case_embd", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ): + continue + self._buffers[item] = base_class._buffers[item] + else: + raise NotImplementedError diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 8f32ca660c..779d786de9 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -22,6 +22,7 @@ import numpy as np import torch +import torch.distributed as dist from deepmd.dpmodel.common import ( to_numpy_array, @@ -141,6 +142,28 @@ def get_additional_data_requirement(_model: Any) -> list[DataRequirementItem]: # --------------------------------------------------------------------------- +def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None: + """Remove ``aten.detach.default`` nodes from an FX graph in-place. + + ``make_fx`` inserts these nodes when recording saved tensors from the + autograd backward pass (``autograd.grad`` with ``create_graph=True``). + The detach breaks the gradient connection between saved activations and + model parameters, causing incorrect second-order derivatives — e.g. + bias gradients become zero for force-loss training. + + Removing these nodes restores the gradient path so that higher-order + derivatives flow correctly through the decomposed backward ops. + """ + graph = gm.graph + for node in list(graph.nodes): + if node.op == "call_function" and node.target == torch.ops.aten.detach.default: + input_node = node.args[0] + node.replace_all_uses_with(input_node) + graph.erase_node(node) + graph.lint() + gm.recompile() + + def _trace_and_compile( model: torch.nn.Module, ext_coord: torch.Tensor, @@ -156,7 +179,7 @@ def _trace_and_compile( Parameters ---------- model : torch.nn.Module - The (uncompiled) model. Temporarily set to eval mode for tracing. + The (uncompiled) model. ext_coord, ext_atype, nlist, mapping, fparam, aparam Sample tensors (already padded to the desired max_nall). compile_opts : dict @@ -187,7 +210,7 @@ def fn( fparam: torch.Tensor | None, aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: - extended_coord = extended_coord.detach().requires_grad_(True) + extended_coord = extended_coord.requires_grad_(True) return model.forward_lower( extended_coord, extended_atype, @@ -202,13 +225,15 @@ def fn( # change at runtime, the caller catches the error and retraces. traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam) + # make_fx inserts aten.detach.default for saved tensors used in the + # decomposed autograd.grad backward ops. These detach nodes break + # second-order gradient flow (d(force)/d(params) for force training). + # Removing them restores correct higher-order derivatives. + _remove_detach_nodes(traced_lower) + if not was_training: model.eval() - # The inductor backend does not propagate gradients through the - # make_fx-decomposed autograd.grad ops (second-order gradients for - # force training). Use "aot_eager" which correctly preserves the - # gradient chain while still benefiting from make_fx decomposition. if "backend" not in compile_opts: compile_opts["backend"] = "aot_eager" compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts) @@ -387,34 +412,38 @@ class Trainer: """Training driver for the pt_expt backend. Uses ``DeepmdDataSystem`` for data loading (numpy batches converted - to torch tensors at the boundary). Single-task, single-GPU only. + to torch tensors at the boundary). Supports single-task and multi-task + training. Single-GPU only. Parameters ---------- config : dict Full training configuration. - training_data : DeepmdDataSystem - Training data. - stat_file_path : DPPath or None + training_data : DeepmdDataSystem or dict + Training data. Dict of ``{model_key: DeepmdDataSystem}`` for multi-task. + stat_file_path : DPPath or dict or None Path for saving / loading statistics. - validation_data : DeepmdDataSystem or None + validation_data : DeepmdDataSystem or dict or None Validation data. init_model : str or None Path to a checkpoint to initialise weights from. restart_model : str or None Path to a checkpoint to *restart* training from (restores step + optimiser). + shared_links : dict or None + Parameter sharing rules for multi-task training. """ def __init__( self, config: dict[str, Any], - training_data: DeepmdDataSystem, - stat_file_path: DPPath | None = None, - validation_data: DeepmdDataSystem | None = None, + training_data: DeepmdDataSystem | dict, + stat_file_path: DPPath | dict | None = None, + validation_data: DeepmdDataSystem | dict | None = None, init_model: str | None = None, restart_model: str | None = None, finetune_model: str | None = None, finetune_links: dict | None = None, + shared_links: dict | None = None, ) -> None: if finetune_model is not None and ( init_model is not None or restart_model is not None @@ -429,6 +458,18 @@ def __init__( model_params = config["model"] training_params = config["training"] + # Multi-task detection + self.multi_task = "model_dict" in model_params + self.model_keys = ( + list(model_params["model_dict"]) if self.multi_task else ["Default"] + ) + self.num_model = len(self.model_keys) + + # Distributed training detection + self.is_distributed = dist.is_available() and dist.is_initialized() + self.rank = dist.get_rank() if self.is_distributed else 0 + self.world_size = dist.get_world_size() if self.is_distributed else 1 + # Iteration config self.num_steps = training_params["numb_steps"] self.disp_file = training_params.get("disp_file", "lcurve.out") @@ -440,47 +481,137 @@ def __init__( self.lcurve_should_print_header = True # Model --------------------------------------------------------------- - self.model = get_model(deepcopy(model_params)).to(DEVICE) + if not self.multi_task: + self.model = get_model(deepcopy(model_params)).to(DEVICE) + else: + self.model = {} + do_case_embd, case_embd_index = _get_case_embd_config(model_params) + for model_key in self.model_keys: + self.model[model_key] = get_model( + deepcopy(model_params["model_dict"][model_key]) + ).to(DEVICE) + if do_case_embd and not resuming: + self.model[model_key].set_case_embd(case_embd_index[model_key]) # Loss ---------------------------------------------------------------- - self.loss = get_loss( - config.get("loss", {}), - config["learning_rate"]["start_lr"], - len(model_params["type_map"]), - self.model, - ) + if not self.multi_task: + self.loss = get_loss( + config.get("loss", {}), + config["learning_rate"]["start_lr"], + len(model_params["type_map"]), + self.model, + ) + else: + self.loss = {} + for model_key in self.model_keys: + loss_param = config["loss_dict"][model_key] + lr_param = config["learning_rate"]["start_lr"] + ntypes = len(model_params["model_dict"][model_key]["type_map"]) + self.loss[model_key] = get_loss( + loss_param, lr_param, ntypes, self.model[model_key] + ) # Data requirements --------------------------------------------------- - data_requirement = self.loss.label_requirement - data_requirement += get_additional_data_requirement(self.model) - training_data.add_data_requirements(data_requirement) - if validation_data is not None: - validation_data.add_data_requirements(data_requirement) - - self.training_data = training_data - self.validation_data = validation_data - self.valid_numb_batch = training_params.get("validation_data", {}).get( - "numb_btch", 1 - ) + if not self.multi_task: + data_requirement = self.loss.label_requirement + data_requirement += get_additional_data_requirement(self.model) + training_data.add_data_requirements(data_requirement) + if validation_data is not None: + validation_data.add_data_requirements(data_requirement) + + self.training_data = training_data + self.validation_data = validation_data + self.valid_numb_batch = training_params.get("validation_data", {}).get( + "numb_btch", 1 + ) + else: + self.training_data = {} + self.validation_data = {} + self.valid_numb_batch = {} + for model_key in self.model_keys: + data_requirement = self.loss[model_key].label_requirement + data_requirement += get_additional_data_requirement( + self.model[model_key] + ) + training_data[model_key].add_data_requirements(data_requirement) + if validation_data[model_key] is not None: + validation_data[model_key].add_data_requirements(data_requirement) + self.training_data[model_key] = training_data[model_key] + self.validation_data[model_key] = validation_data[model_key] + self.valid_numb_batch[model_key] = ( + training_params["data_dict"][model_key] + .get("validation_data", {}) + .get("numb_btch", 1) + ) # Statistics ---------------------------------------------------------- - data_stat_nbatch = model_params.get("data_stat_nbatch", 10) + if not self.multi_task: + data_stat_nbatch = model_params.get("data_stat_nbatch", 10) - @functools.lru_cache - def get_sample() -> list[dict[str, np.ndarray]]: - return make_stat_input(training_data, data_stat_nbatch) + @functools.lru_cache + def get_sample() -> list[dict[str, np.ndarray]]: + return make_stat_input(training_data, data_stat_nbatch) - finetune_has_new_type = ( - finetune_model is not None - and finetune_links is not None - and finetune_links["Default"].get_has_new_type() - ) - if not resuming or finetune_has_new_type: - self.model.compute_or_load_stat( - sampled_func=get_sample, - stat_file_path=stat_file_path, + finetune_has_new_type = ( + finetune_model is not None + and finetune_links is not None + and finetune_links["Default"].get_has_new_type() + ) + if (not resuming or finetune_has_new_type) and self.rank == 0: + self.model.compute_or_load_stat( + sampled_func=get_sample, + stat_file_path=stat_file_path, + ) + if self.is_distributed: + self._broadcast_model_stat(self.model) + else: + self._finetune_update_stat = False + self._sample_funcs: dict[str, Any] = {} + for model_key in self.model_keys: + _nbatch = model_params["model_dict"][model_key].get( + "data_stat_nbatch", 10 + ) + _data = training_data[model_key] + _stat_path = stat_file_path[model_key] if stat_file_path else None + + def _make_sample( + _d: DeepmdDataSystem = _data, _n: int = _nbatch + ) -> list[dict[str, np.ndarray]]: + return make_stat_input(_d, _n) + + self._sample_funcs[model_key] = _make_sample + + _finetune_has_new_type = ( + finetune_model is not None + and finetune_links is not None + and model_key in finetune_links + and finetune_links[model_key].get_has_new_type() + ) + if _finetune_has_new_type: + self._finetune_update_stat = True + if (not resuming or _finetune_has_new_type) and self.rank == 0: + self.model[model_key].compute_or_load_stat( + sampled_func=_make_sample, + stat_file_path=_stat_path, + ) + if self.is_distributed: + for model_key in self.model_keys: + self._broadcast_model_stat(self.model[model_key]) + + # Model probability (multi-task) -------------------------------------- + if self.multi_task: + from deepmd.dpmodel.utils.training_utils import ( + resolve_model_prob, ) + self.model_prob = resolve_model_prob( + self.model_keys, + training_params.get("model_prob"), + training_data, + ) + else: + self.model_prob = None + # Learning rate ------------------------------------------------------- lr_params = config["learning_rate"].copy() lr_params["num_steps"] = self.num_steps @@ -493,6 +624,48 @@ def get_sample() -> list[dict[str, np.ndarray]]: self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) self.start_step = 0 + # Shared params (multi-task) ------------------------------------------ + if shared_links is not None: + _data_stat_protect = np.array( + [ + model_params["model_dict"][ii].get("data_stat_protect", 1e-2) + for ii in model_params["model_dict"] + ] + ) + assert np.allclose(_data_stat_protect, _data_stat_protect[0]), ( + "Model key 'data_stat_protect' must be the same in each branch when multitask!" + ) + self.wrapper.share_params( + shared_links, + resume=(resuming and not self._finetune_update_stat) or self.rank != 0, + model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + data_stat_protect=_data_stat_protect[0], + ) + + # DDP wrapping -------------------------------------------------------- + if self.is_distributed: + # Multi-task uses only one fitting_net per step, so unused + # parameters exist in the graph. Single-task doesn't need this. + _find_unused = self.multi_task + if DEVICE.type == "cuda": + from deepmd.pt_expt.utils.env import ( + LOCAL_RANK, + ) + + torch.cuda.set_device(LOCAL_RANK) + self.wrapper = torch.nn.parallel.DistributedDataParallel( + self.wrapper, + device_ids=[LOCAL_RANK], + find_unused_parameters=_find_unused, + output_device=LOCAL_RANK, + ) + else: + # CPU (gloo backend) — no device_ids + self.wrapper = torch.nn.parallel.DistributedDataParallel( + self.wrapper, + find_unused_parameters=_find_unused, + ) + # Optimiser ----------------------------------------------------------- opt_type = training_params.get("opt_type", "Adam") initial_lr = float(self.lr_schedule.value(self.start_step)) @@ -545,9 +718,8 @@ def get_sample() -> list[dict[str, np.ndarray]]: if finetune_model is not None and finetune_links is not None: # --- Finetune: selective weight loading ----------------------- - finetune_rule = finetune_links["Default"] - # Build pretrained model and load weights + # Build pretrained model(s) and load weights if is_pte: from deepmd.pt_expt.model import ( BaseModel, @@ -557,58 +729,125 @@ def get_sample() -> list[dict[str, np.ndarray]]: ) data = serialize_from_file(finetune_model) + pretrained_model_params = data["model_def_script"] pretrained_model = BaseModel.deserialize(data["model"]).to(DEVICE) else: - pretrained_model = get_model( - deepcopy(state_dict["_extra_state"]["model_params"]) - ).to(DEVICE) - pretrained_wrapper = ModelWrapper(pretrained_model) + pretrained_model_params = state_dict["_extra_state"]["model_params"] + + # Build pretrained model (single-task or multi-task) + if "model_dict" not in pretrained_model_params: + # Single-task pretrained → wrap as {"Default": model} + if is_pte: + pretrained_models = pretrained_model + else: + pretrained_models = get_model( + deepcopy(pretrained_model_params) + ).to(DEVICE) + else: + pretrained_models = {} + for pk in pretrained_model_params["model_dict"]: + pretrained_models[pk] = get_model( + deepcopy(pretrained_model_params["model_dict"][pk]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_models) if not is_pte: pretrained_wrapper.load_state_dict(state_dict) - # Change type map if needed - if ( - finetune_rule.get_finetune_tmap() - != pretrained_wrapper.model.get_type_map() - ): - model_with_new_type_stat = ( - self.wrapper.model if finetune_rule.get_has_new_type() else None - ) - pretrained_wrapper.model.change_type_map( - finetune_rule.get_finetune_tmap(), - model_with_new_type_stat=model_with_new_type_stat, - ) + # Per-branch type map change + for model_key in self.model_keys: + finetune_rule = finetune_links[model_key] + _model_key_from = finetune_rule.get_model_branch() + if ( + finetune_rule.get_finetune_tmap() + != pretrained_wrapper.model[_model_key_from].get_type_map() + ): + model_with_new_type_stat = ( + self._unwrapped.model[model_key] + if finetune_rule.get_has_new_type() + else None + ) + pretrained_wrapper.model[_model_key_from].change_type_map( + finetune_rule.get_finetune_tmap(), + model_with_new_type_stat=model_with_new_type_stat, + ) - # Selectively copy weights: descriptor always from pretrained, - # fitting from pretrained unless random_fitting is True + # Selective weight copy (per-branch key remapping) pretrained_state = pretrained_wrapper.state_dict() - target_state = self.wrapper.state_dict() + target_state = self._unwrapped.state_dict() new_state = {} for key in target_state: if key == "_extra_state": new_state[key] = target_state[key] - elif ( - finetune_rule.get_random_fitting() and ".descriptor." not in key - ): - new_state[key] = target_state[key] # keep random init - elif key in pretrained_state: - new_state[key] = pretrained_state[key] # from pretrained - else: - new_state[key] = target_state[key] # fallback - self.wrapper.load_state_dict(new_state) - - # Adjust output bias - bias_mode = ( - "change-by-statistic" - if not finetune_rule.get_random_fitting() - else "set-by-statistic" - ) - self.model = model_change_out_bias( - self.model, get_sample, _bias_adjust_mode=bias_mode - ) + continue + # Find which model_key this key belongs to + matched = False + for model_key in self.model_keys: + if f".{model_key}." not in key: + continue + matched = True + finetune_rule = finetune_links[model_key] + _key_from = finetune_rule.get_model_branch() + pretrained_key = key.replace(f".{model_key}.", f".{_key_from}.") + use_random = ( + finetune_rule.get_random_fitting() + and ".descriptor." not in key + ) + if use_random: + new_state[key] = target_state[key] + elif pretrained_key in pretrained_state: + new_state[key] = pretrained_state[pretrained_key] + else: + new_state[key] = target_state[key] + break + if not matched: + new_state[key] = target_state[key] + self._unwrapped.load_state_dict(new_state) + + # Per-branch bias adjustment (rank 0 only, then broadcast) + if not self.multi_task: + finetune_rule = finetune_links["Default"] + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + if self.rank == 0: + self.model = model_change_out_bias( + self.model, get_sample, _bias_adjust_mode=bias_mode + ) + if self.is_distributed: + self._broadcast_model_stat(self.model) + else: + for model_key in self.model_keys: + finetune_rule = finetune_links[model_key] + if finetune_rule.get_resuming(): + log.info(f"Model branch {model_key} will resume training.") + continue + log.info(f"Model branch {model_key} will be fine-tuned.") + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + if self.rank == 0: + self.model[model_key] = model_change_out_bias( + self.model[model_key], + self._sample_funcs[model_key], + _bias_adjust_mode=bias_mode, + ) + if self.is_distributed: + self._broadcast_model_stat(self.model[model_key]) else: # --- Normal resume (init_model / restart) -------------------- - self.wrapper.load_state_dict(state_dict) + self._unwrapped.load_state_dict(state_dict) + + if shared_links is not None: + # Re-apply sharing after loading checkpoint + self._unwrapped.share_params( + shared_links, + resume=True, + model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + ) if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) @@ -622,13 +861,6 @@ def get_sample() -> list[dict[str, np.ndarray]]: ) # torch.compile ------------------------------------------------------- - # The model's forward uses torch.autograd.grad (for forces) with - # create_graph=True so the loss backward can differentiate through - # forces. torch.compile does not support this "double backward". - # - # Solution: use make_fx to trace the model forward, which decomposes - # torch.autograd.grad into primitive ops. The resulting traced - # module is then compiled by torch.compile — no double backward. self.enable_compile = training_params.get("enable_compile", False) if self.enable_compile: compile_opts = training_params.get("compile_options", {}) @@ -666,108 +898,117 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: normalize_coord, ) - model = self.model - - # --- Estimate max_nall by sampling multiple batches --- - n_sample = 20 - max_nall = 0 - best_sample: ( - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None - ) = None - - for _ii in range(n_sample): - inp, _ = self.get_data(is_train=True) - coord = inp["coord"].detach() - atype = inp["atype"].detach() - box = inp.get("box") - if box is not None: - box = box.detach() - - nframes, nloc = atype.shape[:2] - coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3) - atype_np = atype.cpu().numpy() - box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None - - if box_np is not None: - coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3)) - else: - coord_norm = coord_np + for task_key in self.model_keys: + model = self.wrapper.model[task_key] + + # --- Estimate max_nall by sampling multiple batches --- + n_sample = 20 + max_nall = 0 + best_sample: ( + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None + ) = None + + for _ii in range(n_sample): + inp, _ = self.get_data(is_train=True, task_key=task_key) + coord = inp["coord"].detach() + atype = inp["atype"].detach() + box = inp.get("box") + if box is not None: + box = box.detach() + + nframes, nloc = atype.shape[:2] + coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3) + atype_np = atype.cpu().numpy() + box_np = ( + box.cpu().numpy().reshape(nframes, 9) if box is not None else None + ) - ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts( - coord_norm, atype_np, box_np, model.get_rcut() - ) - nlist_np = build_neighbor_list( - ext_coord_np, - ext_atype_np, - nloc, - model.get_rcut(), - model.get_sel(), - distinguish_types=False, - ) - ext_coord_np = ext_coord_np.reshape(nframes, -1, 3) - nall = ext_coord_np.shape[1] - if nall > max_nall: - max_nall = nall - best_sample = ( + if box_np is not None: + coord_norm = normalize_coord( + coord_np, box_np.reshape(nframes, 3, 3) + ) + else: + coord_norm = coord_np + + ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts( + coord_norm, atype_np, box_np, model.get_rcut() + ) + nlist_np = build_neighbor_list( ext_coord_np, ext_atype_np, - mapping_np, - nlist_np, nloc, - inp, + model.get_rcut(), + model.get_sel(), + distinguish_types=False, ) + ext_coord_np = ext_coord_np.reshape(nframes, -1, 3) + nall = ext_coord_np.shape[1] + if nall > max_nall: + max_nall = nall + best_sample = ( + ext_coord_np, + ext_atype_np, + mapping_np, + nlist_np, + nloc, + inp, + ) - # Add 20 % margin and round up to a multiple of 8. - max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8 - log.info( - "Estimated max_nall=%d for compiled model (sampled %d batches).", - max_nall, - n_sample, - ) - - # --- Pad the largest sample to max_nall and trace --- - assert best_sample is not None - ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( - best_sample - ) - nframes = ext_coord_np.shape[0] - actual_nall = ext_coord_np.shape[1] - pad_n = max_nall - actual_nall - - if pad_n > 0: - ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0))) - ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n))) - mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n))) + # Add 20 % margin and round up to a multiple of 8. + max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8 + log.info( + "Estimated max_nall=%d for compiled model " + "(task=%s, sampled %d batches).", + max_nall, + task_key, + n_sample, + ) - ext_coord = torch.tensor( - ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE - ) - ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE) - nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE) - mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE) - fparam = sample_input.get("fparam") - aparam = sample_input.get("aparam") + # --- Pad the largest sample to max_nall and trace --- + assert best_sample is not None + ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( + best_sample + ) + nframes = ext_coord_np.shape[0] + actual_nall = ext_coord_np.shape[1] + pad_n = max_nall - actual_nall - compile_opts.pop("dynamic", None) # always False for padded approach + if pad_n > 0: + ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0))) + ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n))) + mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n))) - compiled_lower = _trace_and_compile( - model, - ext_coord, - ext_atype, - nlist_t, - mapping_t, - fparam, - aparam, - compile_opts, - ) + ext_coord = torch.tensor( + ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + ) + ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE) + fparam = sample_input.get("fparam") + aparam = sample_input.get("aparam") + + task_compile_opts = dict(compile_opts) + task_compile_opts.pop("dynamic", None) # always False for padded approach + + compiled_lower = _trace_and_compile( + model, + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + task_compile_opts, + ) - self.wrapper.model = _CompiledModel( - model, compiled_lower, max_nall, compile_opts - ) - log.info( - "Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).", - max_nall, - ) + self.wrapper.model[task_key] = _CompiledModel( + model, compiled_lower, max_nall, task_compile_opts + ) + log.info( + "Model compiled with padded nall=%d (task=%s, dynamic=False).", + max_nall, + task_key, + ) # ------------------------------------------------------------------ # Data helpers @@ -776,14 +1017,29 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: def get_data( self, is_train: bool = True, + task_key: str = "Default", ) -> tuple[dict[str, Any], dict[str, Any]]: """Fetch a batch and split into input / label dicts. + Parameters + ---------- + is_train : bool + Whether to fetch from training or validation data. + task_key : str + Task key for multi-task training. + Returns ------- input_dict, label_dict """ - data_sys = self.training_data if is_train else self.validation_data + if not self.multi_task: + data_sys = self.training_data if is_train else self.validation_data + else: + data_sys = ( + self.training_data[task_key] + if is_train + else self.validation_data[task_key] + ) if data_sys is None: return {}, {} @@ -812,14 +1068,33 @@ def get_data( return input_dict, label_dict + # ------------------------------------------------------------------ + # DDP helpers + # ------------------------------------------------------------------ + + @property + def _unwrapped(self) -> "ModelWrapper": + """Return the raw ModelWrapper, unwrapping DDP if active.""" + if hasattr(self.wrapper, "module"): + return self.wrapper.module + return self.wrapper + + @staticmethod + def _broadcast_model_stat(model: torch.nn.Module) -> None: + """Broadcast model parameters and buffers from rank 0 to all ranks.""" + for p in model.parameters(): + dist.broadcast(p.data, src=0) + for b in model.buffers(): + dist.broadcast(b, src=0) + # ------------------------------------------------------------------ # Checkpointing # ------------------------------------------------------------------ def save_checkpoint(self, step: int) -> None: - self.wrapper.train_infos["step"] = step + self._unwrapped.train_infos["step"] = step state = { - "model": self.wrapper.state_dict(), + "model": self._unwrapped.state_dict(), "optimizer": self.optimizer.state_dict(), } ckpt_path = f"{self.save_ckpt}-{step}.pt" @@ -846,10 +1121,16 @@ def _optimizer_step(self) -> None: self.scheduler.step() def run(self) -> None: - fout = open( - self.disp_file, - mode="w" if not self.restart_training else "a", - buffering=1, + from deepmd.utils import random as dp_random + + fout = ( + open( + self.disp_file, + mode="w" if not self.restart_training else "a", + buffering=1, + ) + if self.rank == 0 + else None ) log.info("Start to train %d steps.", self.num_steps) @@ -860,16 +1141,28 @@ def run(self) -> None: for step_id in range(self.start_step, self.num_steps): cur_lr = float(self.lr_schedule.value(step_id)) + # --- task selection (multi-task) --- + task_key = "Default" + if self.multi_task: + model_index = dp_random.choice( + np.arange(self.num_model, dtype=np.int_), + p=self.model_prob, + ) + task_key = self.model_keys[model_index] + if self.timing_in_training: t_start = time.time() # --- forward / backward --- self.optimizer.zero_grad(set_to_none=True) - input_dict, label_dict = self.get_data(is_train=True) + input_dict, label_dict = self.get_data(is_train=True, task_key=task_key) cur_lr_sched = self.scheduler.get_last_lr()[0] model_pred, loss, more_loss = self.wrapper( - **input_dict, cur_lr=cur_lr_sched, label=label_dict + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + task_key=task_key if self.multi_task else None, ) loss.backward() @@ -890,104 +1183,183 @@ def run(self) -> None: ): self.wrapper.eval() - train_results = {k: v for k, v in more_loss.items() if "l2_" not in k} - - # validation - valid_results: dict[str, Any] = {} - if self.validation_data is not None: - sum_natoms = 0 - for _ii in range(self.valid_numb_batch): - val_input, val_label = self.get_data(is_train=False) - if not val_input: - break - _, _vloss, _vmore = self.wrapper( - **val_input, cur_lr=cur_lr_sched, label=val_label - ) - natoms = int(val_input["atype"].shape[-1]) - sum_natoms += natoms - for k, v in _vmore.items(): - if "l2_" not in k: - valid_results[k] = ( - valid_results.get(k, 0.0) + v * natoms + if self.rank == 0: + if not self.multi_task: + train_results = { + k: v for k, v in more_loss.items() if "l2_" not in k + } + + # validation + valid_results: dict[str, Any] = {} + if self.validation_data is not None: + sum_natoms = 0 + for _ii in range(self.valid_numb_batch): + val_input, val_label = self.get_data(is_train=False) + if not val_input: + break + _, _vloss, _vmore = self._unwrapped( + **val_input, + cur_lr=cur_lr_sched, + label=val_label, ) - if sum_natoms > 0: - valid_results = { - k: v / sum_natoms for k, v in valid_results.items() + natoms = int(val_input["atype"].shape[-1]) + sum_natoms += natoms + for k, v in _vmore.items(): + if "l2_" not in k: + valid_results[k] = ( + valid_results.get(k, 0.0) + v * natoms + ) + if sum_natoms > 0: + valid_results = { + k: v / sum_natoms for k, v in valid_results.items() + } + else: + # Multi-task: compute loss for ALL tasks + train_results = {_key: {} for _key in self.model_keys} + valid_results = {_key: {} for _key in self.model_keys} + + # current task already has loss + train_results[task_key] = { + k: v for k, v in more_loss.items() if "l2_" not in k } - # wall-clock time - current_time = time.time() - wall_elapsed = current_time - wall_start - interval_wall_time = current_time - last_log_time - last_log_time = current_time - if self.timing_in_training: - step_time = t_end - t_start - steps_completed_since_restart = max( - 1, - display_step_id - self.start_step, - ) - eta = int( - (self.num_steps - display_step_id) - / steps_completed_since_restart - * wall_elapsed - ) - log.info( - format_training_message( - batch=display_step_id, - wall_time=interval_wall_time, - eta=eta, - current_time=datetime.datetime.fromtimestamp( - current_time, - tz=datetime.timezone.utc, - ).astimezone(), + # compute loss for other tasks + for _key in self.model_keys: + if _key != task_key: + self.optimizer.zero_grad() + _inp, _lab = self.get_data(is_train=True, task_key=_key) + _, _loss, _more = self._unwrapped( + **_inp, + cur_lr=cur_lr_sched, + label=_lab, + task_key=_key, + ) + train_results[_key] = { + k: v for k, v in _more.items() if "l2_" not in k + } + + # validation for each task + _vdata = self.validation_data[_key] + if _vdata is not None: + _sum_natoms = 0 + _vres: dict[str, Any] = {} + for _ii in range(self.valid_numb_batch[_key]): + _vi, _vl = self.get_data( + is_train=False, task_key=_key + ) + if not _vi: + break + _, _vloss, _vmore = self._unwrapped( + **_vi, + cur_lr=cur_lr_sched, + label=_vl, + task_key=_key, + ) + natoms = int(_vi["atype"].shape[-1]) + _sum_natoms += natoms + for k, v in _vmore.items(): + if "l2_" not in k: + _vres[k] = _vres.get(k, 0.0) + v * natoms + if _sum_natoms > 0: + _vres = { + k: v / _sum_natoms for k, v in _vres.items() + } + valid_results[_key] = _vres + # wall-clock time + current_time = time.time() + wall_elapsed = current_time - wall_start + interval_wall_time = current_time - last_log_time + last_log_time = current_time + if self.timing_in_training: + step_time = t_end - t_start + steps_completed_since_restart = max( + 1, + display_step_id - self.start_step, ) - ) - log.info("step=%d step_time=%.4fs", display_step_id, step_time) - else: - log.info( - format_training_message( - batch=display_step_id, - wall_time=interval_wall_time, + eta = int( + (self.num_steps - display_step_id) + / steps_completed_since_restart + * wall_elapsed + ) + log.info( + format_training_message( + batch=display_step_id, + wall_time=interval_wall_time, + eta=eta, + current_time=datetime.datetime.fromtimestamp( + current_time, + tz=datetime.timezone.utc, + ).astimezone(), + ) + ) + log.info("step=%d step_time=%.4fs", display_step_id, step_time) + else: + log.info( + format_training_message( + batch=display_step_id, + wall_time=interval_wall_time, + ) ) - ) - # log - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name="trn", - rmse=train_results, - learning_rate=cur_lr, - ) - ) - if valid_results: - log.info( - format_training_message_per_task( - batch=display_step_id, - task_name="val", - rmse=valid_results, - learning_rate=None, + # log + if not self.multi_task: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name="trn", + rmse=train_results, + learning_rate=cur_lr, + ) ) - ) + if valid_results: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name="val", + rmse=valid_results, + learning_rate=None, + ) + ) + else: + for _key in self.model_keys: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name=_key + "_trn", + rmse=train_results[_key], + learning_rate=cur_lr, + ) + ) + if valid_results[_key]: + log.info( + format_training_message_per_task( + batch=display_step_id, + task_name=_key + "_val", + rmse=valid_results[_key], + learning_rate=None, + ) + ) - # lcurve file - if self.lcurve_should_print_header: - self.print_header(fout, train_results, valid_results) - self.lcurve_should_print_header = False - self.print_on_training( - fout, display_step_id, cur_lr, train_results, valid_results - ) + # lcurve file + if self.lcurve_should_print_header: + self.print_header(fout, train_results, valid_results) + self.lcurve_should_print_header = False + self.print_on_training( + fout, display_step_id, cur_lr, train_results, valid_results + ) self.wrapper.train() # --- checkpoint --- - if display_step_id % self.save_freq == 0: + if display_step_id % self.save_freq == 0 and self.rank == 0: self.save_checkpoint(display_step_id) # final save - self.save_checkpoint(self.num_steps) + if self.rank == 0: + self.save_checkpoint(self.num_steps) wall_total = time.time() - wall_start - fout.close() + if fout is not None: + fout.close() log.info("Training finished. Total wall time: %.2fs", wall_total) # ------------------------------------------------------------------ @@ -1000,14 +1372,23 @@ def print_header( train_results: dict[str, Any], valid_results: dict[str, Any], ) -> None: - train_keys = sorted(train_results.keys()) header = "# {:5s}".format("step") - if valid_results: - for k in train_keys: - header += f" {k + '_val':>11s} {k + '_trn':>11s}" + if not self.multi_task: + train_keys = sorted(train_results.keys()) + if valid_results: + for k in train_keys: + header += f" {k + '_val':>11s} {k + '_trn':>11s}" + else: + for k in train_keys: + header += f" {k + '_trn':>11s}" else: - for k in train_keys: - header += f" {k + '_trn':>11s}" + for model_key in self.model_keys: + if valid_results[model_key]: + for k in sorted(train_results[model_key].keys()): + header += f" {k + '_val_' + model_key:>11s} {k + '_trn_' + model_key:>11s}" + else: + for k in sorted(train_results[model_key].keys()): + header += f" {k + '_trn_' + model_key:>11s}" header += " {:8s}\n".format("lr") fout.write(header) fout.flush() @@ -1020,14 +1401,23 @@ def print_on_training( train_results: dict, valid_results: dict, ) -> None: - train_keys = sorted(train_results.keys()) line = f"{step_id:7d}" - if valid_results: - for k in train_keys: - line += f" {valid_results.get(k, float('nan')):11.2e} {train_results[k]:11.2e}" + if not self.multi_task: + train_keys = sorted(train_results.keys()) + if valid_results: + for k in train_keys: + line += f" {valid_results.get(k, float('nan')):11.2e} {train_results[k]:11.2e}" + else: + for k in train_keys: + line += f" {train_results[k]:11.2e}" else: - for k in train_keys: - line += f" {train_results[k]:11.2e}" + for model_key in self.model_keys: + if valid_results[model_key]: + for k in sorted(valid_results[model_key].keys()): + line += f" {valid_results[model_key][k]:11.2e} {train_results[model_key][k]:11.2e}" + else: + for k in sorted(train_results[model_key].keys()): + line += f" {train_results[model_key][k]:11.2e}" line += f" {cur_lr:8.1e}\n" fout.write(line) fout.flush() @@ -1074,3 +1464,40 @@ def model_change_out_bias( f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}." ) return _model + + +def _get_case_embd_config( + model_params: dict[str, Any], +) -> tuple[bool, dict[str, int]]: + """Check whether case embedding is enabled and build the index map. + + Parameters + ---------- + model_params : dict + Model parameters containing ``model_dict``. + + Returns + ------- + do_case_embd : bool + Whether case embedding is enabled. + case_embd_index : dict + Mapping from model key to case index (sorted alphabetically). + """ + assert "model_dict" in model_params, ( + "Only support setting case embedding for multi-task model!" + ) + model_keys = list(model_params["model_dict"]) + sorted_model_keys = sorted(model_keys) + numb_case_embd_list = [ + model_params["model_dict"][mk].get("fitting_net", {}).get("dim_case_embd", 0) + for mk in sorted_model_keys + ] + if not all(item == numb_case_embd_list[0] for item in numb_case_embd_list): + raise ValueError( + "All models must have the same dimension of case embedding, " + f"while the settings are: {numb_case_embd_list}" + ) + if numb_case_embd_list[0] == 0: + return False, {} + case_embd_index = {mk: idx for idx, mk in enumerate(sorted_model_keys)} + return True, case_embd_index diff --git a/deepmd/pt_expt/train/wrapper.py b/deepmd/pt_expt/train/wrapper.py index 281168cdba..f67efe8a8e 100644 --- a/deepmd/pt_expt/train/wrapper.py +++ b/deepmd/pt_expt/train/wrapper.py @@ -10,24 +10,24 @@ class ModelWrapper(torch.nn.Module): - """Simplified model wrapper that bundles a model and a loss. + """Model wrapper that bundles model(s) and loss(es). - Single-task only for now (no multi-task support). + Supports both single-task and multi-task training. Parameters ---------- - model : torch.nn.Module - The model to train. - loss : torch.nn.Module - The loss module. + model : torch.nn.Module or dict + Single model or dict of models keyed by task name. + loss : torch.nn.Module or dict or None + Single loss or dict of losses keyed by task name. model_params : dict, optional Model parameters to store as extra state. """ def __init__( self, - model: torch.nn.Module, - loss: torch.nn.Module | None = None, + model: torch.nn.Module | dict, + loss: torch.nn.Module | dict | None = None, model_params: dict[str, Any] | None = None, ) -> None: super().__init__() @@ -36,10 +36,133 @@ def __init__( "lr": 0, "step": 0, } - self.model = model - self.loss = loss + self.multi_task = False + self.model = torch.nn.ModuleDict() + # Model + if isinstance(model, torch.nn.Module): + self.model["Default"] = model + elif isinstance(model, dict): + self.multi_task = True + for task_key in model: + assert isinstance(model[task_key], torch.nn.Module), ( + f"{task_key} in model_dict is not a torch.nn.Module!" + ) + self.model[task_key] = model[task_key] + # Loss — dpmodel losses are not nn.Module, so store in a plain dict. + self.loss: dict[str, Any] | None = None + if loss is not None: + if isinstance(loss, dict): + self.loss = dict(loss) + else: + self.loss = {"Default": loss} self.inference_only = self.loss is None + def share_params( + self, + shared_links: dict[str, Any], + model_key_prob_map: dict, + data_stat_protect: float = 1e-2, + resume: bool = False, + ) -> None: + """Share parameters between models following rules in shared_links. + + Parameters + ---------- + shared_links : dict + Sharing rules from ``preprocess_shared_params``. + model_key_prob_map : dict + Probability map for each model key (for fitting_net stat weighting). + data_stat_protect : float + Protection value for standard deviation computation. + resume : bool + Whether resuming from checkpoint. + """ + for shared_item in shared_links: + shared_base = shared_links[shared_item]["links"][0] + class_type_base = shared_base["shared_type"] + model_key_base = shared_base["model_key"] + shared_level_base = shared_base["shared_level"] + if "descriptor" in class_type_base: + if class_type_base == "descriptor": + base_class = self.model[model_key_base].get_descriptor() + elif "hybrid" in class_type_base: + hybrid_index = int(class_type_base.split("_")[-1]) + base_class = ( + self.model[model_key_base] + .get_descriptor() + .descrpt_list[hybrid_index] + ) + else: + raise RuntimeError(f"Unknown class_type {class_type_base}!") + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert shared_level_link >= shared_level_base, ( + "The shared_links must be sorted by shared_level!" + ) + assert "descriptor" in class_type_link, ( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + if class_type_link == "descriptor": + link_class = self.model[model_key_link].get_descriptor() + elif "hybrid" in class_type_link: + hybrid_index = int(class_type_link.split("_")[-1]) + link_class = ( + self.model[model_key_link] + .get_descriptor() + .descrpt_list[hybrid_index] + ) + else: + raise RuntimeError(f"Unknown class_type {class_type_link}!") + frac_prob = ( + model_key_prob_map[model_key_link] + / model_key_prob_map[model_key_base] + ) + link_class.share_params( + base_class, + shared_level_link, + model_prob=frac_prob, + resume=resume, + ) + log.warning( + f"Shared params of {model_key_base}.{class_type_base} " + f"and {model_key_link}.{class_type_link}!" + ) + else: + if hasattr(self.model[model_key_base].atomic_model, class_type_base): + base_class = self.model[model_key_base].atomic_model.__getattr__( + class_type_base + ) + for link_item in shared_links[shared_item]["links"][1:]: + class_type_link = link_item["shared_type"] + model_key_link = link_item["model_key"] + shared_level_link = int(link_item["shared_level"]) + assert shared_level_link >= shared_level_base, ( + "The shared_links must be sorted by shared_level!" + ) + assert class_type_base == class_type_link, ( + f"Class type mismatched: {class_type_base} vs {class_type_link}!" + ) + link_class = self.model[ + model_key_link + ].atomic_model.__getattr__(class_type_link) + frac_prob = ( + model_key_prob_map[model_key_link] + / model_key_prob_map[model_key_base] + ) + link_class.share_params( + base_class, + shared_level_link, + model_prob=frac_prob, + protection=data_stat_protect, + resume=resume, + ) + log.warning( + f"Shared params of {model_key_base}.{class_type_base} " + f"and {model_key_link}.{class_type_link}!" + ) + def forward( self, coord: torch.Tensor, @@ -49,8 +172,16 @@ def forward( aparam: torch.Tensor | None = None, cur_lr: float | torch.Tensor | None = None, label: dict[str, torch.Tensor] | None = None, + task_key: str | None = None, do_atomic_virial: bool = False, ) -> tuple[dict[str, torch.Tensor], torch.Tensor | None, dict | None]: + if not self.multi_task: + task_key = "Default" + else: + assert task_key is not None, ( + f"Multitask model must specify the inference task! " + f"Supported tasks are {list(self.model.keys())}." + ) input_dict = { "coord": coord, "atype": atype, @@ -60,13 +191,13 @@ def forward( "aparam": aparam, } - model_pred = self.model(**input_dict) + model_pred = self.model[task_key](**input_dict) if self.inference_only or label is None: return model_pred, None, None else: natoms = atype.shape[-1] - loss, more_loss = self.loss( + loss, more_loss = self.loss[task_key]( cur_lr, natoms, model_pred, diff --git a/deepmd/pt_expt/utils/finetune.py b/deepmd/pt_expt/utils/finetune.py index 5e49d8738b..d26db9e847 100644 --- a/deepmd/pt_expt/utils/finetune.py +++ b/deepmd/pt_expt/utils/finetune.py @@ -47,7 +47,7 @@ def get_finetune_rules( model_branch: str = "", change_model_params: bool = True, ) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: - """Get fine-tuning rules for a single-task pt_expt model. + """Get fine-tuning rules for a single-task or multi-task pt_expt model. Loads a pretrained ``.pt`` checkpoint or ``.pte`` frozen model and builds ``FinetuneRuleItem`` objects describing how to map types and @@ -70,29 +70,81 @@ def get_finetune_rules( model_config : dict Possibly updated model config. finetune_links : dict[str, FinetuneRuleItem] - Fine-tuning rules keyed by ``"Default"``. + Fine-tuning rules keyed by model branch name (``"Default"`` for + single-task, or per-branch keys for multi-task). """ last_model_params = _load_model_params(finetune_model) if change_model_params and "descriptor" not in last_model_params: - raise ValueError( - "Cannot use --use-pretrain-script: the pretrained model does not " - "contain full model params. If finetuning from a .pte file, " - "re-freeze it with the latest code so that model_def_script is embedded." - ) - + # For multi-task pretrained, check inside model_dict + if "model_dict" not in last_model_params or "descriptor" not in next( + iter(last_model_params["model_dict"].values()) + ): + raise ValueError( + "Cannot use --use-pretrain-script: the pretrained model does not " + "contain full model params. If finetuning from a .pte file, " + "re-freeze it with the latest code so that model_def_script is embedded." + ) + + multi_task = "model_dict" in model_config finetune_from_multi_task = "model_dict" in last_model_params - - # pt_expt is single-task only - if model_branch == "" and "finetune_head" in model_config: - model_branch = model_config["finetune_head"] - model_config, finetune_rule = get_finetune_rule_single( - model_config, - last_model_params, - from_multitask=finetune_from_multi_task, - model_branch="Default", - model_branch_from=model_branch, - change_model_params=change_model_params, - ) - finetune_links: dict[str, FinetuneRuleItem] = {"Default": finetune_rule} + finetune_links: dict[str, FinetuneRuleItem] = {} + + if not multi_task: + # Single-task target + if model_branch == "" and "finetune_head" in model_config: + model_branch = model_config["finetune_head"] + model_config, finetune_rule = get_finetune_rule_single( + model_config, + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch="Default", + model_branch_from=model_branch, + change_model_params=change_model_params, + ) + finetune_links["Default"] = finetune_rule + else: + # Multi-task target — mirrors PT's logic + assert model_branch == "", ( + "Multi-task fine-tuning does not support command-line branches chosen!" + "Please define the 'finetune_head' in each model params!" + ) + if not finetune_from_multi_task: + pretrained_keys = ["Default"] + else: + pretrained_keys = list(last_model_params["model_dict"].keys()) + for model_key in model_config["model_dict"]: + resuming = False + if ( + "finetune_head" in model_config["model_dict"][model_key] + and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" + ): + pretrained_key = model_config["model_dict"][model_key]["finetune_head"] + assert pretrained_key in pretrained_keys, ( + f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" + f"Available heads are: {list(pretrained_keys)}" + ) + model_branch_from = pretrained_key + elif ( + "finetune_head" not in model_config["model_dict"][model_key] + and model_key in pretrained_keys + ): + # resume — no finetune + model_branch_from = model_key + resuming = True + else: + # new branch or RANDOM → random fitting + model_branch_from = "RANDOM" + model_config["model_dict"][model_key], finetune_rule = ( + get_finetune_rule_single( + model_config["model_dict"][model_key], + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch=model_key, + model_branch_from=model_branch_from, + change_model_params=change_model_params, + ) + ) + finetune_links[model_key] = finetune_rule + finetune_links[model_key].resuming = resuming return model_config, finetune_links diff --git a/deepmd/pt_expt/utils/multi_task.py b/deepmd/pt_expt/utils/multi_task.py new file mode 100644 index 0000000000..a4600d5ebb --- /dev/null +++ b/deepmd/pt_expt/utils/multi_task.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.fitting import ( + BaseFitting, +) + + +def preprocess_shared_params( + model_config: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any]]: + """Preprocess the model params for multitask model, and generate the links dict for further sharing. + + Args: + model_config: Model params of multitask model. + + Returns + ------- + model_config: Preprocessed model params of multitask model. + Those string names are replaced with real params in `shared_dict` of model params. + shared_links: Dict of link infos for further sharing. + Each item, whose key must be in `shared_dict`, is a dict with following keys: + - "type": The real class type of this item. + - "links": List of shared settings, each sub-item is a dict with following keys: + - "model_key": Model key in the `model_dict` to share this item. + - "shared_type": Type of this shard item. + - "shared_level": Shared level (int) of this item in this model. + Lower for more params to share, 0 means to share all params in this item. + This list are sorted by "shared_level". + """ + assert "model_dict" in model_config, "only multi-task model can use this method!" + supported_types = ["type_map", "descriptor", "fitting_net"] + shared_dict = model_config.get("shared_dict", {}) + shared_links = {} + type_map_keys = [] + + def replace_one_item( + params_dict: dict[str, Any], + key_type: str, + key_in_dict: str, + suffix: str = "", + index: int | None = None, + ) -> None: + shared_type = key_type + shared_key = key_in_dict + shared_level = 0 + if ":" in key_in_dict: + shared_key = key_in_dict.split(":")[0] + shared_level = int(key_in_dict.split(":")[1]) + assert shared_key in shared_dict, ( + f"Appointed {shared_type} {shared_key} are not in the shared_dict! Please check the input params." + ) + if index is None: + params_dict[shared_type] = deepcopy(shared_dict[shared_key]) + else: + params_dict[index] = deepcopy(shared_dict[shared_key]) + if shared_type == "type_map": + if key_in_dict not in type_map_keys: + type_map_keys.append(key_in_dict) + else: + if shared_key not in shared_links: + class_name = get_class_name(shared_type, shared_dict[shared_key]) + shared_links[shared_key] = {"type": class_name, "links": []} + link_item = { + "model_key": model_key, + "shared_type": shared_type + suffix, + "shared_level": shared_level, + } + shared_links[shared_key]["links"].append(link_item) + + for model_key in model_config["model_dict"]: + model_params_item = model_config["model_dict"][model_key] + for item_key in model_params_item: + if item_key in supported_types: + item_params = model_params_item[item_key] + if isinstance(item_params, str): + replace_one_item(model_params_item, item_key, item_params) + elif item_params.get("type", "") == "hybrid": + for ii, hybrid_item in enumerate(item_params["list"]): + if isinstance(hybrid_item, str): + replace_one_item( + model_params_item[item_key]["list"], + item_key, + hybrid_item, + suffix=f"_hybrid_{ii}", + index=ii, + ) + for shared_key in shared_links: + shared_links[shared_key]["links"] = sorted( + shared_links[shared_key]["links"], + key=lambda x: ( + x["shared_level"] + - ("spin" in model_config["model_dict"][x["model_key"]]) * 100 + ), + ) + # little trick to make spin models in the front to be the base models, + # because its type embeddings are more general. + assert len(type_map_keys) == 1, "Multitask model must have only one type_map!" + return model_config, shared_links + + +def get_class_name(item_key: str, item_params: dict[str, Any]) -> type: + if item_key == "descriptor": + return BaseDescriptor.get_class_by_type(item_params.get("type", "se_e2_a")) + elif item_key == "fitting_net": + return BaseFitting.get_class_by_type(item_params.get("type", "ener")) + else: + raise RuntimeError(f"Unknown class_name type {item_key}") diff --git a/source/tests/pt/test_fitting_stat.py b/source/tests/pt/test_fitting_stat.py index 7807523221..80d213bcad 100644 --- a/source/tests/pt/test_fitting_stat.py +++ b/source/tests/pt/test_fitting_stat.py @@ -280,7 +280,7 @@ def test_sharefitting_with_fparam(self): self.config["training"]["data_dict"]["model_2"]["validation_data"][ "systems" ] = self.data_file_single - self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 100 + self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 80 self.config["model"], self.shared_links = preprocess_shared_params( self.config["model"] @@ -391,7 +391,7 @@ def test_sharefitting_using_default_fparam(self): ] = self.data_file data_stat_protect = 5e-3 self.config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 3 - self.config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 100 + self.config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 80 self.config["model"]["model_dict"]["model_1"]["data_stat_protect"] = ( data_stat_protect ) diff --git a/source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py b/source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py new file mode 100644 index 0000000000..a8f420b2db --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_descrpt_stat_merge.py @@ -0,0 +1,1328 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for probability-weighted stat merging in descriptor share_params.""" + +from typing import ( + ClassVar, +) + +import numpy as np +import pytest +import torch + +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.dpmodel.utils.env_mat_stat import ( + EnvMatStatSe, + merge_env_stat, +) +from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt_expt.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt_expt.descriptor.dpa3 import ( + DescrptDPA3, +) +from deepmd.pt_expt.descriptor.hybrid import ( + DescrptHybrid, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt_expt.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.env_mat_stat import ( + StatItem, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +def _make_stats(ntypes: int, last_dim: int, rng: np.random.Generator) -> dict: + """Create synthetic StatItem stats for an env mat descriptor. + + The stats dict has keys "r_{i}" and optionally "a_{i}" for each type, + matching the EnvMatStatSe convention. + """ + stats = {} + for ti in range(ntypes): + # Use moderate values to avoid zero-division + n = rng.uniform(100, 500) + s = rng.uniform(-10, 10) + sq = s**2 / n + rng.uniform(0.01, 1.0) # ensure variance > 0 + stats[f"r_{ti}"] = StatItem(number=n, sum=s, squared_sum=sq * n) + if last_dim == 4: + n_a = rng.uniform(100, 500) + s_a = rng.uniform(-10, 10) + sq_a = s_a**2 / n_a + rng.uniform(0.01, 1.0) + stats[f"a_{ti}"] = StatItem(number=n_a, sum=s_a, squared_sum=sq_a * n_a) + return stats + + +def _compute_expected_buffers(descriptor, merged_stats, last_dim): + """Compute expected mean/stddev from merged stats using EnvMatStatSe.""" + env_stat = EnvMatStatSe(descriptor) + env_stat.stats = merged_stats + mean, stddev = env_stat() + return mean, stddev + + +def _merge_stats(base_stats, link_stats, model_prob): + """Manually merge stats dicts.""" + merged = {} + for kk in base_stats: + merged[kk] = base_stats[kk] + link_stats[kk] * model_prob + return merged + + +class TestStatMergeSeA: + """Test stat merging for se_e2_a descriptor.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def _make_descriptor(self, seed): + return DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=seed).to( + self.device + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0, 0.1]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged davg/dstd match manually computed values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = base_stats + dd_link.stats = link_stats + + # Set initial davg/dstd on base + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + # Compute expected + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base, merged_stats, self.last_dim + ) + + # share_params with stat merging + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + # Verify buffers match expected + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + # Verify stats updated for chaining + for kk in merged_stats: + assert abs(dd_base.stats[kk].number - merged_stats[kk].number) < 1e-10 + assert abs(dd_base.stats[kk].sum - merged_stats[kk].sum) < 1e-10 + + def test_buffers_aliased(self) -> None: + """After share_params, link buffers should be aliased to base.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.5, resume=False) + + for key in dd_base._buffers: + assert dd_link._buffers[key] is dd_base._buffers[key], ( + f"Buffer {key} not aliased" + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging and preserve original buffers.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + original_davg = dd_base.davg.clone() + original_dstd = dd_base.dstd.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + # Buffers should be unchanged + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), + original_dstd.detach().cpu().numpy(), + ) + + def test_none_stats_skips_merge(self) -> None: + """When stats is None, merging should be silently skipped.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + # stats is not set (default None) + assert getattr(dd_base, "stats", None) is None + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + original_davg = dd_base.davg.clone() + + # Should not raise + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + # davg should be unchanged (merge was skipped) + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestStatMergeSeR: + """Test stat merging for se_r descriptor.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 100) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 1 + + def _make_descriptor(self, seed): + return DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=seed).to( + self.device + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged davg/dstd match manually computed values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + original_davg = dd_base.davg.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestStatMergeSeT: + """Test stat merging for se_t descriptor.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 200) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def _make_descriptor(self, seed): + return DescrptSeT(self.rcut, self.rcut_smth, self.sel, seed=seed).to( + self.device + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged davg/dstd match manually computed values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.dstd.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dstd0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd_base.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + original_davg = dd_base.davg.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA1: + """Test stat merging for DPA1 descriptor (se_atten block has mean/stddev).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel = 7 + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 300) + self.ntypes = 2 + self.nnei = self.sel + self.last_dim = 4 + + def _make_descriptor(self, seed): + return DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + seed=seed, + ).to(self.device) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged mean/stddev on se_atten block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.se_atten.stats = base_stats + dd_link.se_atten.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_atten.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_atten.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.se_atten, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.se_atten.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.se_atten.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging on se_atten block.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.se_atten.stats = base_stats + dd_link.se_atten.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_atten.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_atten.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.se_atten.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.se_atten.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + def test_level1_no_merge(self) -> None: + """Level 1 shares type_embedding only, no stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.se_atten.stats = base_stats + dd_link.se_atten.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dd_base.se_atten.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.se_atten.mean.clone() + + dd_link.share_params(dd_base, shared_level=1, model_prob=0.6, resume=False) + + # type_embedding shared + assert dd_link._modules["type_embedding"] is dd_base._modules["type_embedding"] + # se_atten NOT shared + assert dd_link._modules["se_atten"] is not dd_base._modules["se_atten"] + # stats unchanged + np.testing.assert_allclose( + dd_base.se_atten.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA2: + """Test stat merging for DPA2 descriptor (repinit and repformers blocks).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel_mix: ClassVar = [7] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 350) + self.ntypes = 2 + self.nnei = sum(self.sel_mix) + self.last_dim = 4 + + def _make_descriptor(self, seed): + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="strip", + set_davg_zero=False, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=self.nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=False, + update_g1_has_grrg=False, + update_g1_has_attn=False, + update_g2_has_g1g1=False, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=True, + update_style="res_residual", + set_davg_zero=False, + ) + dd = DescrptDPA2( + self.ntypes, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + seed=seed, + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_repinit(self, model_prob) -> None: + """Verify merged mean/stddev on repinit block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repinit.stats = base_stats + dd_link.repinit.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.repinit.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repinit.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repinit, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repinit.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.repinit.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_repformers(self, model_prob) -> None: + """Verify merged mean/stddev on repformers block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + nnei_repformers = self.nnei // 2 + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repformers.stats = base_stats + dd_link.repformers.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, nnei_repformers, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, nnei_repformers, self.last_dim)) + ) + dd_base.repformers.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repformers.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repformers, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repformers.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.repformers.stddev.detach().cpu().numpy(), + expected_stddev, + rtol=1e-10, + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging on all blocks.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats_ri = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats_ri = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit.stats = base_stats_ri + dd_link.repinit.stats = link_stats_ri + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dd_base.repinit.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repinit.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.repinit.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + def test_level1_no_merge(self) -> None: + """Level 1 shares type_embedding only, no stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit.stats = base_stats + dd_link.repinit.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + dd_base.repinit.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repinit.mean.clone() + + dd_link.share_params(dd_base, shared_level=1, model_prob=0.6, resume=False) + + # type_embedding shared + assert dd_link._modules["type_embedding"] is dd_base._modules["type_embedding"] + # repinit NOT shared + assert dd_link._modules["repinit"] is not dd_base._modules["repinit"] + # stats unchanged + np.testing.assert_allclose( + dd_base.repinit.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA3: + """Test stat merging for DPA3 descriptor (repflows block has mean/stddev).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel = 7 + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 400) + self.ntypes = 2 + self.nnei = self.sel + self.last_dim = 4 + + def _make_descriptor(self, seed, fix_stat_std=0.0): + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=self.sel, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=self.sel - 1, + axis_neuron=4, + update_angle=True, + update_style="res_residual", + smooth_edge_update=True, + fix_stat_std=fix_stat_std, + ) + dd = DescrptDPA3( + self.ntypes, + repflow=repflow, + seed=seed, + ).to(self.device) + # Override set_davg_zero for testing (default True in repflows) + if fix_stat_std == 0.0: + dd.repflows.set_davg_zero = False + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged mean/stddev on repflows block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repflows.stats = base_stats + dd_link.repflows.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.repflows.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repflows.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repflows, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repflows.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.repflows.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_default_config_skips_merge(self) -> None: + """Default DPA3 has set_davg_zero=True and set_stddev_constant=True, so merge is no-op.""" + dd_base = self._make_descriptor(GLOBAL_SEED, fix_stat_std=0.3) + dd_link = self._make_descriptor(GLOBAL_SEED + 1, fix_stat_std=0.3) + # Restore defaults + dd_base.repflows.set_davg_zero = True + dd_link.repflows.set_davg_zero = True + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repflows.stats = base_stats + dd_link.repflows.stats = link_stats + + original_mean = dd_base.repflows.mean.clone() + original_stddev = dd_base.repflows.stddev.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + # Buffers should be unchanged + np.testing.assert_allclose( + dd_base.repflows.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + np.testing.assert_allclose( + dd_base.repflows.stddev.detach().cpu().numpy(), + original_stddev.detach().cpu().numpy(), + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repflows.stats = base_stats + dd_link.repflows.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.repflows.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repflows.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repflows.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.repflows.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeSeTTebd: + """Test stat merging for se_t_tebd descriptor (se_ttebd block has mean/stddev).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 500) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def _make_descriptor(self, seed): + dd = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + set_davg_zero=False, + seed=seed, + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge(self, model_prob) -> None: + """Verify merged mean/stddev on se_ttebd block match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.se_ttebd.stats = base_stats + dd_link.se_ttebd.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_ttebd.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_ttebd.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.se_ttebd, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.se_ttebd.mean.detach().cpu().numpy(), expected_mean, rtol=1e-10 + ) + np.testing.assert_allclose( + dd_base.se_ttebd.stddev.detach().cpu().numpy(), expected_stddev, rtol=1e-10 + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.se_ttebd.stats = base_stats + dd_link.se_ttebd.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, self.nnei, self.last_dim)) + ) + dd_base.se_ttebd.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.se_ttebd.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.se_ttebd.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.se_ttebd.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeDPA2ThreeBody: + """Test stat merging for DPA2 descriptor with use_three_body=True.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel_mix: ClassVar = [7] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 600) + self.ntypes = 2 + self.nnei = sum(self.sel_mix) + self.last_dim = 4 + self.three_body_sel = 5 + self.three_body_rcut = self.rcut + self.three_body_rcut_smth = self.rcut_smth + + def _make_descriptor(self, seed): + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="strip", + set_davg_zero=False, + use_three_body=True, + three_body_sel=self.three_body_sel, + three_body_rcut=self.three_body_rcut, + three_body_rcut_smth=self.three_body_rcut_smth, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=self.nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=False, + update_g1_has_grrg=False, + update_g1_has_attn=False, + update_g2_has_g1g1=False, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=True, + update_style="res_residual", + set_davg_zero=False, + ) + dd = DescrptDPA2( + self.ntypes, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + seed=seed, + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_three_body(self, model_prob) -> None: + """Verify merged mean/stddev on repinit_three_body block.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + assert dd_base.use_three_body + assert dd_base.repinit_three_body is not None + + # repinit_three_body is a DescrptBlockSeTTebd with mean/stddev + nnei_3b = self.three_body_sel + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.repinit_three_body.stats = base_stats + dd_link.repinit_three_body.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, nnei_3b, self.last_dim)) + stddev0 = 0.1 + np.abs( + self.rng.normal(size=(self.ntypes, nnei_3b, self.last_dim)) + ) + dd_base.repinit_three_body.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + dd_base.repinit_three_body.stddev = torch.tensor( + stddev0, dtype=torch.float64, device=self.device + ) + + merged_stats = _merge_stats(base_stats, link_stats, model_prob) + expected_mean, expected_stddev = _compute_expected_buffers( + dd_base.repinit_three_body, merged_stats, self.last_dim + ) + + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + np.testing.assert_allclose( + dd_base.repinit_three_body.mean.detach().cpu().numpy(), + expected_mean, + rtol=1e-10, + ) + np.testing.assert_allclose( + dd_base.repinit_three_body.stddev.detach().cpu().numpy(), + expected_stddev, + rtol=1e-10, + ) + + def test_three_body_aliased(self) -> None: + """After share_params, repinit_three_body modules should be aliased.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit_three_body.stats = base_stats + dd_link.repinit_three_body.stats = link_stats + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + assert ( + dd_link._modules["repinit_three_body"] + is dd_base._modules["repinit_three_body"] + ) + + def test_resume_skips_three_body_merge(self) -> None: + """resume=True should skip stat merging on three-body block.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + nnei_3b = self.three_body_sel + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.repinit_three_body.stats = base_stats + dd_link.repinit_three_body.stats = link_stats + + mean0 = self.rng.normal(size=(self.ntypes, nnei_3b, self.last_dim)) + dd_base.repinit_three_body.mean = torch.tensor( + mean0, dtype=torch.float64, device=self.device + ) + original_mean = dd_base.repinit_three_body.mean.clone() + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.repinit_three_body.mean.detach().cpu().numpy(), + original_mean.detach().cpu().numpy(), + ) + + +class TestStatMergeHybrid: + """Test stat merging for hybrid descriptor (passes model_prob to sub-descriptors).""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 700) + self.ntypes = 2 + self.nnei = sum(self.sel) + + def _make_descriptor(self, seed): + """Create a hybrid descriptor with se_e2_a (last_dim=4) + se_r (last_dim=1).""" + dd = DescrptHybrid( + list=[ + DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=seed), + DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=seed + 10), + ], + ).to(self.device) + return dd + + @pytest.mark.parametrize("model_prob", [0.6, 1.0]) # probability weight + def test_stat_merge_sub_descriptors(self, model_prob) -> None: + """Verify merged davg/dstd on each sub-descriptor match expected values.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + # SeA sub-descriptor (last_dim=4) + sea_base_stats = _make_stats(self.ntypes, 4, self.rng) + sea_link_stats = _make_stats(self.ntypes, 4, self.rng) + dd_base.descrpt_list[0].stats = sea_base_stats + dd_link.descrpt_list[0].stats = sea_link_stats + + davg0_sea = self.rng.normal(size=(self.ntypes, self.nnei, 4)) + dstd0_sea = 0.1 + np.abs(self.rng.normal(size=(self.ntypes, self.nnei, 4))) + dd_base.descrpt_list[0].davg = torch.tensor( + davg0_sea, dtype=torch.float64, device=self.device + ) + dd_base.descrpt_list[0].dstd = torch.tensor( + dstd0_sea, dtype=torch.float64, device=self.device + ) + + # SeR sub-descriptor (last_dim=1) + ser_base_stats = _make_stats(self.ntypes, 1, self.rng) + ser_link_stats = _make_stats(self.ntypes, 1, self.rng) + dd_base.descrpt_list[1].stats = ser_base_stats + dd_link.descrpt_list[1].stats = ser_link_stats + + davg0_ser = self.rng.normal(size=(self.ntypes, self.nnei, 1)) + dstd0_ser = 0.1 + np.abs(self.rng.normal(size=(self.ntypes, self.nnei, 1))) + dd_base.descrpt_list[1].davg = torch.tensor( + davg0_ser, dtype=torch.float64, device=self.device + ) + dd_base.descrpt_list[1].dstd = torch.tensor( + dstd0_ser, dtype=torch.float64, device=self.device + ) + + # Compute expected for SeA + merged_sea = _merge_stats(sea_base_stats, sea_link_stats, model_prob) + exp_mean_sea, exp_std_sea = _compute_expected_buffers( + dd_base.descrpt_list[0], merged_sea, 4 + ) + + # Compute expected for SeR + merged_ser = _merge_stats(ser_base_stats, ser_link_stats, model_prob) + exp_mean_ser, exp_std_ser = _compute_expected_buffers( + dd_base.descrpt_list[1], merged_ser, 1 + ) + + # share_params on hybrid passes model_prob to each sub-descriptor + dd_link.share_params( + dd_base, shared_level=0, model_prob=model_prob, resume=False + ) + + # Verify SeA sub-descriptor buffers + np.testing.assert_allclose( + dd_base.descrpt_list[0].davg.detach().cpu().numpy(), + exp_mean_sea, + rtol=1e-10, + ) + np.testing.assert_allclose( + dd_base.descrpt_list[0].dstd.detach().cpu().numpy(), + exp_std_sea, + rtol=1e-10, + ) + + # Verify SeR sub-descriptor buffers + np.testing.assert_allclose( + dd_base.descrpt_list[1].davg.detach().cpu().numpy(), + exp_mean_ser, + rtol=1e-10, + ) + np.testing.assert_allclose( + dd_base.descrpt_list[1].dstd.detach().cpu().numpy(), + exp_std_ser, + rtol=1e-10, + ) + + def test_sub_descriptors_aliased(self) -> None: + """After share_params, sub-descriptor modules should be aliased.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + # Populate stats to avoid None-guard early return + for i in range(2): + last_dim = 4 if i == 0 else 1 + dd_base.descrpt_list[i].stats = _make_stats(self.ntypes, last_dim, self.rng) + dd_link.descrpt_list[i].stats = _make_stats(self.ntypes, last_dim, self.rng) + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=False) + + for i in range(2): + assert ( + dd_link.descrpt_list[i].davg.data_ptr() + == dd_base.descrpt_list[i].davg.data_ptr() + ) + + def test_resume_skips_merge(self) -> None: + """resume=True should skip stat merging on all sub-descriptors.""" + dd_base = self._make_descriptor(GLOBAL_SEED) + dd_link = self._make_descriptor(GLOBAL_SEED + 1) + + sea_base_stats = _make_stats(self.ntypes, 4, self.rng) + sea_link_stats = _make_stats(self.ntypes, 4, self.rng) + dd_base.descrpt_list[0].stats = sea_base_stats + dd_link.descrpt_list[0].stats = sea_link_stats + + davg0 = self.rng.normal(size=(self.ntypes, self.nnei, 4)) + dd_base.descrpt_list[0].davg = torch.tensor( + davg0, dtype=torch.float64, device=self.device + ) + original_davg = dd_base.descrpt_list[0].davg.clone() + + # Need stats on all sub-descriptors to avoid None guard + dd_base.descrpt_list[1].stats = _make_stats(self.ntypes, 1, self.rng) + dd_link.descrpt_list[1].stats = _make_stats(self.ntypes, 1, self.rng) + + dd_link.share_params(dd_base, shared_level=0, model_prob=0.6, resume=True) + + np.testing.assert_allclose( + dd_base.descrpt_list[0].davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + + +class TestMergeEnvStatUnit: + """Unit tests for the merge_env_stat function directly.""" + + rcut = 2.2 + rcut_smth = 0.4 + sel: ClassVar = [5, 2] + + def setup_method(self) -> None: + self.device = env.DEVICE + self.rng = np.random.default_rng(GLOBAL_SEED + 600) + self.ntypes = 2 + self.nnei = sum(self.sel) + self.last_dim = 4 + + def test_merge_produces_correct_stats(self) -> None: + """merge_env_stat should compute merged = base + link * model_prob.""" + dd_base = DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + + dd_link = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1 + ).to(self.device) + dd_link.stats = link_stats + + model_prob = 0.3 + merge_env_stat(dd_base, dd_link, model_prob) + + for kk in base_stats: + expected = base_stats[kk] + link_stats[kk] * model_prob + assert abs(dd_base.stats[kk].number - expected.number) < 1e-10 + assert abs(dd_base.stats[kk].sum - expected.sum) < 1e-10 + assert abs(dd_base.stats[kk].squared_sum - expected.squared_sum) < 1e-10 + + def test_chaining_three_models(self) -> None: + """Merging stats from 3 models should accumulate correctly.""" + dd_base = DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + dd_link1 = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1 + ).to(self.device) + dd_link2 = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 2 + ).to(self.device) + + stats_base = _make_stats(self.ntypes, self.last_dim, self.rng) + stats_link1 = _make_stats(self.ntypes, self.last_dim, self.rng) + stats_link2 = _make_stats(self.ntypes, self.last_dim, self.rng) + + dd_base.stats = stats_base + dd_link1.stats = stats_link1 + dd_link2.stats = stats_link2 + + prob1, prob2 = 0.5, 0.3 + + merge_env_stat(dd_base, dd_link1, prob1) + merge_env_stat(dd_base, dd_link2, prob2) + + for kk in stats_base: + expected = ( + stats_base[kk] + stats_link1[kk] * prob1 + stats_link2[kk] * prob2 + ) + assert abs(dd_base.stats[kk].number - expected.number) < 1e-10 + assert abs(dd_base.stats[kk].sum - expected.sum) < 1e-10 + + def test_set_davg_zero_respected(self) -> None: + """When set_davg_zero=True, davg should remain zero after merging.""" + dd_base = DescrptSeA( + self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED, set_davg_zero=True + ).to(self.device) + dd_link = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + seed=GLOBAL_SEED + 1, + set_davg_zero=True, + ).to(self.device) + + base_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + link_stats = _make_stats(self.ntypes, self.last_dim, self.rng) + dd_base.stats = base_stats + dd_link.stats = link_stats + + original_davg = dd_base.davg.clone() + merge_env_stat(dd_base, dd_link, 0.6) + + # davg should stay zero + np.testing.assert_allclose( + dd_base.davg.detach().cpu().numpy(), + original_davg.detach().cpu().numpy(), + ) + # but dstd should be updated + assert dd_base.stats is not base_stats # stats dict replaced diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py index e90c67bc82..2662997a87 100644 --- a/source/tests/pt_expt/descriptor/test_dpa1.py +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -239,3 +239,46 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED, + ).to(self.device) + dd1 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED + 1, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.se_atten.stddev = torch.tensor( + dstd0, dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["se_atten"] is dd0._modules["se_atten"] + elif shared_level == 1: + assert dd1._modules["se_atten"] is not dd0._modules["se_atten"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/descriptor/test_dpa3.py b/source/tests/pt_expt/descriptor/test_dpa3.py index ecc94d24f5..ef4b479724 100644 --- a/source/tests/pt_expt/descriptor/test_dpa3.py +++ b/source/tests/pt_expt/descriptor/test_dpa3.py @@ -260,3 +260,54 @@ def fn(coord_ext, atype_ext, nlist, mapping): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=True, + update_style="res_residual", + update_residual_init="const", + smooth_edge_update=True, + ) + + dd0 = DescrptDPA3( + self.nt, repflow=repflow, exclude_types=[], seed=GLOBAL_SEED + ).to(self.device) + dd1 = DescrptDPA3( + self.nt, repflow=repflow, exclude_types=[], seed=GLOBAL_SEED + 1 + ).to(self.device) + dd0.repflows.mean = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.repflows.stddev = torch.tensor( + dstd0, dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["repflows"] is dd0._modules["repflows"] + elif shared_level == 1: + assert dd1._modules["repflows"] is not dd0._modules["repflows"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/descriptor/test_hybrid.py b/source/tests/pt_expt/descriptor/test_hybrid.py index a3c673d774..5fa8970bf1 100644 --- a/source/tests/pt_expt/descriptor/test_hybrid.py +++ b/source/tests/pt_expt/descriptor/test_hybrid.py @@ -231,3 +231,56 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + def test_share_params(self) -> None: + """share_params level 0: recursively shares all sub-descriptors.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg4 = rng.normal(size=(self.nt, nnei, 4)) + dstd4 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptHybrid( + list=[ + DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED), + DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED), + ] + ).to(self.device) + dd1 = DescrptHybrid( + list=[ + DescrptSeA(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1), + DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1), + ] + ).to(self.device) + + # set stats on dd0's sub-descriptors + dd0.descrpt_list[0].davg = torch.tensor( + davg4, dtype=torch.float64, device=self.device + ) + dd0.descrpt_list[0].dstd = torch.tensor( + dstd4, dtype=torch.float64, device=self.device + ) + dd0.descrpt_list[1].davg = torch.tensor( + davg4[..., :1], dtype=torch.float64, device=self.device + ) + dd0.descrpt_list[1].dstd = torch.tensor( + dstd4[..., :1], dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=0) + + # each sub-descriptor's modules/buffers are shared + for ii in range(len(dd0.descrpt_list)): + for key in dd0.descrpt_list[ii]._modules: + assert ( + dd1.descrpt_list[ii]._modules[key] + is dd0.descrpt_list[ii]._modules[key] + ) + for key in dd0.descrpt_list[ii]._buffers: + assert ( + dd1.descrpt_list[ii]._buffers[key] + is dd0.descrpt_list[ii]._buffers[key] + ) + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=1) diff --git a/source/tests/pt_expt/descriptor/test_se_atten_v2.py b/source/tests/pt_expt/descriptor/test_se_atten_v2.py index 326a78acad..cc86c1600b 100644 --- a/source/tests/pt_expt/descriptor/test_se_atten_v2.py +++ b/source/tests/pt_expt/descriptor/test_se_atten_v2.py @@ -234,3 +234,46 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED, + ).to(self.device) + dd1 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + seed=GLOBAL_SEED + 1, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.se_atten.stddev = torch.tensor( + dstd0, dtype=torch.float64, device=self.device + ) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["se_atten"] is dd0._modules["se_atten"] + elif shared_level == 1: + assert dd1._modules["se_atten"] is not dd0._modules["se_atten"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index cde3295e7a..9056c9f308 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -216,3 +216,43 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + def test_share_params(self) -> None: + """share_params level 0: all modules and buffers are shared.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 1)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 1))) + + dd0 = DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + dd1 = DescrptSeR(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1).to( + self.device + ) + dd0.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + dd1.share_params(dd0, shared_level=0) + + # all modules and buffers are shared (same object) + for key in dd0._modules: + assert dd1._modules[key] is dd0._modules[key] + for key in dd0._buffers: + assert dd1._buffers[key] is dd0._buffers[key] + + # forward pass produces identical output + inputs = ( + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + rd0 = dd0(*inputs)[0] + rd1 = dd1(*inputs)[0] + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd1.detach().cpu().numpy() + ) + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=1) diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py index bb1f9a4b3f..ed71f1e0ed 100644 --- a/source/tests/pt_expt/descriptor/test_se_t.py +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -220,3 +220,43 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + def test_share_params(self) -> None: + """share_params level 0: all modules and buffers are shared.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptSeT(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED).to( + self.device + ) + dd1 = DescrptSeT(self.rcut, self.rcut_smth, self.sel, seed=GLOBAL_SEED + 1).to( + self.device + ) + dd0.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + dd1.share_params(dd0, shared_level=0) + + # all modules and buffers are shared (same object) + for key in dd0._modules: + assert dd1._modules[key] is dd0._modules[key] + for key in dd0._buffers: + assert dd1._buffers[key] is dd0._buffers[key] + + # forward pass produces identical output + inputs = ( + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + rd0 = dd0(*inputs)[0] + rd1 = dd1(*inputs)[0] + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), rd1.detach().cpu().numpy() + ) + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=1) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index 30808f5070..41643b41ef 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -251,3 +251,34 @@ def fn(coord_ext, atype_ext, nlist): rtol=rtol, atol=atol, ) + + @pytest.mark.parametrize("shared_level", [0, 1]) # sharing level + def test_share_params(self, shared_level) -> None: + """share_params level 0: share all; level 1: share type_embedding only.""" + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg0 = rng.normal(size=(self.nt, nnei, 4)) + dstd0 = 0.1 + np.abs(rng.normal(size=(self.nt, nnei, 4))) + + dd0 = DescrptSeTTebd( + self.rcut, self.rcut_smth, self.sel, self.nt, seed=GLOBAL_SEED + ).to(self.device) + dd1 = DescrptSeTTebd( + self.rcut, self.rcut_smth, self.sel, self.nt, seed=GLOBAL_SEED + 1 + ).to(self.device) + dd0.davg = torch.tensor(davg0, dtype=torch.float64, device=self.device) + dd0.dstd = torch.tensor(dstd0, dtype=torch.float64, device=self.device) + + dd1.share_params(dd0, shared_level=shared_level) + + # type_embedding is always shared + assert dd1._modules["type_embedding"] is dd0._modules["type_embedding"] + + if shared_level == 0: + assert dd1._modules["se_ttebd"] is dd0._modules["se_ttebd"] + elif shared_level == 1: + assert dd1._modules["se_ttebd"] is not dd0._modules["se_ttebd"] + + # invalid level raises + with pytest.raises(NotImplementedError): + dd1.share_params(dd0, shared_level=2) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index dcb99dd324..038e0dcf27 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -1,18 +1,42 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os +import shutil +import tempfile import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) import numpy as np import torch +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.descriptor import ( DescrptSeA, ) +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) from deepmd.pt_expt.fitting import ( EnergyFittingNet, ) from deepmd.pt_expt.utils import ( env, ) +from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): @@ -74,6 +98,58 @@ def _brute_aparam_pt(data, ndim): return avg, std +def _get_weighted_fitting_stat( + model_prob: list, *stat_arrays, protection: float +) -> tuple[np.ndarray, np.ndarray]: + """Compute probability-weighted fparam avg and std (matching PT).""" + n_arrays = len(stat_arrays) + assert len(model_prob) == n_arrays + nframes = [stat.shape[0] for stat in stat_arrays] + sums = [stat.sum(axis=0) for stat in stat_arrays] + squared_sums = [(stat**2).sum(axis=0) for stat in stat_arrays] + weighted_sum = sum(model_prob[i] * sums[i] for i in range(n_arrays)) + total_weighted_frames = sum(model_prob[i] * nframes[i] for i in range(n_arrays)) + weighted_avg = weighted_sum / total_weighted_frames + weighted_square_sum = sum(model_prob[i] * squared_sums[i] for i in range(n_arrays)) + weighted_square_avg = weighted_square_sum / total_weighted_frames + weighted_std = np.sqrt(weighted_square_avg - weighted_avg**2) + weighted_std = np.where(weighted_std < protection, protection, weighted_std) + return weighted_avg, weighted_std + + +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0") +_PT_DATA_NO_FPARAM = str( + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1" +) +_PT_DATA_SINGLE = str( + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single" +) + +_descriptor_se_e2_a = { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, +} + +_fitting_net = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, +} + + +def _skip_if_no_data() -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + + class TestEnerFittingStat(unittest.TestCase): def setUp(self) -> None: self.device = env.DEVICE @@ -125,3 +201,468 @@ def test(self) -> None: np.testing.assert_almost_equal(frefs_inv, fparam_inv_std_np) np.testing.assert_almost_equal(arefa, aparam_avg_np) np.testing.assert_almost_equal(arefs_inv, aparam_inv_std_np) + + +class TestMultiTaskFittingStat(unittest.TestCase): + """Test shared fitting stat (fparam_avg/fparam_inv_std) in multi-task. + + Corresponds to PT's TestMultiTaskFittingStat in test_fitting_stat.py. + Verifies: + 1. fparam stats are shared between models (same tensor values) + 2. stat file contents match raw data (number, sum, squared_sum) + 3. weighted stat computation matches model values + 4. case_embd with default_fparam works correctly + """ + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + if not os.path.isdir(_PT_DATA_SINGLE): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA_SINGLE}") + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_fitstat_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + self.stat_files = "se_e2_a_share_fit" + os.makedirs(self.stat_files, exist_ok=True) + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def _make_sharefit_config( + self, + *, + numb_fparam: int = 2, + numb_aparam: int = 0, + default_fparam: list | None = None, + dim_case_embd: int = 2, + model_keys: list[str] | None = None, + data_dirs: dict[str, str] | None = None, + model_probs: dict[str, float] | None = None, + ) -> dict: + """Build a multi-task config with shared fitting + fparam.""" + if model_keys is None: + model_keys = ["model_1", "model_2"] + if data_dirs is None: + data_dirs = dict.fromkeys(model_keys, _PT_DATA) + if model_probs is None: + model_probs = {mk: 1.0 / len(model_keys) for mk in model_keys} + + shared_fitting: dict = deepcopy(_fitting_net) + shared_fitting["numb_fparam"] = numb_fparam + if numb_aparam > 0: + shared_fitting["numb_aparam"] = numb_aparam + shared_fitting["dim_case_embd"] = dim_case_embd + if default_fparam is not None: + shared_fitting["default_fparam"] = default_fparam + + shared_dict: dict = { + "my_type_map": ["O", "H"], + "my_descriptor": deepcopy(_descriptor_se_e2_a), + "my_fitting": shared_fitting, + } + + model_dict = {} + loss_dict = {} + data_dict = {} + for mk in model_keys: + model_dict[mk] = { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + } + loss_dict[mk] = { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + } + data_dict[mk] = { + "stat_file": f"{self.stat_files}/{mk}", + "training_data": { + "systems": [data_dirs[mk]], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dirs[mk]], + "batch_size": 1, + "numb_btch": 1, + }, + } + + config = { + "model": { + "shared_dict": shared_dict, + "model_dict": model_dict, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": loss_dict, + "training": { + "model_prob": model_probs, + "data_dict": data_dict, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 1, + }, + } + return config + + def test_sharefitting_with_fparam(self) -> None: + """Shared fitting with fparam data: weighted fparam stat merging.""" + model_prob = [0.3, 0.7] + config = self._make_sharefit_config( + numb_fparam=2, + default_fparam=[1.0, 0.0], + data_dirs={"model_1": _PT_DATA, "model_2": _PT_DATA_SINGLE}, + model_probs={"model_1": model_prob[0], "model_2": model_prob[1]}, + ) + # data_0 has 80 frames; use data_stat_nbatch=100 to cover all frames + config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 80 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(deepcopy(config), shared_links=shared_links) + trainer.run() + + # fparam_avg and fparam_inv_std should be shared between models + multi_state_dict = trainer.wrapper.model.state_dict() + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_inv_std"], + ) + + # check fitting stat in stat_file is correct + fparam_stat_model1 = np.load(f"{self.stat_files}/model_1/O H/fparam") + fparam_stat_model2 = np.load(f"{self.stat_files}/model_2/O H/fparam") + fparam_data1 = np.load(os.path.join(_PT_DATA, "set.000", "fparam.npy")) + fparam_data2 = np.load(os.path.join(_PT_DATA_SINGLE, "set.000", "fparam.npy")) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 0], [fparam_data1.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 1], fparam_data1.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 2], (fparam_data1**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 0], [fparam_data2.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 1], fparam_data2.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 2], (fparam_data2**2).sum(axis=0) + ) + + # check shared fitting stat is computed correctly + weighted_avg, weighted_std = _get_weighted_fitting_stat( + model_prob, fparam_data1, fparam_data2, protection=1e-2 + ) + np.testing.assert_almost_equal( + weighted_avg, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"] + ), + ) + np.testing.assert_almost_equal( + 1.0 / weighted_std, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"] + ), + ) + + def test_sharefitting_using_default_fparam(self) -> None: + """3 models with dim_case_embd=3, default fparam, no fparam in data.""" + default_fparam = [1.0, 0.0] + model_prob = [0.1, 0.3, 0.6] + data_stat_protect = 5e-3 + config = self._make_sharefit_config( + numb_fparam=2, + default_fparam=default_fparam, + dim_case_embd=3, + model_keys=["model_1", "model_2", "model_3"], + data_dirs={ + "model_1": _PT_DATA_NO_FPARAM, + "model_2": _PT_DATA_SINGLE, + "model_3": _PT_DATA, + }, + model_probs={ + "model_1": model_prob[0], + "model_2": model_prob[1], + "model_3": model_prob[2], + }, + ) + # model_1 uses data without fparam → default_fparam is used + config["model"]["model_dict"]["model_1"]["data_stat_nbatch"] = 3 + config["model"]["model_dict"]["model_3"]["data_stat_nbatch"] = 80 + config["model"]["model_dict"]["model_1"]["data_stat_protect"] = ( + data_stat_protect + ) + config["model"]["model_dict"]["model_2"]["data_stat_protect"] = ( + data_stat_protect + ) + config["model"]["model_dict"]["model_3"]["data_stat_protect"] = ( + data_stat_protect + ) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(deepcopy(config), shared_links=shared_links) + trainer.run() + + # fparam_avg shared across all 3 models + multi_state_dict = trainer.wrapper.model.state_dict() + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"], + multi_state_dict["model_3.atomic_model.fitting_net.fparam_avg"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_2.atomic_model.fitting_net.fparam_inv_std"], + ) + torch.testing.assert_close( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"], + multi_state_dict["model_3.atomic_model.fitting_net.fparam_inv_std"], + ) + + # check fitting stat in stat_file is correct + fparam_stat_model1 = np.load(f"{self.stat_files}/model_1/O H/fparam") + fparam_stat_model2 = np.load(f"{self.stat_files}/model_2/O H/fparam") + fparam_stat_model3 = np.load(f"{self.stat_files}/model_3/O H/fparam") + fparam_data1 = np.array([default_fparam]).repeat(3, axis=0) + fparam_data2 = np.load(os.path.join(_PT_DATA_SINGLE, "set.000", "fparam.npy")) + fparam_data3 = np.load(os.path.join(_PT_DATA, "set.000", "fparam.npy")) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 0], [fparam_data1.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 1], fparam_data1.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model1[:, 2], (fparam_data1**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 0], [fparam_data2.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 1], fparam_data2.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model2[:, 2], (fparam_data2**2).sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 0], [fparam_data3.shape[0]] * 2 + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 1], fparam_data3.sum(axis=0) + ) + np.testing.assert_almost_equal( + fparam_stat_model3[:, 2], (fparam_data3**2).sum(axis=0) + ) + + # check shared fitting stat is computed correctly + weighted_avg, weighted_std = _get_weighted_fitting_stat( + model_prob, + fparam_data1, + fparam_data2, + fparam_data3, + protection=data_stat_protect, + ) + np.testing.assert_almost_equal( + weighted_avg, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_avg"] + ), + ) + np.testing.assert_almost_equal( + 1.0 / weighted_std, + to_numpy_array( + multi_state_dict["model_1.atomic_model.fitting_net.fparam_inv_std"] + ), + ) + + # case_embd should be set on all 3 models + ce1 = trainer.wrapper.model["model_1"].atomic_model.fitting_net.case_embd + ce2 = trainer.wrapper.model["model_2"].atomic_model.fitting_net.case_embd + ce3 = trainer.wrapper.model["model_3"].atomic_model.fitting_net.case_embd + self.assertIsNotNone(ce1) + self.assertIsNotNone(ce2) + self.assertIsNotNone(ce3) + + # dim_case_embd=3 → each is a 3-element one-hot vector + self.assertEqual(ce1.shape[-1], 3) + self.assertEqual(ce2.shape[-1], 3) + self.assertEqual(ce3.shape[-1], 3) + + # Each should be one-hot + self.assertEqual(ce1.sum().item(), 1.0) + self.assertEqual(ce2.sum().item(), 1.0) + self.assertEqual(ce3.sum().item(), 1.0) + + # All three should be different + self.assertFalse(torch.equal(ce1, ce2)) + self.assertFalse(torch.equal(ce1, ce3)) + self.assertFalse(torch.equal(ce2, ce3)) + + # case_embd should NOT be shared in state_dict + for state_key in multi_state_dict: + if "case_embd" in state_key and "model_1" in state_key: + k2 = state_key.replace("model_1", "model_2") + k3 = state_key.replace("model_1", "model_3") + self.assertFalse( + torch.equal(multi_state_dict[state_key], multi_state_dict[k2]), + ) + self.assertFalse( + torch.equal(multi_state_dict[state_key], multi_state_dict[k3]), + ) + + def test_sharefitting_with_aparam(self) -> None: + """Weighted aparam stat merging in share_params (unit test). + + Directly tests the aparam branch in InvarFitting.share_params by + creating two fittings with different aparam stats and verifying that + share_params produces the correct probability-weighted merged result. + """ + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + ntypes = descrpt.get_ntypes() + dim_out = descrpt.get_dim_out() + + fit_base = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_aparam=3, seed=1 + ).to(env.DEVICE) + fit_link = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_aparam=3, seed=2 + ).to(env.DEVICE) + + # give both fittings different aparam stats + data_base = _make_fake_data_pt( + [10, 100], [5, 2], [0, 10, 100], [2, 0.4, 0.00001] + ) + data_link = _make_fake_data_pt([50], [8], [5, 20, 50], [1, 0.5, 0.01]) + fit_base.compute_input_stats(data_base, protection=1e-2) + fit_link.compute_input_stats(data_link, protection=1e-2) + + # record base's aparam_avg before share_params + orig_base_avg = fit_base.aparam_avg.clone() + + # share_params with model_prob=0.6 — should do weighted merging + model_prob = 0.6 + fit_link.share_params( + fit_base, shared_level=0, model_prob=model_prob, protection=1e-2 + ) + + # base's aparam_avg was UPDATED (weighted merging happened) + self.assertFalse( + torch.equal(fit_base.aparam_avg, orig_base_avg), + "aparam_avg should have changed after weighted merging", + ) + + # buffers are shared (same data_ptr) + self.assertEqual(fit_link.aparam_avg.data_ptr(), fit_base.aparam_avg.data_ptr()) + self.assertEqual( + fit_link.aparam_inv_std.data_ptr(), fit_base.aparam_inv_std.data_ptr() + ) + + # verify the merged stats match manual computation + # reconstruct raw aparam data from each fitting's stats + base_aparam_stats = fit_base.get_param_stats().get("aparam", []) + # the merged stats should have 3 StatItem objects + self.assertEqual(len(base_aparam_stats), 3) + + # manually compute the weighted average from raw data + # data_base has two systems: [10 natoms, 5 frames] + [100 natoms, 2 frames] + # data_link has one system: [50 natoms, 8 frames] + # aparam per system: reshape to (nframes * natoms, numb_aparam) + all_base = np.concatenate( + [d["aparam"].reshape(-1, 3) for d in data_base], axis=0 + ) + all_link = np.concatenate( + [d["aparam"].reshape(-1, 3) for d in data_link], axis=0 + ) + # weighted stat: base contributes with weight 1.0, link with model_prob + total_n = all_base.shape[0] + model_prob * all_link.shape[0] + weighted_sum = all_base.sum(axis=0) + model_prob * all_link.sum(axis=0) + weighted_avg = weighted_sum / total_n + weighted_sq_sum = (all_base**2).sum(axis=0) + model_prob * (all_link**2).sum( + axis=0 + ) + weighted_sq_avg = weighted_sq_sum / total_n + weighted_std = np.sqrt(weighted_sq_avg - weighted_avg**2) + weighted_std = np.where(weighted_std < 1e-2, 1e-2, weighted_std) + + aparam_avg_np = to_numpy_array(fit_base.aparam_avg) + aparam_inv_std_np = to_numpy_array(fit_base.aparam_inv_std) + np.testing.assert_almost_equal(aparam_avg_np, weighted_avg) + np.testing.assert_almost_equal(aparam_inv_std_np, 1.0 / weighted_std) + + def test_sharefitting_resume_preserves_stats(self) -> None: + """resume=True in share_params skips stat merging, preserves buffers.""" + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + ntypes = descrpt.get_ntypes() + dim_out = descrpt.get_dim_out() + + fit_base = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_fparam=2, seed=1 + ).to(env.DEVICE) + fit_link = EnergyFittingNet( + ntypes, dim_out, neuron=[16, 16], numb_fparam=2, seed=2 + ).to(env.DEVICE) + + # give both fittings different stats + data_base = _make_fake_data_pt([10], [5], [0, 10], [2, 0.4]) + data_link = _make_fake_data_pt([100], [2], [100, 0], [0.001, 3]) + fit_base.compute_input_stats(data_base, protection=1e-2) + fit_link.compute_input_stats(data_link, protection=1e-2) + + # record base's fparam_avg BEFORE sharing + orig_avg = fit_base.fparam_avg.clone() + orig_inv_std = fit_base.fparam_inv_std.clone() + + # share_params with resume=True: should NOT re-merge stats + fit_link.share_params(fit_base, shared_level=0, resume=True) + + # base's fparam_avg unchanged (no weighted merging happened) + torch.testing.assert_close(fit_base.fparam_avg, orig_avg) + torch.testing.assert_close(fit_base.fparam_inv_std, orig_inv_std) + + # buffers are shared (same data_ptr) + self.assertEqual(fit_link.fparam_avg.data_ptr(), fit_base.fparam_avg.data_ptr()) + self.assertEqual( + fit_link.fparam_inv_std.data_ptr(), fit_base.fparam_inv_std.data_ptr() + ) + + def test_case_embd_mismatched_dim_raises(self) -> None: + """dim_case_embd must be the same across all models.""" + config = self._make_sharefit_config(dim_case_embd=2) + # Override model_2 to have a different dim_case_embd + config["model"]["model_dict"]["model_2"]["fitting_net"] = deepcopy(_fitting_net) + config["model"]["model_dict"]["model_2"]["fitting_net"]["dim_case_embd"] = 3 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + with self.assertRaises( + ValueError, msg="Should reject mismatched dim_case_embd" + ): + get_trainer(config, shared_links=shared_links) diff --git a/source/tests/pt_expt/test_change_bias.py b/source/tests/pt_expt/test_change_bias.py index 03329642e9..e3749671aa 100644 --- a/source/tests/pt_expt/test_change_bias.py +++ b/source/tests/pt_expt/test_change_bias.py @@ -145,7 +145,7 @@ def setUpClass(cls) -> None: cls.model_path = os.path.join(cls.tmpdir, "model.ckpt.pt") # Record original bias - cls.original_bias = to_numpy(trainer.wrapper.model.get_out_bias()) + cls.original_bias = to_numpy(trainer.wrapper.model["Default"].get_out_bias()) # Pre-freeze shared .pte and .pt2 files so individual tests don't # each pay the AOTInductor compilation cost (~82s per .pt2). diff --git a/source/tests/pt_expt/test_finetune.py b/source/tests/pt_expt/test_finetune.py index 063bb85f71..b000c313a3 100644 --- a/source/tests/pt_expt/test_finetune.py +++ b/source/tests/pt_expt/test_finetune.py @@ -371,8 +371,10 @@ def test_finetune_change_type(self) -> None: wrapper_new = ModelWrapper(model_new) _, has_new_type = get_index_between_two_maps(old_type_map, new_type_map) - model_with_new_type_stat = wrapper_new.model if has_new_type else None - pretrained_wrapper.model.change_type_map( + model_with_new_type_stat = ( + wrapper_new.model["Default"] if has_new_type else None + ) + pretrained_wrapper.model["Default"].change_type_map( new_type_map, model_with_new_type_stat=model_with_new_type_stat, ) diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py new file mode 100644 index 0000000000..c21e1070ae --- /dev/null +++ b/source/tests/pt_expt/test_multitask.py @@ -0,0 +1,2099 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for multi-task training in the pt_expt backend. + +Verifies that: +1. Multi-task training completes without error for various descriptors +2. Shared descriptor parameters are identical between tasks +3. lcurve.out has per-model columns +4. Checkpoint save/load roundtrip works +5. Multi-task freeze extracts single head correctly +6. Shared fitting_net with case_embd works (share_fitting) +7. Shared fitting stat (fparam_avg/fparam_inv_std) are shared between models +8. Case embedding with 3 models and dim_case_embd=3 works correctly +9. Multi-task descriptor gradients match sum of single-task gradients +""" + +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from unittest import ( + mock, +) + +import numpy as np +import torch + +import deepmd.utils.random as dp_random +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt_expt.model import ( + get_model, +) +from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.pt_expt.utils.stat import ( + make_stat_input, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, + process_systems, +) + +_energy_data_requirement = [ + DataRequirementItem("energy", ndof=1, atomic=False, must=False, high_prec=True), + DataRequirementItem("force", ndof=3, atomic=True, must=False, high_prec=False), + DataRequirementItem("virial", ndof=9, atomic=False, must=False, high_prec=False), +] + +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent / "pt" / "water" / "data" / "data_0") + +# Path to the multitask template JSON +_MULTITASK_JSON = str(Path(__file__).parent.parent / "pt" / "water" / "multitask.json") +_MULTITASK_SHAREFIT_JSON = str( + Path(__file__).parent.parent / "pt" / "model" / "water" / "multitask_sharefit.json" +) + + +def _skip_if_no_data() -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + + +# --------------------------------------------------------------------------- +# Descriptor configs (small models for fast testing) +# --------------------------------------------------------------------------- +_descriptor_se_e2_a = { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, +} + +_descriptor_dpa1 = { + "type": "se_atten", + "sel": 18, + "rcut_smth": 0.5, + "rcut": 3.0, + "neuron": [8, 16], + "axis_neuron": 4, + "attn": 16, + "attn_layer": 2, + "attn_dotr": True, + "attn_mask": False, + "set_davg_zero": True, + "type_one_side": True, + "seed": 1, +} + +_descriptor_dpa2 = { + "type": "dpa2", + "repinit": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 18, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "activation_function": "tanh", + }, + "repformer": { + "rcut": 3.0, + "rcut_smth": 0.5, + "nsel": 12, + "nlayers": 2, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_neuron": 4, + "update_h2": False, + "update_g1_has_conv": True, + "update_g1_has_grrg": True, + "update_g1_has_drrd": True, + "update_g1_has_attn": True, + "update_g2_has_g1g1": True, + "update_g2_has_attn": True, + "attn2_has_gate": True, + }, + "seed": 1, + "add_tebd_to_repinit_out": False, +} + +_descriptor_dpa3 = { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 5, + "a_dim": 4, + "nlayers": 2, + "e_rcut": 3.0, + "e_rcut_smth": 0.5, + "e_sel": 12, + "a_rcut": 3.0, + "a_rcut_smth": 0.5, + "a_sel": 8, + "axis_neuron": 4, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": True, + "update_angle": True, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const", + "smooth_edge_update": True, + }, + "activation_function": "silut:10.0", + "use_tebd_bias": False, + "precision": "float32", + "concat_output_tebd": False, +} + +_fitting_net = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, +} + + +def _make_multitask_config( + descriptor: dict, + data_dir: str = _PT_DATA, + numb_steps: int = 1, + share_fitting: bool = False, +) -> dict: + """Build a multi-task config with the given descriptor.""" + type_map = ["O", "H"] + fitting = deepcopy(_fitting_net) + + shared_dict: dict = { + "my_type_map": type_map, + "my_descriptor": deepcopy(descriptor), + } + + if share_fitting: + shared_fitting = deepcopy(fitting) + shared_fitting["dim_case_embd"] = 2 + shared_dict["my_fitting"] = shared_fitting + fitting_ref_1: dict | str = "my_fitting" + fitting_ref_2: dict | str = "my_fitting" + else: + fitting_ref_1 = deepcopy(fitting) + fitting_ref_2 = deepcopy(fitting) + + config = { + "model": { + "shared_dict": shared_dict, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_ref_1, + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_ref_2, + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + return config + + +class MultiTaskTrainTest: + """Mixin that tests multi-task training for a particular descriptor type. + + Subclasses must set ``self.config``, ``self.shared_links``, + and ``self.share_fitting`` before calling these test methods. + """ + + def test_multitask_train(self) -> None: + """Train, verify lcurve format and shared params.""" + trainer = get_trainer(deepcopy(self.config), shared_links=self.shared_links) + trainer.run() + + # --- lcurve.out format --- + lcurve_path = "lcurve.out" + self.assertTrue(os.path.exists(lcurve_path), "lcurve.out not created") + with open(lcurve_path) as f: + lines = f.readlines() + header_line = lines[0] + header_cols = header_line.strip().lstrip("#").split() + model_keys = list(self.config["training"]["model_prob"].keys()) + for mk in model_keys: + cols_for_model = [c for c in header_cols if mk in c] + self.assertGreater( + len(cols_for_model), 0, f"No lcurve columns found for {mk}" + ) + data_lines = [line for line in lines if not line.startswith("#")] + self.assertGreater(len(data_lines), 0, "No data lines in lcurve.out") + data_cols = data_lines[0].split() + self.assertEqual(len(data_cols), len(header_cols)) + + # --- model keys --- + self.assertEqual(len(trainer.wrapper.model), 2) + self.assertIn("model_1", trainer.wrapper.model) + self.assertIn("model_2", trainer.wrapper.model) + + # --- shared descriptor params are identical --- + multi_state_dict = trainer.wrapper.model.state_dict() + for state_key in multi_state_dict: + if "model_1" in state_key: + partner_key = state_key.replace("model_1", "model_2") + self.assertIn(partner_key, multi_state_dict) + if "model_2" in state_key: + partner_key = state_key.replace("model_2", "model_1") + self.assertIn(partner_key, multi_state_dict) + + is_descriptor = "model_1.atomic_model.descriptor" in state_key + is_shared_fitting = ( + self.share_fitting + and "model_1.atomic_model.fitting_net" in state_key + and "fitting_net.bias_atom_e" not in state_key + and "fitting_net.case_embd" not in state_key + ) + if is_descriptor or is_shared_fitting: + partner_key = state_key.replace("model_1", "model_2") + torch.testing.assert_close( + multi_state_dict[state_key], + multi_state_dict[partner_key], + msg=f"Shared param mismatch: {state_key}", + ) + + # --- checkpoint exists --- + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + self.assertGreater(len(ckpt_files), 0, "No checkpoint files saved") + + # --- case_embd verification (share_fitting only) --- + # Verify that each branch's case_embd is a distinct one-hot vector + # matching the alphabetical sort order, so the shared fitting net + # can distinguish which training dataset is being used. + if self.share_fitting: + ce1 = trainer.wrapper.model["model_1"].atomic_model.fitting_net.case_embd + ce2 = trainer.wrapper.model["model_2"].atomic_model.fitting_net.case_embd + self.assertIsNotNone(ce1, "case_embd not set on model_1") + self.assertIsNotNone(ce2, "case_embd not set on model_2") + dim = ce1.shape[0] + # Sorted keys: ["model_1", "model_2"] → indices 0, 1 + expected_eye = torch.eye(dim, dtype=ce1.dtype, device=ce1.device) + torch.testing.assert_close( + ce1, + expected_eye[0], + msg="model_1 case_embd should be one-hot index 0 (alphabetical order)", + ) + torch.testing.assert_close( + ce2, + expected_eye[1], + msg="model_2 case_embd should be one-hot index 1 (alphabetical order)", + ) + # case_embd should NOT be shared in state_dict + for state_key in multi_state_dict: + if ( + "model_1.atomic_model.fitting_net" in state_key + and "case_embd" in state_key + ): + partner_key = state_key.replace("model_1", "model_2") + self.assertFalse( + torch.equal( + multi_state_dict[state_key], + multi_state_dict[partner_key], + ), + f"case_embd should NOT be shared: {state_key}", + ) + + def test_multitask_finetune(self) -> None: + """Train, then finetune with 4 branches from pretrained 2-branch model. + + For mixed_types descriptors, uses extended type_map ["O","H","B"] to test + change_type_map + model_with_new_type_stat integration. For non-mixed_types + descriptors, uses same type_map ["O","H"]. + + Builds a reference state_dict by manually replicating the trainer's + finetune operations (load pretrained, change_type_map, weight copy) and + verifies per-branch weight inheritance: + - model_1: resume (ALL weights match reference) + - model_2: finetune from model_2 (all except out_bias/out_std match) + - model_3: finetune from model_2 as new head (cross-branch key remap) + - model_4: random fitting (descriptor from pretrained, random fitting_net) + """ + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + # Phase 1: train pretrained 2-branch model (2 steps) + config_pretrain = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting, numb_steps=2 + ) + config_pretrain["training"]["save_freq"] = 2 + config_pretrain["model"], shared_links_pre = preprocess_shared_params( + config_pretrain["model"] + ) + config_pretrain = update_deepmd_input(config_pretrain, warning=False) + config_pretrain = normalize(config_pretrain, multi_task=True) + trainer = get_trainer(config_pretrain, shared_links=shared_links_pre) + trainer.run() + + ckpt_path = os.path.join(os.getcwd(), "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Pretrained checkpoint not created") + + # Phase 2: build reference state_dict + # For mixed_types: extend type_map to ["O","H","B"], build + # model_with_new_type_stat with computed stats, and apply + # change_type_map on pretrained. + # For non-mixed_types: use pretrained state directly (no extension). + ft_type_map = ["O", "H", "B"] if self.mixed_types else ["O", "H"] + + state_dict_full = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) + state_dict_ckpt = ( + state_dict_full["model"] if "model" in state_dict_full else state_dict_full + ) + pretrained_model_params = state_dict_ckpt["_extra_state"]["model_params"] + + # Build pretrained wrapper (separate model per branch) + pretrained_models = {} + for pk in pretrained_model_params["model_dict"]: + pretrained_models[pk] = get_model( + deepcopy(pretrained_model_params["model_dict"][pk]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_models) + pretrained_wrapper.load_state_dict(state_dict_ckpt) + + # Record pretrained state BEFORE change_type_map — used later to + # verify O/H stats are inherited from pretrained, not recomputed. + pretrained_oh_state = { + k: v.clone() for k, v in pretrained_wrapper.model.state_dict().items() + } + + if self.mixed_types: + # Build a model with extended type_map and compute stats so that + # the new type ("B", unseen in data) gets proper default stats + # (davg=0, dstd=0.1) instead of the no-stat defaults (0/1). + ref_model_params = deepcopy( + pretrained_model_params["model_dict"]["model_1"] + ) + ref_model_params["type_map"] = ft_type_map + ref_model = get_model(ref_model_params).to(DEVICE) + + data_systems = process_systems([_PT_DATA]) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=ft_type_map, + trn_all_set=True, + ) + data.add_data_requirements(_energy_data_requirement) + ref_model.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + + # Apply change_type_map on each pretrained branch + for pk in pretrained_model_params["model_dict"]: + pretrained_wrapper.model[pk].change_type_map( + ft_type_map, + model_with_new_type_stat=ref_model, + ) + + ref_state_dict = pretrained_wrapper.model.state_dict() + + # Phase 3: build 4-branch finetune config + finetune_config = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting + ) + if self.mixed_types: + finetune_config["model"]["shared_dict"]["my_type_map"] = ft_type_map + + # Add model_3 and model_4 (copies of model_2) + finetune_config["model"]["model_dict"]["model_3"] = deepcopy( + finetune_config["model"]["model_dict"]["model_2"] + ) + finetune_config["model"]["model_dict"]["model_4"] = deepcopy( + finetune_config["model"]["model_dict"]["model_2"] + ) + finetune_config["loss_dict"]["model_3"] = deepcopy( + finetune_config["loss_dict"]["model_2"] + ) + finetune_config["loss_dict"]["model_4"] = deepcopy( + finetune_config["loss_dict"]["model_2"] + ) + finetune_config["training"]["model_prob"]["model_3"] = deepcopy( + finetune_config["training"]["model_prob"]["model_2"] + ) + finetune_config["training"]["model_prob"]["model_4"] = deepcopy( + finetune_config["training"]["model_prob"]["model_2"] + ) + finetune_config["training"]["data_dict"]["model_3"] = deepcopy( + finetune_config["training"]["data_dict"]["model_2"] + ) + finetune_config["training"]["data_dict"]["model_3"]["stat_file"] = ( + finetune_config["training"]["data_dict"]["model_3"]["stat_file"].replace( + "model_2", "model_3" + ) + ) + finetune_config["training"]["data_dict"]["model_4"] = deepcopy( + finetune_config["training"]["data_dict"]["model_2"] + ) + finetune_config["training"]["data_dict"]["model_4"]["stat_file"] = ( + finetune_config["training"]["data_dict"]["model_4"]["stat_file"].replace( + "model_2", "model_4" + ) + ) + + # Set finetune rules: + # model_1: no finetune_head → resume from model_1 (resuming=True) + # model_2: finetune_head="model_2" → finetune from model_2 + finetune_config["model"]["model_dict"]["model_2"]["finetune_head"] = "model_2" + # model_3: finetune_head="model_2" → finetune from model_2 (new head) + finetune_config["model"]["model_dict"]["model_3"]["finetune_head"] = "model_2" + # model_4: no finetune_head, new name → random fitting + + finetune_config["training"]["numb_steps"] = 1 + finetune_config["training"]["save_freq"] = 1 + + finetune_config["model"], shared_links_ft = preprocess_shared_params( + finetune_config["model"] + ) + finetune_config["model"], finetune_links = get_finetune_rules( + ckpt_path, finetune_config["model"] + ) + finetune_config = update_deepmd_input(finetune_config, warning=False) + finetune_config = normalize(finetune_config, multi_task=True) + + trainer_ft = get_trainer( + deepcopy(finetune_config), + finetune_model=ckpt_path, + shared_links=shared_links_ft, + finetune_links=finetune_links, + ) + + # Phase 4: verify weight inheritance against reference + ft_state_dict = trainer_ft.wrapper.model.state_dict() + + # When type_map is extended, type_embedding weights for the new type + # are randomly initialized (np.random.default_rng) during + # change_type_map; since reference and trainer build separate + # pretrained wrappers, these random values differ — skip them. + _skip_type_embed = self.mixed_types + + for state_key in ft_state_dict: + if _skip_type_embed and "type_embedding" in state_key: + continue + if "model_1" in state_key: + # model_1: resume — ALL weights match reference model_1 + torch.testing.assert_close( + ref_state_dict[state_key], + ft_state_dict[state_key], + msg=f"model_1 (resume) weight mismatch: {state_key}", + ) + elif ( + "model_2" in state_key + and "out_bias" not in state_key + and "out_std" not in state_key + ): + # model_2: finetune — all except out_bias/out_std + torch.testing.assert_close( + ref_state_dict[state_key], + ft_state_dict[state_key], + msg=f"model_2 (finetune) weight mismatch: {state_key}", + ) + elif ( + "model_3" in state_key + and "out_bias" not in state_key + and "out_std" not in state_key + ): + # model_3: finetune from model_2 — cross-branch key remap + ref_key = state_key.replace("model_3", "model_2") + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_3 (finetune from model_2) weight mismatch: {state_key}", + ) + elif ( + "model_4" in state_key + and "fitting_net" not in state_key + and "out_bias" not in state_key + and "out_std" not in state_key + ): + # model_4: random fitting — descriptor from pretrained + # (RANDOM + from_multitask uses first pretrained key = model_1; + # since descriptors are shared, model_1 == model_2 in pretrained) + ref_key = state_key.replace("model_4", "model_2") + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_4 (random fitting) descriptor mismatch: {state_key}", + ) + + # Phase 5: verify O/H descriptor stats are inherited from pretrained + # (not recomputed from finetune data). + # For mixed_types: pretrained has shape [2,...] (O,H); finetuned has + # shape [3,...] (O,H,B). The first 2 entries must match pretrained. + # For non-mixed_types: shapes are identical, already fully checked above. + _STAT_SUFFIXES = ("mean", "stddev", "davg", "dstd") + if self.mixed_types: + n_old = len(["O", "H"]) + n_new = len(ft_type_map) + checked_count = 0 + for key in ft_state_dict: + if not any(key.endswith(s) for s in _STAT_SUFFIXES): + continue + # Use model_1 (all branches share descriptor after share_params) + if "model_1" not in key: + continue + pre_key = key # same key in pretrained_oh_state + if pre_key not in pretrained_oh_state: + continue + pre_val = pretrained_oh_state[pre_key] + ft_val = ft_state_dict[key] + # Find the type axis (size grew from n_old to n_new) + for ax in range(pre_val.ndim): + if pre_val.shape[ax] == n_old and ft_val.shape[ax] == n_new: + for ti, tname in enumerate(["O", "H"]): + torch.testing.assert_close( + ft_val.select(ax, ti), + pre_val.select(ax, ti), + msg=( + f"{tname} stat not inherited from pretrained: {key}" + ), + ) + checked_count += 1 + break + self.assertGreater( + checked_count, + 0, + "No descriptor stat keys found for O/H inheritance check", + ) + + # Phase 6: verify case_embd inheritance (share_fitting only) + # Pretrained branches keep their case_embd (dataset correspondence). + # New branches (model_3 finetune from model_2, model_4 random) get + # case_embd from the weight copy: model_3 copies model_2's, model_4 + # keeps target default (zeros since set_case_embd is skipped on finetune). + if self.share_fitting: + + def _get_case_embd(mk): + return trainer_ft.wrapper.model[mk].atomic_model.fitting_net.case_embd + + ce1 = _get_case_embd("model_1") + ce2 = _get_case_embd("model_2") + ce3 = _get_case_embd("model_3") + ce4 = _get_case_embd("model_4") + # Pretrained had sorted keys ["model_1","model_2"] → one-hot [1,0], [0,1] + dim = ce1.shape[0] + expected_eye = torch.eye(dim, dtype=ce1.dtype, device=ce1.device) + # model_1 (resume): inherits pretrained model_1's case_embd + torch.testing.assert_close( + ce1, + expected_eye[0], + msg="model_1 case_embd should match pretrained model_1", + ) + # model_2 (finetune from model_2): inherits pretrained model_2's case_embd + torch.testing.assert_close( + ce2, + expected_eye[1], + msg="model_2 case_embd should match pretrained model_2", + ) + # model_3 (finetune from model_2): weight copy from model_2 + torch.testing.assert_close( + ce3, + expected_eye[1], + msg="model_3 case_embd should match pretrained model_2 (finetune source)", + ) + # model_4 (random fitting): target default (zeros, set_case_embd skipped) + torch.testing.assert_close( + ce4, + torch.zeros_like(ce4), + msg="model_4 case_embd should be zeros (random fitting, no re-init on finetune)", + ) + + # Run 1 step to verify no crash + trainer_ft.run() + + def test_multitask_finetune_from_single_task(self) -> None: + """Finetune multi-task model from a single-task pretrained .pt checkpoint. + + Tests the single-task pretrained → multi-task finetune path + (finetune_from_multi_task=False, training.py:714-721). + + model_1: finetune_head="Default" → copies from single-task pretrained + model_2: no finetune_head, not in pretrained_keys=["Default"] → RANDOM fitting + """ + if self.share_fitting: + # Single-task pretrained has no dim_case_embd; incompatible with + # shared fitting multi-task target. + return + + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + # Phase 1: train single-task model (2 steps) + single_config = { + "model": { + "type_map": ["O", "H"], + "descriptor": deepcopy(self.descriptor), + "fitting_net": deepcopy(_fitting_net), + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": {"systems": [_PT_DATA], "batch_size": 1}, + "validation_data": { + "systems": [_PT_DATA], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 2, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": 2, + }, + } + single_config = update_deepmd_input(single_config, warning=False) + single_config = normalize(single_config, multi_task=False) + trainer_st = get_trainer(single_config) + trainer_st.run() + + ckpt_path = os.path.join(os.getcwd(), "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Single-task checkpoint not created") + + # Phase 2: build reference state_dict from single-task checkpoint + state_dict_full = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) + state_dict_ckpt = ( + state_dict_full["model"] if "model" in state_dict_full else state_dict_full + ) + pretrained_model_params = state_dict_ckpt["_extra_state"]["model_params"] + + # Single-task pretrained → wrap as {"Default": model} + ref_model = get_model(deepcopy(pretrained_model_params)).to(DEVICE) + pretrained_wrapper = ModelWrapper(ref_model) + pretrained_wrapper.load_state_dict(state_dict_ckpt) + ref_state_dict = pretrained_wrapper.model.state_dict() + + # Phase 3: build 2-branch multi-task finetune config + finetune_config = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting + ) + # model_1: finetune_head="Default" → copy from single-task + finetune_config["model"]["model_dict"]["model_1"]["finetune_head"] = "Default" + # model_2: no finetune_head, "model_2" not in pretrained_keys=["Default"] → RANDOM + finetune_config["training"]["numb_steps"] = 1 + finetune_config["training"]["save_freq"] = 1 + + finetune_config["model"], shared_links_ft = preprocess_shared_params( + finetune_config["model"] + ) + finetune_config["model"], finetune_links = get_finetune_rules( + ckpt_path, finetune_config["model"] + ) + finetune_config = update_deepmd_input(finetune_config, warning=False) + finetune_config = normalize(finetune_config, multi_task=True) + + trainer_ft = get_trainer( + deepcopy(finetune_config), + finetune_model=ckpt_path, + shared_links=shared_links_ft, + finetune_links=finetune_links, + ) + + # Phase 4: verify weight inheritance + ft_state_dict = trainer_ft.wrapper.model.state_dict() + + for state_key in ft_state_dict: + if "model_1" in state_key: + # model_1: finetune from "Default" — all except out_bias/out_std + if "out_bias" in state_key or "out_std" in state_key: + continue + ref_key = state_key.replace("model_1", "Default") + self.assertIn(ref_key, ref_state_dict, f"Missing ref key: {ref_key}") + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_1 (from Default) weight mismatch: {state_key}", + ) + elif "model_2" in state_key: + if "out_bias" in state_key or "out_std" in state_key: + continue + ref_key = state_key.replace("model_2", "Default") + if ".descriptor." in state_key: + # Descriptor from pretrained (RANDOM uses first pretrained key) + self.assertIn( + ref_key, ref_state_dict, f"Missing ref key: {ref_key}" + ) + torch.testing.assert_close( + ref_state_dict[ref_key], + ft_state_dict[state_key], + msg=f"model_2 (RANDOM) descriptor mismatch: {state_key}", + ) + + # model_2 fitting NN weights (networks.*) should differ (random init) + fitting_nn_mismatch = 0 + for state_key in ft_state_dict: + if ( + "model_2" in state_key + and ".fitting_net." in state_key + and "networks" in state_key + ): + ref_key = state_key.replace("model_2", "Default") + if ref_key in ref_state_dict and not torch.equal( + ref_state_dict[ref_key], ft_state_dict[state_key] + ): + fitting_nn_mismatch += 1 + self.assertGreater( + fitting_nn_mismatch, + 0, + "model_2 fitting NN weights should differ from pretrained (random init)", + ) + + # Phase 5: run 1 step to verify no crash + trainer_ft.run() + + def test_multitask_finetune_no_change_model_params(self) -> None: + """Test change_model_params=False preserves user config in multi-task finetune. + + Contrasts with change_model_params=True which overwrites descriptor/fitting + from pretrained (preserving trainable flags). + """ + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + # Phase 1: train 2-branch multi-task model (2 steps) + config_pretrain = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting, numb_steps=2 + ) + config_pretrain["training"]["save_freq"] = 2 + config_pretrain["model"], shared_links_pre = preprocess_shared_params( + config_pretrain["model"] + ) + config_pretrain = update_deepmd_input(config_pretrain, warning=False) + config_pretrain = normalize(config_pretrain, multi_task=True) + trainer = get_trainer(config_pretrain, shared_links=shared_links_pre) + trainer.run() + + ckpt_path = os.path.join(os.getcwd(), "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt_path), "Pretrained checkpoint not created") + + # Phase 2: build finetune config with marker in descriptor + ft_config = _make_multitask_config( + self.descriptor, share_fitting=self.share_fitting + ) + # model_1: no finetune_head → resume (model_1 in pretrained_keys) + # model_2: finetune_head="model_2" → finetune + ft_config["model"]["model_dict"]["model_2"]["finetune_head"] = "model_2" + ft_config["training"]["numb_steps"] = 1 + ft_config["training"]["save_freq"] = 1 + + # Add markers to descriptor in each branch (before preprocess_shared_params + # resolves shared_dict references) + ft_config["model"]["shared_dict"]["my_descriptor"]["_test_marker"] = True + + # Phase 3: test change_model_params=False + ft_config_false = deepcopy(ft_config) + ft_config_false["model"], _ = preprocess_shared_params(ft_config_false["model"]) + model_config_false, finetune_links_false = get_finetune_rules( + ckpt_path, deepcopy(ft_config_false["model"]), change_model_params=False + ) + + # User config preserved: marker still present + self.assertTrue( + model_config_false["model_dict"]["model_1"]["descriptor"].get( + "_test_marker", False + ), + "model_1 descriptor should preserve _test_marker with change_model_params=False", + ) + self.assertTrue( + model_config_false["model_dict"]["model_2"]["descriptor"].get( + "_test_marker", False + ), + "model_2 descriptor should preserve _test_marker with change_model_params=False", + ) + # FinetuneRuleItem has correct type_map + for mk in ("model_1", "model_2"): + self.assertEqual( + finetune_links_false[mk].get_finetune_tmap(), + ["O", "H"], + f"{mk} finetune tmap should be ['O','H']", + ) + # model_1 is resuming, model_2 is not + self.assertTrue( + finetune_links_false["model_1"].resuming, + "model_1 should be resuming (no finetune_head, name in pretrained_keys)", + ) + self.assertFalse( + finetune_links_false["model_2"].resuming, + "model_2 should not be resuming (has finetune_head)", + ) + + # Phase 4: test change_model_params=True (contrast) + ft_config_true = deepcopy(ft_config) + # Also set trainable=False to verify it's preserved + ft_config_true["model"]["shared_dict"]["my_descriptor"]["trainable"] = False + ft_config_true["model"], _ = preprocess_shared_params(ft_config_true["model"]) + model_config_true, finetune_links_true = get_finetune_rules( + ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True + ) + + # Marker overwritten from pretrained + self.assertFalse( + model_config_true["model_dict"]["model_1"]["descriptor"].get( + "_test_marker", False + ), + "model_1 descriptor should NOT have _test_marker with change_model_params=True", + ) + self.assertFalse( + model_config_true["model_dict"]["model_2"]["descriptor"].get( + "_test_marker", False + ), + "model_2 descriptor should NOT have _test_marker with change_model_params=True", + ) + # trainable=False should be preserved + self.assertFalse( + model_config_true["model_dict"]["model_1"]["descriptor"].get( + "trainable", True + ), + "model_1 descriptor trainable should be preserved as False", + ) + self.assertFalse( + model_config_true["model_dict"]["model_2"]["descriptor"].get( + "trainable", True + ), + "model_2 descriptor trainable should be preserved as False", + ) + + # Phase 5: build trainer with change_model_params=False → run 1 step + ft_config_run = deepcopy(ft_config) + ft_config_run["model"], shared_links_ft = preprocess_shared_params( + ft_config_run["model"] + ) + ft_config_run["model"], finetune_links_run = get_finetune_rules( + ckpt_path, ft_config_run["model"], change_model_params=False + ) + ft_config_run = update_deepmd_input(ft_config_run, warning=False) + ft_config_run = normalize(ft_config_run, multi_task=True) + trainer_ft = get_trainer( + deepcopy(ft_config_run), + finetune_model=ckpt_path, + shared_links=shared_links_ft, + finetune_links=finetune_links_run, + ) + trainer_ft.run() + + def test_change_type_map_stat(self) -> None: + """Validate change_type_map preserves existing types' stats. + + Tests two modes: + 1. WITHOUT model_with_new_type_stat: existing types preserved, + new type gets default values (zeros for davg/bias, ones for dstd/std). + 2. WITH model_with_new_type_stat: existing types preserved, + new type gets data-computed values (davg=0, dstd=0.1 for zero + observations via StatItem defaults). + """ + if not self.mixed_types: + return + + old_tmap = ["O", "H"] + new_tmap = ["O", "H", "B"] + + model_config = deepcopy(self.config["model"]["model_dict"]["model_1"]) + + # Build model with old type_map and compute stats + model = get_model(deepcopy(model_config)).to(DEVICE) + data_systems = process_systems([_PT_DATA]) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=old_tmap, + trn_all_set=True, + ) + data.add_data_requirements(_energy_data_requirement) + model.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + sd_before = {k: v.clone() for k, v in model.state_dict().items()} + + # ---- Test 1: change_type_map WITHOUT model_with_new_type_stat ---- + model.change_type_map(new_tmap, model_with_new_type_stat=None) + sd_no_stat = model.state_dict() + + # Stat-like keys: descriptor mean/stddev/davg/dstd and atomic out_bias/out_std + _STAT_SUFFIXES = ("mean", "stddev", "davg", "dstd", "out_bias", "out_std") + + def _is_stat_key(k: str) -> bool: + return any(k.endswith(s) for s in _STAT_SUFFIXES) + + def _is_std_like(k: str) -> bool: + return k.endswith(("stddev", "dstd", "out_std")) + + for key in sd_no_stat: + if key not in sd_before or not _is_stat_key(key): + continue + old_val = sd_before[key] + new_val = sd_no_stat[key] + if old_val.shape == new_val.shape: + continue + # Find the type axis: size went from len(old_tmap) to len(new_tmap) + for ax in range(old_val.ndim): + if old_val.shape[ax] == len(old_tmap) and new_val.shape[ax] == len( + new_tmap + ): + # Existing types preserved + torch.testing.assert_close( + new_val.select(ax, 0), + old_val.select(ax, 0), + msg=f"O stat changed (no model_with_new_type_stat): {key}", + ) + torch.testing.assert_close( + new_val.select(ax, 1), + old_val.select(ax, 1), + msg=f"H stat changed (no model_with_new_type_stat): {key}", + ) + # New type B: defaults (zeros for mean/davg/bias, ones for std) + new_B = new_val.select(ax, 2) + if _is_std_like(key): + torch.testing.assert_close( + new_B, + torch.ones_like(new_B), + msg=f"B default should be ones: {key}", + ) + else: + torch.testing.assert_close( + new_B, + torch.zeros_like(new_B), + msg=f"B default should be zeros: {key}", + ) + break + + # ---- Test 2: change_type_map WITH model_with_new_type_stat ---- + # Build fresh model with old type_map + model2 = get_model(deepcopy(model_config)).to(DEVICE) + model2.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + sd_before2 = {k: v.clone() for k, v in model2.state_dict().items()} + + # Build model_with_new_type_stat with extended type_map + model_ext_config = deepcopy(model_config) + model_ext_config["type_map"] = new_tmap + model_ext = get_model(model_ext_config).to(DEVICE) + data_ext = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=new_tmap, + trn_all_set=True, + ) + data_ext.add_data_requirements(_energy_data_requirement) + model_ext.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data_ext, 1), + stat_file_path=None, + ) + + model2.change_type_map(new_tmap, model_with_new_type_stat=model_ext) + sd_with_stat = model2.state_dict() + + for key in sd_with_stat: + if key not in sd_before2 or not _is_stat_key(key): + continue + old_val = sd_before2[key] + new_val = sd_with_stat[key] + if old_val.shape == new_val.shape: + continue + for ax in range(old_val.ndim): + if old_val.shape[ax] == len(old_tmap) and new_val.shape[ax] == len( + new_tmap + ): + # Existing types preserved + torch.testing.assert_close( + new_val.select(ax, 0), + old_val.select(ax, 0), + msg=f"O stat changed (with model_with_new_type_stat): {key}", + ) + torch.testing.assert_close( + new_val.select(ax, 1), + old_val.select(ax, 1), + msg=f"H stat changed (with model_with_new_type_stat): {key}", + ) + # New type B: descriptor stats should use model_ext's + # computed values, NOT the no-stat defaults (ones) + new_B = new_val.select(ax, 2) + is_descrpt_std = key.endswith(("stddev", "dstd")) + if is_descrpt_std: + # B has zero observations → StatItem default = 0.1 + # (not ones like the no-stat default) + self.assertFalse( + torch.allclose(new_B, torch.ones_like(new_B)), + f"B descriptor stat should NOT be ones " + f"(should be 0.1 from StatItem default): {key}", + ) + break + + def test_multitask_restart(self) -> None: + """Train, then restart from checkpoint and verify.""" + # Phase 1: train + config1 = deepcopy(self.config) + config1["training"]["numb_steps"] = 2 + config1["training"]["save_freq"] = 2 + trainer1 = get_trainer(config1, shared_links=self.shared_links) + trainer1.run() + + ckpt_path = "model.ckpt.pt" + self.assertTrue(os.path.exists(ckpt_path), "Checkpoint not created") + + # Phase 2: restart to step 4 + config2 = deepcopy(self.config) + config2["training"]["numb_steps"] = 4 + config2["training"]["save_freq"] = 4 + trainer2 = get_trainer( + config2, + restart_model=ckpt_path, + shared_links=self.shared_links, + ) + self.assertEqual(trainer2.start_step, 2) + trainer2.run() + + def test_multitask_freeze(self) -> None: + """Train, then freeze with --head and verify. + + Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + """ + if self.descriptor.get("type") != "dpa3": + return + + from deepmd.pt_expt.entrypoints.main import ( + freeze, + ) + + # Train + config = deepcopy(self.config) + trainer = get_trainer(config, shared_links=self.shared_links) + trainer.run() + + # Freeze head model_1 + ckpt_path = "model.ckpt.pt" + output_path = "frozen_model_1.pte" + freeze(model=ckpt_path, output=output_path, head="model_1") + self.assertTrue(os.path.exists(output_path), "Frozen model not created") + + # Verify frozen model loads + from deepmd.pt_expt.model import ( + BaseModel, + ) + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(output_path) + self.assertIn("model", data) + frozen_model = BaseModel.deserialize(data["model"]) + self.assertIsInstance(frozen_model, torch.nn.Module) + + def test_multitask_freeze_no_head_raises(self) -> None: + """Freezing multi-task model without --head raises ValueError. + + Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + """ + if self.descriptor.get("type") != "dpa3": + return + + from deepmd.pt_expt.entrypoints.main import ( + freeze, + ) + + config = deepcopy(self.config) + trainer = get_trainer(config, shared_links=self.shared_links) + trainer.run() + + ckpt_path = "model.ckpt.pt" + with self.assertRaises(ValueError, msg="Should require --head"): + freeze(model=ckpt_path, output="frozen.pte", head=None) + + def test_multitask_freeze_invalid_head_raises(self) -> None: + """Freezing multi-task model with invalid --head raises ValueError. + + Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + """ + if self.descriptor.get("type") != "dpa3": + return + + from deepmd.pt_expt.entrypoints.main import ( + freeze, + ) + + config = deepcopy(self.config) + trainer = get_trainer(config, shared_links=self.shared_links) + trainer.run() + + ckpt_path = "model.ckpt.pt" + with self.assertRaises(ValueError, msg="Should reject invalid head"): + freeze(model=ckpt_path, output="frozen.pte", head="nonexistent") + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f == "lcurve.out": + os.remove(f) + if f.endswith(".pte"): + os.remove(f) + if os.path.isdir("stat_files"): + shutil.rmtree("stat_files") + + +class TestMultiTaskSeA(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with se_e2_a descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_sea_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_se_e2_a + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = False + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskSeAShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with se_e2_a descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_sea_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_se_e2_a + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = False + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA1(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA1 (se_atten) descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa1_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa1 + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA1ShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA1 descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa1_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa1 + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA2(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA2 descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa2_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa2 + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA2ShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA2 descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa2_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa2 + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA3(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA3 descriptor.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa3_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa3 + config = _make_multitask_config(self.descriptor, share_fitting=False) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = False + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskDPA3ShareFit(unittest.TestCase, MultiTaskTrainTest): + """Multi-task training with DPA3 descriptor and shared fitting_net.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_dpa3_sf_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.descriptor = _descriptor_dpa3 + config = _make_multitask_config(self.descriptor, share_fitting=True) + config["model"], self.shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + self.config = config + self.share_fitting = True + self.mixed_types = True + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + +class TestMultiTaskCompile(unittest.TestCase): + """Verify that multi-task + torch.compile works correctly.""" + + @classmethod + def setUpClass(cls) -> None: + _skip_if_no_data() + + def _check_compile_correctness(self, share_fitting: bool = False) -> None: + """Compiled multi-task model predictions and gradients match uncompiled. + + For each branch: feed the same batch through wrapper (which computes + loss), call loss.backward(), then compare: + 1. model predictions (energy, force) + 2. loss values + 3. parameter gradients (second-order, through force loss) + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + # Build uncompiled trainer + config_uc = _make_multitask_config( + _descriptor_se_e2_a, share_fitting=share_fitting + ) + config_uc["model"], shared_links_uc = preprocess_shared_params( + config_uc["model"] + ) + config_uc = update_deepmd_input(config_uc, warning=False) + config_uc = normalize(config_uc, multi_task=True) + + # Build compiled trainer + config_c = _make_multitask_config( + _descriptor_se_e2_a, share_fitting=share_fitting + ) + config_c["training"]["enable_compile"] = True + config_c["model"], shared_links_c = preprocess_shared_params(config_c["model"]) + config_c = update_deepmd_input(config_c, warning=False) + config_c = normalize(config_c, multi_task=True) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_compile_corr_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer_uc = get_trainer(config_uc, shared_links=shared_links_uc) + trainer_c = get_trainer(config_c, shared_links=shared_links_c) + for mk in ("model_1", "model_2"): + self.assertIsInstance(trainer_c.wrapper.model[mk], _CompiledModel) + + # Copy uncompiled weights → compiled (same starting point) + for mk in ("model_1", "model_2"): + trainer_c.wrapper.model[mk].original_model.load_state_dict( + trainer_uc.wrapper.model[mk].state_dict() + ) + + # For each branch, run one forward+backward and compare + for task_key in ("model_1", "model_2"): + trainer_uc.optimizer.zero_grad(set_to_none=True) + trainer_c.optimizer.zero_grad(set_to_none=True) + + input_dict, label_dict = trainer_uc.get_data( + is_train=True, task_key=task_key + ) + + cur_lr = trainer_uc.scheduler.get_last_lr()[0] + pred_uc, loss_uc, _ = trainer_uc.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=task_key, + ) + pred_c, loss_c, _ = trainer_c.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + task_key=task_key, + ) + + # Compare predictions + torch.testing.assert_close( + pred_c["energy"], + pred_uc["energy"], + atol=1e-10, + rtol=1e-10, + ) + torch.testing.assert_close( + pred_c["force"], + pred_uc["force"], + atol=1e-10, + rtol=1e-10, + ) + torch.testing.assert_close(loss_c, loss_uc, atol=1e-10, rtol=1e-10) + + # Compare gradients (second-order, through force loss) + loss_uc.backward() + loss_c.backward() + for (name_uc, p_uc), (name_c, p_c) in zip( + trainer_uc.wrapper.model[task_key].named_parameters(), + trainer_c.wrapper.model[task_key].original_model.named_parameters(), + strict=True, + ): + if p_uc.grad is not None: + self.assertIsNotNone( + p_c.grad, + msg=f"grad is None for {name_c} (task={task_key})", + ) + torch.testing.assert_close( + p_c.grad, + p_uc.grad, + atol=1e-10, + rtol=1e-10, + msg=f"grad mismatch on {name_uc} (task={task_key})", + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compile_multitask_correctness(self) -> None: + """Compiled multi-task predictions and gradients match uncompiled.""" + self._check_compile_correctness(share_fitting=False) + + def test_compile_multitask_correctness_share_fitting(self) -> None: + """Compiled multi-task with shared fitting: predictions and gradients match.""" + self._check_compile_correctness(share_fitting=True) + + def test_compile_multitask_train(self) -> None: + """Train multi-task model with torch.compile for a few steps.""" + config = _make_multitask_config(_descriptor_se_e2_a) + config["training"]["enable_compile"] = True + config["training"]["numb_steps"] = 2 + config["training"]["save_freq"] = 2 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_compile_train_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compile_multitask_train_share_fitting(self) -> None: + """Train multi-task model with shared fitting + compile for a few steps.""" + config = _make_multitask_config(_descriptor_se_e2_a, share_fitting=True) + config["training"]["enable_compile"] = True + config["training"]["numb_steps"] = 2 + config["training"]["save_freq"] = 2 + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_compile_sf_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Gradient accumulation test helpers +# --------------------------------------------------------------------------- + + +def _generate_random_data_dir( + path: str, + atom_types: list[int], + nframes: int, + seed: int, + nfparam: int = 0, + naparam: int = 0, +) -> None: + """Create a minimal deepmd data directory with random data.""" + rng = np.random.RandomState(seed) + natoms = len(atom_types) + os.makedirs(os.path.join(path, "set.000"), exist_ok=True) + + # type.raw + with open(os.path.join(path, "type.raw"), "w") as f: + for t in atom_types: + f.write(f"{t}\n") + + # box: diagonal 20x20x20 + box = np.tile(np.diag([20.0, 20.0, 20.0]).flatten(), (nframes, 1)) + np.save(os.path.join(path, "set.000", "box.npy"), box) + + # coord + coord = rng.random((nframes, natoms * 3)) * 20.0 + np.save(os.path.join(path, "set.000", "coord.npy"), coord) + + # energy + energy = rng.random((nframes,)) + np.save(os.path.join(path, "set.000", "energy.npy"), energy) + + # force + force = rng.random((nframes, natoms * 3)) + np.save(os.path.join(path, "set.000", "force.npy"), force) + + # fparam (frame parameters) + if nfparam > 0: + fparam = rng.random((nframes, nfparam)) + np.save(os.path.join(path, "set.000", "fparam.npy"), fparam) + + # aparam (atomic parameters) + if naparam > 0: + aparam = rng.random((nframes, natoms * naparam)) + np.save(os.path.join(path, "set.000", "aparam.npy"), aparam) + + +def _make_gradient_test_mt_config( + data_dir_0: str, + data_dir_1: str, + numb_fparam: int = 0, + numb_aparam: int = 0, +) -> dict: + """Multi-task config for gradient accumulation test.""" + type_map = ["O", "H", "C"] + descriptor = deepcopy(_descriptor_dpa3) + fitting_1: dict = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + } + fitting_2: dict = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 2, + } + if numb_fparam > 0: + fitting_1["numb_fparam"] = numb_fparam + fitting_2["numb_fparam"] = numb_fparam + if numb_aparam > 0: + fitting_1["numb_aparam"] = numb_aparam + fitting_2["numb_aparam"] = numb_aparam + return { + "model": { + "shared_dict": { + "my_type_map": type_map, + "my_descriptor": descriptor, + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_1, + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": fitting_2, + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 0.001, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir_0], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir_0], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir_1], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir_1], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": 2, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + + +def _make_gradient_test_st_config( + data_dir: str, + fitting_seed: int, + numb_fparam: int = 0, + numb_aparam: int = 0, +) -> dict: + """Single-task config for gradient accumulation test.""" + type_map = ["O", "H", "C"] + descriptor = deepcopy(_descriptor_dpa3) + fitting: dict = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": fitting_seed, + } + if numb_fparam > 0: + fitting["numb_fparam"] = numb_fparam + if numb_aparam > 0: + fitting["numb_aparam"] = numb_aparam + return { + "model": { + "type_map": type_map, + "descriptor": descriptor, + "fitting_net": fitting, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 0.001, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + + +def _deterministic_task_choice(task_sequence: list[int]): + """Return a patched dp_random.choice that forces task selection order.""" + original = dp_random.choice + it = iter(task_sequence) + + def patched(a, size=None, replace=True, p=None): + # Task selection: array with >=2 elements and probability vector + if hasattr(a, "__len__") and len(a) >= 2 and p is not None: + return next(it) + return original(a, size=size, replace=replace, p=p) + + return patched + + +def _make_recording_step( + trainer, + modules_to_record: dict, + recorded_grads: list[dict], +): + """Patch _optimizer_step: record grads from listed modules, skip optimizer. + + Parameters + ---------- + trainer : Trainer + The trainer whose scheduler.step() is called. + modules_to_record : dict[str, torch.nn.Module] + Named modules whose parameter gradients to record. + recorded_grads : list[dict[str, torch.Tensor]] + Appended with {module_key/param_name: grad} at each step. + """ + + def recording_step(): + grads = {} + for mod_key, mod in modules_to_record.items(): + for n, p in mod.named_parameters(): + if p.grad is not None: + grads[f"{mod_key}/{n}"] = p.grad.clone() + recorded_grads.append(grads) + trainer.scheduler.step() + + return recording_step + + +class TestMultiTaskGradient(unittest.TestCase): + """Verify multi-task descriptor gradients match sum of single-task gradients.""" + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_grad_") + self._old_cwd = os.getcwd() + os.chdir(self.tmpdir) + + self.nfparam = 2 + self.naparam = 3 + + self.data_dir_0 = os.path.join(self.tmpdir, "data_task0") + _generate_random_data_dir( + self.data_dir_0, + atom_types=[0, 0, 1, 1, 1, 2], + nframes=1, + seed=42, + nfparam=self.nfparam, + naparam=self.naparam, + ) + self.data_dir_1 = os.path.join(self.tmpdir, "data_task1") + _generate_random_data_dir( + self.data_dir_1, + atom_types=[0, 1, 1, 2, 2, 2, 2], + nframes=1, + seed=137, + nfparam=self.nfparam, + naparam=self.naparam, + ) + + def tearDown(self) -> None: + os.chdir(self._old_cwd) + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_gradient_accumulation(self) -> None: + """Sum of per-task grads from multi-task run() == + sum of grads from two single-task run() calls. + """ + # ===== Multi-task trainer ===== + mt_config = _make_gradient_test_mt_config( + self.data_dir_0, + self.data_dir_1, + numb_fparam=self.nfparam, + numb_aparam=self.naparam, + ) + mt_config["model"], shared_links = preprocess_shared_params(mt_config["model"]) + mt_config = update_deepmd_input(mt_config, warning=False) + mt_config = normalize(mt_config, multi_task=True) + + mt_trainer = get_trainer(deepcopy(mt_config), shared_links=shared_links) + mt_desc = mt_trainer.wrapper.model["model_1"].atomic_model.descriptor + mt_fit_1 = mt_trainer.wrapper.model["model_1"].atomic_model.fitting_net + mt_fit_2 = mt_trainer.wrapper.model["model_2"].atomic_model.fitting_net + + # Verify descriptor params are aliased (share_params) + mt_desc_2 = mt_trainer.wrapper.model["model_2"].atomic_model.descriptor + for (n1, p1), (n2, p2) in zip( + mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True + ): + assert p1.data_ptr() == p2.data_ptr(), ( + f"Descriptor params not aliased: {n1}" + ) + + # Record grads for descriptor + both fitting heads + mt_grads: list[dict[str, torch.Tensor]] = [] + mt_trainer._optimizer_step = _make_recording_step( + mt_trainer, + {"desc": mt_desc, "fit_1": mt_fit_1, "fit_2": mt_fit_2}, + mt_grads, + ) + with mock.patch( + "deepmd.utils.random.choice", + _deterministic_task_choice([0, 1]), + ): + mt_trainer.run() # 2 steps: task_0 then task_1 + + assert len(mt_grads) == 2 + + # ===== Single-task trainer for task_0 ===== + st0_config = _make_gradient_test_st_config( + self.data_dir_0, + fitting_seed=1, # same as model_1 + numb_fparam=self.nfparam, + numb_aparam=self.naparam, + ) + st0_config = update_deepmd_input(st0_config, warning=False) + st0_config = normalize(st0_config) + + os.chdir(tempfile.mkdtemp(dir=self.tmpdir)) # fresh cwd + st0_trainer = get_trainer(deepcopy(st0_config)) + + # Copy MT model_1 state → ST0 to ensure identical params+buffers + # (stat buffers like davg/dstd/bias_atom_e differ due to data) + mt_m1 = mt_trainer.wrapper.model["model_1"] + st0_m = st0_trainer.wrapper.model["Default"] + st0_m.load_state_dict(mt_m1.state_dict()) + + st0_desc = st0_m.atomic_model.descriptor + st0_fit = st0_m.atomic_model.fitting_net + + st0_grads: list[dict[str, torch.Tensor]] = [] + st0_trainer._optimizer_step = _make_recording_step( + st0_trainer, {"desc": st0_desc, "fit": st0_fit}, st0_grads + ) + st0_trainer.run() # 1 step + assert len(st0_grads) == 1 + + # ===== Single-task trainer for task_1 ===== + st1_config = _make_gradient_test_st_config( + self.data_dir_1, + fitting_seed=2, # same as model_2 + numb_fparam=self.nfparam, + numb_aparam=self.naparam, + ) + st1_config = update_deepmd_input(st1_config, warning=False) + st1_config = normalize(st1_config) + + os.chdir(tempfile.mkdtemp(dir=self.tmpdir)) # fresh cwd + st1_trainer = get_trainer(deepcopy(st1_config)) + + # Copy MT model_2 state → ST1 to ensure identical params+buffers + mt_m2 = mt_trainer.wrapper.model["model_2"] + st1_m = st1_trainer.wrapper.model["Default"] + st1_m.load_state_dict(mt_m2.state_dict()) + + st1_desc = st1_m.atomic_model.descriptor + st1_fit = st1_m.atomic_model.fitting_net + + st1_grads: list[dict[str, torch.Tensor]] = [] + st1_trainer._optimizer_step = _make_recording_step( + st1_trainer, {"desc": st1_desc, "fit": st1_fit}, st1_grads + ) + st1_trainer.run() # 1 step + assert len(st1_grads) == 1 + + # ===== Comparison: descriptor gradients ===== + # Multi-task descriptor grad at each step should match single-task + desc_keys = [k for k in mt_grads[0] if k.startswith("desc/")] + assert len(desc_keys) > 0, "No descriptor gradients" + + # Per-task descriptor grad: mt step_0 == st_0, mt step_1 == st_1 + for name in desc_keys: + np.testing.assert_allclose( + mt_grads[0][name].detach().cpu().numpy(), + st0_grads[0][name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Task_0 descriptor grad mismatch: {name}", + ) + np.testing.assert_allclose( + mt_grads[1][name].detach().cpu().numpy(), + st1_grads[0][name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Task_1 descriptor grad mismatch: {name}", + ) + + # Descriptor grad sum: mt(step0 + step1) == st0 + st1 + for name in desc_keys: + mt_sum = mt_grads[0][name] + mt_grads[1][name] + st_sum = st0_grads[0][name] + st1_grads[0][name] + np.testing.assert_allclose( + mt_sum.detach().cpu().numpy(), + st_sum.detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Descriptor grad sum mismatch: {name}", + ) + + # ===== Comparison: fitting head gradients ===== + # Step 0 trains model_1 → mt fit_1 grads == st0 fit grads + fit1_keys = [k for k in mt_grads[0] if k.startswith("fit_1/")] + for name in fit1_keys: + st_name = name.replace("fit_1/", "fit/") + np.testing.assert_allclose( + mt_grads[0][name].detach().cpu().numpy(), + st0_grads[0][st_name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Fitting head grad mismatch (task_0): {name}", + ) + # Verify fit_2 not in step 0 grads (not part of computation graph) + assert not any(k.startswith("fit_2/") for k in mt_grads[0]), ( + "fit_2 should have no gradients in step 0 (task_0)" + ) + + # Step 1 trains model_2 → mt fit_2 grads == st1 fit grads + fit2_keys = [k for k in mt_grads[1] if k.startswith("fit_2/")] + for name in fit2_keys: + st_name = name.replace("fit_2/", "fit/") + np.testing.assert_allclose( + mt_grads[1][name].detach().cpu().numpy(), + st1_grads[0][st_name].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Fitting head grad mismatch (task_1): {name}", + ) + # Verify fit_1 not in step 1 grads + assert not any(k.startswith("fit_1/") for k in mt_grads[1]), ( + "fit_1 should have no gradients in step 1 (task_1)" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 3b3ab247bb..e0250eb24e 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -193,7 +193,7 @@ def test_nall_growth_triggers_recompile(self) -> None: trainer = get_trainer(config) # The wrapper.model should be a _CompiledModel - compiled_model = trainer.wrapper.model + compiled_model = trainer.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) original_max_nall = compiled_model._max_nall @@ -273,7 +273,7 @@ def test_compiled_matches_uncompiled(self) -> None: config_compiled = update_deepmd_input(config_compiled, warning=False) config_compiled = normalize(config_compiled) trainer_compiled = get_trainer(config_compiled) - compiled_model = trainer_compiled.wrapper.model + compiled_model = trainer_compiled.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) # Copy uncompiled weights to compiled model so they match @@ -330,6 +330,84 @@ def test_compiled_matches_uncompiled(self) -> None: finally: shutil.rmtree(tmpdir, ignore_errors=True) + def test_compiled_gradients_match_uncompiled(self) -> None: + """Parameter gradients from compiled model must match uncompiled. + + Verifies second-order derivatives are correct: the loss includes + force terms, and force is computed via autograd.grad(create_graph=True), + so loss.backward() requires second-order differentiation through the + make_fx-decomposed backward ops. + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + config_uc = _make_config(self.data_dir, numb_steps=1) + config_uc = update_deepmd_input(config_uc, warning=False) + config_uc = normalize(config_uc) + + config_c = _make_config(self.data_dir, numb_steps=1) + config_c["training"]["enable_compile"] = True + config_c = update_deepmd_input(config_c, warning=False) + config_c = normalize(config_c) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_grad_consistency_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer_uc = get_trainer(config_uc) + trainer_c = get_trainer(config_c) + compiled_model = trainer_c.wrapper.model["Default"] + self.assertIsInstance(compiled_model, _CompiledModel) + + # Match weights + compiled_model.original_model.load_state_dict( + trainer_uc.model.state_dict() + ) + + # Forward + backward through wrapper (includes loss) + trainer_uc.optimizer.zero_grad(set_to_none=True) + trainer_c.optimizer.zero_grad(set_to_none=True) + + input_dict, label_dict = trainer_uc.get_data(is_train=True) + cur_lr = trainer_uc.scheduler.get_last_lr()[0] + + _, loss_uc, _ = trainer_uc.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + ) + _, loss_c, _ = trainer_c.wrapper( + **input_dict, + cur_lr=cur_lr, + label=label_dict, + ) + loss_uc.backward() + loss_c.backward() + + for (name_uc, p_uc), (name_c, p_c) in zip( + trainer_uc.model.named_parameters(), + compiled_model.original_model.named_parameters(), + strict=True, + ): + if p_uc.grad is not None: + self.assertIsNotNone( + p_c.grad, + msg=f"grad is None for {name_c}", + ) + torch.testing.assert_close( + p_c.grad, + p_uc.grad, + atol=1e-10, + rtol=1e-10, + msg=f"grad mismatch on {name_uc}", + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + class TestGetData(unittest.TestCase): """Test the batch data conversion in Trainer.get_data.""" @@ -496,7 +574,7 @@ def test_restart_with_compile(self) -> None: trainer2 = get_trainer(config2, restart_model=ckpt_path) self.assertEqual(trainer2.start_step, 5) - self.assertIsInstance(trainer2.wrapper.model, _CompiledModel) + self.assertIsInstance(trainer2.wrapper.model["Default"], _CompiledModel) trainer2.run() with open(os.path.join(tmpdir, "lcurve.out")) as f: diff --git a/source/tests/pt_expt/test_training_ddp.py b/source/tests/pt_expt/test_training_ddp.py new file mode 100644 index 0000000000..053c9af4d5 --- /dev/null +++ b/source/tests/pt_expt/test_training_ddp.py @@ -0,0 +1,1666 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for distributed (DDP) training in the pt_expt backend. + +Uses ``torch.multiprocessing.spawn`` + ``gloo`` backend to verify DDP on CPU. + +Verifies that: +1. Single-task DDP training completes and produces correct outputs +2. Multi-task DDP training completes and produces correct outputs +3. DDP gradient averaging matches manual average of per-rank gradients +4. Multi-task DDP gradient averaging works correctly +5. Finetune + DDP: selective weight copy via _unwrapped +6. Finetune + DDP with random fitting: descriptor from pretrained, fitting random +7. Finetune + DDP with new type: exercises _unwrapped.model["Default"] + stat broadcast +""" + +import os +import shutil +import socket +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, +) +from deepmd.pt_expt.utils.multi_task import ( + preprocess_shared_params, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) + +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent / "pt" / "water" / "data" / "data_0") + +EXAMPLE_DIR = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "examples", + "water", +) + + +def _find_free_port(): + """Find a free TCP port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def _make_config(data_dir: str, numb_steps: int = 2) -> dict: + """Build a minimal single-task config.""" + return { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _make_multitask_config(data_dir: str, numb_steps: int = 2) -> dict: + """Build a minimal multi-task config with shared descriptor.""" + descriptor = { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + } + fitting = { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + } + return { + "model": { + "shared_dict": { + "my_type_map": ["O", "H"], + "my_descriptor": deepcopy(descriptor), + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _make_dpa1_config(data_dir: str, numb_steps: int = 2) -> dict: + """Build a minimal DPA1 config (mixed_types) for finetune new-type tests.""" + return { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa1", + "sel": 12, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [4, 8], + "axis_neuron": 4, + "attn": 4, + "attn_layer": 1, + "attn_dotr": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [8, 8], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _subsample_data(src_dir: str, dst_dir: str, nframes: int = 2) -> None: + """Copy a data system, keeping only the first *nframes* frames.""" + shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True) + set_dir = os.path.join(dst_dir, "set.000") + for name in os.listdir(set_dir): + if name.endswith(".npy"): + arr = np.load(os.path.join(set_dir, name)) + np.save(os.path.join(set_dir, name), arr[:nframes]) + + +# --------------------------------------------------------------------------- +# Worker functions for mp.spawn +# --------------------------------------------------------------------------- + + +def _worker_single_task_train(rank, world_size, port, data_dir, result_dict): + """Worker: run single-task DDP training.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_st_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + # Collect results + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + # Get final weights + weights = { + name: p.detach().cpu().clone() + for name, p in trainer._unwrapped.named_parameters() + } + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "weights": weights, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_multitask_train(rank, world_size, port, data_dir, result_dict): + """Worker: run multi-task DDP training.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_multitask_config(data_dir, numb_steps=2) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + # Get shared descriptor params from model_1 + desc_params = {} + for name, p in trainer._unwrapped.model[ + "model_1" + ].atomic_model.descriptor.named_parameters(): + desc_params[name] = p.detach().cpu().clone() + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "desc_params": desc_params, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_gradient_test(rank, world_size, port, data_dir, result_dict): + """Worker: run 1 step of DDP training, collect gradients and input data.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_grad_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=1) + config["model"]["descriptor"]["precision"] = "float64" + config["model"]["fitting_net"]["precision"] = "float64" + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + + # Run one forward/backward step manually + trainer.wrapper.train() + trainer.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict = trainer.get_data(is_train=True, task_key="Default") + + cur_lr_sched = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper( + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + ) + loss.backward() # DDP all-reduces gradients here + + # Collect post-all-reduce gradients + grads = {} + for name, p in trainer._unwrapped.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().cpu().clone() + + # Collect input batch (for single-process replay) + batch = {} + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.detach().cpu().clone() + else: + batch[k] = v + for k, v in label_dict.items(): + if isinstance(v, torch.Tensor): + batch[f"label_{k}"] = v.detach().cpu().clone() + + # Initial model state dict (before any optimizer step) + init_state = { + k: v.detach().cpu().clone() + for k, v in trainer._unwrapped.state_dict().items() + if k != "_extra_state" + } + + result_dict[rank] = { + "grads": grads, + "batch": batch, + "init_state": init_state, + "config": config, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_multitask_gradient_test(rank, world_size, port, data_dir, result_dict): + """Worker: run 1 step of multi-task DDP training, collect gradients.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_grad_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_multitask_config(data_dir, numb_steps=1) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + + # Run one step with deterministic task selection + + # Force task_key = "model_1" for all ranks (deterministic) + trainer.wrapper.train() + trainer.optimizer.zero_grad(set_to_none=True) + input_dict, label_dict = trainer.get_data(is_train=True, task_key="model_1") + cur_lr_sched = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper( + **input_dict, + cur_lr=cur_lr_sched, + label=label_dict, + task_key="model_1", + ) + loss.backward() + + grads = {} + for name, p in trainer._unwrapped.named_parameters(): + if p.grad is not None: + grads[name] = p.grad.detach().cpu().clone() + + result_dict[rank] = { + "grads": grads, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_check_resume( + rank, world_size, port, data_dir, ckpt_path, numb_steps, is_restart, result_dict +): + """Worker: build DDP trainer from checkpoint, capture initial state, then train. + + Parameters + ---------- + is_restart : bool + True → restart_model (continue training, restore optimizer & step). + False → init_model (inherit weights, reset step to 0). + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_resume_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=numb_steps) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + if is_restart: + trainer = get_trainer(config, restart_model=ckpt_path) + else: + trainer = get_trainer(config, init_model=ckpt_path) + + # Capture initial state BEFORE training + init_weights = { + name: p.detach().cpu().clone() + for name, p in trainer._unwrapped.named_parameters() + } + start_step = trainer.start_step + init_lr = trainer.scheduler.get_last_lr()[0] + + trainer.run() + + result_dict[rank] = { + "init_weights": init_weights, + "start_step": start_step, + "init_lr": init_lr, + "lcurve_exists": os.path.exists("lcurve.out"), + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_finetune( + rank, world_size, port, ckpt_path, config_dict, model_branch, result_dict +): + """Worker: DDP finetune from checkpoint. + + Parameters + ---------- + ckpt_path : str + Absolute path to pretrained checkpoint (.pt). + config_dict : dict + Already normalized config with absolute data paths. + model_branch : str or None + ``"RANDOM"`` for random fitting, ``None`` for normal. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_ft_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = deepcopy(config_dict) + config["model"], finetune_links = get_finetune_rules( + ckpt_path, + config["model"], + model_branch=model_branch or "", + ) + + trainer = get_trainer( + config, + finetune_model=ckpt_path, + finetune_links=finetune_links, + ) + + # Capture state after finetune setup (before training) + init_state = { + k: v.detach().cpu().clone() + for k, v in trainer._unwrapped.state_dict().items() + if k != "_extra_state" + } + + trainer.run() + + result_dict[rank] = { + "init_state": init_state, + "lcurve_exists": os.path.exists("lcurve.out"), + "ckpt_files": [f for f in os.listdir(".") if f.endswith(".pt")], + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +# --------------------------------------------------------------------------- +# Test classes +# --------------------------------------------------------------------------- + + +class TestDDPSingleTaskTrain(unittest.TestCase): + """Smoke test: single-task DDP training with 2 ranks.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + def test_ddp_single_task_trains(self) -> None: + """2 ranks, se_e2_a, 2 training steps — verify completion and outputs.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_single_task_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Only rank 0 produces lcurve.out and checkpoints + self.assertTrue(results[0]["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse( + results[1]["lcurve_exists"], "rank 1 should NOT produce lcurve.out" + ) + self.assertGreater(results[0]["num_ckpts"], 0, "rank 0 should save checkpoints") + self.assertEqual( + results[1]["num_ckpts"], 0, "rank 1 should NOT save checkpoints" + ) + + # Final weights should be identical across ranks + for name in results[0]["weights"]: + torch.testing.assert_close( + results[0]["weights"][name], + results[1]["weights"][name], + msg=f"Weights differ across ranks: {name}", + ) + + +class TestDDPMultiTaskTrain(unittest.TestCase): + """Smoke test: multi-task DDP training with 2 ranks.""" + + @classmethod + def setUpClass(cls) -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + cls.data_dir = _PT_DATA + + def test_ddp_multitask_trains(self) -> None: + """2 ranks, multi-task, 2 steps — verify completion.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Only rank 0 produces output files + self.assertTrue(results[0]["lcurve_exists"]) + self.assertFalse(results[1]["lcurve_exists"]) + self.assertGreater(results[0]["num_ckpts"], 0) + self.assertEqual(results[1]["num_ckpts"], 0) + + # Shared descriptor params should be identical across ranks + for name in results[0]["desc_params"]: + torch.testing.assert_close( + results[0]["desc_params"][name], + results[1]["desc_params"][name], + msg=f"Shared descriptor param differs across ranks: {name}", + ) + + +class TestDDPGradientAveraging(unittest.TestCase): + """Core DDP correctness: gradient averaging matches manual computation. + + Each DDP rank processes different data. After all-reduce, all ranks have + the averaged gradient. We verify: + 1. Both ranks have identical gradients (DDP guarantee) + 2. The DDP gradient equals the average of per-rank single-process gradients + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + def test_ddp_gradient_equals_average(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_gradient_test, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # 1. Verify gradients are identical on both ranks (DDP guarantee) + for name in r0["grads"]: + self.assertIn(name, r1["grads"], f"Grad key missing on rank 1: {name}") + torch.testing.assert_close( + r0["grads"][name], + r1["grads"][name], + atol=0, + rtol=0, + msg=f"Gradients should be identical across ranks: {name}", + ) + + # 2. Rebuild model in single process, replay each rank's batch, + # compute manual average, compare to DDP gradient + config = r0["config"] + tmpdir = tempfile.mkdtemp(prefix="ddp_grad_verify_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + # Load same initial state as DDP workers + state_to_load = dict(trainer._unwrapped.state_dict()) + for k in r0["init_state"]: + state_to_load[k] = r0["init_state"][k] + trainer._unwrapped.load_state_dict(state_to_load) + trainer.wrapper.train() + + # Forward+backward with rank 0's batch + trainer.optimizer.zero_grad(set_to_none=True) + input_0 = { + k: v.clone() + for k, v in r0["batch"].items() + if not k.startswith("label_") + } + label_0 = { + k[len("label_") :]: v.clone() + for k, v in r0["batch"].items() + if k.startswith("label_") + } + input_0["coord"] = input_0["coord"].requires_grad_(True) + cur_lr = trainer.scheduler.get_last_lr()[0] + _, loss_0, _ = trainer.wrapper(**input_0, cur_lr=cur_lr, label=label_0) + loss_0.backward() + grad_0 = { + name: p.grad.detach().clone() + for name, p in trainer._unwrapped.named_parameters() + if p.grad is not None + } + + # Forward+backward with rank 1's batch + trainer.optimizer.zero_grad(set_to_none=True) + input_1 = { + k: v.clone() + for k, v in r1["batch"].items() + if not k.startswith("label_") + } + label_1 = { + k[len("label_") :]: v.clone() + for k, v in r1["batch"].items() + if k.startswith("label_") + } + input_1["coord"] = input_1["coord"].requires_grad_(True) + _, loss_1, _ = trainer.wrapper(**input_1, cur_lr=cur_lr, label=label_1) + loss_1.backward() + grad_1 = { + name: p.grad.detach().clone() + for name, p in trainer._unwrapped.named_parameters() + if p.grad is not None + } + + # Expected = average of the two + for name in r0["grads"]: + if name in grad_0 and name in grad_1: + expected = (grad_0[name] + grad_1[name]) / 2.0 + torch.testing.assert_close( + r0["grads"][name], + expected, + atol=1e-10, + rtol=1e-10, + msg=f"DDP grad != avg(rank0, rank1) for {name}", + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestDDPMultiTaskGradient(unittest.TestCase): + """Verify DDP gradient averaging with multi-task training.""" + + @classmethod + def setUpClass(cls) -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + cls.data_dir = _PT_DATA + + def test_ddp_multitask_gradient(self) -> None: + """Both ranks pick same task; gradients should be identical after all-reduce.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_gradient_test, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Gradients should be identical across ranks + for name in r0["grads"]: + self.assertIn(name, r1["grads"], f"Grad key missing on rank 1: {name}") + torch.testing.assert_close( + r0["grads"][name], + r1["grads"][name], + atol=0, + rtol=0, + msg=f"Multi-task DDP gradients differ across ranks: {name}", + ) + + +class _DDPResumeBase(unittest.TestCase): + """Shared setup: train 2 steps in single process, save checkpoint + weights.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + cls._tmpdir = tempfile.mkdtemp(prefix="ddp_resume_setup_") + old_cwd = os.getcwd() + os.chdir(cls._tmpdir) + try: + config = _make_config(cls.data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + cls.ckpt_path = os.path.join(cls._tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), "Checkpoint not created" + + # Record phase-1 final weights for comparison + cls.phase1_weights = { + name: p.detach().cpu().clone() + for name, p in trainer.wrapper.named_parameters() + } + cls.lr_config = config["learning_rate"].copy() + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + +class TestDDPInitModel(_DDPResumeBase): + """DDP init_model: inherits weights but resets step to 0.""" + + def test_ddp_init_model(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_check_resume, + args=( + 2, + port, + self.data_dir, + self.ckpt_path, + 2, # numb_steps: train 2 fresh steps from step 0 + False, # is_restart=False → init_model + result_dict, + ), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, + ) + + # init_model resets step to 0 + self.assertEqual(r0["start_step"], 0) + self.assertEqual(r1["start_step"], 0) + + # LR should be lr_schedule(0), i.e. start_lr + lr_params = self.lr_config.copy() + lr_params["num_steps"] = 2 # init_model config uses numb_steps=2 + expected_lr = LearningRateExp(**lr_params).value(0) + self.assertAlmostEqual(r0["init_lr"], expected_lr, places=10) + + # Only rank 0 produces lcurve + self.assertTrue(r0["lcurve_exists"]) + self.assertFalse(r1["lcurve_exists"]) + + # Initial weights (after checkpoint load) must match phase-1 final weights + for name in self.phase1_weights: + self.assertIn(name, r0["init_weights"], f"Missing param: {name}") + torch.testing.assert_close( + r0["init_weights"][name], + self.phase1_weights[name], + msg=f"init_model did not inherit weights correctly: {name}", + ) + + # Initial weights identical across ranks + for name in r0["init_weights"]: + torch.testing.assert_close( + r0["init_weights"][name], + r1["init_weights"][name], + msg=f"init_model weights differ across ranks: {name}", + ) + + +class TestDDPRestart(_DDPResumeBase): + """DDP restart: continues training from saved step with restored optimizer.""" + + def test_ddp_restart(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_check_resume, + args=( + 2, + port, + self.data_dir, + self.ckpt_path, + 4, # numb_steps: continue to step 4 + True, # is_restart=True → restart_model + result_dict, + ), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + from deepmd.dpmodel.utils.learning_rate import ( + LearningRateExp, + ) + + # restart restores the step counter + self.assertEqual(r0["start_step"], 2) + self.assertEqual(r1["start_step"], 2) + + # LR should be lr_schedule(2) with num_steps=4 (the restart config) + lr_params = self.lr_config.copy() + lr_params["num_steps"] = 4 # restart config uses numb_steps=4 + lr_sched = LearningRateExp(**lr_params) + expected_lr = lr_sched.value(2) + start_lr = lr_sched.value(0) + self.assertAlmostEqual(r0["init_lr"], expected_lr, places=10) + # Verify it is NOT equal to start_lr (i.e. the LR actually decayed) + self.assertNotAlmostEqual( + r0["init_lr"], + start_lr, + places=10, + msg="restart LR should differ from start_lr", + ) + + # Only rank 0 produces lcurve + self.assertTrue(r0["lcurve_exists"]) + self.assertFalse(r1["lcurve_exists"]) + + # Initial weights (after checkpoint load) must match phase-1 final weights + for name in self.phase1_weights: + self.assertIn(name, r0["init_weights"], f"Missing param: {name}") + torch.testing.assert_close( + r0["init_weights"][name], + self.phase1_weights[name], + msg=f"restart did not load weights correctly: {name}", + ) + + # Initial weights identical across ranks + for name in r0["init_weights"]: + torch.testing.assert_close( + r0["init_weights"][name], + r1["init_weights"][name], + msg=f"restart weights differ across ranks: {name}", + ) + + +# --------------------------------------------------------------------------- +# Finetune + DDP tests +# --------------------------------------------------------------------------- + + +class _DDPFinetuneBase(unittest.TestCase): + """Shared setup: train pretrained se_e2_a model, save checkpoint + weights.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + cls._tmpdir = tempfile.mkdtemp(prefix="ddp_ft_setup_") + old_cwd = os.getcwd() + os.chdir(cls._tmpdir) + try: + config = _make_config(cls.data_dir, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + cls.ckpt_path = os.path.join(cls._tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), "Pretrained checkpoint not created" + + # Save pretrained state for comparison (excluding _extra_state) + state = torch.load(cls.ckpt_path, map_location="cpu", weights_only=True) + model_state = state["model"] if "model" in state else state + cls.pretrained_state = { + k: v.clone() for k, v in model_state.items() if k != "_extra_state" + } + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + +class TestDDPFinetune(_DDPFinetuneBase): + """DDP finetune: same type_map, descriptor + fitting from pretrained.""" + + def test_ddp_finetune(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + + config = _make_config(self.data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + mp.spawn( + _worker_finetune, + args=(2, port, self.ckpt_path, config, None, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Only rank 0 writes output + self.assertTrue(r0["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse(r1["lcurve_exists"], "rank 1 should NOT produce lcurve.out") + self.assertGreater(len(r0["ckpt_files"]), 0, "rank 0 should save checkpoints") + self.assertEqual(len(r1["ckpt_files"]), 0, "rank 1 should NOT save checkpoints") + + # Descriptor + fitting weights must match pretrained + for key in self.pretrained_state: + if key in r0["init_state"] and (".descriptor." in key or ".fitting" in key): + torch.testing.assert_close( + r0["init_state"][key], + self.pretrained_state[key], + msg=f"Weight should match pretrained: {key}", + ) + + # Init state identical across ranks + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Finetune init state differs across ranks: {name}", + ) + + +class TestDDPFinetuneRandomFitting(_DDPFinetuneBase): + """DDP finetune with random fitting: descriptor from pretrained, fitting random.""" + + def test_ddp_finetune_random_fitting(self) -> None: + port = _find_free_port() + result_dict = mp.Manager().dict() + + config = _make_config(self.data_dir, numb_steps=2) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + mp.spawn( + _worker_finetune, + args=(2, port, self.ckpt_path, config, "RANDOM", result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Descriptor weights must match pretrained + for key in self.pretrained_state: + if key in r0["init_state"] and ".descriptor." in key: + torch.testing.assert_close( + r0["init_state"][key], + self.pretrained_state[key], + msg=f"Descriptor weight should match pretrained: {key}", + ) + + # Fitting weights should NOT match pretrained (random init) + # bias_atom_e is set by bias adjustment, not random init — skip it + has_fitting_diff = False + for key in self.pretrained_state: + if ( + key in r0["init_state"] + and ".fitting" in key + and "bias_atom_e" not in key + and r0["init_state"][key].is_floating_point() + ): + if not torch.equal(r0["init_state"][key], self.pretrained_state[key]): + has_fitting_diff = True + self.assertTrue( + has_fitting_diff, "Random fitting should produce different weights" + ) + + # Init state identical across ranks + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Finetune random fitting state differs across ranks: {name}", + ) + + +class TestDDPFinetuneNewType(unittest.TestCase): + """DDP finetune with type_map change (new type). + + Exercises the ``_unwrapped.model["Default"]`` path (line 712) when + ``finetune_rule.get_has_new_type()`` is True, plus stat recomputation + and broadcast for the new type. Uses DPA1 (mixed_types) which supports + ``change_type_map``. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + raw_data = os.path.join(data_dir, "data_0") + + # Subsample data for faster DPA1 test + cls._data_tmpdir = tempfile.mkdtemp(prefix="ddp_ft_nt_data_") + _subsample_data(raw_data, os.path.join(cls._data_tmpdir, "data_0")) + cls.data_dir = os.path.join(cls._data_tmpdir, "data_0") + + # Train pretrained DPA1 with type_map=["O", "H"] + cls._train_tmpdir = tempfile.mkdtemp(prefix="ddp_ft_nt_train_") + old_cwd = os.getcwd() + os.chdir(cls._train_tmpdir) + try: + config = _make_dpa1_config(cls.data_dir, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + cls.ckpt_path = os.path.join(cls._train_tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), ( + "DPA1 pretrained checkpoint not created" + ) + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._data_tmpdir, ignore_errors=True) + shutil.rmtree(cls._train_tmpdir, ignore_errors=True) + + def test_ddp_finetune_new_type(self) -> None: + """Finetune DPA1 from ["O","H"] to ["O","H","B"] under DDP.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + + # Finetune config with new type "B" added + config = _make_dpa1_config(self.data_dir, numb_steps=2) + config["model"]["type_map"] = ["O", "H", "B"] + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + mp.spawn( + _worker_finetune, + args=(2, port, self.ckpt_path, config, None, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Training completes without error + self.assertTrue(r0["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse(r1["lcurve_exists"], "rank 1 should NOT produce lcurve.out") + + # Init state identical across ranks (stat broadcast worked) + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Finetune new_type init state differs across ranks: {name}", + ) + + +def _make_dpa1_multitask_config( + data_dir: str, numb_steps: int = 2, type_map: list | None = None +) -> dict: + """Build a minimal multi-task DPA1 config (mixed_types) for finetune tests.""" + if type_map is None: + type_map = ["O", "H"] + descriptor = { + "type": "dpa1", + "sel": 12, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [4, 8], + "axis_neuron": 4, + "attn": 4, + "attn_layer": 1, + "attn_dotr": True, + "seed": 1, + } + fitting = { + "neuron": [8, 8], + "resnet_dt": True, + "seed": 1, + } + return { + "model": { + "shared_dict": { + "my_type_map": list(type_map), + "my_descriptor": deepcopy(descriptor), + }, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": deepcopy(fitting), + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "model_prob": { + "model_1": 0.5, + "model_2": 0.5, + }, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [data_dir], + "batch_size": 1, + }, + "validation_data": { + "systems": [data_dir], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + + +def _worker_multitask_finetune( + rank, world_size, port, data_dir, ckpt_path, finetune_config, result_dict +): + """Worker: DDP multi-task finetune from checkpoint.""" + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_ft_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = deepcopy(finetune_config) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + config["model"], finetune_links = get_finetune_rules( + ckpt_path, config["model"] + ) + trainer = get_trainer( + config, + finetune_model=ckpt_path, + finetune_links=finetune_links, + shared_links=shared_links, + ) + # Capture init state before training + init_state = { + k: v.detach().cpu().clone() + for k, v in trainer._unwrapped.state_dict().items() + if k != "_extra_state" + } + trainer.run() + result_dict[rank] = { + "init_state": init_state, + "lcurve_exists": os.path.exists("lcurve.out"), + "ckpt_files": [f for f in os.listdir(".") if f.endswith(".pt")], + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +class TestDDPMultiTaskFinetune(unittest.TestCase): + """DDP multi-task finetune with type_map change (new type). + + Trains a 2-branch multi-task DPA1 model with type_map ["O","H"], then + finetunes 4 branches with extended type_map ["O","H","B"] under DDP. + Builds a reference state_dict by manually replicating the trainer's + finetune operations (load pretrained, change_type_map with computed + model_with_new_type_stat, weight copy) to verify correctness. + """ + + @classmethod + def setUpClass(cls) -> None: + from deepmd.pt_expt.model import ( + get_model, + ) + from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, + ) + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + from deepmd.pt_expt.utils.stat import ( + make_stat_input, + ) + from deepmd.utils.data import ( + DataRequirementItem, + ) + from deepmd.utils.data_system import ( + DeepmdDataSystem, + process_systems, + ) + + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + raw_data = os.path.join(data_dir, "data_0") + + # Subsample data for faster test + cls._data_tmpdir = tempfile.mkdtemp(prefix="ddp_mt_ft_data_") + _subsample_data(raw_data, os.path.join(cls._data_tmpdir, "data_0")) + cls.data_dir = os.path.join(cls._data_tmpdir, "data_0") + + ft_type_map = ["O", "H", "B"] + + # Train pretrained 2-branch multi-task DPA1 model + cls._train_tmpdir = tempfile.mkdtemp(prefix="ddp_mt_ft_train_") + old_cwd = os.getcwd() + os.chdir(cls._train_tmpdir) + try: + config = _make_dpa1_multitask_config(cls.data_dir, numb_steps=2) + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + + cls.ckpt_path = os.path.join(cls._train_tmpdir, "model.ckpt.pt") + assert os.path.exists(cls.ckpt_path), ( + "DPA1 multi-task pretrained checkpoint not created" + ) + + # Build reference state_dict with extended type_map + state_dict_full = torch.load( + cls.ckpt_path, map_location=DEVICE, weights_only=True + ) + state_dict_ckpt = ( + state_dict_full["model"] + if "model" in state_dict_full + else state_dict_full + ) + pretrained_model_params = state_dict_ckpt["_extra_state"]["model_params"] + + pretrained_models = {} + for pk in pretrained_model_params["model_dict"]: + pretrained_models[pk] = get_model( + deepcopy(pretrained_model_params["model_dict"][pk]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_models) + pretrained_wrapper.load_state_dict(state_dict_ckpt) + + # Record pretrained state BEFORE change_type_map for O/H + # inheritance verification + cls.pretrained_oh_state = { + k: v.cpu().clone() + for k, v in pretrained_wrapper.model.state_dict().items() + } + + # Build model_with_new_type_stat with computed stats + ref_model_params = deepcopy( + pretrained_model_params["model_dict"]["model_1"] + ) + ref_model_params["type_map"] = ft_type_map + ref_model = get_model(ref_model_params).to(DEVICE) + + energy_data_req = [ + DataRequirementItem( + "energy", ndof=1, atomic=False, must=False, high_prec=True + ), + DataRequirementItem( + "force", ndof=3, atomic=True, must=False, high_prec=False + ), + DataRequirementItem( + "virial", ndof=9, atomic=False, must=False, high_prec=False + ), + ] + data_systems = process_systems([cls.data_dir]) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=ft_type_map, + trn_all_set=True, + ) + data.add_data_requirements(energy_data_req) + ref_model.compute_or_load_stat( + sampled_func=lambda: make_stat_input(data, 1), + stat_file_path=None, + ) + + for pk in pretrained_model_params["model_dict"]: + pretrained_wrapper.model[pk].change_type_map( + ft_type_map, + model_with_new_type_stat=ref_model, + ) + + cls.ref_state_dict = pretrained_wrapper.model.state_dict() + finally: + os.chdir(old_cwd) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._data_tmpdir, ignore_errors=True) + shutil.rmtree(cls._train_tmpdir, ignore_errors=True) + + def test_ddp_multitask_finetune(self) -> None: + """Finetune 4-branch DPA1 from 2-branch with extended type_map under DDP.""" + ft_type_map = ["O", "H", "B"] + ft_config = _make_dpa1_multitask_config( + self.data_dir, numb_steps=1, type_map=ft_type_map + ) + + # Add model_3 and model_4 + ft_config["model"]["model_dict"]["model_3"] = deepcopy( + ft_config["model"]["model_dict"]["model_2"] + ) + ft_config["model"]["model_dict"]["model_4"] = deepcopy( + ft_config["model"]["model_dict"]["model_2"] + ) + ft_config["loss_dict"]["model_3"] = deepcopy(ft_config["loss_dict"]["model_2"]) + ft_config["loss_dict"]["model_4"] = deepcopy(ft_config["loss_dict"]["model_2"]) + ft_config["training"]["model_prob"]["model_3"] = 0.25 + ft_config["training"]["model_prob"]["model_4"] = 0.25 + ft_config["training"]["model_prob"]["model_1"] = 0.25 + ft_config["training"]["model_prob"]["model_2"] = 0.25 + ft_config["training"]["data_dict"]["model_3"] = deepcopy( + ft_config["training"]["data_dict"]["model_2"] + ) + ft_config["training"]["data_dict"]["model_3"]["stat_file"] = ( + "./stat_files/model_3" + ) + ft_config["training"]["data_dict"]["model_4"] = deepcopy( + ft_config["training"]["data_dict"]["model_2"] + ) + ft_config["training"]["data_dict"]["model_4"]["stat_file"] = ( + "./stat_files/model_4" + ) + + # Finetune rules: + # model_1: no finetune_head → resume + # model_2: finetune from model_2 + ft_config["model"]["model_dict"]["model_2"]["finetune_head"] = "model_2" + # model_3: finetune from model_2 + ft_config["model"]["model_dict"]["model_3"]["finetune_head"] = "model_2" + # model_4: no finetune_head, new key → random fitting + + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_finetune, + args=( + 2, + port, + self.data_dir, + self.ckpt_path, + ft_config, + result_dict, + ), + nprocs=2, + join=True, + ) + results = dict(result_dict) + r0, r1 = results[0], results[1] + + # Only rank 0 writes output + self.assertTrue(r0["lcurve_exists"], "rank 0 should produce lcurve.out") + self.assertFalse(r1["lcurve_exists"], "rank 1 should NOT produce lcurve.out") + + # Init state identical across ranks (DDP sync for finetune) + for name in r0["init_state"]: + torch.testing.assert_close( + r0["init_state"][name], + r1["init_state"][name], + msg=f"Multi-task finetune init state differs across ranks: {name}", + ) + + # Verify weight inheritance against reference (with extended type_map) + # Keys in init_state have "model." prefix from wrapper; ref_state_dict + # is from wrapper.model.state_dict() so keys don't have "model." prefix + ref = self.ref_state_dict + init = r0["init_state"] + for key in init: + # Skip type_embedding (random init for new type B differs) + if "type_embedding" in key: + continue + # Strip "model." prefix + model_key = key.split("model.", 1)[-1] if key.startswith("model.") else key + if "model_1" in key: + # model_1: resume — ALL weights match reference + if model_key in ref: + torch.testing.assert_close( + ref[model_key], + init[key], + msg=f"model_1 (resume) DDP mismatch: {key}", + ) + elif "model_2" in key and "out_bias" not in key and "out_std" not in key: + if model_key in ref: + torch.testing.assert_close( + ref[model_key], + init[key], + msg=f"model_2 (finetune) DDP mismatch: {key}", + ) + elif "model_3" in key and "out_bias" not in key and "out_std" not in key: + ref_key = model_key.replace("model_3", "model_2") + if ref_key in ref: + torch.testing.assert_close( + ref[ref_key], + init[key], + msg=f"model_3 (from model_2) DDP mismatch: {key}", + ) + elif ( + "model_4" in key + and "fitting_net" not in key + and "out_bias" not in key + and "out_std" not in key + ): + ref_key = model_key.replace("model_4", "model_2") + if ref_key in ref: + torch.testing.assert_close( + ref[ref_key], + init[key], + msg=f"model_4 (random) descriptor DDP mismatch: {key}", + ) + + # Verify O/H descriptor stats are inherited from pretrained (not + # recomputed). pretrained_oh_state has shape [2,...] for O,H; + # finetuned init has shape [3,...] for O,H,B. + _STAT_SUFFIXES = ("mean", "stddev", "davg", "dstd") + n_old = 2 # ["O", "H"] + n_new = 3 # ["O", "H", "B"] + checked_count = 0 + pretrained_oh = self.pretrained_oh_state + for key in init: + if "type_embedding" in key: + continue + if not any(key.endswith(s) for s in _STAT_SUFFIXES): + continue + # Use model_1 (all branches share descriptor after share_params) + if "model_1" not in key: + continue + # init_state has "model." prefix; pretrained_oh_state doesn't + pre_key = key.split("model.", 1)[-1] if key.startswith("model.") else key + if pre_key not in pretrained_oh: + continue + pre_val = pretrained_oh[pre_key] + ft_val = init[key] + # Find the type axis (size grew from n_old to n_new) + for ax in range(pre_val.ndim): + if pre_val.shape[ax] == n_old and ft_val.shape[ax] == n_new: + for ti, tname in enumerate(["O", "H"]): + torch.testing.assert_close( + ft_val.select(ax, ti), + pre_val.select(ax, ti), + msg=(f"{tname} stat not inherited from pretrained: {key}"), + ) + checked_count += 1 + break + self.assertGreater( + checked_count, + 0, + "No descriptor stat keys found for O/H inheritance check", + ) + + +if __name__ == "__main__": + unittest.main() From 9f4d232dccb9f0c2709b37094b362e7c8be3facc Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 08:07:24 +0800 Subject: [PATCH 02/56] fix(dpmodel): wrap fparam/aparam reshape with descriptive ValueError The reshape-first approach changed the error from a custom ValueError to a generic reshape error, breaking test_self_exception assertions. --- deepmd/dpmodel/fitting/general_fitting.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 810c62f927..a669cdcfc8 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -673,7 +673,13 @@ def _call_common( # check fparam dim, concate to input descriptor if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" - fparam = xp.reshape(fparam, (nf, self.numb_fparam)) + try: + fparam = xp.reshape(fparam, (nf, self.numb_fparam)) + except (ValueError, RuntimeError) as e: + raise ValueError( + f"input fparam: cannot reshape {fparam.shape} " + f"into ({nf}, {self.numb_fparam})." + ) from e fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1) @@ -690,7 +696,13 @@ def _call_common( # check aparam dim, concate to input descriptor if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" - aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + try: + aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + except (ValueError, RuntimeError) as e: + raise ValueError( + f"input aparam: cannot reshape {aparam.shape} " + f"into ({nf}, {nloc}, {self.numb_aparam})." + ) from e aparam = (aparam - self.aparam_avg[...]) * self.aparam_inv_std[...] xx = xp.concat( [xx, aparam], From 9f1f1d8254a3280d24ead91186e6e4805ab19b43 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 08:20:38 +0800 Subject: [PATCH 03/56] fix: address CodeQL findings in PR #5397 - Remove unused _min_nbor_dist variables in main.py - Remove unused _MULTITASK_JSON/_MULTITASK_SHAREFIT_JSON globals - Fix duplicate unittest import (use unittest.mock.patch directly) - Bind test socket to 127.0.0.1 instead of all interfaces - Remove redundant nframes assignment in _compile_model --- deepmd/pt_expt/entrypoints/main.py | 12 +++++------- deepmd/pt_expt/train/training.py | 1 - source/tests/pt_expt/test_multitask.py | 12 +++--------- source/tests/pt_expt/test_training_ddp.py | 2 +- 4 files changed, 9 insertions(+), 18 deletions(-) diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index 08fc6682a7..52c5e26016 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -264,7 +264,7 @@ def train( train_data = get_data( config["training"]["training_data"], 0, type_map, None ) - config["model"], _min_nbor_dist = BaseModel.update_sel( + config["model"], _ = BaseModel.update_sel( train_data, type_map, config["model"] ) else: @@ -276,12 +276,10 @@ def train( type_map, None, ) - config["model"]["model_dict"][model_key], _min_nbor_dist = ( - BaseModel.update_sel( - train_data, - type_map, - config["model"]["model_dict"][model_key], - ) + config["model"]["model_dict"][model_key], _ = BaseModel.update_sel( + train_data, + type_map, + config["model"]["model_dict"][model_key], ) with open(output, "w") as fp: diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 779d786de9..44f88bb82e 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -969,7 +969,6 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( best_sample ) - nframes = ext_coord_np.shape[0] actual_nall = ext_coord_np.shape[1] pad_n = max_nall - actual_nall diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py index c21e1070ae..4576fb0d36 100644 --- a/source/tests/pt_expt/test_multitask.py +++ b/source/tests/pt_expt/test_multitask.py @@ -23,8 +23,8 @@ from pathlib import ( Path, ) -from unittest import ( - mock, +from unittest.mock import ( + patch, ) import numpy as np @@ -72,12 +72,6 @@ # Paths to the water data used by PT tests _PT_DATA = str(Path(__file__).parent.parent / "pt" / "water" / "data" / "data_0") -# Path to the multitask template JSON -_MULTITASK_JSON = str(Path(__file__).parent.parent / "pt" / "water" / "multitask.json") -_MULTITASK_SHAREFIT_JSON = str( - Path(__file__).parent.parent / "pt" / "model" / "water" / "multitask_sharefit.json" -) - def _skip_if_no_data() -> None: if not os.path.isdir(_PT_DATA): @@ -1962,7 +1956,7 @@ def test_gradient_accumulation(self) -> None: {"desc": mt_desc, "fit_1": mt_fit_1, "fit_2": mt_fit_2}, mt_grads, ) - with mock.patch( + with patch( "deepmd.utils.random.choice", _deterministic_task_choice([0, 1]), ): diff --git a/source/tests/pt_expt/test_training_ddp.py b/source/tests/pt_expt/test_training_ddp.py index 053c9af4d5..3923547f03 100644 --- a/source/tests/pt_expt/test_training_ddp.py +++ b/source/tests/pt_expt/test_training_ddp.py @@ -62,7 +62,7 @@ def _find_free_port(): """Find a free TCP port on localhost.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) + s.bind(("127.0.0.1", 0)) return s.getsockname()[1] From f3f5474788ab17b6f667b741a4bccbba42f06f4e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 08:28:45 +0800 Subject: [PATCH 04/56] fix(pt_expt): access unwrapped module in _compile_model for DDP compat Under DDP self.wrapper is DistributedDataParallel which has no .model attribute. Use .module to access the underlying ModelWrapper. --- deepmd/pt_expt/train/training.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 44f88bb82e..9ada95f44e 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -898,8 +898,16 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: normalize_coord, ) + # Under DDP, self.wrapper is a DistributedDataParallel wrapper; + # access the underlying ModelWrapper via .module. + wrapper_mod = ( + self.wrapper.module + if isinstance(self.wrapper, torch.nn.parallel.DistributedDataParallel) + else self.wrapper + ) + for task_key in self.model_keys: - model = self.wrapper.model[task_key] + model = wrapper_mod.model[task_key] # --- Estimate max_nall by sampling multiple batches --- n_sample = 20 @@ -1000,7 +1008,7 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: task_compile_opts, ) - self.wrapper.model[task_key] = _CompiledModel( + wrapper_mod.model[task_key] = _CompiledModel( model, compiled_lower, max_nall, task_compile_opts ) log.info( From 665b85abd87c378f1077f04aa0b08e0f339f9d99 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 09:33:22 +0800 Subject: [PATCH 05/56] test(pt_expt): add DDP + torch.compile training tests Cover the _compile_model DDP unwrap fix with single-task and multi-task tests that enable_compile=True under 2-rank gloo DDP. --- source/tests/pt_expt/test_training_ddp.py | 206 ++++++++++++++++++++++ 1 file changed, 206 insertions(+) diff --git a/source/tests/pt_expt/test_training_ddp.py b/source/tests/pt_expt/test_training_ddp.py index 3923547f03..38771968ac 100644 --- a/source/tests/pt_expt/test_training_ddp.py +++ b/source/tests/pt_expt/test_training_ddp.py @@ -11,6 +11,7 @@ 5. Finetune + DDP: selective weight copy via _unwrapped 6. Finetune + DDP with random fitting: descriptor from pretrained, fitting random 7. Finetune + DDP with new type: exercises _unwrapped.model["Default"] + stat broadcast +8. DDP + torch.compile: single-task and multi-task compile under DDP """ import os @@ -1369,6 +1370,211 @@ def _worker_multitask_finetune( dist.destroy_process_group() +def _worker_single_task_compile_train(rank, world_size, port, data_dir, result_dict): + """Worker: run single-task DDP training with torch.compile enabled. + + This exercises the ``_compile_model`` code path under DDP, which must + unwrap ``DistributedDataParallel`` to access ``wrapper.module.model``. + Before the fix, ``self.wrapper.model[task_key]`` raised ``AttributeError`` + because ``DistributedDataParallel`` does not expose ``.model`` directly. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_compile_st_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_config(data_dir, numb_steps=2) + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + trainer = get_trainer(config) + trainer.run() + + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + # Check the compiled model is a _CompiledModel + is_compiled = isinstance( + trainer._unwrapped.model["Default"], _CompiledModel + ) + + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + weights = { + name: p.detach().cpu().clone() + for name, p in trainer._unwrapped.named_parameters() + } + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "weights": weights, + "is_compiled": is_compiled, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +def _worker_multitask_compile_train(rank, world_size, port, data_dir, result_dict): + """Worker: run multi-task DDP training with torch.compile enabled. + + Exercises the per-branch compilation loop in ``_compile_model`` under DDP. + """ + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = str(port) + os.environ["DEVICE"] = "cpu" + dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + try: + tmpdir = tempfile.mkdtemp(prefix=f"ddp_compile_mt_rank{rank}_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + config = _make_multitask_config(data_dir, numb_steps=2) + config["training"]["enable_compile"] = True + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + trainer = get_trainer(config, shared_links=shared_links) + trainer.run() + + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + # Check both branch models are compiled + compiled_flags = {} + for mk in ("model_1", "model_2"): + compiled_flags[mk] = isinstance( + trainer._unwrapped.model[mk], _CompiledModel + ) + + lcurve_exists = os.path.exists("lcurve.out") + ckpt_files = [f for f in os.listdir(".") if f.endswith(".pt")] + + # Get shared descriptor params from model_1 + desc_params = {} + for name, p in trainer._unwrapped.model[ + "model_1" + ].original_model.atomic_model.descriptor.named_parameters(): + desc_params[name] = p.detach().cpu().clone() + + result_dict[rank] = { + "lcurve_exists": lcurve_exists, + "num_ckpts": len(ckpt_files), + "desc_params": desc_params, + "compiled_flags": compiled_flags, + } + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + finally: + dist.destroy_process_group() + + +class TestDDPCompileSingleTask(unittest.TestCase): + """DDP + torch.compile: single-task training with 2 ranks. + + Exercises ``_compile_model`` under DDP, which requires unwrapping + ``DistributedDataParallel`` to access ``wrapper.module.model``. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = os.path.join(data_dir, "data_0") + + def test_ddp_compile_single_task(self) -> None: + """2 ranks, se_e2_a, enable_compile=True, 2 steps.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_single_task_compile_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Both ranks have compiled models + self.assertTrue(results[0]["is_compiled"], "rank 0 model should be compiled") + self.assertTrue(results[1]["is_compiled"], "rank 1 model should be compiled") + + # Only rank 0 produces output files + self.assertTrue(results[0]["lcurve_exists"]) + self.assertFalse(results[1]["lcurve_exists"]) + self.assertGreater(results[0]["num_ckpts"], 0) + self.assertEqual(results[1]["num_ckpts"], 0) + + # Final weights identical across ranks + for name in results[0]["weights"]: + torch.testing.assert_close( + results[0]["weights"][name], + results[1]["weights"][name], + msg=f"Compiled DDP weights differ across ranks: {name}", + ) + + +class TestDDPCompileMultiTask(unittest.TestCase): + """DDP + torch.compile: multi-task training with 2 ranks. + + Exercises the per-branch compilation loop in ``_compile_model`` under DDP. + """ + + @classmethod + def setUpClass(cls) -> None: + if not os.path.isdir(_PT_DATA): + raise unittest.SkipTest(f"Test data not found: {_PT_DATA}") + cls.data_dir = _PT_DATA + + def test_ddp_compile_multitask(self) -> None: + """2 ranks, multi-task, enable_compile=True, 2 steps.""" + port = _find_free_port() + result_dict = mp.Manager().dict() + mp.spawn( + _worker_multitask_compile_train, + args=(2, port, self.data_dir, result_dict), + nprocs=2, + join=True, + ) + results = dict(result_dict) + + # Both ranks have compiled models for both branches + for mk in ("model_1", "model_2"): + self.assertTrue( + results[0]["compiled_flags"][mk], + f"rank 0 {mk} should be compiled", + ) + self.assertTrue( + results[1]["compiled_flags"][mk], + f"rank 1 {mk} should be compiled", + ) + + # Only rank 0 produces output files + self.assertTrue(results[0]["lcurve_exists"]) + self.assertFalse(results[1]["lcurve_exists"]) + self.assertGreater(results[0]["num_ckpts"], 0) + self.assertEqual(results[1]["num_ckpts"], 0) + + # Shared descriptor params identical across ranks + for name in results[0]["desc_params"]: + torch.testing.assert_close( + results[0]["desc_params"][name], + results[1]["desc_params"][name], + msg=f"Compiled DDP shared descriptor param differs: {name}", + ) + + class TestDDPMultiTaskFinetune(unittest.TestCase): """DDP multi-task finetune with type_map change (new type). From aabb710db1cdacaf58382ae4e01797c0418fc753 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 17:46:52 +0800 Subject: [PATCH 06/56] feat(pt_expt): use inductor+dynamic for training compile Replace aot_eager+padding+manual recompile with symbolic make_fx + torch.compile(backend="inductor", dynamic=True). The compiled graph natively handles varying nframes/nloc/nall so the per-batch padding and runtime _recompile pass can be removed. Use a trace-time nframes of 7 (prime) and reshape with -1 in dpmodel (general_fitting, env_mat) to prevent PyTorch's symbolic tracer from unifying the batch dim with numb_fparam / numb_aparam / ntypes / dim_case_embd. Add TestCompiledVaryingNframesWithParams covering collisions with fparam/aparam, and TestCompileCaseEmbdVaryingNframes covering dim_case_embd > 0 with runtime nframes matching the embed dim. --- deepmd/dpmodel/fitting/general_fitting.py | 23 +- deepmd/dpmodel/utils/env_mat.py | 14 +- deepmd/pt_expt/train/training.py | 288 ++++++++-------------- source/tests/pt_expt/test_multitask.py | 200 +++++++++++++++ source/tests/pt_expt/test_training.py | 251 ++++++++++++++++--- 5 files changed, 541 insertions(+), 235 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a669cdcfc8..ee270c6c7a 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -674,7 +674,12 @@ def _call_common( if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" try: - fparam = xp.reshape(fparam, (nf, self.numb_fparam)) + # Use -1 for nframes so the shape is inferred from the total + # size. Passing the concrete symbol `nf` here would let + # torch.fx's symbolic tracer specialise when `nf` happens to + # equal another tensor dim (e.g. numb_fparam), baking the + # batch size into the compiled graph. + fparam = xp.reshape(fparam, (-1, self.numb_fparam)) except (ValueError, RuntimeError) as e: raise ValueError( f"input fparam: cannot reshape {fparam.shape} " @@ -682,7 +687,7 @@ def _call_common( ) from e fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( - xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1) + xp.reshape(fparam, (-1, 1, self.numb_fparam)), (1, nloc, 1) ) xx = xp.concat( [xx, fparam], @@ -697,7 +702,9 @@ def _call_common( if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" try: - aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) + # Use -1 for nframes so the shape is inferred from the total + # size; see the fparam branch above for rationale. + aparam = xp.reshape(aparam, (-1, nloc, self.numb_aparam)) except (ValueError, RuntimeError) as e: raise ValueError( f"input aparam: cannot reshape {aparam.shape} " @@ -744,8 +751,12 @@ def _call_common( device=array_api_compat.device(descriptor), ) for type_i in range(self.ntypes): + # Use -1 for nframes so the shape is inferred; see the fparam + # branch above for rationale (avoid symbolic-dim collision + # with numb_fparam / other dims during symbolic tracing). mask = xp.tile( - xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out) + xp.reshape((atype == type_i), (-1, nloc, 1)), + (1, 1, net_dim_out), ) atom_property = self.nets[(type_i,)](xx) if self.remove_vaccum_contribution is not None and not ( @@ -761,7 +772,7 @@ def _call_common( if self.eval_return_middle_output and len(self.neuron) > 0: mid = self.nets[(type_i,)].call_until_last(xx) mid_mask = xp.tile( - xp.reshape((atype == type_i), (nf, nloc, 1)), + xp.reshape((atype == type_i), (-1, nloc, 1)), (1, 1, self.neuron[-1]), ) mid = xp.where(mid_mask, mid, xp.zeros_like(mid)) @@ -778,7 +789,7 @@ def _call_common( xp.reshape(atype, (-1,)), axis=0, ), - (nf, nloc, net_dim_out), + (-1, nloc, net_dim_out), ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 9856741317..6b798ae498 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -68,16 +68,22 @@ def _make_env_mat( xp = array_api_compat.array_namespace(nlist) nf, nloc, nnei = nlist.shape # nf x nall x 3 - coord = xp.reshape(coord, (nf, -1, 3)) + # Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise + # both to (nf, nall, 3) using -1 for nframes so the shape is inferred + # from the total size. Passing the symbolic nf here can trigger + # torch.fx symbolic-tracer specialisation when nf happens to collide + # with another dim (e.g. numb_fparam) during training compile. + if coord.ndim == 2: + coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3)) mask = nlist >= 0 nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 - index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) + index = xp.tile(xp.reshape(nlist, (-1, nloc * nnei, 1)), (1, 1, 3)) coord_r = xp_take_along_axis(coord, index, 1) # nf x nloc x nnei x 3 - coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) + coord_r = xp.reshape(coord_r, (-1, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3)) + coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (-1, nloc, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 9ada95f44e..8771c52af5 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -174,16 +174,17 @@ def _trace_and_compile( aparam: torch.Tensor | None, compile_opts: dict[str, Any], ) -> torch.nn.Module: - """Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``. + """Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True. Parameters ---------- model : torch.nn.Module The (uncompiled) model. ext_coord, ext_atype, nlist, mapping, fparam, aparam - Sample tensors (already padded to the desired max_nall). + Sample tensors used to seed the symbolic tracer. compile_opts : dict - Options forwarded to ``torch.compile`` (excluding ``dynamic``). + Options forwarded to ``torch.compile`` (the ``dynamic`` and + ``backend`` keys are ignored and replaced). Returns ------- @@ -210,7 +211,7 @@ def fn( fparam: torch.Tensor | None, aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: - extended_coord = extended_coord.requires_grad_(True) + extended_coord = extended_coord.detach().requires_grad_(True) return model.forward_lower( extended_coord, extended_atype, @@ -220,10 +221,40 @@ def fn( aparam=aparam, ) - # Use default tracing_mode="real" (concrete shapes) for best - # runtime performance. If data-dependent intermediate shapes - # change at runtime, the caller catches the error and retraces. - traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam) + # Pick a trace-time nframes that's unlikely to collide with any other + # tensor dim in the graph. The symbolic tracer merges symbols that + # are numerically equal at trace time, which bakes nframes into the + # compiled graph whenever it matches e.g. numb_fparam, numb_aparam, + # ntypes, axis_neuron, or neuron sizes (8, 16, 32, ...). Using a + # prime value of 7 avoids the common small-dim collisions while still + # being cheap to trace. + _TRACE_NFRAMES = 7 + cur_nframes = ext_coord.shape[0] + if cur_nframes != _TRACE_NFRAMES: + + def _expand(t: torch.Tensor | None) -> torch.Tensor | None: + if t is None: + return None + # Repeat rows so total nframes == _TRACE_NFRAMES. Use index + # gather (mod) so we don't require divisibility. + idx = ( + torch.arange(_TRACE_NFRAMES, dtype=torch.long, device=t.device) + % cur_nframes + ) + return t.index_select(0, idx) + + ext_coord = _expand(ext_coord) + ext_atype = _expand(ext_atype) + nlist = _expand(nlist) + mapping = _expand(mapping) + fparam = _expand(fparam) + aparam = _expand(aparam) + + traced_lower = make_fx( + fn, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + )(ext_coord, ext_atype, nlist, mapping, fparam, aparam) # make_fx inserts aten.detach.default for saved tensors used in the # decomposed autograd.grad backward ops. These detach nodes break @@ -234,72 +265,36 @@ def fn( if not was_training: model.eval() - if "backend" not in compile_opts: - compile_opts["backend"] = "aot_eager" - compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts) - return compiled_lower + # Work on a copy; ignore caller-supplied dynamic/backend. + compile_opts = { + k: v for k, v in compile_opts.items() if k not in ("dynamic", "backend") + } + opts = compile_opts.setdefault("options", {}) + opts.setdefault("max_autotune", False) + opts.setdefault("epilogue_fusion", False) + opts.setdefault("triton.cudagraphs", False) + opts.setdefault("shape_padding", True) + opts.setdefault("max_fusion_size", 8) + + return torch.compile( + traced_lower, + backend="inductor", + dynamic=True, + **compile_opts, + ) class _CompiledModel(torch.nn.Module): - """Coord extension (eager) -> pad nall -> compiled forward_lower. - - If a batch's ``nall`` exceeds the current ``max_nall``, the model is - automatically re-traced and recompiled with a larger pad size. - """ + """Coord extension (eager) -> compiled forward_lower (dynamic shapes).""" def __init__( self, original_model: torch.nn.Module, compiled_forward_lower: torch.nn.Module, - max_nall: int, - compile_opts: dict[str, Any], ) -> None: super().__init__() self.original_model = original_model self.compiled_forward_lower = compiled_forward_lower - self._max_nall = max_nall - self._compile_opts = compile_opts - - def _recompile( - self, - ext_coord: torch.Tensor, - ext_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor, - fparam: torch.Tensor | None, - aparam: torch.Tensor | None, - new_max_nall: int, - ) -> None: - """Re-trace and recompile for the given inputs. - - If *new_max_nall* differs from the current ``_max_nall``, the - inputs are padded (or already padded by the caller). - """ - # Pad if the caller provides unpadded tensors (nall growth case) - actual_nall = ext_coord.shape[1] - pad_n = new_max_nall - actual_nall - if pad_n > 0: - ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) - ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) - mapping = torch.nn.functional.pad(mapping, (0, pad_n)) - - ext_coord = ext_coord.detach() - - self.compiled_forward_lower = _trace_and_compile( - self.original_model, - ext_coord, - ext_atype, - nlist, - mapping, - fparam, - aparam, - self._compile_opts, - ) - self._max_nall = new_max_nall - log.info( - "Recompiled model with max_nall=%d.", - new_max_nall, - ) def forward( self, @@ -343,27 +338,6 @@ def forward( distinguish_types=False, ) ext_coord = ext_coord.reshape(nframes, -1, 3) - - # Grow max_nall if needed (retrace + recompile) - actual_nall = ext_coord.shape[1] - if actual_nall > self._max_nall: - new_max_nall = ((int(actual_nall * 1.2) + 7) // 8) * 8 - log.info( - "nall=%d exceeds max_nall=%d; recompiling with max_nall=%d.", - actual_nall, - self._max_nall, - new_max_nall, - ) - self._recompile( - ext_coord, ext_atype, nlist, mapping, fparam, aparam, new_max_nall - ) - - # Pad to max_nall so compiled graph sees a fixed shape - pad_n = self._max_nall - actual_nall - if pad_n > 0: - ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n)) - ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n)) - mapping = torch.nn.functional.pad(mapping, (0, pad_n)) ext_coord = ext_coord.detach().requires_grad_(True) result = self.compiled_forward_lower( @@ -379,18 +353,12 @@ def forward( out["atom_energy"] = result["atom_energy"] out["energy"] = result["energy"] if "extended_force" in result: - ext_force = result["extended_force"] # (nf, nall_padded, 3) - # mapping may be padded; only use actual_nall entries - map_actual = mapping[:, :actual_nall] # (nf, actual_nall) - ext_force_actual = ext_force[:, :actual_nall, :] # (nf, actual_nall, 3) - # scatter-sum extended forces onto local atoms - idx = map_actual.unsqueeze(-1).expand_as( - ext_force_actual - ) # (nf, actual_nall, 3) + ext_force = result["extended_force"] # (nf, nall, 3) + idx = mapping.unsqueeze(-1).expand_as(ext_force) # (nf, nall, 3) force = torch.zeros( nframes, nloc, 3, dtype=ext_force.dtype, device=ext_force.device ) - force.scatter_add_(1, idx, ext_force_actual) + force.scatter_add_(1, idx, ext_force) out["force"] = force if "virial" in result: out["virial"] = result["virial"] @@ -878,17 +846,13 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: computation) with ``create_graph=True``, which creates a "double backward" that ``torch.compile`` cannot handle. - Solution: use ``make_fx`` to trace ``forward_lower``, decomposing - ``torch.autograd.grad`` into primitive ops. The coord extension + - nlist build (data-dependent control flow) are kept outside the - compiled region. - - To avoid the overhead of symbolic tracing and dynamic shapes, the - extended-atom dimension (nall) is padded to a fixed maximum - estimated from the training data. This allows concrete-shape - tracing and ``dynamic=False``. If a batch exceeds the current - max_nall at runtime, the model is automatically re-traced and - recompiled with a larger pad size. + Solution: use ``make_fx`` in ``tracing_mode="symbolic"`` to trace + ``forward_lower``, decomposing ``torch.autograd.grad`` into + primitive ops. The symbolic trace keeps the extended-atom + dimension (``nall``) and batch dimension (``nframes``) as + symbolic shapes, so no padding or recompile-on-growth logic is + needed. The coord extension + nlist build (data-dependent + control flow) are kept outside the compiled region. """ from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, @@ -909,111 +873,53 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: for task_key in self.model_keys: model = wrapper_mod.model[task_key] - # --- Estimate max_nall by sampling multiple batches --- - n_sample = 20 - max_nall = 0 - best_sample: ( - tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None - ) = None - - for _ii in range(n_sample): - inp, _ = self.get_data(is_train=True, task_key=task_key) - coord = inp["coord"].detach() - atype = inp["atype"].detach() - box = inp.get("box") - if box is not None: - box = box.detach() - - nframes, nloc = atype.shape[:2] - coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3) - atype_np = atype.cpu().numpy() - box_np = ( - box.cpu().numpy().reshape(nframes, 9) if box is not None else None - ) - - if box_np is not None: - coord_norm = normalize_coord( - coord_np, box_np.reshape(nframes, 3, 3) - ) - else: - coord_norm = coord_np + inp, _ = self.get_data(is_train=True, task_key=task_key) + coord = inp["coord"].detach() + atype = inp["atype"].detach() + box = inp.get("box") + if box is not None: + box = box.detach() - ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts( - coord_norm, atype_np, box_np, model.get_rcut() - ) - nlist_np = build_neighbor_list( - ext_coord_np, - ext_atype_np, - nloc, - model.get_rcut(), - model.get_sel(), - distinguish_types=False, - ) - ext_coord_np = ext_coord_np.reshape(nframes, -1, 3) - nall = ext_coord_np.shape[1] - if nall > max_nall: - max_nall = nall - best_sample = ( - ext_coord_np, - ext_atype_np, - mapping_np, - nlist_np, - nloc, - inp, - ) + nframes, nloc = atype.shape[:2] + coord_3d = coord.reshape(nframes, nloc, 3) + box_flat = box.reshape(nframes, 9) if box is not None else None - # Add 20 % margin and round up to a multiple of 8. - max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8 - log.info( - "Estimated max_nall=%d for compiled model " - "(task=%s, sampled %d batches).", - max_nall, - task_key, - n_sample, - ) + if box_flat is not None: + coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3)) + else: + coord_norm = coord_3d - # --- Pad the largest sample to max_nall and trace --- - assert best_sample is not None - ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = ( - best_sample + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_norm, atype, box_flat, model.get_rcut() ) - actual_nall = ext_coord_np.shape[1] - pad_n = max_nall - actual_nall - - if pad_n > 0: - ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0))) - ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n))) - mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n))) - - ext_coord = torch.tensor( - ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE + nlist_t = build_neighbor_list( + ext_coord, + ext_atype, + nloc, + model.get_rcut(), + model.get_sel(), + distinguish_types=False, ) - ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE) - nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE) - mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE) - fparam = sample_input.get("fparam") - aparam = sample_input.get("aparam") + ext_coord = ext_coord.reshape(nframes, -1, 3) - task_compile_opts = dict(compile_opts) - task_compile_opts.pop("dynamic", None) # always False for padded approach + fparam = inp.get("fparam") + aparam = inp.get("aparam") compiled_lower = _trace_and_compile( model, ext_coord, ext_atype, nlist_t, - mapping_t, + mapping, fparam, aparam, - task_compile_opts, + compile_opts, ) - wrapper_mod.model[task_key] = _CompiledModel( - model, compiled_lower, max_nall, task_compile_opts - ) + wrapper_mod.model[task_key] = _CompiledModel(model, compiled_lower) log.info( - "Model compiled with padded nall=%d (task=%s, dynamic=False).", - max_nall, + "Model compiled (task=%s, tracing_mode=symbolic, " + "dynamic=True, backend=inductor).", task_key, ) diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py index 4576fb0d36..23b799ceb6 100644 --- a/source/tests/pt_expt/test_multitask.py +++ b/source/tests/pt_expt/test_multitask.py @@ -2089,5 +2089,205 @@ def test_gradient_accumulation(self) -> None: ) +class TestCompileCaseEmbdVaryingNframes(unittest.TestCase): + """Compiled multi-task with ``dim_case_embd > 0`` and varying ``nframes``. + + The shared-fitting path in ``GeneralFitting.call`` tiles the per-task + case embedding as ``xp.tile(reshape(case_embd, (1, 1, -1)), (nf, nloc, 1))`` + (see ``deepmd/dpmodel/fitting/general_fitting.py``). Under + ``tracing_mode="symbolic"`` the ``nf`` multiplier must stay symbolic; + otherwise the compiled graph hard-codes a specific batch size and + subsequent calls with a different ``nframes`` error out. + + The test uses two systems with different atom counts and per-system + ``batch_size=[2, 3]`` so every branch's compiled graph sees both + nframes values. ``dim_case_embd=2`` is deliberately chosen to also + collide numerically with the nframes=2 runtime case. ``dp_random.choice`` + is mocked so both tasks and both systems are sampled. + """ + + @classmethod + def setUpClass(cls) -> None: + cls.tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_case_embd_vary_") + cls.sys0_m1 = os.path.join(cls.tmpdir, "sys0_model1_6atoms") + cls.sys1_m1 = os.path.join(cls.tmpdir, "sys1_model1_4atoms") + cls.sys0_m2 = os.path.join(cls.tmpdir, "sys0_model2_6atoms") + cls.sys1_m2 = os.path.join(cls.tmpdir, "sys1_model2_4atoms") + for path, seed in ( + (cls.sys0_m1, 11), + (cls.sys1_m1, 12), + (cls.sys0_m2, 21), + (cls.sys1_m2, 22), + ): + _generate_random_data_dir( + path, + atom_types=[i % 2 for i in range(6 if "6atoms" in path else 4)], + nframes=4, + seed=seed, + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def _make_config(self) -> dict: + type_map = ["O", "H"] + fitting = deepcopy(_fitting_net) + fitting["dim_case_embd"] = 2 + shared_dict: dict = { + "my_type_map": type_map, + "my_descriptor": deepcopy(_descriptor_se_e2_a), + "my_fitting": fitting, + } + config = { + "model": { + "shared_dict": shared_dict, + "model_dict": { + "model_1": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + "model_2": { + "type_map": "my_type_map", + "descriptor": "my_descriptor", + "fitting_net": "my_fitting", + "data_stat_nbatch": 1, + }, + }, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss_dict": { + "model_1": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "model_2": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + }, + "training": { + "enable_compile": True, + "model_prob": {"model_1": 0.5, "model_2": 0.5}, + "data_dict": { + "model_1": { + "stat_file": "./stat_files/model_1", + "training_data": { + "systems": [self.sys0_m1, self.sys1_m1], + "batch_size": [2, 3], + }, + "validation_data": { + "systems": [self.sys0_m1], + "batch_size": 1, + "numb_btch": 1, + }, + }, + "model_2": { + "stat_file": "./stat_files/model_2", + "training_data": { + "systems": [self.sys0_m2, self.sys1_m2], + "batch_size": [2, 3], + }, + "validation_data": { + "systems": [self.sys0_m2], + "batch_size": 1, + "numb_btch": 1, + }, + }, + }, + "numb_steps": 1, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + config["model"], shared_links = preprocess_shared_params(config["model"]) + config = update_deepmd_input(config, warning=False) + config = normalize(config, multi_task=True) + return config, shared_links + + def test_compiled_varying_nframes_with_case_embd(self) -> None: + """Compiled shared-fitting graph handles nframes in {2, 3} per branch.""" + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + config, shared_links = self._make_config() + tmpdir = tempfile.mkdtemp(prefix="pt_expt_mt_case_embd_run_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(deepcopy(config), shared_links=shared_links) + # Both branches must be compiled. + for mk in ("model_1", "model_2"): + self.assertIsInstance(trainer.wrapper.model[mk], _CompiledModel) + ce = trainer.wrapper.model[ + mk + ].original_model.atomic_model.fitting_net.case_embd + self.assertIsNotNone(ce, f"case_embd not set on {mk}") + self.assertEqual(int(ce.shape[0]), 2) + + # Drive 6 steps alternating (task, system_index) so each branch's + # compiled graph sees both nframes=2 (sys0) and nframes=3 (sys1). + trainer.wrapper.train() + task_sequence = ["model_1", "model_2"] * 3 + sys_sequence = [0, 1, 0, 1, 0, 1] + sys_iter = iter(sys_sequence) + + original_choice = dp_random.choice + + def task_or_system_choice(a, size=None, replace=True, p=None): + # Per-branch system selection: alternate between the two + # systems so every compiled graph sees both nframes values. + if hasattr(a, "__len__") and len(a) == 2 and p is not None: + return next(sys_iter) + return original_choice(a, size=size, replace=replace, p=p) + + seen_nframes: set[int] = set() + with patch.object(dp_random, "choice", side_effect=task_or_system_choice): + for task_key in task_sequence: + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True, task_key=task_key) + seen_nframes.add(int(inp["coord"].shape[0])) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper( + **inp, cur_lr=lr, label=lab, task_key=task_key + ) + loss.backward() + trainer.optimizer.step() + self.assertFalse(torch.isnan(loss), "loss is NaN") + self.assertFalse(torch.isinf(loss), "loss is Inf") + + self.assertEqual( + seen_nframes, + {2, 3}, + msg=( + f"nframes did not vary across steps: {seen_nframes}. " + "Expected both 2 and 3 (matching and not matching dim_case_embd=2)." + ), + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index e0250eb24e..1397a65c7d 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -12,6 +12,9 @@ import shutil import tempfile import unittest +from unittest.mock import ( + patch, +) import torch @@ -164,8 +167,8 @@ def test_training_loop_compiled(self) -> None: self._run_training(config) -class TestCompiledRecompile(unittest.TestCase): - """Test that _CompiledModel recompiles when nall exceeds max_nall.""" +class TestCompiledDynamicShapes(unittest.TestCase): + """Test that _CompiledModel handles varying nall via dynamic shapes.""" @classmethod def setUpClass(cls) -> None: @@ -174,8 +177,13 @@ def setUpClass(cls) -> None: raise unittest.SkipTest(f"Example data not found: {data_dir}") cls.data_dir = data_dir - def test_nall_growth_triggers_recompile(self) -> None: - """Shrink max_nall to force a recompile, then verify training works.""" + def test_compiled_handles_varying_nall(self) -> None: + """Run several training steps, assert finite loss each step. + + With ``tracing_mode="symbolic"`` + ``dynamic=True``, nall is a + symbolic dim so nall growth across batches is handled without + any recompile or padding. + """ from deepmd.pt_expt.train.training import ( _CompiledModel, ) @@ -185,7 +193,7 @@ def test_nall_growth_triggers_recompile(self) -> None: config = update_deepmd_input(config, warning=False) config = normalize(config) - tmpdir = tempfile.mkdtemp(prefix="pt_expt_recompile_") + tmpdir = tempfile.mkdtemp(prefix="pt_expt_dynamic_") try: old_cwd = os.getcwd() os.chdir(tmpdir) @@ -196,36 +204,18 @@ def test_nall_growth_triggers_recompile(self) -> None: compiled_model = trainer.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) - original_max_nall = compiled_model._max_nall - self.assertGreater(original_max_nall, 0) - - # Artificially shrink max_nall to 1 so the next batch - # will certainly exceed it and trigger recompilation. - compiled_model._max_nall = 1 - old_compiled_lower = compiled_model.compiled_forward_lower - - # Run one training step — should trigger recompile trainer.wrapper.train() - trainer.optimizer.zero_grad(set_to_none=True) - inp, lab = trainer.get_data(is_train=True) - lr = trainer.scheduler.get_last_lr()[0] - _, loss, more_loss = trainer.wrapper(**inp, cur_lr=lr, label=lab) - loss.backward() - trainer.optimizer.step() - - # max_nall should have grown beyond 1 - new_max_nall = compiled_model._max_nall - self.assertGreater(new_max_nall, 1) - - # compiled_forward_lower should be a new object - self.assertIsNot( - compiled_model.compiled_forward_lower, - old_compiled_lower, - ) - - # Loss should be a finite scalar - self.assertFalse(torch.isnan(loss)) - self.assertFalse(torch.isinf(loss)) + for _ in range(3): + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer.optimizer.step() + + # Loss should be a finite scalar at every step + self.assertFalse(torch.isnan(loss)) + self.assertFalse(torch.isinf(loss)) finally: os.chdir(old_cwd) finally: @@ -704,5 +694,198 @@ def test_training_loop(self) -> None: shutil.rmtree(tmpdir, ignore_errors=True) +class TestCompiledVaryingNframesWithParams(unittest.TestCase): + """Compiled training with varying ``nframes`` + ``nall`` + fparam/aparam. + + Exercises the compiled forward path under all three kinds of shape + variation simultaneously: + + * Different systems have different atom counts -> varying ``nloc`` / ``nall``. + * Per-system ``batch_size: [2, 3]`` -> varying ``nframes`` (2 vs 3). + * Both ``fparam`` (per-frame) and ``aparam`` (per-atom) labels are + provided, covering the ``dim_fparam`` / ``dim_aparam`` > 0 branches + inside ``forward_lower``. + + The chosen values (``nframes`` in {2, 3}, ``numb_fparam=2``, + ``numb_aparam=3``) are deliberately chosen so the runtime ``nframes`` + collides with the per-frame / per-atom feature dims — this is the + exact pattern that previously caused PyTorch's symbolic tracer to + specialise the batch dim (see _trace_and_compile in training.py). + + ``dp_random.choice`` is mocked to alternate between the two systems + so both are guaranteed to be sampled across ``nsteps``. + """ + + NFPARAM = 2 + NAPARAM = 3 + + @classmethod + def setUpClass(cls) -> None: + # Reuse the data-dir helper from the multitask gradient tests so we + # don't duplicate the npy/raw layout boilerplate. + from .test_multitask import ( + _generate_random_data_dir, + ) + + cls.tmpdir = tempfile.mkdtemp(prefix="pt_expt_varying_params_data_") + cls.sys0 = os.path.join(cls.tmpdir, "sys0_8atoms") + cls.sys1 = os.path.join(cls.tmpdir, "sys1_4atoms") + # Atom types alternate 0/1 to match the ["O", "H"] type_map below. + _generate_random_data_dir( + cls.sys0, + atom_types=[i % 2 for i in range(8)], + nframes=4, + seed=42, + nfparam=cls.NFPARAM, + naparam=cls.NAPARAM, + ) + _generate_random_data_dir( + cls.sys1, + atom_types=[i % 2 for i in range(4)], + nframes=4, + seed=137, + nfparam=cls.NFPARAM, + naparam=cls.NAPARAM, + ) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def _make_config(self, enable_compile: bool) -> dict: + config = { + "model": { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "numb_fparam": self.NFPARAM, + "numb_aparam": self.NAPARAM, + "seed": 1, + }, + "data_stat_nbatch": 1, + }, + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [self.sys0, self.sys1], + # Per-system batch sizes: sys0 gets nframes=2, sys1 gets nframes=3. + # Combined with sys0=8 atoms / sys1=4 atoms this guarantees + # both `nframes` and `nall` vary across steps. + "batch_size": [2, 3], + }, + "validation_data": { + "systems": [self.sys0], + "batch_size": 1, + "numb_btch": 1, + }, + "numb_steps": 6, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 100, + }, + } + if enable_compile: + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + return config + + def _run_steps(self, enable_compile: bool, nsteps: int = 6) -> None: + from deepmd.utils import data_system as _data_system + + config = self._make_config(enable_compile=enable_compile) + sys_sequence = [i % 2 for i in range(nsteps)] + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_varying_params_run_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + if enable_compile: + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + self.assertIsInstance( + trainer.wrapper.model["Default"], _CompiledModel + ) + + trainer.wrapper.train() + seen_nframes = set() + seen_nall = set() + with patch.object( + _data_system.dp_random, + "choice", + side_effect=sys_sequence, + ): + for _ in range(nsteps): + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + seen_nframes.add(int(inp["coord"].shape[0])) + seen_nall.add(int(inp["atype"].shape[1])) + # fparam/aparam must be present in every batch + self.assertIn("fparam", inp) + self.assertIn("aparam", inp) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer.optimizer.step() + self.assertFalse(torch.isnan(loss), "loss is NaN") + self.assertFalse(torch.isinf(loss), "loss is Inf") + + # The two systems differ in both batch-size-auto and natoms, + # so both nframes and nloc should have varied across steps. + self.assertGreater( + len(seen_nframes), + 1, + msg=f"nframes did not vary across steps: {seen_nframes}", + ) + self.assertGreater( + len(seen_nall), + 1, + msg=f"nloc did not vary across steps: {seen_nall}", + ) + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compiled(self) -> None: + """Compiled training with varying nframes + fparam/aparam.""" + self._run_steps(enable_compile=True) + + def test_uncompiled(self) -> None: + """Baseline: same config, uncompiled, should also succeed.""" + self._run_steps(enable_compile=False) + + if __name__ == "__main__": unittest.main() From f774cd23e34c72fb191a6135cbdb873da378b7ed Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 17:47:51 +0800 Subject: [PATCH 07/56] test(pt_expt): port silut activation + repformers accessors from #5393 Add silut/custom_silu support to _torch_activation using native torch ops (sigmoid, tanh, where) so the custom silu stays traceable by make_fx / torch.export. Cross-backend consistency tests cover multiple thresholds across the silu/tanh branches, and a pt_expt unit file exercises default/custom threshold, gradient flow, make_fx, and torch.export. Also port DescrptBlockRepformers accessor tests (get_rcut_smth, get_env_protection). The underlying accessor methods already exist on this branch; these tests guard against regressions. --- deepmd/pt_expt/utils/network.py | 9 ++ .../common/dpmodel/test_descriptor_dpa2.py | 36 ++++++ source/tests/consistent/test_activation.py | 59 +++++++++ source/tests/pt_expt/utils/test_activation.py | 120 ++++++++++++++++++ 4 files changed, 224 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_activation.py diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 1629ecb83a..adef443de9 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import math from typing import ( Any, ClassVar, @@ -182,6 +183,14 @@ def _torch_activation(x: torch.Tensor, name: str) -> torch.Tensor: return torch.sigmoid(x) elif name == "silu": return torch.nn.functional.silu(x) + elif name.startswith("silut") or name.startswith("custom_silu"): + threshold = float(name.split(":")[-1]) if ":" in name else 3.0 + sig_t = 1.0 / (1.0 + math.exp(-threshold)) + slope = sig_t + threshold * sig_t * (1.0 - sig_t) + const = threshold * sig_t + silu = x * torch.sigmoid(x) + tanh_branch = torch.tanh(slope * (x - threshold)) + const + return torch.where(x < threshold, silu, tanh_branch) elif name in ("none", "linear"): return x else: diff --git a/source/tests/common/dpmodel/test_descriptor_dpa2.py b/source/tests/common/dpmodel/test_descriptor_dpa2.py index 7867fee874..af58d12790 100644 --- a/source/tests/common/dpmodel/test_descriptor_dpa2.py +++ b/source/tests/common/dpmodel/test_descriptor_dpa2.py @@ -10,6 +10,9 @@ RepformerArgs, RepinitArgs, ) +from deepmd.dpmodel.descriptor.repformers import ( + DescrptBlockRepformers, +) from ...seed import ( GLOBAL_SEED, @@ -69,3 +72,36 @@ def test_self_consistency( for ii in [0, 1, 2, 3, 4]: np.testing.assert_equal(mm0[ii].shape, desired_shape[ii]) np.testing.assert_allclose(mm0[ii], mm1[ii]) + + +class TestDescrptBlockRepformersAccessors(unittest.TestCase): + def test_get_rcut_smth(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + ) + self.assertEqual(block.get_rcut_smth(), 5.0) + + def test_get_env_protection(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + env_protection=1.0, + ) + self.assertEqual(block.get_env_protection(), 1.0) + + def test_get_env_protection_default(self) -> None: + block = DescrptBlockRepformers( + rcut=6.0, + rcut_smth=5.0, + sel=40, + ntypes=2, + nlayers=3, + ) + self.assertEqual(block.get_env_protection(), 0.0) diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index 31351d4a9d..b46319e338 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, parameterized, ) @@ -29,6 +30,13 @@ from deepmd.pt.utils.utils import ( to_torch_tensor, ) +if INSTALLED_PT_EXPT: + import torch + + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE + from deepmd.pt_expt.utils.network import ( + _torch_activation, + ) if INSTALLED_TF: from deepmd.tf.common import get_activation_func as get_activation_fn_tf from deepmd.tf.env import ( @@ -98,3 +106,54 @@ def test_pd_consistent_with_ref(self): ActivationFn_pd(self.activation)(to_paddle_tensor(self.random_input)) ) np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") + def test_pt_expt_consistent_with_ref(self) -> None: + if INSTALLED_PT_EXPT: + x = torch.tensor( + self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE + ) + test = _torch_activation(x, self.activation).detach().numpy() + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + +@parameterized( + ( + "silut", # default threshold 3.0 + "silut:3.0", # explicit threshold 3.0 + "silut:10.0", # large threshold + "custom_silu:5.0", # alias + ), +) +class TestSilutVariantsConsistent(unittest.TestCase): + """Cross-backend consistency for silut with different thresholds.""" + + def setUp(self) -> None: + (self.activation,) = self.param + # Parse threshold to build input that covers both branches + threshold = ( + float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0 + ) + rng = np.random.default_rng(GLOBAL_SEED) + # Values below threshold (silu branch) and above threshold (tanh branch) + below = rng.uniform(-threshold - 5, threshold - 0.1, size=(5, 10)) + above = rng.uniform(threshold + 0.1, threshold + 20, size=(5, 10)) + self.random_input = np.concatenate([below, above], axis=0) + self.ref = get_activation_fn_dp(self.activation)(self.random_input) + + @unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") + def test_pt_consistent_with_ref(self) -> None: + if INSTALLED_PT: + test = torch_to_numpy( + ActivationFn_pt(self.activation)(to_torch_tensor(self.random_input)) + ) + np.testing.assert_allclose(self.ref, test, atol=1e-10) + + @unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed") + def test_pt_expt_consistent_with_ref(self) -> None: + if INSTALLED_PT_EXPT: + x = torch.tensor( + self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE + ) + test = _torch_activation(x, self.activation).detach().numpy() + np.testing.assert_allclose(self.ref, test, atol=1e-10) diff --git a/source/tests/pt_expt/utils/test_activation.py b/source/tests/pt_expt/utils/test_activation.py new file mode 100644 index 0000000000..dd8cc5b59f --- /dev/null +++ b/source/tests/pt_expt/utils/test_activation.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import numpy as np +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.utils.network import ( + get_activation_fn, +) +from deepmd.pt_expt.utils.network import ( + _torch_activation, +) + + +class TestSilutActivation: + """Tests for silut activation in _torch_activation.""" + + def setup_method(self) -> None: + # x values spanning both branches: below threshold and above + self.x_np = np.array( + [-5.0, -1.0, 0.0, 1.0, 2.5, 3.0, 5.0, 10.0, 15.0, 20.0], + dtype=np.float64, + ) + self.x_torch = torch.tensor(self.x_np, dtype=torch.float64) + + def test_silut_with_threshold(self) -> None: + """silut:10.0 matches dpmodel numerically.""" + result = _torch_activation(self.x_torch, "silut:10.0") + dp_fn = get_activation_fn("silut:10.0") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_default_threshold(self) -> None: + """Silut without parameter uses default threshold 3.0.""" + result = _torch_activation(self.x_torch, "silut") + dp_fn = get_activation_fn("silut") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_custom_silu_alias(self) -> None: + """custom_silu:5.0 is an alias for silut:5.0.""" + result = _torch_activation(self.x_torch, "custom_silu:5.0") + dp_fn = get_activation_fn("custom_silu:5.0") + expected = dp_fn(self.x_np) + np.testing.assert_allclose( + result.detach().numpy(), expected, rtol=1e-12, atol=1e-12 + ) + + def test_silut_gradient(self) -> None: + """Gradient flows through both branches of silut.""" + x = self.x_torch.clone().requires_grad_(True) + y = _torch_activation(x, "silut:3.0") + loss = y.sum() + loss.backward() + grad = x.grad + assert grad is not None + # gradient should be finite everywhere + assert torch.all(torch.isfinite(grad)) + # gradient should be non-zero for non-zero inputs + nonzero_mask = self.x_np != 0.0 + assert torch.all(grad[nonzero_mask] != 0.0) + + def test_silut_make_fx(self) -> None: + """make_fx can trace through silut activation.""" + + def fn(x: torch.Tensor) -> torch.Tensor: + return _torch_activation(x, "silut:10.0") + + traced = make_fx(fn)(self.x_torch) + result = traced(self.x_torch) + expected = _torch_activation(self.x_torch, "silut:10.0") + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12 + ) + + def test_silut_below_threshold_is_silu(self) -> None: + """Below threshold, silut equals silu exactly.""" + threshold = 10.0 + x_below = torch.tensor([-5.0, 0.0, 1.0, 5.0, 9.9], dtype=torch.float64) + result = _torch_activation(x_below, "silut:10.0") + silu = x_below * torch.sigmoid(x_below) + np.testing.assert_allclose( + result.detach().numpy(), silu.detach().numpy(), rtol=1e-14, atol=1e-14 + ) + + def test_silut_above_threshold_is_tanh_branch(self) -> None: + """Above threshold, silut equals tanh(slope*(x-T))+const.""" + import math + + threshold = 3.0 + sig_t = 1.0 / (1.0 + math.exp(-threshold)) + slope = sig_t + threshold * sig_t * (1.0 - sig_t) + const = threshold * sig_t + + x_above = torch.tensor([3.5, 5.0, 10.0, 20.0], dtype=torch.float64) + result = _torch_activation(x_above, "silut:3.0") + expected = torch.tanh(slope * (x_above - threshold)) + const + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-14, atol=1e-14 + ) + + def test_silut_export(self) -> None: + """torch.export.export can trace through silut activation.""" + + class SilutModule(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return _torch_activation(x, "silut:10.0") + + mod = SilutModule() + exported = torch.export.export(mod, (self.x_torch,)) + result = exported.module()(self.x_torch) + expected = _torch_activation(self.x_torch, "silut:10.0") + np.testing.assert_allclose( + result.detach().numpy(), expected.detach().numpy(), rtol=1e-12, atol=1e-12 + ) From 0b5468e4d5c4a27926996bfe2b43db5dca637347 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 18:22:14 +0800 Subject: [PATCH 08/56] test(pt_expt): assert virial in compile correctness tests Extend the compiled-vs-uncompiled assertions in TestCompiledConsistency (single-task) and _check_compile_correctness (multi-task) to also cover ``atom_energy`` and the reduced ``virial``. Atomic virial is intentionally not exercised because training never sets ``do_atomic_virial=True``. --- source/tests/pt_expt/test_multitask.py | 34 ++++++++++++++----------- source/tests/pt_expt/test_training.py | 35 ++++++++------------------ 2 files changed, 31 insertions(+), 38 deletions(-) diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py index 23b799ceb6..847ca42501 100644 --- a/source/tests/pt_expt/test_multitask.py +++ b/source/tests/pt_expt/test_multitask.py @@ -1469,7 +1469,7 @@ def _check_compile_correctness(self, share_fitting: bool = False) -> None: For each branch: feed the same batch through wrapper (which computes loss), call loss.backward(), then compare: - 1. model predictions (energy, force) + 1. model predictions (atom_energy, energy, force, virial) 2. loss values 3. parameter gradients (second-order, through force loss) """ @@ -1534,19 +1534,25 @@ def _check_compile_correctness(self, share_fitting: bool = False) -> None: task_key=task_key, ) - # Compare predictions - torch.testing.assert_close( - pred_c["energy"], - pred_uc["energy"], - atol=1e-10, - rtol=1e-10, - ) - torch.testing.assert_close( - pred_c["force"], - pred_uc["force"], - atol=1e-10, - rtol=1e-10, - ) + # Compare predictions: atom_energy, energy, force, virial. + # Atomic virial is not exercised because training does not + # pass ``do_atomic_virial=True``; the compiled graph is + # traced with the default (False) so per-atom virial is not + # computed by the compiled path. + for key in ("atom_energy", "energy", "force", "virial"): + self.assertIn( + key, pred_uc, f"uncompiled missing '{key}' (task={task_key})" + ) + self.assertIn( + key, pred_c, f"compiled missing '{key}' (task={task_key})" + ) + torch.testing.assert_close( + pred_c[key], + pred_uc[key], + atol=1e-10, + rtol=1e-10, + msg=f"{key} mismatch (task={task_key})", + ) torch.testing.assert_close(loss_c, loss_uc, atol=1e-10, rtol=1e-10) # Compare gradients (second-order, through force loss) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 1397a65c7d..e8a18dd1af 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -287,33 +287,20 @@ def test_compiled_matches_uncompiled(self) -> None: pred_c = compiled_model(coord.clone(), atype, box) - # Energy - torch.testing.assert_close( - pred_c["energy"], - pred_uc["energy"], - atol=1e-10, - rtol=1e-10, - msg="energy mismatch between compiled and uncompiled", - ) - # Force - self.assertIn("force", pred_c, "compiled model missing 'force'") - self.assertIn("force", pred_uc, "uncompiled model missing 'force'") - torch.testing.assert_close( - pred_c["force"], - pred_uc["force"], - atol=1e-10, - rtol=1e-10, - msg="force mismatch between compiled and uncompiled", - ) - # Virial - if "virial" in pred_uc: - self.assertIn("virial", pred_c, "compiled model missing 'virial'") + # Compare predictions: atom_energy, energy, force, virial. + # Atomic virial is not exercised here because training does + # not pass ``do_atomic_virial=True``; the compiled graph is + # traced with the default (False) so per-atom virial is not + # computed by the compiled path. + for key in ("atom_energy", "energy", "force", "virial"): + self.assertIn(key, pred_uc, f"uncompiled missing '{key}'") + self.assertIn(key, pred_c, f"compiled missing '{key}'") torch.testing.assert_close( - pred_c["virial"], - pred_uc["virial"], + pred_c[key], + pred_uc[key], atol=1e-10, rtol=1e-10, - msg="virial mismatch between compiled and uncompiled", + msg=f"{key} mismatch between compiled and uncompiled", ) finally: os.chdir(old_cwd) From 9bf006b2163c2d498adaa850cfb299f01a6d6ade Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 18:53:57 +0800 Subject: [PATCH 09/56] test(pt_expt): port silu compile and varying-natoms tests from #5393 Adds the remaining tests from PR #5393 that were not yet on this branch: ``test_training_loop_compiled_silu`` (silu activation under torch.compile) and ``TestCompiledVaryingNatoms`` (compiled training across systems with different atom counts). Also drops a stray unused ``threshold`` variable in ``test_silut_below_threshold_is_silu`` to match the upstream PR. --- source/tests/pt_expt/test_training.py | 122 ++++++++++++++++++ source/tests/pt_expt/utils/test_activation.py | 1 - 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index e8a18dd1af..102962209b 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -166,6 +166,16 @@ def test_training_loop_compiled(self) -> None: config = normalize(config) self._run_training(config) + def test_training_loop_compiled_silu(self) -> None: + """Run compiled training with silu activation.""" + config = _make_config(self.data_dir, numb_steps=5) + config["model"]["descriptor"]["activation_function"] = "silu" + config["model"]["fitting_net"]["activation_function"] = "silu" + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + self._run_training(config) + class TestCompiledDynamicShapes(unittest.TestCase): """Test that _CompiledModel handles varying nall via dynamic shapes.""" @@ -874,5 +884,117 @@ def test_uncompiled(self) -> None: self._run_steps(enable_compile=False) +def _create_small_system( + path: str, natoms_o: int = 2, natoms_h: int = 4, nframes: int = 10 +) -> None: + """Create a minimal deepmd data system with few atoms.""" + import numpy as np + + natoms = natoms_o + natoms_h + set_dir = os.path.join(path, "set.000") + os.makedirs(set_dir, exist_ok=True) + + with open(os.path.join(path, "type.raw"), "w") as f: + for _ in range(natoms_o): + f.write("0\n") + for _ in range(natoms_h): + f.write("1\n") + with open(os.path.join(path, "type_map.raw"), "w") as f: + f.write("O\nH\n") + + rng = np.random.default_rng(42) + box_len = 5.0 + box = np.zeros((nframes, 9), dtype=np.float32) + box[:, 0] = box_len + box[:, 4] = box_len + box[:, 8] = box_len + coord = rng.uniform(0, box_len, size=(nframes, natoms * 3)).astype(np.float32) + energy = rng.normal(-100, 10, size=(nframes,)).astype(np.float32) + force = rng.normal(0, 1, size=(nframes, natoms * 3)).astype(np.float32) + np.save(os.path.join(set_dir, "coord.npy"), coord) + np.save(os.path.join(set_dir, "force.npy"), force) + np.save(os.path.join(set_dir, "energy.npy"), energy) + np.save(os.path.join(set_dir, "box.npy"), box) + + +class TestCompiledVaryingNatoms(unittest.TestCase): + """Test compiled training with systems of different atom counts. + + Uses the 192-atom ``data_0`` alongside a synthetic 6-atom system so that + different ``nloc`` / ``nall`` appear across steps, exercising the + dynamic-shape compile path. + + ``dp_random.choice`` is mocked to alternate [0, 1, 0, 1, ...] so that + both systems are guaranteed to be sampled. + + ``batch_size: "auto"`` assigns different batch sizes per system (based + on atom count), so both ``nframes`` and ``natoms`` vary across steps. + """ + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + cls.small_data_dir = tempfile.mkdtemp(prefix="pt_expt_small_data_") + _create_small_system(cls.small_data_dir) + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls.small_data_dir, ignore_errors=True) + + def _make_varying_config(self, enable_compile: bool, numb_steps: int = 10) -> dict: + """Config with two systems of different natoms and auto batch size.""" + config = _make_config(self.data_dir, numb_steps=numb_steps) + config["training"]["training_data"]["systems"].append(self.small_data_dir) + config["training"]["training_data"]["batch_size"] = "auto" + if enable_compile: + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + return config + + def _run_steps(self, config: dict, nsteps: int = 6) -> None: + """Run *nsteps* training steps and assert finite loss at each.""" + # Alternate between system 0 (192 atoms) and system 1 (6 atoms) + sys_sequence = [i % 2 for i in range(nsteps)] + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_varying_") + try: + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + trainer = get_trainer(config) + trainer.wrapper.train() + with patch( + "deepmd.utils.data_system.dp_random.choice", + side_effect=sys_sequence, + ): + for _ in range(nsteps): + trainer.optimizer.zero_grad(set_to_none=True) + inp, lab = trainer.get_data(is_train=True) + lr = trainer.scheduler.get_last_lr()[0] + _, loss, _ = trainer.wrapper(**inp, cur_lr=lr, label=lab) + loss.backward() + trainer._optimizer_step() + self.assertFalse(torch.isnan(loss), "loss is NaN") + self.assertFalse(torch.isinf(loss), "loss is Inf") + finally: + os.chdir(old_cwd) + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_compiled_varying_natoms(self) -> None: + """Compiled training with 192-atom and 6-atom systems.""" + config = self._make_varying_config(enable_compile=True) + self._run_steps(config) + + def test_uncompiled_varying_natoms(self) -> None: + """Uncompiled training with varying natoms as baseline.""" + config = self._make_varying_config(enable_compile=False) + self._run_steps(config) + + if __name__ == "__main__": unittest.main() diff --git a/source/tests/pt_expt/utils/test_activation.py b/source/tests/pt_expt/utils/test_activation.py index dd8cc5b59f..23550d3315 100644 --- a/source/tests/pt_expt/utils/test_activation.py +++ b/source/tests/pt_expt/utils/test_activation.py @@ -80,7 +80,6 @@ def fn(x: torch.Tensor) -> torch.Tensor: def test_silut_below_threshold_is_silu(self) -> None: """Below threshold, silut equals silu exactly.""" - threshold = 10.0 x_below = torch.tensor([-5.0, 0.0, 1.0, 5.0, 9.9], dtype=torch.float64) result = _torch_activation(x_below, "silut:10.0") silu = x_below * torch.sigmoid(x_below) From 7722f5257c114a01a704a3614aeacc6ad1883125 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 19:11:53 +0800 Subject: [PATCH 10/56] test(pt_expt): compare compiled vs uncompiled with varying natoms Replace the two finite-loss smoke tests with a single test that builds both trainers, syncs weights, and per-step asserts identical predictions, loss, and per-parameter gradients (second-order through the force loss). Also add a silu full-model consistency test and write virial.npy in the small synthetic system so the virial passthrough is exercised on every step. Factor the prediction/grad comparison loops into shared helpers. --- source/tests/pt_expt/test_training.py | 218 +++++++++++++++++--------- 1 file changed, 147 insertions(+), 71 deletions(-) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 102962209b..649b1914da 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -40,6 +40,53 @@ "water", ) +# Keys present on the compiled path. ``atom_virial`` is intentionally excluded: +# training never passes ``do_atomic_virial=True``, so the compiled graph is +# traced with the default (False) and per-atom virial is not emitted. +_COMPILE_PRED_KEYS = ("atom_energy", "energy", "force", "virial") +_COMPILE_TOL = {"atol": 1e-10, "rtol": 1e-10} + + +def _assert_compile_predictions_match( + testcase: unittest.TestCase, + out_c: dict, + out_uc: dict, + *, + ctx: str = "", +) -> None: + for key in _COMPILE_PRED_KEYS: + testcase.assertIn(key, out_uc, f"{ctx}uncompiled missing '{key}'") + testcase.assertIn(key, out_c, f"{ctx}compiled missing '{key}'") + torch.testing.assert_close( + out_c[key], + out_uc[key], + **_COMPILE_TOL, + msg=f"{ctx}{key} mismatch between compiled and uncompiled", + ) + + +def _assert_compile_grads_match( + testcase: unittest.TestCase, + model_c: torch.nn.Module, + model_uc: torch.nn.Module, + *, + ctx: str = "", +) -> None: + for (name_uc, p_uc), (_, p_c) in zip( + model_uc.named_parameters(), + model_c.named_parameters(), + strict=True, + ): + if p_uc.grad is None: + continue + testcase.assertIsNotNone(p_c.grad, msg=f"{ctx}grad is None for {name_uc}") + torch.testing.assert_close( + p_c.grad, + p_uc.grad, + **_COMPILE_TOL, + msg=f"{ctx}grad mismatch on {name_uc}", + ) + def _make_config(data_dir: str, numb_steps: int = 5) -> dict: """Build a minimal config dict pointing at *data_dir*.""" @@ -242,37 +289,41 @@ def setUpClass(cls) -> None: raise unittest.SkipTest(f"Example data not found: {data_dir}") cls.data_dir = data_dir - def test_compiled_matches_uncompiled(self) -> None: - """Energy, force, virial from compiled model must match uncompiled.""" + def _check_consistency(self, activation: str | None = None) -> None: + """Compiled model predictions match uncompiled for the given activation. + + ``activation`` overrides both descriptor and fitting-net activation + functions when provided. ``None`` keeps the config default (tanh). + """ from deepmd.pt_expt.train.training import ( _CompiledModel, ) - config = _make_config(self.data_dir, numb_steps=1) - # enable virial in loss so the model returns it - config["loss"]["start_pref_v"] = 1.0 - config["loss"]["limit_pref_v"] = 1.0 - config = update_deepmd_input(config, warning=False) - config = normalize(config) + def _build_config(enable_compile: bool) -> dict: + config = _make_config(self.data_dir, numb_steps=1) + # enable virial in loss so the model returns it + config["loss"]["start_pref_v"] = 1.0 + config["loss"]["limit_pref_v"] = 1.0 + if activation is not None: + config["model"]["descriptor"]["activation_function"] = activation + config["model"]["fitting_net"]["activation_function"] = activation + if enable_compile: + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + return normalize(config) tmpdir = tempfile.mkdtemp(prefix="pt_expt_consistency_") try: old_cwd = os.getcwd() os.chdir(tmpdir) try: - trainer = get_trainer(config) + trainer = get_trainer(_build_config(enable_compile=False)) # Uncompiled model reference uncompiled_model = trainer.model uncompiled_model.eval() # Build compiled model from the same weights - config_compiled = _make_config(self.data_dir, numb_steps=1) - config_compiled["loss"]["start_pref_v"] = 1.0 - config_compiled["loss"]["limit_pref_v"] = 1.0 - config_compiled["training"]["enable_compile"] = True - config_compiled = update_deepmd_input(config_compiled, warning=False) - config_compiled = normalize(config_compiled) - trainer_compiled = get_trainer(config_compiled) + trainer_compiled = get_trainer(_build_config(enable_compile=True)) compiled_model = trainer_compiled.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) @@ -297,26 +348,20 @@ def test_compiled_matches_uncompiled(self) -> None: pred_c = compiled_model(coord.clone(), atype, box) - # Compare predictions: atom_energy, energy, force, virial. - # Atomic virial is not exercised here because training does - # not pass ``do_atomic_virial=True``; the compiled graph is - # traced with the default (False) so per-atom virial is not - # computed by the compiled path. - for key in ("atom_energy", "energy", "force", "virial"): - self.assertIn(key, pred_uc, f"uncompiled missing '{key}'") - self.assertIn(key, pred_c, f"compiled missing '{key}'") - torch.testing.assert_close( - pred_c[key], - pred_uc[key], - atol=1e-10, - rtol=1e-10, - msg=f"{key} mismatch between compiled and uncompiled", - ) + _assert_compile_predictions_match(self, pred_c, pred_uc) finally: os.chdir(old_cwd) finally: shutil.rmtree(tmpdir, ignore_errors=True) + def test_compiled_matches_uncompiled(self) -> None: + """Energy, force, virial from compiled model must match uncompiled.""" + self._check_consistency() + + def test_compiled_matches_uncompiled_silu(self) -> None: + """Same numerical equivalence under silu activation (full model).""" + self._check_consistency(activation="silu") + def test_compiled_gradients_match_uncompiled(self) -> None: """Parameter gradients from compiled model must match uncompiled. @@ -373,23 +418,9 @@ def test_compiled_gradients_match_uncompiled(self) -> None: loss_uc.backward() loss_c.backward() - for (name_uc, p_uc), (name_c, p_c) in zip( - trainer_uc.model.named_parameters(), - compiled_model.original_model.named_parameters(), - strict=True, - ): - if p_uc.grad is not None: - self.assertIsNotNone( - p_c.grad, - msg=f"grad is None for {name_c}", - ) - torch.testing.assert_close( - p_c.grad, - p_uc.grad, - atol=1e-10, - rtol=1e-10, - msg=f"grad mismatch on {name_uc}", - ) + _assert_compile_grads_match( + self, compiled_model.original_model, trainer_uc.model + ) finally: os.chdir(old_cwd) finally: @@ -911,10 +942,12 @@ def _create_small_system( coord = rng.uniform(0, box_len, size=(nframes, natoms * 3)).astype(np.float32) energy = rng.normal(-100, 10, size=(nframes,)).astype(np.float32) force = rng.normal(0, 1, size=(nframes, natoms * 3)).astype(np.float32) + virial = rng.normal(0, 1, size=(nframes, 9)).astype(np.float32) np.save(os.path.join(set_dir, "coord.npy"), coord) np.save(os.path.join(set_dir, "force.npy"), force) np.save(os.path.join(set_dir, "energy.npy"), energy) np.save(os.path.join(set_dir, "box.npy"), box) + np.save(os.path.join(set_dir, "virial.npy"), virial) class TestCompiledVaryingNatoms(unittest.TestCase): @@ -949,14 +982,31 @@ def _make_varying_config(self, enable_compile: bool, numb_steps: int = 10) -> di config = _make_config(self.data_dir, numb_steps=numb_steps) config["training"]["training_data"]["systems"].append(self.small_data_dir) config["training"]["training_data"]["batch_size"] = "auto" + # enable virial in loss so the model returns it (virial.npy exists in + # both systems), exercising the compiled virial passthrough on each step + config["loss"]["start_pref_v"] = 1.0 + config["loss"]["limit_pref_v"] = 1.0 if enable_compile: config["training"]["enable_compile"] = True config = update_deepmd_input(config, warning=False) config = normalize(config) return config - def _run_steps(self, config: dict, nsteps: int = 6) -> None: - """Run *nsteps* training steps and assert finite loss at each.""" + def test_compiled_matches_uncompiled_varying_natoms(self) -> None: + """Compiled and uncompiled produce identical predictions/loss/grads + across batches with varying ``nframes`` and ``natoms``. + + The loss config has ``start_pref_f=1000`` and ``start_pref_v=1.0``, + so ``loss.backward()`` propagates through ``F = -dE/dr`` (computed + via ``autograd.grad(..., create_graph=True)``); the per-parameter + grad comparison therefore exercises the second-order derivative + ``d^2 E / (dr d theta)`` on each step at each system size. + """ + from deepmd.pt_expt.train.training import ( + _CompiledModel, + ) + + nsteps = 4 # Alternate between system 0 (192 atoms) and system 1 (6 atoms) sys_sequence = [i % 2 for i in range(nsteps)] @@ -965,36 +1015,62 @@ def _run_steps(self, config: dict, nsteps: int = 6) -> None: old_cwd = os.getcwd() os.chdir(tmpdir) try: - trainer = get_trainer(config) - trainer.wrapper.train() + trainer_uc = get_trainer(self._make_varying_config(False)) + trainer_c = get_trainer(self._make_varying_config(True)) + compiled_model = trainer_c.wrapper.model["Default"] + self.assertIsInstance(compiled_model, _CompiledModel) + + # Sync weights so predictions can be compared exactly + compiled_model.original_model.load_state_dict( + trainer_uc.model.state_dict() + ) + trainer_uc.wrapper.train() + trainer_c.wrapper.train() + with patch( "deepmd.utils.data_system.dp_random.choice", side_effect=sys_sequence, ): - for _ in range(nsteps): - trainer.optimizer.zero_grad(set_to_none=True) - inp, lab = trainer.get_data(is_train=True) - lr = trainer.scheduler.get_last_lr()[0] - _, loss, _ = trainer.wrapper(**inp, cur_lr=lr, label=lab) - loss.backward() - trainer._optimizer_step() - self.assertFalse(torch.isnan(loss), "loss is NaN") - self.assertFalse(torch.isinf(loss), "loss is Inf") + for step in range(nsteps): + trainer_uc.optimizer.zero_grad(set_to_none=True) + trainer_c.optimizer.zero_grad(set_to_none=True) + + # Single shared batch; mock yields one value per call + inp, lab = trainer_uc.get_data(is_train=True) + lr = trainer_uc.scheduler.get_last_lr()[0] + + out_uc, loss_uc, _ = trainer_uc.wrapper( + **inp, cur_lr=lr, label=lab + ) + out_c, loss_c, _ = trainer_c.wrapper( + **inp, cur_lr=lr, label=lab + ) + + ctx = f"step={step} " + _assert_compile_predictions_match(self, out_c, out_uc, ctx=ctx) + torch.testing.assert_close( + loss_c, + loss_uc, + **_COMPILE_TOL, + msg=f"{ctx}loss mismatch", + ) + + loss_uc.backward() + loss_c.backward() + _assert_compile_grads_match( + self, + compiled_model.original_model, + trainer_uc.model, + ctx=ctx, + ) + + trainer_uc._optimizer_step() + trainer_c._optimizer_step() finally: os.chdir(old_cwd) finally: shutil.rmtree(tmpdir, ignore_errors=True) - def test_compiled_varying_natoms(self) -> None: - """Compiled training with 192-atom and 6-atom systems.""" - config = self._make_varying_config(enable_compile=True) - self._run_steps(config) - - def test_uncompiled_varying_natoms(self) -> None: - """Uncompiled training with varying natoms as baseline.""" - config = self._make_varying_config(enable_compile=False) - self._run_steps(config) - if __name__ == "__main__": unittest.main() From be14ac25d8dc2356ea8aaf71c2adcb38dd65e9d8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 16 Apr 2026 23:35:35 +0800 Subject: [PATCH 11/56] test(pt_expt): cover DPA2/DPA3 in varying-natoms compile correctness Parametrize TestCompiledVaryingNatoms over se_e2_a, DPA2 and DPA3 with strict atol=rtol=1e-10 on float64 (machine epsilon). DPA1 (se_atten) is intentionally omitted: its compiled path is intermittently incorrect (~20% of compiles produce grad diffs up to 0.67 at the first embedding layer), and including it would have required masking the bug with a loose tolerance. --- source/tests/pt_expt/test_training.py | 115 ++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 8 deletions(-) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 649b1914da..272c143270 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -46,6 +46,70 @@ _COMPILE_PRED_KEYS = ("atom_energy", "energy", "force", "virial") _COMPILE_TOL = {"atol": 1e-10, "rtol": 1e-10} +# DPA3 descriptor config used to extend the varying-natoms compile-correctness +# test to a non-trivial architecture (repflow with attention). ``precision: +# float64`` is set explicitly so the strict ``atol=rtol=1e-10`` comparison +# holds at machine epsilon. +# +# DPA1 (se_atten) is intentionally NOT covered here: inductor's compile of the +# se_atten attention path is intermittently incorrect — see the "known +# limitations" section of the multi-task compile memo for details. +_DESCRIPTOR_DPA2 = { + "type": "dpa2", + "repinit": { + "rcut": 4.0, + "rcut_smth": 0.5, + "nsel": 18, + "neuron": [2, 4, 8], + "axis_neuron": 4, + "activation_function": "tanh", + }, + "repformer": { + "rcut": 3.0, + "rcut_smth": 0.5, + "nsel": 12, + "nlayers": 2, + "g1_dim": 8, + "g2_dim": 5, + "attn2_hidden": 3, + "attn2_nhead": 1, + "attn1_hidden": 5, + "attn1_nhead": 1, + "axis_neuron": 4, + "update_h2": False, + "update_g1_has_conv": True, + "update_g1_has_grrg": True, + "update_g1_has_drrd": True, + "update_g1_has_attn": True, + "update_g2_has_g1g1": True, + "update_g2_has_attn": True, + "attn2_has_gate": True, + }, + "precision": "float64", + "seed": 1, + "add_tebd_to_repinit_out": False, +} + +_DESCRIPTOR_DPA3 = { + "type": "dpa3", + "repflow": { + "n_dim": 8, + "e_dim": 5, + "a_dim": 4, + "nlayers": 2, + "e_rcut": 3.0, + "e_rcut_smth": 0.5, + "e_sel": 12, + "a_rcut": 3.0, + "a_rcut_smth": 0.5, + "a_sel": 8, + "axis_neuron": 4, + }, + "precision": "float64", + "concat_output_tebd": False, + "seed": 1, +} + def _assert_compile_predictions_match( testcase: unittest.TestCase, @@ -977,30 +1041,45 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: shutil.rmtree(cls.small_data_dir, ignore_errors=True) - def _make_varying_config(self, enable_compile: bool, numb_steps: int = 10) -> dict: - """Config with two systems of different natoms and auto batch size.""" - config = _make_config(self.data_dir, numb_steps=numb_steps) + def _make_varying_config( + self, + enable_compile: bool, + descriptor: dict | None = None, + ) -> dict: + """Config with two systems of different natoms and auto batch size. + + ``descriptor`` overrides the default se_e2_a descriptor when given. + """ + config = _make_config(self.data_dir) config["training"]["training_data"]["systems"].append(self.small_data_dir) config["training"]["training_data"]["batch_size"] = "auto" # enable virial in loss so the model returns it (virial.npy exists in # both systems), exercising the compiled virial passthrough on each step config["loss"]["start_pref_v"] = 1.0 config["loss"]["limit_pref_v"] = 1.0 + if descriptor is not None: + config["model"]["descriptor"] = descriptor if enable_compile: config["training"]["enable_compile"] = True config = update_deepmd_input(config, warning=False) config = normalize(config) return config - def test_compiled_matches_uncompiled_varying_natoms(self) -> None: - """Compiled and uncompiled produce identical predictions/loss/grads - across batches with varying ``nframes`` and ``natoms``. + def _check_varying_natoms(self, descriptor: dict | None = None) -> None: + """Per-step compiled-vs-uncompiled comparison for the given descriptor. The loss config has ``start_pref_f=1000`` and ``start_pref_v=1.0``, so ``loss.backward()`` propagates through ``F = -dE/dr`` (computed via ``autograd.grad(..., create_graph=True)``); the per-parameter grad comparison therefore exercises the second-order derivative ``d^2 E / (dr d theta)`` on each step at each system size. + + Verifies multi-step training-trajectory equivalence: weights are + synced once at the start, then both trainers step their own Adam + states forward. All assertions use the strict + ``atol=rtol=1e-10`` tolerance; if a descriptor's compiled path + cannot meet that on float64 the descriptor has a real numerical + problem (see the DPA1 limitation note where this happened). """ from deepmd.pt_expt.train.training import ( _CompiledModel, @@ -1015,8 +1094,8 @@ def test_compiled_matches_uncompiled_varying_natoms(self) -> None: old_cwd = os.getcwd() os.chdir(tmpdir) try: - trainer_uc = get_trainer(self._make_varying_config(False)) - trainer_c = get_trainer(self._make_varying_config(True)) + trainer_uc = get_trainer(self._make_varying_config(False, descriptor)) + trainer_c = get_trainer(self._make_varying_config(True, descriptor)) compiled_model = trainer_c.wrapper.model["Default"] self.assertIsInstance(compiled_model, _CompiledModel) @@ -1071,6 +1150,26 @@ def test_compiled_matches_uncompiled_varying_natoms(self) -> None: finally: shutil.rmtree(tmpdir, ignore_errors=True) + def test_compiled_matches_uncompiled_varying_natoms_se_e2_a(self) -> None: + """se_e2_a: compiled vs uncompiled match across varying nframes/natoms.""" + self._check_varying_natoms() # uses default se_e2_a from _make_config + + def test_compiled_matches_uncompiled_varying_natoms_dpa2(self) -> None: + """DPA2: compiled vs uncompiled match across varying nframes/natoms. + + Exercises the DPA2 repinit + repformers stack; matches at machine + epsilon (~1e-12) on float64 just like se_e2_a. + """ + self._check_varying_natoms(_DESCRIPTOR_DPA2) + + def test_compiled_matches_uncompiled_varying_natoms_dpa3(self) -> None: + """DPA3: compiled vs uncompiled match across varying nframes/natoms. + + Exercises a non-trivial multi-layer repflow descriptor; matches at + machine epsilon (~1e-12) on float64 just like se_e2_a. + """ + self._check_varying_natoms(_DESCRIPTOR_DPA3) + if __name__ == "__main__": unittest.main() From 4c0b8ec4d10c9a03a1e12eccfd70faf20e3edf50 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 17 Apr 2026 00:06:05 +0800 Subject: [PATCH 12/56] test(pt_expt): exercise DPA2 three-body branch in compile correctness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable use_three_body=True on the DPA2 varying-natoms test so the compiled three-body neighbor path is also covered. three_body_rcut=3.0 matches the repformer rcut and is large enough to find neighbors in the 6-atom small system (~2.75Å nearest-neighbor distance). --- source/tests/pt_expt/test_training.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 272c143270..8c7f7e2a76 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -63,6 +63,10 @@ "neuron": [2, 4, 8], "axis_neuron": 4, "activation_function": "tanh", + "use_three_body": True, + "three_body_sel": 12, + "three_body_rcut": 3.0, + "three_body_rcut_smth": 0.5, }, "repformer": { "rcut": 3.0, From 80c714c6a2dd74efff77c548c246f6c3fd4270b7 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 17 Apr 2026 19:49:31 +0800 Subject: [PATCH 13/56] fix(dpmodel): restore nf in reshapes to fix zero-atom and add silu_backward decomp Revert (-1, nloc*nnei, ...) reshapes back to (nf, -1, ...) in env_mat.py and general_fitting.py. The -1-for-nf pattern breaks zero-atom systems: numpy cannot infer -1 when other dims multiply to zero (0/0), and torch.export shape assertions hit Mod(0,0). Using nf is safe because _TRACE_NFRAMES=7 already prevents symbolic-tracer specialisation during training compile. Add silu_backward decomposition table to make_fx in training.py so inductor can compile second-order gradients through silu without requiring a fused higher-order derivative kernel. --- deepmd/dpmodel/fitting/general_fitting.py | 22 ++++++---------------- deepmd/dpmodel/utils/env_mat.py | 12 +++++------- deepmd/pt_expt/train/training.py | 11 +++++++++++ 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index ee270c6c7a..b9129a4364 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -674,12 +674,7 @@ def _call_common( if self.numb_fparam > 0: assert fparam is not None, "fparam should not be None" try: - # Use -1 for nframes so the shape is inferred from the total - # size. Passing the concrete symbol `nf` here would let - # torch.fx's symbolic tracer specialise when `nf` happens to - # equal another tensor dim (e.g. numb_fparam), baking the - # batch size into the compiled graph. - fparam = xp.reshape(fparam, (-1, self.numb_fparam)) + fparam = xp.reshape(fparam, (nf, self.numb_fparam)) except (ValueError, RuntimeError) as e: raise ValueError( f"input fparam: cannot reshape {fparam.shape} " @@ -687,7 +682,7 @@ def _call_common( ) from e fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...] fparam = xp.tile( - xp.reshape(fparam, (-1, 1, self.numb_fparam)), (1, nloc, 1) + xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1) ) xx = xp.concat( [xx, fparam], @@ -702,9 +697,7 @@ def _call_common( if self.numb_aparam > 0 and not self.use_aparam_as_mask: assert aparam is not None, "aparam should not be None" try: - # Use -1 for nframes so the shape is inferred from the total - # size; see the fparam branch above for rationale. - aparam = xp.reshape(aparam, (-1, nloc, self.numb_aparam)) + aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam)) except (ValueError, RuntimeError) as e: raise ValueError( f"input aparam: cannot reshape {aparam.shape} " @@ -751,11 +744,8 @@ def _call_common( device=array_api_compat.device(descriptor), ) for type_i in range(self.ntypes): - # Use -1 for nframes so the shape is inferred; see the fparam - # branch above for rationale (avoid symbolic-dim collision - # with numb_fparam / other dims during symbolic tracing). mask = xp.tile( - xp.reshape((atype == type_i), (-1, nloc, 1)), + xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out), ) atom_property = self.nets[(type_i,)](xx) @@ -772,7 +762,7 @@ def _call_common( if self.eval_return_middle_output and len(self.neuron) > 0: mid = self.nets[(type_i,)].call_until_last(xx) mid_mask = xp.tile( - xp.reshape((atype == type_i), (-1, nloc, 1)), + xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, self.neuron[-1]), ) mid = xp.where(mid_mask, mid, xp.zeros_like(mid)) @@ -789,7 +779,7 @@ def _call_common( xp.reshape(atype, (-1,)), axis=0, ), - (-1, nloc, net_dim_out), + (nf, nloc, net_dim_out), ) # nf x nloc exclude_mask = self.emask.build_type_exclude_mask(atype) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 6b798ae498..0b0bd18c35 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -69,21 +69,19 @@ def _make_env_mat( nf, nloc, nnei = nlist.shape # nf x nall x 3 # Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise - # both to (nf, nall, 3) using -1 for nframes so the shape is inferred - # from the total size. Passing the symbolic nf here can trigger - # torch.fx symbolic-tracer specialisation when nf happens to collide - # with another dim (e.g. numb_fparam) during training compile. + # both to (nf, nall, 3) using shape-based inference so the concrete nf + # value is not baked into the reshape. if coord.ndim == 2: coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3)) mask = nlist >= 0 nlist = nlist * xp.astype(mask, nlist.dtype) # nf x (nloc x nnei) x 3 - index = xp.tile(xp.reshape(nlist, (-1, nloc * nnei, 1)), (1, 1, 3)) + index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3)) coord_r = xp_take_along_axis(coord, index, 1) # nf x nloc x nnei x 3 - coord_r = xp.reshape(coord_r, (-1, nloc, nnei, 3)) + coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (-1, nloc, 1, 3)) + coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, nloc, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 8771c52af5..f5ccde3ca0 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -250,10 +250,21 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: fparam = _expand(fparam) aparam = _expand(aparam) + # Decompose silu_backward into primitive ops (sigmoid + mul + ...) + # so that inductor can compile the graph without requiring a + # higher-order derivative that PyTorch does not register for + # the fused silu backward kernel. + from torch._decomp import ( + get_decompositions, + ) + + decomp_table = get_decompositions([torch.ops.aten.silu_backward.default]) + traced_lower = make_fx( fn, tracing_mode="symbolic", _allow_non_fake_inputs=True, + decomposition_table=decomp_table, )(ext_coord, ext_atype, nlist, mapping, fparam, aparam) # make_fx inserts aten.detach.default for saved tensors used in the From 6158d9cea6269b8c36aef345e00d774d58141982 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 17 Apr 2026 20:14:28 +0800 Subject: [PATCH 14/56] fix: address CodeQL findings in PR #5397 Replace assert with if/raise ValueError for user-facing config validation (data_stat_protect, finetune branch/head checks). Wrap train() in try/finally for destroy_process_group cleanup. Add parents=True, exist_ok=True to stat_file mkdir. Add strict=True to zip() calls. Fix minor test issues. --- deepmd/pt_expt/entrypoints/main.py | 29 +++++++++++++------------- deepmd/pt_expt/train/training.py | 15 ++++++++----- deepmd/pt_expt/utils/finetune.py | 18 +++++++++------- source/tests/pt_expt/test_multitask.py | 6 +++--- 4 files changed, 38 insertions(+), 30 deletions(-) diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index 52c5e26016..40302ee7b3 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -91,7 +91,7 @@ def get_trainer( with h5py.File(stat_file_path, "w"): pass else: - Path(stat_file_path).mkdir() + Path(stat_file_path).mkdir(parents=True, exist_ok=True) stat_file_path = DPPath(stat_file_path, "a") else: # Multi-task: build per-task data systems @@ -143,7 +143,7 @@ def get_trainer( with h5py.File(_sf, "w"): pass else: - Path(_sf).mkdir(parents=True) + Path(_sf).mkdir(parents=True, exist_ok=True) stat_file_path[model_key] = DPPath(_sf, "a") else: stat_file_path[model_key] = None @@ -290,18 +290,19 @@ def train( if os.environ.get("LOCAL_RANK") is not None: dist.init_process_group(backend="cuda:nccl,cpu:gloo") - trainer = get_trainer( - config, - init_model, - restart, - finetune_model=finetune, - finetune_links=finetune_links, - shared_links=shared_links, - ) - trainer.run() - - if dist.is_available() and dist.is_initialized(): - dist.destroy_process_group() + try: + trainer = get_trainer( + config, + init_model, + restart, + finetune_model=finetune, + finetune_links=finetune_links, + shared_links=shared_links, + ) + trainer.run() + finally: + if dist.is_available() and dist.is_initialized(): + dist.destroy_process_group() def freeze( diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index f5ccde3ca0..1ae16c92b1 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -611,13 +611,16 @@ def _make_sample( for ii in model_params["model_dict"] ] ) - assert np.allclose(_data_stat_protect, _data_stat_protect[0]), ( - "Model key 'data_stat_protect' must be the same in each branch when multitask!" - ) + if not np.allclose(_data_stat_protect, _data_stat_protect[0]): + raise ValueError( + "Model key 'data_stat_protect' must be the same in each branch when multitask!" + ) self.wrapper.share_params( shared_links, resume=(resuming and not self._finetune_update_stat) or self.rank != 0, - model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ), data_stat_protect=_data_stat_protect[0], ) @@ -825,7 +828,9 @@ def _make_sample( self._unwrapped.share_params( shared_links, resume=True, - model_key_prob_map=dict(zip(self.model_keys, self.model_prob)), + model_key_prob_map=dict( + zip(self.model_keys, self.model_prob, strict=True) + ), ) if optimizer_state_dict is not None: diff --git a/deepmd/pt_expt/utils/finetune.py b/deepmd/pt_expt/utils/finetune.py index d26db9e847..473bb43710 100644 --- a/deepmd/pt_expt/utils/finetune.py +++ b/deepmd/pt_expt/utils/finetune.py @@ -105,10 +105,11 @@ def get_finetune_rules( finetune_links["Default"] = finetune_rule else: # Multi-task target — mirrors PT's logic - assert model_branch == "", ( - "Multi-task fine-tuning does not support command-line branches chosen!" - "Please define the 'finetune_head' in each model params!" - ) + if model_branch != "": + raise ValueError( + "Multi-task fine-tuning does not support command-line branches chosen! " + "Please define the 'finetune_head' in each model params!" + ) if not finetune_from_multi_task: pretrained_keys = ["Default"] else: @@ -120,10 +121,11 @@ def get_finetune_rules( and model_config["model_dict"][model_key]["finetune_head"] != "RANDOM" ): pretrained_key = model_config["model_dict"][model_key]["finetune_head"] - assert pretrained_key in pretrained_keys, ( - f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model!" - f"Available heads are: {list(pretrained_keys)}" - ) + if pretrained_key not in pretrained_keys: + raise ValueError( + f"'{pretrained_key}' head chosen to finetune not exist in the pretrained model! " + f"Available heads are: {list(pretrained_keys)}" + ) model_branch_from = pretrained_key elif ( "finetune_head" not in model_config["model_dict"][model_key] diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py index 847ca42501..e5c6955ac0 100644 --- a/source/tests/pt_expt/test_multitask.py +++ b/source/tests/pt_expt/test_multitask.py @@ -936,7 +936,7 @@ def test_multitask_finetune_no_change_model_params(self) -> None: # Also set trainable=False to verify it's preserved ft_config_true["model"]["shared_dict"]["my_descriptor"]["trainable"] = False ft_config_true["model"], _ = preprocess_shared_params(ft_config_true["model"]) - model_config_true, finetune_links_true = get_finetune_rules( + model_config_true, _finetune_links_true = get_finetune_rules( ckpt_path, deepcopy(ft_config_true["model"]), change_model_params=True ) @@ -1164,7 +1164,7 @@ def test_multitask_restart(self) -> None: def test_multitask_freeze(self) -> None: """Train, then freeze with --head and verify. - Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + Only runs for dpa3 descriptor to avoid redundant slow freeze tests. """ if self.descriptor.get("type") != "dpa3": return @@ -1948,7 +1948,7 @@ def test_gradient_accumulation(self) -> None: # Verify descriptor params are aliased (share_params) mt_desc_2 = mt_trainer.wrapper.model["model_2"].atomic_model.descriptor - for (n1, p1), (n2, p2) in zip( + for (n1, p1), (_n2, p2) in zip( mt_desc.named_parameters(), mt_desc_2.named_parameters(), strict=True ): assert p1.data_ptr() == p2.data_ptr(), ( From c2efbf1a125949c88dce5cf752b990de347b1f3b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 17 Apr 2026 20:43:04 +0800 Subject: [PATCH 15/56] fix(pt): wrap fparam/aparam reshape with descriptive ValueError Match the dpmodel try/except pattern so shape mismatches produce a clear error instead of a raw RuntimeError from torch.view. --- deepmd/pt/model/task/fitting.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index a07b408d72..439d3d11d9 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -779,6 +779,11 @@ def _forward_common( assert fparam is not None, "fparam should not be None" assert self.fparam_avg is not None assert self.fparam_inv_std is not None + if fparam.numel() != nf * self.numb_fparam: + raise ValueError( + f"input fparam: cannot reshape {list(fparam.shape)} " + f"into ({nf}, {self.numb_fparam})." + ) fparam = fparam.view([nf, self.numb_fparam]) nb, _ = fparam.shape t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) @@ -799,6 +804,11 @@ def _forward_common( assert aparam is not None, "aparam should not be None" assert self.aparam_avg is not None assert self.aparam_inv_std is not None + if aparam.numel() % (nf * self.numb_aparam) != 0: + raise ValueError( + f"input aparam: cannot reshape {list(aparam.shape)} " + f"into ({nf}, nloc, {self.numb_aparam})." + ) aparam = aparam.view([nf, -1, self.numb_aparam]) nb, nloc, _ = aparam.shape t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) From 1e694a37418a57fddc621d135b18a59626a44e38 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 08:59:44 +0800 Subject: [PATCH 16/56] feat(pt_expt): reject DPA1/se_atten_v2 with attention at compile time DPA1's se_atten attention path produces incorrect force gradients under inductor compile. Add get_numb_attn_layer() API and isinstance guard to reject attn_layer>0 at compile time with a clear error message. DPA1 with attn_layer=0 compiles correctly and is now tested. The guard also covers se_atten_v2 which inherits from DPA1. --- deepmd/dpmodel/descriptor/dpa1.py | 4 ++ deepmd/pt_expt/train/training.py | 17 ++++++++ source/tests/pt_expt/test_training.py | 59 +++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 7 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 17266974c6..b3b00efb79 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -381,6 +381,10 @@ def get_dim_out(self) -> int: def get_dim_emb(self) -> int: return self.se_atten.dim_emb + def get_numb_attn_layer(self) -> int: + """Returns the number of se_atten attention layers.""" + return self.se_atten.attn_layer + def mixed_types(self) -> bool: """If true, the descriptor 1. assumes total number of atoms aligned across frames; diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1ae16c92b1..05acb89826 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -889,6 +889,23 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: for task_key in self.model_keys: model = wrapper_mod.model[task_key] + # Guard: DPA1/se_atten_v2 attention produces incorrect force + # gradients under inductor compile. Other descriptors (DPA2, + # DPA3) use different attention mechanisms that compile correctly. + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP + + descriptor = model.get_descriptor() + if isinstance(descriptor, DescrptDPA1DP): + n_attn = descriptor.get_numb_attn_layer() + if n_attn > 0: + raise RuntimeError( + f"Cannot compile model (task={task_key}): DPA1 " + f"descriptor has {n_attn} se_atten attention " + f"layer(s). Compiled se_atten produces incorrect " + f"force gradients. Set 'enable_compile: false' or " + f"use attn_layer=0." + ) + inp, _ = self.get_data(is_train=True, task_key=task_key) coord = inp["coord"].detach() atype = inp["atype"].detach() diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 8c7f7e2a76..de5727b5ee 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -8,6 +8,7 @@ 4. Loss decreases over those steps """ +import copy import os import shutil import tempfile @@ -46,14 +47,36 @@ _COMPILE_PRED_KEYS = ("atom_energy", "energy", "force", "virial") _COMPILE_TOL = {"atol": 1e-10, "rtol": 1e-10} -# DPA3 descriptor config used to extend the varying-natoms compile-correctness -# test to a non-trivial architecture (repflow with attention). ``precision: -# float64`` is set explicitly so the strict ``atol=rtol=1e-10`` comparison -# holds at machine epsilon. +# Descriptor configs used to extend compile-correctness tests to non-trivial +# architectures. ``precision: float64`` is set so the strict ``atol=rtol=1e-10`` +# comparison holds at machine epsilon. # -# DPA1 (se_atten) is intentionally NOT covered here: inductor's compile of the -# se_atten attention path is intermittently incorrect — see the "known -# limitations" section of the multi-task compile memo for details. +# DPA1 with attn_layer=0 (no se_atten attention) compiles correctly. +# DPA1 with attn_layer>0 is rejected at compile time because inductor produces +# incorrect force gradients through the attention path. +_DESCRIPTOR_DPA1_NO_ATTN = { + "type": "dpa1", + "sel": 12, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "axis_neuron": 4, + "attn_layer": 0, + "precision": "float64", + "seed": 1, +} + +_DESCRIPTOR_DPA1_WITH_ATTN = { + "type": "dpa1", + "sel": 12, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "axis_neuron": 4, + "attn_layer": 2, + "precision": "float64", + "seed": 1, +} _DESCRIPTOR_DPA2 = { "type": "dpa2", "repinit": { @@ -1174,6 +1197,28 @@ def test_compiled_matches_uncompiled_varying_natoms_dpa3(self) -> None: """ self._check_varying_natoms(_DESCRIPTOR_DPA3) + def test_compiled_matches_uncompiled_varying_natoms_dpa1_no_attn(self) -> None: + """DPA1 (attn_layer=0): compiled vs uncompiled match. + + DPA1 without attention compiles correctly and matches eager mode. + """ + self._check_varying_natoms(_DESCRIPTOR_DPA1_NO_ATTN) + + def test_compile_rejects_dpa1_with_attention(self) -> None: + """DPA1 (attn_layer>0): compile must raise RuntimeError. + + Compiled attention produces incorrect force gradients; the compile + path rejects models with se_atten attention layers. + """ + config = _make_config(self.data_dir, numb_steps=2) + config["model"]["descriptor"] = copy.deepcopy(_DESCRIPTOR_DPA1_WITH_ATTN) + config["training"]["enable_compile"] = True + config = update_deepmd_input(config, warning=False) + config = normalize(config) + with self.assertRaises(RuntimeError, msg="se_atten attention"): + trainer = get_trainer(config) + trainer.run() + if __name__ == "__main__": unittest.main() From 6d39ddf45fb3b296052af4e5b50c7659176ed265 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 11:28:04 +0800 Subject: [PATCH 17/56] fix(pt_expt): remove false DPA1 attention compile guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DPA1 with se_atten attention layers compiles correctly under inductor+dynamic — the _check_varying_natoms test passes at 1e-10 tolerance for attn_layer=2. The guard added in 1e694a37 was based on an inconclusive diagnostic; replace the rejection test with a compile-correctness test for DPA1 with attention. --- deepmd/pt_expt/train/training.py | 17 ----------------- source/tests/pt_expt/test_training.py | 27 ++++++--------------------- 2 files changed, 6 insertions(+), 38 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 05acb89826..1ae16c92b1 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -889,23 +889,6 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: for task_key in self.model_keys: model = wrapper_mod.model[task_key] - # Guard: DPA1/se_atten_v2 attention produces incorrect force - # gradients under inductor compile. Other descriptors (DPA2, - # DPA3) use different attention mechanisms that compile correctly. - from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP - - descriptor = model.get_descriptor() - if isinstance(descriptor, DescrptDPA1DP): - n_attn = descriptor.get_numb_attn_layer() - if n_attn > 0: - raise RuntimeError( - f"Cannot compile model (task={task_key}): DPA1 " - f"descriptor has {n_attn} se_atten attention " - f"layer(s). Compiled se_atten produces incorrect " - f"force gradients. Set 'enable_compile: false' or " - f"use attn_layer=0." - ) - inp, _ = self.get_data(is_train=True, task_key=task_key) coord = inp["coord"].detach() atype = inp["atype"].detach() diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index de5727b5ee..9073b5fd61 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -8,7 +8,6 @@ 4. Loss decreases over those steps """ -import copy import os import shutil import tempfile @@ -50,10 +49,6 @@ # Descriptor configs used to extend compile-correctness tests to non-trivial # architectures. ``precision: float64`` is set so the strict ``atol=rtol=1e-10`` # comparison holds at machine epsilon. -# -# DPA1 with attn_layer=0 (no se_atten attention) compiles correctly. -# DPA1 with attn_layer>0 is rejected at compile time because inductor produces -# incorrect force gradients through the attention path. _DESCRIPTOR_DPA1_NO_ATTN = { "type": "dpa1", "sel": 12, @@ -1198,26 +1193,16 @@ def test_compiled_matches_uncompiled_varying_natoms_dpa3(self) -> None: self._check_varying_natoms(_DESCRIPTOR_DPA3) def test_compiled_matches_uncompiled_varying_natoms_dpa1_no_attn(self) -> None: - """DPA1 (attn_layer=0): compiled vs uncompiled match. - - DPA1 without attention compiles correctly and matches eager mode. - """ + """DPA1 (attn_layer=0): compiled vs uncompiled match.""" self._check_varying_natoms(_DESCRIPTOR_DPA1_NO_ATTN) - def test_compile_rejects_dpa1_with_attention(self) -> None: - """DPA1 (attn_layer>0): compile must raise RuntimeError. + def test_compiled_matches_uncompiled_varying_natoms_dpa1_with_attn(self) -> None: + """DPA1 (attn_layer=2): compiled vs uncompiled match. - Compiled attention produces incorrect force gradients; the compile - path rejects models with se_atten attention layers. + Exercises DPA1 with se_atten attention layers; matches at machine + epsilon (~1e-12) on float64 just like se_e2_a. """ - config = _make_config(self.data_dir, numb_steps=2) - config["model"]["descriptor"] = copy.deepcopy(_DESCRIPTOR_DPA1_WITH_ATTN) - config["training"]["enable_compile"] = True - config = update_deepmd_input(config, warning=False) - config = normalize(config) - with self.assertRaises(RuntimeError, msg="se_atten attention"): - trainer = get_trainer(config) - trainer.run() + self._check_varying_natoms(_DESCRIPTOR_DPA1_WITH_ATTN) if __name__ == "__main__": From 23eb6dd6c261fd901b83b10dd7c1b729475fa5a2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 11:32:56 +0800 Subject: [PATCH 18/56] refactor(dpmodel): remove unused get_numb_attn_layer API The compile guard that used this method was removed in 6d39ddf4. --- deepmd/dpmodel/descriptor/dpa1.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index b3b00efb79..17266974c6 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -381,10 +381,6 @@ def get_dim_out(self) -> int: def get_dim_emb(self) -> int: return self.se_atten.dim_emb - def get_numb_attn_layer(self) -> int: - """Returns the number of se_atten attention layers.""" - return self.se_atten.attn_layer - def mixed_types(self) -> bool: """If true, the descriptor 1. assumes total number of atoms aligned across frames; From bacd31281c042c17722ccba02fdc853ea6b8a750 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 13:04:17 +0800 Subject: [PATCH 19/56] fix(test): use real path for PT water data, remove unused API Use ``pt/model/water`` instead of the ``pt/water`` symlink in test_fitting_stat.py to avoid FileNotFoundError on CI. Remove unused ``get_numb_attn_layer()`` from DescrptDPA1. --- source/tests/pt_expt/fitting/test_fitting_stat.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index 038e0dcf27..a525d74538 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -117,13 +117,17 @@ def _get_weighted_fitting_stat( return weighted_avg, weighted_std -# Paths to the water data used by PT tests -_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0") +# Paths to the water data used by PT tests. +# ``source/tests/pt/water`` is a symlink to ``model/water``; use the real +# path so CI checkouts that materialise symlinks as text files still work. +_PT_DATA = str( + Path(__file__).parent.parent.parent / "pt" / "model" / "water" / "data" / "data_0" +) _PT_DATA_NO_FPARAM = str( - Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1" + Path(__file__).parent.parent.parent / "pt" / "model" / "water" / "data" / "data_1" ) _PT_DATA_SINGLE = str( - Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single" + Path(__file__).parent.parent.parent / "pt" / "model" / "water" / "data" / "single" ) _descriptor_se_e2_a = { From f8342028f8b56ce852b2229be6fe9dabcc490c6c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 13:13:50 +0800 Subject: [PATCH 20/56] fix(pt_expt): rebuild FX graph after detach node removal to avoid segfaults After Graph.erase_node() stale C-level prev/next pointers may remain on neighbouring Node objects. Dynamo re-tracing can dereference them and segfault. Rebuild into a fresh graph to eliminate stale pointers. --- deepmd/pt_expt/train/training.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 1ae16c92b1..0151fca27e 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -164,6 +164,23 @@ def _remove_detach_nodes(gm: torch.fx.GraphModule) -> None: gm.recompile() +def _rebuild_graph_module(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """Copy all nodes into a fresh ``torch.fx.Graph``. + + After ``Graph.erase_node()`` the C-level prev/next pointers on + neighbouring ``Node`` objects may become stale. When ``torch.compile`` + (dynamo) later re-traces the graph it walks these pointers, which can + cause segfaults. Rebuilding into a new graph eliminates stale pointers. + """ + old_graph = gm.graph + new_graph = torch.fx.Graph() + val_map: dict[torch.fx.Node, torch.fx.Node] = {} + for node in old_graph.nodes: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + new_graph.lint() + return torch.fx.GraphModule(gm, new_graph) + + def _trace_and_compile( model: torch.nn.Module, ext_coord: torch.Tensor, @@ -272,6 +289,9 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: # second-order gradient flow (d(force)/d(params) for force training). # Removing them restores correct higher-order derivatives. _remove_detach_nodes(traced_lower) + # Rebuild into a fresh graph to eliminate stale C-level node pointers + # left by erase_node(), which can cause segfaults during dynamo re-trace. + traced_lower = _rebuild_graph_module(traced_lower) if not was_training: model.eval() From 447a5721d38252c82aa5315e2f1ce314a9c0916a Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 13:18:26 +0800 Subject: [PATCH 21/56] fix(pt_expt): tune inductor options for compile training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - max_fusion_size 8 → 64 to avoid scheduler timeouts on large descriptors - add triton.mix_order_reduction=False for PyTorch <=2.11 bugs (pytorch/pytorch#174379, #178080, #179494) - hardcode defaults, let user compile_options override per-key --- deepmd/pt_expt/train/training.py | 34 +++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 0151fca27e..f5478cdb92 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -189,7 +189,7 @@ def _trace_and_compile( mapping: torch.Tensor, fparam: torch.Tensor | None, aparam: torch.Tensor | None, - compile_opts: dict[str, Any], + compile_opts: dict[str, Any] | None = None, ) -> torch.nn.Module: """Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True. @@ -199,9 +199,9 @@ def _trace_and_compile( The (uncompiled) model. ext_coord, ext_atype, nlist, mapping, fparam, aparam Sample tensors used to seed the symbolic tracer. - compile_opts : dict - Options forwarded to ``torch.compile`` (the ``dynamic`` and - ``backend`` keys are ignored and replaced). + compile_opts : dict or None + User-supplied inductor options. These are merged on top of the + built-in defaults (user values take precedence). Returns ------- @@ -296,22 +296,28 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: if not was_training: model.eval() - # Work on a copy; ignore caller-supplied dynamic/backend. - compile_opts = { - k: v for k, v in compile_opts.items() if k not in ("dynamic", "backend") + # Inductor defaults tuned for second-order-gradient training graphs. + # User-supplied compile_opts override these on a per-key basis. + inductor_options: dict[str, Any] = { + "max_autotune": False, + "shape_padding": True, + "epilogue_fusion": False, + "triton.cudagraphs": False, + "max_fusion_size": 64, + # NOTE: mix_order_reduction hits multiple bugs under + # data-dependent symbolic shapes on PyTorch <=2.11 + # (pytorch/pytorch#174379, #178080, #179494) -- our + # edge count is exactly that kind of shape. + "triton.mix_order_reduction": False, } - opts = compile_opts.setdefault("options", {}) - opts.setdefault("max_autotune", False) - opts.setdefault("epilogue_fusion", False) - opts.setdefault("triton.cudagraphs", False) - opts.setdefault("shape_padding", True) - opts.setdefault("max_fusion_size", 8) + if compile_opts: + inductor_options.update(compile_opts) return torch.compile( traced_lower, backend="inductor", dynamic=True, - **compile_opts, + options=inductor_options, ) From fb25ccbf506f101e8dfcccb491304b0ee0cc9ed4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 13:20:27 +0800 Subject: [PATCH 22/56] fix(pt_expt): disable DDPOptimizer to prevent compiled graph splitting DDPOptimizer splits the inner compiled graph at bucket boundaries, producing subgraph outputs with symbolic integers that crash AOT Autograd (pytorch/pytorch#134182). --- deepmd/pt_expt/train/training.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index f5478cdb92..2b59c86e40 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -24,6 +24,14 @@ import torch import torch.distributed as dist +# Disable DDPOptimizer: our compile region wraps only the inner compute +# function, not the whole DDP model. DDPOptimizer assumes it owns the +# full model graph and splits at bucket boundaries, producing subgraphs +# whose outputs include symbolic integers. AOT Autograd then crashes +# with ``'int' object has no attribute 'meta'`` +# (pytorch/pytorch#134182). +torch._dynamo.config.optimize_ddp = False + from deepmd.dpmodel.common import ( to_numpy_array, ) From 479900d075c5948c561c24fab7f6441286f2a5bd Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 13:25:13 +0800 Subject: [PATCH 23/56] fix(test): add .cpu() before .numpy() for GPU-compatible activation tests --- source/tests/consistent/test_activation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/test_activation.py b/source/tests/consistent/test_activation.py index b46319e338..803336c15c 100644 --- a/source/tests/consistent/test_activation.py +++ b/source/tests/consistent/test_activation.py @@ -113,7 +113,7 @@ def test_pt_expt_consistent_with_ref(self) -> None: x = torch.tensor( self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE ) - test = _torch_activation(x, self.activation).detach().numpy() + test = _torch_activation(x, self.activation).detach().cpu().numpy() np.testing.assert_allclose(self.ref, test, atol=1e-10) @@ -155,5 +155,5 @@ def test_pt_expt_consistent_with_ref(self) -> None: x = torch.tensor( self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE ) - test = _torch_activation(x, self.activation).detach().numpy() + test = _torch_activation(x, self.activation).detach().cpu().numpy() np.testing.assert_allclose(self.ref, test, atol=1e-10) From b67a181325bbd5927fc719d4c2728e8562891d28 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 15:27:09 +0800 Subject: [PATCH 24/56] fix(pt_expt): revert inductor options that cause numerical divergence max_fusion_size=64 causes DPA3 force divergence and triton.mix_order_reduction=False causes DPA1-attention divergence (both on CPU, float64, 1e-10 tolerance). Revert to max_fusion_size=8 and remove mix_order_reduction (kept as comment for GPU users). Also revert test_fitting_stat.py path back to pt/water symlink which was confirmed working at c2efbf1. --- deepmd/pt_expt/train/training.py | 11 +++++------ source/tests/pt_expt/fitting/test_fitting_stat.py | 12 ++++-------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 2b59c86e40..a81de2de8b 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -311,12 +311,11 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None: "shape_padding": True, "epilogue_fusion": False, "triton.cudagraphs": False, - "max_fusion_size": 64, - # NOTE: mix_order_reduction hits multiple bugs under - # data-dependent symbolic shapes on PyTorch <=2.11 - # (pytorch/pytorch#174379, #178080, #179494) -- our - # edge count is exactly that kind of shape. - "triton.mix_order_reduction": False, + "max_fusion_size": 8, + # NOTE: On GPU with PyTorch <=2.11, consider adding + # "triton.mix_order_reduction": False to work around + # pytorch/pytorch#174379, #178080, #179494 under + # data-dependent symbolic shapes. } if compile_opts: inductor_options.update(compile_opts) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index a525d74538..038e0dcf27 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -117,17 +117,13 @@ def _get_weighted_fitting_stat( return weighted_avg, weighted_std -# Paths to the water data used by PT tests. -# ``source/tests/pt/water`` is a symlink to ``model/water``; use the real -# path so CI checkouts that materialise symlinks as text files still work. -_PT_DATA = str( - Path(__file__).parent.parent.parent / "pt" / "model" / "water" / "data" / "data_0" -) +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0") _PT_DATA_NO_FPARAM = str( - Path(__file__).parent.parent.parent / "pt" / "model" / "water" / "data" / "data_1" + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1" ) _PT_DATA_SINGLE = str( - Path(__file__).parent.parent.parent / "pt" / "model" / "water" / "data" / "single" + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single" ) _descriptor_se_e2_a = { From 7ce7352f394a18a3bc0936cc4fb52394cf114202 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 16:27:30 +0800 Subject: [PATCH 25/56] fix(test): make DDP tests device-adaptive instead of hardcoding CPU Remove os.environ["DEVICE"] = "cpu" from all worker functions and replace hardcoded backend="gloo" with auto-detected _DDP_BACKEND (nccl on CUDA, gloo on CPU). env.py now evaluates DEVICE naturally based on hardware availability. --- source/tests/pt_expt/test_training_ddp.py | 33 ++++++++++------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/source/tests/pt_expt/test_training_ddp.py b/source/tests/pt_expt/test_training_ddp.py index 38771968ac..19235c3ed4 100644 --- a/source/tests/pt_expt/test_training_ddp.py +++ b/source/tests/pt_expt/test_training_ddp.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """Tests for distributed (DDP) training in the pt_expt backend. -Uses ``torch.multiprocessing.spawn`` + ``gloo`` backend to verify DDP on CPU. +Uses ``torch.multiprocessing.spawn`` with auto-detected backend +(``nccl`` on CUDA, ``gloo`` on CPU). Verifies that: 1. Single-task DDP training completes and produces correct outputs @@ -59,6 +60,9 @@ "water", ) +# Auto-detect DDP backend based on device availability. +_DDP_BACKEND = "nccl" if torch.cuda.is_available() else "gloo" + def _find_free_port(): """Find a free TCP port on localhost.""" @@ -306,8 +310,7 @@ def _worker_single_task_train(rank, world_size, port, data_dir, result_dict): """Worker: run single-task DDP training.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_st_rank{rank}_") old_cwd = os.getcwd() @@ -345,8 +348,7 @@ def _worker_multitask_train(rank, world_size, port, data_dir, result_dict): """Worker: run multi-task DDP training.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_rank{rank}_") old_cwd = os.getcwd() @@ -385,8 +387,7 @@ def _worker_gradient_test(rank, world_size, port, data_dir, result_dict): """Worker: run 1 step of DDP training, collect gradients and input data.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_grad_rank{rank}_") old_cwd = os.getcwd() @@ -453,8 +454,7 @@ def _worker_multitask_gradient_test(rank, world_size, port, data_dir, result_dic """Worker: run 1 step of multi-task DDP training, collect gradients.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_grad_rank{rank}_") old_cwd = os.getcwd() @@ -509,8 +509,7 @@ def _worker_check_resume( """ os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_resume_rank{rank}_") old_cwd = os.getcwd() @@ -564,8 +563,7 @@ def _worker_finetune( """ os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_ft_rank{rank}_") old_cwd = os.getcwd() @@ -1331,8 +1329,7 @@ def _worker_multitask_finetune( """Worker: DDP multi-task finetune from checkpoint.""" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_mt_ft_rank{rank}_") old_cwd = os.getcwd() @@ -1380,8 +1377,7 @@ def _worker_single_task_compile_train(rank, world_size, port, data_dir, result_d """ os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_compile_st_rank{rank}_") old_cwd = os.getcwd() @@ -1431,8 +1427,7 @@ def _worker_multitask_compile_train(rank, world_size, port, data_dir, result_dic """ os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = str(port) - os.environ["DEVICE"] = "cpu" - dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) + dist.init_process_group(backend=_DDP_BACKEND, rank=rank, world_size=world_size) try: tmpdir = tempfile.mkdtemp(prefix=f"ddp_compile_mt_rank{rank}_") old_cwd = os.getcwd() From 975db178ac2e4ce85e0dd24aed33f3dc367c6269 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 16:41:23 +0800 Subject: [PATCH 26/56] fix(test): correct freeze test docstrings to match dpa3 guard --- source/tests/pt_expt/test_multitask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/pt_expt/test_multitask.py b/source/tests/pt_expt/test_multitask.py index e5c6955ac0..92ca70936f 100644 --- a/source/tests/pt_expt/test_multitask.py +++ b/source/tests/pt_expt/test_multitask.py @@ -1200,7 +1200,7 @@ def test_multitask_freeze(self) -> None: def test_multitask_freeze_no_head_raises(self) -> None: """Freezing multi-task model without --head raises ValueError. - Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + Only runs for dpa3 descriptor to avoid redundant slow freeze tests. """ if self.descriptor.get("type") != "dpa3": return @@ -1220,7 +1220,7 @@ def test_multitask_freeze_no_head_raises(self) -> None: def test_multitask_freeze_invalid_head_raises(self) -> None: """Freezing multi-task model with invalid --head raises ValueError. - Only runs for se_e2_a descriptor to avoid redundant slow freeze tests. + Only runs for dpa3 descriptor to avoid redundant slow freeze tests. """ if self.descriptor.get("type") != "dpa3": return From 64dc70367e7e2f04e2b44b406ee1a32ae2f8cdbb Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 20:39:31 +0800 Subject: [PATCH 27/56] fix(pt_expt): move optimize_ddp into _compile_model, resolve test symlinks Move torch._dynamo.config.optimize_ddp = False from module level into _compile_model() so it only applies when compile is active. Resolve symlinks in test_fitting_stat.py data paths for reliable CI access. --- deepmd/pt_expt/train/training.py | 16 ++++++++-------- .../tests/pt_expt/fitting/test_fitting_stat.py | 13 +++++++++---- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index a81de2de8b..7a62495b19 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -24,14 +24,6 @@ import torch import torch.distributed as dist -# Disable DDPOptimizer: our compile region wraps only the inner compute -# function, not the whole DDP model. DDPOptimizer assumes it owns the -# full model graph and splits at bucket boundaries, producing subgraphs -# whose outputs include symbolic integers. AOT Autograd then crashes -# with ``'int' object has no attribute 'meta'`` -# (pytorch/pytorch#134182). -torch._dynamo.config.optimize_ddp = False - from deepmd.dpmodel.common import ( to_numpy_array, ) @@ -903,6 +895,14 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: needed. The coord extension + nlist build (data-dependent control flow) are kept outside the compiled region. """ + # Disable DDPOptimizer: our compile region wraps only the inner + # compute function, not the whole DDP model. DDPOptimizer assumes + # it owns the full model graph and splits at bucket boundaries, + # producing subgraphs whose outputs include symbolic integers. + # AOT Autograd then crashes with ``'int' object has no attribute + # 'meta'`` (pytorch/pytorch#134182). + torch._dynamo.config.optimize_ddp = False + from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index 038e0dcf27..590bd6d553 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -117,13 +117,18 @@ def _get_weighted_fitting_stat( return weighted_avg, weighted_std -# Paths to the water data used by PT tests -_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0") +# Paths to the water data used by PT tests. +# resolve() follows the ``pt/water -> model/water`` symlink so numpy can +# always open the real file, even on CI runners where symlink handling +# can be fragile. +_PT_DATA = str( + (Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0").resolve() +) _PT_DATA_NO_FPARAM = str( - Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1" + (Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1").resolve() ) _PT_DATA_SINGLE = str( - Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single" + (Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single").resolve() ) _descriptor_se_e2_a = { From 28fbcac97b8cd9fa4ac87779b00fd74aea6edadc Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 21:32:51 +0800 Subject: [PATCH 28/56] fix(test): backup/restore fparam.npy in TestFparam instead of deleting TestFparam.tearDown deleted the committed fparam.npy, breaking other tests (e.g. test_fitting_stat) running later in the same CI shard. Also revert the unnecessary .resolve() workaround in test_fitting_stat. --- source/tests/pt/test_training.py | 15 +++++++++++++-- source/tests/pt_expt/fitting/test_fitting_stat.py | 13 ++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index b4cc926844..261e1e3374 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -268,11 +268,22 @@ def setUp(self) -> None: self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 self.set_path = Path(__file__).parent / "water/data/data_0" / "set.000" - shutil.copyfile(self.set_path / "energy.npy", self.set_path / "fparam.npy") + # Backup the committed fparam.npy (numb_fparam=2) before overwriting + # with a 1-column version for this test. + self._fparam_backup = self.set_path / "fparam.npy.bak" + fparam_path = self.set_path / "fparam.npy" + if fparam_path.exists(): + shutil.copyfile(fparam_path, self._fparam_backup) + shutil.copyfile(self.set_path / "energy.npy", fparam_path) self.config["model"]["data_stat_nbatch"] = 100 def tearDown(self) -> None: - (self.set_path / "fparam.npy").unlink(missing_ok=True) + # Restore the original fparam.npy so other tests can use it. + fparam_path = self.set_path / "fparam.npy" + if self._fparam_backup.exists(): + shutil.move(str(self._fparam_backup), str(fparam_path)) + else: + fparam_path.unlink(missing_ok=True) DPTrainTest.tearDown(self) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py index 590bd6d553..038e0dcf27 100644 --- a/source/tests/pt_expt/fitting/test_fitting_stat.py +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -117,18 +117,13 @@ def _get_weighted_fitting_stat( return weighted_avg, weighted_std -# Paths to the water data used by PT tests. -# resolve() follows the ``pt/water -> model/water`` symlink so numpy can -# always open the real file, even on CI runners where symlink handling -# can be fragile. -_PT_DATA = str( - (Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0").resolve() -) +# Paths to the water data used by PT tests +_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0") _PT_DATA_NO_FPARAM = str( - (Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1").resolve() + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1" ) _PT_DATA_SINGLE = str( - (Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single").resolve() + Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single" ) _descriptor_se_e2_a = { From fbb361addaf7bc92c52d196a9b18d736703df9db Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 18 Apr 2026 22:06:46 +0800 Subject: [PATCH 29/56] fix(test): skip DDP tests when NCCL is selected with fewer than 2 GPUs NCCL rejects two ranks sharing the same GPU device, causing all DDP tests to fail on single-GPU CI runners. Skip the entire module when the backend is NCCL and device_count < 2. --- source/tests/pt_expt/test_training_ddp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/source/tests/pt_expt/test_training_ddp.py b/source/tests/pt_expt/test_training_ddp.py index 19235c3ed4..0d71e66870 100644 --- a/source/tests/pt_expt/test_training_ddp.py +++ b/source/tests/pt_expt/test_training_ddp.py @@ -63,6 +63,10 @@ # Auto-detect DDP backend based on device availability. _DDP_BACKEND = "nccl" if torch.cuda.is_available() else "gloo" +# NCCL requires at least 2 GPUs for multi-rank tests. +if _DDP_BACKEND == "nccl" and torch.cuda.device_count() < 2: + raise unittest.SkipTest("NCCL DDP tests require at least 2 GPUs") + def _find_free_port(): """Find a free TCP port on localhost.""" From 7739fadaf0dbcd5e6347154c17e5a4b355ead5a1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 20 Apr 2026 10:57:39 +0800 Subject: [PATCH 30/56] perf(pt2): optimize .pt2 C++ inference path Three optimizations for .pt2 (AOTInductor) C++ inference: 1. Replace buildTypeSortedNlist with createNlistTensor: avoids expensive CPU-side distance computation and type-based sorting. The .pt2 model's compiled graph contains format_nlist which handles distance sorting on-device. The new createNlistTensor pads/truncates to the expected nnei (= sum(sel)) stored in model metadata, since torch.export specializes nnei as a static dimension. 2. Export with do_atomic_virial=False by default: avoids 3 extra torch.autograd.grad backward passes in atomic_virial_corr(). The reduced virial (energy_derv_c_redu) is still correct without the correction (it sums to zero). Add --atomic-virial flag to dp convert-backend for users who need per-atom virial (~2.5x cost). C++ raises error if atomic virial requested but model lacks it. 3. Cache mapping tensor: only rebuild when ago==0 (nlist updated), avoiding redundant CPU-to-device copies every step. Also adds nnei and do_atomic_virial to .pt2 metadata for C++ to read. --- deepmd/entrypoints/convert_backend.py | 15 ++- deepmd/main.py | 8 ++ deepmd/pt_expt/utils/serialization.py | 39 ++++++-- source/api_cc/include/DeepPotPTExpt.h | 5 +- source/api_cc/include/DeepSpinPTExpt.h | 5 +- source/api_cc/src/DeepPotPTExpt.cc | 97 +++++++++++-------- source/api_cc/src/DeepSpinPTExpt.cc | 89 +++++++++++------- source/api_cc/src/commonPTExpt.h | 108 +++++----------------- source/tests/infer/gen_dpa1.py | 2 +- source/tests/infer/gen_dpa2.py | 2 +- source/tests/infer/gen_dpa3.py | 2 +- source/tests/infer/gen_fparam_aparam.py | 8 +- source/tests/infer/gen_model_devi.py | 4 +- source/tests/infer/gen_sea.py | 2 +- source/tests/infer/gen_spin.py | 2 +- source/tests/infer/gen_spin_model_devi.py | 2 +- 16 files changed, 207 insertions(+), 183 deletions(-) diff --git a/deepmd/entrypoints/convert_backend.py b/deepmd/entrypoints/convert_backend.py index a8cf20c6b3..db593507ca 100644 --- a/deepmd/entrypoints/convert_backend.py +++ b/deepmd/entrypoints/convert_backend.py @@ -12,6 +12,7 @@ def convert_backend( *, # Enforce keyword-only arguments INPUT: str, OUTPUT: str, + atomic_virial: bool = False, **kwargs: Any, ) -> None: """Convert a model file from one backend to another. @@ -20,12 +21,22 @@ def convert_backend( ---------- INPUT : str The input model file. - INPUT : str + OUTPUT : str The output model file. + atomic_virial : bool + If True, export .pt2/.pte models with per-atom virial correction. + This adds ~2.5x inference cost. Default False. """ inp_backend: Backend = Backend.detect_backend_by_model(INPUT)() out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)() inp_hook = inp_backend.serialize_hook out_hook = out_backend.deserialize_hook data = inp_hook(INPUT) - out_hook(OUTPUT, data) + # Forward atomic_virial to pt_expt deserialize_to_file if applicable + import inspect + + sig = inspect.signature(out_hook) + if "do_atomic_virial" in sig.parameters: + out_hook(OUTPUT, data, do_atomic_virial=atomic_virial) + else: + out_hook(OUTPUT, data) diff --git a/deepmd/main.py b/deepmd/main.py index 3afcda8b4a..0ba87a733b 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -915,6 +915,14 @@ def main_parser() -> argparse.ArgumentParser: ) parser_convert_backend.add_argument("INPUT", help="The input model file.") parser_convert_backend.add_argument("OUTPUT", help="The output model file.") + parser_convert_backend.add_argument( + "--atomic-virial", + action="store_true", + default=False, + help="Export .pt2/.pte models with per-atom virial correction. " + "This adds ~2.5x inference cost but is required for " + "LAMMPS compute/atom virial output.", + ) # * show model ****************************************************************** parser_show = subparsers.add_parser( diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index f59c397525..ab772178fa 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -205,7 +205,10 @@ def _build_dynamic_shapes( {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) {0: nframes_dim, 1: nall_dim}, # extended_spin: (nframes, nall, 3) - {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + { + 0: nframes_dim, + 1: nloc_dim, + }, # nlist: (nframes, nloc, nnei) — nnei is static {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) {0: nframes_dim} if fparam is not None else None, # fparam {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam @@ -217,7 +220,10 @@ def _build_dynamic_shapes( return ( {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) - {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + { + 0: nframes_dim, + 1: nloc_dim, + }, # nlist: (nframes, nloc, nnei) — nnei is static {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) {0: nframes_dim} if fparam is not None else None, # fparam {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam @@ -257,6 +263,7 @@ def _collect_metadata(model: torch.nn.Module, is_spin: bool = False) -> dict: "type_map": model.get_type_map(), "rcut": model.get_rcut(), "sel": model.get_sel(), + "nnei": sum(model.get_sel()), "dim_fparam": model.get_dim_fparam(), "dim_aparam": model.get_dim_aparam(), "mixed_types": model.mixed_types(), @@ -336,6 +343,7 @@ def deserialize_to_file( model_file: str, data: dict, model_json_override: dict | None = None, + do_atomic_virial: bool = False, ) -> None: """Deserialize a dictionary to a .pte or .pt2 model file. @@ -356,16 +364,24 @@ def deserialize_to_file( If provided, this dict is stored in model.json instead of ``data``. Used by ``dp compress`` to store the compressed model dict while tracing the uncompressed model (make_fx cannot trace custom ops). + do_atomic_virial : bool + If True, export with per-atom virial correction (3 extra backward + passes, ~2.5x slower). Default False for best performance. """ if model_file.endswith(".pt2"): - _deserialize_to_file_pt2(model_file, data, model_json_override) + _deserialize_to_file_pt2( + model_file, data, model_json_override, do_atomic_virial + ) else: - _deserialize_to_file_pte(model_file, data, model_json_override) + _deserialize_to_file_pte( + model_file, data, model_json_override, do_atomic_virial + ) def _trace_and_export( data: dict, model_json_override: dict | None = None, + do_atomic_virial: bool = False, ) -> tuple: """Common logic: build model, trace, export. @@ -447,7 +463,7 @@ def _trace_and_export( mapping_t, fparam=fparam, aparam=aparam, - do_atomic_virial=True, + do_atomic_virial=do_atomic_virial, tracing_mode="symbolic", _allow_non_fake_inputs=True, ) @@ -463,7 +479,7 @@ def _trace_and_export( mapping_t, fparam=fparam, aparam=aparam, - do_atomic_virial=True, + do_atomic_virial=do_atomic_virial, tracing_mode="symbolic", _allow_non_fake_inputs=True, ) @@ -504,7 +520,10 @@ def _trace_and_export( exported = move_to_device_pass(exported, target_device) - # 8. Prepare JSON-serializable model dict + # 8. Record export-time config in metadata + metadata["do_atomic_virial"] = do_atomic_virial + + # 9. Prepare JSON-serializable model dict json_source = model_json_override if model_json_override is not None else data data_for_json = deepcopy(json_source) data_for_json = _numpy_to_json_serializable(data_for_json) @@ -516,10 +535,11 @@ def _deserialize_to_file_pte( model_file: str, data: dict, model_json_override: dict | None = None, + do_atomic_virial: bool = False, ) -> None: """Deserialize a dictionary to a .pte model file.""" exported, metadata, data_for_json, output_keys = _trace_and_export( - data, model_json_override + data, model_json_override, do_atomic_virial ) model_def_script = data.get("model_def_script") or {} @@ -537,6 +557,7 @@ def _deserialize_to_file_pt2( model_file: str, data: dict, model_json_override: dict | None = None, + do_atomic_virial: bool = False, ) -> None: """Deserialize a dictionary to a .pt2 model file (AOTInductor). @@ -551,7 +572,7 @@ def _deserialize_to_file_pt2( ) exported, metadata, data_for_json, output_keys = _trace_and_export( - data, model_json_override + data, model_json_override, do_atomic_virial ) # Compile via AOTInductor into a .pt2 package diff --git a/source/api_cc/include/DeepPotPTExpt.h b/source/api_cc/include/DeepPotPTExpt.h index 0d42324d24..5cc9eca06e 100644 --- a/source/api_cc/include/DeepPotPTExpt.h +++ b/source/api_cc/include/DeepPotPTExpt.h @@ -208,9 +208,10 @@ class DeepPotPTExpt : public DeepPotBackend { bool gpu_enabled; std::vector type_map; std::vector output_keys; // sorted internal output key names - bool mixed_types; - std::vector sel; + bool do_atomic_virial; // whether model was exported with atomic virial corr + int nnei; // expected nlist nnei dimension (= sum(sel)) NeighborListData nlist_data; + at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) std::unique_ptr loader; /** diff --git a/source/api_cc/include/DeepSpinPTExpt.h b/source/api_cc/include/DeepSpinPTExpt.h index af108c7690..f73e7c4cfd 100644 --- a/source/api_cc/include/DeepSpinPTExpt.h +++ b/source/api_cc/include/DeepSpinPTExpt.h @@ -183,9 +183,10 @@ class DeepSpinPTExpt : public DeepSpinBackend { bool gpu_enabled; std::vector type_map; std::vector output_keys; - bool mixed_types; - std::vector sel; + bool do_atomic_virial; // whether model was exported with atomic virial corr + int nnei; // expected nlist nnei dimension (= sum(sel)) NeighborListData nlist_data; + at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) std::unique_ptr loader; std::vector run_model(const torch::Tensor& coord, diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index c1f3d9d674..0d52213157 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -17,7 +17,7 @@ #include "errors.h" #include "neighbor_list.h" -using deepmd::ptexpt::buildTypeSortedNlist; +using deepmd::ptexpt::createNlistTensor; using deepmd::ptexpt::parse_json; using deepmd::ptexpt::read_zip_entry; @@ -93,7 +93,6 @@ void DeepPotPTExpt::init(const std::string& model, ntypes = static_cast(metadata["type_map"].as_array().size()); dfparam = metadata["dim_fparam"].as_int(); daparam = metadata["dim_aparam"].as_int(); - mixed_types = metadata["mixed_types"].as_bool(); aparam_nall = false; // pt_expt models use nloc for aparam if (metadata.obj_val.count("has_default_fparam")) { has_default_fparam_ = metadata["has_default_fparam"].as_bool(); @@ -120,16 +119,29 @@ void DeepPotPTExpt::init(const std::string& model, } } + if (metadata.obj_val.count("do_atomic_virial")) { + do_atomic_virial = metadata["do_atomic_virial"].as_bool(); + } else { + // Older models without this field were exported with do_atomic_virial=True + do_atomic_virial = true; + } + + // Read expected nnei (= sum(sel)) — the .pt2 graph has this dimension static. + if (metadata.obj_val.count("nnei")) { + nnei = metadata["nnei"].as_int(); + } else { + // Fallback: compute from sel array + nnei = 0; + for (const auto& v : metadata["sel"].as_array()) { + nnei += v.as_int(); + } + } + type_map.clear(); for (const auto& v : metadata["type_map"].as_array()) { type_map.push_back(v.as_string()); } - sel.clear(); - for (const auto& v : metadata["sel"].as_array()) { - sel.push_back(v.as_int()); - } - // Parse output keys from metadata output_keys.clear(); for (const auto& v : metadata["output_keys"].as_array()) { @@ -256,37 +268,33 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - } - // Build type-sorted, sel-limited nlist expected by the .pt2 model - at::Tensor firstneigh_tensor = - buildTypeSortedNlist(nlist_data.jlist, coord_d, datype, sel, nloc, - mixed_types) - .to(device); - // Build mapping tensor. - // NOTE: must .clone() because the local vector goes out of scope before - // run_model is called, and torch::from_blob does not copy the data. - at::Tensor mapping_tensor; - if (lmp_list.mapping) { - std::vector mapping(nall_real); - for (int ii = 0; ii < nall_real; ii++) { - mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]; - } - mapping_tensor = - torch::from_blob(mapping.data(), {1, nall_real}, int_option) - .clone() - .to(device); - } else { - // Default identity mapping for local atoms - std::vector mapping(nall_real); - for (int ii = 0; ii < nall_real; ii++) { - mapping[ii] = ii; + // Rebuild mapping tensor only when nlist is updated (ago == 0). + if (lmp_list.mapping) { + std::vector mapping(nall_real); + for (int ii = 0; ii < nall_real; ii++) { + mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]; + } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); + } else { + // Default identity mapping for local atoms + std::vector mapping(nall_real); + for (int ii = 0; ii < nall_real; ii++) { + mapping[ii] = ii; + } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); } - mapping_tensor = - torch::from_blob(mapping.data(), {1, nall_real}, int_option) - .clone() - .to(device); } + // Build raw nlist tensor — the .pt2 model's compiled graph handles + // format_nlist (distance sort + truncation) internally on-device. + at::Tensor firstneigh_tensor = + createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); // Build fparam/aparam tensors (cast to float64 for the model) auto valuetype_options = std::is_same::value @@ -364,6 +372,12 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, nall_real); if (atomic) { + if (!do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } // Extract atom_energy: energy (nf, nloc, 1) torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); @@ -518,8 +532,6 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, ncell, ext_stt, ext_end, region, ncell); } - // 3. Build type-sorted, sel-limited nlist (uses double coords for distances) - // 4. Convert to tensors (always float64 for .pt2 model) // NOTE: must .clone() because from_blob does not copy data, and the local // vectors would go out of scope before run_model completes. @@ -532,10 +544,9 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); + // Build raw nlist tensor — the .pt2 model handles format_nlist internally. at::Tensor nlist_tensor = - buildTypeSortedNlist(nlist_raw, coord_cpy_d, atype_cpy, sel, nloc, - mixed_types) - .to(device); + createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) @@ -614,6 +625,12 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, fold_back(force, extended_force, mapping_vec, nloc, nall, 3, nframes); if (atomic) { + if (!do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } // atom_energy: energy (nf, nloc, 1) — already on local atoms torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index ae4ef423ed..b4814a7001 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -17,7 +17,7 @@ #include "errors.h" #include "neighbor_list.h" -using deepmd::ptexpt::buildTypeSortedNlist; +using deepmd::ptexpt::createNlistTensor; using deepmd::ptexpt::parse_json; using deepmd::ptexpt::read_zip_entry; @@ -93,7 +93,6 @@ void DeepSpinPTExpt::init(const std::string& model, ntypes = static_cast(metadata["type_map"].as_array().size()); dfparam = metadata["dim_fparam"].as_int(); daparam = metadata["dim_aparam"].as_int(); - mixed_types = metadata["mixed_types"].as_bool(); aparam_nall = false; // Spin-specific metadata @@ -133,16 +132,29 @@ void DeepSpinPTExpt::init(const std::string& model, } } + if (metadata.obj_val.count("do_atomic_virial")) { + do_atomic_virial = metadata["do_atomic_virial"].as_bool(); + } else { + // Older models without this field were exported with do_atomic_virial=True + do_atomic_virial = true; + } + + // Read expected nnei (= sum(sel)) — the .pt2 graph has this dimension static. + if (metadata.obj_val.count("nnei")) { + nnei = metadata["nnei"].as_int(); + } else { + // Fallback: compute from sel array + nnei = 0; + for (const auto& v : metadata["sel"].as_array()) { + nnei += v.as_int(); + } + } + type_map.clear(); for (const auto& v : metadata["type_map"].as_array()) { type_map.push_back(v.as_string()); } - sel.clear(); - for (const auto& v : metadata["sel"].as_array()) { - sel.push_back(v.as_int()); - } - output_keys.clear(); for (const auto& v : metadata["output_keys"].as_array()) { output_keys.push_back(v.as_string()); @@ -285,33 +297,31 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - } - at::Tensor firstneigh_tensor = - buildTypeSortedNlist(nlist_data.jlist, coord_d, datype, sel, nloc, - mixed_types) - .to(device); - // Build mapping tensor - at::Tensor mapping_tensor; - if (lmp_list.mapping) { - std::vector mapping(nall_real); - for (int ii = 0; ii < nall_real; ii++) { - mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]; - } - mapping_tensor = - torch::from_blob(mapping.data(), {1, nall_real}, int_option) - .clone() - .to(device); - } else { - std::vector mapping(nall_real); - for (int ii = 0; ii < nall_real; ii++) { - mapping[ii] = ii; + // Rebuild mapping tensor only when nlist is updated (ago == 0). + if (lmp_list.mapping) { + std::vector mapping(nall_real); + for (int ii = 0; ii < nall_real; ii++) { + mapping[ii] = fwd_map[lmp_list.mapping[bkw_map[ii]]]; + } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); + } else { + std::vector mapping(nall_real); + for (int ii = 0; ii < nall_real; ii++) { + mapping[ii] = ii; + } + mapping_tensor = + torch::from_blob(mapping.data(), {1, nall_real}, int_option) + .clone() + .to(device); } - mapping_tensor = - torch::from_blob(mapping.data(), {1, nall_real}, int_option) - .clone() - .to(device); } + // Build raw nlist tensor — the .pt2 model handles format_nlist internally. + at::Tensor firstneigh_tensor = + createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); // Build fparam/aparam tensors auto valuetype_options = std::is_same::value @@ -398,6 +408,12 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, fwd_map.size(), nall_real); if (atomic) { + if (!do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); torch::Tensor cpu_atom_energy_ = atom_energy_tensor.to(torch::kCPU); @@ -580,10 +596,9 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); + // Build raw nlist tensor — the .pt2 model handles format_nlist internally. at::Tensor nlist_tensor = - buildTypeSortedNlist(nlist_raw, coord_cpy_d, atype_cpy, sel, nloc, - mixed_types) - .to(device); + createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) @@ -671,6 +686,12 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, fold_back(force_mag, extended_force_mag, mapping_vec, nloc, nall, 3, nframes); if (atomic) { + if (!do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } // atom_energy: energy (nf, nloc, 1) torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index 7dd02d09a9..73a988ffbc 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -438,100 +438,38 @@ inline std::string read_zip_entry(const std::string& zip_path, } // ============================================================================ -// Build type-sorted, sel-limited neighbor list tensor. +// Create raw neighbor list tensor (no type-sorting or sel-limiting). +// The .pt2 model's compiled graph already contains format_nlist which +// sorts by distance and truncates on-device, so CPU-side sorting is +// unnecessary. // ============================================================================ /** - * @brief Convert a raw neighbor list to the sel-limited format expected by the - * pt_expt model. + * @brief Convert a raw neighbor list to a flat tensor. * - * For non-mixed-type models (distinguish_types=true): the nlist has shape - * (nframes, nloc, sum(sel)), where the first sel[0] entries are neighbors of - * type 0, the next sel[1] are type 1, etc. Within each type group neighbors - * are sorted by distance (ascending). - * - * For mixed-type models (distinguish_types=false): all neighbors go into a - * single group sorted by distance, truncated to sum(sel). + * Simply flattens the jagged nlist into a (1, nloc, nnei) int32 tensor. + * The .pt2 model handles format_nlist internally (distance sort + truncation). + */ +/** + * @brief Create nlist tensor with exactly `expected_nnei` neighbors per atom. * - * Missing slots are filled with -1. + * Each row in `data` may have a different number of neighbors. This function + * pads short rows with -1 and truncates long rows to produce a tensor of shape + * [1, nloc, expected_nnei]. The .pt2 compiled graph has nnei baked as a + * static dimension, so the tensor must match exactly. */ -template -inline torch::Tensor buildTypeSortedNlist( - const std::vector>& raw_nlist, - const std::vector& coord_ext, - const std::vector& atype_ext, - const std::vector& sel, - int nloc, - bool mixed_types) { - int nsel = 0; - for (auto s : sel) { - nsel += s; - } - int ntypes = sel.size(); - std::vector result(static_cast(nloc) * nsel, -1); - +inline torch::Tensor createNlistTensor( + const std::vector>& data, int expected_nnei) { + int nloc = static_cast(data.size()); + std::vector flat_data(static_cast(nloc) * expected_nnei, -1); for (int ii = 0; ii < nloc; ++ii) { - const auto& neighbors = raw_nlist[ii]; - VALUETYPE xi = coord_ext[ii * 3 + 0]; - VALUETYPE yi = coord_ext[ii * 3 + 1]; - VALUETYPE zi = coord_ext[ii * 3 + 2]; - int offset = ii * nsel; - - if (mixed_types) { - std::vector> all_neighbors; - for (int jj : neighbors) { - if (jj < 0) { - continue; - } - int jtype = atype_ext[jj]; - if (jtype < 0) { - continue; - } - VALUETYPE dx = coord_ext[jj * 3 + 0] - xi; - VALUETYPE dy = coord_ext[jj * 3 + 1] - yi; - VALUETYPE dz = coord_ext[jj * 3 + 2] - zi; - VALUETYPE rr = dx * dx + dy * dy + dz * dz; - all_neighbors.emplace_back(rr, jj); - } - std::sort(all_neighbors.begin(), all_neighbors.end()); - int count = std::min(static_cast(all_neighbors.size()), nsel); - for (int kk = 0; kk < count; ++kk) { - result[offset + kk] = all_neighbors[kk].second; - } - } else { - std::vector>> by_type(ntypes); - for (int jj : neighbors) { - if (jj < 0) { - continue; - } - int jtype = atype_ext[jj]; - if (jtype < 0 || jtype >= ntypes) { - continue; - } - VALUETYPE dx = coord_ext[jj * 3 + 0] - xi; - VALUETYPE dy = coord_ext[jj * 3 + 1] - yi; - VALUETYPE dz = coord_ext[jj * 3 + 2] - zi; - VALUETYPE rr = dx * dx + dy * dy + dz * dz; - by_type[jtype].emplace_back(rr, jj); - } - int col = 0; - for (int tt = 0; tt < ntypes; ++tt) { - auto& group = by_type[tt]; - std::sort(group.begin(), group.end()); - int count = std::min(static_cast(group.size()), sel[tt]); - for (int kk = 0; kk < count; ++kk) { - result[offset + col + kk] = group[kk].second; - } - col += sel[tt]; - } + int ncopy = std::min(static_cast(data[ii].size()), expected_nnei); + for (int jj = 0; jj < ncopy; ++jj) { + flat_data[static_cast(ii) * expected_nnei + jj] = data[ii][jj]; } } - - torch::Tensor tensor = - torch::from_blob(result.data(), {1, nloc, nsel}, - torch::TensorOptions().dtype(torch::kInt64)) - .clone(); - return tensor; + torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); + return flat_tensor.view({1, nloc, expected_nnei}); } } // namespace ptexpt diff --git a/source/tests/infer/gen_dpa1.py b/source/tests/infer/gen_dpa1.py index 26e8fc4afd..d587d7e534 100644 --- a/source/tests/infer/gen_dpa1.py +++ b/source/tests/infer/gen_dpa1.py @@ -85,7 +85,7 @@ def main(): pt2_path = os.path.join(base_dir, "deeppot_dpa1.pt2") print(f"Exporting to {pt2_path} ...") # noqa: T201 - pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data)) + pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data), do_atomic_virial=True) pth_path = os.path.join(base_dir, "deeppot_dpa1.pth") print(f"Exporting to {pth_path} ...") # noqa: T201 diff --git a/source/tests/infer/gen_dpa2.py b/source/tests/infer/gen_dpa2.py index 84b4d314a3..8ce277fcf5 100644 --- a/source/tests/infer/gen_dpa2.py +++ b/source/tests/infer/gen_dpa2.py @@ -108,7 +108,7 @@ def main(): pt2_path = os.path.join(base_dir, "deeppot_dpa2.pt2") print(f"Exporting to {pt2_path} ...") # noqa: T201 - pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data)) + pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data), do_atomic_virial=True) pth_path = os.path.join(base_dir, "deeppot_dpa2.pth") print(f"Exporting to {pth_path} ...") # noqa: T201 diff --git a/source/tests/infer/gen_dpa3.py b/source/tests/infer/gen_dpa3.py index 322163462d..e5e2d66579 100644 --- a/source/tests/infer/gen_dpa3.py +++ b/source/tests/infer/gen_dpa3.py @@ -86,7 +86,7 @@ def main(): pt2_path = os.path.join(base_dir, "deeppot_dpa3.pt2") print(f"Exporting to {pt2_path} ...") # noqa: T201 - pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data)) + pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data), do_atomic_virial=True) pth_path = os.path.join(base_dir, "deeppot_dpa3.pth") print(f"Exporting to {pth_path} ...") # noqa: T201 diff --git a/source/tests/infer/gen_fparam_aparam.py b/source/tests/infer/gen_fparam_aparam.py index f0d80d5764..13dbe4ffcf 100644 --- a/source/tests/infer/gen_fparam_aparam.py +++ b/source/tests/infer/gen_fparam_aparam.py @@ -67,7 +67,9 @@ def main(): } pt2_default_path = os.path.join(base_dir, "fparam_aparam_default.pt2") print(f"Exporting to {pt2_default_path} ...") # noqa: T201 - pt_expt_deserialize_to_file(pt2_default_path, copy.deepcopy(data_default)) + pt_expt_deserialize_to_file( + pt2_default_path, copy.deepcopy(data_default), do_atomic_virial=True + ) # ---- 3. Export fparam_aparam.pt2 and .pth (without default_fparam) ---- config_no_default = copy.deepcopy(config) @@ -84,7 +86,9 @@ def main(): pt2_path = os.path.join(base_dir, "fparam_aparam.pt2") print(f"Exporting to {pt2_path} ...") # noqa: T201 - pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data_no_default)) + pt_expt_deserialize_to_file( + pt2_path, copy.deepcopy(data_no_default), do_atomic_virial=True + ) pth_path = os.path.join(base_dir, "fparam_aparam.pth") pth_exported = False diff --git a/source/tests/infer/gen_model_devi.py b/source/tests/infer/gen_model_devi.py index bdceaa05aa..7e43add02e 100644 --- a/source/tests/infer/gen_model_devi.py +++ b/source/tests/infer/gen_model_devi.py @@ -99,7 +99,9 @@ def main(): } pt2_path = os.path.join(base_dir, f"model_devi_md{idx}.pt2") print(f"Exporting to {pt2_path} ...") # noqa: T201 - pt_expt_deserialize_to_file(pt2_path, copy.deepcopy(data)) + pt_expt_deserialize_to_file( + pt2_path, copy.deepcopy(data), do_atomic_virial=True + ) models.append(pt2_path) print("Export done.") # noqa: T201 diff --git a/source/tests/infer/gen_sea.py b/source/tests/infer/gen_sea.py index 84b78bbb78..67f8679175 100644 --- a/source/tests/infer/gen_sea.py +++ b/source/tests/infer/gen_sea.py @@ -57,7 +57,7 @@ def main(): } print(f"Exporting to {pt2_path} ...") # noqa: T201 - deserialize_to_file(pt2_path, data) + deserialize_to_file(pt2_path, data, do_atomic_virial=True) print("Done!") # noqa: T201 diff --git a/source/tests/infer/gen_spin.py b/source/tests/infer/gen_spin.py index 3053e0ad4f..d37e3207ff 100644 --- a/source/tests/infer/gen_spin.py +++ b/source/tests/infer/gen_spin.py @@ -112,7 +112,7 @@ def main(): convert_backend(INPUT=yaml_path, OUTPUT=pth_path) print(f"Converting to {pt2_path} ...") # noqa: T201 - convert_backend(INPUT=yaml_path, OUTPUT=pt2_path) + convert_backend(INPUT=yaml_path, OUTPUT=pt2_path, atomic_virial=True) print("Export done.") # noqa: T201 diff --git a/source/tests/infer/gen_spin_model_devi.py b/source/tests/infer/gen_spin_model_devi.py index 3dc5240e14..72792205e2 100644 --- a/source/tests/infer/gen_spin_model_devi.py +++ b/source/tests/infer/gen_spin_model_devi.py @@ -116,7 +116,7 @@ def main(): pt2_path = os.path.join(base_dir, f"deeppot_dpa_spin_md{idx}.pt2") pt2_paths.append(pt2_path) print(f"Converting to {pt2_path} ...") # noqa: T201 - convert_backend(INPUT=yaml_path, OUTPUT=pt2_path) + convert_backend(INPUT=yaml_path, OUTPUT=pt2_path, atomic_virial=True) print("Export done.") # noqa: T201 From b7509db581b4b050b33dfc87764e10a705a58388 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 20 Apr 2026 19:50:10 +0800 Subject: [PATCH 31/56] feat(pt2): make nlist nnei dimension dynamic in .pt2 export Instead of truncating the neighbor list to sum(sel) in C++ (which drops neighbors without distance sorting), make the nnei dimension dynamic in torch.export. The .pt2 model's compiled format_nlist now sorts by distance on-device and truncates to sum(sel), matching the .pth behavior. C++ createNlistTensor pads to at least nnei+1 columns so the format_nlist sort branch is always active. --- deepmd/pt_expt/utils/serialization.py | 27 ++++++++--- source/api_cc/src/DeepPotPTExpt.cc | 14 +++--- source/api_cc/src/DeepSpinPTExpt.cc | 13 +++--- source/api_cc/src/commonPTExpt.h | 45 ++++++++++--------- source/tests/pt_expt/export_helpers.py | 2 +- source/tests/pt_expt/model/test_ener_model.py | 2 +- .../pt_expt/model/test_export_pipeline.py | 8 +++- 7 files changed, 72 insertions(+), 39 deletions(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index ab772178fa..a0723ae5b8 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -141,6 +141,15 @@ def _make_sample_inputs( sel, distinguish_types=not mixed_types, ) + # Pad nlist with extra -1 columns so n_nnei > nnei at trace time. + # This ensures format_nlist's distance-sort branch is traced into the + # compiled graph, allowing the .pt2 model to handle variable-size + # neighbor lists at runtime (e.g. LAMMPS rcut + skin). + nnei = sum(sel) + n_pad = max(1, nnei // 4) # pad by ~25%, at least 1 + nlist = np.concatenate( + [nlist, -np.ones((nframes, nloc, n_pad), dtype=nlist.dtype)], axis=-1 + ) extended_coord = extended_coord.reshape(nframes, -1, 3) # Convert to torch tensors @@ -178,11 +187,12 @@ def _make_sample_inputs( def _build_dynamic_shapes( *sample_inputs: torch.Tensor | None, has_spin: bool = False, + model_nnei: int = 1, ) -> tuple: """Build dynamic shape specifications for torch.export. - Marks nframes, nloc and nall as dynamic dimensions so the exported - program handles arbitrary frame and atom counts. + Marks nframes, nloc, nall and nnei as dynamic dimensions so the exported + program handles arbitrary frame, atom and neighbor counts. Parameters ---------- @@ -190,12 +200,15 @@ def _build_dynamic_shapes( Sample inputs: either 6 tensors (non-spin) or 7 tensors (spin). has_spin : bool Whether the inputs include an extended_spin tensor. + model_nnei : int + The model's sum(sel). Used as the min for the dynamic nnei dim. Returns a tuple (not dict) to match positional args of the make_fx traced module, whose arg names may have suffixes like ``_1``. """ nframes_dim = torch.export.Dim("nframes", min=1) nall_dim = torch.export.Dim("nall", min=1) nloc_dim = torch.export.Dim("nloc", min=1) + nnei_dim = torch.export.Dim("nnei", min=max(2, model_nnei)) if has_spin: # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) @@ -208,7 +221,8 @@ def _build_dynamic_shapes( { 0: nframes_dim, 1: nloc_dim, - }, # nlist: (nframes, nloc, nnei) — nnei is static + 2: nnei_dim, + }, # nlist: (nframes, nloc, nnei) — nnei is dynamic {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) {0: nframes_dim} if fparam is not None else None, # fparam {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam @@ -223,7 +237,8 @@ def _build_dynamic_shapes( { 0: nframes_dim, 1: nloc_dim, - }, # nlist: (nframes, nloc, nnei) — nnei is static + 2: nnei_dim, + }, # nlist: (nframes, nloc, nnei) — nnei is dynamic {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) {0: nframes_dim} if fparam is not None else None, # fparam {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam @@ -493,7 +508,9 @@ def _trace_and_export( # graph. Exporting on CPU keeps devices consistent; we move the # ExportedProgram to the target device afterwards via the official # move_to_device_pass (avoids FakeTensor device-propagation errors). - dynamic_shapes = _build_dynamic_shapes(*sample_inputs, has_spin=is_spin) + dynamic_shapes = _build_dynamic_shapes( + *sample_inputs, has_spin=is_spin, model_nnei=sum(model.get_sel()) + ) exported = torch.export.export( traced, sample_inputs, diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index 0d52213157..f24c2a40e4 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -291,10 +291,11 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, .to(device); } } - // Build raw nlist tensor — the .pt2 model's compiled graph handles - // format_nlist (distance sort + truncation) internally on-device. - at::Tensor firstneigh_tensor = - createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); + // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 + // columns so the .pt2 model's compiled format_nlist sorts on-device. + at::Tensor firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) + .to(torch::kInt64) + .to(device); // Build fparam/aparam tensors (cast to float64 for the model) auto valuetype_options = std::is_same::value @@ -544,9 +545,10 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); - // Build raw nlist tensor — the .pt2 model handles format_nlist internally. + // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 + // columns so the .pt2 model's compiled format_nlist sorts on-device. at::Tensor nlist_tensor = - createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); + createNlistTensor(nlist_raw, nnei + 1).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index b4814a7001..ea2df3b587 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -319,9 +319,11 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, .to(device); } } - // Build raw nlist tensor — the .pt2 model handles format_nlist internally. - at::Tensor firstneigh_tensor = - createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); + // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 + // columns so the .pt2 model's compiled format_nlist sorts on-device. + at::Tensor firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) + .to(torch::kInt64) + .to(device); // Build fparam/aparam tensors auto valuetype_options = std::is_same::value @@ -596,9 +598,10 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); - // Build raw nlist tensor — the .pt2 model handles format_nlist internally. + // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 + // columns so the .pt2 model's compiled format_nlist sorts on-device. at::Tensor nlist_tensor = - createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); + createNlistTensor(nlist_raw, nnei + 1).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index 73a988ffbc..6036632fcc 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -438,38 +438,43 @@ inline std::string read_zip_entry(const std::string& zip_path, } // ============================================================================ -// Create raw neighbor list tensor (no type-sorting or sel-limiting). -// The .pt2 model's compiled graph already contains format_nlist which -// sorts by distance and truncates on-device, so CPU-side sorting is -// unnecessary. +// Create raw neighbor list tensor. +// The .pt2 compiled graph has format_nlist which sorts by distance and +// truncates on-device (GPU). The C++ side just flattens the jagged nlist +// and pads to at least `min_nnei` columns so format_nlist can work. // ============================================================================ /** - * @brief Convert a raw neighbor list to a flat tensor. - * - * Simply flattens the jagged nlist into a (1, nloc, nnei) int32 tensor. - * The .pt2 model handles format_nlist internally (distance sort + truncation). - */ -/** - * @brief Create nlist tensor with exactly `expected_nnei` neighbors per atom. + * @brief Flatten a jagged neighbor list into a tensor, padding to min_nnei. * * Each row in `data` may have a different number of neighbors. This function - * pads short rows with -1 and truncates long rows to produce a tensor of shape - * [1, nloc, expected_nnei]. The .pt2 compiled graph has nnei baked as a - * static dimension, so the tensor must match exactly. + * pads short rows with -1 to produce a tensor of shape + * [1, nloc, max(max_row_length, min_nnei)]. No truncation or distance + * sorting is done — the .pt2 model's compiled format_nlist handles that + * on-device. + * + * @param data Jagged neighbor list: data[i] contains neighbor indices + * for local atom i. + * @param min_nnei Minimum number of neighbor columns. Must be > sum(sel) + * so that format_nlist's sort branch executes. */ inline torch::Tensor createNlistTensor( - const std::vector>& data, int expected_nnei) { + const std::vector>& data, int min_nnei) { int nloc = static_cast(data.size()); - std::vector flat_data(static_cast(nloc) * expected_nnei, -1); + // Find max row length + int max_nnei = 0; + for (int ii = 0; ii < nloc; ++ii) { + max_nnei = std::max(max_nnei, static_cast(data[ii].size())); + } + int nnei = std::max(max_nnei, min_nnei); + std::vector flat_data(static_cast(nloc) * nnei, -1); for (int ii = 0; ii < nloc; ++ii) { - int ncopy = std::min(static_cast(data[ii].size()), expected_nnei); - for (int jj = 0; jj < ncopy; ++jj) { - flat_data[static_cast(ii) * expected_nnei + jj] = data[ii][jj]; + for (size_t jj = 0; jj < data[ii].size(); ++jj) { + flat_data[static_cast(ii) * nnei + jj] = data[ii][jj]; } } torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); - return flat_tensor.view({1, nloc, expected_nnei}); + return flat_tensor.view({1, nloc, nnei}); } } // namespace ptexpt diff --git a/source/tests/pt_expt/export_helpers.py b/source/tests/pt_expt/export_helpers.py index 453b9d0c01..ae4db82ddb 100644 --- a/source/tests/pt_expt/export_helpers.py +++ b/source/tests/pt_expt/export_helpers.py @@ -213,7 +213,7 @@ def model_forward_lower_export_round_trip( tracing_mode="symbolic", _allow_non_fake_inputs=True, ) - dynamic_shapes = _build_dynamic_shapes(*inputs_2f) + dynamic_shapes = _build_dynamic_shapes(*inputs_2f, model_nnei=sum(md_pt.get_sel())) exported_dyn = torch.export.export( traced_sym, inputs_2f, diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index cd520b5cc0..b91653a260 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -366,7 +366,7 @@ def test_forward_lower_exportable(self) -> None: _allow_non_fake_inputs=True, ) - dynamic_shapes = _build_dynamic_shapes(*inputs_5f) + dynamic_shapes = _build_dynamic_shapes(*inputs_5f, model_nnei=sum(md.get_sel())) exported_dyn = torch.export.export( traced_sym, inputs_5f, diff --git a/source/tests/pt_expt/model/test_export_pipeline.py b/source/tests/pt_expt/model/test_export_pipeline.py index 8bff3d7130..23e0a62a98 100644 --- a/source/tests/pt_expt/model/test_export_pipeline.py +++ b/source/tests/pt_expt/model/test_export_pipeline.py @@ -149,7 +149,13 @@ def test_export_pipeline(self, descriptor_type, with_fparam) -> None: # 6. Export with dynamic shapes (same as dp freeze) dynamic_shapes = _build_dynamic_shapes( - ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam, + aparam, + model_nnei=sum(model2.get_sel()), ) exported = torch.export.export( traced, From eec2528dbc627b099e6243c236b69b72a56a5d8d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 20 Apr 2026 20:21:11 +0800 Subject: [PATCH 32/56] fix(pt2): pad nlist in Python eval path for dynamic nnei The .pt2 model's compiled format_nlist expects n_nnei > sum(sel) to trigger the sort branch. Pad the nlist with one extra -1 column before calling the AOTInductor runner. --- deepmd/pt_expt/infer/deep_eval.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 19476a8537..a9be8e40a6 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -759,6 +759,19 @@ def _eval_model( # returning a dict just like the .pte module. # It also filters non-tensor args automatically, matching the # export-time signature where None args were excluded. + # Pad nlist with extra -1 column so n_nnei > nnei, ensuring + # format_nlist's compiled sort branch executes. + nlist_t = torch.cat( + [ + nlist_t, + -torch.ones( + (*nlist_t.shape[:2], 1), + dtype=nlist_t.dtype, + device=nlist_t.device, + ), + ], + dim=-1, + ) model_ret = self._pt2_runner( ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t ) @@ -898,6 +911,19 @@ def _eval_model_spin( # Call the model with spin (7 args) if self._is_pt2: + # Pad nlist with extra -1 column so n_nnei > nnei, ensuring + # format_nlist's compiled sort branch executes. + nlist_t = torch.cat( + [ + nlist_t, + -torch.ones( + (*nlist_t.shape[:2], 1), + dtype=nlist_t.dtype, + device=nlist_t.device, + ), + ], + dim=-1, + ) model_ret = self._pt2_runner( ext_coord_t, ext_atype_t, From 217a58761f86ccedde46fa36e2dc1251ff050193 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 20 Apr 2026 21:30:29 +0800 Subject: [PATCH 33/56] fix(pt2): move atomic virial check before run_model and reject unsupported --atomic-virial Move the do_atomic_virial check before run_model in both DeepPotPTExpt and DeepSpinPTExpt so we fail fast without wasting GPU computation. Also reject --atomic-virial flag in convert_backend when the output backend doesn't support it. --- deepmd/entrypoints/convert_backend.py | 4 ++++ source/api_cc/src/DeepPotPTExpt.cc | 32 +++++++++++++++------------ source/api_cc/src/DeepSpinPTExpt.cc | 32 +++++++++++++++------------ 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/deepmd/entrypoints/convert_backend.py b/deepmd/entrypoints/convert_backend.py index db593507ca..b0ca2f1d97 100644 --- a/deepmd/entrypoints/convert_backend.py +++ b/deepmd/entrypoints/convert_backend.py @@ -39,4 +39,8 @@ def convert_backend( if "do_atomic_virial" in sig.parameters: out_hook(OUTPUT, data, do_atomic_virial=atomic_virial) else: + if atomic_virial: + raise ValueError( + "--atomic-virial is only supported for pt_expt .pt2/.pte outputs" + ) out_hook(OUTPUT, data) diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index f24c2a40e4..fa7e6b009c 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -337,6 +337,14 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } + // Fail fast: check atomic virial availability before running the model + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } + // Run the .pt2 model auto flat_outputs = run_model(coord_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, fparam_tensor, aparam_tensor); @@ -373,12 +381,6 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, nall_real); if (atomic) { - if (!do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } // Extract atom_energy: energy (nf, nloc, 1) torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); @@ -595,11 +597,19 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // 5. Run the .pt2 model + // 5. Fail fast: check atomic virial availability before running the model + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } + + // 6. Run the .pt2 model auto flat_outputs = run_model(coord_Tensor, atype_Tensor, nlist_tensor, mapping_tensor, fparam_tensor, aparam_tensor); - // 6. Map flat outputs to internal keys + // 7. Map flat outputs to internal keys std::map output_map; extract_outputs(output_map, flat_outputs); @@ -627,12 +637,6 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, fold_back(force, extended_force, mapping_vec, nloc, nall, 3, nframes); if (atomic) { - if (!do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } // atom_energy: energy (nf, nloc, 1) — already on local atoms torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index ea2df3b587..a3271a3a7b 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -365,6 +365,14 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } + // Fail fast: check atomic virial availability before running the model + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } + // Run the .pt2 model (7 args for spin) auto flat_outputs = run_model(coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, @@ -410,12 +418,6 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, fwd_map.size(), nall_real); if (atomic) { - if (!do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); torch::Tensor cpu_atom_energy_ = atom_energy_tensor.to(torch::kCPU); @@ -648,12 +650,20 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // 5. Run the .pt2 model (7 args for spin) + // 5. Fail fast: check atomic virial availability before running the model + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial is not available in this .pt2 model " + "(exported without --atomic-virial). " + "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); + } + + // 6. Run the .pt2 model (7 args for spin) auto flat_outputs = run_model(coord_Tensor, atype_Tensor, spin_Tensor, nlist_tensor, mapping_tensor, fparam_tensor, aparam_tensor); - // 6. Extract outputs + // 7. Extract outputs std::map output_map; extract_outputs(output_map, flat_outputs); @@ -689,12 +699,6 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, fold_back(force_mag, extended_force_mag, mapping_vec, nloc, nall, 3, nframes); if (atomic) { - if (!do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } // atom_energy: energy (nf, nloc, 1) torch::Tensor atom_energy_tensor = output_map["energy"].view({-1}).to(floatType); From 8a9fe6338414280e1ea5bb9d5502d6ff41289547 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 06:50:07 +0800 Subject: [PATCH 34/56] fix(pt2): move nlist padding inside traced fn and strip shape assertions Move nlist padding (+1 column of -1s) inside the `fn` closure in both `make_model.forward_common_lower_exportable` and `SpinModel.forward_common_lower_exportable`, making it part of the traced graph. This fixes proxy tensor shape mismatches from make_fx and removes the need for external padding in deep_eval.py. Also apply `_strip_shape_assertions` unconditionally (not just spin models) to remove spurious torch.export guards like Ne(nnei, sum(sel)). Export tests that verify atomic virial now pass `do_atomic_virial=True` to `deserialize_to_file` so the exported model includes the correction. --- deepmd/pt_expt/infer/deep_eval.py | 26 ----------- deepmd/pt_expt/model/make_model.py | 37 +++++++++++---- deepmd/pt_expt/model/spin_model.py | 37 +++++++++++---- deepmd/pt_expt/utils/serialization.py | 47 +++++++++----------- source/tests/pt_expt/export_helpers.py | 27 ++++++++++- source/tests/pt_expt/infer/test_deep_eval.py | 14 +++--- 6 files changed, 110 insertions(+), 78 deletions(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index a9be8e40a6..19476a8537 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -759,19 +759,6 @@ def _eval_model( # returning a dict just like the .pte module. # It also filters non-tensor args automatically, matching the # export-time signature where None args were excluded. - # Pad nlist with extra -1 column so n_nnei > nnei, ensuring - # format_nlist's compiled sort branch executes. - nlist_t = torch.cat( - [ - nlist_t, - -torch.ones( - (*nlist_t.shape[:2], 1), - dtype=nlist_t.dtype, - device=nlist_t.device, - ), - ], - dim=-1, - ) model_ret = self._pt2_runner( ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t ) @@ -911,19 +898,6 @@ def _eval_model_spin( # Call the model with spin (7 args) if self._is_pt2: - # Pad nlist with extra -1 column so n_nnei > nnei, ensuring - # format_nlist's compiled sort branch executes. - nlist_t = torch.cat( - [ - nlist_t, - -torch.ones( - (*nlist_t.shape[:2], 1), - dtype=nlist_t.dtype, - device=nlist_t.device, - ), - ], - dim=-1, - ) model_ret = self._pt2_runner( ext_coord_t, ext_atype_t, diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 4bd9792420..f986c840c4 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -346,6 +346,21 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + # Pad nlist with one extra -1 column inside the traced function. + # This ensures n_nnei > sum(sel), forcing the sort branch in + # _format_nlist. The padding becomes part of the compiled graph, + # so callers never need to pad externally. + nlist = torch.cat( + [ + nlist, + -torch.ones( + (*nlist.shape[:2], 1), + dtype=nlist.dtype, + device=nlist.device, + ), + ], + dim=-1, + ) return model.forward_common_lower( extended_coord, extended_atype, @@ -356,13 +371,19 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, - extended_atype, - nlist, - mapping, - fparam, - aparam, - ) + # Force format_nlist to always use the sort branch during tracing. + model.need_sorted_nlist_for_lower = lambda: True + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + finally: + del model.need_sorted_nlist_for_lower + return traced return CM diff --git a/deepmd/pt_expt/model/spin_model.py b/deepmd/pt_expt/model/spin_model.py index 70f41f0701..92b9bfa095 100644 --- a/deepmd/pt_expt/model/spin_model.py +++ b/deepmd/pt_expt/model/spin_model.py @@ -96,6 +96,18 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + # Pad nlist inside traced function (see make_model.py for rationale). + nlist = torch.cat( + [ + nlist, + -torch.ones( + (*nlist.shape[:2], 1), + dtype=nlist.dtype, + device=nlist.device, + ), + ], + dim=-1, + ) return model.forward_common_lower( extended_coord, extended_atype, @@ -107,15 +119,22 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, - extended_atype, - extended_spin, - nlist, - mapping, - fparam, - aparam, - ) + # Force format_nlist to always use the sort branch during tracing. + backbone = model.backbone_model + backbone.need_sorted_nlist_for_lower = lambda: True + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) + finally: + del backbone.need_sorted_nlist_for_lower + return traced def forward_common_lower( self, *args: Any, **kwargs: Any diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index a0723ae5b8..165066664c 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -17,23 +17,23 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: - """Remove shape-guard assertion nodes from a spin model's exported graph. + """Remove shape-guard assertion nodes from an exported graph. ``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape - relationships discovered during tracing. For the spin model, the atom- - doubling logic creates slice patterns that depend on ``(nall - nloc)``, - producing guards like ``Ne(nall, nloc)``. These guards are spurious: the - model computes correct results even when ``nall == nloc`` (NoPBC, no ghost - atoms). - - This function is **only called for spin models** (guarded by ``if is_spin`` - in ``_trace_and_export``). The assertion messages use opaque symbolic - variable names (e.g. ``Ne(s22, s96)``) rather than human-readable names, - so filtering by message content is not reliable. Since + relationships discovered during tracing. These guards can be spurious: + + * **Spin models**: atom-doubling logic creates slice patterns that depend + on ``(nall - nloc)``, producing guards like ``Ne(nall, nloc)``. + * **All models**: the nlist padding inside ``forward_common_lower_exportable`` + and the subsequent sort/truncate in ``_format_nlist`` can produce guards + like ``Ne(nnei, sum(sel))``. These are spurious because the compiled + graph handles any ``nnei >= sum(sel)`` correctly. + + The assertion messages use opaque symbolic variable names (e.g. + ``Ne(s22, s96)``) rather than human-readable names, so filtering by + message content is not reliable. Since ``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape - guards into these deferred assertions, and the only shape relationships in - the spin model involve nall/nloc, removing all of them is safe in this - context. + guards into these deferred assertions, removing all of them is safe. """ graph = graph_module.graph for node in list(graph.nodes): @@ -141,10 +141,8 @@ def _make_sample_inputs( sel, distinguish_types=not mixed_types, ) - # Pad nlist with extra -1 columns so n_nnei > nnei at trace time. - # This ensures format_nlist's distance-sort branch is traced into the - # compiled graph, allowing the .pt2 model to handle variable-size - # neighbor lists at runtime (e.g. LAMMPS rcut + skin). + # Pad nlist so nnei > sum(sel) in the sample tensors. + # This prevents torch.export from specializing nnei to sum(sel). nnei = sum(sel) n_pad = max(1, nnei // 4) # pad by ~25%, at least 1 nlist = np.concatenate( @@ -519,15 +517,10 @@ def _trace_and_export( prefer_deferred_runtime_asserts_over_guards=True, ) - if is_spin: - # torch.export re-introduces shape-guard assertions even when - # the make_fx graph has none. The spin model's atom-doubling - # logic creates slice patterns that depend on (nall - nloc), - # producing guards like Ne(nall, nloc). These guards are - # spurious: the model is correct when nall == nloc (NoPBC). - # Strip them from the exported graph so the model can be - # used with any valid nall >= nloc. - _strip_shape_assertions(exported.graph_module) + # torch.export inserts _assert_scalar guards for symbolic shape + # relationships (e.g. Ne(nnei, sum(sel)), Ne(nall, nloc)). These + # are spurious — the model handles any valid input shapes correctly. + _strip_shape_assertions(exported.graph_module) # 7. Move the exported program to the target device if needed. if target_device.type != "cpu": diff --git a/source/tests/pt_expt/export_helpers.py b/source/tests/pt_expt/export_helpers.py index ae4db82ddb..beb286fe25 100644 --- a/source/tests/pt_expt/export_helpers.py +++ b/source/tests/pt_expt/export_helpers.py @@ -70,6 +70,12 @@ def export_save_load_and_compare( strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) + # Strip spurious shape-guard assertions (e.g. Ne(nnei, sum(sel))) + from deepmd.pt_expt.utils.serialization import ( + _strip_shape_assertions, + ) + + _strip_shape_assertions(exported.graph_module) # 4. .pte save -> load round-trip with tempfile.NamedTemporaryFile(suffix=".pte") as f: @@ -199,9 +205,22 @@ def model_forward_lower_export_round_trip( ) # 5. Symbolic trace + dynamic shapes + .pte round-trip + # Pad nlist with extra -1 columns so nnei > sum(sel) in the sample. + # This prevents torch.export from specializing nnei to sum(sel). + nlist_padded = torch.cat( + [ + nlist_t, + -torch.ones( + (*nlist_t.shape[:2], max(1, nlist_t.shape[2] // 4)), + dtype=nlist_t.dtype, + device=nlist_t.device, + ), + ], + dim=-1, + ) inputs_2f = tuple( torch.cat([t, t], dim=0) if t is not None else None - for t in (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) + for t in (ext_coord, ext_atype, nlist_padded, mapping_t, fparam, aparam) ) traced_sym = md_pt.forward_lower_exportable( inputs_2f[0], @@ -221,6 +240,12 @@ def model_forward_lower_export_round_trip( strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) + # Strip spurious shape-guard assertions (e.g. Ne(nnei, sum(sel))) + from deepmd.pt_expt.utils.serialization import ( + _strip_shape_assertions, + ) + + _strip_shape_assertions(exported_dyn.graph_module) with tempfile.NamedTemporaryFile(suffix=".pte") as f: torch.export.save(exported_dyn, f.name) loaded = torch.export.load(f.name).module() diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 6797fa2c03..5a1a315b7c 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -63,11 +63,11 @@ def setUpClass(cls) -> None: cls.model = cls.model.to(torch.float64) cls.model.eval() - # Serialize and save to .pte + # Serialize and save to .pte (with atomic virial for test_dynamic_shapes) cls.model_data = {"model": cls.model.serialize()} cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) cls.tmpfile.close() - deserialize_to_file(cls.tmpfile.name, cls.model_data) + deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True) # Create DeepPot for testing cls.dp = DeepPot(cls.tmpfile.name) @@ -547,14 +547,14 @@ def setUpClass(cls) -> None: # compilation (tests/pt/__init__.py sets it to "cuda:9999999"). torch.set_default_device(None) try: - deserialize_to_file(cls.tmpfile.name, cls.model_data) + deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True) finally: torch.set_default_device("cuda:9999999") # Also save to .pte for cross-format comparison cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) cls.pte_tmpfile.close() - deserialize_to_file(cls.pte_tmpfile.name, cls.model_data) + deserialize_to_file(cls.pte_tmpfile.name, cls.model_data, do_atomic_virial=True) # Create DeepPot for .pt2 cls.dp = DeepPot(cls.tmpfile.name) @@ -1070,7 +1070,7 @@ def setUpClass(cls) -> None: cls.model_data = {"model": cls.model.serialize()} cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) cls.tmpfile.close() - deserialize_to_file(cls.tmpfile.name, cls.model_data) + deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True) cls.dp = DeepPot(cls.tmpfile.name) @@ -1187,14 +1187,14 @@ def setUpClass(cls) -> None: cls.tmpfile.close() torch.set_default_device(None) try: - deserialize_to_file(cls.tmpfile.name, cls.model_data) + deserialize_to_file(cls.tmpfile.name, cls.model_data, do_atomic_virial=True) finally: torch.set_default_device("cuda:9999999") # Also save .pte for cross-format comparison cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) cls.pte_tmpfile.close() - deserialize_to_file(cls.pte_tmpfile.name, cls.model_data) + deserialize_to_file(cls.pte_tmpfile.name, cls.model_data, do_atomic_virial=True) cls.dp = DeepPot(cls.tmpfile.name) cls.dp_pte = DeepPot(cls.pte_tmpfile.name) From 711a1f4ce36981a1013c472427c83213643d2eda Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 08:06:19 +0800 Subject: [PATCH 35/56] fix(test): export test models with atomic_virial=True for .pte/.pt2 The test_models.py tests compare per-atom virial against reference values that include the atomic virial correction. Since convert_backend now defaults to atomic_virial=False for performance, test model generation must explicitly request atomic_virial=True. --- source/tests/infer/case.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/source/tests/infer/case.py b/source/tests/infer/case.py index 828974c6e6..84acc5bad5 100644 --- a/source/tests/infer/case.py +++ b/source/tests/infer/case.py @@ -173,7 +173,12 @@ def get_model(self, suffix: str, out_file: str | None = None) -> str: out_file = tempfile.NamedTemporaryFile( suffix=suffix, dir=tempdir.name, delete=False, prefix=self.key + "_" ).name - convert_backend(INPUT=self.filename, OUTPUT=out_file) + # For .pte/.pt2, export with atomic virial so tests can verify + # per-atom virial against reference values. + kwargs: dict = {} + if suffix in (".pte", ".pt2"): + kwargs["atomic_virial"] = True + convert_backend(INPUT=self.filename, OUTPUT=out_file, **kwargs) return out_file From acba914232d7ca8330dc5131ab1f0f428188f0f1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 12:28:25 +0800 Subject: [PATCH 36/56] perf(pt2): cache firstneigh_tensor across timesteps Move createNlistTensor() inside the `if (ago == 0)` block so the nlist tensor is only rebuilt when LAMMPS updates the neighbor list, matching the existing caching pattern for mapping_tensor. Previously, the nlist was rebuilt every step: flatten jagged list, allocate std::vector, copy into tensor, cast int32->int64, transfer to device. This was the dominant CPU+H2D cost outside the model itself. --- source/api_cc/include/DeepPotPTExpt.h | 3 ++- source/api_cc/include/DeepSpinPTExpt.h | 3 ++- source/api_cc/src/DeepPotPTExpt.cc | 11 ++++++----- source/api_cc/src/DeepSpinPTExpt.cc | 11 ++++++----- 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/source/api_cc/include/DeepPotPTExpt.h b/source/api_cc/include/DeepPotPTExpt.h index 5cc9eca06e..1bcf44a885 100644 --- a/source/api_cc/include/DeepPotPTExpt.h +++ b/source/api_cc/include/DeepPotPTExpt.h @@ -211,7 +211,8 @@ class DeepPotPTExpt : public DeepPotBackend { bool do_atomic_virial; // whether model was exported with atomic virial corr int nnei; // expected nlist nnei dimension (= sum(sel)) NeighborListData nlist_data; - at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) + at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) + at::Tensor firstneigh_tensor; // cached nlist tensor (LAMMPS path) std::unique_ptr loader; /** diff --git a/source/api_cc/include/DeepSpinPTExpt.h b/source/api_cc/include/DeepSpinPTExpt.h index f73e7c4cfd..47b38767d4 100644 --- a/source/api_cc/include/DeepSpinPTExpt.h +++ b/source/api_cc/include/DeepSpinPTExpt.h @@ -186,7 +186,8 @@ class DeepSpinPTExpt : public DeepSpinBackend { bool do_atomic_virial; // whether model was exported with atomic virial corr int nnei; // expected nlist nnei dimension (= sum(sel)) NeighborListData nlist_data; - at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) + at::Tensor mapping_tensor; // cached mapping tensor (LAMMPS path) + at::Tensor firstneigh_tensor; // cached nlist tensor (LAMMPS path) std::unique_ptr loader; std::vector run_model(const torch::Tensor& coord, diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index fa7e6b009c..4f16aacf04 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -290,12 +290,13 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, .clone() .to(device); } + + // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 + // columns so the .pt2 model's compiled format_nlist sorts on-device. + firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) + .to(torch::kInt64) + .to(device); } - // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 - // columns so the .pt2 model's compiled format_nlist sorts on-device. - at::Tensor firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) - .to(torch::kInt64) - .to(device); // Build fparam/aparam tensors (cast to float64 for the model) auto valuetype_options = std::is_same::value diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index a3271a3a7b..c002a1ec8e 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -318,12 +318,13 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, .clone() .to(device); } + + // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 + // columns so the .pt2 model's compiled format_nlist sorts on-device. + firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) + .to(torch::kInt64) + .to(device); } - // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 - // columns so the .pt2 model's compiled format_nlist sorts on-device. - at::Tensor firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) - .to(torch::kInt64) - .to(device); // Build fparam/aparam tensors auto valuetype_options = std::is_same::value From 9426572243b74b8485c805434d278cce81140ae3 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 12:54:52 +0800 Subject: [PATCH 37/56] test(pt2): add regression test for oversized nlist with distance sorting Adds test_oversized_nlist to both .pte and .pt2 test classes. The test shuffles neighbor columns so real neighbors appear past sum(sel), forcing format_nlist to sort by distance rather than just truncating. This catches the bug fixed in b7509db5 where missing distance sort would produce wrong results with oversized neighbor lists. --- source/tests/pt_expt/infer/test_deep_eval.py | 136 +++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 5a1a315b7c..1cc4d074c9 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -273,6 +273,74 @@ def test_dynamic_shapes(self) -> None: err_msg=f"nloc={nloc}, key={key}", ) + def test_oversized_nlist(self) -> None: + """Test that the exported model handles nlist with more neighbors than nnei. + + In LAMMPS, the neighbor list is built with rcut + skin, so atoms + typically have more neighbors than sum(sel). The compiled + format_nlist must sort by distance and truncate correctly. + + The test shuffles neighbor columns so that some real (close) neighbors + appear AFTER sum(sel), forcing format_nlist to distance-sort rather + than just truncating by column index. A model that truncates without + sorting would pick wrong neighbors and produce different energy/force. + """ + exported = torch.export.load(self.tmpfile.name) + exported_mod = exported.module() + + nnei = sum(self.sel) # model's expected neighbor count + nloc = 5 + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs( + self.model, nloc=nloc + ) + + # Pad nlist with -1 columns, then shuffle column order so real + # neighbors are interspersed with absent ones beyond column sum(sel). + n_extra = nnei # double the nlist width + nlist_padded = torch.cat( + [ + nlist_t, + -torch.ones( + (*nlist_t.shape[:2], n_extra), + dtype=nlist_t.dtype, + device=nlist_t.device, + ), + ], + dim=-1, + ) + # Shuffle columns: move some real neighbors past sum(sel) boundary. + rng = np.random.default_rng(42) + perm = rng.permutation(nlist_padded.shape[-1]) + nlist_shuffled = nlist_padded[:, :, perm] + assert nlist_shuffled.shape[-1] > nnei + + # Reference: eager model with shuffled oversized nlist + ec = ext_coord.detach().requires_grad_(True) + ref_ret = self.model.forward_common_lower( + ec, + ext_atype, + nlist_shuffled, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + # Exported model with same shuffled oversized nlist + pte_ret = exported_mod( + ext_coord, ext_atype, nlist_shuffled, mapping_t, fparam, aparam + ) + + for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"): + if ref_ret[key] is not None and key in pte_ret: + np.testing.assert_allclose( + ref_ret[key].detach().cpu().numpy(), + pte_ret[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"oversized nlist, key={key}", + ) + def test_serialize_round_trip(self) -> None: """Test .pte → serialize_from_file → deserialize → model gives same outputs.""" loaded_data = serialize_from_file(self.tmpfile.name) @@ -768,6 +836,74 @@ def test_no_pbc(self) -> None: v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) + def test_oversized_nlist(self) -> None: + """Test that the exported model handles nlist with more neighbors than nnei. + + In LAMMPS, the neighbor list is built with rcut + skin, so atoms + typically have more neighbors than sum(sel). The compiled + format_nlist must sort by distance and truncate correctly. + + The test shuffles neighbor columns so that some real (close) neighbors + appear AFTER sum(sel), forcing format_nlist to distance-sort rather + than just truncating by column index. A model that truncates without + sorting would pick wrong neighbors and produce different energy/force. + """ + exported = torch.export.load(self.pte_tmpfile.name) + exported_mod = exported.module() + + nnei = sum(self.sel) # model's expected neighbor count + nloc = 5 + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs( + self.model, nloc=nloc + ) + + # Pad nlist with -1 columns, then shuffle column order so real + # neighbors are interspersed with absent ones beyond column sum(sel). + n_extra = nnei # double the nlist width + nlist_padded = torch.cat( + [ + nlist_t, + -torch.ones( + (*nlist_t.shape[:2], n_extra), + dtype=nlist_t.dtype, + device=nlist_t.device, + ), + ], + dim=-1, + ) + # Shuffle columns: move some real neighbors past sum(sel) boundary. + rng = np.random.default_rng(42) + perm = rng.permutation(nlist_padded.shape[-1]) + nlist_shuffled = nlist_padded[:, :, perm] + assert nlist_shuffled.shape[-1] > nnei + + # Reference: eager model with shuffled oversized nlist + ec = ext_coord.detach().requires_grad_(True) + ref_ret = self.model.forward_common_lower( + ec, + ext_atype, + nlist_shuffled, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + # Exported model with same shuffled oversized nlist + pte_ret = exported_mod( + ext_coord, ext_atype, nlist_shuffled, mapping_t, fparam, aparam + ) + + for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"): + if ref_ret[key] is not None and key in pte_ret: + np.testing.assert_allclose( + ref_ret[key].detach().cpu().numpy(), + pte_ret[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"oversized nlist, key={key}", + ) + def test_serialize_round_trip(self) -> None: """Test .pt2 → serialize_from_file → deserialize → model gives same outputs.""" loaded_data = serialize_from_file(self.tmpfile.name) From 97870bf8b62d0b9c14db08eeff4b81309372796c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 13:47:55 +0800 Subject: [PATCH 38/56] test(pt2): add regression test for oversized nlist with distance sorting Adds test_oversized_nlist to both .pte and .pt2 test classes. The test verifies two things: 1. The exported model correctly handles an oversized, shuffled nlist (sorts by distance, keeps closest sum(sel) neighbors). 2. Naive truncation of the shuffled nlist to sum(sel) columns (simulating the old C++ createNlistTensor bug) produces wrong energy, proving that close neighbors beyond sum(sel) are lost without distance sorting. --- source/tests/pt_expt/infer/test_deep_eval.py | 77 +++++++++++++++++--- 1 file changed, 67 insertions(+), 10 deletions(-) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 1cc4d074c9..f96f08ae28 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -280,10 +280,16 @@ def test_oversized_nlist(self) -> None: typically have more neighbors than sum(sel). The compiled format_nlist must sort by distance and truncate correctly. - The test shuffles neighbor columns so that some real (close) neighbors - appear AFTER sum(sel), forcing format_nlist to distance-sort rather - than just truncating by column index. A model that truncates without - sorting would pick wrong neighbors and produce different energy/force. + The test verifies two things: + + 1. **Correctness**: the exported model with an oversized, shuffled + nlist produces the same results as the eager model (both sort by + distance and keep the closest sum(sel) neighbors). + + 2. **Naive truncation produces wrong results**: simply taking the + first sum(sel) columns of the shuffled nlist (simulating a C++ + implementation that truncates without sorting) gives a different + energy. This proves the distance sort is necessary. """ exported = torch.export.load(self.tmpfile.name) exported_mod = exported.module() @@ -314,6 +320,7 @@ def test_oversized_nlist(self) -> None: nlist_shuffled = nlist_padded[:, :, perm] assert nlist_shuffled.shape[-1] > nnei + # --- Part 1: exported model sorts correctly --- # Reference: eager model with shuffled oversized nlist ec = ext_coord.detach().requires_grad_(True) ref_ret = self.model.forward_common_lower( @@ -341,6 +348,31 @@ def test_oversized_nlist(self) -> None: err_msg=f"oversized nlist, key={key}", ) + # --- Part 2: naive truncation gives wrong results --- + # Simulate the old C++ bug: truncate shuffled nlist to sum(sel) columns + # without distance sorting. Some close neighbors that were shuffled + # beyond column sum(sel) are lost, producing wrong energy. + nlist_truncated = nlist_shuffled[:, :, :nnei] + ec2 = ext_coord.detach().requires_grad_(True) + trunc_ret = self.model.forward_common_lower( + ec2, + ext_atype, + nlist_truncated, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + # The truncated result MUST differ from the correctly sorted result, + # proving that naive truncation discards real neighbors. + e_ref = ref_ret["energy_redu"].detach().cpu().numpy() + e_trunc = trunc_ret["energy_redu"].detach().cpu().numpy() + assert not np.allclose(e_ref, e_trunc, rtol=1e-10, atol=1e-10), ( + "Naive truncation of shuffled nlist should give different energy, " + "but got the same result. The test data may not have enough " + "neighbors shuffled beyond sum(sel) to trigger the bug." + ) + def test_serialize_round_trip(self) -> None: """Test .pte → serialize_from_file → deserialize → model gives same outputs.""" loaded_data = serialize_from_file(self.tmpfile.name) @@ -843,10 +875,16 @@ def test_oversized_nlist(self) -> None: typically have more neighbors than sum(sel). The compiled format_nlist must sort by distance and truncate correctly. - The test shuffles neighbor columns so that some real (close) neighbors - appear AFTER sum(sel), forcing format_nlist to distance-sort rather - than just truncating by column index. A model that truncates without - sorting would pick wrong neighbors and produce different energy/force. + The test verifies two things: + + 1. **Correctness**: the exported model with an oversized, shuffled + nlist produces the same results as the eager model (both sort by + distance and keep the closest sum(sel) neighbors). + + 2. **Naive truncation produces wrong results**: simply taking the + first sum(sel) columns of the shuffled nlist (simulating a C++ + implementation that truncates without sorting) gives a different + energy. This proves the distance sort is necessary. """ exported = torch.export.load(self.pte_tmpfile.name) exported_mod = exported.module() @@ -877,7 +915,7 @@ def test_oversized_nlist(self) -> None: nlist_shuffled = nlist_padded[:, :, perm] assert nlist_shuffled.shape[-1] > nnei - # Reference: eager model with shuffled oversized nlist + # --- Part 1: exported model sorts correctly --- ec = ext_coord.detach().requires_grad_(True) ref_ret = self.model.forward_common_lower( ec, @@ -889,7 +927,6 @@ def test_oversized_nlist(self) -> None: do_atomic_virial=True, ) - # Exported model with same shuffled oversized nlist pte_ret = exported_mod( ext_coord, ext_atype, nlist_shuffled, mapping_t, fparam, aparam ) @@ -904,6 +941,26 @@ def test_oversized_nlist(self) -> None: err_msg=f"oversized nlist, key={key}", ) + # --- Part 2: naive truncation gives wrong results --- + nlist_truncated = nlist_shuffled[:, :, :nnei] + ec2 = ext_coord.detach().requires_grad_(True) + trunc_ret = self.model.forward_common_lower( + ec2, + ext_atype, + nlist_truncated, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + e_ref = ref_ret["energy_redu"].detach().cpu().numpy() + e_trunc = trunc_ret["energy_redu"].detach().cpu().numpy() + assert not np.allclose(e_ref, e_trunc, rtol=1e-10, atol=1e-10), ( + "Naive truncation of shuffled nlist should give different energy, " + "but got the same result. The test data may not have enough " + "neighbors shuffled beyond sum(sel) to trigger the bug." + ) + def test_serialize_round_trip(self) -> None: """Test .pt2 → serialize_from_file → deserialize → model gives same outputs.""" loaded_data = serialize_from_file(self.tmpfile.name) From 755c2d3c53ee7b9aa0118c9767f37997f304cc18 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 14:42:44 +0800 Subject: [PATCH 39/56] refactor(pt2): remove need_sorted_nlist_for_lower monkey-patch The fn-internal +1 pad already guarantees n_nnei > sum(sel) at trace time, so _format_nlist's distance-sort branch is always traced. The monkey-patch (mutating the model, try/finally cleanup) was redundant. Two mechanisms remain and suffice: - Inner pad (+1 column) inside the traced fn: ensures sort branch is in the compiled graph and fires at runtime. - C++ createNlistTensor pads to nnei+1: satisfies torch.export Dim guard at the model input level. --- deepmd/pt_expt/model/make_model.py | 23 ++++++++++------------- deepmd/pt_expt/model/spin_model.py | 26 +++++++++++--------------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index f986c840c4..547cf9663d 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -371,19 +371,16 @@ def fn( do_atomic_virial=do_atomic_virial, ) - # Force format_nlist to always use the sort branch during tracing. - model.need_sorted_nlist_for_lower = lambda: True - try: - traced = make_fx(fn, **make_fx_kwargs)( - extended_coord, - extended_atype, - nlist, - mapping, - fparam, - aparam, - ) - finally: - del model.need_sorted_nlist_for_lower + # The +1 pad inside fn guarantees n_nnei > sum(sel) at trace + # time, so _format_nlist's distance-sort branch is always traced. + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) return traced return CM diff --git a/deepmd/pt_expt/model/spin_model.py b/deepmd/pt_expt/model/spin_model.py index 92b9bfa095..9dfb616e29 100644 --- a/deepmd/pt_expt/model/spin_model.py +++ b/deepmd/pt_expt/model/spin_model.py @@ -119,21 +119,17 @@ def fn( do_atomic_virial=do_atomic_virial, ) - # Force format_nlist to always use the sort branch during tracing. - backbone = model.backbone_model - backbone.need_sorted_nlist_for_lower = lambda: True - try: - traced = make_fx(fn, **make_fx_kwargs)( - extended_coord, - extended_atype, - extended_spin, - nlist, - mapping, - fparam, - aparam, - ) - finally: - del backbone.need_sorted_nlist_for_lower + # The +1 pad inside fn guarantees n_nnei > sum(sel) at trace + # time, so _format_nlist's distance-sort branch is always traced. + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) return traced def forward_common_lower( From 444b4e57ac548fefb388b02111f48aca11854025 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 21:21:17 +0800 Subject: [PATCH 40/56] refactor(pt2): remove redundant C++ nnei+1 nlist padding The fn-internal +1 pad inside the traced graph already ensures format_nlist's distance-sort branch fires at runtime. The C++ side no longer needs to pad nlist to nnei+1. Simplify createNlistTensor to just flatten the jagged nlist without a min_nnei parameter. Also add test_oversized_nlist to TestSpinEnerModelExportable to verify the spin model export path handles oversized nlist correctly. --- source/api_cc/src/DeepPotPTExpt.cc | 13 +- source/api_cc/src/DeepSpinPTExpt.cc | 13 +- source/api_cc/src/commonPTExpt.h | 30 ++--- .../pt_expt/model/test_spin_ener_model.py | 124 ++++++++++++++++++ 4 files changed, 148 insertions(+), 32 deletions(-) diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index 4f16aacf04..86301364bc 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -291,11 +291,9 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, .to(device); } - // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 - // columns so the .pt2 model's compiled format_nlist sorts on-device. - firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) - .to(torch::kInt64) - .to(device); + // Flatten raw nlist — the .pt2 model sorts by distance on-device. + firstneigh_tensor = + createNlistTensor(nlist_data.jlist).to(torch::kInt64).to(device); } // Build fparam/aparam tensors (cast to float64 for the model) @@ -548,10 +546,9 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); - // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 - // columns so the .pt2 model's compiled format_nlist sorts on-device. + // Flatten raw nlist — the .pt2 model sorts by distance on-device. at::Tensor nlist_tensor = - createNlistTensor(nlist_raw, nnei + 1).to(torch::kInt64).to(device); + createNlistTensor(nlist_raw).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index c002a1ec8e..dbaaa89e06 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -319,11 +319,9 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, .to(device); } - // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 - // columns so the .pt2 model's compiled format_nlist sorts on-device. - firstneigh_tensor = createNlistTensor(nlist_data.jlist, nnei + 1) - .to(torch::kInt64) - .to(device); + // Flatten raw nlist — the .pt2 model sorts by distance on-device. + firstneigh_tensor = + createNlistTensor(nlist_data.jlist).to(torch::kInt64).to(device); } // Build fparam/aparam tensors @@ -601,10 +599,9 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, torch::from_blob(atype_64.data(), {1, nall}, int_options) .clone() .to(device); - // Build raw nlist tensor — pass all neighbors, padded to at least nnei+1 - // columns so the .pt2 model's compiled format_nlist sorts on-device. + // Flatten raw nlist — the .pt2 model sorts by distance on-device. at::Tensor nlist_tensor = - createNlistTensor(nlist_raw, nnei + 1).to(torch::kInt64).to(device); + createNlistTensor(nlist_raw).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index 6036632fcc..152243a655 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -439,34 +439,32 @@ inline std::string read_zip_entry(const std::string& zip_path, // ============================================================================ // Create raw neighbor list tensor. -// The .pt2 compiled graph has format_nlist which sorts by distance and -// truncates on-device (GPU). The C++ side just flattens the jagged nlist -// and pads to at least `min_nnei` columns so format_nlist can work. +// The .pt2 compiled graph already contains format_nlist (distance sort + +// truncation) and an internal +1 pad that guarantees the sort branch fires. +// The C++ side just flattens the jagged nlist into a rectangular tensor. // ============================================================================ /** - * @brief Flatten a jagged neighbor list into a tensor, padding to min_nnei. + * @brief Flatten a jagged neighbor list into a tensor. * * Each row in `data` may have a different number of neighbors. This function * pads short rows with -1 to produce a tensor of shape - * [1, nloc, max(max_row_length, min_nnei)]. No truncation or distance - * sorting is done — the .pt2 model's compiled format_nlist handles that - * on-device. + * [1, nloc, max_row_length]. No truncation or distance sorting is done — + * the .pt2 model's compiled format_nlist handles that on-device. * - * @param data Jagged neighbor list: data[i] contains neighbor indices - * for local atom i. - * @param min_nnei Minimum number of neighbor columns. Must be > sum(sel) - * so that format_nlist's sort branch executes. + * @param data Jagged neighbor list: data[i] contains neighbor indices + * for local atom i. */ inline torch::Tensor createNlistTensor( - const std::vector>& data, int min_nnei) { + const std::vector>& data) { int nloc = static_cast(data.size()); - // Find max row length - int max_nnei = 0; + int nnei = 0; for (int ii = 0; ii < nloc; ++ii) { - max_nnei = std::max(max_nnei, static_cast(data[ii].size())); + nnei = std::max(nnei, static_cast(data[ii].size())); + } + if (nnei == 0) { + nnei = 1; // at least 1 column to avoid empty tensor } - int nnei = std::max(max_nnei, min_nnei); std::vector flat_data(static_cast(nloc) * nnei, -1); for (int ii = 0; ii < nloc; ++ii) { for (size_t jj = 0; jj < data[ii].size(); ++jj) { diff --git a/source/tests/pt_expt/model/test_spin_ener_model.py b/source/tests/pt_expt/model/test_spin_ener_model.py index 5dd6dbecf2..f7f96392d3 100644 --- a/source/tests/pt_expt/model/test_spin_ener_model.py +++ b/source/tests/pt_expt/model/test_spin_ener_model.py @@ -359,6 +359,56 @@ def setUp(self) -> None: class TestSpinEnerModelExportable(unittest.TestCase): + @staticmethod + def _build_extended_inputs(model, natoms=6): + """Build extended inputs for spin model export tests.""" + from deepmd.dpmodel.utils import ( + build_neighbor_list, + extend_coord_with_ghosts, + normalize_coord, + ) + + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = ( + torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + * 0.5 + ) + + rcut = model.get_rcut() + sel = SPIN_DATA["descriptor"]["sel"] + coord_np = coord.unsqueeze(0).numpy() + atype_np = atype.unsqueeze(0).numpy() + box_np = cell.reshape(1, 9).numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, natoms, 3), + box_np.reshape(1, 3, 3), + ) + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype_np, box_np, rcut + ) + nlist = build_neighbor_list( + ext_coord, ext_atype, natoms, rcut, sel, distinguish_types=True + ) + ext_coord = ext_coord.reshape(1, -1, 3) + spin_np = spin.unsqueeze(0).numpy() + ext_spin = np.take_along_axis( + spin_np, + np.repeat(mapping[:, :, np.newaxis], 3, axis=2), + axis=1, + ) + + ext_coord_t = torch.tensor(ext_coord, dtype=dtype, device=env.DEVICE) + ext_atype_t = torch.tensor(ext_atype, dtype=torch.int64, device=env.DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=env.DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=env.DEVICE) + ext_spin_t = torch.tensor(ext_spin, dtype=dtype, device=env.DEVICE) + return ext_coord_t, ext_atype_t, ext_spin_t, nlist_t, mapping_t, sel + def test_forward_lower_exportable(self) -> None: """Test that SpinEnergyModel.forward_lower_exportable works with make_fx and torch.export.""" from deepmd.dpmodel.utils import ( @@ -538,6 +588,80 @@ def test_forward_lower_exportable(self) -> None: err_msg=f"loaded vs eager (nf=1): {key}", ) + def test_oversized_nlist(self) -> None: + """Test that the exported spin model handles nlist with more neighbors than nnei. + + Verifies two things: + 1. The exported model with an oversized, shuffled nlist produces the + same results as the eager model. + 2. Naive truncation of the shuffled nlist gives different energy, + proving distance sort is necessary. + """ + model = _make_model() + ext_coord_t, ext_atype_t, ext_spin_t, nlist_t, mapping_t, sel = ( + self._build_extended_inputs(model) + ) + nnei = sum(sel) + + # Pad and shuffle nlist + n_extra = nnei + nlist_padded = torch.cat( + [ + nlist_t, + -torch.ones( + (*nlist_t.shape[:2], n_extra), + dtype=nlist_t.dtype, + device=nlist_t.device, + ), + ], + dim=-1, + ) + rng = np.random.default_rng(42) + perm = rng.permutation(nlist_padded.shape[-1]) + nlist_shuffled = nlist_padded[:, :, perm] + assert nlist_shuffled.shape[-1] > nnei + + output_keys = ("energy", "extended_force", "extended_force_mag", "virial") + + # --- Part 1: exported model sorts correctly --- + traced = model.forward_lower_exportable( + ext_coord_t, + ext_atype_t, + ext_spin_t, + nlist_shuffled, + mapping_t, + ) + ret_traced = traced( + ext_coord_t, ext_atype_t, ext_spin_t, nlist_shuffled, mapping_t, None, None + ) + + ec = ext_coord_t.detach().requires_grad_(True) + ret_eager = model.forward_lower( + ec, ext_atype_t, ext_spin_t, nlist_shuffled, mapping_t + ) + + for key in output_keys: + if ret_eager.get(key) is not None and key in ret_traced: + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_traced[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"oversized nlist, key={key}", + ) + + # --- Part 2: naive truncation gives wrong results --- + nlist_truncated = nlist_shuffled[:, :, :nnei] + ec2 = ext_coord_t.detach().requires_grad_(True) + ret_trunc = model.forward_lower( + ec2, ext_atype_t, ext_spin_t, nlist_truncated, mapping_t + ) + e_ref = ret_eager["energy"].detach().cpu().numpy() + e_trunc = ret_trunc["energy"].detach().cpu().numpy() + assert not np.allclose(e_ref, e_trunc, rtol=1e-10, atol=1e-10), ( + "Naive truncation of shuffled nlist should give different energy." + ) + if __name__ == "__main__": unittest.main() From 583df2e53fe19e93d08e0823c171df8fe078bc1e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 21 Apr 2026 21:24:14 +0800 Subject: [PATCH 41/56] fix(test): link torch in C++ test binary, fix spin PtExpt guards The test binary (runUnitTests_cc) was missing TORCH_LIBRARIES in its link line, so __has_include for the AOTInductor header always failed and BUILD_PT_EXPT/BUILD_PT_EXPT_SPIN evaluated to 0. All PtExpt C++ tests were silently skipping in CI. Also fix DeepSpin test guards: change #ifndef BUILD_PYTORCH to #if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT_SPIN, and add #include "DeepSpinPTExpt.h" so the macro is defined. --- source/api_cc/tests/CMakeLists.txt | 1 + source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc | 5 +++-- source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/source/api_cc/tests/CMakeLists.txt b/source/api_cc/tests/CMakeLists.txt index a3e7d067f7..a812f776fc 100644 --- a/source/api_cc/tests/CMakeLists.txt +++ b/source/api_cc/tests/CMakeLists.txt @@ -11,6 +11,7 @@ if(ENABLE_TENSORFLOW) endif() if(ENABLE_PYTORCH) target_compile_definitions(runUnitTests_cc PRIVATE BUILD_PYTORCH) + target_link_libraries(runUnitTests_cc "${TORCH_LIBRARIES}") endif() if(ENABLE_JAX) target_compile_definitions(runUnitTests_cc PRIVATE BUILD_JAX) diff --git a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc index 27c1836aa1..3a248e2cf3 100644 --- a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc +++ b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc @@ -8,6 +8,7 @@ #include #include "DeepSpin.h" +#include "DeepSpinPTExpt.h" #include "neighbor_list.h" #include "test_utils.h" @@ -52,7 +53,7 @@ class TestInferDeepSpinDpaPtExpt : public ::testing::Test { GTEST_SKIP() << "Skipping: " << model_path << " not found."; } } -#ifndef BUILD_PYTORCH +#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT_SPIN GTEST_SKIP() << "Skip because PyTorch support is not enabled."; #endif dp.init(model_path); @@ -238,7 +239,7 @@ class TestInferDeepSpinDpaPtExptNopbc : public ::testing::Test { GTEST_SKIP() << "Skipping: " << model_path << " not found."; } } -#ifndef BUILD_PYTORCH +#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT_SPIN GTEST_SKIP() << "Skip because PyTorch support is not enabled."; #endif dp.init(model_path); diff --git a/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc b/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc index e58d9c0f78..ada7c2eba7 100644 --- a/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc +++ b/source/api_cc/tests/test_deepspin_model_devi_ptexpt.cc @@ -8,6 +8,7 @@ #include #include "DeepSpin.h" +#include "DeepSpinPTExpt.h" #include "neighbor_list.h" #include "test_utils.h" @@ -46,7 +47,7 @@ class TestInferDeepSpinModeDeviPtExpt : public ::testing::Test { GTEST_SKIP() << "Skipping: " << model1_path << " not found."; } } -#ifndef BUILD_PYTORCH +#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT_SPIN GTEST_SKIP() << "Skip because PyTorch support is not enabled."; #endif dp0.init(model0_path); From 862f5609a47c6eaedb344ac6fbf8244d95734b70 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 22 Apr 2026 15:28:54 +0800 Subject: [PATCH 42/56] refactor(cc): unify createNlistTensor and fix oversized nlist handling - Move createNlistTensor to shared commonPT.h, replacing duplicates in DeepPotPT.cc, DeepSpinPT.cc, and commonPTExpt.h. The unified version handles jagged rows with a min_nnei parameter for .pt2 models. - Fix shuffle() in common.cc to skip -1 padding entries (was causing out-of-bounds access on fwd_map[-1]). - Add oversized nlist C++ tests for both DeepPot and DeepSpin .pt2 backends, with _pad_shuffle_nlist test helper. --- source/api_cc/include/commonPT.h | 35 +++++++++++++++ source/api_cc/src/DeepPotPT.cc | 16 ------- source/api_cc/src/DeepPotPTExpt.cc | 6 +-- source/api_cc/src/DeepSpinPT.cc | 18 +------- source/api_cc/src/DeepSpinPTExpt.cc | 6 +-- source/api_cc/src/common.cc | 4 +- source/api_cc/src/commonPTExpt.h | 31 ------------- .../tests/test_deeppot_dpa_ptexpt_spin.cc | 42 +++++++++++++++++ source/api_cc/tests/test_deeppot_ptexpt.cc | 45 +++++++++++++++++++ source/api_cc/tests/test_utils.h | 29 ++++++++++++ 10 files changed, 160 insertions(+), 72 deletions(-) diff --git a/source/api_cc/include/commonPT.h b/source/api_cc/include/commonPT.h index a888d29758..20c3d5c1f0 100644 --- a/source/api_cc/include/commonPT.h +++ b/source/api_cc/include/commonPT.h @@ -101,6 +101,41 @@ inline void build_comm_dict_with_virtual_atoms( remapped_sendnum.data(), remapped_recvnum.data()); } +/** + * @brief Flatten a jagged neighbor list into a [1, nloc, nnei] tensor. + * + * Each row in @p data may have a different number of neighbors. Short rows + * are padded with -1. The output width is max(min_nnei, max_row_length). + * No truncation or distance sorting is done — the model's format_nlist + * handles that on-device. + * + * @param data Jagged neighbor list: data[i] holds neighbor indices + * for local atom i. + * @param min_nnei Minimum width of the nnei dimension. For .pt2 models + * this should be sum(sel) from the model metadata, because + * torch.export marks nnei >= sum(sel) as a dynamic constraint. + * For .pth models 0 (the default) is fine. + */ +inline torch::Tensor createNlistTensor( + const std::vector>& data, int min_nnei = 0) { + int nloc = static_cast(data.size()); + int nnei = min_nnei; + for (int ii = 0; ii < nloc; ++ii) { + nnei = std::max(nnei, static_cast(data[ii].size())); + } + if (nnei == 0) { + nnei = 1; // at least 1 column to avoid empty tensor + } + std::vector flat_data(static_cast(nloc) * nnei, -1); + for (int ii = 0; ii < nloc; ++ii) { + for (size_t jj = 0; jj < data[ii].size(); ++jj) { + flat_data[static_cast(ii) * nnei + jj] = data[ii][jj]; + } + } + torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); + return flat_tensor.view({1, nloc, nnei}); +} + } // namespace deepmd #endif // BUILD_PYTORCH diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index fc09949691..d69dbb8f82 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -32,22 +32,6 @@ void DeepPotPT::translate_error(std::function f) { } } -torch::Tensor createNlistTensor(const std::vector>& data) { - size_t total_size = 0; - for (const auto& row : data) { - total_size += row.size(); - } - std::vector flat_data; - flat_data.reserve(total_size); - for (const auto& row : data) { - flat_data.insert(flat_data.end(), row.begin(), row.end()); - } - - torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); - int nloc = data.size(); - int nnei = nloc > 0 ? total_size / nloc : 0; - return flat_tensor.view({1, nloc, nnei}); -} DeepPotPT::DeepPotPT() : inited(false) {} DeepPotPT::DeepPotPT(const std::string& model, const int& gpu_rank, diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index 86301364bc..e95f622a19 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -12,12 +12,12 @@ #include "SimulationRegion.h" #include "common.h" +#include "commonPT.h" #include "commonPTExpt.h" #include "device.h" #include "errors.h" #include "neighbor_list.h" -using deepmd::ptexpt::createNlistTensor; using deepmd::ptexpt::parse_json; using deepmd::ptexpt::read_zip_entry; @@ -293,7 +293,7 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, // Flatten raw nlist — the .pt2 model sorts by distance on-device. firstneigh_tensor = - createNlistTensor(nlist_data.jlist).to(torch::kInt64).to(device); + createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); } // Build fparam/aparam tensors (cast to float64 for the model) @@ -548,7 +548,7 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, .to(device); // Flatten raw nlist — the .pt2 model sorts by distance on-device. at::Tensor nlist_tensor = - createNlistTensor(nlist_raw).to(torch::kInt64).to(device); + createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index e7fdb0a6a5..5add377045 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -31,22 +31,6 @@ void DeepSpinPT::translate_error(std::function f) { } } -torch::Tensor createNlistTensor2(const std::vector>& data) { - size_t total_size = 0; - for (const auto& row : data) { - total_size += row.size(); - } - std::vector flat_data; - flat_data.reserve(total_size); - for (const auto& row : data) { - flat_data.insert(flat_data.end(), row.begin(), row.end()); - } - - torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); - int nloc = data.size(); - int nnei = nloc > 0 ? total_size / nloc : 0; - return flat_tensor.view({1, nloc, nnei}); -} DeepSpinPT::DeepSpinPT() : inited(false) {} DeepSpinPT::DeepSpinPT(const std::string& model, const int& gpu_rank, @@ -209,7 +193,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener, .to(device); } } - at::Tensor firstneigh = createNlistTensor2(nlist_data.jlist); + at::Tensor firstneigh = createNlistTensor(nlist_data.jlist); firstneigh_tensor = firstneigh.to(torch::kInt64).to(device); bool do_atom_virial_tensor = atomic; c10::optional fparam_tensor; diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index dbaaa89e06..afeeab7f05 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -12,12 +12,12 @@ #include "SimulationRegion.h" #include "common.h" +#include "commonPT.h" #include "commonPTExpt.h" #include "device.h" #include "errors.h" #include "neighbor_list.h" -using deepmd::ptexpt::createNlistTensor; using deepmd::ptexpt::parse_json; using deepmd::ptexpt::read_zip_entry; @@ -321,7 +321,7 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, // Flatten raw nlist — the .pt2 model sorts by distance on-device. firstneigh_tensor = - createNlistTensor(nlist_data.jlist).to(torch::kInt64).to(device); + createNlistTensor(nlist_data.jlist, nnei).to(torch::kInt64).to(device); } // Build fparam/aparam tensors @@ -601,7 +601,7 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, .to(device); // Flatten raw nlist — the .pt2 model sorts by distance on-device. at::Tensor nlist_tensor = - createNlistTensor(nlist_raw).to(torch::kInt64).to(device); + createNlistTensor(nlist_raw, nnei).to(torch::kInt64).to(device); std::vector mapping_64(mapping_vec.begin(), mapping_vec.end()); at::Tensor mapping_tensor = torch::from_blob(mapping_64.data(), {1, nall}, int_options) diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 1ad1a5c97b..b446417625 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -295,13 +295,13 @@ void deepmd::NeighborListData::shuffle(const AtomMap& map) { void deepmd::NeighborListData::shuffle(const std::vector& fwd_map) { int nloc = fwd_map.size(); for (unsigned ii = 0; ii < ilist.size(); ++ii) { - if (ilist[ii] < nloc) { + if (ilist[ii] >= 0 && ilist[ii] < nloc) { ilist[ii] = fwd_map[ilist[ii]]; } } for (unsigned ii = 0; ii < jlist.size(); ++ii) { for (unsigned jj = 0; jj < jlist[ii].size(); ++jj) { - if (jlist[ii][jj] < nloc) { + if (jlist[ii][jj] >= 0 && jlist[ii][jj] < nloc) { jlist[ii][jj] = fwd_map[jlist[ii][jj]]; } } diff --git a/source/api_cc/src/commonPTExpt.h b/source/api_cc/src/commonPTExpt.h index 152243a655..ddc8ad5014 100644 --- a/source/api_cc/src/commonPTExpt.h +++ b/source/api_cc/src/commonPTExpt.h @@ -444,36 +444,5 @@ inline std::string read_zip_entry(const std::string& zip_path, // The C++ side just flattens the jagged nlist into a rectangular tensor. // ============================================================================ -/** - * @brief Flatten a jagged neighbor list into a tensor. - * - * Each row in `data` may have a different number of neighbors. This function - * pads short rows with -1 to produce a tensor of shape - * [1, nloc, max_row_length]. No truncation or distance sorting is done — - * the .pt2 model's compiled format_nlist handles that on-device. - * - * @param data Jagged neighbor list: data[i] contains neighbor indices - * for local atom i. - */ -inline torch::Tensor createNlistTensor( - const std::vector>& data) { - int nloc = static_cast(data.size()); - int nnei = 0; - for (int ii = 0; ii < nloc; ++ii) { - nnei = std::max(nnei, static_cast(data[ii].size())); - } - if (nnei == 0) { - nnei = 1; // at least 1 column to avoid empty tensor - } - std::vector flat_data(static_cast(nloc) * nnei, -1); - for (int ii = 0; ii < nloc; ++ii) { - for (size_t jj = 0; jj < data[ii].size(); ++jj) { - flat_data[static_cast(ii) * nnei + jj] = data[ii][jj]; - } - } - torch::Tensor flat_tensor = torch::tensor(flat_data, torch::kInt32); - return flat_tensor.view({1, nloc, nnei}); -} - } // namespace ptexpt } // namespace deepmd diff --git a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc index 3a248e2cf3..95e4b3e48c 100644 --- a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc +++ b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc @@ -471,3 +471,45 @@ TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_lmp_nlist_atomic) { EXPECT_LT(fabs(atom_vir[ii] - expected_atom_v[ii]), EPSILON); } } + +TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_lmp_nlist_oversized) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial; + + std::vector > nlist_data = { + {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5}, + {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}}; + // Pad with extra -1 entries and shuffle to create oversized nlist + std::vector > nlist_oversized; + _pad_shuffle_nlist(nlist_oversized, nlist_data, 50); + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_oversized); + dp.compute(ener, force, force_mag, virial, coord, spin, atype, box, 0, inlist, + 0); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(force_mag.size(), natoms * 3); + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + EXPECT_LT(fabs(force_mag[ii] - expected_fm[ii]), EPSILON); + } + EXPECT_FALSE(virial.empty()) << "Virial should not be empty"; + EXPECT_EQ(virial.size(), 9); + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} diff --git a/source/api_cc/tests/test_deeppot_ptexpt.cc b/source/api_cc/tests/test_deeppot_ptexpt.cc index cfde173a82..7369d27708 100644 --- a/source/api_cc/tests/test_deeppot_ptexpt.cc +++ b/source/api_cc/tests/test_deeppot_ptexpt.cc @@ -356,6 +356,51 @@ TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_2rc) { } } +TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_oversized) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_f = this->expected_f; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + std::vector& expected_tot_v = this->expected_tot_v; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + // Pad with extra -1 entries and shuffle to create oversized nlist + std::vector > nlist_oversized; + _pad_shuffle_nlist(nlist_oversized, nlist_data, 200); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_oversized); + + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + + EXPECT_LT(fabs(ener - expected_tot_e), EPSILON); + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON); + } + for (int ii = 0; ii < 3 * 3; ++ii) { + EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); + } +} + TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_type_sel) { using VALUETYPE = TypeParam; std::vector& coord = this->coord; diff --git a/source/api_cc/tests/test_utils.h b/source/api_cc/tests/test_utils.h index 51f05f9b54..9d5d5a4815 100644 --- a/source/api_cc/tests/test_utils.h +++ b/source/api_cc/tests/test_utils.h @@ -1,6 +1,7 @@ // SPDX-License-Identifier: LGPL-3.0-or-later #pragma once #include +#include #include "SimulationRegion.h" #include "gtest/gtest.h" @@ -84,6 +85,34 @@ inline void _build_nlist(std::vector>& nlist_data, coord_cpy.assign(coord_cpy_.begin(), coord_cpy_.end()); } +/** + * @brief Pad each atom's neighbor list with -1 entries and shuffle. + * + * Mimics the Python test_oversized_nlist approach: append n_extra + * padding entries (-1) to each row, then apply a deterministic + * permutation so that real neighbors are no longer at the front. + * This exercises the model's internal distance-sort + truncation. + */ +inline void _pad_shuffle_nlist(std::vector>& nlist_out, + const std::vector>& nlist_in, + int n_extra, + unsigned int seed = 42) { + nlist_out.resize(nlist_in.size()); + for (size_t ii = 0; ii < nlist_in.size(); ++ii) { + // copy original + pad with -1 + nlist_out[ii] = nlist_in[ii]; + for (int jj = 0; jj < n_extra; ++jj) { + nlist_out[ii].push_back(-1); + } + // deterministic shuffle (std::minstd_rand for reproducibility) + std::minstd_rand rng(seed + static_cast(ii)); + for (size_t jj = nlist_out[ii].size() - 1; jj > 0; --jj) { + size_t kk = rng() % (jj + 1); + std::swap(nlist_out[ii][jj], nlist_out[ii][kk]); + } + } +} + template class EnergyModelTest { protected: From 12ad03b9e0ab8ff24c833cc9b354e8abd4fdb626 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 22 Apr 2026 15:52:14 +0800 Subject: [PATCH 43/56] fix(test): remove unused _DESCRIPTOR_DPA1_WITH_ATTN config --- source/tests/pt_expt/test_training.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/source/tests/pt_expt/test_training.py b/source/tests/pt_expt/test_training.py index 512033dbe1..16bf20cffa 100644 --- a/source/tests/pt_expt/test_training.py +++ b/source/tests/pt_expt/test_training.py @@ -61,17 +61,6 @@ "seed": 1, } -_DESCRIPTOR_DPA1_WITH_ATTN = { - "type": "dpa1", - "sel": 12, - "rcut_smth": 0.50, - "rcut": 3.00, - "neuron": [8, 16], - "axis_neuron": 4, - "attn_layer": 2, - "precision": "float64", - "seed": 1, -} _DESCRIPTOR_DPA2 = { "type": "dpa2", "repinit": { From b0096e036fde82661bcef845404ad22d73ae8bdd Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 22 Apr 2026 18:53:26 +0800 Subject: [PATCH 44/56] fix(pt2): relax nnei dynamic shape lower bound from 2 to 1 --- deepmd/pt_expt/utils/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 165066664c..0e5b3d6934 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -206,7 +206,7 @@ def _build_dynamic_shapes( nframes_dim = torch.export.Dim("nframes", min=1) nall_dim = torch.export.Dim("nall", min=1) nloc_dim = torch.export.Dim("nloc", min=1) - nnei_dim = torch.export.Dim("nnei", min=max(2, model_nnei)) + nnei_dim = torch.export.Dim("nnei", min=max(1, model_nnei)) if has_spin: # (ext_coord, ext_atype, ext_spin, nlist, mapping, fparam, aparam) From afa732f28a0f6b9d7b64d475ca29f293be532977 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 23 Apr 2026 22:49:03 +0800 Subject: [PATCH 45/56] fix(pt2): remove eliminate_dead_code from _strip_shape_assertions Dead-code elimination after removing _assert_scalar nodes incorrectly removes intermediate computation nodes that share sub-expressions with the autograd gradient path, producing NaN forces for DPA1/se_atten in the NoPBC case. Remove the eliminate_dead_code() call; the leftover nodes are harmless unused scalar computations. --- deepmd/pt_expt/utils/serialization.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 0e5b3d6934..2cac839cbf 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -34,6 +34,17 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: message content is not reliable. Since ``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape guards into these deferred assertions, removing all of them is safe. + + .. note:: + + We intentionally do **not** call ``graph.eliminate_dead_code()`` + after removing assertion nodes. Dead-code elimination can remove + intermediate computation nodes that share sub-expressions with the + autograd gradient path (traced via ``torch.autograd.grad`` inside the + exported function). Removing those nodes produces NaN forces for + models like DPA1/se_atten in the NoPBC case. The leftover "dead" + nodes (computing the boolean condition for the removed assertions) + are harmless — they just compute unused scalar values. """ graph = graph_module.graph for node in list(graph.nodes): @@ -42,7 +53,6 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: and node.target is torch.ops.aten._assert_scalar.default ): graph.erase_node(node) - graph.eliminate_dead_code() graph_module.recompile() From 59973bebd8b33aa656790307b088b55e161f51e4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 24 Apr 2026 12:02:32 +0800 Subject: [PATCH 46/56] fix(pt2): neutralise shape assertions instead of erasing them Erasing _assert_scalar nodes from the exported FX graph (and especially calling eliminate_dead_code afterwards) disturbs the graph structure and produces NaN gradients for DPA1/se_atten in the NoPBC case on some Python/torch versions. Replace each assertion's condition with True so the node stays in the graph but never fires at runtime. This preserves the graph topology and avoids the NaN issue across all tested configurations. --- deepmd/pt_expt/utils/serialization.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 2cac839cbf..35649ae6b6 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -17,7 +17,7 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: - """Remove shape-guard assertion nodes from an exported graph. + """Neutralise shape-guard assertion nodes in an exported graph. ``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape relationships discovered during tracing. These guards can be spurious: @@ -29,22 +29,10 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: like ``Ne(nnei, sum(sel))``. These are spurious because the compiled graph handles any ``nnei >= sum(sel)`` correctly. - The assertion messages use opaque symbolic variable names (e.g. - ``Ne(s22, s96)``) rather than human-readable names, so filtering by - message content is not reliable. Since - ``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape - guards into these deferred assertions, removing all of them is safe. - - .. note:: - - We intentionally do **not** call ``graph.eliminate_dead_code()`` - after removing assertion nodes. Dead-code elimination can remove - intermediate computation nodes that share sub-expressions with the - autograd gradient path (traced via ``torch.autograd.grad`` inside the - exported function). Removing those nodes produces NaN forces for - models like DPA1/se_atten in the NoPBC case. The leftover "dead" - nodes (computing the boolean condition for the removed assertions) - are harmless — they just compute unused scalar values. + Instead of erasing the assertion nodes (which can disturb the FX graph + structure and produce NaN gradients on some Python/torch versions), we + replace each assertion's condition with ``True`` so that the node stays + in the graph but never fires at runtime. """ graph = graph_module.graph for node in list(graph.nodes): @@ -52,7 +40,10 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: node.op == "call_function" and node.target is torch.ops.aten._assert_scalar.default ): - graph.erase_node(node) + # Replace the condition with True so the assertion always passes + # but the node stays in the graph. Erasing nodes can disturb the + # graph structure and produce NaN on some Python/torch versions. + node.args = (True, node.args[1]) graph_module.recompile() From 39f84acdf27abd8b5e9e01cd4b922641165f5e58 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 24 Apr 2026 21:48:10 +0800 Subject: [PATCH 47/56] fix(pt2): prevent CUDA NaN in attention backward by disabling kernel fusion AOTInductor's aggressive kernel fusion (realize_opcount_threshold) causes NaN in the backward pass (force/virial) of attention-based descriptors (DPA1, DPA2) on CUDA for certain coordinate patterns (e.g. NoPbc with two well-separated clusters). Setting realize_opcount_threshold=0 during compilation prevents the problematic fusion while keeping correct results on both CPU and CUDA. --- deepmd/pt_expt/utils/serialization.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 35649ae6b6..58b14cdf63 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -586,8 +586,18 @@ def _deserialize_to_file_pt2( data, model_json_override, do_atomic_virial ) - # Compile via AOTInductor into a .pt2 package - aoti_compile_and_package(exported, package_path=model_file) + # Compile via AOTInductor into a .pt2 package. + # realize_opcount_threshold=0 prevents aggressive kernel fusion that + # causes NaN in the backward pass (force/virial) of attention-based + # descriptors (DPA1, DPA2) on CUDA for certain coordinate patterns. + import torch._inductor.config as _inductor_config + + saved_threshold = _inductor_config.realize_opcount_threshold + _inductor_config.realize_opcount_threshold = 0 + try: + aoti_compile_and_package(exported, package_path=model_file) + finally: + _inductor_config.realize_opcount_threshold = saved_threshold # Embed metadata into the .pt2 ZIP archive model_def_script = data.get("model_def_script") or {} From 57cfe3e0ebc72ec11e798abbd530e047451c5477 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 24 Apr 2026 22:21:39 +0800 Subject: [PATCH 48/56] fix(pt2): only apply realize_opcount_threshold=0 on CUDA The inductor fusion workaround causes force divergence for DPA3 on CPU. Restrict it to CUDA where the NaN issue actually occurs. --- deepmd/pt_expt/utils/serialization.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 58b14cdf63..8ac3ca4da2 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -586,14 +586,18 @@ def _deserialize_to_file_pt2( data, model_json_override, do_atomic_virial ) - # Compile via AOTInductor into a .pt2 package. - # realize_opcount_threshold=0 prevents aggressive kernel fusion that + # On CUDA, aggressive kernel fusion (default realize_opcount_threshold=30) # causes NaN in the backward pass (force/virial) of attention-based - # descriptors (DPA1, DPA2) on CUDA for certain coordinate patterns. + # descriptors (DPA1, DPA2). Setting threshold=0 prevents fusion and + # avoids the NaN. Only applied on CUDA; CPU compilation is unaffected. import torch._inductor.config as _inductor_config + import deepmd.pt_expt.utils.env as _env + + is_cuda = _env.DEVICE.type == "cuda" saved_threshold = _inductor_config.realize_opcount_threshold - _inductor_config.realize_opcount_threshold = 0 + if is_cuda: + _inductor_config.realize_opcount_threshold = 0 try: aoti_compile_and_package(exported, package_path=model_file) finally: From cdbfeadce27ebac95b82c188def31723efc7d68c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 11:48:59 +0800 Subject: [PATCH 49/56] refactor(pt_expt): use extra_nlist_sort override instead of stripping shape assertions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the previous "strip all _assert_scalar nodes for non-spin models" approach with a cleaner one that prevents the spurious nnei guards from being emitted in the first place. The compiled .pt2 graph must tolerate oversized nlists at runtime (LAMMPS builds with rcut+skin), which requires _format_nlist's distance-sort branch to be present in the compiled graph. Previously this was forced by padding the nlist with one -1 column inside the traced fn so that n_nnei > sum(sel) at trace time; the resulting Eq(n_nnei, nnei) shape guard was then stripped from the exported graph for every model. This commit instead: * Swaps the operand order in _format_nlist to 'if extra_nlist_sort or n_nnei > nnei:' so Python short-circuits and the symbolic n_nnei comparison is never evaluated when the override is set. * Adds a small _pad_nlist_for_export helper that prepends a single -1 column inside the traced fn (still needed so the terminal ret[..., :nnei] slice produces a statically sized output, but no longer triggers a guard). * Wraps the make_fx() call in each forward_lower_exportable with a try/finally that monkeypatches model.need_sorted_nlist_for_lower to return True (or the spin backbone's), forcing the sort branch. * Re-narrows _strip_shape_assertions to spin models only — non-spin exports no longer emit any spurious shape guards. Eight models are touched: ener, spin_ener, dp_zbl, dipole, dos, dp_linear, polar, property — each has its own forward_lower_exportable that needed the same treatment. --- deepmd/dpmodel/model/make_model.py | 8 ++- deepmd/pt_expt/model/dipole_model.py | 16 ++++-- deepmd/pt_expt/model/dos_model.py | 16 ++++-- deepmd/pt_expt/model/dp_linear_model.py | 16 ++++-- deepmd/pt_expt/model/dp_zbl_model.py | 18 +++++-- deepmd/pt_expt/model/make_model.py | 67 ++++++++++++++++--------- deepmd/pt_expt/model/polar_model.py | 16 ++++-- deepmd/pt_expt/model/property_model.py | 16 ++++-- deepmd/pt_expt/model/spin_model.py | 47 +++++++++-------- deepmd/pt_expt/train/training.py | 1 + deepmd/pt_expt/utils/serialization.py | 56 ++++++++++----------- source/tests/pt_expt/export_helpers.py | 27 +--------- 12 files changed, 185 insertions(+), 119 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 597f8ea006..fb77838b4c 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -614,7 +614,13 @@ def _format_nlist( axis=-1, ) - if n_nnei > nnei or extra_nlist_sort: + # Order matters for torch.export: Python evaluates `or` left-to-right + # with short-circuit. When `extra_nlist_sort=True` (Python bool) is + # on the left, the right-hand `n_nnei > nnei` is not evaluated, so no + # symbolic guard is registered on the dynamic `n_nnei` dimension. + # Swapping the operands would force the SymInt comparison to run and + # emit an `_assert_scalar` node in the exported graph. + if extra_nlist_sort or n_nnei > nnei: n_nf, n_nloc, n_nnei = nlist.shape # make a copy before revise m_real_nei = nlist >= 0 diff --git a/deepmd/pt_expt/model/dipole_model.py b/deepmd/pt_expt/model/dipole_model.py index 79ae26024e..4b0e570ecb 100644 --- a/deepmd/pt_expt/model/dipole_model.py +++ b/deepmd/pt_expt/model/dipole_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -16,6 +17,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -137,6 +139,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, @@ -147,6 +150,13 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, extended_atype, nlist, mapping, fparam, aparam - ) + # See make_model.py for the rationale of the pad + monkeypatch. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced diff --git a/deepmd/pt_expt/model/dos_model.py b/deepmd/pt_expt/model/dos_model.py index 2e69d90ab3..219c22e753 100644 --- a/deepmd/pt_expt/model/dos_model.py +++ b/deepmd/pt_expt/model/dos_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -16,6 +17,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -117,6 +119,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, @@ -127,6 +130,13 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, extended_atype, nlist, mapping, fparam, aparam - ) + # See make_model.py for the rationale of the pad + monkeypatch. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced diff --git a/deepmd/pt_expt/model/dp_linear_model.py b/deepmd/pt_expt/model/dp_linear_model.py index 46790c877e..0ac75659b0 100644 --- a/deepmd/pt_expt/model/dp_linear_model.py +++ b/deepmd/pt_expt/model/dp_linear_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -19,6 +20,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -142,6 +144,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, @@ -152,9 +155,16 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, extended_atype, nlist, mapping, fparam, aparam - ) + # See make_model.py for the rationale of the pad + monkeypatch. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced @classmethod def update_sel( diff --git a/deepmd/pt_expt/model/dp_zbl_model.py b/deepmd/pt_expt/model/dp_zbl_model.py index b7f164114b..baa30c4ce0 100644 --- a/deepmd/pt_expt/model/dp_zbl_model.py +++ b/deepmd/pt_expt/model/dp_zbl_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -16,6 +17,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -139,6 +141,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, @@ -149,6 +152,15 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, extended_atype, nlist, mapping, fparam, aparam - ) + # Force `_format_nlist`'s sort branch into the compiled graph so the + # exported model tolerates oversized nlists at runtime — see + # make_model.py for the full rationale. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 547cf9663d..b28b81ffb1 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import math +import types from typing import ( Any, ) @@ -28,6 +29,28 @@ ) +def _pad_nlist_for_export(nlist: torch.Tensor) -> torch.Tensor: + """Append a single ``-1`` column to ``nlist`` for export-time tracing. + + Used inside ``forward_common_lower_exportable`` (and its spin counterpart) + so that ``_format_nlist``'s terminal slice ``ret[..., :nnei]`` truncates + to a statically sized output. Without the extra column, torch.export + cannot prove the ``ret.shape[-1] == nnei`` assertion at trace time and + would specialise the dynamic ``nnei`` dim to the sample value. + + Combined with the short-circuit order in ``_format_nlist`` + (``extra_nlist_sort`` on the left) and the ``need_sorted_nlist_for_lower`` + override during tracing, this keeps the compiled graph's ``nnei`` axis + fully dynamic and free of symbolic shape guards. + """ + pad = -torch.ones( + (*nlist.shape[:2], 1), + dtype=nlist.dtype, + device=nlist.device, + ) + return torch.cat([nlist, pad], dim=-1) + + def _cal_hessian_ext( model: Any, kk: str, @@ -346,21 +369,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) - # Pad nlist with one extra -1 column inside the traced function. - # This ensures n_nnei > sum(sel), forcing the sort branch in - # _format_nlist. The padding becomes part of the compiled graph, - # so callers never need to pad externally. - nlist = torch.cat( - [ - nlist, - -torch.ones( - (*nlist.shape[:2], 1), - dtype=nlist.dtype, - device=nlist.device, - ), - ], - dim=-1, - ) + nlist = _pad_nlist_for_export(nlist) return model.forward_common_lower( extended_coord, extended_atype, @@ -371,16 +380,26 @@ def fn( do_atomic_virial=do_atomic_virial, ) - # The +1 pad inside fn guarantees n_nnei > sum(sel) at trace - # time, so _format_nlist's distance-sort branch is always traced. - traced = make_fx(fn, **make_fx_kwargs)( - extended_coord, - extended_atype, - nlist, - mapping, - fparam, - aparam, + # Force `_format_nlist`'s sort branch into the compiled graph so the + # exported model tolerates oversized nlists at runtime (LAMMPS builds + # nlists with rcut+skin). Combined with the short-circuit order in + # `_format_nlist`, no symbolic guard on the dynamic `nnei` axis is + # emitted. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType( + lambda self: True, model ) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort return traced return CM diff --git a/deepmd/pt_expt/model/polar_model.py b/deepmd/pt_expt/model/polar_model.py index d421bb76a4..dd6b1c5d0f 100644 --- a/deepmd/pt_expt/model/polar_model.py +++ b/deepmd/pt_expt/model/polar_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -16,6 +17,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -117,6 +119,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, @@ -127,6 +130,13 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, extended_atype, nlist, mapping, fparam, aparam - ) + # See make_model.py for the rationale of the pad + monkeypatch. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced diff --git a/deepmd/pt_expt/model/property_model.py b/deepmd/pt_expt/model/property_model.py index 72a327fb03..223f8e5d78 100644 --- a/deepmd/pt_expt/model/property_model.py +++ b/deepmd/pt_expt/model/property_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -16,6 +17,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -124,6 +126,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) + nlist = _pad_nlist_for_export(nlist) return model.forward_lower( extended_coord, extended_atype, @@ -134,6 +137,13 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, **make_fx_kwargs)( - extended_coord, extended_atype, nlist, mapping, fparam, aparam - ) + # See make_model.py for the rationale of the pad + monkeypatch. + _orig_need_sort = model.need_sorted_nlist_for_lower + model.need_sorted_nlist_for_lower = types.MethodType(lambda self: True, model) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) + finally: + model.need_sorted_nlist_for_lower = _orig_need_sort + return traced diff --git a/deepmd/pt_expt/model/spin_model.py b/deepmd/pt_expt/model/spin_model.py index 9dfb616e29..e69ee29f5a 100644 --- a/deepmd/pt_expt/model/spin_model.py +++ b/deepmd/pt_expt/model/spin_model.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import types from typing import ( Any, ) @@ -17,6 +18,7 @@ ) from .make_model import ( + _pad_nlist_for_export, make_model, ) from .model import ( @@ -96,18 +98,7 @@ def fn( aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) - # Pad nlist inside traced function (see make_model.py for rationale). - nlist = torch.cat( - [ - nlist, - -torch.ones( - (*nlist.shape[:2], 1), - dtype=nlist.dtype, - device=nlist.device, - ), - ], - dim=-1, - ) + nlist = _pad_nlist_for_export(nlist) return model.forward_common_lower( extended_coord, extended_atype, @@ -119,17 +110,29 @@ def fn( do_atomic_virial=do_atomic_virial, ) - # The +1 pad inside fn guarantees n_nnei > sum(sel) at trace - # time, so _format_nlist's distance-sort branch is always traced. - traced = make_fx(fn, **make_fx_kwargs)( - extended_coord, - extended_atype, - extended_spin, - nlist, - mapping, - fparam, - aparam, + # Force the sort branch of `_format_nlist` into the compiled graph by + # overriding `need_sorted_nlist_for_lower` on the backbone (which is + # where `call_common_lower` reads it). Short-circuit `or` in + # `_format_nlist` then skips the symbolic `n_nnei > nnei` comparison, + # so no spurious shape guard is emitted. See make_model.py for the + # non-spin counterpart. + backbone = model.backbone_model + _orig_need_sort = backbone.need_sorted_nlist_for_lower + backbone.need_sorted_nlist_for_lower = types.MethodType( + lambda self: True, backbone ) + try: + traced = make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) + finally: + backbone.need_sorted_nlist_for_lower = _orig_need_sort return traced def forward_common_lower( diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index a26cc44b19..050b0ee52f 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -945,6 +945,7 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None: n_attn, task_key, ) + inp, _ = self.get_data(is_train=True, task_key=task_key) coord = inp["coord"].detach() atype = inp["atype"].detach() diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 8ac3ca4da2..197fa8a760 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -17,22 +17,27 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: - """Neutralise shape-guard assertion nodes in an exported graph. + """Neutralise shape-guard assertion nodes in a spin model's exported graph. ``torch.export`` inserts ``aten._assert_scalar`` nodes for symbolic shape - relationships discovered during tracing. These guards can be spurious: - - * **Spin models**: atom-doubling logic creates slice patterns that depend - on ``(nall - nloc)``, producing guards like ``Ne(nall, nloc)``. - * **All models**: the nlist padding inside ``forward_common_lower_exportable`` - and the subsequent sort/truncate in ``_format_nlist`` can produce guards - like ``Ne(nnei, sum(sel))``. These are spurious because the compiled - graph handles any ``nnei >= sum(sel)`` correctly. - - Instead of erasing the assertion nodes (which can disturb the FX graph - structure and produce NaN gradients on some Python/torch versions), we - replace each assertion's condition with ``True`` so that the node stays - in the graph but never fires at runtime. + relationships discovered during tracing. For the spin model, the atom- + doubling logic creates slice patterns that depend on ``(nall - nloc)``, + producing guards like ``Ne(nall, nloc)``. These guards are spurious: the + model computes correct results even when ``nall == nloc`` (NoPBC, no ghost + atoms). + + This function is **only called for spin models** (guarded by ``if is_spin`` + in ``_trace_and_export``). The assertion messages use opaque symbolic + variable names (e.g. ``Ne(s22, s96)``) rather than human-readable names, + so filtering by message content is not reliable. Since + ``prefer_deferred_runtime_asserts_over_guards=True`` converts all shape + guards into these deferred assertions, and the only shape relationships in + the spin model involve nall/nloc, neutralising all of them is safe in this + context. + + We replace each assertion's condition with ``True`` rather than erasing the + node; erasing nodes can disturb the FX graph structure and produce NaN + gradients on some Python/torch versions. """ graph = graph_module.graph for node in list(graph.nodes): @@ -40,9 +45,6 @@ def _strip_shape_assertions(graph_module: torch.nn.Module) -> None: node.op == "call_function" and node.target is torch.ops.aten._assert_scalar.default ): - # Replace the condition with True so the assertion always passes - # but the node stays in the graph. Erasing nodes can disturb the - # graph structure and produce NaN on some Python/torch versions. node.args = (True, node.args[1]) graph_module.recompile() @@ -142,13 +144,6 @@ def _make_sample_inputs( sel, distinguish_types=not mixed_types, ) - # Pad nlist so nnei > sum(sel) in the sample tensors. - # This prevents torch.export from specializing nnei to sum(sel). - nnei = sum(sel) - n_pad = max(1, nnei // 4) # pad by ~25%, at least 1 - nlist = np.concatenate( - [nlist, -np.ones((nframes, nloc, n_pad), dtype=nlist.dtype)], axis=-1 - ) extended_coord = extended_coord.reshape(nframes, -1, 3) # Convert to torch tensors @@ -518,10 +513,15 @@ def _trace_and_export( prefer_deferred_runtime_asserts_over_guards=True, ) - # torch.export inserts _assert_scalar guards for symbolic shape - # relationships (e.g. Ne(nnei, sum(sel)), Ne(nall, nloc)). These - # are spurious — the model handles any valid input shapes correctly. - _strip_shape_assertions(exported.graph_module) + if is_spin: + # The spin model's atom-doubling slice patterns depend on + # (nall - nloc), producing guards like Ne(nall, nloc). These are + # spurious — the model is correct when nall == nloc (NoPBC). + # Non-spin models don't emit shape guards because the short-circuit + # order in `_format_nlist` (dpmodel) keeps the dynamic `nnei` axis + # free of symbolic comparisons when `extra_nlist_sort=True` + # (see `forward_common_lower_exportable` in pt_expt/model/make_model.py). + _strip_shape_assertions(exported.graph_module) # 7. Move the exported program to the target device if needed. if target_device.type != "cpu": diff --git a/source/tests/pt_expt/export_helpers.py b/source/tests/pt_expt/export_helpers.py index beb286fe25..ae4db82ddb 100644 --- a/source/tests/pt_expt/export_helpers.py +++ b/source/tests/pt_expt/export_helpers.py @@ -70,12 +70,6 @@ def export_save_load_and_compare( strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) - # Strip spurious shape-guard assertions (e.g. Ne(nnei, sum(sel))) - from deepmd.pt_expt.utils.serialization import ( - _strip_shape_assertions, - ) - - _strip_shape_assertions(exported.graph_module) # 4. .pte save -> load round-trip with tempfile.NamedTemporaryFile(suffix=".pte") as f: @@ -205,22 +199,9 @@ def model_forward_lower_export_round_trip( ) # 5. Symbolic trace + dynamic shapes + .pte round-trip - # Pad nlist with extra -1 columns so nnei > sum(sel) in the sample. - # This prevents torch.export from specializing nnei to sum(sel). - nlist_padded = torch.cat( - [ - nlist_t, - -torch.ones( - (*nlist_t.shape[:2], max(1, nlist_t.shape[2] // 4)), - dtype=nlist_t.dtype, - device=nlist_t.device, - ), - ], - dim=-1, - ) inputs_2f = tuple( torch.cat([t, t], dim=0) if t is not None else None - for t in (ext_coord, ext_atype, nlist_padded, mapping_t, fparam, aparam) + for t in (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam) ) traced_sym = md_pt.forward_lower_exportable( inputs_2f[0], @@ -240,12 +221,6 @@ def model_forward_lower_export_round_trip( strict=False, prefer_deferred_runtime_asserts_over_guards=True, ) - # Strip spurious shape-guard assertions (e.g. Ne(nnei, sum(sel))) - from deepmd.pt_expt.utils.serialization import ( - _strip_shape_assertions, - ) - - _strip_shape_assertions(exported_dyn.graph_module) with tempfile.NamedTemporaryFile(suffix=".pte") as f: torch.export.save(exported_dyn, f.name) loaded = torch.export.load(f.name).module() From f4804560c3ced9b2f7acd2aec5bb5eb6c3ea6fd2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 11:49:30 +0800 Subject: [PATCH 50/56] chore(pt2): note thread-safety constraint of realize_opcount_threshold The save/restore around aoti_compile_and_package mutates torch._inductor.config, which is a process-wide singleton. Concurrent .pt2 exports from a thread pool would race on this global. Add a comment so callers know to serialise exports if running threaded. Processes are fine (each has its own inductor config). --- deepmd/pt_expt/utils/serialization.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 197fa8a760..04cdedd6cf 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -590,6 +590,12 @@ def _deserialize_to_file_pt2( # causes NaN in the backward pass (force/virial) of attention-based # descriptors (DPA1, DPA2). Setting threshold=0 prevents fusion and # avoids the NaN. Only applied on CUDA; CPU compilation is unaffected. + # + # NOTE: `torch._inductor.config` is a process-wide singleton. The + # save/restore pattern here is NOT thread-safe — concurrent AOTInductor + # compilations from multiple threads would race on this global. Callers + # must serialise `.pt2` exports if running under a thread pool. Processes + # are fine (each has its own inductor config). import torch._inductor.config as _inductor_config import deepmd.pt_expt.utils.env as _env From d41eefd97b45a60bdb047d132f2b4188f0a2e5f8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 11:49:59 +0800 Subject: [PATCH 51/56] feat(cli): warn instead of raise for --atomic-virial on non-pt_expt outputs Previously, passing --atomic-virial INPUT.pth OUTPUT.pth raised ValueError, breaking scripts that pass the flag indiscriminately when converting between any pair of formats. Switch to a warning + no-op: the flag is only meaningful for .pt2/.pte outputs, but a workflow that loops over multiple output formats now keeps working. Also clarify the CLI help text and the docstring to explicitly mention this behaviour. --- deepmd/entrypoints/convert_backend.py | 16 ++++++++++++---- deepmd/main.py | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/deepmd/entrypoints/convert_backend.py b/deepmd/entrypoints/convert_backend.py index b0ca2f1d97..e5cd51b386 100644 --- a/deepmd/entrypoints/convert_backend.py +++ b/deepmd/entrypoints/convert_backend.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( Any, ) @@ -7,6 +8,8 @@ Backend, ) +log = logging.getLogger(__name__) + def convert_backend( *, # Enforce keyword-only arguments @@ -25,14 +28,17 @@ def convert_backend( The output model file. atomic_virial : bool If True, export .pt2/.pte models with per-atom virial correction. - This adds ~2.5x inference cost. Default False. + This adds ~2.5x inference cost. Default False. Silently ignored + (with a warning) for backends that don't support the flag. """ inp_backend: Backend = Backend.detect_backend_by_model(INPUT)() out_backend: Backend = Backend.detect_backend_by_model(OUTPUT)() inp_hook = inp_backend.serialize_hook out_hook = out_backend.deserialize_hook data = inp_hook(INPUT) - # Forward atomic_virial to pt_expt deserialize_to_file if applicable + # Forward atomic_virial to pt_expt deserialize_to_file if applicable; + # warn and skip the flag for backends that don't accept it so that + # scripts passing --atomic-virial indiscriminately don't break. import inspect sig = inspect.signature(out_hook) @@ -40,7 +46,9 @@ def convert_backend( out_hook(OUTPUT, data, do_atomic_virial=atomic_virial) else: if atomic_virial: - raise ValueError( - "--atomic-virial is only supported for pt_expt .pt2/.pte outputs" + log.warning( + "--atomic-virial is only meaningful for pt_expt .pt2/.pte " + "outputs; ignoring it for output backend %s", + out_backend.name, ) out_hook(OUTPUT, data) diff --git a/deepmd/main.py b/deepmd/main.py index 0ba87a733b..bf59dfdad5 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -921,7 +921,8 @@ def main_parser() -> argparse.ArgumentParser: default=False, help="Export .pt2/.pte models with per-atom virial correction. " "This adds ~2.5x inference cost but is required for " - "LAMMPS compute/atom virial output.", + "LAMMPS compute/atom virial output. " + "Ignored (with a warning) for other output backends.", ) # * show model ****************************************************************** From c80db58dde1bbc3f9c303ad73961c0ccffb36005 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 11:52:26 +0800 Subject: [PATCH 52/56] refactor(pt2): move atomic_virial check to compute() entry; clarify error message MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 'atomic && !do_atomic_virial' check used to fire after building the coord/atype/mapping/nlist/fparam/aparam tensors, wasting that allocation work on a doomed call. Hoist the throw to the entry of each compute() overload (4 sites: DeepPotPTExpt nlist + no-nlist, DeepSpinPTExpt nlist + no-nlist) so the failure happens before any tensor is built. Also rewrite the error message to: * Mention the LAMMPS context (compute */atom/virial) — the most common trigger. * Reference the metadata flag name (do_atomic_virial=False) so users can verify by inspecting the .pt2 archive. * Note the ~2.5x cost trade-off so users understand why it's off by default. * Provide a copy-pastable regen command with file extensions. --- source/api_cc/src/DeepPotPTExpt.cc | 45 +++++++++++++++++------------ source/api_cc/src/DeepSpinPTExpt.cc | 43 +++++++++++++++------------ 2 files changed, 50 insertions(+), 38 deletions(-) diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index e95f622a19..db099ed464 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -223,6 +223,16 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, const std::vector& fparam, const std::vector& aparam, const bool atomic) { + // Fail fast before allocating any tensors: refuse to run if the caller + // asked for atomic virial but the .pt2 was exported without it. + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial was requested (e.g. by LAMMPS compute */atom/virial) " + "but this .pt2 model was exported without it (metadata field " + "do_atomic_virial=False). Atomic virial adds ~2.5x inference cost " + "and is off by default for .pt2. To enable it, regenerate with: " + "dp convert-backend --atomic-virial INPUT.pth OUTPUT.pt2"); + } torch::Device device(torch::kCUDA, gpu_id); if (!gpu_enabled) { device = torch::Device(torch::kCPU); @@ -264,12 +274,15 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, .clone() .to(device); + // LAMMPS sets ago=0 on every nlist rebuild (neighbor rebuild, re-partition, + // atom exchange between subdomains), so `ago > 0` implies the cached + // mapping and nlist tensors are still valid. Rebuild only on ago==0. if (ago == 0) { nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - // Rebuild mapping tensor only when nlist is updated (ago == 0). + // Rebuild mapping tensor if (lmp_list.mapping) { std::vector mapping(nall_real); for (int ii = 0; ii < nall_real; ii++) { @@ -336,14 +349,6 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // Fail fast: check atomic virial availability before running the model - if (atomic && !do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } - // Run the .pt2 model auto flat_outputs = run_model(coord_Tensor, atype_Tensor, firstneigh_tensor, mapping_tensor, fparam_tensor, aparam_tensor); @@ -450,6 +455,16 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, const std::vector& fparam, const std::vector& aparam, const bool atomic) { + // Fail fast before allocating any tensors (same check as the nlist + // overload — see its comment). + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial was requested (e.g. by LAMMPS compute */atom/virial) " + "but this .pt2 model was exported without it (metadata field " + "do_atomic_virial=False). Atomic virial adds ~2.5x inference cost " + "and is off by default for .pt2. To enable it, regenerate with: " + "dp convert-backend --atomic-virial INPUT.pth OUTPUT.pt2"); + } int natoms = atype.size(); int nframes = coord.size() / (natoms * 3); if (nframes > 1) { @@ -595,19 +610,11 @@ void DeepPotPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // 5. Fail fast: check atomic virial availability before running the model - if (atomic && !do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } - - // 6. Run the .pt2 model + // 5. Run the .pt2 model auto flat_outputs = run_model(coord_Tensor, atype_Tensor, nlist_tensor, mapping_tensor, fparam_tensor, aparam_tensor); - // 7. Map flat outputs to internal keys + // 6. Map flat outputs to internal keys std::map output_map; extract_outputs(output_map, flat_outputs); diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index afeeab7f05..dcd7df55a4 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -242,6 +242,15 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, const std::vector& fparam, const std::vector& aparam, const bool atomic) { + // Fail fast before allocating any tensors. + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial was requested (e.g. by LAMMPS compute */atom/virial) " + "but this .pt2 model was exported without it (metadata field " + "do_atomic_virial=False). Atomic virial adds ~2.5x inference cost " + "and is off by default for .pt2. To enable it, regenerate with: " + "dp convert-backend --atomic-virial INPUT.pth OUTPUT.pt2"); + } torch::Device device(torch::kCUDA, gpu_id); if (!gpu_enabled) { device = torch::Device(torch::kCPU); @@ -293,12 +302,15 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, .clone() .to(device); + // LAMMPS sets ago=0 on every nlist rebuild, so ago>0 implies the cached + // mapping and nlist tensors are still valid — see DeepPotPTExpt.cc for + // the same rationale. if (ago == 0) { nlist_data.copy_from_nlist(lmp_list, nall - nghost); nlist_data.shuffle_exclude_empty(fwd_map); nlist_data.padding(); - // Rebuild mapping tensor only when nlist is updated (ago == 0). + // Rebuild mapping tensor if (lmp_list.mapping) { std::vector mapping(nall_real); for (int ii = 0; ii < nall_real; ii++) { @@ -364,14 +376,6 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // Fail fast: check atomic virial availability before running the model - if (atomic && !do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } - // Run the .pt2 model (7 args for spin) auto flat_outputs = run_model(coord_Tensor, atype_Tensor, spin_Tensor, firstneigh_tensor, @@ -494,6 +498,15 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, const std::vector& fparam, const std::vector& aparam, const bool atomic) { + // Fail fast before allocating any tensors. + if (atomic && !do_atomic_virial) { + throw deepmd::deepmd_exception( + "Atomic virial was requested (e.g. by LAMMPS compute */atom/virial) " + "but this .pt2 model was exported without it (metadata field " + "do_atomic_virial=False). Atomic virial adds ~2.5x inference cost " + "and is off by default for .pt2. To enable it, regenerate with: " + "dp convert-backend --atomic-virial INPUT.pth OUTPUT.pt2"); + } int natoms = atype.size(); torch::Device device(torch::kCUDA, gpu_id); @@ -648,20 +661,12 @@ void DeepSpinPTExpt::compute(ENERGYVTYPE& ener, aparam_tensor = torch::zeros({0}, options).to(device); } - // 5. Fail fast: check atomic virial availability before running the model - if (atomic && !do_atomic_virial) { - throw deepmd::deepmd_exception( - "Atomic virial is not available in this .pt2 model " - "(exported without --atomic-virial). " - "Regenerate with: dp convert-backend --atomic-virial INPUT OUTPUT"); - } - - // 6. Run the .pt2 model (7 args for spin) + // 5. Run the .pt2 model (7 args for spin) auto flat_outputs = run_model(coord_Tensor, atype_Tensor, spin_Tensor, nlist_tensor, mapping_tensor, fparam_tensor, aparam_tensor); - // 7. Extract outputs + // 6. Extract outputs std::map output_map; extract_outputs(output_map, flat_outputs); From 72f95f87e135e1b2c2076a31bcf11b4ef42fc094 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 11:54:27 +0800 Subject: [PATCH 53/56] fix(cc): guard empty-vector dereferences in NeighborListData and convert_nlist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two latent UB bugs surface as libstdc++ debug-mode assertions when the caller hands in a neighbor list whose every row is empty (a fully-dissociated subdomain — realistic under aggressive spatial partitioning): * deepmd::NeighborListData::copy_from_nlist: '&jlist[ii][0]' on an empty row, even with 'memcpy(..., 0)' as the size. Guard the memcpy + masking with 'if (jnum > 0)'. Same fix for the outer '&ilist[0]' when 'inum == 0'. * deepmd::convert_nlist: '&from_nlist[ii][0]' on an empty row. Replace with vector::data(), which is well-defined for empty vectors (may return nullptr; the receiver only dereferences when numneigh[ii] > 0 anyway). Also remove the recently-added 'if (nnei == 0) nnei = 1;' clamp from createNlistTensor in commonPT.h. PyTorch accepts zero-sized dimensions and the eager '_format_nlist' pad branch handles them; the clamp was a defensive no-op that masked the actual UB. Update the docstring to clarify the contract. Add a regression test (cpu_lmp_nlist_empty_subdomain) that constructs an explicitly-empty InputNlist and verifies the .pt2 path runs cleanly: finite energy, exactly-zero force/virial (no interactions), no crash. These fixes affect all C++ backends that go through copy_from_nlist or convert_nlist (TF, PT, JAX, PD, and pt_expt), so they can be cherry-picked in isolation. --- source/api_cc/include/commonPT.h | 10 +++-- source/api_cc/src/common.cc | 16 ++++++-- source/api_cc/tests/test_deeppot_ptexpt.cc | 46 ++++++++++++++++++++++ source/lib/src/neighbor_list.cc | 6 ++- 4 files changed, 70 insertions(+), 8 deletions(-) diff --git a/source/api_cc/include/commonPT.h b/source/api_cc/include/commonPT.h index 20c3d5c1f0..61b823ae62 100644 --- a/source/api_cc/include/commonPT.h +++ b/source/api_cc/include/commonPT.h @@ -109,6 +109,13 @@ inline void build_comm_dict_with_virtual_atoms( * No truncation or distance sorting is done — the model's format_nlist * handles that on-device. * + * If @p min_nnei is 0 (the default used by the .pth callers) and every row + * is empty (no atom has any neighbor — fully-dissociated system), the + * output shape is [1, nloc, 0]. PyTorch accepts zero-sized dimensions, and + * the eager `_format_nlist` pads it back up to sum(sel). .pt2 callers + * always pass @p min_nnei = sum(sel) > 0, so the output width is at least + * sum(sel) for them. + * * @param data Jagged neighbor list: data[i] holds neighbor indices * for local atom i. * @param min_nnei Minimum width of the nnei dimension. For .pt2 models @@ -123,9 +130,6 @@ inline torch::Tensor createNlistTensor( for (int ii = 0; ii < nloc; ++ii) { nnei = std::max(nnei, static_cast(data[ii].size())); } - if (nnei == 0) { - nnei = 1; // at least 1 column to avoid empty tensor - } std::vector flat_data(static_cast(nloc) * nnei, -1); for (int ii = 0; ii < nloc; ++ii) { for (size_t jj = 0; jj < data[ii].size(); ++jj) { diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index b446417625..02bd0ea243 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -276,13 +276,21 @@ void deepmd::NeighborListData::copy_from_nlist(const InputNlist& inlist, int inum = natoms >= 0 ? natoms : inlist.inum; ilist.resize(inum); jlist.resize(inum); - memcpy(&ilist[0], inlist.ilist, inum * sizeof(int)); + if (inum > 0) { + memcpy(&ilist[0], inlist.ilist, inum * sizeof(int)); + } for (int ii = 0; ii < inum; ++ii) { int jnum = inlist.numneigh[ii]; jlist[ii].resize(jnum); - memcpy(&jlist[ii][0], inlist.firstneigh[ii], jnum * sizeof(int)); - for (int jj = 0; jj < jnum; ++jj) { - jlist[ii][jj] &= inlist.mask; + // Guard against empty jlist[ii]: `&vec[0]` is undefined behaviour for + // empty vectors and libstdc++ debug mode asserts on it. This happens + // when a subdomain's local atoms legitimately have zero neighbours + // within cutoff (e.g. under spatial partitioning). + if (jnum > 0) { + memcpy(&jlist[ii][0], inlist.firstneigh[ii], jnum * sizeof(int)); + for (int jj = 0; jj < jnum; ++jj) { + jlist[ii][jj] &= inlist.mask; + } } } } diff --git a/source/api_cc/tests/test_deeppot_ptexpt.cc b/source/api_cc/tests/test_deeppot_ptexpt.cc index 7369d27708..ec7f32222c 100644 --- a/source/api_cc/tests/test_deeppot_ptexpt.cc +++ b/source/api_cc/tests/test_deeppot_ptexpt.cc @@ -401,6 +401,52 @@ TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_oversized) { } } +// Edge case: a subdomain whose every local atom has zero neighbors within +// cutoff (e.g. under aggressive spatial partitioning on a small/sparse +// subdomain, or at the start of a simulation before atoms settle in). +// The C++ side builds an InputNlist with empty rows, which feeds +// `createNlistTensor` with min_nnei=sum(sel); the compiled .pt2 graph must +// then run cleanly on an all-`-1` nlist and produce a sensible (finite, +// interaction-free) result. Verifies both that the code path doesn't +// crash and that the forces/virial collapse to zero (no interactions). +TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_empty_subdomain) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + int& natoms = this->natoms; + deepmd::DeepPot& dp = this->dp; + + // Pass coord/atype as-is; the model sees them, but `dp.compute` only + // uses the provided InputNlist for neighbor information. + int nall = natoms; + std::vector > nlist_data(natoms); // every row empty + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force(nall * 3, 0.0), virial(9, 0.0); + // Must not throw: zero-neighbor input is legal and expected under some + // spatial-partitioning configurations. + ASSERT_NO_THROW( + dp.compute(ener, force, virial, coord, atype, box, 0, inlist, 0)); + EXPECT_EQ(force.size(), natoms * 3); + EXPECT_EQ(virial.size(), 9); + EXPECT_TRUE(std::isfinite(ener)); + // With no neighbors, interaction forces and virial must be exactly zero + // (the descriptor sees only -1 entries, so pair contributions vanish). + for (int ii = 0; ii < natoms * 3; ++ii) { + EXPECT_TRUE(std::isfinite(force[ii])); + EXPECT_LT(fabs(force[ii]), EPSILON); + } + for (int ii = 0; ii < 9; ++ii) { + EXPECT_TRUE(std::isfinite(virial[ii])); + EXPECT_LT(fabs(virial[ii]), EPSILON); + } +} + TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_type_sel) { using VALUETYPE = TypeParam; std::vector& coord = this->coord; diff --git a/source/lib/src/neighbor_list.cc b/source/lib/src/neighbor_list.cc index 6723e3de66..3a0d8eb122 100644 --- a/source/lib/src/neighbor_list.cc +++ b/source/lib/src/neighbor_list.cc @@ -852,7 +852,11 @@ void deepmd::convert_nlist(InputNlist& to_nlist, for (int ii = 0; ii < to_nlist.inum; ++ii) { to_nlist.ilist[ii] = ii; to_nlist.numneigh[ii] = from_nlist[ii].size(); - to_nlist.firstneigh[ii] = &from_nlist[ii][0]; + // `&vec[0]` is undefined behaviour for empty vectors (libstdc++ debug + // mode asserts on it). When numneigh[ii] is 0 the pointer is never + // dereferenced; use vector::data() which is well-defined for empty + // vectors (may return nullptr). + to_nlist.firstneigh[ii] = from_nlist[ii].data(); } } From 43d42e6abaee36e3fe225a0c83475f73af3764ef Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 11:56:11 +0800 Subject: [PATCH 54/56] test(cc): tighten oversized-nlist sanity check; verify all predictions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 'cpu_lmp_nlist_oversized' tests rely on '_pad_shuffle_nlist' to distribute real neighbors past the sum(sel) boundary, but with a fixed seed it's possible for the shuffle to keep real neighbors at the front — making the oversized test a tautology. Add a sanity check that truncates the shuffled oversized nlist to one column and asserts that at least one prediction (energy / force / [force_mag] / virial) deviates from the reference by more than SANITY_MIN_DEV. The new sanity check uses a separate threshold (SANITY_MIN_DEV = 1e-3) rather than EPSILON. EPSILON is a numerical-precision tolerance for positive assertions ("model output matches reference within EPSILON"); the sanity check is a physical-signal tolerance ("a meaningful deviation exceeds SANITY_MIN_DEV"). Conflating the two thresholds either makes positive assertions too loose (hides real model regressions) or makes negative assertions sensitive to numerical noise. Independent of the sanity check, also tighten EPSILON in test_deeppot_dpa_ptexpt_spin.cc from 1e-6/1e-1 (double/float) to the project-standard 1e-10/1e-4. All 18 existing spin tests pass at the standard tolerance — the previous local override was an unjustified relaxation. --- .../tests/test_deeppot_dpa_ptexpt_spin.cc | 60 +++++++++++++++- source/api_cc/tests/test_deeppot_ptexpt.cc | 69 +++++++++++++++++++ 2 files changed, 128 insertions(+), 1 deletion(-) diff --git a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc index 95e4b3e48c..5f5ae7a6b6 100644 --- a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc +++ b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc @@ -14,7 +14,7 @@ // Spin models need relaxed epsilon #undef EPSILON -#define EPSILON (std::is_same::value ? 1e-6 : 1e-1) +#define EPSILON (std::is_same::value ? 1e-10 : 1e-4) // ============================================================================ // PBC test fixture @@ -513,3 +513,61 @@ TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, cpu_lmp_nlist_oversized) { EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON); } } + +// Sanity check that the shuffle in `_pad_shuffle_nlist` actually redistributes +// real neighbors away from the front of each row — see the non-spin version +// in test_deeppot_ptexpt.cc for the rationale. At least one prediction +// (energy / force / force_mag / virial) must deviate from the reference by +// more than EPSILON when the nlist is truncated to a single column. +TYPED_TEST(TestInferDeepSpinDpaPtExptNopbc, + cpu_lmp_nlist_oversized_shuffle_sanity) { + using VALUETYPE = TypeParam; + const std::vector& coord = this->coord; + const std::vector& spin = this->spin; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_f = this->expected_f; + std::vector& expected_fm = this->expected_fm; + std::vector& expected_tot_v = this->expected_tot_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepSpin& dp = this->dp; + double ener; + std::vector force, force_mag, virial; + + std::vector > nlist_data = { + {1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5}, + {0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}}; + std::vector > nlist_oversized; + _pad_shuffle_nlist(nlist_oversized, nlist_data, 50); + for (auto& row : nlist_oversized) { + row.resize(1); + } + std::vector ilist(natoms), numneigh(natoms); + std::vector firstneigh(natoms); + deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_oversized); + dp.compute(ener, force, force_mag, virial, coord, spin, atype, box, 0, inlist, + 0); + + // See test_deeppot_ptexpt.cc for the rationale of the separate negative- + // direction threshold. For the spin nopbc test the strongest signal + // is in force_mag (max magnitude ~0.97); 1e-3 is well below all + // expected signal magnitudes and well above the float precision floor. + static constexpr double SANITY_MIN_DEV = 1e-3; + bool any_deviates = fabs(ener - expected_tot_e) > SANITY_MIN_DEV; + for (int ii = 0; ii < natoms * 3 && !any_deviates; ++ii) { + any_deviates = fabs(force[ii] - expected_f[ii]) > SANITY_MIN_DEV; + } + for (int ii = 0; ii < natoms * 3 && !any_deviates; ++ii) { + any_deviates = fabs(force_mag[ii] - expected_fm[ii]) > SANITY_MIN_DEV; + } + for (int ii = 0; ii < 9 && !any_deviates; ++ii) { + any_deviates = fabs(virial[ii] - expected_tot_v[ii]) > SANITY_MIN_DEV; + } + EXPECT_TRUE(any_deviates) + << "Every prediction stayed within SANITY_MIN_DEV after single-column " + "truncation of the shuffled oversized nlist — the shuffle appears " + "to have kept real neighbors at the front. The oversized test " + "above may be a tautology; increase n_extra in _pad_shuffle_nlist."; +} diff --git a/source/api_cc/tests/test_deeppot_ptexpt.cc b/source/api_cc/tests/test_deeppot_ptexpt.cc index ec7f32222c..5834911526 100644 --- a/source/api_cc/tests/test_deeppot_ptexpt.cc +++ b/source/api_cc/tests/test_deeppot_ptexpt.cc @@ -401,6 +401,75 @@ TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_oversized) { } } +// Sanity check for the oversized-nlist test above: the Python counterpart +// asserts that naïvely truncating the shuffled nlist (without distance +// sorting) produces a DIFFERENT prediction, proving that the shuffle did +// move real neighbors past the sum(sel) boundary. The C++ API always +// sorts on-device, so we verify the equivalent property by truncating the +// shuffled nlist to a single column: with the shuffle ratio (~5 real +// entries out of 205), almost every row ends up holding a -1, and the +// model's prediction must deviate from the reference on at least one of +// {energy, force, virial}. If every prediction still matches the +// reference within EPSILON, the shuffle is not meaningful and the +// oversized test above is a tautology. +TYPED_TEST(TestInferDeepPotAPtExpt, cpu_lmp_nlist_oversized_shuffle_sanity) { + using VALUETYPE = TypeParam; + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + std::vector& expected_f = this->expected_f; + std::vector& expected_tot_v = this->expected_tot_v; + int& natoms = this->natoms; + double& expected_tot_e = this->expected_tot_e; + deepmd::DeepPot& dp = this->dp; + float rc = dp.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + std::vector > nlist_oversized; + _pad_shuffle_nlist(nlist_oversized, nlist_data, 200); + for (auto& row : nlist_oversized) { + row.resize(1); + } + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_oversized); + + double ener; + std::vector force_(nall * 3, 0.0), virial(9, 0.0); + dp.compute(ener, force_, virial, coord_cpy, atype_cpy, box, nall - nloc, + inlist, 0); + std::vector force; + _fold_back(force, force_, mapping, nloc, nall, 3); + + // Threshold for the negative-direction check is intentionally separate + // from EPSILON. EPSILON is a numerical-precision tolerance for positive + // assertions ("model matches reference"); here we want a physical-signal + // tolerance ("broken model produces a meaningful deviation"). 1e-3 is + // 100x below the smallest expected force magnitude (~1e-1) and well + // above the float precision noise floor (~1e-5), so it cleanly + // distinguishes "actually broken" from "noise". + static constexpr double SANITY_MIN_DEV = 1e-3; + bool any_deviates = fabs(ener - expected_tot_e) > SANITY_MIN_DEV; + for (int ii = 0; ii < natoms * 3 && !any_deviates; ++ii) { + any_deviates = fabs(force[ii] - expected_f[ii]) > SANITY_MIN_DEV; + } + for (int ii = 0; ii < 9 && !any_deviates; ++ii) { + any_deviates = fabs(virial[ii] - expected_tot_v[ii]) > SANITY_MIN_DEV; + } + EXPECT_TRUE(any_deviates) + << "Every prediction (energy/force/virial) stayed within " + "SANITY_MIN_DEV after single-column truncation of the shuffled " + "oversized nlist — the shuffle appears to have kept real neighbors " + "at the front. The oversized test above may be a tautology; " + "increase n_extra in _pad_shuffle_nlist."; +} + // Edge case: a subdomain whose every local atom has zero neighbors within // cutoff (e.g. under aggressive spatial partitioning on a small/sparse // subdomain, or at the start of a simulation before atoms settle in). From e30206de4fc93fb391e7aaf307d322c73baa3a44 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 12:39:16 +0800 Subject: [PATCH 55/56] fix(cc): apply empty-vector guard to make_inlist and DeepSpinTF; add backend-agnostic tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two more sites with the same '&vec[0]' UB pattern that the previous commit (72f95f87) fixed in convert_nlist and copy_from_nlist: * deepmd::NeighborListData::make_inlist (common.cc): 'firstneigh[ii] = &jlist[ii][0]' — UB if jlist[ii] is empty. Also apply the same guard to the trailing 'inlist.{ilist,numneigh,firstneigh} = &vec[0]' calls for safety. * DeepSpinTF.cc:1124: 'extend_firstneigh[ii] = &extend_neigh[ii][0]' — UB if extend_neigh[ii] is empty. Both replaced with vector::data(), which is well-defined for empty vectors (downstream consumers gate on numneigh > 0 before dereferencing). Both call sites are TF-only (make_inlist is invoked from DeepPotTF, DeepTensorTF, DataModifierTF, DeepSpinTF; the extend_firstneigh site lives entirely in DeepSpinTF), so they are reachable under the same fully-dissociated-subdomain conditions that motivated 72f95f87 but through TF paths rather than the pt_expt path. Add a backend-agnostic test file (test_neighbor_list_data.cc) covering all four UB sites: * MakeInlistEmptyRows — every jlist[ii] empty (this commit). * MakeInlistMixedEmptyAndNonemptyRows — partial-empty scenario. * ConvertNlistEmptyRows — fix from 72f95f87. * CopyFromNlistEmptyRows — fix from 72f95f87. * CopyFromNlistInumZero — fix from 72f95f87 (inum == 0 guard). * RoundTripWithEmptyRows — end-to-end coverage of all paths. The tests exercise NeighborListData and convert_nlist directly with no model dependency, so they run regardless of which backend (TF/PT/JAX/PD/PT_EXPT) is enabled. --- source/api_cc/src/DeepSpinTF.cc | 4 +- source/api_cc/src/common.cc | 12 +- .../api_cc/tests/test_neighbor_list_data.cc | 139 ++++++++++++++++++ 3 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 source/api_cc/tests/test_neighbor_list_data.cc diff --git a/source/api_cc/src/DeepSpinTF.cc b/source/api_cc/src/DeepSpinTF.cc index 7f569ea231..091fca64c2 100644 --- a/source/api_cc/src/DeepSpinTF.cc +++ b/source/api_cc/src/DeepSpinTF.cc @@ -1121,7 +1121,9 @@ void DeepSpinTF::extend(int& extend_inum, extend_firstneigh.resize(extend_nloc); extend_numneigh.resize(extend_nloc); for (int ii = 0; ii < extend_nloc; ii++) { - extend_firstneigh[ii] = &extend_neigh[ii][0]; + // `&vec[0]` is undefined behaviour for empty vectors; use + // vector::data() instead. See convert_nlist for the same fix. + extend_firstneigh[ii] = extend_neigh[ii].data(); extend_numneigh[ii] = extend_neigh[ii].size(); } diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 02bd0ea243..7aae726242 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -361,12 +361,16 @@ void deepmd::NeighborListData::make_inlist(InputNlist& inlist) { firstneigh.resize(nloc); for (int ii = 0; ii < nloc; ++ii) { numneigh[ii] = jlist[ii].size(); - firstneigh[ii] = &jlist[ii][0]; + // `&vec[0]` is undefined behaviour for empty vectors (libstdc++ + // debug mode asserts on it). When numneigh[ii] is 0 the pointer is + // never dereferenced; use vector::data() which is well-defined for + // empty vectors. Mirrors the fix in convert_nlist. + firstneigh[ii] = jlist[ii].data(); } inlist.inum = nloc; - inlist.ilist = &ilist[0]; - inlist.numneigh = &numneigh[0]; - inlist.firstneigh = &firstneigh[0]; + inlist.ilist = ilist.data(); + inlist.numneigh = numneigh.data(); + inlist.firstneigh = firstneigh.data(); } #ifdef BUILD_TENSORFLOW diff --git a/source/api_cc/tests/test_neighbor_list_data.cc b/source/api_cc/tests/test_neighbor_list_data.cc new file mode 100644 index 0000000000..3e5198d30d --- /dev/null +++ b/source/api_cc/tests/test_neighbor_list_data.cc @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: LGPL-3.0-or-later +// Backend-agnostic unit tests for deepmd::NeighborListData and the +// deepmd::convert_nlist helper. Exercises edge cases (empty rows, empty +// nlist) that surfaced as latent libstdc++-debug-mode UB assertions in +// production code paths. +#include + +#include + +#include "common.h" +#include "neighbor_list.h" + +namespace deepmd { + +// Build a NeighborListData with @p nloc local atoms and zero neighbors per +// atom. Realistic under aggressive spatial partitioning where a subdomain's +// every local atom has no neighbors within cutoff. +TEST(TestNeighborListData, MakeInlistEmptyRows) { + NeighborListData data; + const int nloc = 4; + data.ilist.resize(nloc); + for (int ii = 0; ii < nloc; ++ii) { + data.ilist[ii] = ii; + } + data.jlist.resize(nloc); // every row default-constructed to empty + + // Must not trigger UB ('&vec[0]' on empty vector) under libstdc++ debug. + InputNlist inlist; + ASSERT_NO_THROW(data.make_inlist(inlist)); + + EXPECT_EQ(inlist.inum, nloc); + ASSERT_NE(inlist.numneigh, nullptr); + ASSERT_NE(inlist.firstneigh, nullptr); + for (int ii = 0; ii < nloc; ++ii) { + EXPECT_EQ(inlist.numneigh[ii], 0); + } +} + +TEST(TestNeighborListData, MakeInlistMixedEmptyAndNonemptyRows) { + NeighborListData data; + data.ilist = {0, 1, 2}; + data.jlist.resize(3); + // row 0: empty (legitimate edge case) + data.jlist[1] = {7, 8}; + // row 2: empty + InputNlist inlist; + ASSERT_NO_THROW(data.make_inlist(inlist)); + EXPECT_EQ(inlist.numneigh[0], 0); + EXPECT_EQ(inlist.numneigh[1], 2); + EXPECT_EQ(inlist.numneigh[2], 0); + // Only the populated row's firstneigh should be dereferenced. + EXPECT_EQ(inlist.firstneigh[1][0], 7); + EXPECT_EQ(inlist.firstneigh[1][1], 8); +} + +// convert_nlist(jagged) must not dereference an empty row when populating +// firstneigh. Regression test for the same `&vec[0]` UB pattern fixed in +// commit 72f95f87. +TEST(TestNeighborListData, ConvertNlistEmptyRows) { + std::vector> input = {{}, {}, {}}; // all rows empty + std::vector ilist(input.size()), numneigh(input.size()); + std::vector firstneigh(input.size()); + InputNlist out(static_cast(input.size()), ilist.data(), numneigh.data(), + firstneigh.data()); + ASSERT_NO_THROW(convert_nlist(out, input)); + EXPECT_EQ(out.inum, 3); + for (int ii = 0; ii < 3; ++ii) { + EXPECT_EQ(out.numneigh[ii], 0); + // firstneigh[ii] may be vector::data()'s sentinel or nullptr — must not + // be dereferenced when numneigh[ii] == 0. + } +} + +// copy_from_nlist must not dereference an empty source row even when +// memcpy size is 0. Regression test for the same UB pattern. +TEST(TestNeighborListData, CopyFromNlistEmptyRows) { + // Build an InputNlist with all empty rows. + const int nloc = 4; + std::vector src_ilist(nloc), src_numneigh(nloc, 0); + std::vector src_firstneigh(nloc, nullptr); + for (int ii = 0; ii < nloc; ++ii) { + src_ilist[ii] = ii; + } + InputNlist src(nloc, src_ilist.data(), src_numneigh.data(), + src_firstneigh.data()); + src.mask = ~0; // identity mask, must not be applied to absent neighbors + + NeighborListData data; + ASSERT_NO_THROW(data.copy_from_nlist(src)); + EXPECT_EQ(static_cast(data.ilist.size()), nloc); + EXPECT_EQ(static_cast(data.jlist.size()), nloc); + for (int ii = 0; ii < nloc; ++ii) { + EXPECT_TRUE(data.jlist[ii].empty()); + } +} + +// copy_from_nlist with an empty source list (inum == 0) must not +// dereference '&ilist[0]' on the empty target ilist. +TEST(TestNeighborListData, CopyFromNlistInumZero) { + InputNlist src; + src.inum = 0; + src.ilist = nullptr; + src.numneigh = nullptr; + src.firstneigh = nullptr; + src.mask = ~0; + + NeighborListData data; + ASSERT_NO_THROW(data.copy_from_nlist(src)); + EXPECT_TRUE(data.ilist.empty()); + EXPECT_TRUE(data.jlist.empty()); +} + +// Round-trip: convert_nlist(jagged) → copy_from_nlist → make_inlist +// must preserve both empty and non-empty rows without UB. +TEST(TestNeighborListData, RoundTripWithEmptyRows) { + std::vector> input = {{}, {3, 4}, {}, {5}}; + std::vector ilist(input.size()), numneigh(input.size()); + std::vector firstneigh(input.size()); + InputNlist src(static_cast(input.size()), ilist.data(), numneigh.data(), + firstneigh.data()); + ASSERT_NO_THROW(convert_nlist(src, input)); + + NeighborListData data; + ASSERT_NO_THROW(data.copy_from_nlist(src)); + EXPECT_EQ(static_cast(data.jlist.size()), 4); + EXPECT_TRUE(data.jlist[0].empty()); + EXPECT_EQ(data.jlist[1], (std::vector{3, 4})); + EXPECT_TRUE(data.jlist[2].empty()); + EXPECT_EQ(data.jlist[3], (std::vector{5})); + + InputNlist out; + ASSERT_NO_THROW(data.make_inlist(out)); + EXPECT_EQ(out.numneigh[0], 0); + EXPECT_EQ(out.numneigh[1], 2); + EXPECT_EQ(out.numneigh[2], 0); + EXPECT_EQ(out.numneigh[3], 1); +} + +} // namespace deepmd From 9d2f577190905c52f874bf205d5f1e54064dc490 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 25 Apr 2026 12:55:56 +0800 Subject: [PATCH 56/56] test(cc): regression test for atomic_virial fail-fast guard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes the coverage gap left by commit c80db58d: the 'atomic && !do_atomic_virial' throw branch in DeepPotPTExpt::compute was not exercised by any existing test. Approach: * gen_sea.py: after exporting deeppot_sea.pt2 with do_atomic_virial= True, copy the archive to deeppot_sea_no_atomic_virial.pt2 and patch its metadata.json so do_atomic_virial=False. Cheap (ZIP rewrite, no AOTInductor recompile — adds <1s to gen time). Use ZIP_STORED to match the format expected by the C++ read_zip_entry in commonPTExpt.h. * test_deeppot_ptexpt.cc: new TYPED_TEST cpu_atomic_throws_when_ disabled — load the patched .pt2, call compute() with atomic=true, expect deepmd_exception. Also verify atomic=false on the same model still works (sanity check that the guard fires only when actually requested). Covers the most-impactful gap from PR #5407's review: the change in .pt2 default (do_atomic_virial off) made this code path the typical failure mode for users who'd previously been getting per-atom virial "for free". Without this test, regressions in the throw or in the metadata round-trip would slip through. --- source/api_cc/tests/test_deeppot_ptexpt.cc | 46 ++++++++++++++++++++++ source/tests/infer/gen_sea.py | 42 ++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/source/api_cc/tests/test_deeppot_ptexpt.cc b/source/api_cc/tests/test_deeppot_ptexpt.cc index 5834911526..201724a725 100644 --- a/source/api_cc/tests/test_deeppot_ptexpt.cc +++ b/source/api_cc/tests/test_deeppot_ptexpt.cc @@ -581,6 +581,52 @@ TYPED_TEST(TestInferDeepPotAPtExpt, print_summary) { dp.print_summary(""); } +// Regression test for the fail-fast guard hoisted in commit c80db58d. +// `deeppot_sea_no_atomic_virial.pt2` is a copy of deeppot_sea.pt2 with +// the do_atomic_virial=false flag patched into its metadata.json. +// Calling compute() with atomic=true on this model must throw before +// any tensors are allocated. +TYPED_TEST(TestInferDeepPotAPtExpt, cpu_atomic_throws_when_disabled) { + using VALUETYPE = TypeParam; + deepmd::DeepPot dp_no_av; + ASSERT_NO_THROW( + dp_no_av.init("../../tests/infer/deeppot_sea_no_atomic_virial.pt2")); + + std::vector& coord = this->coord; + std::vector& atype = this->atype; + std::vector& box = this->box; + int& natoms = this->natoms; + + // Build an LMP-style nlist so we exercise the nlist-overload of + // compute(); the no-nlist overload has the same guard but is + // covered by symmetry. + float rc = dp_no_av.cutoff(); + int nloc = coord.size() / 3; + std::vector coord_cpy; + std::vector atype_cpy, mapping; + std::vector > nlist_data; + _build_nlist(nlist_data, coord_cpy, atype_cpy, mapping, coord, + atype, box, rc); + int nall = coord_cpy.size() / 3; + std::vector ilist(nloc), numneigh(nloc); + std::vector firstneigh(nloc); + deepmd::InputNlist inlist(nloc, &ilist[0], &numneigh[0], &firstneigh[0]); + convert_nlist(inlist, nlist_data); + + double ener; + std::vector force(nall * 3, 0.0), virial(9, 0.0), atom_ener, + atom_vir; + // atomic=true => guard must trip and throw deepmd_exception. + EXPECT_THROW( + dp_no_av.compute(ener, force, virial, atom_ener, atom_vir, coord_cpy, + atype_cpy, box, nall - nloc, inlist, 0), + deepmd::deepmd_exception); + // atomic=false on the same model must work normally (sanity check + // that the guard fires only when actually requested). + EXPECT_NO_THROW(dp_no_av.compute(ener, force, virial, coord_cpy, atype_cpy, + box, nall - nloc, inlist, 0)); +} + template class TestInferDeepPotAPtExptNoPbc : public ::testing::Test { protected: diff --git a/source/tests/infer/gen_sea.py b/source/tests/infer/gen_sea.py index 67f8679175..02f4e7ee63 100644 --- a/source/tests/infer/gen_sea.py +++ b/source/tests/infer/gen_sea.py @@ -59,8 +59,50 @@ def main(): print(f"Exporting to {pt2_path} ...") # noqa: T201 deserialize_to_file(pt2_path, data, do_atomic_virial=True) + # Produce a variant for regression-testing the C++ "atomic && + # !do_atomic_virial" throw path by copying the .pt2 archive and + # flipping the do_atomic_virial flag in its metadata.json — much + # cheaper than running a second AOTInductor compile. The compiled + # graph itself supports atomic virial; only the C++ guard differs. + import shutil + + pt2_no_aviral = os.path.join(base_dir, "deeppot_sea_no_atomic_virial.pt2") + print(f"Patching to {pt2_no_aviral} ...") # noqa: T201 + shutil.copyfile(pt2_path, pt2_no_aviral) + _patch_no_atomic_virial(pt2_no_aviral) + print("Done!") # noqa: T201 +def _patch_no_atomic_virial(pt2_path: str) -> None: + """Flip do_atomic_virial=False in the metadata.json of a .pt2 archive. + + The .pt2 is a ZIP archive; the metadata blob lives at + ``extra/metadata.json``. We rewrite the archive with that one entry + replaced and all other entries preserved verbatim. + """ + import json + import zipfile + + metadata_name = "extra/metadata.json" + tmp_path = pt2_path + ".tmp" + # PyTorch .pt2 archives use ZIP_STORED (uncompressed) so that the C++ + # reader (read_zip_entry in commonPTExpt.h) and torch's mmap-based + # tensor loader can read entries without decompression. Preserve + # that on rewrite — using ZIP_DEFLATED would yield bytes the C++ + # reader treats as raw, resulting in JSON parse errors. + with zipfile.ZipFile(pt2_path, "r") as src: + names = src.namelist() + meta = json.loads(src.read(metadata_name).decode("utf-8")) + meta["do_atomic_virial"] = False + with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_STORED) as dst: + for name in names: + if name == metadata_name: + dst.writestr(name, json.dumps(meta)) + else: + dst.writestr(name, src.read(name)) + os.replace(tmp_path, pt2_path) + + if __name__ == "__main__": main()