diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index ded716bd15..a446bde06f 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -84,6 +84,10 @@ def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" return self.descriptor.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the atomic model needs sorted nlist when using `forward_lower`.""" + return self.descriptor.need_sorted_nlist_for_lower() + def forward_atomic( self, extended_coord: np.ndarray, diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index da5f8debe2..d522347f41 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -96,6 +96,10 @@ def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" return any(model.has_message_passing() for model in self.models) + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the atomic model needs sorted nlist when using `forward_lower`.""" + return True + def get_rcut(self) -> float: """Get the cut-off radius.""" return max(self.get_model_rcuts()) diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index ac6076a8e3..bf345eaa12 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -119,6 +119,10 @@ def mixed_types(self) -> bool: def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" + @abstractmethod + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + @abstractmethod def fwd( self, diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 8a8cbe5815..4218c24e3e 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -135,6 +135,10 @@ def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the atomic model needs sorted nlist when using `forward_lower`.""" + return False + def change_type_map( self, type_map: List[str], model_with_new_type_stat=None ) -> None: diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index 3e0ad13420..e48479cca8 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -132,6 +132,10 @@ def call( def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" + @abstractmethod + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + def extend_descrpt_stat(des, type_map, des_with_stat=None): r""" diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 360df6a591..70cb818eef 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -371,6 +371,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.se_atten.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.se_atten.need_sorted_nlist_for_lower() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.se_atten.get_env_protection() @@ -956,6 +960,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False + class NeighborGatedAttention(NativeOP): def __init__( diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 5fcf1e27b9..0de63bce4a 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -553,6 +553,10 @@ def has_message_passing(self) -> bool: [self.repinit.has_message_passing(), self.repformers.has_message_passing()] ) + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 3b08426b13..4cd4e230ae 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -146,6 +146,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return any(descrpt.has_message_passing() for descrpt in self.descrpt_list) + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + def get_env_protection(self) -> float: """Returns the protection of building environment matrix. All descriptors should be the same.""" all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list] diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 49bf000248..6ce54c6f12 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -104,6 +104,10 @@ def mixed_types(self) -> bool: def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" + @abstractmethod + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + @abstractmethod def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index af286a35e7..bb84816d3d 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -401,6 +401,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return True + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False + # translated by GPT and modified def get_residual( diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 75ac11dbed..11856521c8 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -265,6 +265,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 4b89e1dd90..2d9f6f5a52 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -223,6 +223,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 72d8a24bd9..364600aa8b 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -215,6 +215,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index a130437b3d..ee4c1f035a 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -214,7 +214,12 @@ def call_lower( """ nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.reshape(nframes, -1, 3) - nlist = self.format_nlist(extended_coord, extended_atype, nlist) + nlist = self.format_nlist( + extended_coord, + extended_atype, + nlist, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) cc_ext, _, fp, ap, input_prec = self.input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) @@ -311,6 +316,7 @@ def format_nlist( extended_coord: np.ndarray, extended_atype: np.ndarray, nlist: np.ndarray, + extra_nlist_sort: bool = False, ): """Format the neighbor list. @@ -336,6 +342,8 @@ def format_nlist( atomic type in extended region. nf x nall nlist neighbor list. nf x nloc x nsel + extra_nlist_sort + whether to forcibly sort the nlist. Returns ------- @@ -345,7 +353,12 @@ def format_nlist( """ n_nf, n_nloc, n_nnei = nlist.shape mixed_types = self.mixed_types() - ret = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + ret = self._format_nlist( + extended_coord, + nlist, + sum(self.get_sel()), + extra_nlist_sort=extra_nlist_sort, + ) if not mixed_types: ret = nlist_distinguish_types(ret, extended_atype, self.get_sel()) return ret @@ -355,6 +368,7 @@ def _format_nlist( extended_coord: np.ndarray, nlist: np.ndarray, nnei: int, + extra_nlist_sort: bool = False, ): n_nf, n_nloc, n_nnei = nlist.shape extended_coord = extended_coord.reshape([n_nf, -1, 3]) @@ -370,7 +384,9 @@ def _format_nlist( ], axis=-1, ) - elif n_nnei > nnei: + + if n_nnei > nnei or extra_nlist_sort: + n_nf, n_nloc, n_nnei = nlist.shape # make a copy before revise m_real_nei = nlist >= 0 ret = np.where(m_real_nei, nlist, 0) @@ -384,9 +400,11 @@ def _format_nlist( ret = np.take_along_axis(ret, ret_mapping, axis=2) ret = np.where(rr > rcut, -1, ret) ret = ret[..., :nnei] - else: # n_nnei == nnei: - # copy anyway... + # not extra_nlist_sort and n_nnei <= nnei: + elif n_nnei == nnei: ret = nlist + else: + pass assert ret.shape[-1] == nnei return ret @@ -483,6 +501,10 @@ def has_message_passing(self) -> bool: """Returns whether the model has message passing.""" return self.atomic_model.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.atomic_model.need_sorted_nlist_for_lower() + def atomic_output_def(self) -> FittingOutputDef: """Get the output def of the atomic model.""" return self.atomic_model.atomic_output_def() diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 549a6dcaee..8def2e48de 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -118,6 +118,10 @@ def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" return self.descriptor.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the atomic model needs sorted nlist when using `forward_lower`.""" + return self.descriptor.need_sorted_nlist_for_lower() + def serialize(self) -> dict: dd = BaseAtomicModel.serialize(self) dd.update( diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index d068066306..3c7692212e 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -108,6 +108,10 @@ def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" return any(model.has_message_passing() for model in self.models) + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the atomic model needs sorted nlist when using `forward_lower`.""" + return True + def get_out_bias(self) -> torch.Tensor: return self.out_bias diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index c4f293bd13..7ef87524dd 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -164,6 +164,10 @@ def has_message_passing(self) -> bool: """Returns whether the atomic model has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the atomic model needs sorted nlist when using `forward_lower`.""" + return False + def change_type_map( self, type_map: List[str], model_with_new_type_stat=None ) -> None: diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 6f67877350..16c3d96301 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -175,6 +175,10 @@ def forward( def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" + @abstractmethod + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + def make_default_type_embedding( ntypes, diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index 0bc4a03807..14767cb100 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -364,6 +364,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.se_atten.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.se_atten.need_sorted_nlist_for_lower() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.se_atten.get_env_protection() diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 4d830ace1b..7e5262e275 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -318,6 +318,10 @@ def has_message_passing(self) -> bool: [self.repinit.has_message_passing(), self.repformers.has_message_passing()] ) + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" # the env_protection of repinit is the same as that of the repformer diff --git a/deepmd/pt/model/descriptor/hybrid.py b/deepmd/pt/model/descriptor/hybrid.py index 41fb5e68e3..7156396c48 100644 --- a/deepmd/pt/model/descriptor/hybrid.py +++ b/deepmd/pt/model/descriptor/hybrid.py @@ -153,6 +153,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return any(descrpt.has_message_passing() for descrpt in self.descrpt_list) + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return True + def get_env_protection(self) -> float: """Returns the protection of building environment matrix. All descriptors should be the same.""" all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list] diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index bca6fa6eec..bc8c331ec3 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -557,3 +557,7 @@ def get_stats(self) -> Dict[str, StatItem]: def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return True + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index ee05c3e613..44564a6fd3 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -160,6 +160,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.sea.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.sea.need_sorted_nlist_for_lower() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.sea.get_env_protection() @@ -712,3 +716,7 @@ def forward( def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index 2d182b7ee2..92d6e223e4 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -591,6 +591,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False + class NeighborGatedAttention(nn.Module): def __init__( diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 3ff74bfb22..da8d422444 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -183,6 +183,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return False + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return False + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.env_protection diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index ea9824127a..5e7e507fbf 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -189,6 +189,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.seat.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor needs sorted nlist when using `forward_lower`.""" + return self.seat.need_sorted_nlist_for_lower() + def get_env_protection(self) -> float: """Returns the protection of building environment matrix.""" return self.seat.get_env_protection() @@ -727,3 +731,7 @@ def forward( def has_message_passing(self) -> bool: """Returns whether the descriptor block has message passing.""" return False + + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the descriptor block needs sorted nlist when using `forward_lower`.""" + return False diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 2d732c2800..0d4a53a850 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -111,6 +111,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index 0dd6af7b80..27d62fa882 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -101,6 +101,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 4d7c16eb7d..4016f0eb35 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -112,6 +112,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) model_predict = {} diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index fb16478bc0..e58ba1df62 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -115,6 +115,7 @@ def forward_lower( aparam=aparam, do_atomic_virial=do_atomic_virial, comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/frozen.py b/deepmd/pt/model/model/frozen.py index 79bc450333..395d81c217 100644 --- a/deepmd/pt/model/model/frozen.py +++ b/deepmd/pt/model/model/frozen.py @@ -111,6 +111,10 @@ def has_message_passing(self) -> bool: """Returns whether the descriptor has message passing.""" return self.model.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.model.need_sorted_nlist_for_lower() + @torch.jit.export def forward( self, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 32432725d3..d4ed4b028b 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -215,6 +215,7 @@ def forward_common_lower( aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, comm_dict: Optional[Dict[str, torch.Tensor]] = None, + extra_nlist_sort: bool = False, ): """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -239,6 +240,8 @@ def forward_common_lower( whether calculate atomic virial. comm_dict The data needed for communication for parallel inference. + extra_nlist_sort + whether to forcibly sort the nlist. Returns ------- @@ -248,7 +251,9 @@ def forward_common_lower( """ nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.view(nframes, -1, 3) - nlist = self.format_nlist(extended_coord, extended_atype, nlist) + nlist = self.format_nlist( + extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort + ) cc_ext, _, fp, ap, input_prec = self.input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) @@ -349,6 +354,7 @@ def format_nlist( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, + extra_nlist_sort: bool = False, ): """Format the neighbor list. @@ -374,6 +380,8 @@ def format_nlist( atomic type in extended region. nf x nall nlist neighbor list. nf x nloc x nsel + extra_nlist_sort + whether to forcibly sort the nlist. Returns ------- @@ -382,7 +390,12 @@ def format_nlist( """ mixed_types = self.mixed_types() - nlist = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + nlist = self._format_nlist( + extended_coord, + nlist, + sum(self.get_sel()), + extra_nlist_sort=extra_nlist_sort, + ) if not mixed_types: nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel()) return nlist @@ -392,6 +405,7 @@ def _format_nlist( extended_coord: torch.Tensor, nlist: torch.Tensor, nnei: int, + extra_nlist_sort: bool = False, ): n_nf, n_nloc, n_nnei = nlist.shape # nf x nall x 3 @@ -411,7 +425,9 @@ def _format_nlist( ], dim=-1, ) - elif n_nnei > nnei: + + if n_nnei > nnei or extra_nlist_sort: + n_nf, n_nloc, n_nnei = nlist.shape m_real_nei = nlist >= 0 nlist = torch.where(m_real_nei, nlist, 0) # nf x nloc x 3 @@ -428,7 +444,7 @@ def _format_nlist( nlist = torch.gather(nlist, 2, nlist_mapping) nlist = torch.where(rr > rcut, -1, nlist) nlist = nlist[..., :nnei] - else: # n_nnei == nnei: + else: # not extra_nlist_sort and n_nnei <= nnei: pass # great! assert nlist.shape[-1] == nnei return nlist @@ -552,6 +568,10 @@ def has_message_passing(self) -> bool: """Returns whether the model has message passing.""" return self.atomic_model.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.atomic_model.need_sorted_nlist_for_lower() + def forward( self, coord, diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 449fdbe700..7fbb7bdcf4 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -95,6 +95,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) if self.get_fitting_net() is not None: model_predict = {} diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 551c0b86b2..717a7ee7c8 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -342,6 +342,10 @@ def has_message_passing(self) -> bool: """Returns whether the model has message passing.""" return self.backbone_model.has_message_passing() + def need_sorted_nlist_for_lower(self) -> bool: + """Returns whether the model needs sorted nlist when using `forward_lower`.""" + return self.backbone_model.need_sorted_nlist_for_lower() + def model_output_def(self): """Get the output def for the model.""" model_output_type = self.backbone_model.model_output_type() @@ -467,6 +471,7 @@ def forward_common_lower( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, + extra_nlist_sort: bool = False, ): nframes, nloc = nlist.shape[:2] ( @@ -487,6 +492,7 @@ def forward_common_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extra_nlist_sort=extra_nlist_sort, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: @@ -611,6 +617,7 @@ def forward_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extra_nlist_sort=self.backbone_model.need_sorted_nlist_for_lower(), ) model_predict = {} model_predict["atom_energy"] = model_ret["energy"] diff --git a/source/lmp/tests/test_lammps_dpa_sel_pt.py b/source/lmp/tests/test_lammps_dpa_sel_pt.py new file mode 100644 index 0000000000..03e2501efb --- /dev/null +++ b/source/lmp/tests/test_lammps_dpa_sel_pt.py @@ -0,0 +1,724 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import os +import shutil +import subprocess as sp +import sys +import tempfile +from pathlib import ( + Path, +) + +import constants +import numpy as np +import pytest +from lammps import ( + PyLammps, +) +from write_lmp_data import ( + write_lmp_data, +) + +pbtxt_file2 = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot-1.pbtxt" +) +# large repinit sel but small repformer sel +pb_file = ( + Path(__file__).parent.parent.parent / "tests" / "infer" / "deeppot_dpa_sel.pth" +) +pb_file2 = Path(__file__).parent / "graph2.pb" +system_file = Path(__file__).parent.parent.parent / "tests" +data_file = Path(__file__).parent / "data.lmp" +data_file_si = Path(__file__).parent / "data.si" +data_type_map_file = Path(__file__).parent / "data_type_map.lmp" +md_file = Path(__file__).parent / "md.out" + +# this is as the same as python and c++ tests, test_deeppot_a.py +expected_ae = np.array( + [ + -94.40466356082422, + -188.20655580528742, + -188.172650838896, + -94.3984730612324, + -188.18804200217326, + -188.20912570390797, + ] +) +expected_e = np.sum(expected_ae) +expected_f = np.array( + [ + -0.5269430960718773, + 0.09443722477575306, + -0.018996127144558193, + 0.07511784469939177, + -0.004636423045215931, + -0.06042882995560078, + -0.11356148928265902, + -0.14249867913062475, + 0.11471641225723211, + 0.48857267799774884, + 0.029274479383282204, + 0.0018077032375469655, + 0.14145328669603485, + 0.061307914850956956, + -0.08774313950622735, + -0.06463922403863916, + -0.03788451683415152, + 0.050643981111607235, + ] +).reshape(6, 3) + +expected_f2 = np.array( + [ + [-0.6454949, 1.72457783, 0.18897958], + [1.68936514, -0.36995299, -1.36044464], + [-1.09902692, -1.35487928, 1.17416702], + [1.68426111, -0.50835585, 0.98340415], + [0.05771758, 1.12515818, -1.77561531], + [-1.686822, -0.61654789, 0.78950921], + ] +) + +expected_v = -np.array( + [ + 0.9071749098850648, + 0.06394291002323482, + -0.045778841699466444, + 0.05724095081080198, + -0.04587607140012173, + 0.03338900821751993, + -0.08821876554631314, + 0.028921736412500003, + -0.0016941267234178055, + 0.0328481028525373, + 0.011077847594560757, + 0.05737319258976218, + 0.03033379636209457, + -0.007106204143787434, + -0.008933706230224273, + 0.08706716158937683, + -0.007590237086934508, + -0.0465897822325519, + 0.005288635023633567, + -0.04363673459933623, + 0.040896766225094555, + -0.0776988217139129, + -0.04503884467345057, + 0.034987399918229245, + 0.06527106015783832, + 0.036805235933779795, + -0.03289891755994384, + 0.9956154345592723, + 0.11963562102541159, + -0.11601555180804074, + 0.12681453991319047, + 0.01822615751480253, + -0.020439753868777312, + -0.10614750448436672, + -0.018079989970225654, + 0.021624509802219784, + 0.29066664335998216, + -0.017510677635950628, + 0.040767419279345324, + 0.0019631746760569863, + -0.029874379170659604, + 0.047763391313415365, + 0.016942327234135898, + 0.04694015440808539, + -0.07381327572535609, + -0.0318765521505375, + -0.03129174028722404, + 0.04031856154341752, + -0.036436413927534876, + -0.015255351334518573, + 0.020292770612590032, + 0.04264726717944152, + 0.0200622102655479, + -0.026746995544407036, + ] +).reshape(6, 9) +expected_v2 = -np.array( + [ + [ + -0.70008436, + -0.06399891, + 0.63678391, + -0.07642171, + -0.70580035, + 0.20506145, + 0.64098364, + 0.20305781, + -0.57906794, + ], + [ + -0.6372635, + 0.14315552, + 0.51952246, + 0.04604049, + -0.06003681, + -0.02688702, + 0.54489318, + -0.10951559, + -0.43730539, + ], + [ + -0.25090748, + -0.37466262, + 0.34085833, + -0.26690852, + -0.37676917, + 0.29080825, + 0.31600481, + 0.37558276, + -0.33251064, + ], + [ + -0.80195614, + -0.10273138, + 0.06935364, + -0.10429256, + -0.29693811, + 0.45643496, + 0.07247872, + 0.45604679, + -0.71048816, + ], + [ + -0.03840668, + -0.07680205, + 0.10940472, + -0.02374189, + -0.27610266, + 0.4336071, + 0.02465248, + 0.4290638, + -0.67496763, + ], + [ + -0.61475065, + -0.21163135, + 0.26652929, + -0.26134659, + -0.11560267, + 0.15415902, + 0.34343952, + 0.1589482, + -0.21370642, + ], + ] +).reshape(6, 9) + +box = np.array([0, 13, 0, 13, 0, 13, 0, 0, 0]) +coord = np.array( + [ + [12.83, 2.56, 2.18], + [12.09, 2.87, 2.74], + [0.25, 3.32, 1.68], + [3.36, 3.00, 1.81], + [3.51, 2.51, 2.60], + [4.27, 3.22, 1.56], + ] +) +type_OH = np.array([1, 2, 2, 1, 2, 2]) +type_HO = np.array([2, 1, 1, 2, 1, 1]) + + +sp.check_output( + f"{sys.executable} -m deepmd convert-from pbtxt -i {pbtxt_file2.resolve()} -o {pb_file2.resolve()}".split() +) + + +def setup_module(): + write_lmp_data(box, coord, type_OH, data_file) + write_lmp_data(box, coord, type_HO, data_type_map_file) + write_lmp_data( + box * constants.dist_metal2si, + coord * constants.dist_metal2si, + type_OH, + data_file_si, + ) + + +def teardown_module(): + os.remove(data_file) + os.remove(data_type_map_file) + + +def _lammps(data_file, units="metal") -> PyLammps: + lammps = PyLammps() + lammps.units(units) + lammps.boundary("p p p") + lammps.atom_style("atomic") + if units == "metal" or units == "real": + lammps.neighbor("2.0 bin") + elif units == "si": + lammps.neighbor("2.0e-10 bin") + else: + raise ValueError("units should be metal, real, or si") + lammps.neigh_modify("every 10 delay 0 check no") + lammps.read_data(data_file.resolve()) + if units == "metal" or units == "real": + lammps.mass("1 16") + lammps.mass("2 2") + elif units == "si": + lammps.mass("1 %.10e" % (16 * constants.mass_metal2si)) + lammps.mass("2 %.10e" % (2 * constants.mass_metal2si)) + else: + raise ValueError("units should be metal, real, or si") + if units == "metal": + lammps.timestep(0.0005) + elif units == "real": + lammps.timestep(0.5) + elif units == "si": + lammps.timestep(5e-16) + else: + raise ValueError("units should be metal, real, or si") + lammps.fix("1 all nve") + return lammps + + +@pytest.fixture +def lammps(): + lmp = _lammps(data_file=data_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_type_map(): + lmp = _lammps(data_file=data_type_map_file) + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_real(): + lmp = _lammps(data_file=data_file, units="real") + yield lmp + lmp.close() + + +@pytest.fixture +def lammps_si(): + lmp = _lammps(data_file=data_file_si, units="si") + yield lmp + lmp.close() + + +def test_pair_deepmd(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + lammps.run(1) + + +def test_pair_deepmd_virial(lammps): + lammps.pair_style(f"deepmd {pb_file.resolve()}") + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + + +def test_pair_deepmd_model_devi(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_virial(lammps): + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps.pair_coeff("* *") + lammps.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + idx_map = lammps.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps.variables[f"virial{ii}"].value + ) / constants.nktv2p == pytest.approx(expected_v[idx_map, ii]) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_model_devi_atomic_relative_v(lammps): + relative = 1.0 + lammps.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative}" + ) + lammps.pair_coeff("* *") + lammps.run(0) + assert lammps.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps.atoms[ii].force == pytest.approx( + expected_f[lammps.atoms[ii].id - 1] + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) + + +def test_pair_deepmd_type_map(lammps_type_map): + lammps_type_map.pair_style(f"deepmd {pb_file.resolve()}") + lammps_type_map.pair_coeff("* * H O") + lammps_type_map.run(0) + assert lammps_type_map.eval("pe") == pytest.approx(expected_e) + for ii in range(6): + assert lammps_type_map.atoms[ii].force == pytest.approx( + expected_f[lammps_type_map.atoms[ii].id - 1] + ) + lammps_type_map.run(1) + + +def test_pair_deepmd_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + lammps_real.run(1) + + +def test_pair_deepmd_virial_real(lammps_real): + lammps_real.pair_style(f"deepmd {pb_file.resolve()}") + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_virial_real(lammps_real): + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic" + ) + lammps_real.pair_coeff("* *") + lammps_real.compute("virial all centroid/stress/atom NULL pair") + for ii in range(9): + jj = [0, 4, 8, 3, 6, 7, 1, 2, 5][ii] + lammps_real.variable(f"virial{jj} atom c_virial[{ii+1}]") + lammps_real.dump( + "1 all custom 1 dump id " + " ".join([f"v_virial{ii}" for ii in range(9)]) + ) + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + idx_map = lammps_real.lmp.numpy.extract_atom("id") - 1 + for ii in range(9): + assert np.array( + lammps_real.variables[f"virial{ii}"].value + ) / constants.nktv2p_real == pytest.approx( + expected_v[idx_map, ii] * constants.ener_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative {relative * constants.force_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_model_devi_atomic_relative_v_real(lammps_real): + relative = 1.0 + lammps_real.pair_style( + f"deepmd {pb_file.resolve()} {pb_file2.resolve()} out_file {md_file.resolve()} out_freq 1 atomic relative_v {relative * constants.ener_metal2real}" + ) + lammps_real.pair_coeff("* *") + lammps_real.run(0) + assert lammps_real.eval("pe") == pytest.approx( + expected_e * constants.ener_metal2real + ) + for ii in range(6): + assert lammps_real.atoms[ii].force == pytest.approx( + expected_f[lammps_real.atoms[ii].id - 1] * constants.force_metal2real + ) + md = np.loadtxt(md_file.resolve()) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + assert md[7:] == pytest.approx(expected_md_f * constants.force_metal2real) + assert md[4] == pytest.approx(np.max(expected_md_f) * constants.force_metal2real) + assert md[5] == pytest.approx(np.min(expected_md_f) * constants.force_metal2real) + assert md[6] == pytest.approx(np.mean(expected_md_f) * constants.force_metal2real) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + norm = ( + np.abs( + np.mean([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) + ) + / 6 + ) + expected_md_v /= norm + relative + assert md[1] == pytest.approx(np.max(expected_md_v) * constants.ener_metal2real) + assert md[2] == pytest.approx(np.min(expected_md_v) * constants.ener_metal2real) + assert md[3] == pytest.approx( + np.sqrt(np.mean(np.square(expected_md_v))) * constants.ener_metal2real + ) + + +def test_pair_deepmd_si(lammps_si): + lammps_si.pair_style(f"deepmd {pb_file.resolve()}") + lammps_si.pair_coeff("* *") + lammps_si.run(0) + assert lammps_si.eval("pe") == pytest.approx(expected_e * constants.ener_metal2si) + for ii in range(6): + assert lammps_si.atoms[ii].force == pytest.approx( + expected_f[lammps_si.atoms[ii].id - 1] * constants.force_metal2si + ) + lammps_si.run(1) + + +@pytest.mark.skipif( + shutil.which("mpirun") is None, reason="MPI is not installed on this system" +) +@pytest.mark.skipif( + importlib.util.find_spec("mpi4py") is None, reason="mpi4py is not installed" +) +@pytest.mark.parametrize( + ("balance_args",), + [(["--balance"],)], +) +def test_pair_deepmd_mpi(balance_args: list): + with tempfile.NamedTemporaryFile() as f: + sp.check_call( + [ + "mpirun", + "-n", + "2", + sys.executable, + Path(__file__).parent / "run_mpi_pair_deepmd.py", + data_file, + pb_file, + pb_file2, + md_file, + f.name, + *balance_args, + ] + ) + arr = np.loadtxt(f.name, ndmin=1) + pe = arr[0] + + relative = 1.0 + assert pe == pytest.approx(expected_e) + # load model devi + md = np.loadtxt(md_file.resolve()) + norm = np.linalg.norm(np.mean([expected_f, expected_f2], axis=0), axis=1) + expected_md_f = np.linalg.norm(np.std([expected_f, expected_f2], axis=0), axis=1) + expected_md_f /= norm + relative + assert md[7:] == pytest.approx(expected_md_f) + assert md[4] == pytest.approx(np.max(expected_md_f)) + assert md[5] == pytest.approx(np.min(expected_md_f)) + assert md[6] == pytest.approx(np.mean(expected_md_f)) + expected_md_v = ( + np.std([np.sum(expected_v, axis=0), np.sum(expected_v2, axis=0)], axis=0) / 6 + ) + assert md[1] == pytest.approx(np.max(expected_md_v)) + assert md[2] == pytest.approx(np.min(expected_md_v)) + assert md[3] == pytest.approx(np.sqrt(np.mean(np.square(expected_md_v)))) diff --git a/source/tests/infer/deeppot_dpa_sel.pth b/source/tests/infer/deeppot_dpa_sel.pth new file mode 100644 index 0000000000..ff9846c0ce Binary files /dev/null and b/source/tests/infer/deeppot_dpa_sel.pth differ diff --git a/source/tests/universal/common/cases/model/utils.py b/source/tests/universal/common/cases/model/utils.py index 2be55a6337..66b2e64fd3 100644 --- a/source/tests/universal/common/cases/model/utils.py +++ b/source/tests/universal/common/cases/model/utils.py @@ -178,6 +178,11 @@ def test_forward(self): input_dict_lower["extended_spin"] = spin_ext ret_lower.append(module.forward_lower(**input_dict_lower)) + + # use shuffled nlist, simulating the lammps interface + rng.shuffle(input_dict_lower["nlist"], axis=-1) + ret_lower.append(module.forward_lower(**input_dict_lower)) + for kk in ret[0]: subret = [] for rr in ret: