From 34d4cb52c9c104050dccb46e8ac4d87392ba0ea2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 7 Apr 2026 21:03:03 +0800 Subject: [PATCH 1/9] test: add .pte and .pt2 tests for dp convert-backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add pt_expt backend (.pte/.pt2) to the parameterized extensions in test_models.py to verify convert-backend works for the new exportable formats. The fparam_aparam model (1 atom type) is switched from type_one_side=False to type_one_side=True (with ndim 2→1), which is equivalent for single-type models but enables pt_expt export. Models with type_one_side=False and multiple types (se_e2_a, se_e2_r) are skipped for .pte/.pt2 as make_fx cannot trace data-dependent indexing in NetworkCollection(ndim=2). --- .../tests/infer/fparam_aparam-testcase.yaml | 2 +- source/tests/infer/fparam_aparam.yaml | 6 ++--- source/tests/infer/fparam_aparam_default.yaml | 6 ++--- source/tests/infer/test_models.py | 26 ++++++++++++------- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/source/tests/infer/fparam_aparam-testcase.yaml b/source/tests/infer/fparam_aparam-testcase.yaml index 220b2df209..1f300e31dd 100644 --- a/source/tests/infer/fparam_aparam-testcase.yaml +++ b/source/tests/infer/fparam_aparam-testcase.yaml @@ -26,7 +26,7 @@ model_def_script: "set_davg_zero": False, "trainable": True, "type": "se_e2_a", - "type_one_side": False, + "type_one_side": True, }, "fitting_net": { diff --git a/source/tests/infer/fparam_aparam.yaml b/source/tests/infer/fparam_aparam.yaml index e0654e142f..3f22bbcf6c 100644 --- a/source/tests/infer/fparam_aparam.yaml +++ b/source/tests/infer/fparam_aparam.yaml @@ -526,7 +526,7 @@ model: embeddings: "@class": NetworkCollection "@version": 1 - ndim: 2 + ndim: 1 network_type: embedding_network networks: - "@class": EmbeddingNetwork @@ -916,7 +916,7 @@ model: type: se_e2_a type_map: &id001 - O - type_one_side: false + type_one_side: true fitting: "@class": Fitting "@variables": @@ -2012,7 +2012,7 @@ model_def_script: set_davg_zero: false trainable: true type: se_e2_a - type_one_side: false + type_one_side: true fitting_net: activation_function: tanh atom_ener: *id004 diff --git a/source/tests/infer/fparam_aparam_default.yaml b/source/tests/infer/fparam_aparam_default.yaml index 6d64bfc328..5798817325 100644 --- a/source/tests/infer/fparam_aparam_default.yaml +++ b/source/tests/infer/fparam_aparam_default.yaml @@ -526,7 +526,7 @@ model: embeddings: "@class": NetworkCollection "@version": 1 - ndim: 2 + ndim: 1 network_type: embedding_network networks: - "@class": EmbeddingNetwork @@ -916,7 +916,7 @@ model: type: se_e2_a type_map: &id001 - O - type_one_side: false + type_one_side: true fitting: "@class": Fitting "@variables": @@ -2021,7 +2021,7 @@ model_def_script: set_davg_zero: false trainable: true type: se_e2_a - type_one_side: false + type_one_side: true fitting_net: activation_function: tanh atom_ener: *id004 diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 7f7b7cc21c..500622c664 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -13,6 +13,7 @@ ) from ..consistent.common import ( + INSTALLED_PT_EXPT, parameterized, ) from .case import ( @@ -28,7 +29,7 @@ "se_e2_r", "fparam_aparam", ), # key - (".pb", ".pth"), # model extension + (".pb", ".pth", ".pte", ".pt2"), # model extension ) class TestDeepPot(unittest.TestCase): # moved from tests/tf/test_deeppot_a.py @@ -36,6 +37,16 @@ class TestDeepPot(unittest.TestCase): @classmethod def setUpClass(cls) -> None: key, extension = cls.param + if extension in (".pte", ".pt2") and not INSTALLED_PT_EXPT: + raise unittest.SkipTest("pt_expt backend not installed") + if key in ("se_e2_a", "se_e2_r") and extension in (".pte", ".pt2"): + raise unittest.SkipTest( + "type_one_side=False is not supported for pt_expt export" + ) + if key == "se_e2_r" and extension == ".pth": + raise unittest.SkipTest( + "se_e2_r type_one_side is not supported for PyTorch models" + ) cls.case = get_cases()[key] cls.model_name = cls.case.get_model(extension) cls.dp = DeepEval(cls.model_name) @@ -44,13 +55,6 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: cls.dp = None - def setUp(self) -> None: - key, extension = self.param - if key == "se_e2_r" and extension == ".pth": - self.skipTest( - reason="se_e2_r type_one_side is not supported for PyTorch models" - ) - def test_attrs(self) -> None: assert isinstance(self.dp, DeepPot) self.assertEqual(self.dp.get_ntypes(), self.case.ntypes) @@ -153,6 +157,8 @@ def test_1frame_atm(self) -> None: def test_descriptor(self) -> None: _, extension = self.param + if extension in (".pte", ".pt2"): + self.skipTest("eval_descriptor not supported for pt_expt models") for ii, result in enumerate(self.case.results): if result.descriptor is None: continue @@ -166,8 +172,8 @@ def test_descriptor(self) -> None: def test_fitting_last_layer(self) -> None: _, extension = self.param - if extension == ".pb": - self.skipTest("fitting_last_layer not supported for TensorFlow models") + if extension in (".pb", ".pte", ".pt2"): + self.skipTest("fitting_last_layer not supported for this backend") for ii, result in enumerate(self.case.results): if result.fit_ll is None: continue From 281989d1589427a8dc2cddfbf90f5890f9ef7b19 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 8 Apr 2026 08:24:08 +0800 Subject: [PATCH 2/9] fix: reset default device before .pt2 AOTInductor compilation tests/pt/__init__.py may set a fake default device for CPU fallback, which poisons AOTInductor compilation. Temporarily clear the default device before converting to .pt2, matching the pattern used in test_change_bias.py. --- source/tests/infer/test_models.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 500622c664..44e3de30cb 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -48,7 +48,19 @@ def setUpClass(cls) -> None: "se_e2_r type_one_side is not supported for PyTorch models" ) cls.case = get_cases()[key] - cls.model_name = cls.case.get_model(extension) + if extension == ".pt2": + import torch + + # Clear default device: tests/pt/__init__.py may set a fake + # device for CPU fallback, which poisons AOTInductor compilation. + saved_device = torch.get_default_device() + torch.set_default_device(None) + try: + cls.model_name = cls.case.get_model(extension) + finally: + torch.set_default_device(saved_device) + else: + cls.model_name = cls.case.get_model(extension) cls.dp = DeepEval(cls.model_name) @classmethod From fbc3b9bc5117e967e2088d89dd416c04861db296 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 8 Apr 2026 22:41:01 +0800 Subject: [PATCH 3/9] feat(pt_expt): add eval_typeebd, eval_descriptor, eval_fitting_last_layer Implement three diagnostic eval methods for the pt_expt backend's DeepEval, using the eager _dpmodel (deserialized from model.json in the archive) to compute intermediate results directly. Key changes: - Add get_dp_atomic_model() API to model hierarchy (make_model CM, FrozenModel, SpinModel) for clean access to DPAtomicModel - Add set_return_middle_output / middle_output support in GeneralFitting - Extract _prepare_inputs helper from _eval_model for reuse - Preserve _call_common output dict in DipoleFitting/PolarFitting call() so middle_output flows through without ad-hoc propagation - Remove .pte/.pt2 skips in cross-backend test_descriptor/test_fitting_last_layer - Add fparam_aparam descriptor/fit_ll reference values to testcase YAML Limitations: - SpinModel: eval_descriptor/eval_fitting_last_layer raise NotImplementedError (virtual atom preprocessing required) - DPZBLModel/LinearEnergyModel: raise NotImplementedError (no single descriptor/fitting_net) - se_e2_a + fparam: cannot export to .pte (torch.export limitation) --- deepmd/dpmodel/fitting/dipole_fitting.py | 9 +- deepmd/dpmodel/fitting/general_fitting.py | 29 + .../dpmodel/fitting/polarizability_fitting.py | 8 +- deepmd/dpmodel/model/frozen.py | 10 + deepmd/dpmodel/model/make_model.py | 21 + deepmd/dpmodel/model/spin_model.py | 4 + deepmd/pt_expt/infer/deep_eval.py | 227 +++- .../dpmodel/test_fitting_middle_output.py | 101 ++ .../tests/infer/fparam_aparam-testcase.yaml | 996 ++++++++++++++++++ source/tests/infer/test_models.py | 34 +- source/tests/pt_expt/infer/test_deep_eval.py | 412 ++++++++ 11 files changed, 1828 insertions(+), 23 deletions(-) create mode 100644 source/tests/common/dpmodel/test_fitting_middle_output.py diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 0fc3a1fefa..818b2eeffd 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -239,9 +239,8 @@ def call( nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # (nframes, nloc, m1) - out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ - self.var_name - ] + results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = results[self.var_name] # (nframes * nloc, 1, m1) out = xp.reshape(out, (-1, 1, self.embedding_width)) # (nframes * nloc, m1, 3) @@ -249,5 +248,5 @@ def call( # (nframes, nloc, 3) # out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) out = out @ gr - out = xp.reshape(out, (nframes, nloc, 3)) - return {self.var_name: out} + results[self.var_name] = xp.reshape(out, (nframes, nloc, 3)) + return results diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 4761703a19..f505bc556c 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -168,6 +168,7 @@ def __init__( if self.spin is not None: raise NotImplementedError("spin is not supported") self.remove_vaccum_contribution = remove_vaccum_contribution + self.eval_return_middle_output = False net_dim_out = self._net_out_dim() # init constants @@ -424,6 +425,16 @@ def get_default_fparam(self) -> list[float] | None: """Get the default frame parameters.""" return self.default_fparam + def set_return_middle_output(self, enable: bool) -> None: + """Enable or disable returning the middle (pre-last-layer) output. + + When enabled, the fitting network's ``call`` method will include + a ``"middle_output"`` key in the returned dict, containing the + hidden-layer activations before the final linear layer. Shape: + ``[nframes, nloc, neuron[-1]]``. + """ + self.eval_return_middle_output = enable + def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. @@ -690,6 +701,12 @@ def _call_common( dtype=get_xp_precision(xp, self.precision), device=array_api_compat.device(descriptor), ) + if self.eval_return_middle_output: + middle_outs = xp.zeros( + [nf, nloc, self.neuron[-1]], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(descriptor), + ) for type_i in range(self.ntypes): mask = xp.tile( xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out) @@ -705,10 +722,20 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + if self.eval_return_middle_output: + mid = self.nets[(type_i,)].call_until_last(xx) + mid_mask = xp.tile( + xp.reshape((atype == type_i), (nf, nloc, 1)), + (1, 1, self.neuron[-1]), + ) + mid = xp.where(mid_mask, mid, xp.zeros_like(mid)) + middle_outs = middle_outs + mid else: outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) + if self.eval_return_middle_output: + middle_outs = self.nets[()].call_until_last(xx) outs += xp.reshape( xp.take( xp.astype(self.bias_atom_e[...], outs.dtype), @@ -723,4 +750,6 @@ def _call_common( # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) results[self.var_name] = outs + if self.eval_return_middle_output: + results["middle_output"] = middle_outs return results diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 361f033a68..42d36000a2 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -326,9 +326,8 @@ def call( "Must provide the rotation matrix for polarizability fitting." ) # (nframes, nloc, _net_out_dim) - out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ - self.var_name - ] + results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = results.pop(self.var_name) # out = out * self.scale[atype, ...] scale_atype = xp.reshape( xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0), @@ -371,4 +370,5 @@ def call( # (nframes, nloc, 3, 3) bias = bias[..., None] * eye out = out + bias - return {"polarizability": out} + results["polarizability"] = out + return results diff --git a/deepmd/dpmodel/model/frozen.py b/deepmd/dpmodel/model/frozen.py index a8bf74323b..ed9da3bafe 100644 --- a/deepmd/dpmodel/model/frozen.py +++ b/deepmd/dpmodel/model/frozen.py @@ -1,9 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + TYPE_CHECKING, Any, NoReturn, ) +if TYPE_CHECKING: + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + from deepmd.dpmodel.common import ( NativeOP, ) @@ -131,6 +137,10 @@ def get_observed_type_list(self) -> list[str]: """Get observed types (elements) of the model during data statistics.""" return self.model.get_observed_type_list() + def get_dp_atomic_model(self) -> "DPAtomicModel | None": + """Get the underlying DPAtomicModel by delegating to the inner model.""" + return self.model.get_dp_atomic_model() + def serialize(self) -> dict: """Serialize the model. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index f70b668a87..597f8ea006 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -3,9 +3,15 @@ Callable, ) from typing import ( + TYPE_CHECKING, Any, ) +if TYPE_CHECKING: + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + import array_api_compat import numpy as np @@ -704,6 +710,21 @@ def is_aparam_nall(self) -> bool: """ return self.atomic_model.is_aparam_nall() + def get_dp_atomic_model(self) -> "DPAtomicModel | None": + """Get the underlying DPAtomicModel with descriptor and fitting_net. + + Returns the ``atomic_model`` if it is a ``DPAtomicModel`` instance + (i.e. has both ``descriptor`` and ``fitting_net``). Returns ``None`` + for composite atomic models such as ``LinearEnergyAtomicModel``. + """ + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + + if isinstance(self.atomic_model, DPAtomicModel): + return self.atomic_model + return None + def get_rcut(self) -> float: """Get the cut-off radius.""" return self.atomic_model.get_rcut() diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index d9f185a60d..1a92d9d87f 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -547,6 +547,10 @@ def __getattr__(self, name: str) -> Any: raise AttributeError(name) return getattr(self.backbone_model, name) + def get_dp_atomic_model(self) -> "DPAtomicModel | None": + """Get the underlying DPAtomicModel by delegating to the backbone model.""" + return self.backbone_model.get_dp_atomic_model() + def serialize(self) -> dict: return { "backbone_model": self.backbone_model.serialize(), diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 290f2ec923..a2d2891b82 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -566,15 +566,22 @@ def _build_nlist_ase_single( return extended_coord, extended_atype, nlist, mapping - def _eval_model( + def _prepare_inputs( self, coords: np.ndarray, cells: np.ndarray | None, atom_types: np.ndarray, fparam: np.ndarray | None, aparam: np.ndarray | None, - request_defs: list[OutputVariableDef], - ) -> tuple[np.ndarray, ...]: + ) -> tuple: + """Prepare tensor inputs for model evaluation. + + Returns + ------- + tuple + (ext_coord_t, ext_atype_t, nlist_t, mapping_t, + fparam_t, aparam_t, nframes, natoms) + """ nframes = coords.shape[0] if len(atom_types.shape) == 1: natoms = len(atom_types) @@ -622,9 +629,7 @@ def _eval_model( dtype=torch.float64, device=DEVICE, ) - elif self._is_pt2 and self.get_dim_fparam() > 0: - # .pt2 models are compiled with fparam as a required input. - # When the user omits fparam, fill with default values from metadata. + elif self.get_dim_fparam() > 0: default_fp = self.metadata.get("default_fparam") if default_fp is not None: fparam_t = ( @@ -647,9 +652,45 @@ def _eval_model( dtype=torch.float64, device=DEVICE, ) + elif self.get_dim_aparam() > 0: + raise ValueError( + f"aparam is required for this model (dim_aparam={self.get_dim_aparam()}) " + "but was not provided." + ) else: aparam_t = None + return ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + nframes, + natoms, + ) + + def _eval_model( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + nframes, + natoms, + ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + # Call the model (forward_common_lower interface, internal keys) if self._is_pt2: # AOTInductor's __call__ unflattens output using stored out_spec, @@ -732,3 +773,177 @@ def get_model(self) -> torch.nn.Module: The exported model module. """ return self.exported_module + + def _is_spin_model(self) -> bool: + """Check if the underlying dpmodel is a SpinModel.""" + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + + return isinstance(self._dpmodel, SpinModel) + + def eval_typeebd(self) -> np.ndarray: + """Evaluate type embedding. + + Returns + ------- + np.ndarray + Type embedding array of shape ``(ntypes, tebd_dim)``. + + Raises + ------ + KeyError + If the model has no type embedding networks. + """ + from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP + + model = self._dpmodel + if self._is_spin_model(): + model = model.backbone_model + out = [] + for mm in model.modules(): + if isinstance(mm, TypeEmbedNetDP): + out.append(mm()) + if not out: + raise KeyError("The model has no type embedding networks.") + typeebd = torch.cat(out, dim=1) + return typeebd.detach().cpu().numpy() + + def eval_descriptor( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> np.ndarray: + """Evaluate descriptor. + + Parameters + ---------- + coords + Coordinates, shape ``(nframes, natoms, 3)``. + cells + Cell vectors, shape ``(nframes, 3, 3)`` or ``None``. + atom_types + Atom types, shape ``(natoms,)`` or ``(nframes, natoms)``. + fparam + Frame parameters, optional. + aparam + Atom parameters, optional. + + Returns + ------- + np.ndarray + Descriptor output, shape ``(nframes, nloc, dim_descrpt)``. + """ + if self._is_spin_model(): + raise NotImplementedError( + "eval_descriptor is not supported for spin models." + ) + dp_am = self._dpmodel.get_dp_atomic_model() + if dp_am is None: + raise NotImplementedError( + "eval_descriptor is not supported for this model type " + f"({type(self._dpmodel).__name__})." + ) + ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + nframes, + natoms, + ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + with torch.no_grad(): + fparam_for_des = ( + fparam_t if getattr(dp_am, "add_chg_spin_ebd", False) else None + ) + descriptor, rot_mat, g2, h2, sw = dp_am.descriptor( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping=mapping_t, + fparam=fparam_for_des, + ) + return descriptor.detach().cpu().numpy() + + def eval_fitting_last_layer( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> np.ndarray: + """Evaluate the last hidden layer of the fitting network. + + Parameters + ---------- + coords + Coordinates, shape ``(nframes, natoms, 3)``. + cells + Cell vectors, shape ``(nframes, 3, 3)`` or ``None``. + atom_types + Atom types, shape ``(natoms,)`` or ``(nframes, natoms)``. + fparam + Frame parameters, optional. + aparam + Atom parameters, optional. + + Returns + ------- + np.ndarray + Middle-layer output, shape ``(nframes, nloc, neuron[-1])``. + """ + if self._is_spin_model(): + raise NotImplementedError( + "eval_fitting_last_layer is not supported for spin models." + ) + dp_am = self._dpmodel.get_dp_atomic_model() + if dp_am is None: + raise NotImplementedError( + "eval_fitting_last_layer is not supported for this model type " + f"({type(self._dpmodel).__name__})." + ) + ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + nframes, + natoms, + ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + with torch.no_grad(): + fparam_for_des = ( + fparam_t if getattr(dp_am, "add_chg_spin_ebd", False) else None + ) + descriptor, rot_mat, g2, h2, sw = dp_am.descriptor( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping=mapping_t, + fparam=fparam_for_des, + ) + atype = ext_atype_t[:, :natoms] + fitting_net = dp_am.fitting_net + fitting_net.set_return_middle_output(True) + try: + ret = fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam_t, + aparam=aparam_t, + ) + finally: + fitting_net.set_return_middle_output(False) + return ret["middle_output"].detach().cpu().numpy() diff --git a/source/tests/common/dpmodel/test_fitting_middle_output.py b/source/tests/common/dpmodel/test_fitting_middle_output.py new file mode 100644 index 0000000000..0e53f13e90 --- /dev/null +++ b/source/tests/common/dpmodel/test_fitting_middle_output.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +class TestFittingMiddleOutput(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def _build_fitting(self, mixed_types, nfp=0, nap=0): + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + descriptor = dd[0] + atype = self.atype_ext[:, : self.nloc] + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, # dim_out + mixed_types=mixed_types, + numb_fparam=nfp, + numb_aparam=nap, + seed=GLOBAL_SEED, + ) + return ft, descriptor, atype + + def test_middle_output_disabled_by_default(self) -> None: + """Middle output should not be present when not enabled.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ret = ft.call(descriptor, atype) + self.assertIn("energy", ret) + self.assertNotIn("middle_output", ret) + + def test_middle_output_enabled_mixed_types(self) -> None: + """When enabled, middle_output key is present with correct shape for mixed_types=True.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + expected_shape = (nf, nloc, ft.neuron[-1]) + self.assertEqual(ret["middle_output"].shape, expected_shape) + + def test_middle_output_enabled_per_type(self) -> None: + """When enabled, middle_output key is present with correct shape for mixed_types=False.""" + ft, descriptor, atype = self._build_fitting(mixed_types=False) + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + expected_shape = (nf, nloc, ft.neuron[-1]) + self.assertEqual(ret["middle_output"].shape, expected_shape) + + def test_middle_output_toggle(self) -> None: + """Verify toggling set_return_middle_output on/off works.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + + # Enable + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype) + self.assertIn("middle_output", ret) + + # Disable + ft.set_return_middle_output(False) + ret = ft.call(descriptor, atype) + self.assertNotIn("middle_output", ret) + + def test_middle_output_with_fparam_aparam(self) -> None: + """Middle output works with fparam and aparam.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True, nfp=2, nap=3) + nf, nloc, _ = descriptor.shape + fparam = np.zeros([nf, 2], dtype=np.float64) + aparam = np.zeros([nf, nloc, 3], dtype=np.float64) + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype, fparam=fparam, aparam=aparam) + self.assertIn("middle_output", ret) + expected_shape = (nf, nloc, ft.neuron[-1]) + self.assertEqual(ret["middle_output"].shape, expected_shape) + + def test_middle_output_deterministic(self) -> None: + """Middle output should be deterministic.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft.set_return_middle_output(True) + ret1 = ft.call(descriptor, atype) + ret2 = ft.call(descriptor, atype) + np.testing.assert_array_equal(ret1["middle_output"], ret2["middle_output"]) diff --git a/source/tests/infer/fparam_aparam-testcase.yaml b/source/tests/infer/fparam_aparam-testcase.yaml index 1f300e31dd..5bb9ab471b 100644 --- a/source/tests/infer/fparam_aparam-testcase.yaml +++ b/source/tests/infer/fparam_aparam-testcase.yaml @@ -159,3 +159,999 @@ results: -2.486859433265622629e-02, 2.875323131744185121e-02, ] + descriptor: + [ + 6.083714331120656515e-01, + 4.502505554576420876e-01, + 1.594398494919171405e-01, + 4.727519861103077758e-01, + 6.727545551660457646e-01, + -3.105848623409598885e-01, + -2.034772733369290820e-01, + -2.166933245535891117e-01, + 4.502505554576420876e-01, + 1.139649142647525260e+00, + 4.969499257159836203e-01, + 4.301892148257391857e-01, + 1.270001504410809723e+00, + -5.140452446715122470e-01, + 3.263504683974155496e-01, + -9.091582204946582202e-01, + 1.594398494919171405e-01, + 4.969499257159836203e-01, + 2.260654681080056239e-01, + 1.624402806893359419e-01, + 5.386383726647075987e-01, + -2.137567595770779316e-01, + 1.714384661213686489e-01, + -4.035449616399264805e-01, + 4.727519861103077758e-01, + 4.301892148257391857e-01, + 1.624402806893359419e-01, + 3.754699534763070723e-01, + 5.995953110586557111e-01, + -2.694918713461056381e-01, + -1.105341613089115826e-01, + -2.422835589130273026e-01, + 6.727545551660457646e-01, + 1.270001504410809723e+00, + 5.386383726647075987e-01, + 5.995953110586557111e-01, + 1.483309072850735211e+00, + -6.156618440021969230e-01, + 2.315666690444803111e-01, + -9.570105981549420493e-01, + -3.105848623409598885e-01, + -5.140452446715122470e-01, + -2.137567595770779316e-01, + -2.694918713461056381e-01, + -6.156618440021969230e-01, + 2.589390132149291257e-01, + -6.406882715059965261e-02, + 3.754908530130299793e-01, + -2.034772733369290820e-01, + 3.263504683974155496e-01, + 1.714384661213686489e-01, + -1.105341613089115826e-01, + 2.315666690444803111e-01, + -6.406882715059965261e-02, + 3.502027394021257067e-01, + -3.698375855067269069e-01, + -2.166933245535891117e-01, + -9.091582204946582202e-01, + -4.035449616399264805e-01, + -2.422835589130273026e-01, + -9.570105981549420493e-01, + 3.754908530130299793e-01, + -3.698375855067269069e-01, + 7.767038256079379366e-01, + -2.981667553358613998e-01, + -7.391122602545086018e-01, + -3.205454587664270938e-01, + -2.831620287994112140e-01, + -8.262423426451125374e-01, + 3.351661217421733063e-01, + -2.067609428898770363e-01, + 5.886302766667602659e-01, + 1.240976071619029480e-01, + 4.973224483159546239e-01, + 2.234717192757742499e-01, + 1.368572698960379141e-01, + 5.254712190591790399e-01, + -2.061836993159548159e-01, + 1.983387449937808189e-01, + -4.204030104006635793e-01, + 5.581270077279179009e-01, + 4.082193934796568002e-01, + 1.441148337259589574e-01, + 4.332383483055207152e-01, + 6.125584989446182238e-01, + -2.832083358508835080e-01, + -1.895286928385826741e-01, + -1.942100761080858551e-01, + 6.378556323516176851e-01, + 1.001982931038889690e+00, + 4.168948972325110525e-01, + 5.485284807739624346e-01, + 1.212661151645888724e+00, + -5.122423732761739457e-01, + 1.001406309724684474e-01, + -7.186443216927924649e-01, + -7.152415656398955490e-02, + 1.054901580743528022e-01, + 6.232449399777988119e-02, + -3.894272340885426148e-02, + 7.205094434523259816e-02, + -1.805014026654702675e-02, + 1.183066919255474220e-01, + -1.161656443653441007e-01, + 5.022149866663377926e-01, + 6.976700285197479090e-01, + 2.862873989982428480e-01, + 4.229179663891721730e-01, + 8.673481304519186086e-01, + -3.709822075657827067e-01, + 2.498047634770347511e-02, + -4.803400360910294875e-01, + 5.637582607285308578e-01, + 1.079696549143701167e+00, + 4.590639095981799978e-01, + 5.040455670483138251e-01, + 1.257759447946239506e+00, + -5.212815054748195509e-01, + 2.032328956811573217e-01, + -8.159605088714143584e-01, + -1.350693558539911299e-01, + -3.220522605062802191e-01, + -1.397884793071107190e-01, + -1.270778868078931978e-01, + -3.620099554771795125e-01, + 1.472182089511498149e-01, + -8.617472043402696347e-02, + 2.543117761117610343e-01, + -1.464136841504689035e-01, + 1.217686222663357798e-01, + 6.951144078262200265e-02, + -9.088638084027333974e-02, + 5.842502264744667967e-02, + -6.392865301134783088e-03, + 1.850518592878056534e-01, + -1.617257895507598076e-01, + 9.094458269279154239e-02, + -3.087692407825532448e-01, + -1.487259125583649533e-01, + 3.376356883973294365e-02, + -2.598628715751217788e-01, + 8.690297349942077698e-02, + -2.524037647949817775e-01, + 3.202540825701963301e-01, + -3.751637757936663320e-01, + -3.982900935111258200e-01, + -1.544156549323568994e-01, + -3.034653687171679493e-01, + -5.304295281929436445e-01, + 2.341575952346893752e-01, + 5.419559866725016245e-02, + 2.461408379536359770e-01, + 3.132239475444893451e-01, + 8.765556472050846093e-01, + 3.875192888559441107e-01, + 3.079155531500631371e-01, + 9.635319863477369573e-01, + -3.866632840469169663e-01, + 2.768015304833532642e-01, + -7.082444142190988945e-01, + 1.430144157827138329e+00, + 8.371406802726794050e-01, + 3.217236266667155009e-01, + 1.101372195257868780e+00, + 1.309710128726800038e+00, + -6.265729912157642634e-01, + -5.976692872812503499e-01, + -2.230698942337786261e-01, + 8.371406802726794050e-01, + 1.023033322697949554e+00, + 4.121430074516977404e-01, + 6.841961394490133630e-01, + 1.371741299045477591e+00, + -5.865551302740638073e-01, + -4.901372130550269574e-02, + -7.126229372870805934e-01, + 3.217236266667155009e-01, + 4.121430074516977404e-01, + 1.667923268135688974e-01, + 2.645504158349288315e-01, + 5.474427363099826360e-01, + -2.328003308815323191e-01, + -7.917678800792056840e-03, + -2.933583025540117961e-01, + 1.101372195257868780e+00, + 6.841961394490133630e-01, + 2.645504158349288315e-01, + 8.512102245969188630e-01, + 1.052791432993813947e+00, + -4.985889765492535486e-01, + -4.378685187396406109e-01, + -2.142944536866224214e-01, + 1.309710128726800038e+00, + 1.371741299045477591e+00, + 5.474427363099826360e-01, + 1.052791432993813947e+00, + 1.890936188225563086e+00, + -8.248805096784416202e-01, + -2.065440769036151480e-01, + -8.692732098846506217e-01, + -6.265729912157642634e-01, + -5.865551302740638073e-01, + -2.328003308815323191e-01, + -4.985889765492535486e-01, + -8.248805096784416202e-01, + 3.656750155213921438e-01, + 1.380477511643647681e-01, + 3.391836108469640787e-01, + -5.976692872812503499e-01, + -4.901372130550269574e-02, + -7.917678800792056840e-03, + -4.378685187396406109e-01, + -2.065440769036151480e-01, + 1.380477511643647681e-01, + 4.196794682686730837e-01, + -2.346170286615520129e-01, + -2.230698942337786261e-01, + -7.126229372870805934e-01, + -2.933583025540117961e-01, + -2.142944536866224214e-01, + -8.692732098846506217e-01, + 3.391836108469640787e-01, + -2.346170286615520129e-01, + 6.743229971375143128e-01, + -5.123920094325192798e-01, + -7.074779841860322493e-01, + -2.855206215044043816e-01, + -4.243401890981060021e-01, + -9.350347600855188901e-01, + 3.936059152956140617e-01, + -1.539375914738916849e-02, + 5.278501097255398067e-01, + 1.466478863200821325e-01, + 4.137911975084271332e-01, + 1.699008976760616185e-01, + 1.368122661427808029e-01, + 5.095046380033478872e-01, + -2.004656295627817442e-01, + 1.233528182536938744e-01, + -3.836611231041641701e-01, + 1.307379055422948300e+00, + 7.738567876659411260e-01, + 2.975463412155298082e-01, + 1.007378233374699450e+00, + 1.207607062406870435e+00, + -5.765232705711489380e-01, + -5.416165483742040321e-01, + -2.138262638647728253e-01, + 1.359643853640744249e+00, + 1.143610489764218752e+00, + 4.519272049237009758e-01, + 1.072869654075972567e+00, + 1.639797498390930031e+00, + -7.390349747630275967e-01, + -3.719220608199254596e-01, + -5.916890292398517825e-01, + -1.552513570223078287e-01, + -1.897934780868526863e-02, + -4.210073071744310953e-03, + -1.139855766718844976e-01, + -6.215846401873216520e-02, + 3.892738523383156707e-02, + 1.057189870558455924e-01, + -5.276144303775173738e-02, + 1.104538870923048366e+00, + 8.526603431257446797e-01, + 3.351761963971622293e-01, + 8.659673339527532709e-01, + 1.245043447077310406e+00, + -5.687481602434459882e-01, + -3.451881390779078518e-01, + -3.969124013792098005e-01, + 1.089349537510279520e+00, + 1.164895568404513027e+00, + 4.652591969475391998e-01, + 8.773598369767227068e-01, + 1.600472758596691580e+00, + -6.961382514444349745e-01, + -1.583528216631740482e-01, + -7.496262147675080145e-01, + -2.480012474417669366e-01, + -3.065647179173028869e-01, + -1.233673686142572196e-01, + -2.028456491009781593e-01, + -4.110517151842552619e-01, + 1.754473861346198538e-01, + 1.266109575260882664e-02, + 2.155799593938188818e-01, + -3.978403937119566747e-01, + -9.991332892649396058e-02, + -3.340409898200767669e-02, + -2.963873874846542633e-01, + -2.143409093118101016e-01, + 1.197953317221483738e-01, + 2.414569614468069225e-01, + -8.227296457316374267e-02, + 3.690821426046725362e-01, + -8.055988299054643587e-02, + -4.082341030522432940e-02, + 2.626029367695669747e-01, + -1.042803680946458424e-03, + -3.859858812343196915e-02, + -3.212762213243168796e-01, + 2.685029558806420469e-01, + -8.323969526289327625e-01, + -5.886919464746885877e-01, + -2.295559263038976228e-01, + -6.484017939892503524e-01, + -8.785152580554453916e-01, + 4.068776062722544995e-01, + 2.907761085338105289e-01, + 2.415838031028531563e-01, + 5.332734738415039200e-01, + 8.033746811924462605e-01, + 3.253994932488234459e-01, + 4.466281431516882505e-01, + 1.049148736467505838e+00, + -4.372511097323733553e-01, + 5.389587131807656306e-02, + -6.224441065274786133e-01, + 9.652167274938354691e-01, + 5.313446375815668032e-01, + 1.595985340042652689e-01, + 7.366646596146324555e-01, + 8.456546075061203149e-01, + -4.160141501041305645e-01, + -4.255582314137071887e-01, + -1.459893329129947348e-01, + 5.313446375815668032e-01, + 7.843080248152077827e-01, + 2.839722764135915178e-01, + 4.378789672765076579e-01, + 1.043223182916673375e+00, + -4.397200153695126068e-01, + 3.886481651629476036e-02, + -6.392727633140400378e-01, + 1.595985340042652689e-01, + 2.839722764135915178e-01, + 1.158293104700064968e-01, + 1.354240804871464643e-01, + 3.740428049816837408e-01, + -1.520077295574986109e-01, + 3.906416117492179929e-02, + -2.414007968687422734e-01, + 7.366646596146324555e-01, + 4.378789672765076579e-01, + 1.354240804871464643e-01, + 5.644438045269433157e-01, + 6.833253054234037505e-01, + -3.312184769506364979e-01, + -3.067479524964061288e-01, + -1.475897350138534181e-01, + 8.456546075061203149e-01, + 1.043223182916673375e+00, + 3.740428049816837408e-01, + 6.833253054234037505e-01, + 1.423660291885181062e+00, + -6.125441018811154104e-01, + -5.221076190477519363e-02, + -7.845287167838036479e-01, + -4.160141501041305645e-01, + -4.397200153695126068e-01, + -1.520077295574986109e-01, + -3.312184769506364979e-01, + -6.125441018811154104e-01, + 2.698786254482626323e-01, + 6.654507192479590383e-02, + 3.033699870391518005e-01, + -4.255582314137071887e-01, + 3.886481651629476036e-02, + 3.906416117492179929e-02, + -3.067479524964061288e-01, + -5.221076190477519363e-02, + 6.654507192479590383e-02, + 3.393861782856803511e-01, + -2.455063645724311627e-01, + -1.459893329129947348e-01, + -6.392727633140400378e-01, + -2.414007968687422734e-01, + -1.475897350138534181e-01, + -7.845287167838036479e-01, + 3.033699870391518005e-01, + -2.455063645724311627e-01, + 6.614884222618355736e-01, + -3.265542808407963515e-01, + -5.693028167157063724e-01, + -2.101678679286862750e-01, + -2.746923718860757591e-01, + -7.457901247019582680e-01, + 3.081124176255724545e-01, + -7.215648382777055392e-02, + 4.929607152885880916e-01, + 7.289144396716068508e-02, + 3.543740733174537971e-01, + 1.413074915636731432e-01, + 7.633121055347411033e-02, + 4.362453637185305655e-01, + -1.663244632123324074e-01, + 1.423151678522617813e-01, + -3.675326287502824751e-01, + 8.796846030153004925e-01, + 4.949259267405346496e-01, + 1.507390539711822852e-01, + 6.721064859018730520e-01, + 7.839949395373813079e-01, + -3.837608242990865337e-01, + -3.819288553812149045e-01, + -1.449413796860413162e-01, + 8.907246324644731983e-01, + 8.102877153047098879e-01, + 2.761347910508374359e-01, + 7.009365398991779239e-01, + 1.156667158090711078e+00, + -5.208994952781673682e-01, + -2.149701864086924763e-01, + -4.977028399023077365e-01, + -1.544386539268297331e-01, + -3.135035718915457625e-02, + 7.480831051489745913e-03, + -1.135803115961424858e-01, + -6.844467004321277970e-02, + 4.444629092177178331e-02, + 9.844907620866516496e-02, + -3.171928598253399845e-02, + 7.230832319002997721e-01, + 5.844662077789191112e-01, + 1.962525977491133278e-01, + 5.642907623475763579e-01, + 8.531813387800899484e-01, + -3.913192904287589591e-01, + -2.151528698629770198e-01, + -3.199599599522977011e-01, + 6.988775778606687306e-01, + 8.886437601116213836e-01, + 3.210311239935199623e-01, + 5.665061413069706342e-01, + 1.208524116908088031e+00, + -5.175924737174160128e-01, + -2.842688210451090741e-02, + -6.780668764946560234e-01, + -1.565472119713383070e-01, + -2.381925212230145972e-01, + -8.722212678513849293e-02, + -1.294711626807042715e-01, + -3.164147548509885222e-01, + 1.326809813472158428e-01, + -1.537634923757306461e-02, + 1.963938295947420254e-01, + -2.733270100864703123e-01, + -3.208757074499291734e-02, + 3.495293544523320217e-04, + -2.008501478596516066e-01, + -1.016318578095670166e-01, + 6.715777107960560488e-02, + 1.862588844222162754e-01, + -9.355767339931092552e-02, + 2.405851005618613869e-01, + -1.543923965166731360e-01, + -6.881011534701647614e-02, + 1.652293323598198749e-01, + -1.251034413458357586e-01, + 1.988420677015985405e-02, + -2.649793288818961257e-01, + 2.931400138614652096e-01, + -5.591003147909873183e-01, + -4.063284579846219158e-01, + -1.315208571745482835e-01, + -4.331117440048848355e-01, + -6.061090378540733292e-01, + 2.833828635442544042e-01, + 1.918609571473225972e-01, + 1.970841104720812420e-01, + 3.097713438305200739e-01, + 6.338697878210927117e-01, + 2.441240075513875507e-01, + 2.671479762635345367e-01, + 8.207590450429802509e-01, + -3.321173802625960736e-01, + 1.208172718023899639e-01, + -5.709923069462401468e-01, + 4.378905518390653340e-01, + 2.803383952198099105e-01, + 4.818083515067635852e-02, + 3.288723036125778543e-01, + 4.568212815808281313e-01, + -2.205844558239555830e-01, + -1.790930632302553671e-01, + -1.600790746380182095e-01, + 2.803383952198099105e-01, + 1.184924484528727451e+00, + 5.280975986537399525e-01, + 3.217531805643514264e-01, + 1.211448561860737794e+00, + -4.775164830407008787e-01, + 4.909065147537653440e-01, + -9.839528741162539838e-01, + 4.818083515067635852e-02, + 5.280975986537399525e-01, + 2.570327553754042649e-01, + 9.166999066254227779e-02, + 5.052143239146044129e-01, + -1.897550899742722341e-01, + 2.801748154632567878e-01, + -4.496013116553528310e-01, + 3.288723036125778543e-01, + 3.217531805643514264e-01, + 9.166999066254227779e-02, + 2.593365269807865192e-01, + 4.447725051818553488e-01, + -2.027935980045048603e-01, + -6.749380249578725011e-02, + -2.173901573078995675e-01, + 4.568212815808281313e-01, + 1.211448561860737794e+00, + 5.052143239146044129e-01, + 4.447725051818553488e-01, + 1.316578719027061473e+00, + -5.374363093691157944e-01, + 3.666830608571204353e-01, + -9.723643805738775292e-01, + -2.205844558239555830e-01, + -4.775164830407008787e-01, + -1.897550899742722341e-01, + -2.027935980045048603e-01, + -5.374363093691157944e-01, + 2.237214263935641545e-01, + -1.122709912087574979e-01, + 3.760362192894101119e-01, + -1.790930632302553671e-01, + 4.909065147537653440e-01, + 2.801748154632567878e-01, + -6.749380249578725011e-02, + 3.666830608571204353e-01, + -1.122709912087574979e-01, + 4.379920339886785308e-01, + -4.651501679265258593e-01, + -1.600790746380182095e-01, + -9.839528741162539838e-01, + -4.496013116553528310e-01, + -2.173901573078995675e-01, + -9.723643805738775292e-01, + 3.760362192894101119e-01, + -4.651501679265258593e-01, + 8.339826836209653926e-01, + -2.019987455578689528e-01, + -7.458382466808392008e-01, + -3.265377399991231666e-01, + -2.198489170601660714e-01, + -7.741919203025848795e-01, + 3.080513688669864747e-01, + -2.886593869231082743e-01, + 6.147449159647125905e-01, + 7.131193565321254646e-02, + 5.249313970468063584e-01, + 2.458973988601513838e-01, + 1.066531993642767101e-01, + 5.125435760822050213e-01, + -1.960840086240358826e-01, + 2.595614234479117211e-01, + -4.455571070106654208e-01, + 4.023781103462598097e-01, + 2.439986598498889436e-01, + 3.791545332624686460e-02, + 3.007270943868774471e-01, + 4.073693169405033787e-01, + -1.980921968504558150e-01, + -1.727376141424090128e-01, + -1.349192056315940413e-01, + 4.254523519602377291e-01, + 9.354988128396966029e-01, + 3.754166251626768758e-01, + 3.929303613668370665e-01, + 1.049997075315301398e+00, + -4.360231065678457640e-01, + 2.254237742176496972e-01, + -7.364430001865647224e-01, + -1.096490937195783028e-01, + 1.580908847286551788e-01, + 1.068758541049165955e-01, + -5.659607022330204879e-02, + 9.474341305433385541e-02, + -2.025508820369920082e-02, + 1.827501136383326286e-01, + -1.559512090836832898e-01, + 3.324229255148461459e-01, + 6.295495345956941824e-01, + 2.438920609066078660e-01, + 2.958568755730026645e-01, + 7.277856895597104581e-01, + -3.066669744223792793e-01, + 1.151140737695940108e-01, + -4.860400475978431944e-01, + 3.800286645551473885e-01, + 1.030846678700064523e+00, + 4.321346638586625599e-01, + 3.725917742884784500e-01, + 1.116358772562584312e+00, + -4.547357119141249293e-01, + 3.189530464986248259e-01, + -8.288056457618550033e-01, + -8.708450843312832979e-02, + -3.261851498223737877e-01, + -1.434582664155235332e-01, + -9.532600255607970308e-02, + -3.380393053647171020e-01, + 1.343028283928490718e-01, + -1.272684286809669663e-01, + 2.688332125008372486e-01, + -1.182987729627547613e-01, + 2.218483536179890747e-01, + 1.335847595358480577e-01, + -5.597980621885078473e-02, + 1.485338228432657171e-01, + -4.002313195306946014e-02, + 2.275728678460948173e-01, + -2.180286303147872251e-01, + 5.203035227557934600e-02, + -4.039081787251208033e-01, + -2.070299876563956587e-01, + -8.991377771592213783e-03, + -3.450698857581925294e-01, + 1.205275942669613021e-01, + -2.843701930822792323e-01, + 3.666462601777820685e-01, + -2.705826322727860056e-01, + -3.132742354154710029e-01, + -9.859239844920053564e-02, + -2.186706418403572272e-01, + -4.102509343627187000e-01, + 1.832099469879538423e-01, + 2.634881197618781090e-02, + 2.219945829310046026e-01, + 1.807596106975759565e-01, + 8.994429150699673192e-01, + 4.103717482408931194e-01, + 2.226815410105222326e-01, + 9.051338172534704185e-01, + -3.527781448781722728e-01, + 3.982850361239002601e-01, + -7.512015699620381293e-01, + 1.300102742718015625e+00, + 6.054999403594980567e-01, + 2.097253231448496791e-01, + 9.827165263729973343e-01, + 1.052183733453920800e+00, + -5.195049333939609770e-01, + -6.381641027068527539e-01, + -7.364899813768920056e-02, + 6.054999403594980567e-01, + 9.039638177949659292e-01, + 3.486898489277805435e-01, + 5.059022952874115964e-01, + 1.170428721365738722e+00, + -4.925183973772490065e-01, + 5.701278884854733137e-02, + -7.030345936214288383e-01, + 2.097253231448496791e-01, + 3.486898489277805435e-01, + 1.391402783708104718e-01, + 1.776896403010444214e-01, + 4.499490466203024952e-01, + -1.859372353636008224e-01, + 3.956661747329379836e-02, + -2.828032361752622625e-01, + 9.827165263729973343e-01, + 5.059022952874115964e-01, + 1.776896403010444214e-01, + 7.465734159825266891e-01, + 8.476440027475331540e-01, + -4.120299431266394308e-01, + -4.548760980436803991e-01, + -1.074321953463584917e-01, + 1.052183733453920800e+00, + 1.170428721365738722e+00, + 4.499490466203024952e-01, + 8.476440027475331540e-01, + 1.603809031278819575e+00, + -6.959819635037184371e-01, + -1.296058402624230832e-01, + -7.926580902398893125e-01, + -5.195049333939609770e-01, + -4.925183973772490065e-01, + -1.859372353636008224e-01, + -4.120299431266394308e-01, + -6.959819635037184371e-01, + 3.087929131865239096e-01, + 1.124120074707089523e-01, + 2.991135908127298126e-01, + -6.381641027068527539e-01, + 5.701278884854733137e-02, + 3.956661747329379836e-02, + -4.548760980436803991e-01, + -1.296058402624230832e-01, + 1.124120074707089523e-01, + 5.150411483832132431e-01, + -3.445907142048604821e-01, + -7.364899813768920056e-02, + -7.030345936214288383e-01, + -2.828032361752622625e-01, + -1.074321953463584917e-01, + -7.926580902398893125e-01, + 2.991135908127298126e-01, + -3.445907142048604821e-01, + 7.234936414299397711e-01, + -3.732452792355842597e-01, + -6.351284941938279971e-01, + -2.493644712191097901e-01, + -3.176635771364644811e-01, + -8.109315931741168937e-01, + 3.357484199566660443e-01, + -7.917224627668460746e-02, + 5.179274965636704309e-01, + 6.103879963205482984e-02, + 3.926374684337851639e-01, + 1.610613109134209353e-01, + 7.405475779758788346e-02, + 4.538001478377386433e-01, + -1.721953254916427645e-01, + 1.769878680633666101e-01, + -3.961762226546653443e-01, + 1.191157188899728991e+00, + 5.605474901628545448e-01, + 1.954757521943455700e-01, + 9.007411238079842120e-01, + 9.717425968880002429e-01, + -4.785550805028060961e-01, + -5.815036985906476552e-01, + -7.396165185977580936e-02, + 1.132707043019949067e+00, + 9.319961358271104945e-01, + 3.464538551202276939e-01, + 8.875056954741523674e-01, + 1.359864912960755401e+00, + -6.156841792369970312e-01, + -3.257019742083836489e-01, + -4.991680044292492457e-01, + -1.825740192995925215e-01, + -1.057937522442242186e-02, + 4.460297560427498031e-03, + -1.325177201685841522e-01, + -6.089046756479399997e-02, + 4.198807823721553006e-02, + 1.315877118315471328e-01, + -7.070128940557268704e-02, + 9.376629691212025053e-01, + 6.766047344331500568e-01, + 2.487179767062538960e-01, + 7.273093171529937395e-01, + 1.022193095621619419e+00, + -4.714878433053744100e-01, + -3.236961738617304407e-01, + -3.112248563810770530e-01, + 8.715441272817934237e-01, + 9.957608706283338496e-01, + 3.843068460465810521e-01, + 7.040812625593925178e-01, + 1.358621587856315260e+00, + -5.873343112872276839e-01, + -9.250561013998101489e-02, + -6.850860344768691101e-01, + -1.839873054978174161e-01, + -2.697675766943509745e-01, + -1.049195435465825332e-01, + -1.532710511861232472e-01, + -3.516162193145014880e-01, + 1.479201485584544506e-01, + -1.441931622091501171e-02, + 2.085957883977366534e-01, + -4.079828637765540722e-01, + -2.847368118114416286e-02, + -2.222134405267033215e-03, + -2.957398893337482293e-01, + -1.557361823194183814e-01, + 9.835505242304752593e-02, + 2.924429004194846748e-01, + -1.501487313940521318e-01, + 3.982233761717781650e-01, + -1.588944981924644839e-01, + -7.408969697337658422e-02, + 2.742647039835456879e-01, + -5.353523938839521018e-02, + -2.055105489653940437e-02, + -3.916635706554859042e-01, + 3.475448753197058482e-01, + -7.360058674331342310e-01, + -4.568919993931332968e-01, + -1.655906071391073375e-01, + -5.651141307901779154e-01, + -7.216348754100389007e-01, + 3.402776887471806178e-01, + 2.963766262538370944e-01, + 1.646029265129388963e-01, + 3.634385605888428850e-01, + 7.177088151935814286e-01, + 2.862230103263159187e-01, + 3.167632415194596152e-01, + 9.029296450871021618e-01, + -3.677650232872346492e-01, + 1.332516815067029126e-01, + -6.119232793393495351e-01, + 1.470617714177075541e+00, + 7.908160637269205928e-01, + 2.967611009382591924e-01, + 1.126498485385288628e+00, + 1.276801732029178904e+00, + -6.188661256643098740e-01, + -6.555616565594825085e-01, + -1.622820106502845250e-01, + 7.908160637269205928e-01, + 1.019847269601365625e+00, + 4.140808136319966692e-01, + 6.501651569008026765e-01, + 1.353982396050527948e+00, + -5.750737982370550672e-01, + -1.571684107553555573e-02, + -7.288264621550111233e-01, + 2.967611009382591924e-01, + 4.140808136319966692e-01, + 1.688255123604280317e-01, + 2.463489561234036285e-01, + 5.431181093161604467e-01, + -2.285272873930015991e-01, + 1.190173024325155877e-02, + -3.072061588815623856e-01, + 1.126498485385288628e+00, + 6.501651569008026765e-01, + 2.463489561234036285e-01, + 8.663277218548093295e-01, + 1.027034106828141224e+00, + -4.918915403527173713e-01, + -4.768929805439556802e-01, + -1.715930912354062809e-01, + 1.276801732029178904e+00, + 1.353982396050527948e+00, + 5.431181093161604467e-01, + 1.027034106828141224e+00, + 1.863909708290216072e+00, + -8.111554978845125774e-01, + -1.920275777457542132e-01, + -8.656707854826148907e-01, + -6.188661256643098740e-01, + -5.750737982370550672e-01, + -2.285272873930015991e-01, + -4.918915403527173713e-01, + -8.111554978845125774e-01, + 3.597341191090284718e-01, + 1.389081190062548543e-01, + 3.311289995354575466e-01, + -6.555616565594825085e-01, + -1.571684107553555573e-02, + 1.190173024325155877e-02, + -4.768929805439556802e-01, + -1.920275777457542132e-01, + 1.389081190062548543e-01, + 4.831489109598711695e-01, + -2.903972000340614423e-01, + -1.622820106502845250e-01, + -7.288264621550111233e-01, + -3.072061588815623856e-01, + -1.715930912354062809e-01, + -8.656707854826148907e-01, + 3.311289995354575466e-01, + -2.903972000340614423e-01, + 7.136178988114953992e-01, + -4.801359830987989574e-01, + -7.072071231901065902e-01, + -2.889551561531346069e-01, + -4.007779804898234932e-01, + -9.248816752531052732e-01, + 3.862384773157101492e-01, + -3.973170595224694707e-02, + 5.404661605326807061e-01, + 1.135129783310338814e-01, + 4.223520933079102369e-01, + 1.774298690373650178e-01, + 1.133778094061134029e-01, + 5.082598909026351253e-01, + -1.962563595187890308e-01, + 1.534628752553490738e-01, + -4.054692993770361853e-01, + 1.343835858041975362e+00, + 7.321279204278162700e-01, + 2.752106405208382123e-01, + 1.029969906078544550e+00, + 1.178294959727800517e+00, + -5.696579444659997105e-01, + -5.938038404477978816e-01, + -1.592103103982526191e-01, + 1.351996704100781876e+00, + 1.115390913195200673e+00, + 4.390524493312819576e-01, + 1.064629486341718945e+00, + 1.609751622898790480e+00, + -7.272045864926830472e-01, + -3.826981143794290774e-01, + -5.682521760868141092e-01, + -1.738843102777075378e-01, + -8.601801098605380003e-03, + 1.274148486290022043e-03, + -1.267748650962710610e-01, + -5.627986373366875017e-02, + 3.876373515740432624e-02, + 1.256954388033035552e-01, + -7.196565882631425493e-02, + 1.107641644352691745e+00, + 8.264319183306835237e-01, + 3.223156539290735756e-01, + 8.657358177050420434e-01, + 1.220391344258679345e+00, + -5.600641558420980104e-01, + -3.629685669468760700e-01, + -3.710150121819992819e-01, + 1.060129731097080752e+00, + 1.150689072388683787e+00, + 4.622609209910348849e-01, + 8.546229501566764419e-01, + 1.578087848817897987e+00, + -6.845234077831010566e-01, + -1.445537517343584433e-01, + -7.479025509047609876e-01, + -2.354087887671292845e-01, + -3.052486440759409336e-01, + -1.239441248907928866e-01, + -1.935320869814817213e-01, + -4.059118589235164998e-01, + 1.721661764329455446e-01, + 3.880989047236374587e-03, + 2.194887259681063407e-01, + -4.276691974179438471e-01, + -8.120234886415216013e-02, + -2.257704216701504271e-02, + -3.162911706641049481e-01, + -2.057942213675230947e-01, + 1.197982705434366413e-01, + 2.751309243674305294e-01, + -1.122352637998940972e-01, + 4.130879341289221407e-01, + -1.041506808305002196e-01, + -5.622961801539583249e-02, + 2.923612513907042354e-01, + -9.829481080015531683e-03, + -4.020157399527740649e-02, + -3.686448663143884197e-01, + 3.081438262949895979e-01, + -8.451881571484531896e-01, + -5.653460146886296611e-01, + -2.179605774453889921e-01, + -6.555215826327438489e-01, + -8.595224164327244232e-01, + 4.012363786307169833e-01, + 3.141583866609505282e-01, + 2.138404247587386564e-01, + 4.901499623543186246e-01, + 8.064861280426489643e-01, + 3.311403671368340351e-01, + 4.153443650103284535e-01, + 1.039836036114208717e+00, + -4.289759458199104425e-01, + 8.832851884678478982e-02, + -6.435274972517068814e-01, + ] + fit_ll: + [ + 1.212985093937938963e+00, + -1.381549085068657590e+00, + 1.085133808142466849e+00, + 1.321675583158594902e+00, + 9.996861251087960643e-01, + 1.241922641459953125e+00, + -1.382821053720374227e+00, + 1.086264457711279219e+00, + 1.331897794223656728e+00, + 1.002200480399525473e+00, + 1.217151059829681081e+00, + -1.381911178384377825e+00, + 1.086244130315317191e+00, + 1.327760578088094334e+00, + 1.001454813411551736e+00, + 1.178374034136443704e+00, + -1.379130236372588580e+00, + 1.081631474710690766e+00, + 1.304438695799553294e+00, + 9.939766256981953374e-01, + 1.237617107943818962e+00, + -1.382695233447721828e+00, + 1.086338611560096368e+00, + 1.331179464750698704e+00, + 1.002117303895980971e+00, + 1.242018332379969614e+00, + -1.382821853835294101e+00, + 1.086264346741943942e+00, + 1.331911247257247188e+00, + 1.002203014460946617e+00, + ] diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 44e3de30cb..535dc1c2d3 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -169,33 +169,51 @@ def test_1frame_atm(self) -> None: def test_descriptor(self) -> None: _, extension = self.param - if extension in (".pte", ".pt2"): - self.skipTest("eval_descriptor not supported for pt_expt models") for ii, result in enumerate(self.case.results): if result.descriptor is None: continue - descpt = self.dp.eval_descriptor(result.coord, result.box, result.atype) + descpt = self.dp.eval_descriptor( + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, + ) expected_descpt = result.descriptor np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) # See #4533 - descpt = self.dp.eval_descriptor(result.coord, result.box, result.atype) + descpt = self.dp.eval_descriptor( + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, + ) expected_descpt = result.descriptor np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) def test_fitting_last_layer(self) -> None: _, extension = self.param - if extension in (".pb", ".pte", ".pt2"): - self.skipTest("fitting_last_layer not supported for this backend") + if extension == ".pb": + self.skipTest("fitting_last_layer not supported for TensorFlow models") for ii, result in enumerate(self.case.results): if result.fit_ll is None: continue fit_ll = self.dp.eval_fitting_last_layer( - result.coord, result.box, result.atype + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, ) expected_fit_ll = result.fit_ll np.testing.assert_almost_equal(fit_ll.ravel(), expected_fit_ll.ravel()) fit_ll = self.dp.eval_fitting_last_layer( - result.coord, result.box, result.atype + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, ) expected_fit_ll = result.fit_ll np.testing.assert_almost_equal(fit_ll.ravel(), expected_fit_ll.ravel()) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 112ada5dc7..d652938527 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -6,6 +6,7 @@ """ import importlib +import os import tempfile import unittest import zipfile @@ -855,5 +856,416 @@ def test_pt2_vs_pte_consistency(self) -> None: ) +class TestDeepEvalEnerDefaultFparam(unittest.TestCase): + """Test .pte inference with default fparam (non-spin model).""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + numb_fparam=1, + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.tmpfile.close() + deserialize_to_file(cls.tmpfile.name, cls.model_data) + + cls.dp = DeepPot(cls.tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_eval_with_fparam(self) -> None: + """Model with fparam works when fparam is explicitly provided.""" + coords, cells, atom_types = self._make_inputs() + fparam = np.array([0.5], dtype=np.float64) + ee, ff, vv = self.dp.eval(coords, cells, atom_types, fparam=fparam)[:3] + self.assertEqual(ee.shape, (1, 1)) + self.assertEqual(ff.shape, (1, 6, 3)) + + def test_eval_without_fparam_has_default(self) -> None: + """When fparam is omitted but default exists, should use default.""" + coords, cells, atom_types = self._make_inputs() + # The model has dim_fparam=1 but default_fparam is None by default. + # Without a default, omitting fparam should raise ValueError. + with self.assertRaises(ValueError): + self.dp.eval(coords, cells, atom_types) + + +class TestEvalTypeEbd(unittest.TestCase): + """Test eval_typeebd for pt_expt models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # se_e2_a model (no type embedding) + ds_sea = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_sea = EnergyFittingNet( + cls.nt, + ds_sea.get_dim_out(), + mixed_types=ds_sea.mixed_types(), + seed=GLOBAL_SEED, + ) + model_sea = EnergyModel(ds_sea, ft_sea, type_map=cls.type_map) + model_sea = model_sea.to(torch.float64) + model_sea.eval() + cls._tmpdir = tempfile.mkdtemp() + pte_sea = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_sea, {"model": model_sea.serialize()}) + cls.dp_sea = DeepPot(pte_sea) + + # DPA1 model (has type embedding) + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + ds_dpa1 = DescrptDPA1( + cls.rcut, + cls.rcut_smth, + cls.sel, + ntypes=cls.nt, + seed=GLOBAL_SEED, + ) + ft_dpa1 = EnergyFittingNet( + cls.nt, + ds_dpa1.get_dim_out(), + mixed_types=ds_dpa1.mixed_types(), + seed=GLOBAL_SEED, + ) + model_dpa1 = EnergyModel(ds_dpa1, ft_dpa1, type_map=cls.type_map) + model_dpa1 = model_dpa1.to(torch.float64) + model_dpa1.eval() + pte_dpa1 = os.path.join(cls._tmpdir, "dpa1.pte") + deserialize_to_file(pte_dpa1, {"model": model_dpa1.serialize()}) + cls.dp_dpa1 = DeepPot(pte_dpa1) + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def test_typeebd_dpa1(self) -> None: + """DPA1 model has type embedding, should return valid array.""" + typeebd = self.dp_dpa1.deep_eval.eval_typeebd() + self.assertEqual(typeebd.ndim, 2) + # DPA1 TypeEmbedNet outputs (ntypes+1) rows (padding type included) + self.assertIn(typeebd.shape[0], (self.nt, self.nt + 1)) + self.assertTrue(typeebd.shape[1] > 0) + + def test_typeebd_sea_raises(self) -> None: + """se_e2_a model has no type embedding, should raise KeyError.""" + with self.assertRaises(KeyError): + self.dp_sea.deep_eval.eval_typeebd() + + +class TestEvalDescriptor(unittest.TestCase): + """Test eval_descriptor for pt_expt models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # se_e2_a model + ds_sea = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_sea = EnergyFittingNet( + cls.nt, + ds_sea.get_dim_out(), + mixed_types=ds_sea.mixed_types(), + seed=GLOBAL_SEED, + ) + model_sea = EnergyModel(ds_sea, ft_sea, type_map=cls.type_map) + model_sea = model_sea.to(torch.float64) + model_sea.eval() + cls._tmpdir = tempfile.mkdtemp() + pte_sea = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_sea, {"model": model_sea.serialize()}) + cls.dp_sea = DeepPot(pte_sea) + cls.dim_descrpt_sea = ds_sea.get_dim_out() + + # DPA1 model + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + ds_dpa1 = DescrptDPA1( + cls.rcut, + cls.rcut_smth, + cls.sel, + ntypes=cls.nt, + seed=GLOBAL_SEED, + ) + ft_dpa1 = EnergyFittingNet( + cls.nt, + ds_dpa1.get_dim_out(), + mixed_types=ds_dpa1.mixed_types(), + seed=GLOBAL_SEED, + ) + model_dpa1 = EnergyModel(ds_dpa1, ft_dpa1, type_map=cls.type_map) + model_dpa1 = model_dpa1.to(torch.float64) + model_dpa1.eval() + pte_dpa1 = os.path.join(cls._tmpdir, "dpa1.pte") + deserialize_to_file(pte_dpa1, {"model": model_dpa1.serialize()}) + cls.dp_dpa1 = DeepPot(pte_dpa1) + cls.dim_descrpt_dpa1 = ds_dpa1.get_dim_out() + + # se_e2_a model with fparam (dim_fparam=1, no default; swap _dpmodel + # because se_e2_a + fparam hits GuardOnDataDependentSymNode in export) + ds_fp = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_fp = EnergyFittingNet( + cls.nt, + ds_fp.get_dim_out(), + mixed_types=ds_fp.mixed_types(), + numb_fparam=1, + seed=GLOBAL_SEED, + ) + model_fp = EnergyModel(ds_fp, ft_fp, type_map=cls.type_map) + pte_fp = os.path.join(cls._tmpdir, "fp.pte") + deserialize_to_file(pte_fp, {"model": model_sea.serialize()}) + cls.dp_fp = DeepPot(pte_fp) + cls.dp_fp.deep_eval._dpmodel = model_fp + cls.dim_descrpt_fp = ds_fp.get_dim_out() + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_descriptor_shape_sea(self) -> None: + """se_e2_a descriptor has correct shape.""" + coords, cells, atom_types = self._make_inputs() + descpt = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + self.assertEqual(descpt.shape, (1, 6, self.dim_descrpt_sea)) + + def test_descriptor_shape_dpa1(self) -> None: + """DPA1 descriptor has correct shape.""" + coords, cells, atom_types = self._make_inputs() + descpt = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) + self.assertEqual(descpt.shape, (1, 6, self.dim_descrpt_dpa1)) + + def test_descriptor_deterministic_sea(self) -> None: + """Calling eval_descriptor twice gives same result for se_e2_a.""" + coords, cells, atom_types = self._make_inputs() + d1 = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + d2 = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + np.testing.assert_array_equal(d1, d2) + + def test_descriptor_deterministic_dpa1(self) -> None: + """Calling eval_descriptor twice gives same result for DPA1.""" + coords, cells, atom_types = self._make_inputs() + d1 = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) + d2 = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) + np.testing.assert_array_equal(d1, d2) + + def test_descriptor_with_fparam(self) -> None: + """eval_descriptor works with fparam.""" + coords, cells, atom_types = self._make_inputs() + fparam = np.array([0.5], dtype=np.float64) + descpt = self.dp_fp.deep_eval.eval_descriptor( + coords, cells, atom_types, fparam=fparam + ) + self.assertEqual(descpt.shape, (1, 6, self.dim_descrpt_fp)) + + def test_descriptor_without_fparam_raises(self) -> None: + """eval_descriptor raises when fparam is required but not provided.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(ValueError): + self.dp_fp.deep_eval.eval_descriptor(coords, cells, atom_types) + + +class TestEvalFittingLastLayer(unittest.TestCase): + """Test eval_fitting_last_layer for pt_expt models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + cls.neuron = [120, 120, 120] # default fitting net neurons + + # se_e2_a model (mixed_types=False) + ds_sea = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_sea = EnergyFittingNet( + cls.nt, + ds_sea.get_dim_out(), + mixed_types=ds_sea.mixed_types(), + seed=GLOBAL_SEED, + ) + model_sea = EnergyModel(ds_sea, ft_sea, type_map=cls.type_map) + model_sea = model_sea.to(torch.float64) + model_sea.eval() + cls._tmpdir = tempfile.mkdtemp() + pte_sea = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_sea, {"model": model_sea.serialize()}) + cls.dp_sea = DeepPot(pte_sea) + + # DPA1 model (mixed_types=True) + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + ds_dpa1 = DescrptDPA1( + cls.rcut, + cls.rcut_smth, + cls.sel, + ntypes=cls.nt, + seed=GLOBAL_SEED, + ) + ft_dpa1 = EnergyFittingNet( + cls.nt, + ds_dpa1.get_dim_out(), + mixed_types=ds_dpa1.mixed_types(), + seed=GLOBAL_SEED, + ) + model_dpa1 = EnergyModel(ds_dpa1, ft_dpa1, type_map=cls.type_map) + model_dpa1 = model_dpa1.to(torch.float64) + model_dpa1.eval() + pte_dpa1 = os.path.join(cls._tmpdir, "dpa1.pte") + deserialize_to_file(pte_dpa1, {"model": model_dpa1.serialize()}) + cls.dp_dpa1 = DeepPot(pte_dpa1) + + # se_e2_a model with fparam and aparam (swap _dpmodel because + # se_e2_a + fparam hits GuardOnDataDependentSymNode in export) + ds_fp = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_fp = EnergyFittingNet( + cls.nt, + ds_fp.get_dim_out(), + mixed_types=ds_fp.mixed_types(), + numb_fparam=1, + numb_aparam=2, + seed=GLOBAL_SEED, + ) + model_fp = EnergyModel(ds_fp, ft_fp, type_map=cls.type_map) + pte_fp = os.path.join(cls._tmpdir, "fp.pte") + deserialize_to_file(pte_fp, {"model": model_sea.serialize()}) + cls.dp_fp = DeepPot(pte_fp) + cls.dp_fp.deep_eval._dpmodel = model_fp + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_fitting_ll_shape_sea(self) -> None: + """se_e2_a fitting last layer has correct shape.""" + coords, cells, atom_types = self._make_inputs() + fit_ll = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + self.assertEqual(fit_ll.shape, (1, 6, self.neuron[-1])) + + def test_fitting_ll_shape_dpa1(self) -> None: + """DPA1 fitting last layer has correct shape.""" + coords, cells, atom_types = self._make_inputs() + fit_ll = self.dp_dpa1.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + self.assertEqual(fit_ll.shape, (1, 6, self.neuron[-1])) + + def test_fitting_ll_deterministic_sea(self) -> None: + """Verify calling twice gives the same result for se_e2_a.""" + coords, cells, atom_types = self._make_inputs() + fit_ll1 = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + fit_ll2 = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + np.testing.assert_array_equal(fit_ll1, fit_ll2) + + def test_fitting_ll_deterministic_dpa1(self) -> None: + """Verify calling twice gives the same result for DPA1.""" + coords, cells, atom_types = self._make_inputs() + fit_ll1 = self.dp_dpa1.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + fit_ll2 = self.dp_dpa1.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + np.testing.assert_array_equal(fit_ll1, fit_ll2) + + def test_fitting_ll_with_fparam_aparam(self) -> None: + """eval_fitting_last_layer works with fparam and aparam.""" + coords, cells, atom_types = self._make_inputs() + fparam = np.array([0.5], dtype=np.float64) + aparam = np.zeros((1, 6, 2), dtype=np.float64) + fit_ll = self.dp_fp.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types, fparam=fparam, aparam=aparam + ) + self.assertEqual(fit_ll.shape, (1, 6, self.neuron[-1])) + + def test_fitting_ll_without_fparam_raises(self) -> None: + """eval_fitting_last_layer raises when fparam is required but not provided.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(ValueError): + self.dp_fp.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) + + if __name__ == "__main__": unittest.main() From da7d4f975e4e40f54ac59a330b5a09c39a9e9ab3 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 8 Apr 2026 23:24:54 +0800 Subject: [PATCH 4/9] test: cover untested code paths in eval diagnostic methods --- .../dpmodel/test_fitting_middle_output.py | 58 ++++ source/tests/pt_expt/infer/test_deep_eval.py | 267 ++++++++++++++++++ 2 files changed, 325 insertions(+) diff --git a/source/tests/common/dpmodel/test_fitting_middle_output.py b/source/tests/common/dpmodel/test_fitting_middle_output.py index 0e53f13e90..74d67177a1 100644 --- a/source/tests/common/dpmodel/test_fitting_middle_output.py +++ b/source/tests/common/dpmodel/test_fitting_middle_output.py @@ -99,3 +99,61 @@ def test_middle_output_deterministic(self) -> None: ret1 = ft.call(descriptor, atype) ret2 = ft.call(descriptor, atype) np.testing.assert_array_equal(ret1["middle_output"], ret2["middle_output"]) + + def test_middle_output_dipole_fitting(self) -> None: + """middle_output flows through DipoleFitting.call().""" + from deepmd.dpmodel.fitting.dipole_fitting import ( + DipoleFitting, + ) + + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + descriptor = dd[0] + rot_mat = dd[1] + atype = self.atype_ext[:, : self.nloc] + ft = DipoleFitting( + ntypes=self.nt, + dim_descrpt=ds.get_dim_out(), + embedding_width=ds.get_dim_emb(), + seed=GLOBAL_SEED, + ) + # Disabled: no middle_output + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertNotIn("middle_output", ret) + # Enabled: middle_output present + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + # Primary output still correct shape + self.assertEqual(ret["dipole"].shape, (nf, nloc, 3)) + + def test_middle_output_polar_fitting(self) -> None: + """middle_output flows through PolarFitting.call().""" + from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting, + ) + + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + descriptor = dd[0] + rot_mat = dd[1] + atype = self.atype_ext[:, : self.nloc] + ft = PolarFitting( + ntypes=self.nt, + dim_descrpt=ds.get_dim_out(), + embedding_width=ds.get_dim_emb(), + seed=GLOBAL_SEED, + ) + # Disabled: no middle_output + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertNotIn("middle_output", ret) + # Enabled: middle_output present + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + # Primary output still correct shape + self.assertEqual(ret["polarizability"].shape, (nf, nloc, 3, 3)) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index d652938527..06636b5cdc 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -1267,5 +1267,272 @@ def test_fitting_ll_without_fparam_raises(self) -> None: self.dp_fp.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) +class TestGetDpAtomicModel(unittest.TestCase): + """Test get_dp_atomic_model() API on various model types.""" + + def test_energy_model(self) -> None: + """Standard energy model returns a DPAtomicModel.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + + ds = DescrptSeA(4.0, 0.5, [8, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=False, seed=GLOBAL_SEED) + model = EnergyModel(ds, ft, type_map=["foo", "bar"]) + dp_am = model.get_dp_atomic_model() + self.assertIsNotNone(dp_am) + self.assertIsInstance(dp_am, DPAtomicModelDP) + self.assertTrue(hasattr(dp_am, "descriptor")) + self.assertTrue(hasattr(dp_am, "fitting_net")) + + def test_zbl_model_returns_none(self) -> None: + """DPZBLModel's LinearEnergyAtomicModel should return None.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, + ) + from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, + ) + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP + + # LinearAtomicModel requires mixed_type descriptors + ds = DescrptDPA1DP(4.0, 0.5, [14], ntypes=2) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=True, seed=GLOBAL_SEED) + dp_am = DPAtomicModelDP(ds, ft, type_map=["foo", "bar"]) + pair_tab = PairTabAtomicModel( + tab_file=None, rcut=4.0, sel=14, type_map=["foo", "bar"] + ) + zbl_am = DPZBLLinearEnergyAtomicModel( + dp_am, pair_tab, sw_rmin=1.0, sw_rmax=2.0, type_map=["foo", "bar"] + ) + # zbl_am is a LinearEnergyAtomicModel, not DPAtomicModel + self.assertFalse(isinstance(zbl_am, DPAtomicModelDP)) + + def test_spin_model_delegates(self) -> None: + """SpinModel.get_dp_atomic_model() delegates to backbone.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + from deepmd.utils.spin import ( + Spin, + ) + + ds = DescrptSeA(4.0, 0.5, [8, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=False, seed=GLOBAL_SEED) + model = EnergyModel(ds, ft, type_map=["foo", "bar"]) + spin = Spin( + use_spin=[False, False], + virtual_scale=[0.0, 0.0], + ) + spin_model = SpinModel(backbone_model=model, spin=spin) + dp_am = spin_model.get_dp_atomic_model() + self.assertIsNotNone(dp_am) + self.assertIsInstance(dp_am, DPAtomicModelDP) + + def test_frozen_model_delegates(self) -> None: + """FrozenModel.get_dp_atomic_model() delegates to inner model.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + + ds = DescrptSeA(4.0, 0.5, [8, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=False, seed=GLOBAL_SEED) + model = EnergyModel(ds, ft, type_map=["foo", "bar"]) + model = model.to(torch.float64) + model.eval() + + tmpdir = tempfile.mkdtemp() + try: + pte_path = os.path.join(tmpdir, "test.pte") + deserialize_to_file(pte_path, {"model": model.serialize()}) + dp = DeepPot(pte_path) + # The _dpmodel deserialized from .pte is a regular energy model + dp_am = dp.deep_eval._dpmodel.get_dp_atomic_model() + self.assertIsNotNone(dp_am) + self.assertIsInstance(dp_am, DPAtomicModelDP) + finally: + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestEvalDiagSpinModel(unittest.TestCase): + """Test eval diagnostic methods on spin models.""" + + @classmethod + def setUpClass(cls) -> None: + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + from deepmd.utils.spin import ( + Spin, + ) + + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # DPA1 model with spin wrapper + ds = DescrptDPA1( + cls.rcut, cls.rcut_smth, cls.sel, ntypes=cls.nt, seed=GLOBAL_SEED + ) + ft = EnergyFittingNet( + cls.nt, ds.get_dim_out(), mixed_types=ds.mixed_types(), seed=GLOBAL_SEED + ) + backbone = EnergyModel(ds, ft, type_map=cls.type_map) + backbone = backbone.to(torch.float64) + backbone.eval() + + spin = Spin(use_spin=[True, False], virtual_scale=[0.5, 0.0]) + cls.spin_model = SpinModel(backbone_model=backbone, spin=spin) + + # Export backbone as .pte, then swap _dpmodel to SpinModel + cls._tmpdir = tempfile.mkdtemp() + pte_path = os.path.join(cls._tmpdir, "spin.pte") + deserialize_to_file(pte_path, {"model": backbone.serialize()}) + cls.dp = DeepPot(pte_path) + cls.dp.deep_eval._dpmodel = cls.spin_model + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_eval_typeebd_spin(self) -> None: + """eval_typeebd traverses backbone_model for spin models.""" + typeebd = self.dp.deep_eval.eval_typeebd() + self.assertEqual(typeebd.ndim, 2) + # DPA1 TypeEmbedNet outputs ntypes or ntypes+1 + self.assertIn(typeebd.shape[0], (self.nt, self.nt + 1)) + self.assertTrue(typeebd.shape[1] > 0) + + def test_eval_descriptor_spin_raises(self) -> None: + """eval_descriptor raises NotImplementedError for spin models.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(NotImplementedError): + self.dp.deep_eval.eval_descriptor(coords, cells, atom_types) + + def test_eval_fitting_last_layer_spin_raises(self) -> None: + """eval_fitting_last_layer raises NotImplementedError for spin models.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(NotImplementedError): + self.dp.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) + + +class TestEvalDescriptorASE(unittest.TestCase): + """Test eval_descriptor with ASE neighbor list.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + model = EnergyModel(ds, ft, type_map=cls.type_map) + model = model.to(torch.float64) + model.eval() + cls.dim_descrpt = ds.get_dim_out() + + cls._tmpdir = tempfile.mkdtemp() + pte_path = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_path, {"model": model.serialize()}) + cls.dp_native = DeepPot(pte_path) + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_eval_descriptor_ase_vs_native(self) -> None: + """eval_descriptor with ASE nlist matches native nlist.""" + import ase.neighborlist + + pte_path = os.path.join(self._tmpdir, "sea.pte") + dp_ase = DeepPot( + pte_path, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, bothways=True + ), + ) + + rng = np.random.default_rng(GLOBAL_SEED + 99) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + d_native = self.dp_native.deep_eval.eval_descriptor(coords, cells, atom_types) + d_ase = dp_ase.deep_eval.eval_descriptor(coords, cells, atom_types) + + self.assertEqual(d_native.shape, d_ase.shape) + np.testing.assert_allclose(d_native, d_ase, rtol=1e-10, atol=1e-10) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_eval_fitting_last_layer_ase_vs_native(self) -> None: + """eval_fitting_last_layer with ASE nlist matches native nlist.""" + import ase.neighborlist + + pte_path = os.path.join(self._tmpdir, "sea.pte") + dp_ase = DeepPot( + pte_path, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, bothways=True + ), + ) + + rng = np.random.default_rng(GLOBAL_SEED + 99) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + f_native = self.dp_native.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + f_ase = dp_ase.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) + + self.assertEqual(f_native.shape, f_ase.shape) + np.testing.assert_allclose(f_native, f_ase, rtol=1e-10, atol=1e-10) + + if __name__ == "__main__": unittest.main() From 67ffce3aec466c38a62e3ec72cd0ac5d10c57699 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 9 Apr 2026 00:37:34 +0800 Subject: [PATCH 5/9] feat(dpmodel): register middle_output in FittingOutputDef Make fitting_check_output evaluate output_def() dynamically so that middle_output is validated when enabled via set_return_middle_output. Each fitting subclass includes _middle_output_def() in its output_def(). --- deepmd/dpmodel/fitting/dipole_fitting.py | 1 + deepmd/dpmodel/fitting/dos_fitting.py | 1 + deepmd/dpmodel/fitting/general_fitting.py | 17 ++++++++++++ deepmd/dpmodel/fitting/invar_fitting.py | 1 + .../dpmodel/fitting/polarizability_fitting.py | 1 + deepmd/dpmodel/fitting/property_fitting.py | 1 + deepmd/dpmodel/output_def.py | 5 ++-- .../dpmodel/test_fitting_middle_output.py | 27 +++++++++++++++++++ 8 files changed, 52 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 818b2eeffd..5e39622bb9 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -198,6 +198,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=self.r_differentiable, c_differentiable=self.c_differentiable, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/fitting/dos_fitting.py b/deepmd/dpmodel/fitting/dos_fitting.py index c8f145ce15..8d088656d5 100644 --- a/deepmd/dpmodel/fitting/dos_fitting.py +++ b/deepmd/dpmodel/fitting/dos_fitting.py @@ -89,6 +89,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=False, c_differentiable=False, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index f505bc556c..ec2e676682 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -24,6 +24,9 @@ get_xp_precision, to_numpy_array, ) +from deepmd.dpmodel.output_def import ( + OutputVariableDef, +) from deepmd.dpmodel.utils import ( AtomExcludeMask, FittingNet, @@ -435,6 +438,20 @@ def set_return_middle_output(self, enable: bool) -> None: """ self.eval_return_middle_output = enable + def _middle_output_def(self) -> list[OutputVariableDef]: + """Return extra OutputVariableDefs for middle_output when enabled.""" + if self.eval_return_middle_output: + return [ + OutputVariableDef( + "middle_output", + [self.neuron[-1]], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + return [] + def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index f771f927cd..f7daf6d98f 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -210,6 +210,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=True, c_differentiable=True, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 42d36000a2..e55a29c36a 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -250,6 +250,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=False, c_differentiable=False, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/fitting/property_fitting.py b/deepmd/dpmodel/fitting/property_fitting.py index b0841bafc2..78082ad4b9 100644 --- a/deepmd/dpmodel/fitting/property_fitting.py +++ b/deepmd/dpmodel/fitting/property_fitting.py @@ -129,6 +129,7 @@ def output_def(self) -> FittingOutputDef: c_differentiable=False, intensive=self.intensive, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/output_def.py b/deepmd/dpmodel/output_def.py index 5028bc43a3..5705c822ef 100644 --- a/deepmd/dpmodel/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -102,8 +102,9 @@ def __call__( **kwargs: Any, ) -> Any: ret = cls.__call__(self, *args, **kwargs) - for kk in self.md.keys(): - dd = self.md[kk] + md = self.output_def() + for kk in md.keys(): + dd = md[kk] check_var(ret[kk], dd) return ret diff --git a/source/tests/common/dpmodel/test_fitting_middle_output.py b/source/tests/common/dpmodel/test_fitting_middle_output.py index 74d67177a1..739b0ecb12 100644 --- a/source/tests/common/dpmodel/test_fitting_middle_output.py +++ b/source/tests/common/dpmodel/test_fitting_middle_output.py @@ -92,6 +92,33 @@ def test_middle_output_with_fparam_aparam(self) -> None: expected_shape = (nf, nloc, ft.neuron[-1]) self.assertEqual(ret["middle_output"].shape, expected_shape) + def test_middle_output_registered_in_output_def(self) -> None: + """middle_output should appear in output_def when enabled.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + # Not registered by default + self.assertNotIn("middle_output", ft.output_def().keys()) + # Registered after enabling + ft.set_return_middle_output(True) + self.assertIn("middle_output", ft.output_def().keys()) + odef = ft.output_def()["middle_output"] + self.assertEqual(odef.shape, [ft.neuron[-1]]) + self.assertFalse(odef.reducible) + self.assertFalse(odef.r_differentiable) + self.assertFalse(odef.c_differentiable) + # Removed after disabling + ft.set_return_middle_output(False) + self.assertNotIn("middle_output", ft.output_def().keys()) + + def test_middle_output_checked_by_decorator(self) -> None: + """fitting_check_output decorator validates middle_output shape.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft.set_return_middle_output(True) + # __call__ goes through fitting_check_output which validates output_def + ret = ft(descriptor, atype) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + def test_middle_output_deterministic(self) -> None: """Middle output should be deterministic.""" ft, descriptor, atype = self._build_fitting(mixed_types=True) From 13fefa612fcd316d002dead1c7c50d1fff4b56b3 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 11 Apr 2026 14:47:12 +0800 Subject: [PATCH 6/9] fix: use assertGreater instead of assertTrue for CodeQL --- source/tests/pt_expt/infer/test_deep_eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index cb4c3adb6a..69c2aadd66 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -1364,7 +1364,7 @@ def test_typeebd_dpa1(self) -> None: self.assertEqual(typeebd.ndim, 2) # DPA1 TypeEmbedNet outputs (ntypes+1) rows (padding type included) self.assertIn(typeebd.shape[0], (self.nt, self.nt + 1)) - self.assertTrue(typeebd.shape[1] > 0) + self.assertGreater(typeebd.shape[1], 0) def test_typeebd_sea_raises(self) -> None: """se_e2_a model has no type embedding, should raise KeyError.""" @@ -1807,7 +1807,7 @@ def test_eval_typeebd_spin(self) -> None: self.assertEqual(typeebd.ndim, 2) # DPA1 TypeEmbedNet outputs ntypes or ntypes+1 self.assertIn(typeebd.shape[0], (self.nt, self.nt + 1)) - self.assertTrue(typeebd.shape[1] > 0) + self.assertGreater(typeebd.shape[1], 0) def test_eval_descriptor_spin_raises(self) -> None: """eval_descriptor raises NotImplementedError for spin models.""" From f012b212f2fb9543f8f9ecfb68300bec6991aa58 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 11 Apr 2026 15:10:11 +0800 Subject: [PATCH 7/9] fix: address review comments from coderabbitai and chatgpt-codex - Raise ValueError (not warning) when set_return_middle_output(True) with neuron=[] (no hidden layers) - Fix _is_spin_model() to use precomputed self._is_spin flag instead of isinstance check against dpmodel SpinModel (which misses pt_expt SpinModel) - Fix RUF059 unused unpacked variables in eval_descriptor and eval_fitting_last_layer - Fix RUF059 unused unpack in test_middle_output_registered_in_output_def - Rename test_frozen_model_delegates to test_deserialized_model_delegates to match what it actually tests --- deepmd/dpmodel/fitting/general_fitting.py | 19 ++++++++---- deepmd/pt_expt/infer/deep_eval.py | 18 +++++------ .../dpmodel/test_fitting_middle_output.py | 30 ++++++++++++++++++- source/tests/pt_expt/infer/test_deep_eval.py | 10 +++---- 4 files changed, 54 insertions(+), 23 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index ec2e676682..3a3012440c 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -435,12 +435,21 @@ def set_return_middle_output(self, enable: bool) -> None: a ``"middle_output"`` key in the returned dict, containing the hidden-layer activations before the final linear layer. Shape: ``[nframes, nloc, neuron[-1]]``. + + Raises + ------ + ValueError + If ``enable`` is True but ``neuron`` is empty (no hidden layers). """ + if enable and len(self.neuron) == 0: + raise ValueError( + "middle_output requires at least one hidden layer (neuron=[])" + ) self.eval_return_middle_output = enable def _middle_output_def(self) -> list[OutputVariableDef]: """Return extra OutputVariableDefs for middle_output when enabled.""" - if self.eval_return_middle_output: + if self.eval_return_middle_output and len(self.neuron) > 0: return [ OutputVariableDef( "middle_output", @@ -718,7 +727,7 @@ def _call_common( dtype=get_xp_precision(xp, self.precision), device=array_api_compat.device(descriptor), ) - if self.eval_return_middle_output: + if self.eval_return_middle_output and len(self.neuron) > 0: middle_outs = xp.zeros( [nf, nloc, self.neuron[-1]], dtype=get_xp_precision(xp, self.precision), @@ -739,7 +748,7 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - if self.eval_return_middle_output: + if self.eval_return_middle_output and len(self.neuron) > 0: mid = self.nets[(type_i,)].call_until_last(xx) mid_mask = xp.tile( xp.reshape((atype == type_i), (nf, nloc, 1)), @@ -751,7 +760,7 @@ def _call_common( outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) - if self.eval_return_middle_output: + if self.eval_return_middle_output and len(self.neuron) > 0: middle_outs = self.nets[()].call_until_last(xx) outs += xp.reshape( xp.take( @@ -767,6 +776,6 @@ def _call_common( # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) results[self.var_name] = outs - if self.eval_return_middle_output: + if self.eval_return_middle_output and len(self.neuron) > 0: results["middle_output"] = middle_outs return results diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index d74d3c599a..9bc9579f0c 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -996,11 +996,7 @@ def get_model(self) -> torch.nn.Module: def _is_spin_model(self) -> bool: """Check if the underlying dpmodel is a SpinModel.""" - from deepmd.dpmodel.model.spin_model import ( - SpinModel, - ) - - return isinstance(self._dpmodel, SpinModel) + return getattr(self, "_is_spin", False) def eval_typeebd(self) -> np.ndarray: """Evaluate type embedding. @@ -1074,15 +1070,15 @@ def eval_descriptor( nlist_t, mapping_t, fparam_t, - aparam_t, - nframes, - natoms, + _aparam_t, + _nframes, + _natoms, ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) with torch.no_grad(): fparam_for_des = ( fparam_t if getattr(dp_am, "add_chg_spin_ebd", False) else None ) - descriptor, rot_mat, g2, h2, sw = dp_am.descriptor( + descriptor, *_ = dp_am.descriptor( ext_coord_t, ext_atype_t, nlist_t, @@ -1137,14 +1133,14 @@ def eval_fitting_last_layer( mapping_t, fparam_t, aparam_t, - nframes, + _nframes, natoms, ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) with torch.no_grad(): fparam_for_des = ( fparam_t if getattr(dp_am, "add_chg_spin_ebd", False) else None ) - descriptor, rot_mat, g2, h2, sw = dp_am.descriptor( + descriptor, rot_mat, g2, h2, _sw = dp_am.descriptor( ext_coord_t, ext_atype_t, nlist_t, diff --git a/source/tests/common/dpmodel/test_fitting_middle_output.py b/source/tests/common/dpmodel/test_fitting_middle_output.py index 739b0ecb12..9703a324cf 100644 --- a/source/tests/common/dpmodel/test_fitting_middle_output.py +++ b/source/tests/common/dpmodel/test_fitting_middle_output.py @@ -94,7 +94,7 @@ def test_middle_output_with_fparam_aparam(self) -> None: def test_middle_output_registered_in_output_def(self) -> None: """middle_output should appear in output_def when enabled.""" - ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft, _, _ = self._build_fitting(mixed_types=True) # Not registered by default self.assertNotIn("middle_output", ft.output_def().keys()) # Registered after enabling @@ -119,6 +119,34 @@ def test_middle_output_checked_by_decorator(self) -> None: nf, nloc, _ = descriptor.shape self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + def test_middle_output_empty_neuron_mixed_types(self) -> None: + """neuron=[] should raise ValueError when enabling middle_output.""" + ft = InvarFitting( + "energy", + self.nt, + 4, # dim_descrpt (arbitrary) + 1, + neuron=[], + mixed_types=True, + seed=GLOBAL_SEED, + ) + with self.assertRaises(ValueError): + ft.set_return_middle_output(True) + + def test_middle_output_empty_neuron_per_type(self) -> None: + """neuron=[] with mixed_types=False should raise ValueError.""" + ft = InvarFitting( + "energy", + self.nt, + 4, + 1, + neuron=[], + mixed_types=False, + seed=GLOBAL_SEED, + ) + with self.assertRaises(ValueError): + ft.set_return_middle_output(True) + def test_middle_output_deterministic(self) -> None: """Middle output should be deterministic.""" ft, descriptor, atype = self._build_fitting(mixed_types=True) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 69c2aadd66..e29b4fdd2a 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -1665,7 +1665,7 @@ def test_energy_model(self) -> None: self.assertTrue(hasattr(dp_am, "fitting_net")) def test_zbl_model_returns_none(self) -> None: - """DPZBLModel's LinearEnergyAtomicModel should return None.""" + """DPZBLModel wraps a LinearEnergyAtomicModel, so get_dp_atomic_model returns None.""" from deepmd.dpmodel.atomic_model.dp_atomic_model import ( DPAtomicModel as DPAtomicModelDP, ) @@ -1677,7 +1677,6 @@ def test_zbl_model_returns_none(self) -> None: ) from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP - # LinearAtomicModel requires mixed_type descriptors ds = DescrptDPA1DP(4.0, 0.5, [14], ntypes=2) ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=True, seed=GLOBAL_SEED) dp_am = DPAtomicModelDP(ds, ft, type_map=["foo", "bar"]) @@ -1687,7 +1686,7 @@ def test_zbl_model_returns_none(self) -> None: zbl_am = DPZBLLinearEnergyAtomicModel( dp_am, pair_tab, sw_rmin=1.0, sw_rmax=2.0, type_map=["foo", "bar"] ) - # zbl_am is a LinearEnergyAtomicModel, not DPAtomicModel + # LinearEnergyAtomicModel is not a DPAtomicModel self.assertFalse(isinstance(zbl_am, DPAtomicModelDP)) def test_spin_model_delegates(self) -> None: @@ -1714,8 +1713,8 @@ def test_spin_model_delegates(self) -> None: self.assertIsNotNone(dp_am) self.assertIsInstance(dp_am, DPAtomicModelDP) - def test_frozen_model_delegates(self) -> None: - """FrozenModel.get_dp_atomic_model() delegates to inner model.""" + def test_deserialized_model_delegates(self) -> None: + """Model deserialized from .pte exposes get_dp_atomic_model().""" from deepmd.dpmodel.atomic_model.dp_atomic_model import ( DPAtomicModel as DPAtomicModelDP, ) @@ -1731,7 +1730,6 @@ def test_frozen_model_delegates(self) -> None: pte_path = os.path.join(tmpdir, "test.pte") deserialize_to_file(pte_path, {"model": model.serialize()}) dp = DeepPot(pte_path) - # The _dpmodel deserialized from .pte is a regular energy model dp_am = dp.deep_eval._dpmodel.get_dp_atomic_model() self.assertIsNotNone(dp_am) self.assertIsInstance(dp_am, DPAtomicModelDP) From 1df931a691544b9a9279c5f24fed9faf1e581a24 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 11 Apr 2026 15:35:29 +0800 Subject: [PATCH 8/9] fix: restore isinstance check in _is_spin_model pt_expt.SpinModel inherits from dpmodel.SpinModel, so isinstance works correctly. The coderabbitai suggestion to use self._is_spin was wrong. --- deepmd/pt_expt/infer/deep_eval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 9bc9579f0c..351b072ce7 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -996,7 +996,11 @@ def get_model(self) -> torch.nn.Module: def _is_spin_model(self) -> bool: """Check if the underlying dpmodel is a SpinModel.""" - return getattr(self, "_is_spin", False) + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + + return isinstance(self._dpmodel, SpinModel) def eval_typeebd(self) -> np.ndarray: """Evaluate type embedding. From 4f4983d6a6d5fd1f9a64068345f0948153a8df7d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 11 Apr 2026 16:20:27 +0800 Subject: [PATCH 9/9] fix: normalize inputs in eval_descriptor and eval_fitting_last_layer Accept Python lists (not just np.ndarray) by coercing coords, cells, atom_types with np.array() at the top, matching the eval() API. --- deepmd/pt_expt/infer/deep_eval.py | 8 ++++++++ source/tests/pt_expt/infer/test_deep_eval.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 351b072ce7..19476a8537 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -1058,6 +1058,10 @@ def eval_descriptor( np.ndarray Descriptor output, shape ``(nframes, nloc, dim_descrpt)``. """ + coords = np.array(coords) + atom_types = np.array(atom_types, dtype=np.int32) + if cells is not None: + cells = np.array(cells) if self._is_spin_model(): raise NotImplementedError( "eval_descriptor is not supported for spin models." @@ -1120,6 +1124,10 @@ def eval_fitting_last_layer( np.ndarray Middle-layer output, shape ``(nframes, nloc, neuron[-1])``. """ + coords = np.array(coords) + atom_types = np.array(atom_types, dtype=np.int32) + if cells is not None: + cells = np.array(cells) if self._is_spin_model(): raise NotImplementedError( "eval_fitting_last_layer is not supported for spin models." diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index e29b4fdd2a..6797fa2c03 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -1502,6 +1502,26 @@ def test_descriptor_without_fparam_raises(self) -> None: with self.assertRaises(ValueError): self.dp_fp.deep_eval.eval_descriptor(coords, cells, atom_types) + def test_descriptor_accepts_list_inputs(self) -> None: + """eval_descriptor accepts Python lists (not just np.ndarray).""" + coords, cells, atom_types = self._make_inputs() + descpt_arr = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + descpt_list = self.dp_sea.deep_eval.eval_descriptor( + coords.tolist(), cells.tolist(), atom_types.tolist() + ) + np.testing.assert_allclose(descpt_arr, descpt_list) + + def test_fitting_last_layer_accepts_list_inputs(self) -> None: + """eval_fitting_last_layer accepts Python lists (not just np.ndarray).""" + coords, cells, atom_types = self._make_inputs() + mid_arr = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + mid_list = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords.tolist(), cells.tolist(), atom_types.tolist() + ) + np.testing.assert_allclose(mid_arr, mid_list) + class TestEvalFittingLastLayer(unittest.TestCase): """Test eval_fitting_last_layer for pt_expt models."""