diff --git a/deepmd/dpmodel/fitting/dipole_fitting.py b/deepmd/dpmodel/fitting/dipole_fitting.py index 0fc3a1fefa..5e39622bb9 100644 --- a/deepmd/dpmodel/fitting/dipole_fitting.py +++ b/deepmd/dpmodel/fitting/dipole_fitting.py @@ -198,6 +198,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=self.r_differentiable, c_differentiable=self.c_differentiable, ), + *self._middle_output_def(), ] ) @@ -239,9 +240,8 @@ def call( nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # (nframes, nloc, m1) - out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ - self.var_name - ] + results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = results[self.var_name] # (nframes * nloc, 1, m1) out = xp.reshape(out, (-1, 1, self.embedding_width)) # (nframes * nloc, m1, 3) @@ -249,5 +249,5 @@ def call( # (nframes, nloc, 3) # out = np.einsum("bim,bmj->bij", out, gr).squeeze(-2).reshape(nframes, nloc, 3) out = out @ gr - out = xp.reshape(out, (nframes, nloc, 3)) - return {self.var_name: out} + results[self.var_name] = xp.reshape(out, (nframes, nloc, 3)) + return results diff --git a/deepmd/dpmodel/fitting/dos_fitting.py b/deepmd/dpmodel/fitting/dos_fitting.py index c8f145ce15..8d088656d5 100644 --- a/deepmd/dpmodel/fitting/dos_fitting.py +++ b/deepmd/dpmodel/fitting/dos_fitting.py @@ -89,6 +89,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=False, c_differentiable=False, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 4761703a19..3a3012440c 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -24,6 +24,9 @@ get_xp_precision, to_numpy_array, ) +from deepmd.dpmodel.output_def import ( + OutputVariableDef, +) from deepmd.dpmodel.utils import ( AtomExcludeMask, FittingNet, @@ -168,6 +171,7 @@ def __init__( if self.spin is not None: raise NotImplementedError("spin is not supported") self.remove_vaccum_contribution = remove_vaccum_contribution + self.eval_return_middle_output = False net_dim_out = self._net_out_dim() # init constants @@ -424,6 +428,39 @@ def get_default_fparam(self) -> list[float] | None: """Get the default frame parameters.""" return self.default_fparam + def set_return_middle_output(self, enable: bool) -> None: + """Enable or disable returning the middle (pre-last-layer) output. + + When enabled, the fitting network's ``call`` method will include + a ``"middle_output"`` key in the returned dict, containing the + hidden-layer activations before the final linear layer. Shape: + ``[nframes, nloc, neuron[-1]]``. + + Raises + ------ + ValueError + If ``enable`` is True but ``neuron`` is empty (no hidden layers). + """ + if enable and len(self.neuron) == 0: + raise ValueError( + "middle_output requires at least one hidden layer (neuron=[])" + ) + self.eval_return_middle_output = enable + + def _middle_output_def(self) -> list[OutputVariableDef]: + """Return extra OutputVariableDefs for middle_output when enabled.""" + if self.eval_return_middle_output and len(self.neuron) > 0: + return [ + OutputVariableDef( + "middle_output", + [self.neuron[-1]], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + return [] + def get_sel_type(self) -> list[int]: """Get the selected atom types of this model. @@ -690,6 +727,12 @@ def _call_common( dtype=get_xp_precision(xp, self.precision), device=array_api_compat.device(descriptor), ) + if self.eval_return_middle_output and len(self.neuron) > 0: + middle_outs = xp.zeros( + [nf, nloc, self.neuron[-1]], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(descriptor), + ) for type_i in range(self.ntypes): mask = xp.tile( xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out) @@ -705,10 +748,20 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + if self.eval_return_middle_output and len(self.neuron) > 0: + mid = self.nets[(type_i,)].call_until_last(xx) + mid_mask = xp.tile( + xp.reshape((atype == type_i), (nf, nloc, 1)), + (1, 1, self.neuron[-1]), + ) + mid = xp.where(mid_mask, mid, xp.zeros_like(mid)) + middle_outs = middle_outs + mid else: outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) + if self.eval_return_middle_output and len(self.neuron) > 0: + middle_outs = self.nets[()].call_until_last(xx) outs += xp.reshape( xp.take( xp.astype(self.bias_atom_e[...], outs.dtype), @@ -723,4 +776,6 @@ def _call_common( # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) results[self.var_name] = outs + if self.eval_return_middle_output and len(self.neuron) > 0: + results["middle_output"] = middle_outs return results diff --git a/deepmd/dpmodel/fitting/invar_fitting.py b/deepmd/dpmodel/fitting/invar_fitting.py index f771f927cd..f7daf6d98f 100644 --- a/deepmd/dpmodel/fitting/invar_fitting.py +++ b/deepmd/dpmodel/fitting/invar_fitting.py @@ -210,6 +210,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=True, c_differentiable=True, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/fitting/polarizability_fitting.py b/deepmd/dpmodel/fitting/polarizability_fitting.py index 361f033a68..e55a29c36a 100644 --- a/deepmd/dpmodel/fitting/polarizability_fitting.py +++ b/deepmd/dpmodel/fitting/polarizability_fitting.py @@ -250,6 +250,7 @@ def output_def(self) -> FittingOutputDef: r_differentiable=False, c_differentiable=False, ), + *self._middle_output_def(), ] ) @@ -326,9 +327,8 @@ def call( "Must provide the rotation matrix for polarizability fitting." ) # (nframes, nloc, _net_out_dim) - out = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ - self.var_name - ] + results = self._call_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = results.pop(self.var_name) # out = out * self.scale[atype, ...] scale_atype = xp.reshape( xp.take(xp.astype(self.scale, out.dtype), xp.reshape(atype, (-1,)), axis=0), @@ -371,4 +371,5 @@ def call( # (nframes, nloc, 3, 3) bias = bias[..., None] * eye out = out + bias - return {"polarizability": out} + results["polarizability"] = out + return results diff --git a/deepmd/dpmodel/fitting/property_fitting.py b/deepmd/dpmodel/fitting/property_fitting.py index b0841bafc2..78082ad4b9 100644 --- a/deepmd/dpmodel/fitting/property_fitting.py +++ b/deepmd/dpmodel/fitting/property_fitting.py @@ -129,6 +129,7 @@ def output_def(self) -> FittingOutputDef: c_differentiable=False, intensive=self.intensive, ), + *self._middle_output_def(), ] ) diff --git a/deepmd/dpmodel/model/frozen.py b/deepmd/dpmodel/model/frozen.py index a8bf74323b..ed9da3bafe 100644 --- a/deepmd/dpmodel/model/frozen.py +++ b/deepmd/dpmodel/model/frozen.py @@ -1,9 +1,15 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( + TYPE_CHECKING, Any, NoReturn, ) +if TYPE_CHECKING: + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + from deepmd.dpmodel.common import ( NativeOP, ) @@ -131,6 +137,10 @@ def get_observed_type_list(self) -> list[str]: """Get observed types (elements) of the model during data statistics.""" return self.model.get_observed_type_list() + def get_dp_atomic_model(self) -> "DPAtomicModel | None": + """Get the underlying DPAtomicModel by delegating to the inner model.""" + return self.model.get_dp_atomic_model() + def serialize(self) -> dict: """Serialize the model. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index f70b668a87..597f8ea006 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -3,9 +3,15 @@ Callable, ) from typing import ( + TYPE_CHECKING, Any, ) +if TYPE_CHECKING: + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + import array_api_compat import numpy as np @@ -704,6 +710,21 @@ def is_aparam_nall(self) -> bool: """ return self.atomic_model.is_aparam_nall() + def get_dp_atomic_model(self) -> "DPAtomicModel | None": + """Get the underlying DPAtomicModel with descriptor and fitting_net. + + Returns the ``atomic_model`` if it is a ``DPAtomicModel`` instance + (i.e. has both ``descriptor`` and ``fitting_net``). Returns ``None`` + for composite atomic models such as ``LinearEnergyAtomicModel``. + """ + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + + if isinstance(self.atomic_model, DPAtomicModel): + return self.atomic_model + return None + def get_rcut(self) -> float: """Get the cut-off radius.""" return self.atomic_model.get_rcut() diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index e60bf809f9..be6566e303 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -547,6 +547,10 @@ def __getattr__(self, name: str) -> Any: raise AttributeError(name) return getattr(self.backbone_model, name) + def get_dp_atomic_model(self) -> "DPAtomicModel | None": + """Get the underlying DPAtomicModel by delegating to the backbone model.""" + return self.backbone_model.get_dp_atomic_model() + def serialize(self) -> dict: return { "type": "spin_ener", diff --git a/deepmd/dpmodel/output_def.py b/deepmd/dpmodel/output_def.py index 5028bc43a3..5705c822ef 100644 --- a/deepmd/dpmodel/output_def.py +++ b/deepmd/dpmodel/output_def.py @@ -102,8 +102,9 @@ def __call__( **kwargs: Any, ) -> Any: ret = cls.__call__(self, *args, **kwargs) - for kk in self.md.keys(): - dd = self.md[kk] + md = self.output_def() + for kk in md.keys(): + dd = md[kk] check_var(ret[kk], dd) return ret diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index dd1831a4ba..19476a8537 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -624,15 +624,22 @@ def _build_nlist_ase_single( return extended_coord, extended_atype, nlist, mapping - def _eval_model( + def _prepare_inputs( self, coords: np.ndarray, cells: np.ndarray | None, atom_types: np.ndarray, fparam: np.ndarray | None, aparam: np.ndarray | None, - request_defs: list[OutputVariableDef], - ) -> tuple[np.ndarray, ...]: + ) -> tuple: + """Prepare tensor inputs for model evaluation. + + Returns + ------- + tuple + (ext_coord_t, ext_atype_t, nlist_t, mapping_t, + fparam_t, aparam_t, nframes, natoms) + """ nframes = coords.shape[0] if len(atom_types.shape) == 1: natoms = len(atom_types) @@ -715,6 +722,37 @@ def _eval_model( else: aparam_t = None + return ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + nframes, + natoms, + ) + + def _eval_model( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + nframes, + natoms, + ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + # Call the model (forward_common_lower interface, internal keys) if self._is_pt2: # AOTInductor's __call__ unflattens output using stored out_spec, @@ -955,3 +993,185 @@ def get_model(self) -> torch.nn.Module: The exported model module. """ return self.exported_module + + def _is_spin_model(self) -> bool: + """Check if the underlying dpmodel is a SpinModel.""" + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + + return isinstance(self._dpmodel, SpinModel) + + def eval_typeebd(self) -> np.ndarray: + """Evaluate type embedding. + + Returns + ------- + np.ndarray + Type embedding array of shape ``(ntypes, tebd_dim)``. + + Raises + ------ + KeyError + If the model has no type embedding networks. + """ + from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP + + model = self._dpmodel + if self._is_spin_model(): + model = model.backbone_model + out = [] + for mm in model.modules(): + if isinstance(mm, TypeEmbedNetDP): + out.append(mm()) + if not out: + raise KeyError("The model has no type embedding networks.") + typeebd = torch.cat(out, dim=1) + return typeebd.detach().cpu().numpy() + + def eval_descriptor( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> np.ndarray: + """Evaluate descriptor. + + Parameters + ---------- + coords + Coordinates, shape ``(nframes, natoms, 3)``. + cells + Cell vectors, shape ``(nframes, 3, 3)`` or ``None``. + atom_types + Atom types, shape ``(natoms,)`` or ``(nframes, natoms)``. + fparam + Frame parameters, optional. + aparam + Atom parameters, optional. + + Returns + ------- + np.ndarray + Descriptor output, shape ``(nframes, nloc, dim_descrpt)``. + """ + coords = np.array(coords) + atom_types = np.array(atom_types, dtype=np.int32) + if cells is not None: + cells = np.array(cells) + if self._is_spin_model(): + raise NotImplementedError( + "eval_descriptor is not supported for spin models." + ) + dp_am = self._dpmodel.get_dp_atomic_model() + if dp_am is None: + raise NotImplementedError( + "eval_descriptor is not supported for this model type " + f"({type(self._dpmodel).__name__})." + ) + ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + _aparam_t, + _nframes, + _natoms, + ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + with torch.no_grad(): + fparam_for_des = ( + fparam_t if getattr(dp_am, "add_chg_spin_ebd", False) else None + ) + descriptor, *_ = dp_am.descriptor( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping=mapping_t, + fparam=fparam_for_des, + ) + return descriptor.detach().cpu().numpy() + + def eval_fitting_last_layer( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> np.ndarray: + """Evaluate the last hidden layer of the fitting network. + + Parameters + ---------- + coords + Coordinates, shape ``(nframes, natoms, 3)``. + cells + Cell vectors, shape ``(nframes, 3, 3)`` or ``None``. + atom_types + Atom types, shape ``(natoms,)`` or ``(nframes, natoms)``. + fparam + Frame parameters, optional. + aparam + Atom parameters, optional. + + Returns + ------- + np.ndarray + Middle-layer output, shape ``(nframes, nloc, neuron[-1])``. + """ + coords = np.array(coords) + atom_types = np.array(atom_types, dtype=np.int32) + if cells is not None: + cells = np.array(cells) + if self._is_spin_model(): + raise NotImplementedError( + "eval_fitting_last_layer is not supported for spin models." + ) + dp_am = self._dpmodel.get_dp_atomic_model() + if dp_am is None: + raise NotImplementedError( + "eval_fitting_last_layer is not supported for this model type " + f"({type(self._dpmodel).__name__})." + ) + ( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping_t, + fparam_t, + aparam_t, + _nframes, + natoms, + ) = self._prepare_inputs(coords, cells, atom_types, fparam, aparam) + with torch.no_grad(): + fparam_for_des = ( + fparam_t if getattr(dp_am, "add_chg_spin_ebd", False) else None + ) + descriptor, rot_mat, g2, h2, _sw = dp_am.descriptor( + ext_coord_t, + ext_atype_t, + nlist_t, + mapping=mapping_t, + fparam=fparam_for_des, + ) + atype = ext_atype_t[:, :natoms] + fitting_net = dp_am.fitting_net + fitting_net.set_return_middle_output(True) + try: + ret = fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam_t, + aparam=aparam_t, + ) + finally: + fitting_net.set_return_middle_output(False) + return ret["middle_output"].detach().cpu().numpy() diff --git a/source/tests/common/dpmodel/test_fitting_middle_output.py b/source/tests/common/dpmodel/test_fitting_middle_output.py new file mode 100644 index 0000000000..9703a324cf --- /dev/null +++ b/source/tests/common/dpmodel/test_fitting_middle_output.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .case_single_frame_with_nlist import ( + TestCaseSingleFrameWithNlist, +) + + +class TestFittingMiddleOutput(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def _build_fitting(self, mixed_types, nfp=0, nap=0): + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + descriptor = dd[0] + atype = self.atype_ext[:, : self.nloc] + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, # dim_out + mixed_types=mixed_types, + numb_fparam=nfp, + numb_aparam=nap, + seed=GLOBAL_SEED, + ) + return ft, descriptor, atype + + def test_middle_output_disabled_by_default(self) -> None: + """Middle output should not be present when not enabled.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ret = ft.call(descriptor, atype) + self.assertIn("energy", ret) + self.assertNotIn("middle_output", ret) + + def test_middle_output_enabled_mixed_types(self) -> None: + """When enabled, middle_output key is present with correct shape for mixed_types=True.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + expected_shape = (nf, nloc, ft.neuron[-1]) + self.assertEqual(ret["middle_output"].shape, expected_shape) + + def test_middle_output_enabled_per_type(self) -> None: + """When enabled, middle_output key is present with correct shape for mixed_types=False.""" + ft, descriptor, atype = self._build_fitting(mixed_types=False) + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + expected_shape = (nf, nloc, ft.neuron[-1]) + self.assertEqual(ret["middle_output"].shape, expected_shape) + + def test_middle_output_toggle(self) -> None: + """Verify toggling set_return_middle_output on/off works.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + + # Enable + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype) + self.assertIn("middle_output", ret) + + # Disable + ft.set_return_middle_output(False) + ret = ft.call(descriptor, atype) + self.assertNotIn("middle_output", ret) + + def test_middle_output_with_fparam_aparam(self) -> None: + """Middle output works with fparam and aparam.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True, nfp=2, nap=3) + nf, nloc, _ = descriptor.shape + fparam = np.zeros([nf, 2], dtype=np.float64) + aparam = np.zeros([nf, nloc, 3], dtype=np.float64) + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype, fparam=fparam, aparam=aparam) + self.assertIn("middle_output", ret) + expected_shape = (nf, nloc, ft.neuron[-1]) + self.assertEqual(ret["middle_output"].shape, expected_shape) + + def test_middle_output_registered_in_output_def(self) -> None: + """middle_output should appear in output_def when enabled.""" + ft, _, _ = self._build_fitting(mixed_types=True) + # Not registered by default + self.assertNotIn("middle_output", ft.output_def().keys()) + # Registered after enabling + ft.set_return_middle_output(True) + self.assertIn("middle_output", ft.output_def().keys()) + odef = ft.output_def()["middle_output"] + self.assertEqual(odef.shape, [ft.neuron[-1]]) + self.assertFalse(odef.reducible) + self.assertFalse(odef.r_differentiable) + self.assertFalse(odef.c_differentiable) + # Removed after disabling + ft.set_return_middle_output(False) + self.assertNotIn("middle_output", ft.output_def().keys()) + + def test_middle_output_checked_by_decorator(self) -> None: + """fitting_check_output decorator validates middle_output shape.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft.set_return_middle_output(True) + # __call__ goes through fitting_check_output which validates output_def + ret = ft(descriptor, atype) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + + def test_middle_output_empty_neuron_mixed_types(self) -> None: + """neuron=[] should raise ValueError when enabling middle_output.""" + ft = InvarFitting( + "energy", + self.nt, + 4, # dim_descrpt (arbitrary) + 1, + neuron=[], + mixed_types=True, + seed=GLOBAL_SEED, + ) + with self.assertRaises(ValueError): + ft.set_return_middle_output(True) + + def test_middle_output_empty_neuron_per_type(self) -> None: + """neuron=[] with mixed_types=False should raise ValueError.""" + ft = InvarFitting( + "energy", + self.nt, + 4, + 1, + neuron=[], + mixed_types=False, + seed=GLOBAL_SEED, + ) + with self.assertRaises(ValueError): + ft.set_return_middle_output(True) + + def test_middle_output_deterministic(self) -> None: + """Middle output should be deterministic.""" + ft, descriptor, atype = self._build_fitting(mixed_types=True) + ft.set_return_middle_output(True) + ret1 = ft.call(descriptor, atype) + ret2 = ft.call(descriptor, atype) + np.testing.assert_array_equal(ret1["middle_output"], ret2["middle_output"]) + + def test_middle_output_dipole_fitting(self) -> None: + """middle_output flows through DipoleFitting.call().""" + from deepmd.dpmodel.fitting.dipole_fitting import ( + DipoleFitting, + ) + + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + descriptor = dd[0] + rot_mat = dd[1] + atype = self.atype_ext[:, : self.nloc] + ft = DipoleFitting( + ntypes=self.nt, + dim_descrpt=ds.get_dim_out(), + embedding_width=ds.get_dim_emb(), + seed=GLOBAL_SEED, + ) + # Disabled: no middle_output + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertNotIn("middle_output", ret) + # Enabled: middle_output present + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + # Primary output still correct shape + self.assertEqual(ret["dipole"].shape, (nf, nloc, 3)) + + def test_middle_output_polar_fitting(self) -> None: + """middle_output flows through PolarFitting.call().""" + from deepmd.dpmodel.fitting.polarizability_fitting import ( + PolarFitting, + ) + + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + descriptor = dd[0] + rot_mat = dd[1] + atype = self.atype_ext[:, : self.nloc] + ft = PolarFitting( + ntypes=self.nt, + dim_descrpt=ds.get_dim_out(), + embedding_width=ds.get_dim_emb(), + seed=GLOBAL_SEED, + ) + # Disabled: no middle_output + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertNotIn("middle_output", ret) + # Enabled: middle_output present + ft.set_return_middle_output(True) + ret = ft.call(descriptor, atype, gr=rot_mat) + self.assertIn("middle_output", ret) + nf, nloc, _ = descriptor.shape + self.assertEqual(ret["middle_output"].shape, (nf, nloc, ft.neuron[-1])) + # Primary output still correct shape + self.assertEqual(ret["polarizability"].shape, (nf, nloc, 3, 3)) diff --git a/source/tests/infer/fparam_aparam-testcase.yaml b/source/tests/infer/fparam_aparam-testcase.yaml index 1f300e31dd..5bb9ab471b 100644 --- a/source/tests/infer/fparam_aparam-testcase.yaml +++ b/source/tests/infer/fparam_aparam-testcase.yaml @@ -159,3 +159,999 @@ results: -2.486859433265622629e-02, 2.875323131744185121e-02, ] + descriptor: + [ + 6.083714331120656515e-01, + 4.502505554576420876e-01, + 1.594398494919171405e-01, + 4.727519861103077758e-01, + 6.727545551660457646e-01, + -3.105848623409598885e-01, + -2.034772733369290820e-01, + -2.166933245535891117e-01, + 4.502505554576420876e-01, + 1.139649142647525260e+00, + 4.969499257159836203e-01, + 4.301892148257391857e-01, + 1.270001504410809723e+00, + -5.140452446715122470e-01, + 3.263504683974155496e-01, + -9.091582204946582202e-01, + 1.594398494919171405e-01, + 4.969499257159836203e-01, + 2.260654681080056239e-01, + 1.624402806893359419e-01, + 5.386383726647075987e-01, + -2.137567595770779316e-01, + 1.714384661213686489e-01, + -4.035449616399264805e-01, + 4.727519861103077758e-01, + 4.301892148257391857e-01, + 1.624402806893359419e-01, + 3.754699534763070723e-01, + 5.995953110586557111e-01, + -2.694918713461056381e-01, + -1.105341613089115826e-01, + -2.422835589130273026e-01, + 6.727545551660457646e-01, + 1.270001504410809723e+00, + 5.386383726647075987e-01, + 5.995953110586557111e-01, + 1.483309072850735211e+00, + -6.156618440021969230e-01, + 2.315666690444803111e-01, + -9.570105981549420493e-01, + -3.105848623409598885e-01, + -5.140452446715122470e-01, + -2.137567595770779316e-01, + -2.694918713461056381e-01, + -6.156618440021969230e-01, + 2.589390132149291257e-01, + -6.406882715059965261e-02, + 3.754908530130299793e-01, + -2.034772733369290820e-01, + 3.263504683974155496e-01, + 1.714384661213686489e-01, + -1.105341613089115826e-01, + 2.315666690444803111e-01, + -6.406882715059965261e-02, + 3.502027394021257067e-01, + -3.698375855067269069e-01, + -2.166933245535891117e-01, + -9.091582204946582202e-01, + -4.035449616399264805e-01, + -2.422835589130273026e-01, + -9.570105981549420493e-01, + 3.754908530130299793e-01, + -3.698375855067269069e-01, + 7.767038256079379366e-01, + -2.981667553358613998e-01, + -7.391122602545086018e-01, + -3.205454587664270938e-01, + -2.831620287994112140e-01, + -8.262423426451125374e-01, + 3.351661217421733063e-01, + -2.067609428898770363e-01, + 5.886302766667602659e-01, + 1.240976071619029480e-01, + 4.973224483159546239e-01, + 2.234717192757742499e-01, + 1.368572698960379141e-01, + 5.254712190591790399e-01, + -2.061836993159548159e-01, + 1.983387449937808189e-01, + -4.204030104006635793e-01, + 5.581270077279179009e-01, + 4.082193934796568002e-01, + 1.441148337259589574e-01, + 4.332383483055207152e-01, + 6.125584989446182238e-01, + -2.832083358508835080e-01, + -1.895286928385826741e-01, + -1.942100761080858551e-01, + 6.378556323516176851e-01, + 1.001982931038889690e+00, + 4.168948972325110525e-01, + 5.485284807739624346e-01, + 1.212661151645888724e+00, + -5.122423732761739457e-01, + 1.001406309724684474e-01, + -7.186443216927924649e-01, + -7.152415656398955490e-02, + 1.054901580743528022e-01, + 6.232449399777988119e-02, + -3.894272340885426148e-02, + 7.205094434523259816e-02, + -1.805014026654702675e-02, + 1.183066919255474220e-01, + -1.161656443653441007e-01, + 5.022149866663377926e-01, + 6.976700285197479090e-01, + 2.862873989982428480e-01, + 4.229179663891721730e-01, + 8.673481304519186086e-01, + -3.709822075657827067e-01, + 2.498047634770347511e-02, + -4.803400360910294875e-01, + 5.637582607285308578e-01, + 1.079696549143701167e+00, + 4.590639095981799978e-01, + 5.040455670483138251e-01, + 1.257759447946239506e+00, + -5.212815054748195509e-01, + 2.032328956811573217e-01, + -8.159605088714143584e-01, + -1.350693558539911299e-01, + -3.220522605062802191e-01, + -1.397884793071107190e-01, + -1.270778868078931978e-01, + -3.620099554771795125e-01, + 1.472182089511498149e-01, + -8.617472043402696347e-02, + 2.543117761117610343e-01, + -1.464136841504689035e-01, + 1.217686222663357798e-01, + 6.951144078262200265e-02, + -9.088638084027333974e-02, + 5.842502264744667967e-02, + -6.392865301134783088e-03, + 1.850518592878056534e-01, + -1.617257895507598076e-01, + 9.094458269279154239e-02, + -3.087692407825532448e-01, + -1.487259125583649533e-01, + 3.376356883973294365e-02, + -2.598628715751217788e-01, + 8.690297349942077698e-02, + -2.524037647949817775e-01, + 3.202540825701963301e-01, + -3.751637757936663320e-01, + -3.982900935111258200e-01, + -1.544156549323568994e-01, + -3.034653687171679493e-01, + -5.304295281929436445e-01, + 2.341575952346893752e-01, + 5.419559866725016245e-02, + 2.461408379536359770e-01, + 3.132239475444893451e-01, + 8.765556472050846093e-01, + 3.875192888559441107e-01, + 3.079155531500631371e-01, + 9.635319863477369573e-01, + -3.866632840469169663e-01, + 2.768015304833532642e-01, + -7.082444142190988945e-01, + 1.430144157827138329e+00, + 8.371406802726794050e-01, + 3.217236266667155009e-01, + 1.101372195257868780e+00, + 1.309710128726800038e+00, + -6.265729912157642634e-01, + -5.976692872812503499e-01, + -2.230698942337786261e-01, + 8.371406802726794050e-01, + 1.023033322697949554e+00, + 4.121430074516977404e-01, + 6.841961394490133630e-01, + 1.371741299045477591e+00, + -5.865551302740638073e-01, + -4.901372130550269574e-02, + -7.126229372870805934e-01, + 3.217236266667155009e-01, + 4.121430074516977404e-01, + 1.667923268135688974e-01, + 2.645504158349288315e-01, + 5.474427363099826360e-01, + -2.328003308815323191e-01, + -7.917678800792056840e-03, + -2.933583025540117961e-01, + 1.101372195257868780e+00, + 6.841961394490133630e-01, + 2.645504158349288315e-01, + 8.512102245969188630e-01, + 1.052791432993813947e+00, + -4.985889765492535486e-01, + -4.378685187396406109e-01, + -2.142944536866224214e-01, + 1.309710128726800038e+00, + 1.371741299045477591e+00, + 5.474427363099826360e-01, + 1.052791432993813947e+00, + 1.890936188225563086e+00, + -8.248805096784416202e-01, + -2.065440769036151480e-01, + -8.692732098846506217e-01, + -6.265729912157642634e-01, + -5.865551302740638073e-01, + -2.328003308815323191e-01, + -4.985889765492535486e-01, + -8.248805096784416202e-01, + 3.656750155213921438e-01, + 1.380477511643647681e-01, + 3.391836108469640787e-01, + -5.976692872812503499e-01, + -4.901372130550269574e-02, + -7.917678800792056840e-03, + -4.378685187396406109e-01, + -2.065440769036151480e-01, + 1.380477511643647681e-01, + 4.196794682686730837e-01, + -2.346170286615520129e-01, + -2.230698942337786261e-01, + -7.126229372870805934e-01, + -2.933583025540117961e-01, + -2.142944536866224214e-01, + -8.692732098846506217e-01, + 3.391836108469640787e-01, + -2.346170286615520129e-01, + 6.743229971375143128e-01, + -5.123920094325192798e-01, + -7.074779841860322493e-01, + -2.855206215044043816e-01, + -4.243401890981060021e-01, + -9.350347600855188901e-01, + 3.936059152956140617e-01, + -1.539375914738916849e-02, + 5.278501097255398067e-01, + 1.466478863200821325e-01, + 4.137911975084271332e-01, + 1.699008976760616185e-01, + 1.368122661427808029e-01, + 5.095046380033478872e-01, + -2.004656295627817442e-01, + 1.233528182536938744e-01, + -3.836611231041641701e-01, + 1.307379055422948300e+00, + 7.738567876659411260e-01, + 2.975463412155298082e-01, + 1.007378233374699450e+00, + 1.207607062406870435e+00, + -5.765232705711489380e-01, + -5.416165483742040321e-01, + -2.138262638647728253e-01, + 1.359643853640744249e+00, + 1.143610489764218752e+00, + 4.519272049237009758e-01, + 1.072869654075972567e+00, + 1.639797498390930031e+00, + -7.390349747630275967e-01, + -3.719220608199254596e-01, + -5.916890292398517825e-01, + -1.552513570223078287e-01, + -1.897934780868526863e-02, + -4.210073071744310953e-03, + -1.139855766718844976e-01, + -6.215846401873216520e-02, + 3.892738523383156707e-02, + 1.057189870558455924e-01, + -5.276144303775173738e-02, + 1.104538870923048366e+00, + 8.526603431257446797e-01, + 3.351761963971622293e-01, + 8.659673339527532709e-01, + 1.245043447077310406e+00, + -5.687481602434459882e-01, + -3.451881390779078518e-01, + -3.969124013792098005e-01, + 1.089349537510279520e+00, + 1.164895568404513027e+00, + 4.652591969475391998e-01, + 8.773598369767227068e-01, + 1.600472758596691580e+00, + -6.961382514444349745e-01, + -1.583528216631740482e-01, + -7.496262147675080145e-01, + -2.480012474417669366e-01, + -3.065647179173028869e-01, + -1.233673686142572196e-01, + -2.028456491009781593e-01, + -4.110517151842552619e-01, + 1.754473861346198538e-01, + 1.266109575260882664e-02, + 2.155799593938188818e-01, + -3.978403937119566747e-01, + -9.991332892649396058e-02, + -3.340409898200767669e-02, + -2.963873874846542633e-01, + -2.143409093118101016e-01, + 1.197953317221483738e-01, + 2.414569614468069225e-01, + -8.227296457316374267e-02, + 3.690821426046725362e-01, + -8.055988299054643587e-02, + -4.082341030522432940e-02, + 2.626029367695669747e-01, + -1.042803680946458424e-03, + -3.859858812343196915e-02, + -3.212762213243168796e-01, + 2.685029558806420469e-01, + -8.323969526289327625e-01, + -5.886919464746885877e-01, + -2.295559263038976228e-01, + -6.484017939892503524e-01, + -8.785152580554453916e-01, + 4.068776062722544995e-01, + 2.907761085338105289e-01, + 2.415838031028531563e-01, + 5.332734738415039200e-01, + 8.033746811924462605e-01, + 3.253994932488234459e-01, + 4.466281431516882505e-01, + 1.049148736467505838e+00, + -4.372511097323733553e-01, + 5.389587131807656306e-02, + -6.224441065274786133e-01, + 9.652167274938354691e-01, + 5.313446375815668032e-01, + 1.595985340042652689e-01, + 7.366646596146324555e-01, + 8.456546075061203149e-01, + -4.160141501041305645e-01, + -4.255582314137071887e-01, + -1.459893329129947348e-01, + 5.313446375815668032e-01, + 7.843080248152077827e-01, + 2.839722764135915178e-01, + 4.378789672765076579e-01, + 1.043223182916673375e+00, + -4.397200153695126068e-01, + 3.886481651629476036e-02, + -6.392727633140400378e-01, + 1.595985340042652689e-01, + 2.839722764135915178e-01, + 1.158293104700064968e-01, + 1.354240804871464643e-01, + 3.740428049816837408e-01, + -1.520077295574986109e-01, + 3.906416117492179929e-02, + -2.414007968687422734e-01, + 7.366646596146324555e-01, + 4.378789672765076579e-01, + 1.354240804871464643e-01, + 5.644438045269433157e-01, + 6.833253054234037505e-01, + -3.312184769506364979e-01, + -3.067479524964061288e-01, + -1.475897350138534181e-01, + 8.456546075061203149e-01, + 1.043223182916673375e+00, + 3.740428049816837408e-01, + 6.833253054234037505e-01, + 1.423660291885181062e+00, + -6.125441018811154104e-01, + -5.221076190477519363e-02, + -7.845287167838036479e-01, + -4.160141501041305645e-01, + -4.397200153695126068e-01, + -1.520077295574986109e-01, + -3.312184769506364979e-01, + -6.125441018811154104e-01, + 2.698786254482626323e-01, + 6.654507192479590383e-02, + 3.033699870391518005e-01, + -4.255582314137071887e-01, + 3.886481651629476036e-02, + 3.906416117492179929e-02, + -3.067479524964061288e-01, + -5.221076190477519363e-02, + 6.654507192479590383e-02, + 3.393861782856803511e-01, + -2.455063645724311627e-01, + -1.459893329129947348e-01, + -6.392727633140400378e-01, + -2.414007968687422734e-01, + -1.475897350138534181e-01, + -7.845287167838036479e-01, + 3.033699870391518005e-01, + -2.455063645724311627e-01, + 6.614884222618355736e-01, + -3.265542808407963515e-01, + -5.693028167157063724e-01, + -2.101678679286862750e-01, + -2.746923718860757591e-01, + -7.457901247019582680e-01, + 3.081124176255724545e-01, + -7.215648382777055392e-02, + 4.929607152885880916e-01, + 7.289144396716068508e-02, + 3.543740733174537971e-01, + 1.413074915636731432e-01, + 7.633121055347411033e-02, + 4.362453637185305655e-01, + -1.663244632123324074e-01, + 1.423151678522617813e-01, + -3.675326287502824751e-01, + 8.796846030153004925e-01, + 4.949259267405346496e-01, + 1.507390539711822852e-01, + 6.721064859018730520e-01, + 7.839949395373813079e-01, + -3.837608242990865337e-01, + -3.819288553812149045e-01, + -1.449413796860413162e-01, + 8.907246324644731983e-01, + 8.102877153047098879e-01, + 2.761347910508374359e-01, + 7.009365398991779239e-01, + 1.156667158090711078e+00, + -5.208994952781673682e-01, + -2.149701864086924763e-01, + -4.977028399023077365e-01, + -1.544386539268297331e-01, + -3.135035718915457625e-02, + 7.480831051489745913e-03, + -1.135803115961424858e-01, + -6.844467004321277970e-02, + 4.444629092177178331e-02, + 9.844907620866516496e-02, + -3.171928598253399845e-02, + 7.230832319002997721e-01, + 5.844662077789191112e-01, + 1.962525977491133278e-01, + 5.642907623475763579e-01, + 8.531813387800899484e-01, + -3.913192904287589591e-01, + -2.151528698629770198e-01, + -3.199599599522977011e-01, + 6.988775778606687306e-01, + 8.886437601116213836e-01, + 3.210311239935199623e-01, + 5.665061413069706342e-01, + 1.208524116908088031e+00, + -5.175924737174160128e-01, + -2.842688210451090741e-02, + -6.780668764946560234e-01, + -1.565472119713383070e-01, + -2.381925212230145972e-01, + -8.722212678513849293e-02, + -1.294711626807042715e-01, + -3.164147548509885222e-01, + 1.326809813472158428e-01, + -1.537634923757306461e-02, + 1.963938295947420254e-01, + -2.733270100864703123e-01, + -3.208757074499291734e-02, + 3.495293544523320217e-04, + -2.008501478596516066e-01, + -1.016318578095670166e-01, + 6.715777107960560488e-02, + 1.862588844222162754e-01, + -9.355767339931092552e-02, + 2.405851005618613869e-01, + -1.543923965166731360e-01, + -6.881011534701647614e-02, + 1.652293323598198749e-01, + -1.251034413458357586e-01, + 1.988420677015985405e-02, + -2.649793288818961257e-01, + 2.931400138614652096e-01, + -5.591003147909873183e-01, + -4.063284579846219158e-01, + -1.315208571745482835e-01, + -4.331117440048848355e-01, + -6.061090378540733292e-01, + 2.833828635442544042e-01, + 1.918609571473225972e-01, + 1.970841104720812420e-01, + 3.097713438305200739e-01, + 6.338697878210927117e-01, + 2.441240075513875507e-01, + 2.671479762635345367e-01, + 8.207590450429802509e-01, + -3.321173802625960736e-01, + 1.208172718023899639e-01, + -5.709923069462401468e-01, + 4.378905518390653340e-01, + 2.803383952198099105e-01, + 4.818083515067635852e-02, + 3.288723036125778543e-01, + 4.568212815808281313e-01, + -2.205844558239555830e-01, + -1.790930632302553671e-01, + -1.600790746380182095e-01, + 2.803383952198099105e-01, + 1.184924484528727451e+00, + 5.280975986537399525e-01, + 3.217531805643514264e-01, + 1.211448561860737794e+00, + -4.775164830407008787e-01, + 4.909065147537653440e-01, + -9.839528741162539838e-01, + 4.818083515067635852e-02, + 5.280975986537399525e-01, + 2.570327553754042649e-01, + 9.166999066254227779e-02, + 5.052143239146044129e-01, + -1.897550899742722341e-01, + 2.801748154632567878e-01, + -4.496013116553528310e-01, + 3.288723036125778543e-01, + 3.217531805643514264e-01, + 9.166999066254227779e-02, + 2.593365269807865192e-01, + 4.447725051818553488e-01, + -2.027935980045048603e-01, + -6.749380249578725011e-02, + -2.173901573078995675e-01, + 4.568212815808281313e-01, + 1.211448561860737794e+00, + 5.052143239146044129e-01, + 4.447725051818553488e-01, + 1.316578719027061473e+00, + -5.374363093691157944e-01, + 3.666830608571204353e-01, + -9.723643805738775292e-01, + -2.205844558239555830e-01, + -4.775164830407008787e-01, + -1.897550899742722341e-01, + -2.027935980045048603e-01, + -5.374363093691157944e-01, + 2.237214263935641545e-01, + -1.122709912087574979e-01, + 3.760362192894101119e-01, + -1.790930632302553671e-01, + 4.909065147537653440e-01, + 2.801748154632567878e-01, + -6.749380249578725011e-02, + 3.666830608571204353e-01, + -1.122709912087574979e-01, + 4.379920339886785308e-01, + -4.651501679265258593e-01, + -1.600790746380182095e-01, + -9.839528741162539838e-01, + -4.496013116553528310e-01, + -2.173901573078995675e-01, + -9.723643805738775292e-01, + 3.760362192894101119e-01, + -4.651501679265258593e-01, + 8.339826836209653926e-01, + -2.019987455578689528e-01, + -7.458382466808392008e-01, + -3.265377399991231666e-01, + -2.198489170601660714e-01, + -7.741919203025848795e-01, + 3.080513688669864747e-01, + -2.886593869231082743e-01, + 6.147449159647125905e-01, + 7.131193565321254646e-02, + 5.249313970468063584e-01, + 2.458973988601513838e-01, + 1.066531993642767101e-01, + 5.125435760822050213e-01, + -1.960840086240358826e-01, + 2.595614234479117211e-01, + -4.455571070106654208e-01, + 4.023781103462598097e-01, + 2.439986598498889436e-01, + 3.791545332624686460e-02, + 3.007270943868774471e-01, + 4.073693169405033787e-01, + -1.980921968504558150e-01, + -1.727376141424090128e-01, + -1.349192056315940413e-01, + 4.254523519602377291e-01, + 9.354988128396966029e-01, + 3.754166251626768758e-01, + 3.929303613668370665e-01, + 1.049997075315301398e+00, + -4.360231065678457640e-01, + 2.254237742176496972e-01, + -7.364430001865647224e-01, + -1.096490937195783028e-01, + 1.580908847286551788e-01, + 1.068758541049165955e-01, + -5.659607022330204879e-02, + 9.474341305433385541e-02, + -2.025508820369920082e-02, + 1.827501136383326286e-01, + -1.559512090836832898e-01, + 3.324229255148461459e-01, + 6.295495345956941824e-01, + 2.438920609066078660e-01, + 2.958568755730026645e-01, + 7.277856895597104581e-01, + -3.066669744223792793e-01, + 1.151140737695940108e-01, + -4.860400475978431944e-01, + 3.800286645551473885e-01, + 1.030846678700064523e+00, + 4.321346638586625599e-01, + 3.725917742884784500e-01, + 1.116358772562584312e+00, + -4.547357119141249293e-01, + 3.189530464986248259e-01, + -8.288056457618550033e-01, + -8.708450843312832979e-02, + -3.261851498223737877e-01, + -1.434582664155235332e-01, + -9.532600255607970308e-02, + -3.380393053647171020e-01, + 1.343028283928490718e-01, + -1.272684286809669663e-01, + 2.688332125008372486e-01, + -1.182987729627547613e-01, + 2.218483536179890747e-01, + 1.335847595358480577e-01, + -5.597980621885078473e-02, + 1.485338228432657171e-01, + -4.002313195306946014e-02, + 2.275728678460948173e-01, + -2.180286303147872251e-01, + 5.203035227557934600e-02, + -4.039081787251208033e-01, + -2.070299876563956587e-01, + -8.991377771592213783e-03, + -3.450698857581925294e-01, + 1.205275942669613021e-01, + -2.843701930822792323e-01, + 3.666462601777820685e-01, + -2.705826322727860056e-01, + -3.132742354154710029e-01, + -9.859239844920053564e-02, + -2.186706418403572272e-01, + -4.102509343627187000e-01, + 1.832099469879538423e-01, + 2.634881197618781090e-02, + 2.219945829310046026e-01, + 1.807596106975759565e-01, + 8.994429150699673192e-01, + 4.103717482408931194e-01, + 2.226815410105222326e-01, + 9.051338172534704185e-01, + -3.527781448781722728e-01, + 3.982850361239002601e-01, + -7.512015699620381293e-01, + 1.300102742718015625e+00, + 6.054999403594980567e-01, + 2.097253231448496791e-01, + 9.827165263729973343e-01, + 1.052183733453920800e+00, + -5.195049333939609770e-01, + -6.381641027068527539e-01, + -7.364899813768920056e-02, + 6.054999403594980567e-01, + 9.039638177949659292e-01, + 3.486898489277805435e-01, + 5.059022952874115964e-01, + 1.170428721365738722e+00, + -4.925183973772490065e-01, + 5.701278884854733137e-02, + -7.030345936214288383e-01, + 2.097253231448496791e-01, + 3.486898489277805435e-01, + 1.391402783708104718e-01, + 1.776896403010444214e-01, + 4.499490466203024952e-01, + -1.859372353636008224e-01, + 3.956661747329379836e-02, + -2.828032361752622625e-01, + 9.827165263729973343e-01, + 5.059022952874115964e-01, + 1.776896403010444214e-01, + 7.465734159825266891e-01, + 8.476440027475331540e-01, + -4.120299431266394308e-01, + -4.548760980436803991e-01, + -1.074321953463584917e-01, + 1.052183733453920800e+00, + 1.170428721365738722e+00, + 4.499490466203024952e-01, + 8.476440027475331540e-01, + 1.603809031278819575e+00, + -6.959819635037184371e-01, + -1.296058402624230832e-01, + -7.926580902398893125e-01, + -5.195049333939609770e-01, + -4.925183973772490065e-01, + -1.859372353636008224e-01, + -4.120299431266394308e-01, + -6.959819635037184371e-01, + 3.087929131865239096e-01, + 1.124120074707089523e-01, + 2.991135908127298126e-01, + -6.381641027068527539e-01, + 5.701278884854733137e-02, + 3.956661747329379836e-02, + -4.548760980436803991e-01, + -1.296058402624230832e-01, + 1.124120074707089523e-01, + 5.150411483832132431e-01, + -3.445907142048604821e-01, + -7.364899813768920056e-02, + -7.030345936214288383e-01, + -2.828032361752622625e-01, + -1.074321953463584917e-01, + -7.926580902398893125e-01, + 2.991135908127298126e-01, + -3.445907142048604821e-01, + 7.234936414299397711e-01, + -3.732452792355842597e-01, + -6.351284941938279971e-01, + -2.493644712191097901e-01, + -3.176635771364644811e-01, + -8.109315931741168937e-01, + 3.357484199566660443e-01, + -7.917224627668460746e-02, + 5.179274965636704309e-01, + 6.103879963205482984e-02, + 3.926374684337851639e-01, + 1.610613109134209353e-01, + 7.405475779758788346e-02, + 4.538001478377386433e-01, + -1.721953254916427645e-01, + 1.769878680633666101e-01, + -3.961762226546653443e-01, + 1.191157188899728991e+00, + 5.605474901628545448e-01, + 1.954757521943455700e-01, + 9.007411238079842120e-01, + 9.717425968880002429e-01, + -4.785550805028060961e-01, + -5.815036985906476552e-01, + -7.396165185977580936e-02, + 1.132707043019949067e+00, + 9.319961358271104945e-01, + 3.464538551202276939e-01, + 8.875056954741523674e-01, + 1.359864912960755401e+00, + -6.156841792369970312e-01, + -3.257019742083836489e-01, + -4.991680044292492457e-01, + -1.825740192995925215e-01, + -1.057937522442242186e-02, + 4.460297560427498031e-03, + -1.325177201685841522e-01, + -6.089046756479399997e-02, + 4.198807823721553006e-02, + 1.315877118315471328e-01, + -7.070128940557268704e-02, + 9.376629691212025053e-01, + 6.766047344331500568e-01, + 2.487179767062538960e-01, + 7.273093171529937395e-01, + 1.022193095621619419e+00, + -4.714878433053744100e-01, + -3.236961738617304407e-01, + -3.112248563810770530e-01, + 8.715441272817934237e-01, + 9.957608706283338496e-01, + 3.843068460465810521e-01, + 7.040812625593925178e-01, + 1.358621587856315260e+00, + -5.873343112872276839e-01, + -9.250561013998101489e-02, + -6.850860344768691101e-01, + -1.839873054978174161e-01, + -2.697675766943509745e-01, + -1.049195435465825332e-01, + -1.532710511861232472e-01, + -3.516162193145014880e-01, + 1.479201485584544506e-01, + -1.441931622091501171e-02, + 2.085957883977366534e-01, + -4.079828637765540722e-01, + -2.847368118114416286e-02, + -2.222134405267033215e-03, + -2.957398893337482293e-01, + -1.557361823194183814e-01, + 9.835505242304752593e-02, + 2.924429004194846748e-01, + -1.501487313940521318e-01, + 3.982233761717781650e-01, + -1.588944981924644839e-01, + -7.408969697337658422e-02, + 2.742647039835456879e-01, + -5.353523938839521018e-02, + -2.055105489653940437e-02, + -3.916635706554859042e-01, + 3.475448753197058482e-01, + -7.360058674331342310e-01, + -4.568919993931332968e-01, + -1.655906071391073375e-01, + -5.651141307901779154e-01, + -7.216348754100389007e-01, + 3.402776887471806178e-01, + 2.963766262538370944e-01, + 1.646029265129388963e-01, + 3.634385605888428850e-01, + 7.177088151935814286e-01, + 2.862230103263159187e-01, + 3.167632415194596152e-01, + 9.029296450871021618e-01, + -3.677650232872346492e-01, + 1.332516815067029126e-01, + -6.119232793393495351e-01, + 1.470617714177075541e+00, + 7.908160637269205928e-01, + 2.967611009382591924e-01, + 1.126498485385288628e+00, + 1.276801732029178904e+00, + -6.188661256643098740e-01, + -6.555616565594825085e-01, + -1.622820106502845250e-01, + 7.908160637269205928e-01, + 1.019847269601365625e+00, + 4.140808136319966692e-01, + 6.501651569008026765e-01, + 1.353982396050527948e+00, + -5.750737982370550672e-01, + -1.571684107553555573e-02, + -7.288264621550111233e-01, + 2.967611009382591924e-01, + 4.140808136319966692e-01, + 1.688255123604280317e-01, + 2.463489561234036285e-01, + 5.431181093161604467e-01, + -2.285272873930015991e-01, + 1.190173024325155877e-02, + -3.072061588815623856e-01, + 1.126498485385288628e+00, + 6.501651569008026765e-01, + 2.463489561234036285e-01, + 8.663277218548093295e-01, + 1.027034106828141224e+00, + -4.918915403527173713e-01, + -4.768929805439556802e-01, + -1.715930912354062809e-01, + 1.276801732029178904e+00, + 1.353982396050527948e+00, + 5.431181093161604467e-01, + 1.027034106828141224e+00, + 1.863909708290216072e+00, + -8.111554978845125774e-01, + -1.920275777457542132e-01, + -8.656707854826148907e-01, + -6.188661256643098740e-01, + -5.750737982370550672e-01, + -2.285272873930015991e-01, + -4.918915403527173713e-01, + -8.111554978845125774e-01, + 3.597341191090284718e-01, + 1.389081190062548543e-01, + 3.311289995354575466e-01, + -6.555616565594825085e-01, + -1.571684107553555573e-02, + 1.190173024325155877e-02, + -4.768929805439556802e-01, + -1.920275777457542132e-01, + 1.389081190062548543e-01, + 4.831489109598711695e-01, + -2.903972000340614423e-01, + -1.622820106502845250e-01, + -7.288264621550111233e-01, + -3.072061588815623856e-01, + -1.715930912354062809e-01, + -8.656707854826148907e-01, + 3.311289995354575466e-01, + -2.903972000340614423e-01, + 7.136178988114953992e-01, + -4.801359830987989574e-01, + -7.072071231901065902e-01, + -2.889551561531346069e-01, + -4.007779804898234932e-01, + -9.248816752531052732e-01, + 3.862384773157101492e-01, + -3.973170595224694707e-02, + 5.404661605326807061e-01, + 1.135129783310338814e-01, + 4.223520933079102369e-01, + 1.774298690373650178e-01, + 1.133778094061134029e-01, + 5.082598909026351253e-01, + -1.962563595187890308e-01, + 1.534628752553490738e-01, + -4.054692993770361853e-01, + 1.343835858041975362e+00, + 7.321279204278162700e-01, + 2.752106405208382123e-01, + 1.029969906078544550e+00, + 1.178294959727800517e+00, + -5.696579444659997105e-01, + -5.938038404477978816e-01, + -1.592103103982526191e-01, + 1.351996704100781876e+00, + 1.115390913195200673e+00, + 4.390524493312819576e-01, + 1.064629486341718945e+00, + 1.609751622898790480e+00, + -7.272045864926830472e-01, + -3.826981143794290774e-01, + -5.682521760868141092e-01, + -1.738843102777075378e-01, + -8.601801098605380003e-03, + 1.274148486290022043e-03, + -1.267748650962710610e-01, + -5.627986373366875017e-02, + 3.876373515740432624e-02, + 1.256954388033035552e-01, + -7.196565882631425493e-02, + 1.107641644352691745e+00, + 8.264319183306835237e-01, + 3.223156539290735756e-01, + 8.657358177050420434e-01, + 1.220391344258679345e+00, + -5.600641558420980104e-01, + -3.629685669468760700e-01, + -3.710150121819992819e-01, + 1.060129731097080752e+00, + 1.150689072388683787e+00, + 4.622609209910348849e-01, + 8.546229501566764419e-01, + 1.578087848817897987e+00, + -6.845234077831010566e-01, + -1.445537517343584433e-01, + -7.479025509047609876e-01, + -2.354087887671292845e-01, + -3.052486440759409336e-01, + -1.239441248907928866e-01, + -1.935320869814817213e-01, + -4.059118589235164998e-01, + 1.721661764329455446e-01, + 3.880989047236374587e-03, + 2.194887259681063407e-01, + -4.276691974179438471e-01, + -8.120234886415216013e-02, + -2.257704216701504271e-02, + -3.162911706641049481e-01, + -2.057942213675230947e-01, + 1.197982705434366413e-01, + 2.751309243674305294e-01, + -1.122352637998940972e-01, + 4.130879341289221407e-01, + -1.041506808305002196e-01, + -5.622961801539583249e-02, + 2.923612513907042354e-01, + -9.829481080015531683e-03, + -4.020157399527740649e-02, + -3.686448663143884197e-01, + 3.081438262949895979e-01, + -8.451881571484531896e-01, + -5.653460146886296611e-01, + -2.179605774453889921e-01, + -6.555215826327438489e-01, + -8.595224164327244232e-01, + 4.012363786307169833e-01, + 3.141583866609505282e-01, + 2.138404247587386564e-01, + 4.901499623543186246e-01, + 8.064861280426489643e-01, + 3.311403671368340351e-01, + 4.153443650103284535e-01, + 1.039836036114208717e+00, + -4.289759458199104425e-01, + 8.832851884678478982e-02, + -6.435274972517068814e-01, + ] + fit_ll: + [ + 1.212985093937938963e+00, + -1.381549085068657590e+00, + 1.085133808142466849e+00, + 1.321675583158594902e+00, + 9.996861251087960643e-01, + 1.241922641459953125e+00, + -1.382821053720374227e+00, + 1.086264457711279219e+00, + 1.331897794223656728e+00, + 1.002200480399525473e+00, + 1.217151059829681081e+00, + -1.381911178384377825e+00, + 1.086244130315317191e+00, + 1.327760578088094334e+00, + 1.001454813411551736e+00, + 1.178374034136443704e+00, + -1.379130236372588580e+00, + 1.081631474710690766e+00, + 1.304438695799553294e+00, + 9.939766256981953374e-01, + 1.237617107943818962e+00, + -1.382695233447721828e+00, + 1.086338611560096368e+00, + 1.331179464750698704e+00, + 1.002117303895980971e+00, + 1.242018332379969614e+00, + -1.382821853835294101e+00, + 1.086264346741943942e+00, + 1.331911247257247188e+00, + 1.002203014460946617e+00, + ] diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 44e3de30cb..b8183585ac 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -174,11 +174,23 @@ def test_descriptor(self) -> None: for ii, result in enumerate(self.case.results): if result.descriptor is None: continue - descpt = self.dp.eval_descriptor(result.coord, result.box, result.atype) + descpt = self.dp.eval_descriptor( + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, + ) expected_descpt = result.descriptor np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) # See #4533 - descpt = self.dp.eval_descriptor(result.coord, result.box, result.atype) + descpt = self.dp.eval_descriptor( + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, + ) expected_descpt = result.descriptor np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) @@ -190,12 +202,20 @@ def test_fitting_last_layer(self) -> None: if result.fit_ll is None: continue fit_ll = self.dp.eval_fitting_last_layer( - result.coord, result.box, result.atype + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, ) expected_fit_ll = result.fit_ll np.testing.assert_almost_equal(fit_ll.ravel(), expected_fit_ll.ravel()) fit_ll = self.dp.eval_fitting_last_layer( - result.coord, result.box, result.atype + result.coord, + result.box, + result.atype, + fparam=result.fparam, + aparam=result.aparam, ) expected_fit_ll = result.fit_ll np.testing.assert_almost_equal(fit_ll.ravel(), expected_fit_ll.ravel()) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index f0bef34cf8..6797fa2c03 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -6,6 +6,7 @@ """ import importlib +import os import tempfile import unittest import zipfile @@ -1299,5 +1300,636 @@ def test_pt2_vs_pte_consistency(self) -> None: np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") +class TestEvalTypeEbd(unittest.TestCase): + """Test eval_typeebd for pt_expt models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # se_e2_a model (no type embedding) + ds_sea = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_sea = EnergyFittingNet( + cls.nt, + ds_sea.get_dim_out(), + mixed_types=ds_sea.mixed_types(), + seed=GLOBAL_SEED, + ) + model_sea = EnergyModel(ds_sea, ft_sea, type_map=cls.type_map) + model_sea = model_sea.to(torch.float64) + model_sea.eval() + cls._tmpdir = tempfile.mkdtemp() + pte_sea = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_sea, {"model": model_sea.serialize()}) + cls.dp_sea = DeepPot(pte_sea) + + # DPA1 model (has type embedding) + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + ds_dpa1 = DescrptDPA1( + cls.rcut, + cls.rcut_smth, + cls.sel, + ntypes=cls.nt, + seed=GLOBAL_SEED, + ) + ft_dpa1 = EnergyFittingNet( + cls.nt, + ds_dpa1.get_dim_out(), + mixed_types=ds_dpa1.mixed_types(), + seed=GLOBAL_SEED, + ) + model_dpa1 = EnergyModel(ds_dpa1, ft_dpa1, type_map=cls.type_map) + model_dpa1 = model_dpa1.to(torch.float64) + model_dpa1.eval() + pte_dpa1 = os.path.join(cls._tmpdir, "dpa1.pte") + deserialize_to_file(pte_dpa1, {"model": model_dpa1.serialize()}) + cls.dp_dpa1 = DeepPot(pte_dpa1) + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def test_typeebd_dpa1(self) -> None: + """DPA1 model has type embedding, should return valid array.""" + typeebd = self.dp_dpa1.deep_eval.eval_typeebd() + self.assertEqual(typeebd.ndim, 2) + # DPA1 TypeEmbedNet outputs (ntypes+1) rows (padding type included) + self.assertIn(typeebd.shape[0], (self.nt, self.nt + 1)) + self.assertGreater(typeebd.shape[1], 0) + + def test_typeebd_sea_raises(self) -> None: + """se_e2_a model has no type embedding, should raise KeyError.""" + with self.assertRaises(KeyError): + self.dp_sea.deep_eval.eval_typeebd() + + +class TestEvalDescriptor(unittest.TestCase): + """Test eval_descriptor for pt_expt models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # se_e2_a model + ds_sea = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_sea = EnergyFittingNet( + cls.nt, + ds_sea.get_dim_out(), + mixed_types=ds_sea.mixed_types(), + seed=GLOBAL_SEED, + ) + model_sea = EnergyModel(ds_sea, ft_sea, type_map=cls.type_map) + model_sea = model_sea.to(torch.float64) + model_sea.eval() + cls._tmpdir = tempfile.mkdtemp() + pte_sea = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_sea, {"model": model_sea.serialize()}) + cls.dp_sea = DeepPot(pte_sea) + cls.dim_descrpt_sea = ds_sea.get_dim_out() + + # DPA1 model + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + ds_dpa1 = DescrptDPA1( + cls.rcut, + cls.rcut_smth, + cls.sel, + ntypes=cls.nt, + seed=GLOBAL_SEED, + ) + ft_dpa1 = EnergyFittingNet( + cls.nt, + ds_dpa1.get_dim_out(), + mixed_types=ds_dpa1.mixed_types(), + seed=GLOBAL_SEED, + ) + model_dpa1 = EnergyModel(ds_dpa1, ft_dpa1, type_map=cls.type_map) + model_dpa1 = model_dpa1.to(torch.float64) + model_dpa1.eval() + pte_dpa1 = os.path.join(cls._tmpdir, "dpa1.pte") + deserialize_to_file(pte_dpa1, {"model": model_dpa1.serialize()}) + cls.dp_dpa1 = DeepPot(pte_dpa1) + cls.dim_descrpt_dpa1 = ds_dpa1.get_dim_out() + + # se_e2_a model with fparam (dim_fparam=1, no default; swap _dpmodel + # because se_e2_a + fparam hits GuardOnDataDependentSymNode in export) + ds_fp = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_fp = EnergyFittingNet( + cls.nt, + ds_fp.get_dim_out(), + mixed_types=ds_fp.mixed_types(), + numb_fparam=1, + seed=GLOBAL_SEED, + ) + model_fp = EnergyModel(ds_fp, ft_fp, type_map=cls.type_map) + pte_fp = os.path.join(cls._tmpdir, "fp.pte") + deserialize_to_file(pte_fp, {"model": model_sea.serialize()}) + cls.dp_fp = DeepPot(pte_fp) + cls.dp_fp.deep_eval._dpmodel = model_fp + cls.dim_descrpt_fp = ds_fp.get_dim_out() + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_descriptor_shape_sea(self) -> None: + """se_e2_a descriptor has correct shape.""" + coords, cells, atom_types = self._make_inputs() + descpt = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + self.assertEqual(descpt.shape, (1, 6, self.dim_descrpt_sea)) + + def test_descriptor_shape_dpa1(self) -> None: + """DPA1 descriptor has correct shape.""" + coords, cells, atom_types = self._make_inputs() + descpt = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) + self.assertEqual(descpt.shape, (1, 6, self.dim_descrpt_dpa1)) + + def test_descriptor_deterministic_sea(self) -> None: + """Calling eval_descriptor twice gives same result for se_e2_a.""" + coords, cells, atom_types = self._make_inputs() + d1 = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + d2 = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + np.testing.assert_array_equal(d1, d2) + + def test_descriptor_deterministic_dpa1(self) -> None: + """Calling eval_descriptor twice gives same result for DPA1.""" + coords, cells, atom_types = self._make_inputs() + d1 = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) + d2 = self.dp_dpa1.deep_eval.eval_descriptor(coords, cells, atom_types) + np.testing.assert_array_equal(d1, d2) + + def test_descriptor_with_fparam(self) -> None: + """eval_descriptor works with fparam.""" + coords, cells, atom_types = self._make_inputs() + fparam = np.array([0.5], dtype=np.float64) + descpt = self.dp_fp.deep_eval.eval_descriptor( + coords, cells, atom_types, fparam=fparam + ) + self.assertEqual(descpt.shape, (1, 6, self.dim_descrpt_fp)) + + def test_descriptor_without_fparam_raises(self) -> None: + """eval_descriptor raises when fparam is required but not provided.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(ValueError): + self.dp_fp.deep_eval.eval_descriptor(coords, cells, atom_types) + + def test_descriptor_accepts_list_inputs(self) -> None: + """eval_descriptor accepts Python lists (not just np.ndarray).""" + coords, cells, atom_types = self._make_inputs() + descpt_arr = self.dp_sea.deep_eval.eval_descriptor(coords, cells, atom_types) + descpt_list = self.dp_sea.deep_eval.eval_descriptor( + coords.tolist(), cells.tolist(), atom_types.tolist() + ) + np.testing.assert_allclose(descpt_arr, descpt_list) + + def test_fitting_last_layer_accepts_list_inputs(self) -> None: + """eval_fitting_last_layer accepts Python lists (not just np.ndarray).""" + coords, cells, atom_types = self._make_inputs() + mid_arr = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + mid_list = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords.tolist(), cells.tolist(), atom_types.tolist() + ) + np.testing.assert_allclose(mid_arr, mid_list) + + +class TestEvalFittingLastLayer(unittest.TestCase): + """Test eval_fitting_last_layer for pt_expt models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + cls.neuron = [120, 120, 120] # default fitting net neurons + + # se_e2_a model (mixed_types=False) + ds_sea = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_sea = EnergyFittingNet( + cls.nt, + ds_sea.get_dim_out(), + mixed_types=ds_sea.mixed_types(), + seed=GLOBAL_SEED, + ) + model_sea = EnergyModel(ds_sea, ft_sea, type_map=cls.type_map) + model_sea = model_sea.to(torch.float64) + model_sea.eval() + cls._tmpdir = tempfile.mkdtemp() + pte_sea = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_sea, {"model": model_sea.serialize()}) + cls.dp_sea = DeepPot(pte_sea) + + # DPA1 model (mixed_types=True) + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + + ds_dpa1 = DescrptDPA1( + cls.rcut, + cls.rcut_smth, + cls.sel, + ntypes=cls.nt, + seed=GLOBAL_SEED, + ) + ft_dpa1 = EnergyFittingNet( + cls.nt, + ds_dpa1.get_dim_out(), + mixed_types=ds_dpa1.mixed_types(), + seed=GLOBAL_SEED, + ) + model_dpa1 = EnergyModel(ds_dpa1, ft_dpa1, type_map=cls.type_map) + model_dpa1 = model_dpa1.to(torch.float64) + model_dpa1.eval() + pte_dpa1 = os.path.join(cls._tmpdir, "dpa1.pte") + deserialize_to_file(pte_dpa1, {"model": model_dpa1.serialize()}) + cls.dp_dpa1 = DeepPot(pte_dpa1) + + # se_e2_a model with fparam and aparam (swap _dpmodel because + # se_e2_a + fparam hits GuardOnDataDependentSymNode in export) + ds_fp = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft_fp = EnergyFittingNet( + cls.nt, + ds_fp.get_dim_out(), + mixed_types=ds_fp.mixed_types(), + numb_fparam=1, + numb_aparam=2, + seed=GLOBAL_SEED, + ) + model_fp = EnergyModel(ds_fp, ft_fp, type_map=cls.type_map) + pte_fp = os.path.join(cls._tmpdir, "fp.pte") + deserialize_to_file(pte_fp, {"model": model_sea.serialize()}) + cls.dp_fp = DeepPot(pte_fp) + cls.dp_fp.deep_eval._dpmodel = model_fp + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_fitting_ll_shape_sea(self) -> None: + """se_e2_a fitting last layer has correct shape.""" + coords, cells, atom_types = self._make_inputs() + fit_ll = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + self.assertEqual(fit_ll.shape, (1, 6, self.neuron[-1])) + + def test_fitting_ll_shape_dpa1(self) -> None: + """DPA1 fitting last layer has correct shape.""" + coords, cells, atom_types = self._make_inputs() + fit_ll = self.dp_dpa1.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + self.assertEqual(fit_ll.shape, (1, 6, self.neuron[-1])) + + def test_fitting_ll_deterministic_sea(self) -> None: + """Verify calling twice gives the same result for se_e2_a.""" + coords, cells, atom_types = self._make_inputs() + fit_ll1 = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + fit_ll2 = self.dp_sea.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + np.testing.assert_array_equal(fit_ll1, fit_ll2) + + def test_fitting_ll_deterministic_dpa1(self) -> None: + """Verify calling twice gives the same result for DPA1.""" + coords, cells, atom_types = self._make_inputs() + fit_ll1 = self.dp_dpa1.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + fit_ll2 = self.dp_dpa1.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + np.testing.assert_array_equal(fit_ll1, fit_ll2) + + def test_fitting_ll_with_fparam_aparam(self) -> None: + """eval_fitting_last_layer works with fparam and aparam.""" + coords, cells, atom_types = self._make_inputs() + fparam = np.array([0.5], dtype=np.float64) + aparam = np.zeros((1, 6, 2), dtype=np.float64) + fit_ll = self.dp_fp.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types, fparam=fparam, aparam=aparam + ) + self.assertEqual(fit_ll.shape, (1, 6, self.neuron[-1])) + + def test_fitting_ll_without_fparam_raises(self) -> None: + """eval_fitting_last_layer raises when fparam is required but not provided.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(ValueError): + self.dp_fp.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) + + +class TestGetDpAtomicModel(unittest.TestCase): + """Test get_dp_atomic_model() API on various model types.""" + + def test_energy_model(self) -> None: + """Standard energy model returns a DPAtomicModel.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + + ds = DescrptSeA(4.0, 0.5, [8, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=False, seed=GLOBAL_SEED) + model = EnergyModel(ds, ft, type_map=["foo", "bar"]) + dp_am = model.get_dp_atomic_model() + self.assertIsNotNone(dp_am) + self.assertIsInstance(dp_am, DPAtomicModelDP) + self.assertTrue(hasattr(dp_am, "descriptor")) + self.assertTrue(hasattr(dp_am, "fitting_net")) + + def test_zbl_model_returns_none(self) -> None: + """DPZBLModel wraps a LinearEnergyAtomicModel, so get_dp_atomic_model returns None.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, + ) + from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, + ) + from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP + + ds = DescrptDPA1DP(4.0, 0.5, [14], ntypes=2) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=True, seed=GLOBAL_SEED) + dp_am = DPAtomicModelDP(ds, ft, type_map=["foo", "bar"]) + pair_tab = PairTabAtomicModel( + tab_file=None, rcut=4.0, sel=14, type_map=["foo", "bar"] + ) + zbl_am = DPZBLLinearEnergyAtomicModel( + dp_am, pair_tab, sw_rmin=1.0, sw_rmax=2.0, type_map=["foo", "bar"] + ) + # LinearEnergyAtomicModel is not a DPAtomicModel + self.assertFalse(isinstance(zbl_am, DPAtomicModelDP)) + + def test_spin_model_delegates(self) -> None: + """SpinModel.get_dp_atomic_model() delegates to backbone.""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + from deepmd.utils.spin import ( + Spin, + ) + + ds = DescrptSeA(4.0, 0.5, [8, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=False, seed=GLOBAL_SEED) + model = EnergyModel(ds, ft, type_map=["foo", "bar"]) + spin = Spin( + use_spin=[False, False], + virtual_scale=[0.0, 0.0], + ) + spin_model = SpinModel(backbone_model=model, spin=spin) + dp_am = spin_model.get_dp_atomic_model() + self.assertIsNotNone(dp_am) + self.assertIsInstance(dp_am, DPAtomicModelDP) + + def test_deserialized_model_delegates(self) -> None: + """Model deserialized from .pte exposes get_dp_atomic_model().""" + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel as DPAtomicModelDP, + ) + + ds = DescrptSeA(4.0, 0.5, [8, 6]) + ft = EnergyFittingNet(2, ds.get_dim_out(), mixed_types=False, seed=GLOBAL_SEED) + model = EnergyModel(ds, ft, type_map=["foo", "bar"]) + model = model.to(torch.float64) + model.eval() + + tmpdir = tempfile.mkdtemp() + try: + pte_path = os.path.join(tmpdir, "test.pte") + deserialize_to_file(pte_path, {"model": model.serialize()}) + dp = DeepPot(pte_path) + dp_am = dp.deep_eval._dpmodel.get_dp_atomic_model() + self.assertIsNotNone(dp_am) + self.assertIsInstance(dp_am, DPAtomicModelDP) + finally: + import shutil + + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestEvalDiagSpinModel(unittest.TestCase): + """Test eval diagnostic methods on spin models.""" + + @classmethod + def setUpClass(cls) -> None: + from deepmd.dpmodel.model.spin_model import ( + SpinModel, + ) + from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, + ) + from deepmd.utils.spin import ( + Spin, + ) + + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # DPA1 model with spin wrapper + ds = DescrptDPA1( + cls.rcut, cls.rcut_smth, cls.sel, ntypes=cls.nt, seed=GLOBAL_SEED + ) + ft = EnergyFittingNet( + cls.nt, ds.get_dim_out(), mixed_types=ds.mixed_types(), seed=GLOBAL_SEED + ) + backbone = EnergyModel(ds, ft, type_map=cls.type_map) + backbone = backbone.to(torch.float64) + backbone.eval() + + spin = Spin(use_spin=[True, False], virtual_scale=[0.5, 0.0]) + cls.spin_model = SpinModel(backbone_model=backbone, spin=spin) + + # Export backbone as .pte, then swap _dpmodel to SpinModel + cls._tmpdir = tempfile.mkdtemp() + pte_path = os.path.join(cls._tmpdir, "spin.pte") + deserialize_to_file(pte_path, {"model": backbone.serialize()}) + cls.dp = DeepPot(pte_path) + cls.dp.deep_eval._dpmodel = cls.spin_model + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + def _make_inputs(self): + nframes = 1 + natoms = 6 + coords = ( + np.random.default_rng(42).random((nframes, natoms, 3)).astype(np.float64) + ) + cells = 5.0 * np.eye(3, dtype=np.float64).reshape(1, 3, 3).repeat( + nframes, axis=0 + ) + atom_types = np.array([0, 0, 0, 1, 1, 1], dtype=int) + return coords, cells, atom_types + + def test_eval_typeebd_spin(self) -> None: + """eval_typeebd traverses backbone_model for spin models.""" + typeebd = self.dp.deep_eval.eval_typeebd() + self.assertEqual(typeebd.ndim, 2) + # DPA1 TypeEmbedNet outputs ntypes or ntypes+1 + self.assertIn(typeebd.shape[0], (self.nt, self.nt + 1)) + self.assertGreater(typeebd.shape[1], 0) + + def test_eval_descriptor_spin_raises(self) -> None: + """eval_descriptor raises NotImplementedError for spin models.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(NotImplementedError): + self.dp.deep_eval.eval_descriptor(coords, cells, atom_types) + + def test_eval_fitting_last_layer_spin_raises(self) -> None: + """eval_fitting_last_layer raises NotImplementedError for spin models.""" + coords, cells, atom_types = self._make_inputs() + with self.assertRaises(NotImplementedError): + self.dp.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) + + +class TestEvalDescriptorASE(unittest.TestCase): + """Test eval_descriptor with ASE neighbor list.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + model = EnergyModel(ds, ft, type_map=cls.type_map) + model = model.to(torch.float64) + model.eval() + cls.dim_descrpt = ds.get_dim_out() + + cls._tmpdir = tempfile.mkdtemp() + pte_path = os.path.join(cls._tmpdir, "sea.pte") + deserialize_to_file(pte_path, {"model": model.serialize()}) + cls.dp_native = DeepPot(pte_path) + + @classmethod + def tearDownClass(cls) -> None: + import shutil + + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_eval_descriptor_ase_vs_native(self) -> None: + """eval_descriptor with ASE nlist matches native nlist.""" + import ase.neighborlist + + pte_path = os.path.join(self._tmpdir, "sea.pte") + dp_ase = DeepPot( + pte_path, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, bothways=True + ), + ) + + rng = np.random.default_rng(GLOBAL_SEED + 99) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + d_native = self.dp_native.deep_eval.eval_descriptor(coords, cells, atom_types) + d_ase = dp_ase.deep_eval.eval_descriptor(coords, cells, atom_types) + + self.assertEqual(d_native.shape, d_ase.shape) + np.testing.assert_allclose(d_native, d_ase, rtol=1e-10, atol=1e-10) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_eval_fitting_last_layer_ase_vs_native(self) -> None: + """eval_fitting_last_layer with ASE nlist matches native nlist.""" + import ase.neighborlist + + pte_path = os.path.join(self._tmpdir, "sea.pte") + dp_ase = DeepPot( + pte_path, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, bothways=True + ), + ) + + rng = np.random.default_rng(GLOBAL_SEED + 99) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + f_native = self.dp_native.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types + ) + f_ase = dp_ase.deep_eval.eval_fitting_last_layer(coords, cells, atom_types) + + self.assertEqual(f_native.shape, f_ase.shape) + np.testing.assert_allclose(f_native, f_ase, rtol=1e-10, atol=1e-10) + + if __name__ == "__main__": unittest.main()