diff --git a/deepmd/entrypoints/embedding.py b/deepmd/entrypoints/embedding.py new file mode 100644 index 0000000000..d9d2cb1f86 --- /dev/null +++ b/deepmd/entrypoints/embedding.py @@ -0,0 +1,210 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Evaluate model embeddings using a trained DeePMD-kit model.""" + +import logging +import os +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import h5py +import numpy as np + +from deepmd.common import ( + expand_sys_str, +) +from deepmd.infer.deep_eval import ( + DeepEval, +) +from deepmd.utils.data import ( + DeepmdData, +) + +__all__ = ["embedding"] + +log = logging.getLogger(__name__) + +# Byte shuffle plus gzip gives a strong compression ratio on floating-point +# embeddings without any optional HDF5 plugin: shuffle groups the equal-order +# bytes of neighboring values so the deflate stage finds longer runs. +_HDF5_DATASET_KWARGS = { + "compression": "gzip", + "compression_opts": 9, + "shuffle": True, +} + + +def _unique_group_name(system_path: str, used_names: set[str]) -> str: + """ + Return a collision-free HDF5 group name derived from a system path. + + Parameters + ---------- + system_path : str + The source system directory. + used_names : set[str] + Group names already assigned within the output file. + + Returns + ------- + str + A unique group name based on the system directory's base name. + """ + base = os.path.basename(system_path.rstrip("/")) or "system" + name = base + idx = 1 + while name in used_names: + name = f"{base}_{idx}" + idx += 1 + used_names.add(name) + return name + + +def embedding( + *, + model: str, + system: str, + datafile: str, + output: str = "embedding.hdf5", + head: str | None = None, + dtype: str = "fp32", + **kwargs: Any, +) -> None: + """Evaluate embeddings for the given systems and store them in one HDF5 file. + + Three embeddings are produced per system in a single forward pass: the + per-atom ``descriptor``, the per-atom ``atomic_feature`` (the activation + after the last fitting hidden layer), and the per-structure + ``structural_feature`` (the masked atom-sum of ``atomic_feature``). + + Parameters + ---------- + model : str + Path where the model is stored. + system : str + System directory; systems are detected recursively. + datafile : str + Path to a file listing system directories, one per line. + output : str + Output HDF5 file. Each system becomes a group holding the three + embedding datasets. + head : str, optional + (Supported backend: PyTorch) Task head if in multi-task mode. + dtype : str + Output dtype for embedding arrays: ``"fp32"``, ``"fp64"``, or + ``"native"``. + **kwargs + Additional arguments. + + Notes + ----- + The output HDF5 file stores one group per system. The group name is the + system directory's base name (de-duplicated on collision), and the source + directory is recorded in the group's ``system`` attribute. Each group holds + the datasets ``descriptor`` (nframes, natoms, dim_descriptor), + ``atomic_feature`` (nframes, natoms, dim_hidden), + ``structural_feature`` (nframes, dim_hidden), and ``atom_types`` + (nframes, natoms), together with an ``nframes`` attribute; the frame axis + follows the system's frame order. The model ``type_map`` is stored as a + file-level attribute. The three embedding datasets are stored using the + selected ``dtype``, and all datasets use gzip + shuffle compression. + + Raises + ------ + RuntimeError + If no valid system was found. + """ + if datafile is not None: + with open(datafile) as datalist: + all_sys = [line.strip() for line in datalist if line.strip()] + else: + all_sys = expand_sys_str(system) + + if len(all_sys) == 0: + raise RuntimeError("Did not find valid system") + + dp = DeepEval(model, head=head) + + output_path = Path(output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with h5py.File(output_path, "w") as h5file: + h5file.attrs["type_map"] = np.array( + dp.get_type_map(), dtype=h5py.string_dtype() + ) + used_names: set[str] = set() + for system_path in all_sys: + log.info("# -------output of embedding------- ") + log.info(f"# processing system : {system_path}") + + tmap = dp.get_type_map() + data = DeepmdData( + system_path, + set_prefix="set", + shuffle_test=False, + type_map=tmap, + sort_atoms=False, + ) + + test_data = data.get_test() + mixed_type = data.mixed_type + nframes = test_data["box"].shape[0] + + coord = test_data["coord"].reshape([nframes, -1]) + box = test_data["box"] + if not data.pbc: + box = None + if mixed_type: + atype = test_data["type"].reshape([nframes, -1]) + else: + atype = test_data["type"][0] + + fparam = None + if dp.get_dim_fparam() > 0 and "fparam" in test_data: + fparam = test_data["fparam"] + aparam = None + if dp.get_dim_aparam() > 0 and "aparam" in test_data: + aparam = test_data["aparam"] + + log.info(f"# evaluating embeddings for {nframes} frames") + descriptor, atomic_feature, structural_feature = dp.eval_embedding( + coord, + box, + atype, + fparam=fparam, + aparam=aparam, + mixed_type=mixed_type, + dtype=dtype, + ) + + group_name = _unique_group_name(system_path, used_names) + group = h5file.create_group(group_name) + group.create_dataset("descriptor", data=descriptor, **_HDF5_DATASET_KWARGS) + group.create_dataset( + "atomic_feature", data=atomic_feature, **_HDF5_DATASET_KWARGS + ) + group.create_dataset( + "structural_feature", + data=structural_feature, + **_HDF5_DATASET_KWARGS, + ) + atom_types = np.asarray(atype, dtype=np.int32) + if atom_types.ndim == 1: + atom_types = np.tile(atom_types, (nframes, 1)) + group.create_dataset("atom_types", data=atom_types, **_HDF5_DATASET_KWARGS) + group.attrs["nframes"] = int(nframes) + group.attrs["system"] = str(system_path) + + log.info( + f"# stored group '{group_name}': " + f"descriptor {descriptor.shape}, " + f"atomic_feature {atomic_feature.shape}, " + f"structural_feature {structural_feature.shape}" + ) + log.info("# ----------------------------------- ") + + log.info(f"# embeddings saved to {output_path}") + log.info("# embedding completed successfully") diff --git a/deepmd/entrypoints/eval_desc.py b/deepmd/entrypoints/eval_desc.py index dc5f8df955..3fed69b3a4 100644 --- a/deepmd/entrypoints/eval_desc.py +++ b/deepmd/entrypoints/eval_desc.py @@ -34,6 +34,7 @@ def eval_desc( datafile: str, output: str = "desc", head: str | None = None, + dtype: str = "native", **kwargs: Any, ) -> None: """Evaluate descriptors for given systems. @@ -50,6 +51,9 @@ def eval_desc( output directory for descriptor files head : Optional[str], optional (Supported backend: PyTorch) Task head if in multi-task mode. + dtype : str + Output dtype for descriptor arrays: ``"fp32"``, ``"fp64"``, or + ``"native"``. **kwargs additional arguments @@ -65,7 +69,7 @@ def eval_desc( """ if datafile is not None: with open(datafile) as datalist: - all_sys = datalist.read().splitlines() + all_sys = [line.strip() for line in datalist if line.strip()] else: all_sys = expand_sys_str(system) @@ -129,6 +133,7 @@ def eval_desc( fparam=fparam, aparam=aparam, mixed_type=mixed_type, + dtype=dtype, ) # descriptors are kept in 3D format (nframes, natoms, ndesc) diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index 86c9687bd4..26a89238d6 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -18,6 +18,9 @@ from deepmd.entrypoints.doc import ( doc_train_input, ) +from deepmd.entrypoints.embedding import ( + embedding, +) from deepmd.entrypoints.eval_desc import ( eval_desc, ) @@ -79,6 +82,14 @@ def main(args: argparse.Namespace) -> None: strict_prefer=False, ) eval_desc(**dict_args) + elif args.command == "embed": + dict_args["model"] = format_model_suffix( + dict_args["model"], + feature=Backend.Feature.DEEP_EVAL, + preferred_backend=args.backend, + strict_prefer=False, + ) + embedding(**dict_args) elif args.command == "doc-train-input": doc_train_input(**dict_args) elif args.command == "model-devi": diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 557f3ddd23..b4f2b99449 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -271,6 +271,59 @@ def eval_fitting_last_layer( """ raise NotImplementedError + def eval_embedding( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + dtype: str = "fp32", + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Evaluate the descriptor, atomic feature, and structural feature. + + A single forward pass produces all three embeddings without force or + virial autograd. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + dtype + Output dtype: ``"fp32"``, ``"fp64"``, or ``"native"``. + + Returns + ------- + descriptor + The per-atom descriptor, of size nframes x natoms x dim_descriptor. + atomic_feature + The per-atom last hidden activation, of size + nframes x natoms x dim_hidden. + structural_feature + The per-structure pooled feature, of size nframes x dim_hidden. + """ + raise NotImplementedError + def eval_typeebd(self) -> np.ndarray: """Evaluate output of type embedding network by using this model. @@ -381,6 +434,35 @@ def get_model(self) -> Any: """ +def _cast_output_dtype(array: np.ndarray, dtype: str) -> np.ndarray: + """Cast a backend evaluation output to the requested output dtype. + + The cast is performed in this backend-agnostic wrapper so every backend + shares identical ``--dtype`` behavior: the backend always returns its + native precision, and this high-level API decides the emitted dtype. + + Parameters + ---------- + array + The array returned by a backend evaluation. + dtype + Output dtype: ``"fp32"``, ``"fp64"``, or ``"native"``. ``"native"`` + leaves the backend precision unchanged. + + Returns + ------- + np.ndarray + The array cast to the requested precision. + """ + if dtype == "native": + return array + if dtype == "fp32": + return array.astype(np.float32) + if dtype == "fp64": + return array.astype(np.float64) + raise ValueError(f"Unknown dtype {dtype!r}; expected 'fp32', 'fp64', or 'native'.") + + class DeepEval(ABC): """High-level Deep Evaluator interface. @@ -503,6 +585,7 @@ def eval_descriptor( fparam: np.ndarray | None = None, aparam: np.ndarray | None = None, mixed_type: bool = False, + dtype: str = "native", **kwargs: Any, ) -> np.ndarray: """Evaluate descriptors by using this DP. @@ -537,6 +620,8 @@ def eval_descriptor( Whether to perform the mixed_type mode. If True, the input data has the mixed_type format (see doc/model/train_se_atten.md), in which frames in a system may have different natoms_vec(s), with the same nloc. + dtype + Output dtype: ``"fp32"``, ``"fp64"``, or ``"native"``. Returns ------- @@ -553,14 +638,9 @@ def eval_descriptor( natoms, ) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type) descriptor = self.deep_eval.eval_descriptor( - coords, - cells, - atom_types, - fparam=fparam, - aparam=aparam, - **kwargs, + coords, cells, atom_types, fparam=fparam, aparam=aparam, **kwargs ) - return descriptor + return _cast_output_dtype(descriptor, dtype) def eval_fitting_last_layer( self, @@ -570,6 +650,7 @@ def eval_fitting_last_layer( fparam: np.ndarray | None = None, aparam: np.ndarray | None = None, mixed_type: bool = False, + dtype: str = "native", **kwargs: Any, ) -> np.ndarray: """Evaluate fitting before last layer by using this DP. @@ -604,6 +685,8 @@ def eval_fitting_last_layer( Whether to perform the mixed_type mode. If True, the input data has the mixed_type format (see doc/model/train_se_atten.md), in which frames in a system may have different natoms_vec(s), with the same nloc. + dtype + Output dtype: ``"fp32"``, ``"fp64"``, or ``"native"``. Returns ------- @@ -620,14 +703,95 @@ def eval_fitting_last_layer( natoms, ) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type) fitting = self.deep_eval.eval_fitting_last_layer( + coords, cells, atom_types, fparam=fparam, aparam=aparam, **kwargs + ) + return _cast_output_dtype(fitting, dtype) + + def eval_embedding( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + mixed_type: bool = False, + dtype: str = "fp32", + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Evaluate the descriptor, atomic feature, and structural feature. + + A single forward pass produces all three embeddings without force or + virial autograd. The descriptor is the per-atom local-environment + representation; the atomic feature is the activation after the last + fitting hidden layer; the structural feature is the masked atom-sum of + the atomic feature, a whole-structure summary. For models with a single + shared fitting network, projecting the structural feature through the + fitting output layer reproduces the (bias-free) total energy. The output + precision is selected by ``dtype`` and defaults to float32. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + mixed_type + Whether to perform the mixed_type mode. + If True, the input data has the mixed_type format (see doc/model/train_se_atten.md), + in which frames in a system may have different natoms_vec(s), with the same nloc. + dtype + Output dtype: ``"fp32"``, ``"fp64"``, or ``"native"``. + + Returns + ------- + descriptor + The per-atom descriptor, of size nframes x natoms x dim_descriptor. + atomic_feature + The per-atom last hidden activation, of size + nframes x natoms x dim_hidden. + structural_feature + The per-structure pooled feature, of size nframes x dim_hidden. + + Raises + ------ + NotImplementedError + If the loaded model does not support embedding extraction. + """ + ( + coords, + cells, + atom_types, + fparam, + aparam, + nframes, + natoms, + ) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type) + return self.deep_eval.eval_embedding( coords, cells, atom_types, fparam=fparam, aparam=aparam, + dtype=dtype, **kwargs, ) - return fitting def eval_typeebd(self) -> np.ndarray: """Evaluate output of type embedding network by using this model. diff --git a/deepmd/main.py b/deepmd/main.py index bf59dfdad5..43f40dc214 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -492,6 +492,73 @@ def main_parser() -> argparse.ArgumentParser: help="Output directory for descriptor files. Descriptors will be saved as desc/(system_name).npy", ) parser_eval_desc.add_argument( + "--dtype", + choices=["fp32", "fp64", "native"], + default="native", + type=str, + help="Output dtype for descriptors. `native` keeps the model output precision (default).", + ) + parser_eval_desc.add_argument( + "--head", + "--model-branch", + default=None, + type=str, + help="(Supported backend: PyTorch) Task head (alias: model branch) to use if in multi-task mode.", + ) + + # * embedding script ************************************************************* + parser_embedding = subparsers.add_parser( + "embed", + parents=[parser_log], + help="evaluate model embeddings (descriptor, atomic feature, structural feature)", + formatter_class=RawTextArgumentDefaultsHelpFormatter, + epilog=textwrap.dedent( + """\ + examples: + dp embed -m model.ckpt.pt -s /path/to/system -o embedding.hdf5 + """ + ), + ) + parser_embedding.add_argument( + "-m", + "--model", + default="model.ckpt.pt", + type=str, + help="(Supported backend: PyTorch) Energy model to import: a training " + "checkpoint (suffix .pt) or a frozen model (suffix .pth). SeZM/DPA4 " + "only supports the .pt checkpoint; the frozen .pt2 package is not supported.", + ) + parser_embedding_subgroup = parser_embedding.add_mutually_exclusive_group() + parser_embedding_subgroup.add_argument( + "-s", + "--system", + default=".", + type=str, + help="The system dir. Recursively detect systems in this directory", + ) + parser_embedding_subgroup.add_argument( + "-f", + "--datafile", + default=None, + type=str, + help="The path to the datafile, each line of which is a path to one data system.", + ) + parser_embedding.add_argument( + "-o", + "--output", + default="embedding.hdf5", + type=str, + help="Output HDF5 file. Each system is stored as a group holding the " + "descriptor, atomic_feature, and structural_feature datasets.", + ) + parser_embedding.add_argument( + "--dtype", + choices=["fp32", "fp64", "native"], + default="fp32", + type=str, + help="Output dtype for embeddings. `native` keeps the model output precision.", + ) + parser_embedding.add_argument( "--head", "--model-branch", default=None, @@ -1032,6 +1099,7 @@ def main(args: list[str] | None = None) -> None: if args.command in ( "test", "eval-desc", + "embed", "doc-train-input", "model-devi", "neighbor-stat", diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index e3f195ac67..d822fdc462 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -98,6 +98,12 @@ log = logging.getLogger(__name__) +_EMBEDDING_DTYPE_TO_TORCH = { + "fp32": torch.float32, + "fp64": torch.float64, +} + + def _is_sezm_model_params(model_params: dict[str, Any]) -> bool: """Return whether the params describe a SeZM / DPA4 model.""" model_type = str(model_params.get("type", "")).lower() @@ -1030,6 +1036,13 @@ def eval_descriptor( ) -> np.ndarray: """Evaluate descriptors by using this DP. + .. deprecated:: + Use :meth:`eval_embedding` instead, which returns the descriptor + together with the atomic and structural features in a single + forward pass. This method is a thin wrapper kept for compatibility. + For models frozen before ``forward_embedding`` existed, it falls + back to the descriptor hook baked into that TorchScript module. + Parameters ---------- coords @@ -1060,30 +1073,27 @@ def eval_descriptor( Descriptors. """ model = self.dp.model["Default"] - while True: - if self.auto_batch_size is not None: - self.auto_batch_size.set_oom_retry_mode(True) - model.set_eval_descriptor_hook(True) - retry = False - try: - self.eval( - coords, - cells, - atom_types, - atomic=False, - fparam=fparam, - aparam=aparam, - **kwargs, - ) - descriptor = model.eval_descriptor() - except RetrySignal: - retry = True - finally: - model.set_eval_descriptor_hook(False) - if self.auto_batch_size is not None: - self.auto_batch_size.set_oom_retry_mode(False) - if not retry: - return to_numpy_array(descriptor) + if not hasattr(model, "forward_embedding"): + return self._eval_legacy_feature( + coords, + cells, + atom_types, + fparam, + aparam, + enable_hook=model.set_eval_descriptor_hook, + read_feature=model.eval_descriptor, + **kwargs, + ) + descriptor, _, _ = self.eval_embedding( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + dtype="native", + **kwargs, + ) + return descriptor def eval_fitting_last_layer( self, @@ -1096,6 +1106,13 @@ def eval_fitting_last_layer( ) -> np.ndarray: """Evaluate fitting before last layer by using this DP. + .. deprecated:: + Use :meth:`eval_embedding` instead, which returns this activation as + the ``atomic_feature`` output. This method is a thin wrapper kept + for compatibility. For models frozen before ``forward_embedding`` + existed, it falls back to the fitting-last-layer hook baked into + that TorchScript module. + Parameters ---------- coords @@ -1126,10 +1143,51 @@ def eval_fitting_last_layer( Fitting output before last layer. """ model = self.dp.model["Default"] + if not hasattr(model, "forward_embedding"): + return self._eval_legacy_feature( + coords, + cells, + atom_types, + fparam, + aparam, + enable_hook=model.set_eval_fitting_last_layer_hook, + read_feature=model.eval_fitting_last_layer, + **kwargs, + ) + _, atomic_feature, _ = self.eval_embedding( + coords, + cells, + atom_types, + fparam=fparam, + aparam=aparam, + dtype="native", + **kwargs, + ) + return atomic_feature + + def _eval_legacy_feature( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + *, + enable_hook: Callable[[bool], None], + read_feature: Callable[[], torch.Tensor], + **kwargs: Any, + ) -> np.ndarray: + """Extract a cached descriptor or fitting feature from a legacy frozen model. + + Models frozen before ``forward_embedding`` expose these features only + through a hook that caches them during a forward pass. The retry loop + keeps the cache consistent when the auto batch size splits the forward + and hits an out-of-memory condition. + """ while True: if self.auto_batch_size is not None: self.auto_batch_size.set_oom_retry_mode(True) - model.set_eval_fitting_last_layer_hook(True) + enable_hook(True) retry = False try: self.eval( @@ -1141,12 +1199,184 @@ def eval_fitting_last_layer( aparam=aparam, **kwargs, ) - fitting_net = model.eval_fitting_last_layer() + feature = read_feature() except RetrySignal: retry = True finally: - model.set_eval_fitting_last_layer_hook(False) + enable_hook(False) if self.auto_batch_size is not None: self.auto_batch_size.set_oom_retry_mode(False) if not retry: - return to_numpy_array(fitting_net) + return to_numpy_array(feature) + + def eval_embedding( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + charge_spin: np.ndarray | None = None, + dtype: str = "fp32", + **kwargs: Any, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Evaluate the descriptor, atomic feature, and structural feature. + + A single forward pass produces all three embeddings without any + force or virial autograd. The descriptor is the per-atom + local-environment representation; the atomic feature is the activation + after the last fitting hidden layer; the structural feature is the + masked atom-sum of the atomic feature, a whole-structure summary. For + models with a single shared fitting network, projecting the structural + feature through the fitting output layer reproduces the (bias-free) + total energy. The output precision is selected by ``dtype`` and + defaults to float32. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + charge_spin + The frame-level charge and spin conditions. + The array should be of size nframes x 2 + dtype + Output dtype: ``"fp32"``, ``"fp64"``, or ``"native"``. + + Returns + ------- + descriptor + The per-atom descriptor, of size nframes x natoms x dim_descriptor. + atomic_feature + The per-atom last hidden activation, of size + nframes x natoms x dim_hidden. + structural_feature + The per-structure pooled feature, of size nframes x dim_hidden. + + Raises + ------ + NotImplementedError + If the loaded model does not support embedding extraction. + """ + if self._has_spin: + raise NotImplementedError( + "eval_embedding is not supported for spin models in the " + "PyTorch backend." + ) + if dtype not in ("fp32", "fp64", "native"): + raise ValueError("dtype must be one of 'fp32', 'fp64', or 'native'.") + if not hasattr(self.dp.model["Default"], "forward_embedding"): + raise NotImplementedError( + "eval_embedding requires a model frozen with forward_embedding " + "support. Please re-freeze the model with a newer DeePMD-kit " + "version." + ) + atom_types = np.array(atom_types, dtype=np.int32) + coords = np.array(coords) + if cells is not None: + cells = np.array(cells) + natoms, numb_test = self._get_natoms_and_nframes( + coords, atom_types, len(atom_types.shape) > 1 + ) + return self._eval_func(self._eval_embedding, numb_test, natoms)( + coords, cells, atom_types, fparam, aparam, charge_spin, dtype + ) + + def _eval_embedding( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + charge_spin: np.ndarray | None, + dtype: str, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + self.dp.to(DEVICE) + # A data modifier augments physical outputs after the model forward in + # ModelWrapper. It does not define or transform descriptor/fitting + # features, so embeddings are intentionally taken from the neural model. + model = self.dp.model["Default"] + prec = NP_PRECISION_DICT[RESERVED_PRECISION_DICT[GLOBAL_PT_FLOAT_PRECISION]] + + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = torch.tensor( + coords.reshape([nframes, natoms, 3]).astype(prec), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + type_input = torch.tensor( + atom_types.astype(NP_PRECISION_DICT[RESERVED_PRECISION_DICT[torch.long]]), + dtype=torch.long, + device=DEVICE, + ) + box_input = ( + torch.tensor( + cells.reshape([nframes, 3, 3]).astype(prec), + dtype=GLOBAL_PT_FLOAT_PRECISION, + device=DEVICE, + ) + if cells is not None + else None + ) + fparam_input = ( + to_torch_tensor(fparam.reshape(nframes, self.get_dim_fparam())) + if fparam is not None + else None + ) + aparam_input = ( + to_torch_tensor(aparam.reshape(nframes, natoms, self.get_dim_aparam())) + if aparam is not None + else None + ) + charge_spin_input = ( + to_torch_tensor(charge_spin.reshape(nframes, 2)) + if charge_spin is not None + else None + ) + out = model.forward_embedding( + coord_input, + type_input, + box=box_input, + fparam=fparam_input, + aparam=aparam_input, + charge_spin=charge_spin_input, + ) + + def cast_output(value: torch.Tensor) -> np.ndarray: + value = value.detach() + if dtype != "native": + value = value.to(_EMBEDDING_DTYPE_TO_TORCH[dtype]) + return value.cpu().numpy() + + # Single output-precision boundary: the model produces embeddings in its + # native precision, and this API chooses the emitted dtype. + return ( + cast_output(out["descriptor"]), + cast_output(out["atomic_feature"]), + cast_output(out["structural_feature"]), + ) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 8605db9359..e700a49b32 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -418,6 +418,35 @@ def forward( charge_spin=charge_spin, ) + def has_embedding(self) -> bool: + """Whether this atomic model can produce ``forward_embedding`` outputs. + + False for atomic models without a descriptor-fitting pair (e.g. a pure + tabulated pair potential); linear combinations report True when any + sub-model supports it. + """ + return False + + def forward_embedding( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Extract model embeddings; only implemented by descriptor-fitting models. + + Defined here so the model-level ``forward_embedding`` (and TorchScript) + always resolves the call; atomic models without a fitting net (e.g. a + pure tabulated pair potential) inherit this guard. + """ + raise NotImplementedError( + "forward_embedding is not supported for this atomic model." + ) + def change_type_map( self, type_map: list[str], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index d59d518cab..a5ce444fd3 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -69,44 +69,6 @@ def __init__( self.add_chg_spin_ebd: bool = getattr( self.descriptor, "add_chg_spin_ebd", False ) - self.enable_eval_descriptor_hook = False - self.enable_eval_fitting_last_layer_hook = False - self.eval_descriptor_list = [] - self.eval_fitting_last_layer_list = [] - - eval_descriptor_list: list[torch.Tensor] - eval_fitting_last_layer_list: list[torch.Tensor] - - def set_eval_descriptor_hook(self, enable: bool) -> None: - """Set the hook for evaluating descriptor and clear the cache for descriptor list.""" - self.enable_eval_descriptor_hook = enable - # = [] does not work; See #4533 - self.eval_descriptor_list.clear() - - def eval_descriptor(self) -> torch.Tensor: - """Evaluate the descriptor.""" - if not self.eval_descriptor_list: - raise RuntimeError( - "eval_descriptor_list is empty. " - "Call set_eval_descriptor_hook(True) and perform a forward pass first." - ) - return torch.concat(self.eval_descriptor_list) - - def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: - """Set the hook for evaluating fitting last layer output and clear the cache for fitting last layer output list.""" - self.enable_eval_fitting_last_layer_hook = enable - self.fitting_net.set_return_middle_output(enable) - # = [] does not work; See #4533 - self.eval_fitting_last_layer_list.clear() - - def eval_fitting_last_layer(self) -> torch.Tensor: - """Evaluate the fitting last layer output.""" - if not self.eval_fitting_last_layer_list: - raise RuntimeError( - "eval_fitting_last_layer_list is empty. " - "Call set_eval_fitting_last_layer_hook(True) and perform a forward pass first." - ) - return torch.concat(self.eval_fitting_last_layer_list) @torch.jit.export def fitting_output_def(self) -> FittingOutputDef: @@ -245,6 +207,7 @@ def forward_atomic( aparam: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, charge_spin: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: """Return atomic prediction. @@ -262,16 +225,28 @@ def forward_atomic( frame parameter. nf x ndf aparam atomic parameter. nf x nloc x nda + return_atomic_feature + When True, run the fitting net only up to its last hidden layer with + no force/virial autograd, and additionally return the raw per-atom + ``descriptor`` and the last hidden ``atomic_feature``. Used by the + embedding path. Returns ------- result_dict - the result dict, defined by the `FittingOutputDef`. + the result dict, defined by the `FittingOutputDef`. When + ``return_atomic_feature`` is True, it also contains ``descriptor`` + and ``atomic_feature``. """ nframes, nloc, nnei = nlist.shape atype = extended_atype[:, :nloc] - if (self.do_grad_r() or self.do_grad_c()) and not extended_coord.requires_grad: + # The embedding path produces no force and never allocates an autograd leaf. + if ( + not return_atomic_feature + and (self.do_grad_r() or self.do_grad_c()) + and not extended_coord.requires_grad + ): extended_coord = extended_coord.clone().requires_grad_(True) # Handle default chg_spin if descriptor supports it @@ -290,9 +265,19 @@ def forward_atomic( charge_spin=charge_spin if self.add_chg_spin_ebd else None, ) assert descriptor is not None - if self.enable_eval_descriptor_hook: - self.eval_descriptor_list.append(descriptor.detach()) - # energy, force + if return_atomic_feature: + fit_ret = self.fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + return_atomic_feature=True, + ) + fit_ret["descriptor"] = descriptor + return fit_ret fit_ret = self.fitting_net( descriptor, atype, @@ -302,15 +287,85 @@ def forward_atomic( fparam=fparam, aparam=aparam, ) - if self.enable_eval_fitting_last_layer_hook: - assert "middle_output" in fit_ret, ( - "eval_fitting_last_layer not supported for this fitting net!" - ) - self.eval_fitting_last_layer_list.append( - fit_ret.pop("middle_output").detach() - ) return fit_ret + def has_embedding(self) -> bool: + """A standard descriptor-fitting atomic model supports embeddings.""" + return True + + def forward_embedding( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Extract embeddings, reusing the descriptor and fitting forward. + + The neighbor/type masking mirrors `forward_common_atomic` so that the + descriptor matches the energy forward, and the heavy descriptor and + fitting work is delegated to `forward_atomic` with + ``return_atomic_feature=True``. + + Parameters + ---------- + extended_coord + Extended coordinates with shape (nf, nall, 3). + extended_atype + Extended atom types with shape (nf, nall). + nlist + Neighbor list with shape (nf, nloc, nnei). + mapping + Extended-to-local index map with shape (nf, nall), or None. + fparam + Frame parameters with shape (nf, dim_fparam), or None. + aparam + Atomic parameters with shape (nf, nloc, dim_aparam), or None. + charge_spin + Frame-level charge and spin conditions with shape (nf, 2), or None. + + Returns + ------- + dict[str, torch.Tensor] + ``descriptor`` with shape (nf, nloc, d), ``atomic_feature`` (the last + fitting hidden activation) with shape (nf, nloc, h), and + ``structural_feature`` (the masked atom-sum of ``atomic_feature``) + with shape (nf, h). + """ + _, nloc, _ = nlist.shape + # Original local types drive the output mask; masked types feed the nets. + atype = extended_atype[:, :nloc] + if self.pair_excl is not None: + pair_mask = self.pair_excl(nlist, extended_atype) + nlist = torch.where(pair_mask == 1, nlist, -1) + ext_atom_mask = self.make_atom_mask(extended_atype) + fit_ret = self.forward_atomic( + extended_coord, + torch.where(ext_atom_mask, extended_atype, 0), + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + return_atomic_feature=True, + ) + atomic_feature = fit_ret["atomic_feature"] + # nf x nloc + atom_mask = ext_atom_mask[:, :nloc].to(torch.int32) + if self.atom_excl is not None: + atom_mask = atom_mask * self.atom_excl(atype) + structural_feature = ( + atomic_feature * atom_mask[:, :, None].to(atomic_feature.dtype) + ).sum(dim=1) + return { + "descriptor": fit_ret["descriptor"], + "atomic_feature": atomic_feature, + "structural_feature": structural_feature, + } + def compute_or_load_stat( self, sampled_func: Callable[[], list[dict]], diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index 6c620f6f5b..245341a460 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -271,12 +271,14 @@ def forward_atomic( raw_nlists = [ nlists[get_multiple_nlist_key(rcut, sel)] - for rcut, sel in zip(self.get_model_rcuts(), self.get_model_nsels()) + for rcut, sel in zip( + self.get_model_rcuts(), self.get_model_nsels(), strict=True + ) ] nlists_ = [ nl if mt else nlist_distinguish_types(nl, extended_atype, sel) for mt, nl, sel in zip( - self.mixed_types_list, raw_nlists, self.get_model_sels() + self.mixed_types_list, raw_nlists, self.get_model_sels(), strict=True ) ] ener_list = [] @@ -306,6 +308,66 @@ def forward_atomic( } # (nframes, nloc, 1) return fit_ret + def has_embedding(self) -> bool: + """A linear model supports embeddings if any sub-model does.""" + for model in self.models: + if model.has_embedding(): + return True + return False + + def forward_embedding( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Return the embedding of the first descriptor-fitting sub-model. + + A linear/ZBL combination carries one learned descriptor-fitting model + (e.g. the DP part of a DP+ZBL model) whose embedding is well defined; the + per-sub-model neighbor lists are rebuilt exactly as in `forward_atomic`. + """ + nframes, nloc, nnei = nlist.shape + extended_coord = extended_coord.view(nframes, -1, 3) + sorted_rcuts, sorted_sels = self._sort_rcuts_sels() + nlists = build_multiple_neighbor_list( + extended_coord.detach(), + nlist, + sorted_rcuts, + sorted_sels, + ) + raw_nlists = [ + nlists[get_multiple_nlist_key(rcut, sel)] + for rcut, sel in zip( + self.get_model_rcuts(), self.get_model_nsels(), strict=True + ) + ] + nlists_ = [ + nl if mt else nlist_distinguish_types(nl, extended_atype, sel) + for mt, nl, sel in zip( + self.mixed_types_list, raw_nlists, self.get_model_sels(), strict=True + ) + ] + for i, model in enumerate(self.models): + if model.has_embedding(): + type_map_model = self.mapping_list[i].to(extended_atype.device) + return model.forward_embedding( + extended_coord, + type_map_model[extended_atype], + nlists_[i], + mapping, + fparam, + aparam, + charge_spin, + ) + raise NotImplementedError( + "This linear model has no embedding-capable sub-model." + ) + def apply_out_stat( self, ret: dict[str, torch.Tensor], diff --git a/deepmd/pt/model/atomic_model/sezm_atomic_model.py b/deepmd/pt/model/atomic_model/sezm_atomic_model.py index e96cd6b761..d27dad8b38 100644 --- a/deepmd/pt/model/atomic_model/sezm_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sezm_atomic_model.py @@ -498,23 +498,6 @@ def fitting_output_def(self) -> FittingOutputDef: return super().fitting_output_def() return active_fitting.output_def() - def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: - """ - Set the fitting-last-layer evaluation hook for the active fitting path. - - Parameters - ---------- - enable - Whether to enable the hook. - """ - self.enable_eval_fitting_last_layer_hook = enable - active_fitting = self.get_active_fitting_net() - if active_fitting is not None and hasattr( - active_fitting, "set_return_middle_output" - ): - active_fitting.set_return_middle_output(enable) - self.eval_fitting_last_layer_list.clear() - def change_type_map( self, type_map: list[str], diff --git a/deepmd/pt/model/descriptor/sezm_nn/dens.py b/deepmd/pt/model/descriptor/sezm_nn/dens.py index e08c6bccf7..8dcf872188 100644 --- a/deepmd/pt/model/descriptor/sezm_nn/dens.py +++ b/deepmd/pt/model/descriptor/sezm_nn/dens.py @@ -475,7 +475,6 @@ def __init__( self.trainable = copy.deepcopy(trainable) self.atom_ener = atom_ener self.use_aparam_as_mask = bool(use_aparam_as_mask) - self._return_middle_output = False self.has_force_embedding_latent = self.condition_lmax >= 1 self.has_vector_latent = self.latent_lmax >= 1 trainable_flag = ( @@ -584,11 +583,6 @@ def get_sel_type(self) -> list[int]: """Return selected atom types of the energy branch.""" return self.energy_head.get_sel_type() - def set_return_middle_output(self, enable: bool) -> None: - """Enable or disable forwarding of the scalar energy hidden activations.""" - self._return_middle_output = bool(enable) - self.energy_head.set_return_middle_output(enable) - def build_force_embedding( self, force_input: torch.Tensor, @@ -703,8 +697,6 @@ def forward( "energy": energy_ret["energy"], "dforce": mixed_force.to(dtype=descriptor.dtype), } - if "middle_output" in energy_ret: - result["middle_output"] = energy_ret["middle_output"] if return_components: result["clean_dforce"] = clean_force result["denoising_dforce"] = denoising_force diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index a8b5b55584..264c01e8ee 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -1,7 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import torch - from deepmd.pt.model.descriptor.base_descriptor import ( BaseDescriptor, ) @@ -52,23 +50,3 @@ def get_fitting_net(self): # noqa: ANN201 def get_descriptor(self): # noqa: ANN201 """Get the descriptor.""" return self.atomic_model.descriptor - - @torch.jit.export - def set_eval_descriptor_hook(self, enable: bool) -> None: - """Set the hook for evaluating descriptor and clear the cache for descriptor list.""" - self.atomic_model.set_eval_descriptor_hook(enable) - - @torch.jit.export - def eval_descriptor(self) -> torch.Tensor: - """Evaluate the descriptor.""" - return self.atomic_model.eval_descriptor() - - @torch.jit.export - def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: - """Set the hook for evaluating fitting_last_layer and clear the cache for fitting_last_layer list.""" - self.atomic_model.set_eval_fitting_last_layer_hook(enable) - - @torch.jit.export - def eval_fitting_last_layer(self) -> torch.Tensor: - """Evaluate the fitting_last_layer.""" - return self.atomic_model.eval_fitting_last_layer() diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 6f5e347e68..465ff0af19 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -216,6 +216,88 @@ def forward_common( model_predict = self._output_type_cast(model_predict, input_prec) return model_predict + @torch.jit.export + def forward_embedding( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Extract embeddings in a single forward, without force/virial autograd. + + One descriptor and fitting forward yields the per-atom descriptor, the + per-atom last fitting hidden activation, and the per-structure pooled + feature (the masked atom-sum of the atomic feature). The neighbor + list is built exactly as in `forward_common`, so the descriptor and + atomic feature match the energy forward. + + Parameters + ---------- + coord + Coordinates with shape (nf, nloc*3) or (nf, nloc, 3). + atype + Atom types with shape (nf, nloc). + box + Simulation box with shape (nf, 9), or None. + fparam + Frame parameters with shape (nf, ndf), or None. + aparam + Atomic parameters with shape (nf, nloc, nda), or None. + charge_spin + Frame-level charge and spin conditions with shape (nf, 2), or None. + + Returns + ------- + dict[str, torch.Tensor] + ``descriptor`` with shape (nf, nloc, d), ``atomic_feature`` with + shape (nf, nloc, h), and ``structural_feature`` with shape + (nf, h), in the model's native precision. The DeepEval embedding + API casts these to the requested output dtype (float32 by + default). + + Raises + ------ + RuntimeError + If called in training mode; call ``model.eval()`` first. + """ + if self.training: + raise RuntimeError( + "Embedding extraction requires eval mode; call model.eval() first." + ) + cc, bb, fp, ap, _ = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + del coord, box, fparam, aparam + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + self.get_rcut(), + self.get_sel(), + # types are distinguished by `format_nlist` below when needed + mixed_types=True, + box=bb, + ) + extended_coord = extended_coord.view(extended_atype.shape[0], -1, 3) + nlist = self.format_nlist(extended_coord, extended_atype, nlist) + with torch.no_grad(): + return self.atomic_model.forward_embedding( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fp, + aparam=ap, + charge_spin=charge_spin, + ) + def get_out_bias(self) -> torch.Tensor: return self.atomic_model.get_out_bias() diff --git a/deepmd/pt/model/model/sezm_model.py b/deepmd/pt/model/model/sezm_model.py index a369fd028a..6537f9d84e 100644 --- a/deepmd/pt/model/model/sezm_model.py +++ b/deepmd/pt/model/model/sezm_model.py @@ -650,15 +650,8 @@ def _sezm_structure_key(model: SeZMModel) -> tuple[Any, ...]: descriptor.inner_clamp_r_outer, int(descriptor.get_dim_chg_spin()), ) - fitting_state = ( - _int_tuple(fitting.exclude_types), - bool(fitting.eval_return_middle_output), - ) - atomic_state = ( - _int_tuple(atomic_model.atom_exclude_types), - bool(atomic_model.enable_eval_descriptor_hook), - bool(atomic_model.enable_eval_fitting_last_layer_hook), - ) + fitting_state = (_int_tuple(fitting.exclude_types),) + atomic_state = (_int_tuple(atomic_model.atom_exclude_types),) model_state = ( str(model.bridging_method), model.inter_potential is not None, @@ -729,6 +722,8 @@ def __init__( # full, or EMA full -- therefore reuses cached compile products # instead of evicting the other mode. object.__setattr__(self, "compiled_core_compute_cache", {}) + object.__setattr__(self, "compiled_embedding", None) + object.__setattr__(self, "_embedding_task_buf_order", None) object.__setattr__(self, "compiled_dens_compute", None) # Maps cache_key -> task_buf_order for this instance so forward() # knows which buffers to pass and in what order. @@ -881,6 +876,69 @@ def forward( model_predict["updated_coord"] += coord return model_predict + def forward_embedding( + self, + coord: Float[Tensor, "nf nloc 3"] | Float[Tensor, "nf nloc_x3"], + atype: Int[Tensor, "nf nloc"], + box: Float[Tensor, "nf 9"] | None = None, + fparam: Float[Tensor, "nf ndf"] | None = None, + aparam: Float[Tensor, "nf nloc nda"] | None = None, + charge_spin: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """ + Extract embeddings in a single forward, without force/virial autograd. + + Reuses the standard ``ener`` neighbor-list and descriptor path and honors + the ``DP_COMPILE_INFER`` compile setting through a dedicated embedding + graph cache. + + Parameters + ---------- + coord + Coordinates with shape (nf, nloc*3) or (nf, nloc, 3) in Å. + atype + Atom types with shape (nf, nloc). + box + Box tensor with shape (nf, 9) in Å, or None. + fparam + Frame parameters with shape (nf, ndf) or None. + aparam + Atomic parameters with shape (nf, nloc, nda) or None. + charge_spin + Frame-level charge and spin conditions with shape (nf, 2). + + Returns + ------- + dict[str, torch.Tensor] + ``descriptor`` with shape (nf, nloc, d), ``atomic_feature`` with + shape (nf, nloc, h), and ``structural_feature`` with shape (nf, h). + + Raises + ------ + NotImplementedError + If the model is not in the ``ener`` execution mode. + RuntimeError + If called in training mode; call ``model.eval()`` first. + """ + if self.get_active_mode() != "ener": + raise NotImplementedError( + "Embedding extraction is only supported in the SeZM `ener` mode." + ) + if self.training: + raise RuntimeError( + "Embedding extraction requires eval mode; call model.eval() first." + ) + with torch.no_grad(): + return self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + charge_spin=charge_spin, + embedding_only=True, + ) + def forward_common( self, coord: Float[Tensor, "nf nloc 3"] | Float[Tensor, "nf nloc_x3"], @@ -892,6 +950,7 @@ def forward_common( force_input: Float[Tensor, "nf nloc 3"] | None = None, noise_mask: torch.Tensor | None = None, charge_spin: torch.Tensor | None = None, + embedding_only: bool = False, ) -> dict[str, torch.Tensor]: """ Return model prediction using standard neighbor list. @@ -975,6 +1034,7 @@ def forward_common( aparam=ap, charge_spin=charge_spin, input_prec=input_prec, + embedding_only=embedding_only, ) def forward_common_lower( @@ -992,6 +1052,7 @@ def forward_common_lower( charge_spin: torch.Tensor | None = None, input_prec: torch.dtype | None = None, use_compile: bool | None = None, + embedding_only: bool = False, ) -> dict[str, torch.Tensor]: """ Run the conservative SeZM lower interface on explicit edge vectors. @@ -1045,25 +1106,44 @@ def forward_common_lower( device=coord.device, ) has_coord_corr = extended_coord_corr is not None - cache_key = (bool(self.training), has_coord_corr) - if cache_key not in self.compiled_core_compute_cache: - self.trace_and_compile( - coord, - atype, - edge_index, - edge_vec, - edge_scatter_index, - edge_mask, - fp, - ap, - charge_spin, - extended_coord_corr=extended_coord_corr, - ) - compiled_core_compute = self.compiled_core_compute_cache[cache_key] - task_buf_vals = get_task_buffer_values( - self, - self._task_buf_order_cache[cache_key], - ) + if embedding_only: + # Eval-only graph: a single slot, with no + # (training, coord_corr) key and no cross-task sharing. + if self.compiled_embedding is None: + self.trace_and_compile( + coord, + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fp, + ap, + charge_spin, + embedding_only=True, + ) + cache_key = None + compiled_core_compute = self.compiled_embedding + task_buf_order = self._embedding_task_buf_order + else: + cache_key = (bool(self.training), has_coord_corr) + if cache_key not in self.compiled_core_compute_cache: + self.trace_and_compile( + coord, + atype, + edge_index, + edge_vec, + edge_scatter_index, + edge_mask, + fp, + ap, + charge_spin, + extended_coord_corr=extended_coord_corr, + ) + compiled_core_compute = self.compiled_core_compute_cache[cache_key] + task_buf_order = self._task_buf_order_cache[cache_key] + assert task_buf_order is not None + task_buf_vals = get_task_buffer_values(self, task_buf_order) grad_ctx: Any = nullcontext() if self.training else torch.no_grad() with nvtx_range("SeZM/core_compute"), grad_ctx: if extended_coord_corr is None: @@ -1120,6 +1200,7 @@ def forward_common_lower( aparam=ap, charge_spin=charge_spin, extended_coord_corr=extended_coord_corr, + embedding_only=embedding_only, ) return self._output_type_cast(model_predict, input_prec) @@ -1261,6 +1342,7 @@ def core_compute( charge_spin: torch.Tensor | None = None, comm_dict: dict[str, torch.Tensor] | None = None, extended_coord_corr: torch.Tensor | None = None, + embedding_only: bool = False, ) -> dict[str, torch.Tensor]: """ Compute SeZM lower outputs from the unified edge-vector schema. @@ -1295,6 +1377,9 @@ def core_compute( extended_coord_corr Coordinates correction for virial with shape ``(nf, nscatter, 3)`` or ``None``. + embedding_only + When ``True``, return only the embedding outputs and skip the + force/virial autograd entirely. Returns ------- @@ -1302,7 +1387,9 @@ def core_compute( DeePMD lower-style outputs (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu, mask). The per-atom virial (energy_derv_c) is always produced; callers decide whether to keep - it. + it. When ``embedding_only`` is ``True``, instead returns + ``descriptor`` (nf, nloc, d), ``atomic_feature`` (nf, nloc, h), and + ``structural_feature`` (nf, h). """ del comm_dict nf, nloc = atype.shape[:2] @@ -1315,8 +1402,10 @@ def core_compute( # SeZM differentiates only the pure map ``(edge_vec, theta) -> E``. # This keeps coordinate gathering and shift application outside the # differentiated region while preserving conservative forces through the - # scatter indices below. - edge_vec = edge_vec.detach().requires_grad_(True) + # scatter indices below. The embedding path produces no force, so it + # keeps ``edge_vec`` detached and never allocates an autograd leaf. + if not embedding_only: + edge_vec = edge_vec.detach().requires_grad_(True) # === Step 2. Descriptor forward === with nvtx_range("SeZM/descriptor"): @@ -1328,31 +1417,47 @@ def core_compute( edge_mask=edge_mask, charge_spin=charge_spin, ) - if self.atomic_model.enable_eval_descriptor_hook: - self.atomic_model.eval_descriptor_list.append(descriptor.detach()) - # === Step 3. Fitting net + output statistics === + # === Atom mask === + atom_mask = self.atomic_model.make_atom_mask(atype).to(torch.int32) + if self.atomic_model.atom_excl is not None: + atom_mask = atom_mask * self.atomic_model.atom_excl(atype) + + # === Step 3. Fitting net === + # The same fitting forward serves both modes; ``embedding_only`` only asks + # it to also return the last hidden activation. with nvtx_range("SeZM/fitting_net"): fit_ret = self.atomic_model.fitting_net( descriptor, atype, fparam=fparam, aparam=aparam, + return_atomic_feature=embedding_only, ) - if self.atomic_model.enable_eval_fitting_last_layer_hook: - assert "middle_output" in fit_ret, ( - "eval_fitting_last_layer not supported for this fitting net!" - ) - self.atomic_model.eval_fitting_last_layer_list.append( - fit_ret.pop("middle_output").detach() - ) + + # === Embedding short circuit === + # The embedding path returns three plain forward outputs: the per-atom + # descriptor, the per-atom last hidden activation, and the + # structure-level pooled feature (the masked atom-sum of the last hidden + # activation). All force/virial autograd below is skipped; the outputs + # stay in native precision and are cast to float32 by the DeepEval + # embedding API. + if embedding_only: + atomic_feature = fit_ret["atomic_feature"] + structural_feature = ( + atomic_feature * atom_mask[:, :, None].to(atomic_feature.dtype) + ).sum(dim=1) + return { + "descriptor": descriptor, + "atomic_feature": atomic_feature, + "structural_feature": structural_feature, + } + + # === Step 3b. Output statistics === with nvtx_range("SeZM/apply_out_stat"): fit_ret = self.atomic_model.apply_out_stat(fit_ret, atype) # === Step 4. Apply atom mask === - atom_mask = self.atomic_model.make_atom_mask(atype).to(torch.int32) - if self.atomic_model.atom_excl is not None: - atom_mask *= self.atomic_model.atom_excl(atype) for key in fit_ret.keys(): out_shape = fit_ret[key].shape flat_dim = 1 @@ -1492,8 +1597,6 @@ def core_compute_dens( force_embedding=force_embedding, charge_spin=charge_spin, ) - if self.atomic_model.enable_eval_descriptor_hook: - self.atomic_model.eval_descriptor_list.append(descriptor.detach()) # === Step 4. Dens fitting net === with nvtx_range("SeZM/dens_fitting_net"): @@ -1506,13 +1609,6 @@ def core_compute_dens( aparam=aparam, return_components=True, ) - if self.atomic_model.enable_eval_fitting_last_layer_hook: - assert "middle_output" in fit_ret, ( - "eval_fitting_last_layer not supported for this fitting net!" - ) - self.atomic_model.eval_fitting_last_layer_list.append( - fit_ret.pop("middle_output").detach() - ) return torch.cat( [ fit_ret["energy"], @@ -1647,6 +1743,7 @@ def trace_and_compile( ap: torch.Tensor, charge_spin: torch.Tensor, extended_coord_corr: torch.Tensor | None = None, + embedding_only: bool = False, ) -> None: """Trace ``core_compute()`` with ``make_fx`` and cache the compiled callable. @@ -1676,7 +1773,7 @@ def trace_and_compile( structure_key = _sezm_structure_key(self) cache_key = (bool(self.training), has_coord_corr) full_cache_key = structure_key + cache_key - if full_cache_key in _SEZM_COMPILE_CACHE: + if not embedding_only and full_cache_key in _SEZM_COMPILE_CACHE: self.compiled_core_compute_cache[cache_key] = _SEZM_COMPILE_CACHE[ full_cache_key ] @@ -1795,6 +1892,7 @@ def compute_fn( fparam=fp, aparam=ap, charge_spin=charge_spin, + embedding_only=embedding_only, ) finally: _restore_task_bufs(_saved) @@ -1830,6 +1928,7 @@ def compute_fn( # type: ignore[misc] aparam=ap, charge_spin=charge_spin, extended_coord_corr=extended_coord_corr, + embedding_only=embedding_only, ) finally: _restore_task_bufs(_saved) @@ -2056,6 +2155,17 @@ def _inductor_inference_compiler( def compiled(*args: Any, _fn: Any = _compiled_flat) -> dict[str, Any]: return dict(zip(_keys, _fn(*args))) + # The embedding graph is eval-only with a single slot (cache key + # ``None``) and is not shared across tasks. It reuses the pending-compile + # timer so its compile time is logged after the first compiled call, + # robust whether the AOT lowering is eager or lazy. + if embedding_only: + object.__setattr__(self, "compiled_embedding", compiled) + object.__setattr__(self, "_embedding_task_buf_order", task_buf_names) + self._core_compute_pending_compile_t0 = _compile_t0 + self._core_compute_pending_compile_key = None + return + # Populate both per-instance and module-level shared caches. # The shared cache (_SEZM_COMPILE_CACHE) lets a second task with the # same structure key skip re-tracing and re-compiling entirely. @@ -2666,8 +2776,11 @@ def reset_head_for_mode(self, mode: str) -> None: self._core_compute_pending_compile_t0 = None self._core_compute_pending_compile_key = None # Drop every compile slot so the next forward retraces against the - # reinitialised fitting head. + # reinitialised fitting head. The embedding graph reads the same + # fitting head, so it is invalidated together with the energy graph. self.compiled_core_compute_cache.clear() + object.__setattr__(self, "compiled_embedding", None) + object.__setattr__(self, "_embedding_task_buf_order", None) # ========================================================================= # Bridging Helpers diff --git a/deepmd/pt/model/task/dipole.py b/deepmd/pt/model/task/dipole.py index 6e9b0340ff..245903a9c5 100644 --- a/deepmd/pt/model/task/dipole.py +++ b/deepmd/pt/model/task/dipole.py @@ -187,22 +187,34 @@ def forward( h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: nframes, nloc, _ = descriptor.shape assert gr is not None, "Must provide the rotation matrix for dipole fitting." # cast the input to internal precsion gr = gr.to(self.prec) + fit_ret = self._forward_common( + descriptor, + atype, + gr, + g2, + h2, + fparam, + aparam, + return_atomic_feature=return_atomic_feature, + ) # (nframes, nloc, m1) - out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ - self.var_name - ] + out = fit_ret[self.var_name] # (nframes * nloc, 1, m1) out = out.view(-1, 1, self.embedding_width) # (nframes * nloc, m1, 3) gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes, nloc, 3) out = torch.bmm(out, gr).squeeze(-2).view(nframes, nloc, 3) - return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + result = {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + if return_atomic_feature: + result["atomic_feature"] = fit_ret["atomic_feature"] + return result # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 39dabf99f8..2aec091e65 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -204,6 +204,7 @@ def forward( h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> tuple[torch.Tensor, None]: """Based on embedding net output, alculate total energy. @@ -215,6 +216,10 @@ def forward( ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ + if return_atomic_feature: + raise NotImplementedError( + "EnergyFittingNetDirect does not expose an atomic feature." + ) nframes, nloc, _ = inputs.size() if self.use_tebd: # if atype_tebd is not None: diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 439d3d11d9..0bc322f331 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -548,8 +548,6 @@ def __init__( for param in self.parameters(): param.requires_grad = self.trainable - self.eval_return_middle_output = False - def reinit_exclude( self, exclude_types: list[int] = [], @@ -681,9 +679,6 @@ def set_case_embd(self, case_idx: int) -> None: case_idx ] - def set_return_middle_output(self, return_middle_output: bool = True) -> None: - self.eval_return_middle_output = return_middle_output - def __setitem__(self, key: str, value: torch.Tensor) -> None: if key in ["bias_atom_e"]: value = value.view([self.ntypes, self._net_out_dim()]) @@ -745,6 +740,7 @@ def _forward_common( h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: # cast the input to internal precsion xx = descriptor.to(self.prec) @@ -846,8 +842,8 @@ def _forward_common( if self.mixed_types: atom_property = self.filter_layers.networks[0](xx) - if self.eval_return_middle_output: - results["middle_output"] = self.filter_layers.networks[ + if return_atomic_feature: + results["atomic_feature"] = self.filter_layers.networks[ 0 ].call_until_last(xx) if xx_zeros is not None: @@ -856,23 +852,28 @@ def _forward_common( outs + atom_property + self.bias_atom_e[atype].to(self.prec) ) # Shape is [nframes, natoms[0], net_dim_out] else: - if self.eval_return_middle_output: - outs_middle = torch.zeros( - (nf, nloc, self.neuron[-1]), - dtype=self.prec, - device=descriptor.device, - ) # jit assertion + if return_atomic_feature: + # Each atom carries the last hidden activation of its own type + # network, gathered by summing the type-masked contributions. + atomic_feature_type: torch.Tensor = self.filter_layers.networks[ + 0 + ].call_until_last(xx) + mask = (atype == 0).unsqueeze(-1) + atomic_feature = torch.where( + mask, + atomic_feature_type, + torch.zeros_like(atomic_feature_type), + ) for type_i, ll in enumerate(self.filter_layers.networks): - mask = (atype == type_i).unsqueeze(-1) - mask = torch.tile(mask, (1, 1, net_dim_out)) - middle_output_type = ll.call_until_last(xx) - middle_output_type = torch.where( - torch.tile(mask, (1, 1, self.neuron[-1])), - middle_output_type, - 0.0, - ) - outs_middle = outs_middle + middle_output_type - results["middle_output"] = outs_middle + if type_i > 0: + mask = (atype == type_i).unsqueeze(-1) + atomic_feature_type = ll.call_until_last(xx) + atomic_feature = atomic_feature + torch.where( + mask, + atomic_feature_type, + torch.zeros_like(atomic_feature_type), + ) + results["atomic_feature"] = atomic_feature for type_i, ll in enumerate(self.filter_layers.networks): mask = (atype == type_i).unsqueeze(-1) mask = torch.tile(mask, (1, 1, net_dim_out)) diff --git a/deepmd/pt/model/task/invar_fitting.py b/deepmd/pt/model/task/invar_fitting.py index a8953fcd2b..584915321b 100644 --- a/deepmd/pt/model/task/invar_fitting.py +++ b/deepmd/pt/model/task/invar_fitting.py @@ -174,27 +174,33 @@ def forward( h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: """Based on embedding net output, alculate total energy. Args: - inputs: Embedding matrix. Its shape is [nframes, natoms[0], self.dim_descrpt]. - natoms: Tell atom count and element count. Its shape is [2+self.ntypes]. + - return_atomic_feature: also return the last hidden activation under the + ``atomic_feature`` key. Returns ------- - `torch.Tensor`: Total energy with shape [nframes, natoms[0]]. """ - out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam) + out = self._forward_common( + descriptor, + atype, + gr, + g2, + h2, + fparam, + aparam, + return_atomic_feature=return_atomic_feature, + ) result = {self.var_name: out[self.var_name].to(env.GLOBAL_PT_FLOAT_PRECISION)} - if "middle_output" in out: - result.update( - { - "middle_output": out["middle_output"].to( - env.GLOBAL_PT_FLOAT_PRECISION - ) - } - ) + if return_atomic_feature: + result["atomic_feature"] = out["atomic_feature"] return result # make jit happy with torch 2.0.0 diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index c3a7ed52a1..3d463a7723 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -235,6 +235,7 @@ def forward( h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: nframes, nloc, _ = descriptor.shape assert gr is not None, ( @@ -242,10 +243,18 @@ def forward( ) # cast the input to internal precsion gr = gr.to(self.prec) + fit_ret = self._forward_common( + descriptor, + atype, + gr, + g2, + h2, + fparam, + aparam, + return_atomic_feature=return_atomic_feature, + ) # (nframes, nloc, _net_out_dim) - out = self._forward_common(descriptor, atype, gr, g2, h2, fparam, aparam)[ - self.var_name - ] + out = fit_ret[self.var_name] out = out * (self.scale.to(atype.device).to(self.prec))[atype] gr = gr.view(nframes * nloc, self.embedding_width, 3) # (nframes * nloc, m1, 3) @@ -261,7 +270,10 @@ def forward( "bim,bmj->bij", gr.transpose(1, 2), out ) # (nframes * nloc, 3, 3) out = out.view(nframes, nloc, 3, 3) - return {"polarizability": out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + result = {"polarizability": out.to(env.GLOBAL_PT_FLOAT_PRECISION)} + if return_atomic_feature: + result["atomic_feature"] = fit_ret["atomic_feature"] + return result # make jit happy with torch 2.0.0 exclude_types: list[int] diff --git a/deepmd/pt/model/task/sezm_ener.py b/deepmd/pt/model/task/sezm_ener.py index c6af12fb5f..0932ec7086 100644 --- a/deepmd/pt/model/task/sezm_ener.py +++ b/deepmd/pt/model/task/sezm_ener.py @@ -645,6 +645,7 @@ def _forward_common( h2: torch.Tensor | None = None, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: """Run the SeZM fitting path with optional case FiLM.""" if not self.case_film_embd: @@ -656,8 +657,15 @@ def _forward_common( h2, fparam, aparam, + return_atomic_feature=return_atomic_feature, ) - return self._forward_case_film(descriptor, atype, fparam, aparam) + return self._forward_case_film( + descriptor, + atype, + fparam, + aparam, + return_atomic_feature=return_atomic_feature, + ) def _forward_case_film( self, @@ -665,6 +673,7 @@ def _forward_case_film( atype: torch.Tensor, fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, + return_atomic_feature: bool = False, ) -> dict[str, torch.Tensor]: """ Forward path for SeZM case FiLM. @@ -679,6 +688,9 @@ def _forward_case_film( Frame parameters with shape (nf, numb_fparam). aparam Atomic parameters with shape (nf, nloc, numb_aparam). + return_atomic_feature + When True, also return the last hidden activation under the + ``atomic_feature`` key. Returns ------- @@ -752,8 +764,8 @@ def _forward_case_film( fitting = self.filter_layers.networks[0] atom_property = fitting(xx, self.case_embd) - if self.eval_return_middle_output: - results["middle_output"] = fitting.call_until_last(xx, self.case_embd) + if return_atomic_feature: + results["atomic_feature"] = fitting.call_until_last(xx, self.case_embd) if xx_zeros is not None: atom_property -= fitting(xx_zeros, self.case_embd) outs = outs + atom_property + self.bias_atom_e[atype].to(self.prec) diff --git a/doc/inference/embedding.md b/doc/inference/embedding.md new file mode 100644 index 0000000000..b4509c74ee --- /dev/null +++ b/doc/inference/embedding.md @@ -0,0 +1,92 @@ +# Model embeddings + +A trained model can export learned representations ("embeddings") for downstream +analysis, such as clustering, visualization, or training auxiliary models. A +single forward pass produces the embeddings without computing forces or virials. + +:::{note} +**Supported backends**: PyTorch {{ pytorch_icon }}, for energy models (including +DPA4/SeZM and DP+ZBL / linear combinations, where the embedding comes from the +descriptor-fitting sub-model). It also works for other descriptor-fitting models +(dipole, polarizability, dos, property), though the `structural_feature` is only +physically meaningful for energy models. Spin models are not supported. +::: + +Three embeddings are produced for each frame: + +- `descriptor`: the per-atom local-environment representation, with shape + (nframes, natoms, dim_descriptor). +- `atomic_feature`: the per-atom activation after the last fitting hidden layer + (before the final output projection), with shape (nframes, natoms, dim_hidden). +- `structural_feature`: a per-structure summary obtained by summing + `atomic_feature` over the atoms of each frame, with shape + (nframes, dim_hidden). + +## Command line + +The embeddings of a model can be evaluated and saved using `dp embed`. A +typical usage is + +```bash +dp embed -m model.ckpt.pt -s /path/to/system -o embedding.hdf5 +``` + +where `-m` gives the model (a training checkpoint `.pt`, or a frozen `.pth` for +standard energy models; SeZM/DPA4 only supports the `.pt` checkpoint), `-s` the +path to the system directory (or `-f` for a datafile listing system directories, +one per line), and `-o` the output HDF5 file. Use `--dtype` to choose the output +precision (`fp32`, `fp64`, or `native`; default `fp32`). Reading from a +multi-task model additionally accepts `--head` to select the model branch. + +Several other command line options can be passed to `dp embed`, which can be +checked with + +```bash +dp embed --help +``` + +## Output format + +The output is a single HDF5 file. Each system is stored as a group named after +the system directory, with the source directory recorded in the group's +`system` attribute. Each group holds the datasets `descriptor`, +`atomic_feature`, `structural_feature`, and `atom_types` (with shape +(nframes, natoms); the frame axis follows the system's frame order), together +with an `nframes` attribute. The model `type_map` is stored as a file-level +attribute. The three embedding datasets are stored using the selected output +dtype, and all datasets use gzip and byte-shuffle compression. + +The file can be read back with `h5py`: + +```python +import h5py + +with h5py.File("embedding.hdf5", "r") as f: + type_map = f.attrs["type_map"] + for system_name in f: + group = f[system_name] + source = group.attrs["system"] + descriptor = group["descriptor"][:] + atomic_feature = group["atomic_feature"][:] + structural_feature = group["structural_feature"][:] +``` + +## Python interface + +The same embeddings are available from the Python inference interface: + +```python +from deepmd.infer import DeepPot +import numpy as np + +dp = DeepPot("model.ckpt.pt") +coord = np.array([[1, 0, 0], [0, 0, 1.5], [1, 0, 3]]).reshape([1, -1]) +cell = np.diag(10 * np.ones(3)).reshape([1, -1]) +atype = [1, 0, 1] +descriptor, atomic_feature, structural_feature = dp.eval_embedding(coord, cell, atype) +``` + +The embeddings are returned as float32 by default (both from the Python interface +and the `dp embed` command), which is ample for downstream analysis. Pass +`dtype="fp64"` or `dtype="native"` to {meth}`DeepPot.eval_embedding` (or +`--dtype fp64/native` to `dp embed`) when a different output precision is needed. diff --git a/doc/inference/index.rst b/doc/inference/index.rst index dd935aac84..3d360db4fb 100644 --- a/doc/inference/index.rst +++ b/doc/inference/index.rst @@ -7,5 +7,6 @@ Note that the model for inference is required to be compatible with the DeePMD-k :maxdepth: 1 python + embedding cxx nodejs diff --git a/doc/inference/python.md b/doc/inference/python.md index 361db7b64f..c53ade8c31 100644 --- a/doc/inference/python.md +++ b/doc/inference/python.md @@ -34,6 +34,15 @@ descriptors = dp.eval_descriptor(coord, cell, atype) where `descriptors` is the descriptor matrix of the system. This can also be done using the command line interface `dp eval-desc` as described in the [test documentation](../test/test.md). +:::{note} +`eval_descriptor` is the descriptor-only interface supported across backends. In +the PyTorch backend, [`eval_embedding`](embedding.md) additionally returns the +descriptor, per-atom feature, and per-structure feature in a single forward pass. +PyTorch descriptor/embedding APIs accept `dtype="fp32"`, `"fp64"`, or `"native"`; +`eval_descriptor` defaults to `native`, while `eval_embedding` defaults to +`fp32`. +::: + Furthermore, one can use the python interface to calculate model deviation. ```python diff --git a/doc/model/dpa4.md b/doc/model/dpa4.md index 78d4787a5d..6dc9949f42 100644 --- a/doc/model/dpa4.md +++ b/doc/model/dpa4.md @@ -552,6 +552,44 @@ pair_coeff * * O H The ordinary TorchScript freeze path is not used for DPA4/SeZM checkpoints. A small LAMMPS example is in `examples/water/dpa4/lmp/`. +## Embedding extraction + +A trained DPA4/SeZM model can export learned representations for downstream +analysis with `dp embed`. A single forward pass (no force or virial +computation) produces three embeddings per system: + +- `descriptor`: the per-atom local-environment representation, with shape + (nframes, natoms, dim_descriptor). +- `atomic_feature`: the per-atom activation after the last fitting hidden layer, + with shape (nframes, natoms, dim_hidden). +- `structural_feature`: a whole-structure summary obtained by summing + `atomic_feature` over atoms, with shape (nframes, dim_hidden). + +A typical invocation operates on the PyTorch checkpoint (`.pt`): + +```bash +dp embed -m model.ckpt.pt -s /path/to/system -o embedding.hdf5 +``` + +The results are written to a single HDF5 file in which each system is a group +holding the three float32 datasets above. They can be read back with `h5py`: + +```python +import h5py + +with h5py.File("embedding.hdf5", "r") as f: + type_map = f.attrs["type_map"] + group = f[next(iter(f.keys()))] + descriptor = group["descriptor"][:] + atomic_feature = group["atomic_feature"][:] + structural_feature = group["structural_feature"][:] +``` + +This command is available for DPA4/SeZM energy models in the PyTorch backend and +honors both `DP_COMPILE_INFER` and `DP_TRITON_INFER`. It operates on the training checkpoint (`.pt`); the +frozen `.pt2` package is not supported. See +[model embeddings](../inference/embedding.md) for the full description. + ## Data format DPA4/SeZM uses the [standard DeePMD-kit data format](../data/system.md). Keep diff --git a/doc/test/test.md b/doc/test/test.md index 9d399cb1ed..c7c2dc1271 100644 --- a/doc/test/test.md +++ b/doc/test/test.md @@ -26,7 +26,7 @@ The descriptors of a model can be evaluated and saved using `dp eval-desc`. A ty dp eval-desc -m graph.pb -s /path/to/system -o desc ``` -where `-m` gives the model file, `-s` the path to the system directory (or `-f` for a datafile containing paths to systems), and `-o` the output directory where descriptor files will be saved. The descriptors for each system will be saved as `.npy` files with the format `desc/(system_name).npy`. Each descriptor file contains a 3D array with shape (nframes, natoms, ndesc). +where `-m` gives the model file, `-s` the path to the system directory (or `-f` for a datafile containing paths to systems), and `-o` the output directory where descriptor files will be saved. Use `--dtype` to choose the output precision (`fp32`, `fp64`, or `native`; default `native`). The descriptors for each system will be saved as `.npy` files with the format `desc/(system_name).npy`. Each descriptor file contains a 3D array with shape (nframes, natoms, ndesc). Several other command line options can be passed to `dp eval-desc`, which can be checked with diff --git a/source/tests/pt/model/test_embedding.py b/source/tests/pt/model/test_embedding.py new file mode 100644 index 0000000000..2617a0afba --- /dev/null +++ b/source/tests/pt/model/test_embedding.py @@ -0,0 +1,437 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for embedding extraction (descriptor, atomic and structural features).""" + +import os +import shutil +import tempfile +import unittest +from types import ( + SimpleNamespace, +) + +import h5py +import numpy as np +import torch +from packaging.version import parse as parse_version + +from deepmd.infer.deep_pot import ( + DeepPot, +) +from deepmd.pt.infer.deep_eval import DeepEval as PTDeepEval +from deepmd.pt.model.model import ( + get_model, + get_sezm_model, +) +from deepmd.pt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt.utils import ( + env, +) + +# The SeZM compile path is validated on torch 2.11.x / 2.12.x only. +_TORCH_VERSION = parse_version(torch.__version__) +_SKIP_COMPILE = (_TORCH_VERSION.major, _TORCH_VERSION.minor) not in {(2, 11), (2, 12)} +_SKIP_COMPILE_REASON = ( + "SeZM's torch.compile path is only supported on torch 2.11.x and 2.12.x." +) + + +def _sezm_params() -> dict: + """Return a small SeZM/DPA4 model configuration for fast tests.""" + return { + "type": "SeZM", + "type_map": ["A", "B"], + "descriptor": { + "type": "SeZM", + "sel": [2, 2], + "rcut": 3.0, + "channels": 4, + "n_focus": 1, + "n_radial": 3, + "radial_mlp": [6], + "use_env_seed": True, + "l_schedule": [1, 0], + "mmax": 1, + "so2_layers": 1, + "n_atten_head": 1, + "ffn_neurons": 8, + "ffn_blocks": 1, + "use_amp": False, + "precision": "float32", + "seed": 7, + }, + "fitting_net": { + "neuron": [8], + "activation_function": "silu", + "precision": "float32", + "seed": 7, + }, + } + + +def _se_e2_a_params() -> dict: + """Return a small standard ``se_e2_a`` energy model configuration.""" + return { + "type_map": ["A", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [4, 4], + "rcut_smth": 0.5, + "rcut": 3.0, + "neuron": [4, 8], + "axis_neuron": 4, + "seed": 1, + }, + "fitting_net": { + "neuron": [8], + "seed": 1, + }, + } + + +def _randomize(model: torch.nn.Module, seed: int = 1234) -> None: + """Fill parameters with small random values to expose masked paths.""" + torch.manual_seed(seed) + with torch.no_grad(): + for param in model.parameters(): + param.copy_(torch.randn_like(param) * 0.1) + + +def _make_frame( + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build one deterministic 7-atom frame (coord, atype, box).""" + coord = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [1.1, 0.3, 0.0], + [0.2, 1.5, 0.4], + [1.7, 1.2, 0.2], + [2.3, 0.1, 1.0], + [0.8, 2.2, 1.1], + [2.6, 1.8, 1.5], + ], + ], + device=device, + dtype=torch.float32, + ) + atype = torch.tensor([[0, 1, 0, 1, 0, 1, 0]], device=device, dtype=torch.int32) + box = torch.tensor( + [[8.0, 0.0, 0.0, 0.0, 8.0, 0.0, 0.0, 0.0, 8.0]], + device=device, + dtype=torch.float32, + ) + return coord, atype, box + + +class TestSeZMEmbeddingForward(unittest.TestCase): + """Validate ``forward_embedding`` against the independent energy forward.""" + + def setUp(self) -> None: + torch.manual_seed(2024) + self.device = env.DEVICE + self.model = get_sezm_model(_sezm_params()).to(self.device) + _randomize(self.model) + self.model.eval() + self.coord, self.atype, self.box = _make_frame(self.device) + + def test_embedding_reconstructs_energy(self) -> None: + # Projecting the atomic / structural features through the fitting output + # layer reproduces the per-atom / total energy of the independent energy + # forward. With the per-type bias zeroed and no excluded atoms, the + # output layer is linear, so this validates both the atomic feature (it + # is the last hidden activation) and the structural feature (it is the + # atom-pooled feature). + fitting = self.model.atomic_model.fitting_net + with torch.no_grad(): + fitting.bias_atom_e.zero_() + out = self.model(self.coord, self.atype, box=self.box) + emb = self.model.forward_embedding(self.coord, self.atype, box=self.box) + output_layer = fitting.filter_layers.networks[0].output_layer + + recon_atom = output_layer(emb["atomic_feature"].to(fitting.prec)) + torch.testing.assert_close( + recon_atom.double(), + out["atom_energy"].double(), + atol=1e-4, + rtol=1e-4, + check_dtype=False, + ) + recon_total = output_layer(emb["structural_feature"].to(fitting.prec)) + torch.testing.assert_close( + recon_total.flatten().double(), + out["energy"].flatten().double(), + atol=1e-4, + rtol=1e-4, + check_dtype=False, + ) + + @unittest.skipIf(_SKIP_COMPILE, _SKIP_COMPILE_REASON) + def test_compiled_matches_eager(self) -> None: + with unittest.mock.patch.dict( + os.environ, {"DP_COMPILE_INFER": "1"}, clear=False + ): + model_cmp = get_sezm_model(_sezm_params()).to(self.device) + model_cmp.load_state_dict(self.model.state_dict()) + model_cmp.eval() + self.assertTrue(model_cmp.should_use_compile()) + + eager = self.model.forward_embedding(self.coord, self.atype, box=self.box) + compiled = model_cmp.forward_embedding(self.coord, self.atype, box=self.box) + # Inductor reductions can differ from eager by ~1e-3 in float32 on GPU. + atol = 1e-5 if self.device == torch.device("cpu") else 2e-3 + rtol = 1e-5 if self.device == torch.device("cpu") else 3e-3 + for key in ("descriptor", "atomic_feature", "structural_feature"): + torch.testing.assert_close( + eager[key], compiled[key], atol=atol, rtol=rtol, msg=key + ) + + +class TestEmbeddingDeepEvalAPI(unittest.TestCase): + """Validate the ``DeepEval`` embedding API and the unsupported boundary.""" + + def setUp(self) -> None: + torch.manual_seed(2024) + self.device = env.DEVICE + self._tmp = tempfile.mkdtemp() + coord, atype, box = _make_frame(self.device) + self.coord_np = coord.cpu().numpy() + self.cell_np = box.cpu().numpy() + self.atype_np = atype[0].cpu().numpy() + + def tearDown(self) -> None: + shutil.rmtree(self._tmp, ignore_errors=True) + + def _save_checkpoint(self, model: torch.nn.Module, params: dict, name: str) -> str: + path = os.path.join(self._tmp, name) + wrapper = ModelWrapper(model, model_params=params) + torch.save({"model": wrapper.state_dict()}, path) + return path + + def test_sezm_eval_embedding(self) -> None: + params = _sezm_params() + model = get_sezm_model(params) + _randomize(model) + path = self._save_checkpoint(model, params, "sezm.pt") + + dp = DeepPot(path) + descriptor, atomic_feature, structural_feature = dp.eval_embedding( + self.coord_np, self.cell_np, self.atype_np + ) + + natoms = int(self.atype_np.shape[0]) + self.assertEqual(descriptor.shape[:2], (1, natoms)) + self.assertEqual(atomic_feature.shape[:2], (1, natoms)) + self.assertEqual(structural_feature.shape, (1, atomic_feature.shape[2])) + # Embeddings are returned as float32. + self.assertEqual(descriptor.dtype, np.float32) + self.assertEqual(atomic_feature.dtype, np.float32) + self.assertEqual(structural_feature.dtype, np.float32) + + def test_standard_model_embedding(self) -> None: + # Standard energy models reuse the base-class embedding path; the + # descriptor and atomic feature flow through the same neighbor list as + # the energy forward. + params = _se_e2_a_params() + model = get_model(params) + _randomize(model) + path = self._save_checkpoint(model, params, "se_e2_a.pt") + + dp = DeepPot(path) + descriptor, atomic_feature, structural_feature = dp.eval_embedding( + self.coord_np, self.cell_np, self.atype_np + ) + + natoms = int(self.atype_np.shape[0]) + self.assertEqual(descriptor.shape[:2], (1, natoms)) + self.assertEqual(atomic_feature.shape[:2], (1, natoms)) + self.assertEqual(structural_feature.shape, (1, atomic_feature.shape[2])) + # Embeddings are returned as float32. + self.assertEqual(descriptor.dtype, np.float32) + self.assertEqual(atomic_feature.dtype, np.float32) + self.assertEqual(structural_feature.dtype, np.float32) + # The structural feature is the atom-sum of the atomic feature. + np.testing.assert_allclose( + structural_feature[0], + atomic_feature[0].sum(axis=0), + rtol=1e-4, + atol=1e-5, + ) + # eval_descriptor / eval_fitting_last_layer are thin wrappers that slice + # the embedding output (independent forward passes match to round-off). + np.testing.assert_allclose( + dp.eval_descriptor( + self.coord_np, self.cell_np, self.atype_np, dtype="fp32" + ), + descriptor, + rtol=1e-10, + atol=1e-12, + ) + np.testing.assert_allclose( + dp.eval_fitting_last_layer( + self.coord_np, self.cell_np, self.atype_np, dtype="fp32" + ), + atomic_feature, + rtol=1e-10, + atol=1e-12, + ) + + def test_standard_model_embedding_without_hidden_layers(self) -> None: + params = _se_e2_a_params() + params["fitting_net"]["neuron"] = [] + model = get_model(params) + _randomize(model) + path = self._save_checkpoint(model, params, "se_e2_a_no_hidden.pt") + + dp = DeepPot(path) + descriptor, atomic_feature, structural_feature = dp.eval_embedding( + self.coord_np, self.cell_np, self.atype_np + ) + + self.assertEqual(atomic_feature.shape, descriptor.shape) + self.assertEqual(structural_feature.shape, (1, atomic_feature.shape[2])) + self.assertEqual(atomic_feature.dtype, np.float32) + + def test_eval_embedding_dtype_fp64(self) -> None: + params = _se_e2_a_params() + model = get_model(params) + path = self._save_checkpoint(model, params, "se_e2_a_fp64.pt") + + dp = DeepPot(path) + descriptor, atomic_feature, structural_feature = dp.eval_embedding( + self.coord_np, self.cell_np, self.atype_np, dtype="fp64" + ) + + self.assertEqual(descriptor.dtype, np.float64) + self.assertEqual(atomic_feature.dtype, np.float64) + self.assertEqual(structural_feature.dtype, np.float64) + self.assertEqual( + dp.eval_descriptor( + self.coord_np, self.cell_np, self.atype_np, dtype="fp64" + ).dtype, + np.float64, + ) + + def test_legacy_frozen_model_uses_baked_in_hook(self) -> None: + # Frozen ``.pth`` files predating ``forward_embedding`` still carry the + # descriptor / fitting hooks baked into the TorchScript module. The + # deprecated ``eval_descriptor`` / ``eval_fitting_last_layer`` must drive + # those hooks instead of raising, so existing artifacts keep working. + toggles: list[tuple[str, bool]] = [] + with torch.device("cpu"): + desc = torch.arange(6, dtype=torch.float64).reshape(1, 2, 3) + fit = torch.arange(8, dtype=torch.float64).reshape(1, 2, 4) + + class LegacyModel: + def set_eval_descriptor_hook(self, enable: bool) -> None: + toggles.append(("descriptor", enable)) + + def eval_descriptor(self) -> torch.Tensor: + return desc + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + toggles.append(("fitting", enable)) + + def eval_fitting_last_layer(self) -> torch.Tensor: + return fit + + backend = object.__new__(PTDeepEval) + backend.auto_batch_size = None + backend.dp = SimpleNamespace(model={"Default": LegacyModel()}) + # The hook caches are populated by the (here mocked) forward pass. + backend.eval = lambda *args, **kwargs: None + + np.testing.assert_array_equal( + PTDeepEval.eval_descriptor( + backend, self.coord_np, self.cell_np, self.atype_np + ), + desc.numpy(), + ) + np.testing.assert_array_equal( + PTDeepEval.eval_fitting_last_layer( + backend, self.coord_np, self.cell_np, self.atype_np + ), + fit.numpy(), + ) + # Each hook is enabled for the forward pass and disabled afterwards. + self.assertEqual( + toggles, + [ + ("descriptor", True), + ("descriptor", False), + ("fitting", True), + ("fitting", False), + ], + ) + + +class TestEmbeddingEntrypoint(unittest.TestCase): + """Validate the HDF5 entrypoint without requiring on-disk systems.""" + + def setUp(self) -> None: + self._tmp = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self._tmp, ignore_errors=True) + + def test_embedding_writes_hdf5_from_datafile(self) -> None: + import importlib + + embedding_module = importlib.import_module("deepmd.entrypoints.embedding") + datafile = os.path.join(self._tmp, "systems.txt") + output = os.path.join(self._tmp, "embedding.hdf5") + with open(datafile, "w") as fp: + fp.write("/tmp/a/sys\n\n/tmp/b/sys\n") + + class FakeData: + mixed_type = False + pbc = False + + def __init__(self, *args, **kwargs) -> None: + pass + + def get_test(self) -> dict[str, np.ndarray]: + return { + "coord": np.zeros((1, 2, 3), dtype=np.float64), + "box": np.zeros((1, 9), dtype=np.float64), + "type": np.array([[0, 1]], dtype=np.int32), + } + + descriptor = np.zeros((1, 2, 3), dtype=np.float64) + atomic_feature = np.ones((1, 2, 4), dtype=np.float64) + structural_feature = atomic_feature.sum(axis=1) + dp = unittest.mock.Mock() + dp.get_type_map.return_value = ["A", "B"] + dp.get_dim_fparam.return_value = 0 + dp.get_dim_aparam.return_value = 0 + dp.eval_embedding.return_value = ( + descriptor, + atomic_feature, + structural_feature, + ) + + with ( + unittest.mock.patch.object(embedding_module, "DeepEval", return_value=dp), + unittest.mock.patch.object(embedding_module, "DeepmdData", FakeData), + ): + embedding_module.embedding( + model="model.pt", + system=".", + datafile=datafile, + output=output, + dtype="fp64", + ) + + self.assertEqual(dp.eval_embedding.call_count, 2) + self.assertIsNone(dp.eval_embedding.call_args.args[1]) + self.assertEqual(dp.eval_embedding.call_args.kwargs["dtype"], "fp64") + with h5py.File(output, "r") as h5file: + self.assertEqual(set(h5file.keys()), {"sys", "sys_1"}) + self.assertEqual(h5file["sys"]["descriptor"].dtype, np.float64) + self.assertEqual(h5file["sys_1"].attrs["system"], "/tmp/b/sys") + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_linear_atomic_model.py b/source/tests/pt/model/test_linear_atomic_model.py index e693e8b90e..5b178ea38e 100644 --- a/source/tests/pt/model/test_linear_atomic_model.py +++ b/source/tests/pt/model/test_linear_atomic_model.py @@ -218,6 +218,22 @@ def test_jit(self) -> None: # self.assertEqual(md3.get_rcut(), self.rcut) # self.assertEqual(md3.get_type_map(), ["foo", "bar"]) + def test_forward_embedding(self) -> None: + # The embedding of a DP+ZBL model is taken from its DP sub-model. + self.assertTrue(self.md0.has_embedding()) + args = [ + to_torch_tensor(ii) for ii in [self.coord_ext, self.atype_ext, self.nlist] + ] + emb = self.md0.forward_embedding(*args) + for key in ("descriptor", "atomic_feature", "structural_feature"): + self.assertIn(key, emb) + self.assertEqual(tuple(emb["descriptor"].shape[:2]), (self.nf, self.nloc)) + self.assertEqual(tuple(emb["atomic_feature"].shape[:2]), (self.nf, self.nloc)) + self.assertEqual( + tuple(emb["structural_feature"].shape), + (self.nf, emb["atomic_feature"].shape[2]), + ) + class TestRemmapMethod(unittest.TestCase): def test_valid(self) -> None: diff --git a/source/tests/pt/test_eval_desc.py b/source/tests/pt/test_eval_desc.py index ff79a0a376..f11f0300fc 100644 --- a/source/tests/pt/test_eval_desc.py +++ b/source/tests/pt/test_eval_desc.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +# +# NOTE: `dp eval-desc` is deprecated; prefer `dp embed` (see test_embedding.py). +# These tests cover the compatibility path, which now redirects to the embedding +# implementation internally and extracts the descriptor from its output. import json import os import shutil diff --git a/source/tests/pt/test_oom_retry.py b/source/tests/pt/test_oom_retry.py deleted file mode 100644 index cc9fe118aa..0000000000 --- a/source/tests/pt/test_oom_retry.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import unittest -from types import ( - SimpleNamespace, -) -from typing import ( - Any, -) -from unittest.mock import ( - MagicMock, - call, - patch, -) - -import numpy as np - -from deepmd.utils.batch_size import ( - RetrySignal, -) - - -class DummyAutoBatchSize: - def __init__(self) -> None: - self.oom_retry_mode = False - self.modes: list[bool] = [] - - def set_oom_retry_mode(self, enable: bool) -> None: - self.oom_retry_mode = enable - self.modes.append(enable) - - -class TestPtOOMRetry(unittest.TestCase): - def _make_backend(self, method_name: str) -> tuple[Any, MagicMock]: - try: - from deepmd.pt.infer.deep_eval import ( - DeepEval, - ) - except ModuleNotFoundError as exc: - self.skipTest("pt backend dependencies are unavailable: " + str(exc)) - - abstract_methods = getattr(DeepEval, "__abstractmethods__", frozenset()) - try: - DeepEval.__abstractmethods__ = frozenset() - deep_eval = object.__new__(DeepEval) - finally: - DeepEval.__abstractmethods__ = abstract_methods - - model = MagicMock() - model.eval_descriptor.return_value = np.array([1.0, 2.0, 3.0]) - model.eval_fitting_last_layer.return_value = np.array([4.0, 5.0, 6.0]) - - deep_eval.dp = SimpleNamespace(model={"Default": model}) - deep_eval.auto_batch_size = DummyAutoBatchSize() - return deep_eval, model - - def _assert_retry_clears_hook_between_attempts( - self, - method_name: str, - hook_name: str, - expected: np.ndarray, - ) -> None: - deep_eval, model = self._make_backend(method_name) - with patch.object( - deep_eval, "eval", side_effect=[RetrySignal, None] - ) as eval_mock: - result = getattr(deep_eval, method_name)( - coords=np.zeros((3, 1, 3)), - cells=None, - atom_types=np.array([0]), - ) - self.assertEqual(eval_mock.call_count, 2) - np.testing.assert_array_equal(result, expected) - self.assertEqual( - getattr(model, hook_name).call_args_list, - [call(True), call(False), call(True), call(False)], - ) - self.assertFalse(deep_eval.auto_batch_size.oom_retry_mode) - self.assertEqual(deep_eval.auto_batch_size.modes, [True, False, True, False]) - - def _assert_runtime_error_clears_state( - self, - method_name: str, - hook_name: str, - ) -> None: - deep_eval, model = self._make_backend(method_name) - with patch.object( - deep_eval, - "eval", - side_effect=RuntimeError("non-retry failure"), - ): - with self.assertRaisesRegex(RuntimeError, "non-retry failure"): - getattr(deep_eval, method_name)( - coords=np.zeros((3, 1, 3)), - cells=None, - atom_types=np.array([0]), - ) - self.assertEqual( - getattr(model, hook_name).call_args_list, [call(True), call(False)] - ) - self.assertFalse(deep_eval.auto_batch_size.oom_retry_mode) - self.assertEqual(deep_eval.auto_batch_size.modes, [True, False]) - - def test_eval_descriptor_retry_clears_hook_between_attempts(self) -> None: - self._assert_retry_clears_hook_between_attempts( - "eval_descriptor", - "set_eval_descriptor_hook", - np.array([1.0, 2.0, 3.0]), - ) - - def test_eval_fitting_last_layer_retry_clears_hook_between_attempts( - self, - ) -> None: - self._assert_retry_clears_hook_between_attempts( - "eval_fitting_last_layer", - "set_eval_fitting_last_layer_hook", - np.array([4.0, 5.0, 6.0]), - ) - - def test_eval_descriptor_runtime_error_clears_state(self) -> None: - self._assert_runtime_error_clears_state( - "eval_descriptor", - "set_eval_descriptor_hook", - ) - - def test_eval_fitting_last_layer_runtime_error_clears_state(self) -> None: - self._assert_runtime_error_clears_state( - "eval_fitting_last_layer", - "set_eval_fitting_last_layer_hook", - ) - - -if __name__ == "__main__": - unittest.main()