From bea33efe80eff60ca3c74d516dbbd46c1c90791c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 23:03:45 +0800 Subject: [PATCH 1/6] refact(dpmodel): model output made the same as pt backend --- deepmd/dpmodel/infer/deep_eval.py | 4 +- deepmd/dpmodel/model/dipole_model.py | 81 +++++++++++++++++ deepmd/dpmodel/model/dos_model.py | 63 +++++++++++++ deepmd/dpmodel/model/dp_zbl_model.py | 85 +++++++++++++++++ deepmd/dpmodel/model/ener_model.py | 69 ++++++++++++++ deepmd/dpmodel/model/make_model.py | 12 ++- deepmd/dpmodel/model/polar_model.py | 63 +++++++++++++ deepmd/dpmodel/model/property_model.py | 69 ++++++++++++++ deepmd/dpmodel/model/spin_model.py | 6 +- deepmd/jax/infer/deep_eval.py | 4 +- deepmd/jax/jax2tf/serialization.py | 2 +- deepmd/jax/utils/serialization.py | 2 +- deepmd/pt/model/model/dipole_model.py | 16 ++-- deepmd/pt/model/model/dp_zbl_model.py | 2 +- deepmd/pt/model/model/ener_model.py | 2 +- deepmd/pt_expt/model/ener_model.py | 4 +- source/tests/common/dpmodel/test_dp_model.py | 10 +- .../common/dpmodel/test_padding_atoms.py | 8 +- source/tests/consistent/model/test_dipole.py | 17 ++-- source/tests/consistent/model/test_dos.py | 17 ++-- source/tests/consistent/model/test_dpa1.py | 42 +++------ source/tests/consistent/model/test_ener.py | 91 ++++++------------- source/tests/consistent/model/test_frozen.py | 8 +- source/tests/consistent/model/test_polar.py | 17 ++-- .../tests/consistent/model/test_property.py | 7 +- .../tests/consistent/model/test_zbl_ener.py | 20 ++-- source/tests/jax/test_dp_hessian_model.py | 24 ++--- source/tests/jax/test_make_hessian_model.py | 6 +- source/tests/jax/test_padding_atoms.py | 8 +- source/tests/pd/model/test_dp_model.py | 6 +- source/tests/pt/model/test_dp_model.py | 6 +- source/tests/pt_expt/model/test_ener_model.py | 4 +- 32 files changed, 572 insertions(+), 203 deletions(-) diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index 8088ba1d2f..c80898ec74 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -358,9 +358,7 @@ def _eval_model( results = [] for odef in request_defs: - # it seems not doing conversion - # dp_name = self._OUTDEF_DP2BACKEND[odef.name] - dp_name = odef.name + dp_name = self._OUTDEF_DP2BACKEND[odef.name] if dp_name in batch_output: shape = self._get_output_shape(odef, nframes, natoms) if batch_output[dp_name] is not None: diff --git a/deepmd/dpmodel/model/dipole_model.py b/deepmd/dpmodel/model/dipole_model.py index d213514551..e01e309973 100644 --- a/deepmd/dpmodel/model/dipole_model.py +++ b/deepmd/dpmodel/model/dipole_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPDipoleAtomicModel, ) @@ -31,3 +34,81 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPDipoleModel_.__init__(self, *args, **kwargs) + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["dipole"] = model_ret["dipole"] + model_predict["global_dipole"] = model_ret["dipole_redu"] + if self.do_grad_r("dipole") and model_ret["dipole_derv_r"] is not None: + model_predict["force"] = model_ret["dipole_derv_r"] + if self.do_grad_c("dipole") and model_ret["dipole_derv_c_redu"] is not None: + model_predict["virial"] = model_ret["dipole_derv_c_redu"] + if do_atomic_virial and model_ret["dipole_derv_c"] is not None: + model_predict["atom_virial"] = model_ret["dipole_derv_c"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["dipole"] = model_ret["dipole"] + model_predict["global_dipole"] = model_ret["dipole_redu"] + if self.do_grad_r("dipole") and model_ret.get("dipole_derv_r") is not None: + model_predict["extended_force"] = model_ret["dipole_derv_r"] + if self.do_grad_c("dipole") and model_ret.get("dipole_derv_c_redu") is not None: + model_predict["virial"] = model_ret["dipole_derv_c_redu"] + if do_atomic_virial and model_ret.get("dipole_derv_c") is not None: + model_predict["extended_virial"] = model_ret["dipole_derv_c"] + return model_predict + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "dipole": out_def_data["dipole"], + "global_dipole": out_def_data["dipole_redu"], + } + if self.do_grad_r("dipole"): + output_def["force"] = out_def_data["dipole_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("dipole"): + output_def["virial"] = out_def_data["dipole_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["dipole_derv_c"] + output_def["atom_virial"].squeeze(-2) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def diff --git a/deepmd/dpmodel/model/dos_model.py b/deepmd/dpmodel/model/dos_model.py index 5c5d2a5e90..dded7b076a 100644 --- a/deepmd/dpmodel/model/dos_model.py +++ b/deepmd/dpmodel/model/dos_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPDOSAtomicModel, ) @@ -31,3 +34,63 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPDOSModel_.__init__(self, *args, **kwargs) + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_dos"] = model_ret["dos"] + model_predict["dos"] = model_ret["dos_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_dos": out_def_data["dos"], + "dos": out_def_data["dos_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def diff --git a/deepmd/dpmodel/model/dp_zbl_model.py b/deepmd/dpmodel/model/dp_zbl_model.py index b5940f4707..81c2476447 100644 --- a/deepmd/dpmodel/model/dp_zbl_model.py +++ b/deepmd/dpmodel/model/dp_zbl_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model.linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, ) @@ -34,6 +37,88 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy") and model_ret["energy_derv_r"] is not None: + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy") and model_ret["energy_derv_c_redu"] is not None: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial and model_ret["energy_derv_c"] is not None: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy") and model_ret.get("energy_derv_c_redu") is not None: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial and model_ret.get("energy_derv_c") is not None: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -2 + ) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-2) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def + @classmethod def update_sel( cls, diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index 27d6db811e..ac90d94fc5 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -6,6 +6,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPEnergyAtomicModel, ) @@ -48,6 +51,72 @@ def atomic_output_def(self) -> FittingOutputDef: return self.hess_fitting_def return super().atomic_output_def() + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy") and model_ret["energy_derv_r"] is not None: + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy") and model_ret["energy_derv_c_redu"] is not None: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial and model_ret["energy_derv_c"] is not None: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._enable_hessian and model_ret.get("energy_derv_r_derv_r") is not None: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3) + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy") and model_ret.get("energy_derv_c_redu") is not None: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial and model_ret.get("energy_derv_c") is not None: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -2 + ) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + def translated_output_def(self) -> dict[str, Any]: """Get the translated output definition. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index e115478df5..786b12a2f0 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -223,7 +223,7 @@ def enable_compression( check_frequency, ) - def call( + def call_common( self, coord: Array, atype: Array, @@ -262,7 +262,7 @@ def call( ) del coord, box, fparam, aparam model_predict = model_call_from_call_lower( - call_lower=self.call_lower, + call_lower=self.call_common_lower, rcut=self.get_rcut(), sel=self.get_sel(), mixed_types=self.mixed_types(), @@ -277,7 +277,7 @@ def call( model_predict = self._output_type_cast(model_predict, input_prec) return model_predict - def call_lower( + def call_common_lower( self, extended_coord: Array, extended_atype: Array, @@ -365,9 +365,11 @@ def forward_common_atomic( mask=atomic_ret["mask"] if "mask" in atomic_ret else None, ) + call = call_common + call_lower = call_common_lower forward_lower = call_lower - forward_common = call - forward_common_lower = call_lower + forward_common = call_common + forward_common_lower = call_common_lower def get_out_bias(self) -> Array: """Get the output bias.""" diff --git a/deepmd/dpmodel/model/polar_model.py b/deepmd/dpmodel/model/polar_model.py index b898eababd..057410f280 100644 --- a/deepmd/dpmodel/model/polar_model.py +++ b/deepmd/dpmodel/model/polar_model.py @@ -3,6 +3,9 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPPolarAtomicModel, ) @@ -31,3 +34,63 @@ def __init__( ) -> None: DPModelCommon.__init__(self) DPPolarModel_.__init__(self, *args, **kwargs) + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["polar"] = model_ret["polarizability"] + model_predict["global_polar"] = model_ret["polarizability_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["polar"] = model_ret["polarizability"] + model_predict["global_polar"] = model_ret["polarizability_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "polar": out_def_data["polarizability"], + "global_polar": out_def_data["polarizability_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def diff --git a/deepmd/dpmodel/model/property_model.py b/deepmd/dpmodel/model/property_model.py index 20c884cd20..ea34609393 100644 --- a/deepmd/dpmodel/model/property_model.py +++ b/deepmd/dpmodel/model/property_model.py @@ -3,12 +3,18 @@ Any, ) +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.atomic_model import ( DPPropertyAtomicModel, ) from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.output_def import ( + OutputVariableDef, +) from .dp_model import ( DPModelCommon, @@ -33,3 +39,66 @@ def __init__( def get_var_name(self) -> str: """Get the name of the property.""" return self.get_fitting_net().var_name + + def call( + self, + coord: Array, + atype: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + var_name = self.get_var_name() + model_predict = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + model_predict[var_name] = model_ret[f"{var_name}_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + var_name = self.get_var_name() + model_predict = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + model_predict[var_name] = model_ret[f"{var_name}_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict + + def translated_output_def(self) -> dict[str, OutputVariableDef]: + out_def_data = self.model_output_def().get_data() + var_name = self.get_var_name() + output_def = { + f"atom_{var_name}": out_def_data[var_name], + var_name: out_def_data[f"{var_name}_redu"], + } + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + return output_def diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index 521978bdde..079ee3519e 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -377,7 +377,7 @@ def call( coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) - model_predict = self.backbone_model.call( + model_predict = self.backbone_model.call_common( coord_updated, atype_updated, box, @@ -447,7 +447,7 @@ def call_lower( ) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) - model_predict = self.backbone_model.call_lower( + model_predict = self.backbone_model.call_common_lower( extended_coord_updated, extended_atype_updated, nlist_updated, @@ -465,5 +465,3 @@ def call_lower( )[0] # for now omit the grad output return model_predict - - forward_lower = call_lower diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 4008d75a53..30ff28680c 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -388,9 +388,7 @@ def _eval_model( results = [] for odef in request_defs: - # it seems not doing conversion - # dp_name = self._OUTDEF_DP2BACKEND[odef.name] - dp_name = odef.name + dp_name = self._OUTDEF_DP2BACKEND[odef.name] if dp_name in batch_output: shape = self._get_output_shape(odef, nframes, natoms) if batch_output[dp_name] is not None: diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 31d0d7eb82..4881ca98f8 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -34,7 +34,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: if model_file.endswith(".savedmodel"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] - call_lower = model.call_lower + call_lower = model.call_common_lower tf_model = tf.Module() diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 5d3432aab8..14386d9f3d 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -49,7 +49,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: elif model_file.endswith(".hlo"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] - call_lower = model.call_lower + call_lower = model.call_common_lower nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index 5cfebb4b03..a5c86b3337 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -74,13 +74,11 @@ def forward( model_predict["dipole"] = model_ret["dipole"] model_predict["global_dipole"] = model_ret["dipole_redu"] if self.do_grad_r("dipole"): - model_predict["force"] = model_ret["dipole_derv_r"].squeeze(-2) + model_predict["force"] = model_ret["dipole_derv_r"] if self.do_grad_c("dipole"): - model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2) + model_predict["virial"] = model_ret["dipole_derv_c_redu"] if do_atomic_virial: - model_predict["atom_virial"] = model_ret["dipole_derv_c"].squeeze( - -3 - ) + model_predict["atom_virial"] = model_ret["dipole_derv_c"] if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] else: @@ -116,13 +114,11 @@ def forward_lower( model_predict["dipole"] = model_ret["dipole"] model_predict["global_dipole"] = model_ret["dipole_redu"] if self.do_grad_r("dipole"): - model_predict["extended_force"] = model_ret["dipole_derv_r"].squeeze(-2) + model_predict["extended_force"] = model_ret["dipole_derv_r"] if self.do_grad_c("dipole"): - model_predict["virial"] = model_ret["dipole_derv_c_redu"].squeeze(-2) + model_predict["virial"] = model_ret["dipole_derv_c_redu"] if do_atomic_virial: - model_predict["extended_virial"] = model_ret[ - "dipole_derv_c" - ].squeeze(-2) + model_predict["extended_virial"] = model_ret["dipole_derv_c"] else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index 07f0732687..f44cb926d0 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -120,7 +120,7 @@ def forward_lower( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) else: assert model_ret["dforce"] is not None diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 36beb33ff6..4b680013e6 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -124,7 +124,7 @@ def forward( if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] if self._hessian_enabled: - model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2) + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3) else: model_predict = model_ret model_predict["updated_coord"] += coord diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 5547543d27..5f30f3a227 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -42,7 +42,7 @@ def forward( aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: - model_ret = self.call( + model_ret = self.call_common( coord, atype, box, @@ -73,7 +73,7 @@ def _forward_lower( aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: - model_ret = self.call_lower( + model_ret = self.call_common_lower( extended_coord, extended_atype, nlist, diff --git a/source/tests/common/dpmodel/test_dp_model.py b/source/tests/common/dpmodel/test_dp_model.py index af4eea624d..29b73a75ec 100644 --- a/source/tests/common/dpmodel/test_dp_model.py +++ b/source/tests/common/dpmodel/test_dp_model.py @@ -49,8 +49,8 @@ def test_self_consistency( ret0 = md0.call_lower(self.coord_ext, self.atype_ext, self.nlist) ret1 = md1.call_lower(self.coord_ext, self.atype_ext, self.nlist) + np.testing.assert_allclose(ret0["atom_energy"], ret1["atom_energy"]) np.testing.assert_allclose(ret0["energy"], ret1["energy"]) - np.testing.assert_allclose(ret0["energy_redu"], ret1["energy_redu"]) def test_prec_consistency(self) -> None: rng = np.random.default_rng(GLOBAL_SEED) @@ -83,10 +83,12 @@ def test_prec_consistency(self) -> None: model_l_ret_64 = md1.call_lower(*args64, fparam=fparam, aparam=aparam) model_l_ret_32 = md1.call_lower(*args32, fparam=fparam, aparam=aparam) + # After translation, reduced keys are "energy" and "virial" + _REDUCED_KEYS = {"energy", "virial"} for ii in model_l_ret_32.keys(): if model_l_ret_32[ii] is None: continue - if ii[-4:] == "redu": + if ii in _REDUCED_KEYS: self.assertEqual(model_l_ret_32[ii].dtype, np.float64) else: self.assertEqual(model_l_ret_32[ii].dtype, np.float32) @@ -137,10 +139,12 @@ def test_prec_consistency(self) -> None: model_l_ret_64 = md1.call(*args64, fparam=fparam, aparam=aparam) model_l_ret_32 = md1.call(*args32, fparam=fparam, aparam=aparam) + # After translation, reduced keys are "energy" and "virial" + _REDUCED_KEYS = {"energy", "virial"} for ii in model_l_ret_32.keys(): if model_l_ret_32[ii] is None: continue - if ii[-4:] == "redu": + if ii in _REDUCED_KEYS: self.assertEqual(model_l_ret_32[ii].dtype, np.float64) else: self.assertEqual(model_l_ret_32[ii].dtype, np.float32) diff --git a/source/tests/common/dpmodel/test_padding_atoms.py b/source/tests/common/dpmodel/test_padding_atoms.py index d4ea39f598..29e34c09a9 100644 --- a/source/tests/common/dpmodel/test_padding_atoms.py +++ b/source/tests/common/dpmodel/test_padding_atoms.py @@ -69,8 +69,8 @@ def test_padding_atoms_consistency(self): result = model.call(*args) # test intensive np.testing.assert_allclose( - result[f"{var_name}_redu"], - np.mean(result[f"{var_name}"], axis=1), + result[var_name], + np.mean(result[f"atom_{var_name}"], axis=1), atol=self.atol, ) # test padding atoms @@ -93,8 +93,8 @@ def test_padding_atoms_consistency(self): args = [coord_padding, atype_padding, self.cell] result_padding = model.call(*args) np.testing.assert_allclose( - result[f"{var_name}_redu"], - result_padding[f"{var_name}_redu"], + result[var_name], + result_padding[var_name], atol=self.atol, ) diff --git a/source/tests/consistent/model/test_dipole.py b/source/tests/consistent/model/test_dipole.py index 339dcae7c3..bcd199a633 100644 --- a/source/tests/consistent/model/test_dipole.py +++ b/source/tests/consistent/model/test_dipole.py @@ -188,21 +188,20 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend in {self.RefBackend.DP, self.RefBackend.JAX}: + if backend is self.RefBackend.TF: return ( - ret["dipole_redu"].ravel(), - ret["dipole"].ravel(), + ret[0].ravel(), + ret[1].ravel(), ) - elif backend is self.RefBackend.PT: + elif backend in { + self.RefBackend.DP, + self.RefBackend.PT, + self.RefBackend.JAX, + }: return ( ret["global_dipole"].ravel(), ret["dipole"].ravel(), ) - elif backend is self.RefBackend.TF: - return ( - ret[0].ravel(), - ret[1].ravel(), - ) raise ValueError(f"Unknown backend: {backend}") def test_atom_exclude_types(self): diff --git a/source/tests/consistent/model/test_dos.py b/source/tests/consistent/model/test_dos.py index ef72e9096b..f967973913 100644 --- a/source/tests/consistent/model/test_dos.py +++ b/source/tests/consistent/model/test_dos.py @@ -182,19 +182,18 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend in {self.RefBackend.DP, self.RefBackend.JAX}: + if backend is self.RefBackend.TF: return ( - ret["dos_redu"].ravel(), - ret["dos"].ravel(), + ret[0].ravel(), + ret[1].ravel(), ) - elif backend is self.RefBackend.PT: + elif backend in { + self.RefBackend.DP, + self.RefBackend.PT, + self.RefBackend.JAX, + }: return ( ret["dos"].ravel(), ret["atom_dos"].ravel(), ) - elif backend is self.RefBackend.TF: - return ( - ret[0].ravel(), - ret[1].ravel(), - ) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py index 8b8fab7ae1..bacca12413 100644 --- a/source/tests/consistent/model/test_dpa1.py +++ b/source/tests/consistent/model/test_dpa1.py @@ -221,15 +221,27 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend is self.RefBackend.DP: + if backend is self.RefBackend.TF: + return ( + ret[0].ravel(), + ret[1].ravel(), + ret[2].ravel(), + ret[3].ravel(), + ret[4].ravel(), + ) + elif backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, SKIP_FLAG, ) - elif backend is self.RefBackend.PT: + elif backend in { + self.RefBackend.PT, + self.RefBackend.PD, + self.RefBackend.JAX, + }: return ( ret["energy"].ravel(), ret["atom_energy"].ravel(), @@ -237,28 +249,4 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["virial"].ravel(), ret["atom_virial"].ravel(), ) - elif backend is self.RefBackend.TF: - return ( - ret[0].ravel(), - ret[1].ravel(), - ret[2].ravel(), - ret[3].ravel(), - ret[4].ravel(), - ) - elif backend is self.RefBackend.PD: - return ( - ret["energy"].flatten(), - ret["atom_energy"].flatten(), - ret["force"].flatten(), - ret["virial"].flatten(), - ret["atom_virial"].flatten(), - ) - elif backend is self.RefBackend.JAX: - return ( - ret["energy_redu"].ravel(), - ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ret["energy_derv_c"].ravel(), - ) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 6cceb1c640..5364cd1a6f 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -269,23 +269,28 @@ def eval_pd(self, pd_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend is self.RefBackend.DP: + if backend is self.RefBackend.TF: return ( - ret["energy_redu"].ravel(), - ret["energy"].ravel(), - SKIP_FLAG, - SKIP_FLAG, - SKIP_FLAG, + ret[0].ravel(), + ret[1].ravel(), + ret[2].ravel(), + ret[3].ravel(), + ret[4].ravel(), ) - elif backend is self.RefBackend.PT: + elif backend is self.RefBackend.DP: return ( ret["energy"].ravel(), ret["atom_energy"].ravel(), - ret["force"].ravel(), - ret["virial"].ravel(), - ret["atom_virial"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + SKIP_FLAG, ) - elif backend is self.RefBackend.PT_EXPT: + elif backend in { + self.RefBackend.PT, + self.RefBackend.PT_EXPT, + self.RefBackend.JAX, + self.RefBackend.PD, + }: return ( ret["energy"].ravel(), ret["atom_energy"].ravel(), @@ -293,30 +298,6 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["virial"].ravel(), ret["atom_virial"].ravel(), ) - elif backend is self.RefBackend.TF: - return ( - ret[0].ravel(), - ret[1].ravel(), - ret[2].ravel(), - ret[3].ravel(), - ret[4].ravel(), - ) - elif backend is self.RefBackend.JAX: - return ( - ret["energy_redu"].ravel(), - ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ret["energy_derv_c"].ravel(), - ) - elif backend is self.RefBackend.PD: - return ( - ret["energy"].flatten(), - ret["atom_energy"].flatten(), - ret["force"].flatten(), - ret["virial"].flatten(), - ret["atom_virial"].flatten(), - ) raise ValueError(f"Unknown backend: {backend}") @@ -536,20 +517,12 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, SKIP_FLAG, ) - elif backend is self.RefBackend.PT: - return ( - ret["energy"].ravel(), - ret["atom_energy"].ravel(), - ret["extended_force"].ravel(), - ret["virial"].ravel(), - ret["extended_virial"].ravel(), - ) elif backend is self.RefBackend.PT_EXPT: return ( ret["energy_redu"].ravel(), @@ -558,21 +531,17 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["energy_derv_c_redu"].ravel(), ret["energy_derv_c"].ravel(), ) - elif backend is self.RefBackend.JAX: + elif backend in { + self.RefBackend.PT, + self.RefBackend.JAX, + self.RefBackend.PD, + }: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ret["energy_derv_c"].ravel(), - ) - elif backend is self.RefBackend.PD: - return ( - ret["energy"].flatten(), - ret["atom_energy"].flatten(), - ret["extended_force"].flatten(), - ret["virial"].flatten(), - ret["extended_virial"].flatten(), + ret["atom_energy"].ravel(), + ret["extended_force"].ravel(), + ret["virial"].ravel(), + ret["extended_virial"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") @@ -726,8 +695,8 @@ def test_set_out_bias(self) -> None: ) def test_forward_common_alias(self) -> None: - """forward_common should be the same as call on dpmodel.""" - ret_call = self.dp_model.call( + """forward_common should be the same as call_common on dpmodel.""" + ret_call = self.dp_model.call_common( self.coords, self.atype, box=self.box, @@ -741,8 +710,8 @@ def test_forward_common_alias(self) -> None: np.testing.assert_equal(ret_call[key], ret_fc[key]) def test_forward_common_lower_alias(self) -> None: - """forward_common_lower should be the same as call_lower on dpmodel.""" - ret_call = self.dp_model.call_lower( + """forward_common_lower should be the same as call_common_lower on dpmodel.""" + ret_call = self.dp_model.call_common_lower( self.extended_coord, self.extended_atype, self.nlist, diff --git a/source/tests/consistent/model/test_frozen.py b/source/tests/consistent/model/test_frozen.py index ff7d651e7e..d7dfcfe735 100644 --- a/source/tests/consistent/model/test_frozen.py +++ b/source/tests/consistent/model/test_frozen.py @@ -150,10 +150,8 @@ def eval_pt(self, pt_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend is self.RefBackend.DP: - return (ret["energy_redu"].ravel(), ret["energy"].ravel()) - elif backend is self.RefBackend.PT: - return (ret["energy"].ravel(), ret["atom_energy"].ravel()) - elif backend is self.RefBackend.TF: + if backend is self.RefBackend.TF: return (ret[0].ravel(), ret[1].ravel()) + elif backend in {self.RefBackend.PT}: + return (ret["energy"].ravel(), ret["atom_energy"].ravel()) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/consistent/model/test_polar.py b/source/tests/consistent/model/test_polar.py index 1405814f03..0b9b94a599 100644 --- a/source/tests/consistent/model/test_polar.py +++ b/source/tests/consistent/model/test_polar.py @@ -182,21 +182,20 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend in {self.RefBackend.DP, self.RefBackend.JAX}: + if backend is self.RefBackend.TF: return ( - ret["polarizability_redu"].ravel(), - ret["polarizability"].ravel(), + ret[0].ravel(), + ret[1].ravel(), ) - elif backend is self.RefBackend.PT: + elif backend in { + self.RefBackend.DP, + self.RefBackend.PT, + self.RefBackend.JAX, + }: return ( ret["global_polar"].ravel(), ret["polar"].ravel(), ) - elif backend is self.RefBackend.TF: - return ( - ret[0].ravel(), - ret[1].ravel(), - ) raise ValueError(f"Unknown backend: {backend}") def test_atom_exclude_types(self): diff --git a/source/tests/consistent/model/test_property.py b/source/tests/consistent/model/test_property.py index 75aded98fd..33e63af98e 100644 --- a/source/tests/consistent/model/test_property.py +++ b/source/tests/consistent/model/test_property.py @@ -184,12 +184,7 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... property_name = self.data["fitting_net"]["property_name"] - if backend in {self.RefBackend.DP, self.RefBackend.JAX}: - return ( - ret[f"{property_name}_redu"].ravel(), - ret[property_name].ravel(), - ) - elif backend is self.RefBackend.PT: + if backend in {self.RefBackend.DP, self.RefBackend.PT, self.RefBackend.JAX}: return ( ret[property_name].ravel(), ret[f"atom_{property_name}"].ravel(), diff --git a/source/tests/consistent/model/test_zbl_ener.py b/source/tests/consistent/model/test_zbl_ener.py index 6fb44a59ed..2783bb4a02 100644 --- a/source/tests/consistent/model/test_zbl_ener.py +++ b/source/tests/consistent/model/test_zbl_ener.py @@ -207,27 +207,23 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... - if backend is self.RefBackend.DP: + if backend is self.RefBackend.TF: + return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) + elif backend is self.RefBackend.DP: return ( - ret["energy_redu"].ravel(), ret["energy"].ravel(), + ret["atom_energy"].ravel(), SKIP_FLAG, SKIP_FLAG, ) - elif backend is self.RefBackend.PT: + elif backend in { + self.RefBackend.PT, + self.RefBackend.JAX, + }: return ( ret["energy"].ravel(), ret["atom_energy"].ravel(), ret["force"].ravel(), ret["virial"].ravel(), ) - elif backend is self.RefBackend.TF: - return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel()) - elif backend is self.RefBackend.JAX: - return ( - ret["energy_redu"].ravel(), - ret["energy"].ravel(), - ret["energy_derv_r"].ravel(), - ret["energy_derv_c_redu"].ravel(), - ) raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/jax/test_dp_hessian_model.py b/source/tests/jax/test_dp_hessian_model.py index 89c066e980..00d9a7adee 100644 --- a/source/tests/jax/test_dp_hessian_model.py +++ b/source/tests/jax/test_dp_hessian_model.py @@ -82,34 +82,34 @@ def test_self_consistency(self): ret0 = md0.call(*args) ret1 = md1.call(*args) np.testing.assert_allclose( - to_numpy_array(ret0["energy"]), - to_numpy_array(ret1["energy"]), + to_numpy_array(ret0["atom_energy"]), + to_numpy_array(ret1["atom_energy"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_redu"]), - to_numpy_array(ret1["energy_redu"]), + to_numpy_array(ret0["energy"]), + to_numpy_array(ret1["energy"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_r"]), - to_numpy_array(ret1["energy_derv_r"]), + to_numpy_array(ret0["force"]), + to_numpy_array(ret1["force"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_c_redu"]), - to_numpy_array(ret1["energy_derv_c_redu"]), + to_numpy_array(ret0["virial"]), + to_numpy_array(ret1["virial"]), atol=self.atol, ) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_r_derv_r"]), - to_numpy_array(ret1["energy_derv_r_derv_r"]), + to_numpy_array(ret0["hessian"]), + to_numpy_array(ret1["hessian"]), atol=self.atol, ) ret0 = md0.call(*args, do_atomic_virial=True) ret1 = md1.call(*args, do_atomic_virial=True) np.testing.assert_allclose( - to_numpy_array(ret0["energy_derv_c"]), - to_numpy_array(ret1["energy_derv_c"]), + to_numpy_array(ret0["atom_virial"]), + to_numpy_array(ret1["atom_virial"]), atol=self.atol, ) diff --git a/source/tests/jax/test_make_hessian_model.py b/source/tests/jax/test_make_hessian_model.py index 8666ff4ad4..bb25bd67ca 100644 --- a/source/tests/jax/test_make_hessian_model.py +++ b/source/tests/jax/test_make_hessian_model.py @@ -100,7 +100,7 @@ def test( ) # compare hess and value models np.testing.assert_allclose(ret_dict0["energy"], ret_dict1["energy"]) - ana_hess = ret_dict0["energy_derv_r_derv_r"] + ana_hess = ret_dict0["hessian"] # compute finite difference fnt_hess = [] @@ -121,13 +121,13 @@ def np_infer( return ret def ff(xx): - return np_infer(xx)["energy_redu"] + return np_infer(xx)["energy"] xx = to_numpy_array(coord[ii]) fnt_hess.append(finite_hessian(ff, xx, delta=delta).squeeze()) # compare finite difference with autodiff - fnt_hess = np.stack(fnt_hess).reshape([nf, nv, natoms * 3, natoms * 3]) + fnt_hess = np.stack(fnt_hess).reshape([nf, natoms * 3, natoms * 3]) np.testing.assert_almost_equal( fnt_hess, to_numpy_array(ana_hess), decimal=places ) diff --git a/source/tests/jax/test_padding_atoms.py b/source/tests/jax/test_padding_atoms.py index b63b464721..0f1b569821 100644 --- a/source/tests/jax/test_padding_atoms.py +++ b/source/tests/jax/test_padding_atoms.py @@ -89,8 +89,8 @@ def test_padding_atoms_consistency(self): result = model.call(*args) # test intensive np.testing.assert_allclose( - to_numpy_array(result[f"{var_name}_redu"]), - np.mean(to_numpy_array(result[f"{var_name}"]), axis=1), + to_numpy_array(result[var_name]), + np.mean(to_numpy_array(result[f"atom_{var_name}"]), axis=1), atol=self.atol, ) # test padding atoms @@ -115,8 +115,8 @@ def test_padding_atoms_consistency(self): ] result_padding = model.call(*args) np.testing.assert_allclose( - to_numpy_array(result[f"{var_name}_redu"]), - to_numpy_array(result_padding[f"{var_name}_redu"]), + to_numpy_array(result[var_name]), + to_numpy_array(result_padding[var_name]), atol=self.atol, ) diff --git a/source/tests/pd/model/test_dp_model.py b/source/tests/pd/model/test_dp_model.py index a281851f14..5e30b5ebaa 100644 --- a/source/tests/pd/model/test_dp_model.py +++ b/source/tests/pd/model/test_dp_model.py @@ -140,7 +140,7 @@ def test_dp_consistency(self): args1 = [to_paddle_tensor(ii) for ii in [self.coord, self.atype, self.cell]] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_paddle_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -179,7 +179,7 @@ def test_dp_consistency_nopbc(self): args1 = [to_paddle_tensor(ii) for ii in args0] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_paddle_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -313,7 +313,7 @@ def test_dp_consistency(self): args1 = [ to_paddle_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] - ret0 = md0.call_lower(*args0) + ret0 = md0.call_common_lower(*args0) ret1 = md1.forward_common_lower(*args1) np.testing.assert_allclose( ret0["energy"], diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index 93153ce6d5..f4e350869a 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -140,7 +140,7 @@ def test_dp_consistency(self) -> None: args1 = [to_torch_tensor(ii) for ii in [self.coord, self.atype, self.cell]] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_torch_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -179,7 +179,7 @@ def test_dp_consistency_nopbc(self) -> None: args1 = [to_torch_tensor(ii) for ii in args0] kwargs0 = {"fparam": fparam, "aparam": aparam} kwargs1 = {kk: to_torch_tensor(vv) for kk, vv in kwargs0.items()} - ret0 = md0.call(*args0, **kwargs0) + ret0 = md0.call_common(*args0, **kwargs0) ret1 = md1.forward_common(*args1, **kwargs1) np.testing.assert_allclose( ret0["energy"], @@ -313,7 +313,7 @@ def test_dp_consistency(self) -> None: args1 = [ to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] ] - ret0 = md0.call_lower(*args0) + ret0 = md0.call_common_lower(*args0) ret1 = md1.forward_common_lower(*args1) np.testing.assert_allclose( ret0["energy"], diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index 6e5006661e..588b0ec2a9 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -364,13 +364,13 @@ def test_dp_consistency(self) -> None: ret_pt = md_pt(coord, self.atype, self.cell.reshape(1, 9)) np.testing.assert_allclose( - ret_dp["energy_redu"], + ret_dp["energy"], ret_pt["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, ) np.testing.assert_allclose( - ret_dp["energy"], + ret_dp["atom_energy"], ret_pt["atom_energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, From 9709c302ad96c731aa0d02a5ebf069cbf6289516 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 23:34:18 +0800 Subject: [PATCH 2/6] rm forward* from dpmodel backend --- deepmd/dpmodel/model/make_model.py | 3 -- deepmd/pt_expt/model/make_model.py | 10 +++++++ source/tests/consistent/model/test_ener.py | 32 ---------------------- 3 files changed, 10 insertions(+), 35 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 786b12a2f0..9c51be464f 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -367,9 +367,6 @@ def forward_common_atomic( call = call_common call_lower = call_common_lower - forward_lower = call_lower - forward_common = call_common - forward_common_lower = call_common_lower def get_out_bias(self) -> Array: """Get the output bias.""" diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index d26733696d..2785fade14 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -46,6 +46,16 @@ def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: """ return self.call(*args, **kwargs) + def forward_common(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: + """Forward common delegates to call_common().""" + return self.call_common(*args, **kwargs) + + def forward_common_lower( + self, *args: Any, **kwargs: Any + ) -> dict[str, torch.Tensor]: + """Forward common lower delegates to call_common_lower().""" + return self.call_common_lower(*args, **kwargs) + def forward_common_atomic( self, extended_coord: torch.Tensor, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 5364cd1a6f..0bb1c343bd 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -694,38 +694,6 @@ def test_set_out_bias(self) -> None: atol=1e-10, ) - def test_forward_common_alias(self) -> None: - """forward_common should be the same as call_common on dpmodel.""" - ret_call = self.dp_model.call_common( - self.coords, - self.atype, - box=self.box, - ) - ret_fc = self.dp_model.forward_common( - self.coords, - self.atype, - box=self.box, - ) - for key in ret_call: - np.testing.assert_equal(ret_call[key], ret_fc[key]) - - def test_forward_common_lower_alias(self) -> None: - """forward_common_lower should be the same as call_common_lower on dpmodel.""" - ret_call = self.dp_model.call_common_lower( - self.extended_coord, - self.extended_atype, - self.nlist, - self.mapping, - ) - ret_fc = self.dp_model.forward_common_lower( - self.extended_coord, - self.extended_atype, - self.nlist, - self.mapping, - ) - for key in ret_call: - np.testing.assert_equal(ret_call[key], ret_fc[key]) - def test_model_output_def(self) -> None: """model_output_def should return the same keys and shapes on dp and pt.""" dp_def = self.dp_model.model_output_def().get_data() From c859576931925a308dc8867e2baf417332fd8cf9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 23:34:35 +0800 Subject: [PATCH 3/6] fix: missing mask output --- deepmd/dpmodel/model/dipole_model.py | 2 ++ deepmd/pt/model/model/dipole_model.py | 2 ++ deepmd/pt/model/model/dos_model.py | 3 ++- deepmd/pt/model/model/dp_zbl_model.py | 2 ++ deepmd/pt/model/model/ener_model.py | 2 ++ deepmd/pt/model/model/polar_model.py | 2 ++ 6 files changed, 12 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/model/dipole_model.py b/deepmd/dpmodel/model/dipole_model.py index e01e309973..421dd1b10f 100644 --- a/deepmd/dpmodel/model/dipole_model.py +++ b/deepmd/dpmodel/model/dipole_model.py @@ -93,6 +93,8 @@ def call_lower( model_predict["virial"] = model_ret["dipole_derv_c_redu"] if do_atomic_virial and model_ret.get("dipole_derv_c") is not None: model_predict["extended_virial"] = model_ret["dipole_derv_c"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] return model_predict def translated_output_def(self) -> dict[str, Any]: diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index a5c86b3337..9bd52dd428 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -119,6 +119,8 @@ def forward_lower( model_predict["virial"] = model_ret["dipole_derv_c_redu"] if do_atomic_virial: model_predict["extended_virial"] = model_ret["dipole_derv_c"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/dos_model.py b/deepmd/pt/model/model/dos_model.py index 75f89c141a..d28487ed9c 100644 --- a/deepmd/pt/model/model/dos_model.py +++ b/deepmd/pt/model/model/dos_model.py @@ -105,7 +105,8 @@ def forward_lower( model_predict = {} model_predict["atom_dos"] = model_ret["dos"] model_predict["dos"] = model_ret["dos_redu"] - + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index f44cb926d0..ea2cd17f38 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -125,6 +125,8 @@ def forward_lower( else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] return model_predict @classmethod diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 4b680013e6..7dcd035412 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -168,6 +168,8 @@ def forward_lower( else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/polar_model.py b/deepmd/pt/model/model/polar_model.py index 7210117823..7c9550dc3a 100644 --- a/deepmd/pt/model/model/polar_model.py +++ b/deepmd/pt/model/model/polar_model.py @@ -102,6 +102,8 @@ def forward_lower( model_predict = {} model_predict["polar"] = model_ret["polarizability"] model_predict["global_polar"] = model_ret["polarizability_redu"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] else: model_predict = model_ret return model_predict From 8c356b4ee8cf820209ea3178b16ac5aedc016d7e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 23:47:38 +0800 Subject: [PATCH 4/6] fix bugs --- deepmd/jax/infer/deep_eval.py | 4 +++- source/tests/universal/dpmodel/backend.py | 2 ++ source/tests/universal/dpmodel/model/test_model.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index 30ff28680c..2e028225f7 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -388,7 +388,9 @@ def _eval_model( results = [] for odef in request_defs: - dp_name = self._OUTDEF_DP2BACKEND[odef.name] + # HLO and TFModelWrapper return raw internal keys (not translated), + # so no key mapping is needed here. + dp_name = odef.name if dp_name in batch_output: shape = self._get_output_shape(odef, nframes, natoms) if batch_output[dp_name] is not None: diff --git a/source/tests/universal/dpmodel/backend.py b/source/tests/universal/dpmodel/backend.py index 2f15efe1e6..916b4c4d7b 100644 --- a/source/tests/universal/dpmodel/backend.py +++ b/source/tests/universal/dpmodel/backend.py @@ -21,6 +21,8 @@ class DPTestCase(BackendTestCase): """DP module to test.""" def forward_wrapper(self, x): + if not hasattr(x, "forward_lower") and hasattr(x, "call_lower"): + x.forward_lower = x.call_lower return x def forward_wrapper_cpu_ref(self, x): diff --git a/source/tests/universal/dpmodel/model/test_model.py b/source/tests/universal/dpmodel/model/test_model.py index 815c612bb0..edee8454a1 100644 --- a/source/tests/universal/dpmodel/model/test_model.py +++ b/source/tests/universal/dpmodel/model/test_model.py @@ -164,7 +164,7 @@ def setUpClass(cls) -> None: ft, type_map=cls.expected_type_map, ) - cls.output_def = cls.module.model_output_def().get_data() + cls.output_def = cls.module.translated_output_def() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() From 20b9a40afaf67d8d498e16231ade06c35f6a2ed4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 20 Feb 2026 15:50:03 +0800 Subject: [PATCH 5/6] make spin model output consistent between dpmodel and pt. add consistency test for spin model --- deepmd/dpmodel/model/spin_model.py | 255 +++++++++++- .../tests/consistent/model/test_spin_ener.py | 389 ++++++++++++++++++ source/tests/pt/model/test_ener_spin_model.py | 4 +- .../universal/dpmodel/model/test_model.py | 4 +- 4 files changed, 630 insertions(+), 22 deletions(-) create mode 100644 source/tests/consistent/model/test_spin_ener.py diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index 079ee3519e..85e23df3cc 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) from typing import ( Any, ) @@ -311,10 +314,9 @@ def model_output_def(self) -> ModelOutputDef: def __getattr__(self, name: str) -> Any: """Get attribute from the wrapped model.""" - if name in self.__dict__: - return self.__dict__[name] - else: - return getattr(self.backbone_model, name) + if "backbone_model" not in self.__dict__: + raise AttributeError(name) + return getattr(self.backbone_model, name) def serialize(self) -> dict: return { @@ -333,7 +335,7 @@ def deserialize(cls, data: dict) -> "SpinModel": spin=spin, ) - def call( + def call_common( self, coord: Array, atype: Array, @@ -343,7 +345,7 @@ def call( aparam: Array | None = None, do_atomic_virial: bool = False, ) -> dict[str, Array]: - """Return model prediction. + """Return model prediction with raw internal keys. Parameters ---------- @@ -377,7 +379,7 @@ def call( coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) - model_predict = self.backbone_model.call_common( + model_ret = self.backbone_model.call_common( coord_updated, atype_updated, box, @@ -389,13 +391,104 @@ def call( if "mask" in model_output_type: model_output_type.pop(model_output_type.index("mask")) var_name = model_output_type[0] - model_predict[f"{var_name}"] = np.split( - model_predict[f"{var_name}"], [nloc], axis=1 - )[0] - # for now omit the grad output + model_ret[f"{var_name}"] = np.split(model_ret[f"{var_name}"], [nloc], axis=1)[0] + if ( + self.backbone_model.do_grad_r(var_name) + and model_ret.get(f"{var_name}_derv_r") is not None + ): + ( + model_ret[f"{var_name}_derv_r"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output(atype, model_ret[f"{var_name}_derv_r"]) + if ( + self.backbone_model.do_grad_c(var_name) + and do_atomic_virial + and model_ret.get(f"{var_name}_derv_c") is not None + ): + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output( + atype, + model_ret[f"{var_name}_derv_c"], + add_mag=False, + virtual_scale=False, + ) + # Always compute mask_mag from atom types (even when forces are unavailable) + if "mask_mag" not in model_ret: + nframes_m, nloc_m = atype.shape[:2] + atomic_mask = self.virtual_scale_mask[atype].reshape([nframes_m, nloc_m, 1]) + model_ret["mask_mag"] = atomic_mask > 0.0 + return model_ret + + def call( + self, + coord: Array, + atype: Array, + spin: Array, + box: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + """Return model prediction with translated user-facing keys. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + spin + The spins of the atoms. + shape: nf x (nloc x 3) + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict with translated keys, e.g. + ``atom_energy``, ``energy``, ``force``, ``force_mag``. + + """ + model_ret = self.call_common( + coord, + atype, + spin, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + model_predict = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + model_predict[var_name] = model_ret[f"{var_name}_redu"] + if "mask_mag" in model_ret: + model_predict["mask_mag"] = model_ret["mask_mag"] + if ( + self.backbone_model.do_grad_r(var_name) + and model_ret.get(f"{var_name}_derv_r") is not None + ): + model_predict["force"] = model_ret[f"{var_name}_derv_r"].squeeze(-2) + model_predict["force_mag"] = model_ret[f"{var_name}_derv_r_mag"].squeeze(-2) + # not support virial by far return model_predict - def call_lower( + def call_common_lower( self, extended_coord: Array, extended_atype: Array, @@ -406,7 +499,7 @@ def call_lower( aparam: Array | None = None, do_atomic_virial: bool = False, ) -> dict[str, Array]: - """Return model prediction. Lower interface that takes + """Return model prediction with raw internal keys. Lower interface that takes extended atomic coordinates, types and spins, nlist, and mapping as input, and returns the predictions on the extended region. The predictions are not reduced. @@ -447,7 +540,7 @@ def call_lower( ) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) - model_predict = self.backbone_model.call_common_lower( + model_ret = self.backbone_model.call_common_lower( extended_coord_updated, extended_atype_updated, nlist_updated, @@ -460,8 +553,134 @@ def call_lower( if "mask" in model_output_type: model_output_type.pop(model_output_type.index("mask")) var_name = model_output_type[0] - model_predict[f"{var_name}"] = np.split( - model_predict[f"{var_name}"], [nloc], axis=1 - )[0] - # for now omit the grad output + model_ret[f"{var_name}"] = np.split(model_ret[f"{var_name}"], [nloc], axis=1)[0] + if ( + self.backbone_model.do_grad_r(var_name) + and model_ret.get(f"{var_name}_derv_r") is not None + ): + ( + model_ret[f"{var_name}_derv_r"], + model_ret[f"{var_name}_derv_r_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, model_ret[f"{var_name}_derv_r"], nloc + ) + if ( + self.backbone_model.do_grad_c(var_name) + and do_atomic_virial + and model_ret.get(f"{var_name}_derv_c") is not None + ): + ( + model_ret[f"{var_name}_derv_c"], + model_ret[f"{var_name}_derv_c_mag"], + model_ret["mask_mag"], + ) = self.process_spin_output_lower( + extended_atype, + model_ret[f"{var_name}_derv_c"], + nloc, + add_mag=False, + virtual_scale=False, + ) + # Always compute mask_mag from atom types (even when forces are unavailable) + if "mask_mag" not in model_ret: + nall = extended_atype.shape[1] + atomic_mask = self.virtual_scale_mask[extended_atype].reshape( + [nframes, nall, 1] + ) + model_ret["mask_mag"] = atomic_mask > 0.0 + return model_ret + + def call_lower( + self, + extended_coord: Array, + extended_atype: Array, + extended_spin: Array, + nlist: Array, + mapping: Array | None = None, + fparam: Array | None = None, + aparam: Array | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, Array]: + """Return model prediction with translated user-facing keys. Lower interface. + + Parameters + ---------- + extended_coord + coordinates in extended region. nf x (nall x 3). + extended_atype + atomic type in extended region. nf x nall. + extended_spin + spins in extended region. nf x (nall x 3). + nlist + neighbor list. nf x nloc x nsel. + mapping + maps the extended indices to local indices. nf x nall. + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + whether calculate atomic virial + + Returns + ------- + result_dict + The result dict with translated keys, e.g. + ``atom_energy``, ``energy``, ``extended_force``, ``extended_force_mag``. + + """ + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + model_predict = {} + model_predict[f"atom_{var_name}"] = model_ret[var_name] + model_predict[var_name] = model_ret[f"{var_name}_redu"] + if "mask_mag" in model_ret: + model_predict["extended_mask_mag"] = model_ret["mask_mag"] + if ( + self.backbone_model.do_grad_r(var_name) + and model_ret.get(f"{var_name}_derv_r") is not None + ): + model_predict["extended_force"] = model_ret[f"{var_name}_derv_r"].squeeze( + -2 + ) + model_predict["extended_force_mag"] = model_ret[ + f"{var_name}_derv_r_mag" + ].squeeze(-2) + # not support virial by far return model_predict + + def translated_output_def(self) -> dict[str, Any]: + """Get the translated output definition. + + Maps internal output names to user-facing names, e.g. + ``energy`` -> ``atom_energy``, ``energy_redu`` -> ``energy``, + ``energy_derv_r`` -> ``force``, ``energy_derv_r_mag`` -> ``force_mag``. + """ + out_def_data = self.model_output_def().get_data() + model_output_type = self.backbone_model.model_output_type() + if "mask" in model_output_type: + model_output_type.pop(model_output_type.index("mask")) + var_name = model_output_type[0] + output_def = { + f"atom_{var_name}": out_def_data[var_name], + var_name: out_def_data[f"{var_name}_redu"], + "mask_mag": out_def_data["mask_mag"], + } + if self.backbone_model.do_grad_r(var_name): + output_def["force"] = deepcopy(out_def_data[f"{var_name}_derv_r"]) + output_def["force"].squeeze(-2) + output_def["force_mag"] = deepcopy(out_def_data[f"{var_name}_derv_r_mag"]) + output_def["force_mag"].squeeze(-2) + return output_def diff --git a/source/tests/consistent/model/test_spin_ener.py b/source/tests/consistent/model/test_spin_ener.py new file mode 100644 index 0000000000..61c26b350b --- /dev/null +++ b/source/tests/consistent/model/test_spin_ener.py @@ -0,0 +1,389 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.dpmodel.model.spin_model import SpinModel as SpinModelDP +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_PT, + CommonTest, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.spin_model import SpinEnergyModel as SpinEnergyModelPT + from deepmd.pt.utils.utils import to_numpy_array as torch_to_numpy + from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch +else: + SpinEnergyModelPT = None + +from deepmd.utils.argcheck import ( + model_args, +) + +SPIN_DATA = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20, 20], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140], + }, +} + + +class TestSpinEner(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return SPIN_DATA + + tf_class = None + dp_class = SpinModelDP + pt_class = SpinEnergyModelPT + pd_class = None + pt_expt_class = None + jax_class = None + args = model_args() + + skip_tf = True + skip_jax = True + skip_pt_expt = True + skip_pd = True + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = copy.deepcopy(data) + if cls is SpinModelDP: + return get_model_dp(data) + elif cls is SpinEnergyModelPT: + return get_model_pt(data) + return cls(**data, **self.additional_data) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 3 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 0, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 3, 3, 0], dtype=np.int32) + self.spin = np.array( + [ + 0.50, + 0.30, + 0.20, + 0.40, + 0.25, + 0.15, + 0.10, + 0.05, + 0.08, + 0.12, + 0.07, + 0.09, + 0.45, + 0.35, + 0.28, + 0.11, + 0.06, + 0.03, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + raise NotImplementedError("no TF in this test") + + def eval_dp(self, dp_obj: Any) -> Any: + return dp_obj( + self.coords, + self.atype, + self.spin, + box=self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return { + kk: torch_to_numpy(vv) + for kk, vv in pt_obj( + numpy_to_torch(self.coords), + numpy_to_torch(self.atype), + numpy_to_torch(self.spin), + box=numpy_to_torch(self.box), + ).items() + } + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + from ..common import ( + SKIP_FLAG, + ) + + if backend is self.RefBackend.DP: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["mask_mag"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["mask_mag"].ravel(), + ret["force"].ravel(), + ret["force_mag"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") + + +class TestSpinEnerLower(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + return SPIN_DATA + + tf_class = None + dp_class = SpinModelDP + pt_class = SpinEnergyModelPT + pd_class = None + pt_expt_class = None + jax_class = None + args = model_args() + + skip_tf = True + skip_jax = True + skip_pt_expt = True + skip_pd = True + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = copy.deepcopy(data) + if cls is SpinModelDP: + return get_model_dp(data) + elif cls is SpinEnergyModelPT: + return get_model_pt(data) + return cls(**data, **self.additional_data) + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 3 + coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + atype = np.array([0, 0, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.spin = np.array( + [ + 0.50, + 0.30, + 0.20, + 0.40, + 0.25, + 0.15, + 0.10, + 0.05, + 0.08, + 0.12, + 0.07, + 0.09, + 0.45, + 0.35, + 0.28, + 0.11, + 0.06, + 0.03, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + + rcut = 4.0 + nframes, nloc = atype.shape[:2] + coord_normalized = normalize_coord( + coords.reshape(nframes, nloc, 3), + box.reshape(nframes, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, atype, box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + [20, 20, 20], + distinguish_types=True, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + self.nlist = nlist + self.extended_coord = extended_coord + self.extended_atype = extended_atype + self.mapping = mapping + + # Build extended spin from mapping + nall = extended_coord.shape[1] + self.extended_spin = np.take_along_axis( + self.spin, + np.repeat(mapping[:, :, np.newaxis], 3, axis=2), + axis=1, + ) + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + raise NotImplementedError("no TF in this test") + + def eval_dp(self, dp_obj: Any) -> Any: + return dp_obj.call_lower( + self.extended_coord, + self.extended_atype, + self.extended_spin, + self.nlist, + self.mapping, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return { + kk: torch_to_numpy(vv) + for kk, vv in pt_obj.forward_lower( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.extended_spin), + numpy_to_torch(self.nlist), + numpy_to_torch(self.mapping), + ).items() + } + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + from ..common import ( + SKIP_FLAG, + ) + + if backend is self.RefBackend.DP: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_mask_mag"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_mask_mag"].ravel(), + ret["extended_force"].ravel(), + ret["extended_force_mag"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") diff --git a/source/tests/pt/model/test_ener_spin_model.py b/source/tests/pt/model/test_ener_spin_model.py index ddea392f33..eb933baed5 100644 --- a/source/tests/pt/model/test_ener_spin_model.py +++ b/source/tests/pt/model/test_ener_spin_model.py @@ -313,7 +313,7 @@ def test_dp_consistency(self) -> None: return dp_model = DPSpinModel.deserialize(self.model.serialize()) # test call - dp_ret = dp_model.call( + dp_ret = dp_model.call_common( to_numpy_array(self.coord), to_numpy_array(self.atype), to_numpy_array(self.spin), @@ -355,7 +355,7 @@ def test_dp_consistency(self) -> None: extended_spin = torch.gather( self.spin, index=mapping.unsqueeze(-1).tile((1, 1, 3)), dim=1 ) - dp_ret_lower = dp_model.call_lower( + dp_ret_lower = dp_model.call_common_lower( to_numpy_array(extended_coord), to_numpy_array(extended_atype), to_numpy_array(extended_spin), diff --git a/source/tests/universal/dpmodel/model/test_model.py b/source/tests/universal/dpmodel/model/test_model.py index edee8454a1..092f7d4ae9 100644 --- a/source/tests/universal/dpmodel/model/test_model.py +++ b/source/tests/universal/dpmodel/model/test_model.py @@ -164,7 +164,7 @@ def setUpClass(cls) -> None: ft, type_map=cls.expected_type_map, ) - cls.output_def = cls.module.translated_output_def() + cls.output_def = cls.module.model_output_def().get_data() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() @@ -271,7 +271,7 @@ def setUpClass(cls) -> None: pair_exclude_types=pair_exclude_types, ) cls.module = SpinModel(backbone_model=backbone_model, spin=spin) - cls.output_def = cls.module.model_output_def().get_data() + cls.output_def = cls.module.translated_output_def() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam() From a3dfaa64468daa9d70123b99a455da17c43eb3cd Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 20 Feb 2026 16:35:01 +0800 Subject: [PATCH 6/6] fix bugs --- source/tests/consistent/model/test_spin_ener.py | 1 - source/tests/universal/dpmodel/model/test_model.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/source/tests/consistent/model/test_spin_ener.py b/source/tests/consistent/model/test_spin_ener.py index 61c26b350b..f6019732cf 100644 --- a/source/tests/consistent/model/test_spin_ener.py +++ b/source/tests/consistent/model/test_spin_ener.py @@ -333,7 +333,6 @@ def setUp(self) -> None: self.mapping = mapping # Build extended spin from mapping - nall = extended_coord.shape[1] self.extended_spin = np.take_along_axis( self.spin, np.repeat(mapping[:, :, np.newaxis], 3, axis=2), diff --git a/source/tests/universal/dpmodel/model/test_model.py b/source/tests/universal/dpmodel/model/test_model.py index 092f7d4ae9..c82074c601 100644 --- a/source/tests/universal/dpmodel/model/test_model.py +++ b/source/tests/universal/dpmodel/model/test_model.py @@ -164,7 +164,7 @@ def setUpClass(cls) -> None: ft, type_map=cls.expected_type_map, ) - cls.output_def = cls.module.model_output_def().get_data() + cls.output_def = cls.module.translated_output_def() cls.expected_has_message_passing = ds.has_message_passing() cls.expected_sel_type = ft.get_sel_type() cls.expected_dim_fparam = ft.get_dim_fparam()