diff --git a/dpdata/system.py b/dpdata/system.py index a777ccb6..18b16aaf 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -370,10 +370,10 @@ def map_atom_types( _set2 = set(list(type_map.keys())) assert _set1.issubset(_set2) - atom_types_list = [] - for name, numb in zip(self.get_atom_names(), self.get_atom_numbs()): - atom_types_list.extend([name] * numb) - new_atom_types = np.array([type_map[ii] for ii in atom_types_list], dtype=int) + atom_names = self.get_atom_names() + new_atom_types = np.array( + [type_map[atom_names[ii]] for ii in self.data["atom_types"]], dtype=int + ) return new_atom_types diff --git a/tests/test_system_set_type.py b/tests/test_system_set_type.py index d8362ec7..58bc0cdc 100644 --- a/tests/test_system_set_type.py +++ b/tests/test_system_set_type.py @@ -6,6 +6,24 @@ from context import dpdata +class TestMapAtomTypes(unittest.TestCase): + def test_map_atom_types_preserves_current_atom_order(self): + data = { + "atom_names": ["O", "H"], + "atom_numbs": [1, 2], + "atom_types": np.array([1, 0, 1]), + "orig": np.zeros(3), + "cells": np.eye(3).reshape(1, 3, 3), + "coords": np.zeros((1, 3, 3)), + } + + system = dpdata.System(data=data) + + np.testing.assert_array_equal( + system.map_atom_types({"H": 0, "O": 1}), np.array([0, 1, 0]) + ) + + class TestSetAtomTypes(unittest.TestCase): def setUp(self): self.system_1 = dpdata.LabeledSystem("poscars/vasprun.h2o.md.10.xml")