From 2748e1a8862a8c3fce38fdc42e31d8b8e1b38134 Mon Sep 17 00:00:00 2001 From: SumGuo Date: Mon, 23 Dec 2024 20:08:42 +0800 Subject: [PATCH 1/2] 4424fix --- deepmd/pt/utils/dataset.py | 12 ++++++++++++ deepmd/pt/utils/stat.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 3043839308..569cd29a94 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -39,6 +39,18 @@ def __getitem__(self, index): b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data + + def _build_element_to_frames(self): + """Mapping element types to frame indexes""" + element_to_frames = {element: [] for element in range(self._ntypes)} + for frame_idx in range(len(self)): + frame_data = self._data_system.get_item_torch(frame_idx) + + elements = frame_data["atype"] + for element in set(elements): + if len(element_to_frames[element]) < 10: + element_to_frames[element].append(frame_idx) + return element_to_frames def add_data_requirement(self, data_requirement: list[DataRequirementItem]) -> None: """Add data requirement for this data system.""" diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 1c5e3f1c52..21240e5f90 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -82,6 +82,40 @@ def make_stat_input(datasets, dataloaders, nbatches): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) lst.append(sys_stat) + + all_elements = set() + if datasets and hasattr(datasets[0], 'element_to_frames'): + all_elements.update(datasets[0].element_to_frames.keys()) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + missing_elements = all_elements - collected_elements + + for missing_element in missing_elements: + for i, dataset in enumerate(datasets): + if hasattr(dataset, 'element_to_frames'): + frame_indices = dataset.element_to_frames.get(missing_element, []) + for frame_idx in frame_indices: + if len(lst[i]['atype']) >= nbatches: + break + frame_data = dataset[frame_idx] + for key in frame_data: + if key not in lst[i]: + lst[i][key] = [] + lst[i][key].append(frame_data[key]) + + collected_elements = set() + for sys_stat in lst: + if 'atype' in sys_stat: + collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + + for sys_stat in lst: + for key in sys_stat: + if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + sys_stat[key] = torch.cat(sys_stat[key], dim=0) + return lst From 52e8c345770473e2dba08af312938ac2e1d3e00f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Dec 2024 12:14:04 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/dataset.py | 8 ++++---- deepmd/pt/utils/stat.py | 24 +++++++++++++++--------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 569cd29a94..85a4fbea96 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -39,16 +39,16 @@ def __getitem__(self, index): b_data = self._data_system.get_item_torch(index) b_data["natoms"] = self._natoms_vec return b_data - + def _build_element_to_frames(self): """Mapping element types to frame indexes""" - element_to_frames = {element: [] for element in range(self._ntypes)} + element_to_frames = {element: [] for element in range(self._ntypes)} for frame_idx in range(len(self)): frame_data = self._data_system.get_item_torch(frame_idx) - elements = frame_data["atype"] + elements = frame_data["atype"] for element in set(elements): - if len(element_to_frames[element]) < 10: + if len(element_to_frames[element]) < 10: element_to_frames[element].append(frame_idx) return element_to_frames diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 21240e5f90..c28d1c3703 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -84,21 +84,23 @@ def make_stat_input(datasets, dataloaders, nbatches): lst.append(sys_stat) all_elements = set() - if datasets and hasattr(datasets[0], 'element_to_frames'): + if datasets and hasattr(datasets[0], "element_to_frames"): all_elements.update(datasets[0].element_to_frames.keys()) collected_elements = set() for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + if "atype" in sys_stat: + collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) missing_elements = all_elements - collected_elements for missing_element in missing_elements: for i, dataset in enumerate(datasets): - if hasattr(dataset, 'element_to_frames'): - frame_indices = dataset.element_to_frames.get(missing_element, []) + if hasattr(dataset, "element_to_frames"): + frame_indices = dataset.element_to_frames.get( + missing_element, [] + ) for frame_idx in frame_indices: - if len(lst[i]['atype']) >= nbatches: + if len(lst[i]["atype"]) >= nbatches: break frame_data = dataset[frame_idx] for key in frame_data: @@ -108,12 +110,16 @@ def make_stat_input(datasets, dataloaders, nbatches): collected_elements = set() for sys_stat in lst: - if 'atype' in sys_stat: - collected_elements.update(np.unique(sys_stat['atype'].cpu().numpy())) + if "atype" in sys_stat: + collected_elements.update( + np.unique(sys_stat["atype"].cpu().numpy()) + ) for sys_stat in lst: for key in sys_stat: - if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): + if isinstance(sys_stat[key], list) and isinstance( + sys_stat[key][0], torch.Tensor + ): sys_stat[key] = torch.cat(sys_stat[key], dim=0) return lst