diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index ecfd08b61a..b61e82bf04 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -62,6 +62,45 @@ def __init__( self.rcond = rcond self.preset_out_bias = preset_out_bias self.data_stat_protect = data_stat_protect + self._observed_type: list[str] | None = None + + @property + def observed_type(self) -> list[str] | None: + """Get the observed element type list from data statistics.""" + return self._observed_type + + def _collect_and_set_observed_type( + self, + sampled_func: Callable[[], list[dict]], + stat_file_path: DPPath | None, + preset_observed_type: list[str] | None, + ) -> None: + """Collect observed types with priority: preset > stat_file > compute. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames. + stat_file_path + The path to the statistics files (should already include type_map suffix). + preset_observed_type + User-specified observed types that take highest priority. + """ + from deepmd.dpmodel.utils.stat import ( + _restore_observed_type_from_file, + _save_observed_type_to_file, + collect_observed_types, + ) + + if preset_observed_type is not None: + self._observed_type = preset_observed_type + else: + observed = _restore_observed_type_from_file(stat_file_path) + if observed is None: + sampled = sampled_func() + observed = collect_observed_types(sampled, self.type_map) + _save_observed_type_to_file(stat_file_path, observed) + self._observed_type = observed def init_out_stat(self) -> None: """Initialize the output bias.""" @@ -271,6 +310,29 @@ def get_compute_stats_distinguish_types(self) -> bool: """Get whether the fitting net computes stats which are not distinguished between different types of atoms.""" return True + def compute_or_load_stat( + self, + sampled_func: Callable[[], list[dict]], + stat_file_path: DPPath | None = None, + compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, + ) -> None: + """Compute or load the statistics parameters of the model, + such as mean and standard deviation of descriptors or the energy bias of the fitting net. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The path to the stat file. + compute_or_load_out_stat : bool + Whether to compute the output statistics. + If False, it will only compute the input statistics + (e.g. mean and standard deviation of descriptors). + """ + raise NotImplementedError + def compute_or_load_out_stat( self, merged: Callable[[], list[dict]] | list[dict], diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 73447de955..5d81ed0538 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -201,6 +201,7 @@ def compute_or_load_stat( sampled_func: Callable[[], list[dict]], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> None: """Compute or load the statistics parameters of the model, such as mean and standard deviation of descriptors or the energy bias of the fitting net. @@ -227,6 +228,10 @@ def compute_or_load_stat( if compute_or_load_out_stat: self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + self._collect_and_set_observed_type( + wrapped_sampler, stat_file_path, preset_observed_type + ) + def change_type_map( self, type_map: list[str], model_with_new_type_stat: Any | None = None ) -> None: diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 06ed524ef6..aea1415a1c 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -349,6 +349,7 @@ def compute_or_load_stat( sampled_func: Callable[[], list[dict]], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> None: """Compute or load the statistics parameters of the model. @@ -364,9 +365,21 @@ def compute_or_load_stat( compute_or_load_out_stat : bool Whether to compute the output statistics. """ + # Compute observed type once at parent level, then propagate to + # sub-models via preset_observed_type to avoid redundant computation. + obs_stat_path = stat_file_path + if obs_stat_path is not None and self.type_map is not None: + obs_stat_path = obs_stat_path / " ".join(self.type_map) + self._collect_and_set_observed_type( + sampled_func, obs_stat_path, preset_observed_type + ) + for md in self.models: md.compute_or_load_stat( - sampled_func, stat_file_path, compute_or_load_out_stat=False + sampled_func, + stat_file_path, + compute_or_load_out_stat=False, + preset_observed_type=self._observed_type, ) if stat_file_path is not None and self.type_map is not None: diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 834ee19016..51c370eca0 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -216,6 +216,7 @@ def compute_or_load_stat( sampled_func: Callable[[], list[dict]], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> None: """Compute or load the statistics parameters of the model. @@ -235,6 +236,15 @@ def compute_or_load_stat( wrapped_sampler = self._make_wrapped_sampler(sampled_func) self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + self._collect_and_set_observed_type( + sampled_func if callable(sampled_func) else lambda: sampled_func, + stat_file_path, + preset_observed_type, + ) + def forward_atomic( self, extended_coord: Array, diff --git a/deepmd/dpmodel/infer/deep_eval.py b/deepmd/dpmodel/infer/deep_eval.py index c80898ec74..ac6963b435 100644 --- a/deepmd/dpmodel/infer/deep_eval.py +++ b/deepmd/dpmodel/infer/deep_eval.py @@ -50,6 +50,9 @@ from deepmd.infer.deep_wfc import ( DeepWFC, ) +from deepmd.utils.econf_embd import ( + sort_element_type, +) if TYPE_CHECKING: import ase.neighborlist @@ -403,6 +406,31 @@ def get_model_def_script(self) -> dict: """Get model definition script.""" return json.loads(self.dp.get_model_def_script()) + def get_observed_types(self) -> dict: + """Get observed types (elements) of the model during data statistics. + + Returns + ------- + dict + A dictionary containing the information of observed type in the model: + - 'type_num': the total number of observed types in this model. + - 'observed_type': a list of the observed types in this model. + """ + # Try metadata first (from model_def_script) + model_def_script = self.get_model_def_script() + observed_type_list = model_def_script.get("info", {}).get("observed_type") + if observed_type_list is not None: + return { + "type_num": len(observed_type_list), + "observed_type": observed_type_list, + } + # Fallback: bias-based approach for old models + observed_type_list = self.dp.get_observed_type_list() + return { + "type_num": len(observed_type_list), + "observed_type": sort_element_type(observed_type_list), + } + def get_model(self) -> "BaseModel": """Get the dpmodel BaseModel. diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index f06451b3fa..9dc83ac9d6 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -381,6 +381,8 @@ def get_out_bias(self) -> Array: def get_observed_type_list(self) -> list[str]: """Get observed types (elements) of the model during data statistics. + Bias-based fallback for old models without metadata. + Returns ------- list[str] @@ -718,6 +720,7 @@ def compute_or_load_stat( self, sampled_func: Callable[[], Any], stat_file_path: DPPath | None = None, + preset_observed_type: list[str] | None = None, ) -> None: """Compute or load the statistics parameters of the model. @@ -728,8 +731,12 @@ def compute_or_load_stat( data systems. stat_file_path The path to the stat file. + preset_observed_type + User-specified observed types that take highest priority. """ - self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path) + self.atomic_model.compute_or_load_stat( + sampled_func, stat_file_path, preset_observed_type=preset_observed_type + ) def get_model_def_script(self) -> str: """Get the model definition script.""" diff --git a/deepmd/dpmodel/utils/stat.py b/deepmd/dpmodel/utils/stat.py index 34c500d7c8..e9068d38e3 100644 --- a/deepmd/dpmodel/utils/stat.py +++ b/deepmd/dpmodel/utils/stat.py @@ -29,6 +29,62 @@ log = logging.getLogger(__name__) +def collect_observed_types(sampled: list[dict], type_map: list[str]) -> list[str]: + """Collect observed element types from sampled training data. + + Parameters + ---------- + sampled : list[dict] + Sampled data from different data systems. Each dict must contain + ``"atype"`` with shape ``[nframes, natoms]``. + type_map : list[str] + Mapping from type index to element symbol. + + Returns + ------- + list[str] + Sorted list of observed element symbols. + """ + from deepmd.utils.econf_embd import ( + sort_element_type, + ) + + observed_indices: set[int] = set() + for system in sampled: + atype = to_numpy_array(system["atype"]) # shape: [nframes, natoms] + observed_indices.update(np.unique(atype).tolist()) + observed_types = [ + type_map[i] for i in sorted(observed_indices) if i < len(type_map) + ] + return sort_element_type(observed_types) + + +def _restore_observed_type_from_file( + stat_file_path: DPPath | None, +) -> list[str] | None: + """Try to load observed_type from stat file.""" + if stat_file_path is None: + return None + fp = stat_file_path / "observed_type" + if fp.is_file(): + arr = fp.load_numpy() + # Decode bytes back to str if stored as bytes (for h5py compatibility) + return [x.decode() if isinstance(x, bytes) else x for x in arr.tolist()] + return None + + +def _save_observed_type_to_file( + stat_file_path: DPPath | None, observed_type: list[str] +) -> None: + """Save observed_type to stat file.""" + if stat_file_path is None: + return + stat_file_path.mkdir(exist_ok=True, parents=True) + fp = stat_file_path / "observed_type" + # Use bytes dtype for h5py compatibility (h5py cannot store Unicode strings) + fp.save_numpy(np.array(observed_type, dtype="S")) + + def _restore_from_file( stat_file_path: DPPath, keys: list[str], diff --git a/deepmd/entrypoints/show.py b/deepmd/entrypoints/show.py index ddd097d22c..7fd3e81467 100644 --- a/deepmd/entrypoints/show.py +++ b/deepmd/entrypoints/show.py @@ -126,6 +126,13 @@ def show( ) else: log.info("The observed types for this model: ") - observed_types = model.get_observed_types() + observed_type_list = model_params.get("info", {}).get("observed_type") + if observed_type_list is not None: + observed_types = { + "type_num": len(observed_type_list), + "observed_type": observed_type_list, + } + else: + observed_types = model.get_observed_types() log.info(f"Number of observed types: {observed_types['type_num']} ") log.info(f"Observed types: {observed_types['observed_type']} ") diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 6e63ecb2fc..11a877040d 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -736,6 +736,14 @@ def get_observed_types(self) -> dict: - 'type_num': the total number of observed types in this model. - 'observed_type': a list of the observed types in this model. """ + # Try metadata first (from model_def_script, already a dict) + observed_type_list = self.model_def_script.get("info", {}).get("observed_type") + if observed_type_list is not None: + return { + "type_num": len(observed_type_list), + "observed_type": observed_type_list, + } + # Fallback: bias-based approach for old models observed_type_list = self.dp.model["Default"].get_observed_type_list() return { "type_num": len(observed_type_list), diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 920b83d12b..596c4c07fe 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -90,6 +90,45 @@ def __init__( self.rcond = rcond self.preset_out_bias = preset_out_bias self.data_stat_protect = data_stat_protect + self._observed_type: list[str] | None = None + + @property + def observed_type(self) -> list[str] | None: + """Get the observed element type list from data statistics.""" + return self._observed_type + + def _collect_and_set_observed_type( + self, + sampled_func: Callable[[], list[dict]], + stat_file_path: "DPPath | None", + preset_observed_type: list[str] | None, + ) -> None: + """Collect observed types with priority: preset > stat_file > compute. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames. + stat_file_path + The path to the statistics files (should already include type_map suffix). + preset_observed_type + User-specified observed types that take highest priority. + """ + from deepmd.dpmodel.utils.stat import ( + _restore_observed_type_from_file, + _save_observed_type_to_file, + collect_observed_types, + ) + + if preset_observed_type is not None: + self._observed_type = preset_observed_type + else: + observed = _restore_observed_type_from_file(stat_file_path) + if observed is None: + sampled = sampled_func() + observed = collect_observed_types(sampled, self.type_map) + _save_observed_type_to_file(stat_file_path, observed) + self._observed_type = observed def init_out_stat(self) -> None: """Initialize the output bias.""" @@ -376,6 +415,7 @@ def compute_or_load_stat( merged: Callable[[], list[dict]] | list[dict], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> NoReturn: """ Compute or load the statistics parameters of the model, diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 78fa0c3cf7..3fd501df13 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -304,6 +304,7 @@ def compute_or_load_stat( sampled_func: Callable[[], list[dict]], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> None: """ Compute or load the statistics parameters of the model, @@ -355,6 +356,10 @@ def wrapped_sampler() -> list[dict]: if compute_or_load_out_stat: self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + self._collect_and_set_observed_type( + wrapped_sampler, stat_file_path, preset_observed_type + ) + def compute_fitting_input_stat( self, sample_merged: Callable[[], list[dict]] | list[dict], diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index de3acfcaca..b00393c0ff 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -476,6 +476,7 @@ def compute_or_load_stat( sampled_func: Callable[[], list[dict[str, Any]]], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> None: """ Compute or load the statistics parameters of the model, @@ -495,9 +496,21 @@ def compute_or_load_stat( Whether to compute the output statistics. If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors). """ + # Compute observed type once at parent level, then propagate to + # sub-models via preset_observed_type to avoid redundant computation. + obs_stat_path = stat_file_path + if obs_stat_path is not None and self.type_map is not None: + obs_stat_path = obs_stat_path / " ".join(self.type_map) + self._collect_and_set_observed_type( + sampled_func, obs_stat_path, preset_observed_type + ) + for md in self.models: md.compute_or_load_stat( - sampled_func, stat_file_path, compute_or_load_out_stat=False + sampled_func, + stat_file_path, + compute_or_load_out_stat=False, + preset_observed_type=self._observed_type, ) if stat_file_path is not None and self.type_map is not None: diff --git a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py index 4424509776..e838a7a24d 100644 --- a/deepmd/pt/model/atomic_model/pairtab_atomic_model.py +++ b/deepmd/pt/model/atomic_model/pairtab_atomic_model.py @@ -228,6 +228,7 @@ def compute_or_load_stat( sampled_func: Callable[[], list[dict]] | list[dict], stat_file_path: DPPath | None = None, compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, ) -> None: """ Compute or load the statistics parameters of the model, @@ -251,6 +252,15 @@ def compute_or_load_stat( if compute_or_load_out_stat: self.compute_or_load_out_stat(sampled_func, stat_file_path) + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + self._collect_and_set_observed_type( + sampled_func if callable(sampled_func) else lambda: sampled_func, + stat_file_path, + preset_observed_type, + ) + def forward_atomic( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c02ce5265c..83e0209ad8 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -602,9 +602,14 @@ def compute_or_load_stat( self, sampled_func: Callable[[], Any], stat_file_path: DPPath | None = None, + preset_observed_type: list[str] | None = None, ) -> None: """Compute or load the statistics.""" - return self.atomic_model.compute_or_load_stat(sampled_func, stat_file_path) + return self.atomic_model.compute_or_load_stat( + sampled_func, + stat_file_path, + preset_observed_type=preset_observed_type, + ) @torch.jit.export def get_observed_type_list(self) -> list[str]: diff --git a/deepmd/pt/model/model/model.py b/deepmd/pt/model/model/model.py index d9169fdc54..b4f09b40ec 100644 --- a/deepmd/pt/model/model/model.py +++ b/deepmd/pt/model/model/model.py @@ -30,6 +30,7 @@ def compute_or_load_stat( self, sampled_func: Any, stat_file_path: DPPath | None = None, + preset_observed_type: list[str] | None = None, ) -> NoReturn: """ Compute or load the statistics parameters of the model, diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index f08fba756b..6d864c3205 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -392,6 +392,7 @@ def compute_or_load_stat( self, sampled_func: Callable[[], list[dict[str, Any]]], stat_file_path: DPPath | None = None, + preset_observed_type: list[str] | None = None, ) -> None: """ Compute or load the statistics parameters of the model, @@ -432,7 +433,11 @@ def spin_sampled_func() -> list[dict[str, Any]]: spin_sampled.append(tmp_dict) return spin_sampled - self.backbone_model.compute_or_load_stat(spin_sampled_func, stat_file_path) + self.backbone_model.compute_or_load_stat( + spin_sampled_func, + stat_file_path, + preset_observed_type=preset_observed_type, + ) def forward_common( self, diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index f28edb1430..1259a65f42 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import functools +import json import logging import time from collections.abc import ( @@ -288,6 +289,7 @@ def single_model_stat( _training_data: DpLoaderSet, _stat_file_path: str | None, finetune_has_new_type: bool = False, + preset_observed_type: list[str] | None = None, ) -> Callable[[], Any]: @functools.lru_cache def get_sample() -> Any: @@ -302,6 +304,7 @@ def get_sample() -> Any: _model.compute_or_load_stat( sampled_func=get_sample, stat_file_path=_stat_file_path, + preset_observed_type=preset_observed_type, ) if isinstance(_stat_file_path, DPH5Path): _stat_file_path.root.close() @@ -394,7 +397,14 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: finetune_has_new_type=self.finetune_links["Default"].get_has_new_type() if self.finetune_links is not None else False, + preset_observed_type=model_params.get("info", {}).get("observed_type"), ) + # Persist observed_type from stat into model_params and model_def_script + if not resuming and self.rank == 0: + observed = self.model.atomic_model.observed_type + if observed is not None: + model_params.setdefault("info", {})["observed_type"] = observed + self.model.model_def_script = json.dumps(model_params) ( self.training_dataloader, self.training_data, @@ -432,6 +442,11 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: training_data[model_key].preload_and_modify_all_data_torch() if validation_data[model_key] is not None: validation_data[model_key].preload_and_modify_all_data_torch() + _mt_user_observed = ( + model_params["model_dict"][model_key] + .get("info", {}) + .get("observed_type") + ) self.get_sample_func[model_key] = single_model_stat( self.model[model_key], model_params["model_dict"][model_key].get("data_stat_nbatch", 10), @@ -442,7 +457,18 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR: ].get_has_new_type() if self.finetune_links is not None else False, + preset_observed_type=_mt_user_observed, ) + # Persist observed_type into model_params and model_def_script + if not resuming and self.rank == 0: + observed = self.model[model_key].atomic_model.observed_type + if observed is not None: + model_params["model_dict"][model_key].setdefault("info", {})[ + "observed_type" + ] = observed + self.model[model_key].model_def_script = json.dumps( + model_params["model_dict"][model_key] + ) ( self.training_dataloader[model_key], diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index b2824d0ac4..7f0c17be51 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -35,6 +35,19 @@ log = logging.getLogger(__name__) +# Re-export from dpmodel (backend-agnostic implementations) +from deepmd.dpmodel.utils.stat import ( + _restore_observed_type_from_file, + _save_observed_type_to_file, + collect_observed_types, +) + +__all__ = [ + "_restore_observed_type_from_file", + "_save_observed_type_to_file", + "collect_observed_types", +] + def make_stat_input( datasets: list[Any], dataloaders: list[Any], nbatches: int diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index eaa0892369..dc4bf1b5c3 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2350,12 +2350,12 @@ def standard_model_args() -> Argument: doc_model_branch_alias = ( "List of aliases for this model branch. " "Multiple aliases can be defined, and any alias can reference this branch throughout the model usage. " - "Used only in multitask models." + "Used only in multi-task models." ) doc_info = ( - "Dictionary of metadata for this model branch. " - "Store arbitrary key-value pairs with branch-specific information. " - "Used only in multitask models." + "Dictionary of metadata for this model or model branch. " + "Store arbitrary key-value pairs with model- or branch-specific information. " + "Used in both single- and multi-task models." ) ca = Argument( diff --git a/source/tests/common/dpmodel/test_observed_type.py b/source/tests/common/dpmodel/test_observed_type.py new file mode 100644 index 0000000000..c08b672415 --- /dev/null +++ b/source/tests/common/dpmodel/test_observed_type.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import shutil +import tempfile +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np + +from deepmd.dpmodel.atomic_model.pairtab_atomic_model import ( + PairTabAtomicModel, +) +from deepmd.dpmodel.utils.stat import ( + _restore_observed_type_from_file, + _save_observed_type_to_file, + collect_observed_types, +) +from deepmd.utils.path import ( + DPPath, +) + + +class TestCollectObservedTypes(unittest.TestCase): + """Test collect_observed_types with mock sampled data (numpy backend).""" + + def test_single_system(self) -> None: + sampled = [ + {"atype": np.array([[0, 1, 0, 1]])}, + ] + type_map = ["O", "H", "Au"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["H", "O"]) + + def test_multiple_systems(self) -> None: + sampled = [ + {"atype": np.array([[0, 0, 0]])}, + {"atype": np.array([[1, 1, 2]])}, + ] + type_map = ["O", "H", "Au"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["H", "O", "Au"]) + + def test_subset_of_types(self) -> None: + sampled = [ + {"atype": np.array([[2, 2]])}, + ] + type_map = ["O", "H", "Au"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["Au"]) + + def test_multi_frame(self) -> None: + sampled = [ + {"atype": np.array([[0, 1], [0, 0]])}, + ] + type_map = ["O", "H"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["H", "O"]) + + def test_out_of_range_index_ignored(self) -> None: + sampled = [ + {"atype": np.array([[0, 5]])}, + ] + type_map = ["O", "H"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["O"]) + + +class TestObservedTypeStatFile(unittest.TestCase): + """Test stat file save/load round-trip for observed_type (dpmodel).""" + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.tmpdir) + + def test_save_and_restore(self) -> None: + stat_path = DPPath(self.tmpdir, mode="w") + observed = ["H", "O"] + _save_observed_type_to_file(stat_path, observed) + restored = _restore_observed_type_from_file(DPPath(self.tmpdir)) + self.assertEqual(restored, observed) + + def test_restore_missing_file(self) -> None: + stat_path = DPPath(self.tmpdir, mode="r") + result = _restore_observed_type_from_file(stat_path) + self.assertIsNone(result) + + def test_restore_none_path(self) -> None: + result = _restore_observed_type_from_file(None) + self.assertIsNone(result) + + def test_save_none_path(self) -> None: + # Should not raise + _save_observed_type_to_file(None, ["H", "O"]) + + +class TestPairTabObservedType(unittest.TestCase): + """Test observed_type collection for dpmodel PairTabAtomicModel.""" + + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + # 3 types -> ntypes*(ntypes+1)/2 = 6 energy columns -> 7 total columns + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0, 1.5, 2.5, 3.5], + [0.01, 0.8, 1.6, 2.4, 1.2, 2.0, 2.8], + [0.015, 0.5, 1.0, 1.5, 0.75, 1.25, 1.75], + [0.02, 0.25, 0.4, 0.75, 0.35, 0.6, 0.9], + ] + ) + self.model = PairTabAtomicModel( + tab_file="dummy_path", rcut=0.02, sel=2, type_map=["H", "O", "Au"] + ) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.tmpdir) + + def _make_sampled(self, atypes: list[list[list[int]]]) -> list[dict]: + """Create mock sampled data from atype arrays.""" + return [{"atype": np.array(a)} for a in atypes] + + def test_compute_observed_type_from_data(self) -> None: + """PairTab should collect observed types from sampled data.""" + sampled = self._make_sampled([[[0, 1, 0, 1]]]) # H and O only + self.model.compute_or_load_stat( + lambda: sampled, + stat_file_path=DPPath(self.tmpdir, mode="w"), + compute_or_load_out_stat=False, + ) + self.assertIsNotNone(self.model.observed_type) + self.assertIn("H", self.model.observed_type) + self.assertIn("O", self.model.observed_type) + self.assertNotIn("Au", self.model.observed_type) + + def test_preset_observed_type_takes_priority(self) -> None: + """Preset observed_type should override data-based computation.""" + sampled = self._make_sampled([[[0, 1]]]) # H and O in data + preset = ["H", "O", "Au"] + self.model.compute_or_load_stat( + lambda: sampled, + stat_file_path=DPPath(self.tmpdir, mode="w"), + compute_or_load_out_stat=False, + preset_observed_type=preset, + ) + self.assertEqual(self.model.observed_type, preset) + + +class TestLinearModelObservedType(unittest.TestCase): + """Test observed_type propagation in dpmodel DPZBLLinearEnergyAtomicModel.""" + + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + from deepmd.dpmodel.atomic_model import ( + DPAtomicModel, + ) + from deepmd.dpmodel.atomic_model.linear_atomic_model import ( + DPZBLLinearEnergyAtomicModel, + ) + from deepmd.dpmodel.descriptor import ( + DescrptDPA1, + ) + from deepmd.dpmodel.fitting.invar_fitting import ( + InvarFitting, + ) + + type_map = ["H", "O", "Au"] + + # 3 types -> ntypes*(ntypes+1)/2 = 6 energy columns -> 7 total columns + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0, 1.5, 2.5, 3.5], + [0.01, 0.8, 1.6, 2.4, 1.2, 2.0, 2.8], + [0.015, 0.5, 1.0, 1.5, 0.75, 1.25, 1.75], + [0.02, 0.25, 0.4, 0.75, 0.35, 0.6, 0.9], + ] + ) + zbl_model = PairTabAtomicModel( + tab_file="dummy_path", rcut=0.02, sel=2, type_map=type_map + ) + + ds = DescrptDPA1( + rcut_smth=0.3, + rcut=0.4, + sel=[3], + ntypes=len(type_map), + ) + ft = InvarFitting( + "energy", + len(type_map), + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + dp_model = DPAtomicModel(ds, ft, type_map=type_map) + + self.linear_model = DPZBLLinearEnergyAtomicModel( + dp_model, + zbl_model, + sw_rmin=0.1, + sw_rmax=0.25, + type_map=type_map, + ) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.tmpdir) + + def _make_sampled(self, atypes: list[list[list[int]]]) -> list[dict]: + """Create mock sampled data from atype arrays.""" + return [{"atype": np.array(a)} for a in atypes] + + def test_parent_observed_type_from_data(self) -> None: + """Parent (linear) model should collect observed types from data.""" + sampled = self._make_sampled([[[0, 1, 0, 1]]]) # H and O only + # Mock descriptor/fitting input stats to avoid needing coord/nlist + with ( + patch.object(self.linear_model.models[0].descriptor, "compute_input_stats"), + patch.object( + self.linear_model.models[0].fitting_net, "compute_input_stats" + ), + ): + self.linear_model.compute_or_load_stat( + lambda: sampled, + stat_file_path=DPPath(self.tmpdir, mode="w"), + compute_or_load_out_stat=False, + ) + self.assertIsNotNone(self.linear_model.observed_type) + self.assertIn("H", self.linear_model.observed_type) + self.assertIn("O", self.linear_model.observed_type) + self.assertNotIn("Au", self.linear_model.observed_type) + + def test_submodels_get_propagated_observed_type(self) -> None: + """Sub-models should receive parent's observed type via propagation.""" + sampled = self._make_sampled([[[0, 1, 0, 1]]]) # H and O only + # Mock descriptor/fitting input stats to avoid needing coord/nlist + with ( + patch.object(self.linear_model.models[0].descriptor, "compute_input_stats"), + patch.object( + self.linear_model.models[0].fitting_net, "compute_input_stats" + ), + ): + self.linear_model.compute_or_load_stat( + lambda: sampled, + stat_file_path=DPPath(self.tmpdir, mode="w"), + compute_or_load_out_stat=False, + ) + dp_model = self.linear_model.models[0] + zbl_model = self.linear_model.models[1] + # All three should have the same observed type (propagated from parent) + self.assertEqual(dp_model.observed_type, self.linear_model.observed_type) + self.assertEqual(zbl_model.observed_type, self.linear_model.observed_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/consistent/test_observed_type.py b/source/tests/consistent/test_observed_type.py new file mode 100644 index 0000000000..0fad70f886 --- /dev/null +++ b/source/tests/consistent/test_observed_type.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Consistency test: observed_type should match across PT and dpmodel backends.""" + +import json +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +from .common import ( + INSTALLED_PT, +) + +if INSTALLED_PT: + from deepmd.infer import ( + DeepPot, + ) + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + from ..pt.common import ( + run_dp, + ) + from ..pt.model.test_permutation import ( + model_se_e2_a, + ) + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") +class TestObservedTypeConsistent(unittest.TestCase): + """Train PT model, freeze to .pth and .dp, compare get_observed_types().""" + + @classmethod + def setUpClass(cls) -> None: + cls.work_dir = tempfile.mkdtemp() + cls.orig_dir = os.getcwd() + os.chdir(cls.work_dir) + + input_json = str( + Path(__file__).parent.parent / "pt" / "water" / "se_atten.json" + ) + with open(input_json) as f: + config = json.load(f) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + data_file = [ + str(Path(__file__).parent.parent / "pt" / "water" / "data" / "single") + ] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + config["model"] = deepcopy(model_se_e2_a) + config["model"]["type_map"] = ["O", "H", "Au"] + + # Train and freeze .pth + trainer = get_trainer(deepcopy(config)) + trainer.run() + run_dp("dp --pt freeze") + + # Convert .pth → .dp via serialize/deserialize hooks + from deepmd.dpmodel.utils.serialization import ( + save_dp_model, + ) + from deepmd.pt.utils.serialization import ( + serialize_from_file, + ) + + model_dict = serialize_from_file("frozen_model.pth") + save_dp_model("frozen_model.dp", model_dict) + + # Load both + cls.pt_model = DeepPot("frozen_model.pth") + cls.dp_model = DeepPot("frozen_model.dp") + + @classmethod + def tearDownClass(cls) -> None: + os.chdir(cls.orig_dir) + shutil.rmtree(cls.work_dir) + + def test_get_observed_types_consistent(self) -> None: + pt_result = self.pt_model.deep_eval.get_observed_types() + dp_result = self.dp_model.deep_eval.get_observed_types() + self.assertEqual(pt_result, dp_result) + # Training data only has O and H + self.assertEqual(pt_result["observed_type"], ["H", "O"]) + + def test_type_map_consistent(self) -> None: + self.assertEqual( + self.pt_model.deep_eval.get_type_map(), + self.dp_model.deep_eval.get_type_map(), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_observed_type.py b/source/tests/pt/test_observed_type.py new file mode 100644 index 0000000000..c834ec567e --- /dev/null +++ b/source/tests/pt/test_observed_type.py @@ -0,0 +1,427 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) +from unittest.mock import ( + patch, +) + +import numpy as np +import torch + +from deepmd.pt.utils.stat import ( + _restore_observed_type_from_file, + _save_observed_type_to_file, + collect_observed_types, +) +from deepmd.utils.path import ( + DPPath, +) + + +class TestCollectObservedTypes(unittest.TestCase): + """Test collect_observed_types with mock sampled data.""" + + def test_single_system(self) -> None: + sampled = [ + {"atype": torch.tensor([[0, 1, 0, 1]], device="cpu")}, + ] + type_map = ["O", "H", "Au"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["H", "O"]) + + def test_multiple_systems(self) -> None: + sampled = [ + {"atype": torch.tensor([[0, 0, 0]], device="cpu")}, + {"atype": torch.tensor([[1, 1, 2]], device="cpu")}, + ] + type_map = ["O", "H", "Au"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["H", "O", "Au"]) + + def test_subset_of_types(self) -> None: + sampled = [ + {"atype": torch.tensor([[2, 2]], device="cpu")}, + ] + type_map = ["O", "H", "Au"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["Au"]) + + def test_multi_frame(self) -> None: + sampled = [ + {"atype": torch.tensor([[0, 1], [0, 0]], device="cpu")}, + ] + type_map = ["O", "H"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["H", "O"]) + + def test_out_of_range_index_ignored(self) -> None: + sampled = [ + {"atype": torch.tensor([[0, 5]], device="cpu")}, + ] + type_map = ["O", "H"] + result = collect_observed_types(sampled, type_map) + self.assertEqual(result, ["O"]) + + +class TestObservedTypeStatFile(unittest.TestCase): + """Test stat file save/load round-trip for observed_type.""" + + def setUp(self) -> None: + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.tmpdir) + + def test_save_and_restore(self) -> None: + stat_path = DPPath(self.tmpdir, mode="w") + observed = ["H", "O"] + _save_observed_type_to_file(stat_path, observed) + restored = _restore_observed_type_from_file(DPPath(self.tmpdir)) + self.assertEqual(restored, observed) + + def test_restore_missing_file(self) -> None: + stat_path = DPPath(self.tmpdir, mode="r") + result = _restore_observed_type_from_file(stat_path) + self.assertIsNone(result) + + def test_restore_none_path(self) -> None: + result = _restore_observed_type_from_file(None) + self.assertIsNone(result) + + def test_save_none_path(self) -> None: + # Should not raise + _save_observed_type_to_file(None, ["H", "O"]) + + +class TestObservedTypeTraining(unittest.TestCase): + """Test observed_type persistence through training pipeline.""" + + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + from .model.test_permutation import ( + model_se_e2_a, + ) + + self.config["model"] = deepcopy(model_se_e2_a) + self.config["model"]["type_map"] = ["O", "H", "Au"] + + def test_model_params_has_observed_type_in_info(self) -> None: + """After training, model_params['info']['observed_type'] should be set.""" + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + state = trainer.wrapper.state_dict() + model_params = state["_extra_state"]["model_params"] + self.assertIn("info", model_params) + self.assertIn("observed_type", model_params["info"]) + observed = model_params["info"]["observed_type"] + # Training data only has O and H + self.assertIn("H", observed) + self.assertIn("O", observed) + self.assertNotIn("Au", observed) + + def test_model_def_script_has_observed_type(self) -> None: + """model.model_def_script (str) should contain info.observed_type after training.""" + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + # model_def_script is a JSON string on the model object + script_str = trainer.model.model_def_script + self.assertIsInstance(script_str, str) + script_dict = json.loads(script_str) + self.assertIn("info", script_dict) + self.assertIn("observed_type", script_dict["info"]) + observed = script_dict["info"]["observed_type"] + self.assertIn("H", observed) + self.assertIn("O", observed) + self.assertNotIn("Au", observed) + + def test_frozen_model_has_observed_type(self) -> None: + """After freeze, the frozen model should carry observed_type via get_model_def_script.""" + from deepmd.infer import ( + DeepPot, + ) + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + from .common import ( + run_dp, + ) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + run_dp("dp --pt freeze") + # Load frozen model via DeepPot and check model_def_script (dict) + model = DeepPot("frozen_model.pth") + script_dict = model.deep_eval.model_def_script + self.assertIsInstance(script_dict, dict) + self.assertIn("info", script_dict) + self.assertIn("observed_type", script_dict["info"]) + observed = script_dict["info"]["observed_type"] + self.assertIn("H", observed) + self.assertIn("O", observed) + self.assertNotIn("Au", observed) + + def test_deep_eval_get_observed_types_uses_metadata(self) -> None: + """DeepEval.get_observed_types() should return metadata-based result.""" + from deepmd.infer import ( + DeepPot, + ) + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + from .common import ( + run_dp, + ) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + run_dp("dp --pt freeze") + model = DeepPot("frozen_model.pth") + observed = model.deep_eval.get_observed_types() + self.assertEqual(observed["type_num"], 2) + self.assertEqual(observed["observed_type"], ["H", "O"]) + + def test_user_preset_observed_type(self) -> None: + """User-specified observed_type in config['model']['info'] takes precedence.""" + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + config = deepcopy(self.config) + config["model"].setdefault("info", {})["observed_type"] = ["O", "H", "Au"] + trainer = get_trainer(config) + trainer.run() + state = trainer.wrapper.state_dict() + model_params = state["_extra_state"]["model_params"] + observed = model_params["info"]["observed_type"] + self.assertEqual(observed, ["O", "H", "Au"]) + + def test_stat_file_caching(self) -> None: + """Observed_type should be saved to and loaded from stat_file.""" + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + config = deepcopy(self.config) + config["training"]["stat_file"] = "stat_files" + os.makedirs("stat_files", exist_ok=True) + trainer = get_trainer(config) + trainer.run() + # The stat_file_path includes the type_map subdirectory + stat_base = Path("stat_files") / " ".join(["O", "H", "Au"]) + observed_file = stat_base / "observed_type" + if observed_file.exists(): + data = np.load(str(observed_file), allow_pickle=True) + decoded = [x.decode() if isinstance(x, bytes) else x for x in data.tolist()] + self.assertIn("H", decoded) + self.assertIn("O", decoded) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "frozen_model.pth", "output.txt", "checkpoint"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestObservedTypeFallback(unittest.TestCase): + """Test bias-based fallback for old models without metadata.""" + + def test_deep_eval_fallback(self) -> None: + """When model_def_script has no observed_type in info, fallback to bias-based.""" + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + config = json.load(f) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/single")] + config["training"]["training_data"]["systems"] = data_file + config["training"]["validation_data"]["systems"] = data_file + from .model.test_permutation import ( + model_se_e2_a, + ) + + config["model"] = deepcopy(model_se_e2_a) + config["model"]["type_map"] = ["O", "H", "Au"] + + from deepmd.infer import ( + DeepPot, + ) + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + from .common import ( + run_dp, + ) + + trainer = get_trainer(deepcopy(config)) + trainer.run() + run_dp("dp --pt freeze") + + model = DeepPot("frozen_model.pth") + # Simulate old model by removing observed_type from info + model.deep_eval.model_def_script.get("info", {}).pop("observed_type", None) + observed = model.deep_eval.get_observed_types() + # Should still work via bias-based fallback + self.assertIn("type_num", observed) + self.assertIn("observed_type", observed) + self.assertGreater(observed["type_num"], 0) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "frozen_model.pth", "output.txt", "checkpoint"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestPairTabObservedType(unittest.TestCase): + """Test observed_type collection for PairTabAtomicModel.""" + + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + from deepmd.pt.model.atomic_model import ( + PairTabAtomicModel, + ) + + # 3 types -> ntypes*(ntypes+1)/2 = 6 energy columns -> 7 total columns + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0, 1.5, 2.5, 3.5], + [0.01, 0.8, 1.6, 2.4, 1.2, 2.0, 2.8], + [0.015, 0.5, 1.0, 1.5, 0.75, 1.25, 1.75], + [0.02, 0.25, 0.4, 0.75, 0.35, 0.6, 0.9], + ] + ) + self.model = PairTabAtomicModel( + tab_file="dummy_path", rcut=0.02, sel=2, type_map=["H", "O", "Au"] + ) + self.tmpdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + shutil.rmtree(self.tmpdir) + + def _make_sampled(self, atypes: list[list[list[int]]]) -> list[dict]: + """Create mock sampled data from atype arrays.""" + return [{"atype": torch.tensor(a, device="cpu")} for a in atypes] + + def test_compute_observed_type_from_data(self) -> None: + """PairTab should collect observed types from sampled data.""" + sampled = self._make_sampled([[[0, 1, 0, 1]]]) # H and O only + self.model.compute_or_load_stat( + lambda: sampled, + stat_file_path=DPPath(self.tmpdir, mode="w"), + compute_or_load_out_stat=False, + ) + self.assertIsNotNone(self.model.observed_type) + self.assertIn("H", self.model.observed_type) + self.assertIn("O", self.model.observed_type) + self.assertNotIn("Au", self.model.observed_type) + + def test_preset_observed_type_takes_priority(self) -> None: + """Preset observed_type should override data-based computation.""" + sampled = self._make_sampled([[[0, 1]]]) # H and O in data + preset = ["H", "O", "Au"] + self.model.compute_or_load_stat( + lambda: sampled, + stat_file_path=DPPath(self.tmpdir, mode="w"), + compute_or_load_out_stat=False, + preset_observed_type=preset, + ) + self.assertEqual(self.model.observed_type, preset) + + +class TestDPZBLObservedType(unittest.TestCase): + """Test observed_type propagation in DPZBLLinearEnergyAtomicModel. + + The parent LinearEnergyAtomicModel computes observed type once, then + propagates it to sub-models via preset_observed_type to avoid redundant + computation. + """ + + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + data_file = [str(Path(__file__).parent / "water/data/single")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + + from .model.test_permutation import ( + model_zbl, + ) + + self.config["model"] = deepcopy(model_zbl) + + def test_parent_observed_type_from_data(self) -> None: + """Parent (linear) model should collect observed types from data.""" + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + observed = trainer.model.atomic_model.observed_type + self.assertIsNotNone(observed) + # Training data only has O and H (model_zbl type_map is ["O", "H", "B"]) + self.assertIn("H", observed) + self.assertIn("O", observed) + self.assertNotIn("B", observed) + + def test_submodels_get_propagated_observed_type(self) -> None: + """Sub-models should receive parent's observed type via propagation.""" + from deepmd.pt.entrypoints.main import ( + get_trainer, + ) + + trainer = get_trainer(deepcopy(self.config)) + trainer.run() + linear_model = trainer.model.atomic_model + dp_model = linear_model.models[0] + zbl_model = linear_model.models[1] + # All three should have the same observed type (propagated from parent) + self.assertEqual(dp_model.observed_type, linear_model.observed_type) + self.assertEqual(zbl_model.observed_type, linear_model.observed_type) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "frozen_model.pth", "output.txt", "checkpoint"]: + os.remove(f) + + +if __name__ == "__main__": + unittest.main()