diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 3043839308..85a4fbea96 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -40,6 +40,18 @@ def __getitem__(self, index): b_data["natoms"] = self._natoms_vec return b_data + def _build_element_to_frames(self): + """Mapping element types to frame indexes""" + element_to_frames = {element: [] for element in range(self._ntypes)} + for frame_idx in range(len(self)): + frame_data = self._data_system.get_item_torch(frame_idx) + + elements = frame_data["atype"] + for element in set(elements): + if len(element_to_frames[element]) < 10: + element_to_frames[element].append(frame_idx) + return element_to_frames + def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" for data_item in data_requirement: diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1c5e3f1c52..c28d1c3703 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -82,6 +82,46 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) + + all_elements = set() + if datasets and hasattr(datasets[0], "element_to_frames"): + all_elements.update(datasets[0].element_to_frames.keys()) + + collected_elements = set() + for sys_stat in lst: + if "atype" in sys_stat: + collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) + missing_elements = all_elements - collected_elements + + for missing_element in missing_elements: + for i, dataset in enumerate(datasets): + if hasattr(dataset, "element_to_frames"): + frame_indices = dataset.element_to_frames.get( + missing_element, [] + ) + for frame_idx in frame_indices: + if len(lst[i]["atype"]) >= nbatches: + break + frame_data = dataset[frame_idx] + for key in frame_data: + if key not in lst[i]: + lst[i][key] = [] + lst[i][key].append(frame_data[key]) + + collected_elements = set() + for sys_stat in lst: + if "atype" in sys_stat: + collected_elements.update( + np.unique(sys_stat["atype"].cpu().numpy()) + ) + + for sys_stat in lst: + for key in sys_stat: + if isinstance(sys_stat[key], list) and isinstance( + sys_stat[key][0], torch.Tensor + ): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + return lst