diff --git a/source/api_c/tests/test_deeppot_a_ptexpt.cc b/source/api_c/tests/test_deeppot_a_ptexpt.cc index 064b3b6c7b..59a9832d7b 100644 --- a/source/api_c/tests/test_deeppot_a_ptexpt.cc +++ b/source/api_c/tests/test_deeppot_a_ptexpt.cc @@ -169,16 +169,7 @@ TEST_F(TestInferDeepPotAPtExptC, numb_types_spin) { TEST_F(TestInferDeepPotAPtExptC, type_map) { const char* type_map = DP_DeepPotGetTypeMap(dp); - std::string type_map_str(type_map); - std::istringstream iss(type_map_str); - std::vector types; - std::string token; - while (iss >> token) { - types.push_back(token); - } - EXPECT_EQ(types.size(), 2); - EXPECT_EQ(types[0], "O"); - EXPECT_EQ(types[1], "H"); + EXPECT_STREQ(type_map, "O H"); DP_DeleteChar(type_map); } diff --git a/source/api_cc/src/DeepPotPD.cc b/source/api_cc/src/DeepPotPD.cc index 51d45a9182..49f409ca4d 100644 --- a/source/api_cc/src/DeepPotPD.cc +++ b/source/api_cc/src/DeepPotPD.cc @@ -677,6 +677,7 @@ void DeepPotPD::get_type_map(std::string& type_map) { std::vector type_map_arr(type_map_size, 0); type_map_tensor->CopyToCpu(type_map_arr.data()); + type_map.clear(); for (auto char_c : type_map_arr) { type_map += std::string(1, char_c); } diff --git a/source/api_cc/src/DeepPotPT.cc b/source/api_cc/src/DeepPotPT.cc index d69dbb8f82..641a375162 100644 --- a/source/api_cc/src/DeepPotPT.cc +++ b/source/api_cc/src/DeepPotPT.cc @@ -452,9 +452,12 @@ template void DeepPotPT::compute>( const bool atomic); void DeepPotPT::get_type_map(std::string& type_map) { auto ret = module.run_method("get_type_map").toList(); + type_map.clear(); for (const torch::IValue& element : ret) { - type_map += torch::str(element); // Convert each element to a string - type_map += " "; // Add a space between elements + if (!type_map.empty()) { + type_map += " "; + } + type_map += torch::str(element); } } diff --git a/source/api_cc/src/DeepPotPTExpt.cc b/source/api_cc/src/DeepPotPTExpt.cc index c8c1bfcfad..845e804b40 100644 --- a/source/api_cc/src/DeepPotPTExpt.cc +++ b/source/api_cc/src/DeepPotPTExpt.cc @@ -947,9 +947,12 @@ template void DeepPotPTExpt::compute>( const bool atomic); void DeepPotPTExpt::get_type_map(std::string& type_map_str) { + type_map_str.clear(); for (const auto& t : type_map) { + if (!type_map_str.empty()) { + type_map_str += " "; + } type_map_str += t; - type_map_str += " "; } } diff --git a/source/api_cc/src/DeepSpinPT.cc b/source/api_cc/src/DeepSpinPT.cc index 5add377045..aa4e05591d 100644 --- a/source/api_cc/src/DeepSpinPT.cc +++ b/source/api_cc/src/DeepSpinPT.cc @@ -476,9 +476,12 @@ template void DeepSpinPT::compute>( const bool atomic); void DeepSpinPT::get_type_map(std::string& type_map) { auto ret = module.run_method("get_type_map").toList(); + type_map.clear(); for (const torch::IValue& element : ret) { - type_map += torch::str(element); // Convert each element to a string - type_map += " "; // Add a space between elements + if (!type_map.empty()) { + type_map += " "; + } + type_map += torch::str(element); } } diff --git a/source/api_cc/src/DeepSpinPTExpt.cc b/source/api_cc/src/DeepSpinPTExpt.cc index 9edd51474b..f5870247f4 100644 --- a/source/api_cc/src/DeepSpinPTExpt.cc +++ b/source/api_cc/src/DeepSpinPTExpt.cc @@ -1028,9 +1028,12 @@ template void DeepSpinPTExpt::compute>( const bool atomic); void DeepSpinPTExpt::get_type_map(std::string& type_map_str) { + type_map_str.clear(); for (const auto& t : type_map) { + if (!type_map_str.empty()) { + type_map_str += " "; + } type_map_str += t; - type_map_str += " "; } } diff --git a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc index 7d9b04ddc0..a356151c76 100644 --- a/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc +++ b/source/api_cc/tests/test_deeppot_dpa_ptexpt_spin.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include "DeepSpin.h" @@ -96,6 +97,12 @@ TYPED_TEST(TestInferDeepSpinDpaPtExpt, test_get_use_spin) { EXPECT_FALSE(use_spin[2]); // H has no spin } +TYPED_TEST(TestInferDeepSpinDpaPtExpt, type_map) { + std::string type_map; + this->dp.get_type_map(type_map); + EXPECT_EQ(type_map, "Ni O H"); +} + TYPED_TEST(TestInferDeepSpinDpaPtExpt, cpu_build_nlist) { using VALUETYPE = TypeParam; const std::vector& coord = this->coord; diff --git a/source/api_cc/tests/test_deeppot_pd.cc b/source/api_cc/tests/test_deeppot_pd.cc index 6da9ee643b..bede9b11aa 100644 --- a/source/api_cc/tests/test_deeppot_pd.cc +++ b/source/api_cc/tests/test_deeppot_pd.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "DeepPot.h" @@ -89,6 +90,12 @@ class TestInferDeepPotAPd : public ::testing::Test { TYPED_TEST_SUITE(TestInferDeepPotAPd, PDValueTypes); +TYPED_TEST(TestInferDeepPotAPd, type_map) { + std::string type_map = "stale"; + this->dp.get_type_map(type_map); + EXPECT_EQ(type_map, "O H"); +} + TYPED_TEST(TestInferDeepPotAPd, cpu_build_nlist) { using VALUETYPE = TypeParam; std::vector& coord = this->coord; diff --git a/source/api_cc/tests/test_deeppot_pt.cc b/source/api_cc/tests/test_deeppot_pt.cc index 144ee8da8c..7f527296b1 100644 --- a/source/api_cc/tests/test_deeppot_pt.cc +++ b/source/api_cc/tests/test_deeppot_pt.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "DeepPot.h" @@ -82,6 +83,12 @@ class TestInferDeepPotAPt : public ::testing::Test { TYPED_TEST_SUITE(TestInferDeepPotAPt, ValueTypes); +TYPED_TEST(TestInferDeepPotAPt, type_map) { + std::string type_map; + this->dp.get_type_map(type_map); + EXPECT_EQ(type_map, "O H"); +} + TYPED_TEST(TestInferDeepPotAPt, cpu_build_nlist) { using VALUETYPE = TypeParam; std::vector& coord = this->coord; diff --git a/source/api_cc/tests/test_deeppot_ptexpt.cc b/source/api_cc/tests/test_deeppot_ptexpt.cc index 4f90555839..cbf1303633 100644 --- a/source/api_cc/tests/test_deeppot_ptexpt.cc +++ b/source/api_cc/tests/test_deeppot_ptexpt.cc @@ -944,8 +944,7 @@ TYPED_TEST_SUITE(TestDeepPotPTExptMetadata, ValueTypes); TYPED_TEST(TestDeepPotPTExptMetadata, type_map) { std::string type_map; this->dp.get_type_map(type_map); - EXPECT_NE(type_map.find("O"), std::string::npos); - EXPECT_NE(type_map.find("H"), std::string::npos); + EXPECT_EQ(type_map, "O H"); } TYPED_TEST(TestDeepPotPTExptMetadata, cutoff) {