From 139e559b126211f9e0084ed421fad4c30f68b95d Mon Sep 17 00:00:00 2001 From: kunjrathod2005 Date: Fri, 3 Jul 2026 12:10:44 -0700 Subject: [PATCH] fix: preserve atom order in type mapping --- dpdata/system.py | 8 ++++---- tests/test_system_set_type.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/dpdata/system.py b/dpdata/system.py index a777ccb6b..18b16aaf8 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -370,10 +370,10 @@ def map_atom_types( _set2 = set(list(type_map.keys())) assert _set1.issubset(_set2) - atom_types_list = [] - for name, numb in zip(self.get_atom_names(), self.get_atom_numbs()): - atom_types_list.extend([name] * numb) - new_atom_types = np.array([type_map[ii] for ii in atom_types_list], dtype=int) + atom_names = self.get_atom_names() + new_atom_types = np.array( + [type_map[atom_names[ii]] for ii in self.data["atom_types"]], dtype=int + ) return new_atom_types diff --git a/tests/test_system_set_type.py b/tests/test_system_set_type.py index d8362ec7b..58bc0cdcc 100644 --- a/tests/test_system_set_type.py +++ b/tests/test_system_set_type.py @@ -6,6 +6,24 @@ from context import dpdata +class TestMapAtomTypes(unittest.TestCase): + def test_map_atom_types_preserves_current_atom_order(self): + data = { + "atom_names": ["O", "H"], + "atom_numbs": [1, 2], + "atom_types": np.array([1, 0, 1]), + "orig": np.zeros(3), + "cells": np.eye(3).reshape(1, 3, 3), + "coords": np.zeros((1, 3, 3)), + } + + system = dpdata.System(data=data) + + np.testing.assert_array_equal( + system.map_atom_types({"H": 0, "O": 1}), np.array([0, 1, 0]) + ) + + class TestSetAtomTypes(unittest.TestCase): def setUp(self): self.system_1 = dpdata.LabeledSystem("poscars/vasprun.h2o.md.10.xml")