From 73a2fdf1e797a92385a06834c759e3ef2868b71d Mon Sep 17 00:00:00 2001 From: Fei Yang <2501213217@stu.pku.edu.cn> Date: Tue, 30 Jun 2026 15:39:11 +0800 Subject: [PATCH 1/5] add mpi in neighbor_search --- .../module_neighlist/neighbor_search.cpp | 541 +++++++++++++++--- .../module_neighlist/neighbor_search.h | 59 +- .../module_neighlist/test/CMakeLists.txt | 23 +- .../test/neighbor_search_mpi_benchmark.cpp | 311 ++++++++++ .../test/neighbor_search_test.cpp | 148 ++++- source/source_esolver/esolver_lj.cpp | 140 ++++- 6 files changed, 1122 insertions(+), 100 deletions(-) create mode 100644 source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp diff --git a/source/source_cell/module_neighlist/neighbor_search.cpp b/source/source_cell/module_neighlist/neighbor_search.cpp index 912515bf9d5..69e33cdf93b 100644 --- a/source/source_cell/module_neighlist/neighbor_search.cpp +++ b/source/source_cell/module_neighlist/neighbor_search.cpp @@ -3,6 +3,116 @@ #include #include #include +#include +#include +#include + +namespace +{ +struct OriginalAtom +{ + std::array frac; + int atom_type = 0; + int atom_index = 0; +}; + +struct PeriodicInterval +{ + double lo = 0.0; + double hi = 0.0; + int shift = 0; +}; + +struct FractionalDomain +{ + std::array lo; + std::array hi; +}; + +double dot_product(const ModuleBase::Vector3& a, const ModuleBase::Vector3& b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +ModuleBase::Vector3 cross_product(const ModuleBase::Vector3& a, + const ModuleBase::Vector3& b) +{ + return ModuleBase::Vector3(a.y * b.z - a.z * b.y, + a.z * b.x - a.x * b.z, + a.x * b.y - a.y * b.x); +} + +double norm(const ModuleBase::Vector3& v) +{ + return std::sqrt(dot_product(v, v)); +} + +double wrap_fractional(double value) +{ + value -= std::floor(value); + if (value >= 1.0 - 1.0e-12) + { + return 0.0; + } + if (value < 1.0e-12) + { + return 0.0; + } + return value; +} + +int clamp_index(int value, int low, int high) +{ + return std::min(std::max(value, low), high); +} + +int fractional_domain_index(double frac, int n) +{ + return clamp_index(static_cast(std::floor(frac * n)), 0, n - 1); +} + +long long bin_key(int ix, int iy, int iz, const std::array& nbin) +{ + return (static_cast(ix) * nbin[1] + iy) * nbin[2] + iz; +} + +std::vector split_periodic_interval(double lo, double hi) +{ + std::vector intervals; + if (hi <= lo) + { + return intervals; + } + + const int first_shift = static_cast(std::floor(lo)); + const int last_shift = static_cast(std::ceil(hi)) - 1; + for (int shift = first_shift; shift <= last_shift; ++shift) + { + const double local_lo = std::max(0.0, lo - shift); + const double local_hi = std::min(1.0, hi - shift); + if (local_lo < local_hi) + { + intervals.push_back({local_lo, local_hi, shift}); + } + } + return intervals; +} + +bool inside_interval(double value, double lo, double hi) +{ + return value >= lo && value < hi; +} + +bool inside_block(const OriginalAtom& atom, + const PeriodicInterval& bx, + const PeriodicInterval& by, + const PeriodicInterval& bz) +{ + return inside_interval(atom.frac[0], bx.lo, bx.hi) && + inside_interval(atom.frac[1], by.lo, by.hi) && + inside_interval(atom.frac[2], bz.lo, bz.hi); +} +} // namespace // ========== Getter methods ========== @@ -211,9 +321,238 @@ void NeighborSearch::set_member_variables(const AtomProvider& ucell) } } +void NeighborSearch::set_local_member_variables(const AtomProvider& ucell, + const InputAtoms& atoms, + int nx, + int ny, + int nz) +{ + all_atoms_.clear(); + inside_atoms_.clear(); + ghost_atoms_.clear(); + + ModuleBase::Vector3 vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13); + ModuleBase::Vector3 vec2(ucell.get_latvec().e21, ucell.get_latvec().e22, ucell.get_latvec().e23); + ModuleBase::Vector3 vec3(ucell.get_latvec().e31, ucell.get_latvec().e32, ucell.get_latvec().e33); + + const auto domain_index = [](double position, double low, double wide, int n) { + if (wide < coord_tolerance) + { + return std::abs(position - low) < coord_tolerance ? 0 : std::numeric_limits::max(); + } + return std::min(std::max(static_cast(std::floor((position - low) / wide)), 0), n - 1); + }; + + const double search_radius2 = search_radius_ * search_radius_; + for (int ix = -glayerX_minus_; ix < glayerX_; ix++) + { + for (int iy = -glayerY_minus_; iy < glayerY_; iy++) + { + for (int iz = -glayerZ_minus_; iz < glayerZ_; iz++) + { + const bool central_image = (ix == 0 && iy == 0 && iz == 0); + for (int it = 0; it < ucell.get_ntype(); it++) + { + for (int ia = 0; ia < ucell.get_na(it); ia++) + { + double atom_x = ucell.get_tau(it, ia).x + vec1[0] * ix + vec2[0] * iy + vec3[0] * iz; + double atom_y = ucell.get_tau(it, ia).y + vec1[1] * ix + vec2[1] * iy + vec3[1] * iz; + double atom_z = ucell.get_tau(it, ia).z + vec1[2] * ix + vec2[2] * iy + vec3[2] * iz; + + const int in_x = domain_index(atom_x, atoms.x_low, wide_x_, nx); + const int in_y = domain_index(atom_y, atoms.y_low, wide_y_, ny); + const int in_z = domain_index(atom_z, atoms.z_low, wide_z_, nz); + + const bool owned = central_image && + in_x == x_ && + in_y == y_ && + in_z == z_ && + atom_x <= atoms.x_high && + atom_y <= atoms.y_high && + atom_z <= atoms.z_high; + const bool ghost = !owned && + distance(atom_x, atom_y, atom_z, atoms.x_low, atoms.y_low, atoms.z_low) + <= search_radius2; + + if (!owned && !ghost) + { + continue; + } + + NeighborAtom atom(atom_x, atom_y, atom_z, it, ia, static_cast(all_atoms_.size())); + atom.is_inside = owned; + all_atoms_.push_back(atom); + if (owned) + { + inside_atoms_.push_back(atom); + } + else + { + ghost_atoms_.push_back(atom); + } + } + } + } + } + } +} + +void NeighborSearch::set_local_member_variables_by_halo(const AtomProvider& ucell, int nx, int ny, int nz) +{ + all_atoms_.clear(); + inside_atoms_.clear(); + ghost_atoms_.clear(); + + const ModuleBase::Matrix3& lat = ucell.get_latvec(); + const ModuleBase::Matrix3 inv_lat = lat.Inverse(); + + const ModuleBase::Vector3 a1(lat.e11, lat.e12, lat.e13); + const ModuleBase::Vector3 a2(lat.e21, lat.e22, lat.e23); + const ModuleBase::Vector3 a3(lat.e31, lat.e32, lat.e33); + + const ModuleBase::Vector3 a2xa3 = cross_product(a2, a3); + const ModuleBase::Vector3 a3xa1 = cross_product(a3, a1); + const ModuleBase::Vector3 a1xa2 = cross_product(a1, a2); + + const double volume = std::abs(dot_product(a1, a2xa3)); + assert(volume > coord_tolerance); + + const std::array heights = { + volume / norm(a2xa3), + volume / norm(a3xa1), + volume / norm(a1xa2) + }; + + std::array margin = { + search_radius_ / heights[0] + coord_tolerance, + search_radius_ / heights[1] + coord_tolerance, + search_radius_ / heights[2] + coord_tolerance + }; + + FractionalDomain domain{ + {static_cast(x_) / nx, static_cast(y_) / ny, static_cast(z_) / nz}, + {static_cast(x_ + 1) / nx, static_cast(y_ + 1) / ny, static_cast(z_ + 1) / nz} + }; + + FractionalDomain halo{ + {domain.lo[0] - margin[0], domain.lo[1] - margin[1], domain.lo[2] - margin[2]}, + {domain.hi[0] + margin[0], domain.hi[1] + margin[1], domain.hi[2] + margin[2]} + }; + + std::vector original_atoms; + original_atoms.reserve(ucell.get_natom()); + for (int it = 0; it < ucell.get_ntype(); ++it) + { + for (int ia = 0; ia < ucell.get_na(it); ++ia) + { + const ModuleBase::Vector3 cart = ucell.get_tau(it, ia); + const ModuleBase::Vector3 frac = cart * inv_lat; + original_atoms.push_back({ + {wrap_fractional(frac.x), wrap_fractional(frac.y), wrap_fractional(frac.z)}, + it, + ia + }); + } + } + + std::array nbin; + for (int idim = 0; idim < 3; ++idim) + { + nbin[idim] = std::max(1, static_cast(std::ceil(1.0 / std::max(margin[idim], coord_tolerance)))); + } + + std::unordered_map> bins; + bins.reserve(original_atoms.size()); + for (int iat = 0; iat < static_cast(original_atoms.size()); ++iat) + { + const OriginalAtom& atom = original_atoms[iat]; + const int ix = clamp_index(static_cast(std::floor(atom.frac[0] * nbin[0])), 0, nbin[0] - 1); + const int iy = clamp_index(static_cast(std::floor(atom.frac[1] * nbin[1])), 0, nbin[1] - 1); + const int iz = clamp_index(static_cast(std::floor(atom.frac[2] * nbin[2])), 0, nbin[2] - 1); + bins[bin_key(ix, iy, iz, nbin)].push_back(iat); + } + + const std::vector intervals_x = split_periodic_interval(halo.lo[0], halo.hi[0]); + const std::vector intervals_y = split_periodic_interval(halo.lo[1], halo.hi[1]); + const std::vector intervals_z = split_periodic_interval(halo.lo[2], halo.hi[2]); + + for (const PeriodicInterval& bx : intervals_x) + { + const int ix_begin = clamp_index(static_cast(std::floor(bx.lo * nbin[0])), 0, nbin[0] - 1); + const int ix_end = clamp_index(static_cast(std::ceil(bx.hi * nbin[0])) - 1, 0, nbin[0] - 1); + for (const PeriodicInterval& by : intervals_y) + { + const int iy_begin = clamp_index(static_cast(std::floor(by.lo * nbin[1])), 0, nbin[1] - 1); + const int iy_end = clamp_index(static_cast(std::ceil(by.hi * nbin[1])) - 1, 0, nbin[1] - 1); + for (const PeriodicInterval& bz : intervals_z) + { + const int iz_begin = clamp_index(static_cast(std::floor(bz.lo * nbin[2])), 0, nbin[2] - 1); + const int iz_end = clamp_index(static_cast(std::ceil(bz.hi * nbin[2])) - 1, 0, nbin[2] - 1); + + for (int ix = ix_begin; ix <= ix_end; ++ix) + { + for (int iy = iy_begin; iy <= iy_end; ++iy) + { + for (int iz = iz_begin; iz <= iz_end; ++iz) + { + const auto bin_iter = bins.find(bin_key(ix, iy, iz, nbin)); + if (bin_iter == bins.end()) + { + continue; + } + + for (const int atom_id : bin_iter->second) + { + const OriginalAtom& original = original_atoms[atom_id]; + if (!inside_block(original, bx, by, bz)) + { + continue; + } + + const bool central_image = bx.shift == 0 && by.shift == 0 && bz.shift == 0; + const bool owned = central_image && + fractional_domain_index(original.frac[0], nx) == x_ && + fractional_domain_index(original.frac[1], ny) == y_ && + fractional_domain_index(original.frac[2], nz) == z_; + + const ModuleBase::Vector3 frac_image(original.frac[0] + bx.shift, + original.frac[1] + by.shift, + original.frac[2] + bz.shift); + const ModuleBase::Vector3 cart_image = frac_image * lat; + + NeighborAtom atom(cart_image.x, + cart_image.y, + cart_image.z, + original.atom_type, + original.atom_index, + static_cast(all_atoms_.size())); + atom.is_inside = owned; + all_atoms_.push_back(atom); + if (owned) + { + inside_atoms_.push_back(atom); + } + else + { + ghost_atoms_.push_back(atom); + } + } + } + } + } + } + } + } +} + // ========== Main public interface ========== void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank) +{ + this->init(ucell, sr, mpi_rank, 1); +} + +void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank, int mpi_size) { // clear possible residual data from previous runs inside_atoms_.clear(); @@ -224,104 +563,87 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank) search_radius_ = sr / ucell.get_lat0(); check_expand_condition(ucell); - set_member_variables(ucell); InputAtoms atoms = ucell_to_input_atoms(ucell); - int mpi_size = 1; + assert(mpi_size > 0); + assert(mpi_rank >= 0); + + const double span_x = atoms.x_high - atoms.x_low; + const double span_y = atoms.y_high - atoms.y_low; + const double span_z = atoms.z_high - atoms.z_low; + int nx, ny, nz; - decompose(mpi_size, nx, ny, nz); + decompose(mpi_size, span_x, span_y, span_z, nx, ny, nz); - z_ = mpi_rank / (nx * ny); - y_ = (mpi_rank % (nx * ny)) / nx; - x_ = mpi_rank % (nx * ny) % nx; + const int active_size = nx * ny * nz; + assert(active_size > 0); + assert(active_size <= mpi_size); - wide_x_ = (atoms.x_high - atoms.x_low) / nx; - wide_y_ = (atoms.y_high - atoms.y_low) / ny; - wide_z_ = (atoms.z_high - atoms.z_low) / nz; + wide_x_ = span_x / nx; + wide_y_ = span_y / ny; + wide_z_ = span_z / nz; assert(wide_x_ >= 0); assert(wide_y_ >= 0); assert(wide_z_ >= 0); - int in_x, in_y, in_z; + if (mpi_rank >= active_size) + { + x_ = -1; + y_ = -1; + z_ = -1; + neighbor_list_.initialize(0, neighbor_reserve_factor); + return; + } - for (size_t i = 0; i < all_atoms_.size(); i++) + z_ = mpi_rank / (nx * ny); + y_ = (mpi_rank % (nx * ny)) / nx; + x_ = mpi_rank % nx; + + if (mpi_size > 1) { - if(wide_x_ < coord_tolerance) - { - if(std::abs(all_atoms_[i].position_x - atoms.x_low) < coord_tolerance) - { - in_x = x_; - } - else - { - in_x = std::numeric_limits::max(); - } - } - else - { - in_x = std::min( - static_cast(std::floor((all_atoms_[i].position_x - atoms.x_low) / wide_x_)), - nx - 1 - ); - } - if(wide_y_ < coord_tolerance) - { - if(std::abs(all_atoms_[i].position_y - atoms.y_low) < coord_tolerance) - { - in_y = y_; - } - else + set_local_member_variables_by_halo(ucell, nx, ny, nz); + } + else + { + set_member_variables(ucell); + + int in_x, in_y, in_z; + const auto domain_index = [](double position, double low, double wide, int n) { + if (wide < coord_tolerance) { - in_y = std::numeric_limits::max(); + return std::abs(position - low) < coord_tolerance ? 0 : std::numeric_limits::max(); } - } - else - { - in_y = std::min( - static_cast(std::floor((all_atoms_[i].position_y - atoms.y_low) / wide_y_)), - ny - 1 - ); - } - if(wide_z_ < coord_tolerance) + return std::min(std::max(static_cast(std::floor((position - low) / wide)), 0), n - 1); + }; + + for (size_t i = 0; i < all_atoms_.size(); i++) { - if(std::abs(all_atoms_[i].position_z - atoms.z_low) < coord_tolerance) + in_x = domain_index(all_atoms_[i].position_x, atoms.x_low, wide_x_, nx); + in_y = domain_index(all_atoms_[i].position_y, atoms.y_low, wide_y_, ny); + in_z = domain_index(all_atoms_[i].position_z, atoms.z_low, wide_z_, nz); + + if (in_x == x_ && in_y == y_ && in_z == z_ && + all_atoms_[i].position_x <= atoms.x_high && + all_atoms_[i].position_y <= atoms.y_high && + all_atoms_[i].position_z <= atoms.z_high && + all_atoms_[i].is_inside) { - in_z = z_; + inside_atoms_.push_back(all_atoms_[i]); } - else + else if (distance( + all_atoms_[i].position_x, + all_atoms_[i].position_y, + all_atoms_[i].position_z, + atoms.x_low, + atoms.y_low, + atoms.z_low) <= search_radius_ * search_radius_) { - in_z = std::numeric_limits::max(); + ghost_atoms_.push_back(all_atoms_[i]); } } - else - { - in_z = std::min( - static_cast(std::floor((all_atoms_[i].position_z - atoms.z_low) / wide_z_)), - nz - 1 - ); - } - - if (in_x == x_ && in_y == y_ && in_z == z_ && - all_atoms_[i].position_x <= atoms.x_high && - all_atoms_[i].position_y <= atoms.y_high && - all_atoms_[i].position_z <= atoms.z_high && - all_atoms_[i].is_inside) - { - inside_atoms_.push_back(all_atoms_[i]); - } - else if (distance( - all_atoms_[i].position_x, - all_atoms_[i].position_y, - all_atoms_[i].position_z, - atoms.x_low, - atoms.y_low, - atoms.z_low) <= search_radius_ * search_radius_) - { - ghost_atoms_.push_back(all_atoms_[i]); - } } - neighbor_list_.initialize(inside_atoms_.size(), all_atoms_.size() * neighbor_reserve_factor); + neighbor_list_.initialize(inside_atoms_.size(), std::max(1, static_cast(all_atoms_.size()) * neighbor_reserve_factor)); } void NeighborSearch::build_neighbors() @@ -349,6 +671,7 @@ double NeighborSearch::distance( void NeighborSearch::decompose(int mpi_size, int &nx, int &ny, int &nz) { + assert(mpi_size > 0); nx = 1; ny = 1; nz = mpi_size; @@ -374,4 +697,66 @@ void NeighborSearch::decompose(int mpi_size, int &nx, int &ny, int &nz) break; } } -} \ No newline at end of file +} + +void NeighborSearch::decompose(int mpi_size, double span_x, double span_y, double span_z, int& nx, int& ny, int& nz) +{ + assert(mpi_size > 0); + + nx = 1; + ny = 1; + nz = 1; + + span_x = std::max(0.0, span_x); + span_y = std::max(0.0, span_y); + span_z = std::max(0.0, span_z); + + const bool can_split_x = span_x > coord_tolerance; + const bool can_split_y = span_y > coord_tolerance; + const bool can_split_z = span_z > coord_tolerance; + if (!can_split_x && !can_split_y && !can_split_z) + { + return; + } + + std::vector factors; + int remaining = mpi_size; + for (int factor = 2; factor * factor <= remaining; ++factor) + { + while (remaining % factor == 0) + { + factors.push_back(factor); + remaining /= factor; + } + } + if (remaining > 1) + { + factors.push_back(remaining); + } + std::sort(factors.rbegin(), factors.rend()); + + for (const int factor : factors) + { + int* best_dim = nullptr; + double best_score = -1.0; + + const auto try_dimension = [&](bool can_split, double span, int& dim) { + if (!can_split) + { + return; + } + const double score = span / dim; + if (score > best_score) + { + best_score = score; + best_dim = &dim; + } + }; + + try_dimension(can_split_x, span_x, nx); + try_dimension(can_split_y, span_y, ny); + try_dimension(can_split_z, span_z, nz); + assert(best_dim != nullptr); + *best_dim *= factor; + } +} diff --git a/source/source_cell/module_neighlist/neighbor_search.h b/source/source_cell/module_neighlist/neighbor_search.h index b75a8926d5c..4e6a1f06177 100644 --- a/source/source_cell/module_neighlist/neighbor_search.h +++ b/source/source_cell/module_neighlist/neighbor_search.h @@ -45,6 +45,19 @@ class NeighborSearch */ void init(const AtomProvider& ucell, double sr, int mpi_rank); + /** + * @brief Initialize the neighbor search with explicit MPI rank and size. + * + * This overload keeps the single-rank interface intact while allowing + * callers to decompose central atoms across MPI ranks. + * + * @param ucell Unit cell providing atom positions and lattice info. + * @param sr Search radius (cutoff distance) in Bohr. + * @param mpi_rank MPI rank of this process. + * @param mpi_size Total number of MPI processes. + */ + void init(const AtomProvider& ucell, double sr, int mpi_rank, int mpi_size); + /** * @brief Build the neighbor list for all inside atoms. * @@ -100,6 +113,22 @@ class NeighborSearch */ void decompose(int mpi_size, int& nx, int& ny, int& nz); + /** + * @brief Decompose MPI ranks only along directions with nonzero span. + * + * Directions whose atom-coordinate span is zero are assigned one domain + * layer so ownership is not duplicated across that direction. + * + * @param mpi_size Total number of MPI processes. + * @param span_x Atom-coordinate span in X. + * @param span_y Atom-coordinate span in Y. + * @param span_z Atom-coordinate span in Z. + * @param nx Output: number of divisions in X. + * @param ny Output: number of divisions in Y. + * @param nz Output: number of divisions in Z. + */ + void decompose(int mpi_size, double span_x, double span_y, double span_z, int& nx, int& ny, int& nz); + // ========== Getter methods ========== /** @@ -252,6 +281,34 @@ class NeighborSearch */ void set_member_variables(const AtomProvider& ucell); + /** + * @brief Generate only atoms needed by the local MPI domain. + * + * The resulting all_atoms_ is a rank-local index space containing local + * inside atoms and cutoff-relevant ghost/image atoms. + * + * @param ucell Unit cell providing atom positions. + * @param atoms Original unit-cell atom bounds. + * @param nx Number of MPI divisions in X. + * @param ny Number of MPI divisions in Y. + * @param nz Number of MPI divisions in Z. + */ + void set_local_member_variables(const AtomProvider& ucell, const InputAtoms& atoms, int nx, int ny, int nz); + + /** + * @brief Generate local atoms by querying fractional-coordinate halo bins. + * + * Ownership is still defined only for atoms in the primary unit cell. Periodic + * images are generated only when they overlap the local cutoff halo and are + * stored as ghost atoms. + * + * @param ucell Unit cell providing atom positions. + * @param nx Number of MPI divisions in fractional X. + * @param ny Number of MPI divisions in fractional Y. + * @param nz Number of MPI divisions in fractional Z. + */ + void set_local_member_variables_by_halo(const AtomProvider& ucell, int nx, int ny, int nz); + /** * @brief Compute the norm of the cross product of two 3D vectors. * @@ -313,4 +370,4 @@ class NeighborSearch static constexpr int neighbor_reserve_factor = 2; }; -#endif // NEIGHBOR_SEARCH_H \ No newline at end of file +#endif // NEIGHBOR_SEARCH_H diff --git a/source/source_cell/module_neighlist/test/CMakeLists.txt b/source/source_cell/module_neighlist/test/CMakeLists.txt index ec6df835d5e..3ee10db0bcc 100644 --- a/source/source_cell/module_neighlist/test/CMakeLists.txt +++ b/source/source_cell/module_neighlist/test/CMakeLists.txt @@ -34,4 +34,25 @@ AddTest( ../page_allocator.cpp ) - +if(ENABLE_MPI) + add_executable(MODULE_CELL_NEIGHBOR_neighbor_search_mpi_benchmark + neighbor_search_mpi_benchmark.cpp + ../neighbor_search.cpp + ../bin_manager.cpp + ../page_allocator.cpp + ../unitcell_lite.cpp + ) + target_link_libraries(MODULE_CELL_NEIGHBOR_neighbor_search_mpi_benchmark + parameter ${math_libs} base device Threads::Threads MPI::MPI_CXX + ) + if(USE_OPENMP) + target_link_libraries(MODULE_CELL_NEIGHBOR_neighbor_search_mpi_benchmark OpenMP::OpenMP_CXX) + endif() + install(TARGETS MODULE_CELL_NEIGHBOR_neighbor_search_mpi_benchmark DESTINATION ${CMAKE_BINARY_DIR}/tests) + add_test(NAME MODULE_CELL_NEIGHBOR_neighbor_search_mpi_benchmark_np4 + COMMAND ${MPIEXEC_EXECUTABLE} ${MPIEXEC_NUMPROC_FLAG} 4 + $ + 12 12 12 2 1.75 1.0 0.2 1 + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + ) +endif() diff --git a/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp b/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp new file mode 100644 index 00000000000..57c9e2f2ebb --- /dev/null +++ b/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp @@ -0,0 +1,311 @@ +#include "source_cell/module_neighlist/neighbor_search.h" +#include "source_cell/module_neighlist/unitcell_lite.h" + +#include + +#include +#include +#include +#include +#include +#include + +namespace +{ +int read_int_arg(int argc, char** argv, int index, int fallback) +{ + return argc <= index ? fallback : std::atoi(argv[index]); +} + +double read_double_arg(int argc, char** argv, int index, double fallback) +{ + return argc <= index ? fallback : std::atof(argv[index]); +} + +double cell_volume(const ModuleBase::Matrix3& latvec) +{ + const double cx = latvec.e22 * latvec.e33 - latvec.e23 * latvec.e32; + const double cy = latvec.e23 * latvec.e31 - latvec.e21 * latvec.e33; + const double cz = latvec.e21 * latvec.e32 - latvec.e22 * latvec.e31; + return std::abs(latvec.e11 * cx + latvec.e12 * cy + latvec.e13 * cz); +} + +ModuleBase::Vector3 direct_to_cartesian(const ModuleBase::Matrix3& latvec, + double fx, + double fy, + double fz) +{ + return ModuleBase::Vector3(fx * latvec.e11 + fy * latvec.e21 + fz * latvec.e31, + fx * latvec.e12 + fy * latvec.e22 + fz * latvec.e32, + fx * latvec.e13 + fy * latvec.e23 + fz * latvec.e33); +} + +UnitCellLite make_simple_lattice_ucell(int nx, int ny, int nz, double spacing, double skew) +{ + ModuleBase::Matrix3 latvec; + latvec.e11 = nx * spacing; + latvec.e12 = 0.0; + latvec.e13 = 0.0; + latvec.e21 = skew * ny * spacing; + latvec.e22 = ny * spacing; + latvec.e23 = 0.0; + latvec.e31 = 0.25 * skew * nz * spacing; + latvec.e32 = 0.5 * skew * nz * spacing; + latvec.e33 = nz * spacing; + + std::vector> tau; + tau.reserve(static_cast(nx) * ny * nz); + for (int ix = 0; ix < nx; ++ix) + { + for (int iy = 0; iy < ny; ++iy) + { + for (int iz = 0; iz < nz; ++iz) + { + tau.push_back(direct_to_cartesian(latvec, + static_cast(ix) / nx, + static_cast(iy) / ny, + static_cast(iz) / nz)); + } + } + } + + UnitCellLite ucell; + const double omega = cell_volume(latvec); + ucell.set_lattice(1.0, omega, latvec); + ucell.set_atoms(1, {static_cast(tau.size())}, tau); + return ucell; +} + +long long count_neighbor_pairs(const NeighborList& list) +{ + long long pairs = 0; + for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) + { + pairs += list.get_numneigh(local_i); + } + return pairs; +} + +long long square_sum(long long n) +{ + return n * (n - 1) * (2 * n - 1) / 6; +} +} // namespace + +int main(int argc, char** argv) +{ + MPI_Init(&argc, &argv); + + int mpi_rank = 0; + int mpi_size = 1; + MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank); + MPI_Comm_size(MPI_COMM_WORLD, &mpi_size); + + if (argc > 1 && std::string(argv[1]) == "--help") + { + if (mpi_rank == 0) + { + std::cout << "Usage: neighbor_search_mpi_benchmark [nx ny nz repeat cutoff spacing skew check_serial]\n" + << "Defaults: nx=16 ny=16 nz=16 repeat=5 cutoff=1.75 spacing=1.0 skew=0.0 check_serial=1\n"; + } + MPI_Finalize(); + return 0; + } + + const int nx = read_int_arg(argc, argv, 1, 16); + const int ny = read_int_arg(argc, argv, 2, 16); + const int nz = read_int_arg(argc, argv, 3, 16); + const int repeat = read_int_arg(argc, argv, 4, 5); + const double cutoff = read_double_arg(argc, argv, 5, 1.75); + const double spacing = read_double_arg(argc, argv, 6, 1.0); + const double skew = read_double_arg(argc, argv, 7, 0.0); + const int check_serial = read_int_arg(argc, argv, 8, 1); + + if (nx <= 0 || ny <= 0 || nz <= 0 || repeat <= 0 || cutoff <= 0.0 || spacing <= 0.0) + { + if (mpi_rank == 0) + { + std::cerr << "All dimensions, repeat, cutoff, and spacing must be positive.\n"; + } + MPI_Finalize(); + return 2; + } + + UnitCellLite ucell = make_simple_lattice_ucell(nx, ny, nz, spacing, skew); + + long long serial_all_atoms = -1; + long long serial_neighbor_pairs = -1; + double serial_init_time = 0.0; + double serial_build_time = 0.0; + if (mpi_rank == 0 && check_serial) + { + NeighborSearch serial; + const double t0 = MPI_Wtime(); + serial.init(ucell, cutoff, 0); + const double t1 = MPI_Wtime(); + serial.build_neighbors(); + const double t2 = MPI_Wtime(); + serial_all_atoms = static_cast(serial.get_all_atoms().size()); + serial_neighbor_pairs = count_neighbor_pairs(serial.get_neighbor_list()); + serial_init_time = t1 - t0; + serial_build_time = t2 - t1; + } + MPI_Bcast(&serial_all_atoms, 1, MPI_LONG_LONG, 0, MPI_COMM_WORLD); + MPI_Bcast(&serial_neighbor_pairs, 1, MPI_LONG_LONG, 0, MPI_COMM_WORLD); + + double init_time = 0.0; + double build_time = 0.0; + double total_time = 0.0; + long long last_inside = 0; + long long last_ghost = 0; + long long last_all = 0; + long long last_pairs = 0; + long long inside_index_sum = 0; + long long inside_index_square_sum = 0; + int local_failure = 0; + + for (int i = 0; i < repeat; ++i) + { + MPI_Barrier(MPI_COMM_WORLD); + const double t0 = MPI_Wtime(); + NeighborSearch ns; + ns.init(ucell, cutoff, mpi_rank, mpi_size); + const double t1 = MPI_Wtime(); + ns.build_neighbors(); + const double t2 = MPI_Wtime(); + + init_time += t1 - t0; + build_time += t2 - t1; + total_time += t2 - t0; + + if (i == repeat - 1) + { + const auto& inside_atoms = ns.get_inside_atoms(); + const auto& ghost_atoms = ns.get_ghost_atoms(); + const auto& all_atoms = ns.get_all_atoms(); + const auto& list = ns.get_neighbor_list(); + + last_inside = static_cast(inside_atoms.size()); + last_ghost = static_cast(ghost_atoms.size()); + last_all = static_cast(all_atoms.size()); + last_pairs = 0; + inside_index_sum = 0; + inside_index_square_sum = 0; + + for (size_t atom_id = 0; atom_id < all_atoms.size(); ++atom_id) + { + if (all_atoms[atom_id].atom_id != static_cast(atom_id)) + { + local_failure = 1; + } + } + + for (const NeighborAtom& atom : inside_atoms) + { + inside_index_sum += atom.atom_index; + inside_index_square_sum += static_cast(atom.atom_index) * atom.atom_index; + } + + for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) + { + last_pairs += list.get_numneigh(local_i); + for (int ad = 0; ad < list.get_numneigh(local_i); ++ad) + { + const int neighbor_id = list.get_firstneigh(local_i)[ad]; + if (neighbor_id < 0 || neighbor_id >= static_cast(all_atoms.size())) + { + local_failure = 1; + } + } + } + } + } + + long long global_inside = 0; + long long global_ghost = 0; + long long global_all = 0; + long long global_pairs = 0; + long long global_index_sum = 0; + long long global_index_square_sum = 0; + long long min_all = 0; + long long max_all = 0; + long long min_inside = 0; + long long max_inside = 0; + long long min_ghost = 0; + long long max_ghost = 0; + long long min_pairs = 0; + long long max_pairs = 0; + int global_failure = 0; + MPI_Allreduce(&last_inside, &global_inside, 1, MPI_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&last_ghost, &global_ghost, 1, MPI_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&last_all, &global_all, 1, MPI_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&last_pairs, &global_pairs, 1, MPI_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&inside_index_sum, &global_index_sum, 1, MPI_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&inside_index_square_sum, &global_index_square_sum, 1, MPI_LONG_LONG, MPI_SUM, MPI_COMM_WORLD); + MPI_Allreduce(&last_all, &min_all, 1, MPI_LONG_LONG, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(&last_all, &max_all, 1, MPI_LONG_LONG, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(&last_inside, &min_inside, 1, MPI_LONG_LONG, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(&last_inside, &max_inside, 1, MPI_LONG_LONG, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(&last_ghost, &min_ghost, 1, MPI_LONG_LONG, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(&last_ghost, &max_ghost, 1, MPI_LONG_LONG, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(&last_pairs, &min_pairs, 1, MPI_LONG_LONG, MPI_MIN, MPI_COMM_WORLD); + MPI_Allreduce(&last_pairs, &max_pairs, 1, MPI_LONG_LONG, MPI_MAX, MPI_COMM_WORLD); + MPI_Allreduce(&local_failure, &global_failure, 1, MPI_INT, MPI_MAX, MPI_COMM_WORLD); + + double max_init_time = 0.0; + double max_build_time = 0.0; + double max_total_time = 0.0; + MPI_Reduce(&init_time, &max_init_time, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD); + MPI_Reduce(&build_time, &max_build_time, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD); + MPI_Reduce(&total_time, &max_total_time, 1, MPI_DOUBLE, MPI_MAX, 0, MPI_COMM_WORLD); + + const long long nat = static_cast(ucell.get_natom()); + const bool ownership_ok = global_inside == nat && + global_index_sum == nat * (nat - 1) / 2 && + global_index_square_sum == square_sum(nat); + const bool neighbor_pairs_ok = !check_serial || global_pairs == serial_neighbor_pairs; + const bool all_ok = ownership_ok && global_failure == 0 && neighbor_pairs_ok; + + if (mpi_rank == 0) + { + std::cout << "NeighborSearch MPI halo benchmark\n" + << "algorithm fractional_halo_bins\n" + << "np " << mpi_size << "\n" + << "atoms " << nat << "\n" + << "grid " << nx << " " << ny << " " << nz << "\n" + << "repeat " << repeat << "\n" + << "cutoff " << cutoff << "\n" + << "spacing " << spacing << "\n" + << "skew " << skew << "\n" + << "check_serial " << check_serial << "\n" + << "serial_all_atoms " << serial_all_atoms << "\n" + << "serial_neighbor_pairs " << serial_neighbor_pairs << "\n" + << "inside_sum " << global_inside << "\n" + << "inside_min " << min_inside << "\n" + << "inside_max " << max_inside << "\n" + << "ghost_sum " << global_ghost << "\n" + << "ghost_min " << min_ghost << "\n" + << "ghost_max " << max_ghost << "\n" + << "all_atoms_sum " << global_all << "\n" + << "all_atoms_min " << min_all << "\n" + << "all_atoms_max " << max_all << "\n" + << "neighbor_pairs_sum " << global_pairs << "\n" + << "neighbor_pairs_min " << min_pairs << "\n" + << "neighbor_pairs_max " << max_pairs << "\n" + << "time_serial_ref_init " << serial_init_time << "\n" + << "time_serial_ref_build " << serial_build_time << "\n" + << "time_serial_ref_total " << serial_init_time + serial_build_time << "\n" + << "time_init_max_total " << max_init_time << "\n" + << "time_build_max_total " << max_build_time << "\n" + << "time_total_max_total " << max_total_time << "\n" + << "time_init_max_avg " << max_init_time / repeat << "\n" + << "time_build_max_avg " << max_build_time / repeat << "\n" + << "time_total_max_avg " << max_total_time / repeat << "\n" + << "ownership_ok " << (ownership_ok ? 1 : 0) << "\n" + << "neighbor_pairs_ok " << (neighbor_pairs_ok ? 1 : 0) << "\n" + << "neighbor_ids_ok " << (global_failure == 0 ? 1 : 0) << "\n"; + } + + MPI_Finalize(); + return all_ok ? 0 : 1; +} diff --git a/source/source_cell/module_neighlist/test/neighbor_search_test.cpp b/source/source_cell/module_neighlist/test/neighbor_search_test.cpp index b8a5bdf0ef3..f2bbe790b8a 100644 --- a/source/source_cell/module_neighlist/test/neighbor_search_test.cpp +++ b/source/source_cell/module_neighlist/test/neighbor_search_test.cpp @@ -1,6 +1,7 @@ #include #include "../neighbor_search.h" #include "../unitcell_lite.h" +#include // Helper function to create a simple UnitCellLite for testing static UnitCellLite make_test_ucell(double lat0, double omega, @@ -108,6 +109,26 @@ TEST(NeighborSearchUnit, DecomposePrimeNumber) EXPECT_EQ(nz, 13); } +TEST(NeighborSearchUnit, DecomposeSkipsZeroSpanDirections) +{ + NeighborSearch ns; + int nx, ny, nz; + + ns.decompose(8, 1.0, 1.0, 0.0, nx, ny, nz); + EXPECT_EQ(nx * ny * nz, 8); + EXPECT_EQ(nz, 1); + + ns.decompose(4, 4.0, 0.0, 0.0, nx, ny, nz); + EXPECT_EQ(nx, 4); + EXPECT_EQ(ny, 1); + EXPECT_EQ(nz, 1); + + ns.decompose(4, 0.0, 0.0, 0.0, nx, ny, nz); + EXPECT_EQ(nx, 1); + EXPECT_EQ(ny, 1); + EXPECT_EQ(nz, 1); +} + TEST(NeighborSearchUnit, NonOrthogonalLatticeExpand) { ModuleBase::Matrix3 latvec; @@ -169,6 +190,131 @@ TEST(NeighborSearchInit_MpiRankIndexing, RankValues) EXPECT_EQ(ns0.get_z(), 0); } +TEST(NeighborSearchInit_MpiOwnership, SingleAtomZeroSpanIsOwnedOnce) +{ + ModuleBase::Matrix3 latvec; + latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; + latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; + latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; + + UnitCellLite ucell = make_test_ucell( + 1.0, 1.0, latvec, 1, {1}, + {{0.0, 0.0, 0.0}} + ); + + size_t total_inside = 0; + for (int rank = 0; rank < 4; ++rank) + { + NeighborSearch ns; + ns.init(ucell, 0.1, rank, 4); + total_inside += ns.get_inside_atoms().size(); + EXPECT_EQ(ns.get_neighbor_list().get_nlocal(), static_cast(ns.get_inside_atoms().size())); + EXPECT_EQ(ns.get_inside_atoms().size(), rank == 0 ? 1U : 0U); + } + EXPECT_EQ(total_inside, 1U); +} + +TEST(NeighborSearchInit_MpiOwnership, SplitsOnlyNonzeroSpanDirection) +{ + ModuleBase::Matrix3 latvec; + latvec.e11 = 4; latvec.e12 = 0; latvec.e13 = 0; + latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; + latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; + + UnitCellLite ucell = make_test_ucell( + 1.0, 4.0, latvec, 1, {4}, + {{0.0, 0.0, 0.0}, + {1.0, 0.0, 0.0}, + {2.0, 0.0, 0.0}, + {3.0, 0.0, 0.0}} + ); + + size_t total_inside = 0; + for (int rank = 0; rank < 4; ++rank) + { + NeighborSearch ns; + ns.init(ucell, 0.1, rank, 4); + total_inside += ns.get_inside_atoms().size(); + EXPECT_EQ(ns.get_y(), 0); + EXPECT_EQ(ns.get_z(), 0); + EXPECT_EQ(ns.get_inside_atoms().size(), 1U); + } + EXPECT_EQ(total_inside, 4U); +} + +TEST(NeighborSearchInit_MpiLocalAtoms, LocalIdsAreValidAndAllAtomsShrink) +{ + ModuleBase::Matrix3 latvec; + latvec.e11 = 4; latvec.e12 = 0; latvec.e13 = 0; + latvec.e21 = 0; latvec.e22 = 4; latvec.e23 = 0; + latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 4; + + std::vector> tau; + for (int ix = 0; ix < 4; ++ix) + { + for (int iy = 0; iy < 4; ++iy) + { + for (int iz = 0; iz < 4; ++iz) + { + tau.emplace_back(ix, iy, iz); + } + } + } + + UnitCellLite ucell = make_test_ucell(1.0, 64.0, latvec, 1, {static_cast(tau.size())}, tau); + + NeighborSearch serial; + serial.init(ucell, 1.1, 0); + serial.build_neighbors(); + const size_t serial_all_atoms = serial.get_all_atoms().size(); + size_t serial_neighbor_pairs = 0; + for (int local_i = 0; local_i < serial.get_neighbor_list().get_nlocal(); ++local_i) + { + serial_neighbor_pairs += serial.get_neighbor_list().get_numneigh(local_i); + } + + size_t total_inside = 0; + size_t parallel_all_atoms_sum = 0; + size_t parallel_all_atoms_max = 0; + size_t parallel_neighbor_pairs = 0; + for (int rank = 0; rank < 4; ++rank) + { + NeighborSearch ns; + ns.init(ucell, 1.1, rank, 4); + ns.build_neighbors(); + + const auto& all_atoms = ns.get_all_atoms(); + const auto& list = ns.get_neighbor_list(); + total_inside += ns.get_inside_atoms().size(); + parallel_all_atoms_sum += all_atoms.size(); + parallel_all_atoms_max = std::max(parallel_all_atoms_max, all_atoms.size()); + for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) + { + parallel_neighbor_pairs += list.get_numneigh(local_i); + } + + for (size_t i = 0; i < all_atoms.size(); ++i) + { + EXPECT_EQ(all_atoms[i].atom_id, static_cast(i)); + } + + for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) + { + for (int ad = 0; ad < list.get_numneigh(local_i); ++ad) + { + const int neighbor_id = list.get_firstneigh(local_i)[ad]; + EXPECT_GE(neighbor_id, 0); + EXPECT_LT(neighbor_id, static_cast(all_atoms.size())); + } + } + } + + EXPECT_EQ(total_inside, tau.size()); + EXPECT_EQ(parallel_neighbor_pairs, serial_neighbor_pairs); + EXPECT_LT(parallel_all_atoms_max, serial_all_atoms); + EXPECT_LT(parallel_all_atoms_sum, serial_all_atoms * 4); +} + TEST(NeighborSearchDistance_OutsideCases, VariousAxes) { NeighborSearch ns; @@ -231,4 +377,4 @@ TEST(NeighborSearchUnit, ExpansionLayersAndAtomCount) EXPECT_EQ(static_cast(ns.get_all_atoms().size()), expected); } -// end of additional tests \ No newline at end of file +// end of additional tests diff --git a/source/source_esolver/esolver_lj.cpp b/source/source_esolver/esolver_lj.cpp index 3ddba86adc2..4997d42f2b2 100644 --- a/source/source_esolver/esolver_lj.cpp +++ b/source/source_esolver/esolver_lj.cpp @@ -5,7 +5,15 @@ #include "source_io/module_output/output_log.h" #include "source_io/module_output/cif_io.h" #include "source_cell/module_neighlist/neighbor_search.h" +#include "source_base/global_variable.h" +#include "source_base/timer.h" +#ifdef __MPI +#include "source_base/parallel_reduce.h" +#endif +#include +#include +#include namespace ModuleESolver @@ -57,29 +65,54 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) { UnitCellLite ucell_lite = change_from_ucell_to_ucell_lite(ucell); NeighborSearch neighbor_search; - neighbor_search.init(ucell_lite, search_radius, 0); - neighbor_search.build_neighbors(); - - double distance = 0.0; - int index = 0; // Important! potential, force, virial must be zero per step lj_potential = 0; lj_force.zero_out(); lj_virial.zero_out(); + double distance = 0.0; ModuleBase::Vector3 tau1, tau2, dtau; - const NeighborList& neighbor_list = neighbor_search.get_neighbor_list(); - const std::vector& all_atoms = neighbor_search.get_all_atoms(); - for (int it = 0; it < ucell.ntype; ++it) + +#ifdef __MPI + if (GlobalV::NPROC > 1) { - Atom* atom1 = &ucell.atoms[it]; - for (int ia = 0; ia < atom1->na; ++ia) + ModuleBase::timer::start("ESolverLJ", "mpi_total"); + ModuleBase::timer::start("ESolverLJ", "neigh_init"); + neighbor_search.init(ucell_lite, search_radius, GlobalV::MY_RANK, GlobalV::NPROC); + ModuleBase::timer::end("ESolverLJ", "neigh_init"); + ModuleBase::timer::start("ESolverLJ", "neigh_bld"); + neighbor_search.build_neighbors(); + ModuleBase::timer::end("ESolverLJ", "neigh_bld"); + + const NeighborList& neighbor_list = neighbor_search.get_neighbor_list(); + const std::vector& inside_atoms = neighbor_search.get_inside_atoms(); + const std::vector& all_atoms = neighbor_search.get_all_atoms(); + + std::vector atom_start(ucell.ntype + 1, 0); + for (int it = 0; it < ucell.ntype; ++it) { - tau1 = atom1->tau[ia]; - for (int ad = 0; ad < neighbor_list.get_numneigh(index); ++ad) + atom_start[it + 1] = atom_start[it] + ucell.atoms[it].na; + } + + std::vector potential_by_atom(ucell.nat, 0.0); + std::vector virial_by_atom(ucell.nat * 9, 0.0); + + ModuleBase::timer::start("ESolverLJ", "force_loc"); + for (int local_i = 0; local_i < neighbor_list.get_nlocal(); ++local_i) + { + const NeighborAtom& center_atom = inside_atoms[local_i]; + const int it = center_atom.atom_type; + const int ia = center_atom.atom_index; + const int global_i = atom_start[it] + ia; + + tau1.x = center_atom.position_x; + tau1.y = center_atom.position_y; + tau1.z = center_atom.position_z; + + for (int ad = 0; ad < neighbor_list.get_numneigh(local_i); ++ad) { - const NeighborAtom& neighbor_atom = all_atoms[neighbor_list.get_firstneigh(index)[ad]]; + const NeighborAtom& neighbor_atom = all_atoms[neighbor_list.get_firstneigh(local_i)[ad]]; tau2.x = neighbor_atom.position_x; tau2.y = neighbor_atom.position_y; tau2.z = neighbor_atom.position_z; @@ -88,16 +121,85 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) distance = dtau.norm(); if (distance < lj_rcut(it, it2)) { - lj_potential += LJ_energy(distance, it, it2) - en_shift(it, it2); + potential_by_atom[global_i] += LJ_energy(distance, it, it2) - en_shift(it, it2); ModuleBase::Vector3 f_ij = LJ_force(dtau, it, it2); - lj_force(index, 0) += f_ij.x; - lj_force(index, 1) += f_ij.y; - lj_force(index, 2) += f_ij.z; - LJ_virial(f_ij, dtau); + lj_force(global_i, 0) += f_ij.x; + lj_force(global_i, 1) += f_ij.y; + lj_force(global_i, 2) += f_ij.z; + for (int i = 0; i < 3; ++i) + { + for (int j = 0; j < 3; ++j) + { + virial_by_atom[global_i * 9 + i * 3 + j] += dtau[i] * f_ij[j]; + } + } } } - index++; } + ModuleBase::timer::end("ESolverLJ", "force_loc"); + + ModuleBase::timer::start("ESolverLJ", "reduce"); + Parallel_Reduce::reduce_all(potential_by_atom.data(), static_cast(potential_by_atom.size())); + Parallel_Reduce::reduce_all(lj_force.c, lj_force.nr * lj_force.nc); + Parallel_Reduce::reduce_all(virial_by_atom.data(), static_cast(virial_by_atom.size())); + ModuleBase::timer::end("ESolverLJ", "reduce"); + + for (int iat = 0; iat < ucell.nat; ++iat) + { + lj_potential += potential_by_atom[iat]; + for (int i = 0; i < 3; ++i) + { + for (int j = 0; j < 3; ++j) + { + lj_virial(i, j) += virial_by_atom[iat * 9 + i * 3 + j]; + } + } + } + ModuleBase::timer::end("ESolverLJ", "mpi_total"); + } + else +#endif + { + ModuleBase::timer::start("ESolverLJ", "serial_tot"); + ModuleBase::timer::start("ESolverLJ", "ser_neigh"); + neighbor_search.init(ucell_lite, search_radius, 0); + neighbor_search.build_neighbors(); + ModuleBase::timer::end("ESolverLJ", "ser_neigh"); + + int index = 0; + const NeighborList& neighbor_list = neighbor_search.get_neighbor_list(); + const std::vector& all_atoms = neighbor_search.get_all_atoms(); + ModuleBase::timer::start("ESolverLJ", "ser_force"); + for (int it = 0; it < ucell.ntype; ++it) + { + Atom* atom1 = &ucell.atoms[it]; + for (int ia = 0; ia < atom1->na; ++ia) + { + tau1 = atom1->tau[ia]; + for (int ad = 0; ad < neighbor_list.get_numneigh(index); ++ad) + { + const NeighborAtom& neighbor_atom = all_atoms[neighbor_list.get_firstneigh(index)[ad]]; + tau2.x = neighbor_atom.position_x; + tau2.y = neighbor_atom.position_y; + tau2.z = neighbor_atom.position_z; + int it2 = neighbor_atom.atom_type; + dtau = (tau1 - tau2) * ucell.lat0; + distance = dtau.norm(); + if (distance < lj_rcut(it, it2)) + { + lj_potential += LJ_energy(distance, it, it2) - en_shift(it, it2); + ModuleBase::Vector3 f_ij = LJ_force(dtau, it, it2); + lj_force(index, 0) += f_ij.x; + lj_force(index, 1) += f_ij.y; + lj_force(index, 2) += f_ij.z; + LJ_virial(f_ij, dtau); + } + } + index++; + } + } + ModuleBase::timer::end("ESolverLJ", "ser_force"); + ModuleBase::timer::end("ESolverLJ", "serial_tot"); } From aa64209c5d9fe2e85a12b878b08e61cedd2aecb6 Mon Sep 17 00:00:00 2001 From: Fei Yang <2501213217@stu.pku.edu.cn> Date: Tue, 30 Jun 2026 17:58:50 +0800 Subject: [PATCH 2/5] fix neighbor search build without mpi --- source/source_cell/module_neighlist/neighbor_search.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/source/source_cell/module_neighlist/neighbor_search.cpp b/source/source_cell/module_neighlist/neighbor_search.cpp index 69e33cdf93b..b818d80abec 100644 --- a/source/source_cell/module_neighlist/neighbor_search.cpp +++ b/source/source_cell/module_neighlist/neighbor_search.cpp @@ -11,6 +11,11 @@ namespace { struct OriginalAtom { + OriginalAtom(const std::array& frac_in, int atom_type_in, int atom_index_in) + : frac(frac_in), atom_type(atom_type_in), atom_index(atom_index_in) + { + } + std::array frac; int atom_type = 0; int atom_index = 0; @@ -18,6 +23,10 @@ struct OriginalAtom struct PeriodicInterval { + PeriodicInterval(double lo_in, double hi_in, int shift_in) : lo(lo_in), hi(hi_in), shift(shift_in) + { + } + double lo = 0.0; double hi = 0.0; int shift = 0; From ed1f4b2af7751ceb4431f9ab3cdb1d6a2d777958 Mon Sep 17 00:00:00 2001 From: Fei Yang <2501213217@stu.pku.edu.cn> Date: Tue, 30 Jun 2026 18:28:37 +0800 Subject: [PATCH 3/5] make neighbor search initialization cxx11 compatible --- .../module_neighlist/neighbor_search.cpp | 62 +++++++++++++------ 1 file changed, 42 insertions(+), 20 deletions(-) diff --git a/source/source_cell/module_neighlist/neighbor_search.cpp b/source/source_cell/module_neighlist/neighbor_search.cpp index b818d80abec..09db6fa58fb 100644 --- a/source/source_cell/module_neighlist/neighbor_search.cpp +++ b/source/source_cell/module_neighlist/neighbor_search.cpp @@ -34,6 +34,11 @@ struct PeriodicInterval struct FractionalDomain { + FractionalDomain(const std::array& lo_in, const std::array& hi_in) + : lo(lo_in), hi(hi_in) + { + } + std::array lo; std::array hi; }; @@ -101,7 +106,7 @@ std::vector split_periodic_interval(double lo, double hi) const double local_hi = std::min(1.0, hi - shift); if (local_lo < local_hi) { - intervals.push_back({local_lo, local_hi, shift}); + intervals.push_back(PeriodicInterval(local_lo, local_hi, shift)); } } return intervals; @@ -123,6 +128,8 @@ bool inside_block(const OriginalAtom& atom, } } // namespace +constexpr double NeighborSearch::coord_tolerance; + // ========== Getter methods ========== double NeighborSearch::get_search_radius() const { @@ -426,27 +433,41 @@ void NeighborSearch::set_local_member_variables_by_halo(const AtomProvider& ucel const double volume = std::abs(dot_product(a1, a2xa3)); assert(volume > coord_tolerance); - const std::array heights = { + const std::array heights = {{ volume / norm(a2xa3), volume / norm(a3xa1), volume / norm(a1xa2) - }; + }}; - std::array margin = { + std::array margin = {{ search_radius_ / heights[0] + coord_tolerance, search_radius_ / heights[1] + coord_tolerance, search_radius_ / heights[2] + coord_tolerance - }; - - FractionalDomain domain{ - {static_cast(x_) / nx, static_cast(y_) / ny, static_cast(z_) / nz}, - {static_cast(x_ + 1) / nx, static_cast(y_ + 1) / ny, static_cast(z_ + 1) / nz} - }; - - FractionalDomain halo{ - {domain.lo[0] - margin[0], domain.lo[1] - margin[1], domain.lo[2] - margin[2]}, - {domain.hi[0] + margin[0], domain.hi[1] + margin[1], domain.hi[2] + margin[2]} - }; + }}; + + const std::array domain_lo = {{ + static_cast(x_) / nx, + static_cast(y_) / ny, + static_cast(z_) / nz + }}; + const std::array domain_hi = {{ + static_cast(x_ + 1) / nx, + static_cast(y_ + 1) / ny, + static_cast(z_ + 1) / nz + }}; + const FractionalDomain domain(domain_lo, domain_hi); + + const std::array halo_lo = {{ + domain.lo[0] - margin[0], + domain.lo[1] - margin[1], + domain.lo[2] - margin[2] + }}; + const std::array halo_hi = {{ + domain.hi[0] + margin[0], + domain.hi[1] + margin[1], + domain.hi[2] + margin[2] + }}; + const FractionalDomain halo(halo_lo, halo_hi); std::vector original_atoms; original_atoms.reserve(ucell.get_natom()); @@ -456,11 +477,12 @@ void NeighborSearch::set_local_member_variables_by_halo(const AtomProvider& ucel { const ModuleBase::Vector3 cart = ucell.get_tau(it, ia); const ModuleBase::Vector3 frac = cart * inv_lat; - original_atoms.push_back({ - {wrap_fractional(frac.x), wrap_fractional(frac.y), wrap_fractional(frac.z)}, - it, - ia - }); + const std::array wrapped_frac = {{ + wrap_fractional(frac.x), + wrap_fractional(frac.y), + wrap_fractional(frac.z) + }}; + original_atoms.push_back(OriginalAtom(wrapped_frac, it, ia)); } } From 09dc1e6bd30e761f010a7372a268c1d4b7835504 Mon Sep 17 00:00:00 2001 From: Fei Yang <2501213217@stu.pku.edu.cn> Date: Tue, 30 Jun 2026 20:36:11 +0800 Subject: [PATCH 4/5] add distributed neighbor decomposition --- source/Makefile.Objects | 3 + .../module_neighlist/CMakeLists.txt | 16 +- .../module_neighlist/domain_decomposition.cpp | 387 ++++++++++++++++++ .../module_neighlist/domain_decomposition.h | 91 ++++ .../source_cell/module_neighlist/local_atom.h | 53 +++ .../module_neighlist/neighbor_atom.h | 31 +- .../module_neighlist/neighbor_search.cpp | 56 +++ .../module_neighlist/neighbor_search.h | 17 + .../module_neighlist/test/CMakeLists.txt | 1 + .../test/neighbor_search_mpi_benchmark.cpp | 9 +- .../test/neighbor_search_test.cpp | 33 ++ source/source_esolver/esolver_lj.cpp | 9 +- 12 files changed, 698 insertions(+), 8 deletions(-) create mode 100644 source/source_cell/module_neighlist/domain_decomposition.cpp create mode 100644 source/source_cell/module_neighlist/domain_decomposition.h create mode 100644 source/source_cell/module_neighlist/local_atom.h diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 4f4c105781a..e11fec13efe 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -415,6 +415,9 @@ OBJS_NEIGHBOR=sltk_atom.o\ OBJS_NEIGHBOR_SEARCH=neighbor_search.o\ bin_manager.o\ + domain_decomposition.o\ + page_allocator.o\ + unitcell_lite.o\ OBJS_ORBITAL=ORB_atomic.o\ diff --git a/source/source_cell/module_neighlist/CMakeLists.txt b/source/source_cell/module_neighlist/CMakeLists.txt index b40a0d4de29..b8ec7f2d055 100644 --- a/source/source_cell/module_neighlist/CMakeLists.txt +++ b/source/source_cell/module_neighlist/CMakeLists.txt @@ -1,12 +1,20 @@ -add_library( - neighbor_search - OBJECT +set(neighbor_search_sources bin_manager.cpp neighbor_search.cpp page_allocator.cpp unitcell_lite.cpp ) +if(ENABLE_MPI) + list(APPEND neighbor_search_sources domain_decomposition.cpp) +endif() + +add_library( + neighbor_search + OBJECT + ${neighbor_search_sources} +) + if(ENABLE_COVERAGE) add_coverage(neighbor_search) endif() @@ -15,4 +23,4 @@ if(BUILD_TESTING) if(ENABLE_MPI) add_subdirectory(test) endif() -endif() \ No newline at end of file +endif() diff --git a/source/source_cell/module_neighlist/domain_decomposition.cpp b/source/source_cell/module_neighlist/domain_decomposition.cpp new file mode 100644 index 00000000000..974196829fd --- /dev/null +++ b/source/source_cell/module_neighlist/domain_decomposition.cpp @@ -0,0 +1,387 @@ +#include "source_cell/module_neighlist/domain_decomposition.h" + +#include +#include +#include + +DomainDecomposition::DomainDecomposition() + : comm_(MPI_COMM_NULL), + cart_comm_(MPI_COMM_NULL), + owns_cart_comm_(false), + rank_(0), + size_(1), + dims_(), + coords_(), + margin_(), + latvec_(), + inv_latvec_(), + lat0_(1.0), + cutoff_(0.0), + skin_(0.0) +{ + dims_[0] = dims_[1] = dims_[2] = 1; + coords_[0] = coords_[1] = coords_[2] = 0; + margin_[0] = margin_[1] = margin_[2] = 0.0; +} + +DomainDecomposition::~DomainDecomposition() +{ + if (owns_cart_comm_ && cart_comm_ != MPI_COMM_NULL) + { + MPI_Comm_free(&cart_comm_); + } +} + +double DomainDecomposition::wrap_fractional(double value) +{ + value -= std::floor(value); + if (value >= 1.0 - 1.0e-12) + { + return 0.0; + } + if (value < 1.0e-12) + { + return 0.0; + } + return value; +} + +int DomainDecomposition::floor_div(int value, int divisor) +{ + int quotient = value / divisor; + const int remainder = value % divisor; + if (remainder != 0 && ((remainder < 0) != (divisor < 0))) + { + --quotient; + } + return quotient; +} + +int DomainDecomposition::positive_mod(int value, int divisor) +{ + int result = value % divisor; + if (result < 0) + { + result += divisor; + } + return result; +} + +double DomainDecomposition::dot_product(const ModuleBase::Vector3& a, + const ModuleBase::Vector3& b) +{ + return a.x * b.x + a.y * b.y + a.z * b.z; +} + +ModuleBase::Vector3 DomainDecomposition::cross_product(const ModuleBase::Vector3& a, + const ModuleBase::Vector3& b) +{ + return ModuleBase::Vector3(a.y * b.z - a.z * b.y, + a.z * b.x - a.x * b.z, + a.x * b.y - a.y * b.x); +} + +double DomainDecomposition::norm(const ModuleBase::Vector3& value) +{ + return std::sqrt(dot_product(value, value)); +} + +void DomainDecomposition::init(MPI_Comm comm, + const ModuleBase::Matrix3& latvec, + double lat0, + double cutoff, + double skin) +{ + comm_ = comm; + MPI_Comm_rank(comm_, &rank_); + MPI_Comm_size(comm_, &size_); + + latvec_ = latvec; + inv_latvec_ = latvec_.Inverse(); + lat0_ = lat0; + cutoff_ = cutoff; + skin_ = skin; + + int dims[3] = {0, 0, 0}; + MPI_Dims_create(size_, 3, dims); + dims_[0] = std::max(1, dims[0]); + dims_[1] = std::max(1, dims[1]); + dims_[2] = std::max(1, dims[2]); + + int periods[3] = {1, 1, 1}; + if (owns_cart_comm_ && cart_comm_ != MPI_COMM_NULL) + { + MPI_Comm_free(&cart_comm_); + cart_comm_ = MPI_COMM_NULL; + owns_cart_comm_ = false; + } + MPI_Cart_create(comm_, 3, dims, periods, 0, &cart_comm_); + owns_cart_comm_ = cart_comm_ != MPI_COMM_NULL; + MPI_Comm_rank(cart_comm_, &rank_); + int coords[3] = {0, 0, 0}; + MPI_Cart_coords(cart_comm_, rank_, 3, coords); + coords_[0] = coords[0]; + coords_[1] = coords[1]; + coords_[2] = coords[2]; + + const ModuleBase::Vector3 a1(latvec_.e11, latvec_.e12, latvec_.e13); + const ModuleBase::Vector3 a2(latvec_.e21, latvec_.e22, latvec_.e23); + const ModuleBase::Vector3 a3(latvec_.e31, latvec_.e32, latvec_.e33); + const ModuleBase::Vector3 a2xa3 = cross_product(a2, a3); + const ModuleBase::Vector3 a3xa1 = cross_product(a3, a1); + const ModuleBase::Vector3 a1xa2 = cross_product(a1, a2); + + const double volume = std::abs(dot_product(a1, a2xa3)); + const double heights[3] = { + volume / norm(a2xa3), + volume / norm(a3xa1), + volume / norm(a1xa2) + }; + const double cutoff_lat0 = (cutoff_ + skin_) / lat0_; + for (int idim = 0; idim < 3; ++idim) + { + margin_[idim] = cutoff_lat0 / heights[idim] + 1.0e-12; + } +} + +const std::array& DomainDecomposition::dims() const +{ + return dims_; +} + +const std::array& DomainDecomposition::coords() const +{ + return coords_; +} + +int DomainDecomposition::rank() const +{ + return rank_; +} + +int DomainDecomposition::size() const +{ + return size_; +} + +ModuleBase::Vector3 DomainDecomposition::wrapped_frac_from_cart( + const ModuleBase::Vector3& cart) const +{ + const ModuleBase::Vector3 frac = cart * inv_latvec_; + return ModuleBase::Vector3(wrap_fractional(frac.x), + wrap_fractional(frac.y), + wrap_fractional(frac.z)); +} + +int DomainDecomposition::rank_from_coords(const std::array& coords) const +{ + int raw_coords[3] = {coords[0], coords[1], coords[2]}; + int rank = 0; + MPI_Cart_rank(cart_comm_, raw_coords, &rank); + return rank; +} + +int DomainDecomposition::owner_rank_from_frac(const ModuleBase::Vector3& frac) const +{ + std::array owner_coords; + const double values[3] = { + wrap_fractional(frac.x), + wrap_fractional(frac.y), + wrap_fractional(frac.z) + }; + for (int idim = 0; idim < 3; ++idim) + { + int index = static_cast(std::floor(values[idim] * dims_[idim])); + index = std::min(std::max(index, 0), dims_[idim] - 1); + owner_coords[idim] = index; + } + return rank_from_coords(owner_coords); +} + +void DomainDecomposition::split_owned_atoms_from_ucell(const AtomProvider& ucell, + std::vector& owned_atoms) const +{ + owned_atoms.clear(); + owned_atoms.reserve(static_cast(ucell.get_natom() / std::max(1, size_) + 1)); + + long long global_id = 0; + for (int it = 0; it < ucell.get_ntype(); ++it) + { + for (int ia = 0; ia < ucell.get_na(it); ++ia) + { + const ModuleBase::Vector3 original_cart = ucell.get_tau(it, ia); + const ModuleBase::Vector3 frac = wrapped_frac_from_cart(original_cart); + const int owner = owner_rank_from_frac(frac); + if (owner == rank_) + { + const ModuleBase::Vector3 wrapped_cart = frac * latvec_; + owned_atoms.push_back(LocalAtom(wrapped_cart, frac, it, ia, global_id, owner, false)); + } + ++global_id; + } + } +} + +void DomainDecomposition::target_for_offset(const std::array& offset, + std::array& target_coords, + std::array& image_shift) const +{ + for (int idim = 0; idim < 3; ++idim) + { + const int unwrapped = coords_[idim] + offset[idim]; + const int period_shift = floor_div(unwrapped, dims_[idim]); + target_coords[idim] = positive_mod(unwrapped, dims_[idim]); + image_shift[idim] = -period_shift; + } +} + +bool DomainDecomposition::atom_overlaps_target_halo( + const LocalAtom& atom, + const std::array& target_coords, + const std::array& image_shift) const +{ + const double frac_values[3] = { + atom.frac.x + image_shift[0], + atom.frac.y + image_shift[1], + atom.frac.z + image_shift[2] + }; + for (int idim = 0; idim < 3; ++idim) + { + const double lo = static_cast(target_coords[idim]) / dims_[idim]; + const double hi = static_cast(target_coords[idim] + 1) / dims_[idim]; + if (frac_values[idim] < lo - margin_[idim] || + frac_values[idim] >= hi + margin_[idim]) + { + return false; + } + } + return true; +} + +int DomainDecomposition::neighbor_layer(int dim) const +{ + return std::max(1, static_cast(std::ceil(margin_[dim] * dims_[dim]))); +} + +DomainDecomposition::PackedAtom DomainDecomposition::pack_atom( + const LocalAtom& atom, + const std::array& image_shift) const +{ + PackedAtom packed; + packed.frac[0] = atom.frac.x; + packed.frac[1] = atom.frac.y; + packed.frac[2] = atom.frac.z; + packed.image_shift[0] = image_shift[0]; + packed.image_shift[1] = image_shift[1]; + packed.image_shift[2] = image_shift[2]; + packed.type = atom.type; + packed.type_index = atom.type_index; + packed.global_id = atom.global_id; + packed.owner_rank = atom.owner_rank; + return packed; +} + +LocalAtom DomainDecomposition::unpack_ghost_atom(const PackedAtom& packed) const +{ + const ModuleBase::Vector3 frac(packed.frac[0], packed.frac[1], packed.frac[2]); + const ModuleBase::Vector3 image_frac(packed.frac[0] + packed.image_shift[0], + packed.frac[1] + packed.image_shift[1], + packed.frac[2] + packed.image_shift[2]); + const ModuleBase::Vector3 cart = image_frac * latvec_; + return LocalAtom(cart, + frac, + packed.type, + packed.type_index, + packed.global_id, + packed.owner_rank, + true); +} + +void DomainDecomposition::exchange_ghost_atoms(const std::vector& owned_atoms, + std::vector& ghost_atoms) const +{ + ghost_atoms.clear(); + + const int nlayer_x = neighbor_layer(0); + const int nlayer_y = neighbor_layer(1); + const int nlayer_z = neighbor_layer(2); + + for (int dx = -nlayer_x; dx <= nlayer_x; ++dx) + { + for (int dy = -nlayer_y; dy <= nlayer_y; ++dy) + { + for (int dz = -nlayer_z; dz <= nlayer_z; ++dz) + { + if (dx == 0 && dy == 0 && dz == 0) + { + continue; + } + + std::array offset = {{dx, dy, dz}}; + std::array recv_offset = {{-dx, -dy, -dz}}; + std::array target_coords; + std::array image_shift; + std::array recv_coords; + std::array recv_image_shift; + target_for_offset(offset, target_coords, image_shift); + target_for_offset(recv_offset, recv_coords, recv_image_shift); + const int send_rank = rank_from_coords(target_coords); + const int recv_rank = rank_from_coords(recv_coords); + + std::vector send_atoms; + for (size_t iat = 0; iat < owned_atoms.size(); ++iat) + { + if (atom_overlaps_target_halo(owned_atoms[iat], target_coords, image_shift)) + { + send_atoms.push_back(pack_atom(owned_atoms[iat], image_shift)); + } + } + + if (send_rank == rank_ && recv_rank == rank_) + { + for (size_t i = 0; i < send_atoms.size(); ++i) + { + ghost_atoms.push_back(unpack_ghost_atom(send_atoms[i])); + } + continue; + } + + int send_count = static_cast(send_atoms.size()); + int recv_count = 0; + MPI_Sendrecv(&send_count, + 1, + MPI_INT, + send_rank, + 9100, + &recv_count, + 1, + MPI_INT, + recv_rank, + 9100, + cart_comm_, + MPI_STATUS_IGNORE); + + std::vector recv_atoms(static_cast(recv_count)); + const int send_bytes = static_cast(send_atoms.size() * sizeof(PackedAtom)); + const int recv_bytes = static_cast(recv_atoms.size() * sizeof(PackedAtom)); + MPI_Sendrecv(send_atoms.empty() ? NULL : &send_atoms[0], + send_bytes, + MPI_BYTE, + send_rank, + 9101, + recv_atoms.empty() ? NULL : &recv_atoms[0], + recv_bytes, + MPI_BYTE, + recv_rank, + 9101, + cart_comm_, + MPI_STATUS_IGNORE); + + for (size_t i = 0; i < recv_atoms.size(); ++i) + { + ghost_atoms.push_back(unpack_ghost_atom(recv_atoms[i])); + } + } + } + } +} diff --git a/source/source_cell/module_neighlist/domain_decomposition.h b/source/source_cell/module_neighlist/domain_decomposition.h new file mode 100644 index 00000000000..3585c9b1c38 --- /dev/null +++ b/source/source_cell/module_neighlist/domain_decomposition.h @@ -0,0 +1,91 @@ +#ifndef DOMAIN_DECOMPOSITION_H +#define DOMAIN_DECOMPOSITION_H + +#include "source_cell/module_neighlist/atom_provider.h" +#include "source_cell/module_neighlist/local_atom.h" + +#include +#include + +#include + +/** + * @brief MPI domain decomposition for distributed neighbor-search input. + * + * The decomposition is performed in fractional coordinates. Owned atoms are + * selected by wrapped fractional position, and ghost atoms are exchanged as + * shifted periodic images. + */ +class DomainDecomposition +{ +public: + DomainDecomposition(); + ~DomainDecomposition(); + + void init(MPI_Comm comm, + const ModuleBase::Matrix3& latvec, + double lat0, + double cutoff, + double skin); + + int owner_rank_from_frac(const ModuleBase::Vector3& frac) const; + + void split_owned_atoms_from_ucell(const AtomProvider& ucell, + std::vector& owned_atoms) const; + + void exchange_ghost_atoms(const std::vector& owned_atoms, + std::vector& ghost_atoms) const; + + const std::array& dims() const; + const std::array& coords() const; + int rank() const; + int size() const; + +private: + struct PackedAtom + { + double frac[3]; + int image_shift[3]; + int type; + int type_index; + long long global_id; + int owner_rank; + }; + + MPI_Comm comm_; + MPI_Comm cart_comm_; + bool owns_cart_comm_; + int rank_; + int size_; + std::array dims_; + std::array coords_; + std::array margin_; + ModuleBase::Matrix3 latvec_; + ModuleBase::Matrix3 inv_latvec_; + double lat0_; + double cutoff_; + double skin_; + + static double wrap_fractional(double value); + static int floor_div(int value, int divisor); + static int positive_mod(int value, int divisor); + static double dot_product(const ModuleBase::Vector3& a, + const ModuleBase::Vector3& b); + static ModuleBase::Vector3 cross_product(const ModuleBase::Vector3& a, + const ModuleBase::Vector3& b); + static double norm(const ModuleBase::Vector3& value); + + ModuleBase::Vector3 wrapped_frac_from_cart(const ModuleBase::Vector3& cart) const; + int rank_from_coords(const std::array& coords) const; + void target_for_offset(const std::array& offset, + std::array& target_coords, + std::array& image_shift) const; + bool atom_overlaps_target_halo(const LocalAtom& atom, + const std::array& target_coords, + const std::array& image_shift) const; + int neighbor_layer(int dim) const; + PackedAtom pack_atom(const LocalAtom& atom, const std::array& image_shift) const; + LocalAtom unpack_ghost_atom(const PackedAtom& packed) const; +}; + +#endif // DOMAIN_DECOMPOSITION_H diff --git a/source/source_cell/module_neighlist/local_atom.h b/source/source_cell/module_neighlist/local_atom.h new file mode 100644 index 00000000000..be9804a624e --- /dev/null +++ b/source/source_cell/module_neighlist/local_atom.h @@ -0,0 +1,53 @@ +#ifndef LOCAL_ATOM_H +#define LOCAL_ATOM_H + +#include "source_base/vector3.h" + +/** + * @brief Atom record owned by a distributed neighbor-search rank. + * + * cart is in lattice-coordinate units, matching UnitCell::tau and the existing + * NeighborSearch implementation. frac is wrapped into [0, 1) for owned atoms. + * Ghost atoms may have shifted cartesian coordinates while retaining the + * original wrapped frac coordinate for ownership metadata. + */ +struct LocalAtom +{ + ModuleBase::Vector3 cart; + ModuleBase::Vector3 frac; + int type; + int type_index; + long long global_id; + int owner_rank; + bool is_ghost; + + LocalAtom() + : cart(0.0, 0.0, 0.0), + frac(0.0, 0.0, 0.0), + type(0), + type_index(0), + global_id(-1), + owner_rank(0), + is_ghost(false) + { + } + + LocalAtom(const ModuleBase::Vector3& cart_in, + const ModuleBase::Vector3& frac_in, + int type_in, + int type_index_in, + long long global_id_in, + int owner_rank_in, + bool is_ghost_in) + : cart(cart_in), + frac(frac_in), + type(type_in), + type_index(type_index_in), + global_id(global_id_in), + owner_rank(owner_rank_in), + is_ghost(is_ghost_in) + { + } +}; + +#endif // LOCAL_ATOM_H diff --git a/source/source_cell/module_neighlist/neighbor_atom.h b/source/source_cell/module_neighlist/neighbor_atom.h index 9e9525c9e97..5304c2c0bb7 100644 --- a/source/source_cell/module_neighlist/neighbor_atom.h +++ b/source/source_cell/module_neighlist/neighbor_atom.h @@ -31,6 +31,12 @@ class NeighborAtom /// Unique atom ID across all domains and periodic images int atom_id; + /// Global atom ID in the primary cell. Rank-local images share this ID. + long long global_id; + + /// MPI rank that owns the primary atom. + int owner_rank; + /// Whether this atom is inside the local MPI domain bool is_inside; @@ -46,7 +52,28 @@ class NeighborAtom */ NeighborAtom(double x, double y, double z, int type, int index, int id) : position_x(x), position_y(y), position_z(z), - atom_type(type), atom_index(index), atom_id(id), is_inside(false) {} + atom_type(type), atom_index(index), atom_id(id), + global_id(id), owner_rank(0), is_inside(false) {} + + NeighborAtom(double x, + double y, + double z, + int type, + int index, + int id, + long long global_id_in, + int owner_rank_in) + : position_x(x), + position_y(y), + position_z(z), + atom_type(type), + atom_index(index), + atom_id(id), + global_id(global_id_in), + owner_rank(owner_rank_in), + is_inside(false) + { + } }; /** @@ -91,4 +118,4 @@ class InputAtoms : x_low(0), x_high(0), y_low(0), y_high(0), z_low(0), z_high(0), n_atoms(0) {} }; -#endif // NEIGHBOR_ATOM_H \ No newline at end of file +#endif // NEIGHBOR_ATOM_H diff --git a/source/source_cell/module_neighlist/neighbor_search.cpp b/source/source_cell/module_neighlist/neighbor_search.cpp index 09db6fa58fb..5461fb47bbb 100644 --- a/source/source_cell/module_neighlist/neighbor_search.cpp +++ b/source/source_cell/module_neighlist/neighbor_search.cpp @@ -583,6 +583,62 @@ void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank) this->init(ucell, sr, mpi_rank, 1); } +void NeighborSearch::init_distributed(const std::vector& owned_atoms, + const std::vector& ghost_atoms, + double sr, + double lat0) +{ + inside_atoms_.clear(); + ghost_atoms_.clear(); + all_atoms_.clear(); + bin_manager_.clear(); + + search_radius_ = sr / lat0; + glayerX_ = glayerY_ = glayerZ_ = 0; + glayerX_minus_ = glayerY_minus_ = glayerZ_minus_ = 0; + x_ = y_ = z_ = 0; + wide_x_ = wide_y_ = wide_z_ = 0.0; + + all_atoms_.reserve(owned_atoms.size() + ghost_atoms.size()); + inside_atoms_.reserve(owned_atoms.size()); + ghost_atoms_.reserve(ghost_atoms.size()); + + for (size_t iat = 0; iat < owned_atoms.size(); ++iat) + { + const LocalAtom& local = owned_atoms[iat]; + NeighborAtom atom(local.cart.x, + local.cart.y, + local.cart.z, + local.type, + local.type_index, + static_cast(all_atoms_.size()), + local.global_id, + local.owner_rank); + atom.is_inside = true; + all_atoms_.push_back(atom); + inside_atoms_.push_back(atom); + } + + for (size_t iat = 0; iat < ghost_atoms.size(); ++iat) + { + const LocalAtom& local = ghost_atoms[iat]; + NeighborAtom atom(local.cart.x, + local.cart.y, + local.cart.z, + local.type, + local.type_index, + static_cast(all_atoms_.size()), + local.global_id, + local.owner_rank); + atom.is_inside = false; + all_atoms_.push_back(atom); + ghost_atoms_.push_back(atom); + } + + neighbor_list_.initialize(inside_atoms_.size(), + std::max(1, static_cast(all_atoms_.size()) * neighbor_reserve_factor)); +} + void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank, int mpi_size) { // clear possible residual data from previous runs diff --git a/source/source_cell/module_neighlist/neighbor_search.h b/source/source_cell/module_neighlist/neighbor_search.h index 4e6a1f06177..b752fa94e4a 100644 --- a/source/source_cell/module_neighlist/neighbor_search.h +++ b/source/source_cell/module_neighlist/neighbor_search.h @@ -5,6 +5,7 @@ #include "source_cell/module_neighlist/bin_manager.h" #include "source_cell/module_neighlist/neighbor_list.h" #include "source_cell/module_neighlist/atom_provider.h" +#include "source_cell/module_neighlist/local_atom.h" /** * @brief Neighbor search algorithm for building atom neighbor lists. @@ -58,6 +59,22 @@ class NeighborSearch */ void init(const AtomProvider& ucell, double sr, int mpi_rank, int mpi_size); + /** + * @brief Initialize from rank-local owned atoms and exchanged ghost atoms. + * + * This distributed entry point does not inspect a global UnitCell. The + * caller is responsible for domain ownership and ghost exchange. + * + * @param owned_atoms Atoms owned by this rank and used as list centers. + * @param ghost_atoms Cutoff halo atoms received from neighboring ranks. + * @param sr Search radius (cutoff distance) in Bohr. + * @param lat0 Lattice constant in Bohr. + */ + void init_distributed(const std::vector& owned_atoms, + const std::vector& ghost_atoms, + double sr, + double lat0); + /** * @brief Build the neighbor list for all inside atoms. * diff --git a/source/source_cell/module_neighlist/test/CMakeLists.txt b/source/source_cell/module_neighlist/test/CMakeLists.txt index 3ee10db0bcc..94e328575fa 100644 --- a/source/source_cell/module_neighlist/test/CMakeLists.txt +++ b/source/source_cell/module_neighlist/test/CMakeLists.txt @@ -37,6 +37,7 @@ AddTest( if(ENABLE_MPI) add_executable(MODULE_CELL_NEIGHBOR_neighbor_search_mpi_benchmark neighbor_search_mpi_benchmark.cpp + ../domain_decomposition.cpp ../neighbor_search.cpp ../bin_manager.cpp ../page_allocator.cpp diff --git a/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp b/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp index 57c9e2f2ebb..917be5104fe 100644 --- a/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp +++ b/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp @@ -1,4 +1,5 @@ #include "source_cell/module_neighlist/neighbor_search.h" +#include "source_cell/module_neighlist/domain_decomposition.h" #include "source_cell/module_neighlist/unitcell_lite.h" #include @@ -168,8 +169,14 @@ int main(int argc, char** argv) { MPI_Barrier(MPI_COMM_WORLD); const double t0 = MPI_Wtime(); + DomainDecomposition decomp; + std::vector owned_atoms; + std::vector ghost_atoms; NeighborSearch ns; - ns.init(ucell, cutoff, mpi_rank, mpi_size); + decomp.init(MPI_COMM_WORLD, ucell.get_latvec(), ucell.get_lat0(), cutoff, 0.0); + decomp.split_owned_atoms_from_ucell(ucell, owned_atoms); + decomp.exchange_ghost_atoms(owned_atoms, ghost_atoms); + ns.init_distributed(owned_atoms, ghost_atoms, cutoff, ucell.get_lat0()); const double t1 = MPI_Wtime(); ns.build_neighbors(); const double t2 = MPI_Wtime(); diff --git a/source/source_cell/module_neighlist/test/neighbor_search_test.cpp b/source/source_cell/module_neighlist/test/neighbor_search_test.cpp index f2bbe790b8a..266c0b7f2ad 100644 --- a/source/source_cell/module_neighlist/test/neighbor_search_test.cpp +++ b/source/source_cell/module_neighlist/test/neighbor_search_test.cpp @@ -64,6 +64,39 @@ TEST(NeighborSearchTest, NoNeighbor) EXPECT_EQ(list.get_numneigh(1), 0); } +TEST(NeighborSearchTest, DistributedInputUsesOwnedCentersAndGhostNeighbors) +{ + std::vector owned_atoms; + std::vector ghost_atoms; + owned_atoms.push_back(LocalAtom(ModuleBase::Vector3(0.0, 0.0, 0.0), + ModuleBase::Vector3(0.0, 0.0, 0.0), + 0, + 0, + 0, + 0, + false)); + ghost_atoms.push_back(LocalAtom(ModuleBase::Vector3(0.5, 0.0, 0.0), + ModuleBase::Vector3(0.5, 0.0, 0.0), + 0, + 1, + 1, + 1, + true)); + + NeighborSearch ns; + ns.init_distributed(owned_atoms, ghost_atoms, 1.0, 1.0); + ns.build_neighbors(); + + const NeighborList& list = ns.get_neighbor_list(); + ASSERT_EQ(list.get_nlocal(), 1); + ASSERT_EQ(list.get_numneigh(0), 1); + const int neighbor_id = list.get_firstneigh(0)[0]; + ASSERT_GE(neighbor_id, 0); + ASSERT_LT(neighbor_id, static_cast(ns.get_all_atoms().size())); + EXPECT_EQ(ns.get_all_atoms()[neighbor_id].global_id, 1); + EXPECT_EQ(ns.get_all_atoms()[neighbor_id].owner_rank, 1); +} + TEST(NeighborSearchUnit, DistanceBox) { NeighborSearch ns; diff --git a/source/source_esolver/esolver_lj.cpp b/source/source_esolver/esolver_lj.cpp index 4997d42f2b2..6ee2c40b48a 100644 --- a/source/source_esolver/esolver_lj.cpp +++ b/source/source_esolver/esolver_lj.cpp @@ -8,6 +8,7 @@ #include "source_base/global_variable.h" #include "source_base/timer.h" #ifdef __MPI +#include "source_cell/module_neighlist/domain_decomposition.h" #include "source_base/parallel_reduce.h" #endif @@ -79,7 +80,13 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) { ModuleBase::timer::start("ESolverLJ", "mpi_total"); ModuleBase::timer::start("ESolverLJ", "neigh_init"); - neighbor_search.init(ucell_lite, search_radius, GlobalV::MY_RANK, GlobalV::NPROC); + DomainDecomposition decomp; + decomp.init(MPI_COMM_WORLD, ucell_lite.get_latvec(), ucell_lite.get_lat0(), search_radius, 0.0); + std::vector owned_atoms; + std::vector ghost_atoms; + decomp.split_owned_atoms_from_ucell(ucell_lite, owned_atoms); + decomp.exchange_ghost_atoms(owned_atoms, ghost_atoms); + neighbor_search.init_distributed(owned_atoms, ghost_atoms, search_radius, ucell_lite.get_lat0()); ModuleBase::timer::end("ESolverLJ", "neigh_init"); ModuleBase::timer::start("ESolverLJ", "neigh_bld"); neighbor_search.build_neighbors(); From 010a6acf23b87522f025463a2603d6d1944b3114 Mon Sep 17 00:00:00 2001 From: Fei Yang <2501213217@stu.pku.edu.cn> Date: Fri, 3 Jul 2026 00:41:18 +0800 Subject: [PATCH 5/5] add mpi in neighbor_search2 --- .../module_neighlist/bin_manager.cpp | 116 ++- .../module_neighlist/bin_manager.h | 54 +- .../module_neighlist/domain_decomposition.cpp | 4 +- .../module_neighlist/domain_decomposition.h | 2 +- .../source_cell/module_neighlist/local_atom.h | 5 +- .../module_neighlist/neighbor_atom.h | 69 +- .../module_neighlist/neighbor_list.h | 10 +- .../module_neighlist/neighbor_search.cpp | 840 +++--------------- .../module_neighlist/neighbor_search.h | 241 +---- .../module_neighlist/page_allocator.cpp | 11 + .../test/bin_manager_test.cpp | 98 +- .../test/neighbor_search_mpi_benchmark.cpp | 8 +- .../test/neighbor_search_test.cpp | 456 +++------- .../module_neighlist/unitcell_lite.cpp | 9 +- source/source_esolver/esolver_lj.cpp | 97 +- 15 files changed, 453 insertions(+), 1567 deletions(-) diff --git a/source/source_cell/module_neighlist/bin_manager.cpp b/source/source_cell/module_neighlist/bin_manager.cpp index 1077b91dd75..0cae41420fe 100644 --- a/source/source_cell/module_neighlist/bin_manager.cpp +++ b/source/source_cell/module_neighlist/bin_manager.cpp @@ -2,24 +2,13 @@ #include #include #include +#include #include "bin_manager.h" // ========== Bin class implementation ========== -int Bin::get_id_x() const { - return id_x_; -} - -int Bin::get_id_y() const { - return id_y_; -} - -int Bin::get_id_z() const { - return id_z_; -} - -const std::vector& Bin::get_atoms() const { - return atoms_; +const std::vector& Bin::get_atom_indices() const { + return atom_indices_; } void Bin::set_id(int ix, int iy, int iz) { @@ -29,11 +18,11 @@ void Bin::set_id(int ix, int iy, int iz) { } void Bin::clear_atoms() { - atoms_.clear(); + atom_indices_.clear(); } -void Bin::add_atom(const NeighborAtom& atom) { - atoms_.push_back(atom); +void Bin::add_atom_index(ModuleNeighList::LocalAtomIndex atom_index) { + atom_indices_.push_back(atom_index); } // ========== BinManager getter methods ========== @@ -51,26 +40,30 @@ int BinManager::get_nbinz() const { } int BinManager::get_total_bins() const { - return static_cast(bins_.size()); + return ModuleNeighList::checked_int_size(bins_.size(), "BinManager total bin count"); } int BinManager::get_bin_atom_count(int bin_index) const { - if (bin_index < 0 || bin_index >= static_cast(bins_.size())) { + if (bin_index < 0 || static_cast(bin_index) >= bins_.size()) { return 0; } - return static_cast(bins_[bin_index].get_atoms().size()); + return ModuleNeighList::checked_int_size(bins_[bin_index].get_atom_indices().size(), + "Bin atom count"); } // ========== BinManager main methods ========== void BinManager::init_bins( double sr, - const std::vector& inside_atoms, - const std::vector& ghost_atoms + const std::vector& all_atoms ) { sradius_ = sr; - if(inside_atoms.empty() && ghost_atoms.empty()) + if (!std::isfinite(sradius_) || sradius_ <= 0.0) + { + throw std::invalid_argument("BinManager search radius must be finite and positive."); + } + if(all_atoms.empty()) { x_min_ = y_min_ = z_min_ = 0; x_max_ = y_max_ = z_max_ = 0; @@ -98,20 +91,34 @@ void BinManager::init_bins( } }; - update_bounds(inside_atoms); - update_bounds(ghost_atoms); + update_bounds(all_atoms); bin_sizex_ = bin_sizey_ = bin_sizez_ = sradius_; - nbinx_ = std::ceil((x_max_ - x_min_) / bin_sizex_); - nbiny_ = std::ceil((y_max_ - y_min_) / bin_sizey_); - nbinz_ = std::ceil((z_max_ - z_min_) / bin_sizez_); + const auto checked_bin_dimension = [](const double span, const double bin_size, const char* context) { + const double count = std::ceil(span / bin_size); + if (!std::isfinite(count) || count > static_cast(std::numeric_limits::max())) + { + throw std::overflow_error(std::string(context) + " exceeds int range."); + } + return static_cast(count); + }; + + nbinx_ = checked_bin_dimension(x_max_ - x_min_, bin_sizex_, "BinManager X bin count"); + nbiny_ = checked_bin_dimension(y_max_ - y_min_, bin_sizey_, "BinManager Y bin count"); + nbinz_ = checked_bin_dimension(z_max_ - z_min_, bin_sizez_, "BinManager Z bin count"); nbinx_ = std::max(1, nbinx_); nbiny_ = std::max(1, nbiny_); nbinz_ = std::max(1, nbinz_); - int nbins = nbinx_ * nbiny_ * nbinz_; + const std::size_t nbins_xy = ModuleNeighList::checked_size_product(static_cast(nbinx_), + static_cast(nbiny_), + "BinManager bin count"); + const std::size_t nbins_size = ModuleNeighList::checked_size_product(nbins_xy, + static_cast(nbinz_), + "BinManager bin count"); + const int nbins = ModuleNeighList::checked_int_size(nbins_size, "BinManager bin count"); bins_.clear(); bins_.resize(nbins); @@ -132,12 +139,17 @@ void BinManager::init_bins( } void BinManager::do_binning( - const std::vector& inside_atoms, - const std::vector& ghost_atoms + const std::vector& atoms ) { - auto bin_atom = [&](const NeighborAtom& atom) + if (atoms.size() > static_cast(std::numeric_limits::max())) { + throw std::overflow_error("BinManager binned atom count exceeds local atom index range."); + } + + for (std::size_t iatom = 0; iatom < atoms.size(); ++iatom) + { + const NeighborAtom& atom = atoms[iatom]; int ix = std::min( std::max(int((atom.position_x - x_min_) / bin_sizex_), 0), nbinx_ - 1 @@ -155,11 +167,9 @@ void BinManager::do_binning( int idx = bin_index(ix, iy, iz); - bins_[idx].add_atom(atom); - }; - - for (const auto& atom : inside_atoms) bin_atom(atom); - for (const auto& atom : ghost_atoms) bin_atom(atom); + const ModuleNeighList::LocalAtomIndex atom_index = static_cast(iatom); + bins_[idx].add_atom_index(atom_index); + } } int BinManager::bin_index(int ix, int iy, int iz) const { @@ -168,7 +178,8 @@ int BinManager::bin_index(int ix, int iy, int iz) const { void BinManager::build_atom_neighbors( NeighborList& neighbor_list, - std::vector& atoms + const std::vector& atoms, + const std::vector& binned_atoms ) { assert(atoms.size() == static_cast(neighbor_list.get_nlocal())); @@ -179,22 +190,24 @@ void BinManager::build_atom_neighbors( std::vector neigh_tmp; - for (int i = 0; i < atoms.size(); i++) + const int nlocal = neighbor_list.get_nlocal(); + for (int i = 0; i < nlocal; i++) { neigh_tmp.clear(); + const NeighborAtom& atom = atoms[i]; int ix = std::min( - std::max(int((atoms[i].position_x - x_min_) / bin_sizex_), 0), + std::max(int((atom.position_x - x_min_) / bin_sizex_), 0), nbinx_ - 1 ); int iy = std::min( - std::max(int((atoms[i].position_y - y_min_) / bin_sizey_), 0), + std::max(int((atom.position_y - y_min_) / bin_sizey_), 0), nbiny_ - 1 ); int iz = std::min( - std::max(int((atoms[i].position_z - z_min_) / bin_sizez_), 0), + std::max(int((atom.position_z - z_min_) / bin_sizez_), 0), nbinz_ - 1 ); @@ -215,15 +228,20 @@ void BinManager::build_atom_neighbors( int nidx = bin_index(jx, jy, jz); - for (const NeighborAtom& natom : bins_[nidx].get_atoms()) + for (const ModuleNeighList::LocalAtomIndex binned_atom_index : bins_[nidx].get_atom_indices()) { - double dx = atoms[i].position_x - natom.position_x; - double dy = atoms[i].position_y - natom.position_y; - double dz = atoms[i].position_z - natom.position_z; + const NeighborAtom& natom = binned_atoms[static_cast(binned_atom_index)]; + double dx = atom.position_x - natom.position_x; + double dy = atom.position_y - natom.position_y; + double dz = atom.position_z - natom.position_z; double dist2 = dx * dx + dy * dy + dz * dz; - if (dist2 <= sradius2 && dist2 != 0) + if (natom.atom_id == atom.atom_id) + { + continue; + } + if (dist2 <= sradius2) { neigh_tmp.push_back(natom.atom_id); } @@ -232,7 +250,7 @@ void BinManager::build_atom_neighbors( } } - int n = neigh_tmp.size(); + const int n = ModuleNeighList::checked_int_size(neigh_tmp.size(), "BinManager neighbor count"); int* ptr = neighbor_list.allocator_.allocate(n); @@ -255,4 +273,4 @@ void BinManager::clear() } bins_.clear(); -} \ No newline at end of file +} diff --git a/source/source_cell/module_neighlist/bin_manager.h b/source/source_cell/module_neighlist/bin_manager.h index 22b94d394a0..ffb470e8724 100644 --- a/source/source_cell/module_neighlist/bin_manager.h +++ b/source/source_cell/module_neighlist/bin_manager.h @@ -4,11 +4,12 @@ #include #include "source_cell/module_neighlist/neighbor_atom.h" #include "source_cell/module_neighlist/neighbor_list.h" +#include "source_cell/module_neighlist/neighbor_types.h" /** * @brief A single bin in the 3D binning grid for neighbor search. * - * Each bin stores atoms that fall within its spatial region, + * Each bin stores indices of atoms that fall within its spatial region, * along with its position indices in the 3D grid. */ class Bin @@ -27,28 +28,10 @@ class Bin // ========== Getter methods ========== /** - * @brief Get the X index of this bin in the grid. - * @return X index. + * @brief Get the atom indices stored in this bin. + * @return Const reference to the atom-index vector. */ - int get_id_x() const; - - /** - * @brief Get the Y index of this bin in the grid. - * @return Y index. - */ - int get_id_y() const; - - /** - * @brief Get the Z index of this bin in the grid. - * @return Z index. - */ - int get_id_z() const; - - /** - * @brief Get the atoms stored in this bin. - * @return Const reference to the atom vector. - */ - const std::vector& get_atoms() const; + const std::vector& get_atom_indices() const; // ========== Setter methods (internal use) ========== @@ -66,10 +49,10 @@ class Bin void clear_atoms(); /** - * @brief Add an atom to this bin. - * @param atom The atom to add. + * @brief Add an atom index to this bin. + * @param atom_index Index of the atom in the vector passed to BinManager::do_binning(). */ - void add_atom(const NeighborAtom& atom); + void add_atom_index(ModuleNeighList::LocalAtomIndex atom_index); private: /// X index in the 3D bin grid @@ -81,8 +64,8 @@ class Bin /// Z index in the 3D bin grid int id_z_ = 0; - /// Atoms contained in this bin - std::vector atoms_; + /// Indices into the atom vector passed to BinManager::do_binning(). + std::vector atom_indices_; }; /** @@ -113,8 +96,7 @@ class BinManager */ void init_bins( double sr, - const std::vector& inside_atoms, - const std::vector& ghost_atoms + const std::vector& all_atoms ); /** @@ -123,13 +105,9 @@ class BinManager * Must be called after init_bins(). Each atom is placed into the * bin that contains its spatial position. * - * @param inside_atoms Atoms inside the local MPI domain. - * @param ghost_atoms Ghost atoms from neighboring domains. + * @param atoms All atoms to assign to bins. */ - void do_binning( - const std::vector& inside_atoms, - const std::vector& ghost_atoms - ); + void do_binning(const std::vector& atoms); /** * @brief Build neighbor list by searching adjacent bins. @@ -139,10 +117,12 @@ class BinManager * * @param neighbor_list Output neighbor list to populate. * @param atoms Atoms for which to build neighbors. + * @param binned_atoms All atoms assigned to bins by do_binning(). */ void build_atom_neighbors( NeighborList& neighbor_list, - std::vector& atoms + const std::vector& atoms, + const std::vector& binned_atoms ); /** @@ -220,4 +200,4 @@ class BinManager int bin_index(int ix, int iy, int iz) const; }; -#endif // BIN_MANAGER_H \ No newline at end of file +#endif // BIN_MANAGER_H diff --git a/source/source_cell/module_neighlist/domain_decomposition.cpp b/source/source_cell/module_neighlist/domain_decomposition.cpp index 974196829fd..5cb24b02eba 100644 --- a/source/source_cell/module_neighlist/domain_decomposition.cpp +++ b/source/source_cell/module_neighlist/domain_decomposition.cpp @@ -1,6 +1,7 @@ #include "source_cell/module_neighlist/domain_decomposition.h" #include +#include #include #include @@ -48,6 +49,7 @@ double DomainDecomposition::wrap_fractional(double value) int DomainDecomposition::floor_div(int value, int divisor) { + assert(divisor!=0); int quotient = value / divisor; const int remainder = value % divisor; if (remainder != 0 && ((remainder < 0) != (divisor < 0))) @@ -204,7 +206,7 @@ void DomainDecomposition::split_owned_atoms_from_ucell(const AtomProvider& ucell owned_atoms.clear(); owned_atoms.reserve(static_cast(ucell.get_natom() / std::max(1, size_) + 1)); - long long global_id = 0; + ModuleNeighList::GlobalAtomId global_id = 0; for (int it = 0; it < ucell.get_ntype(); ++it) { for (int ia = 0; ia < ucell.get_na(it); ++ia) diff --git a/source/source_cell/module_neighlist/domain_decomposition.h b/source/source_cell/module_neighlist/domain_decomposition.h index 3585c9b1c38..19fb1be371c 100644 --- a/source/source_cell/module_neighlist/domain_decomposition.h +++ b/source/source_cell/module_neighlist/domain_decomposition.h @@ -48,7 +48,7 @@ class DomainDecomposition int image_shift[3]; int type; int type_index; - long long global_id; + ModuleNeighList::GlobalAtomId global_id; int owner_rank; }; diff --git a/source/source_cell/module_neighlist/local_atom.h b/source/source_cell/module_neighlist/local_atom.h index be9804a624e..f48a8da8f75 100644 --- a/source/source_cell/module_neighlist/local_atom.h +++ b/source/source_cell/module_neighlist/local_atom.h @@ -1,6 +1,7 @@ #ifndef LOCAL_ATOM_H #define LOCAL_ATOM_H +#include "source_cell/module_neighlist/neighbor_types.h" #include "source_base/vector3.h" /** @@ -17,7 +18,7 @@ struct LocalAtom ModuleBase::Vector3 frac; int type; int type_index; - long long global_id; + ModuleNeighList::GlobalAtomId global_id; int owner_rank; bool is_ghost; @@ -36,7 +37,7 @@ struct LocalAtom const ModuleBase::Vector3& frac_in, int type_in, int type_index_in, - long long global_id_in, + ModuleNeighList::GlobalAtomId global_id_in, int owner_rank_in, bool is_ghost_in) : cart(cart_in), diff --git a/source/source_cell/module_neighlist/neighbor_atom.h b/source/source_cell/module_neighlist/neighbor_atom.h index 5304c2c0bb7..3f62d30571a 100644 --- a/source/source_cell/module_neighlist/neighbor_atom.h +++ b/source/source_cell/module_neighlist/neighbor_atom.h @@ -1,6 +1,8 @@ #ifndef NEIGHBOR_ATOM_H #define NEIGHBOR_ATOM_H +#include "source_cell/module_neighlist/neighbor_types.h" + #include /** @@ -28,18 +30,15 @@ class NeighborAtom /// Index of the atom within its type int atom_index; - /// Unique atom ID across all domains and periodic images - int atom_id; + /// Rank-local atom ID used by the neighbor list. + ModuleNeighList::LocalAtomIndex atom_id; /// Global atom ID in the primary cell. Rank-local images share this ID. - long long global_id; + ModuleNeighList::GlobalAtomId global_id; /// MPI rank that owns the primary atom. int owner_rank; - /// Whether this atom is inside the local MPI domain - bool is_inside; - /** * @brief Construct a NeighborAtom. * @@ -50,18 +49,23 @@ class NeighborAtom * @param index Index within the atom type. * @param id Unique atom ID. */ - NeighborAtom(double x, double y, double z, int type, int index, int id) + NeighborAtom(double x, + double y, + double z, + int type, + int index, + ModuleNeighList::LocalAtomIndex id) : position_x(x), position_y(y), position_z(z), atom_type(type), atom_index(index), atom_id(id), - global_id(id), owner_rank(0), is_inside(false) {} + global_id(id), owner_rank(0) {} NeighborAtom(double x, double y, double z, int type, int index, - int id, - long long global_id_in, + ModuleNeighList::LocalAtomIndex id, + ModuleNeighList::GlobalAtomId global_id_in, int owner_rank_in) : position_x(x), position_y(y), @@ -70,52 +74,9 @@ class NeighborAtom atom_index(index), atom_id(id), global_id(global_id_in), - owner_rank(owner_rank_in), - is_inside(false) + owner_rank(owner_rank_in) { } }; -/** - * @brief Input structure for neighbor search initialization. - * - * Contains atom data and spatial bounds computed from input atoms, - * used to initialize the binning grid. - */ -class InputAtoms -{ -public: - /// List of input atoms - std::vector InputAtom; - - /// Minimum X coordinate of the atom bounding box - double x_low; - - /// Maximum X coordinate of the atom bounding box - double x_high; - - /// Minimum Y coordinate of the atom bounding box - double y_low; - - /// Maximum Y coordinate of the atom bounding box - double y_high; - - /// Minimum Z coordinate of the atom bounding box - double z_low; - - /// Maximum Z coordinate of the atom bounding box - double z_high; - - /// Total number of atoms - int n_atoms; - - /** - * @brief Default constructor. - * - * Initializes bounds to zero and atom count to zero. - */ - InputAtoms() - : x_low(0), x_high(0), y_low(0), y_high(0), z_low(0), z_high(0), n_atoms(0) {} -}; - #endif // NEIGHBOR_ATOM_H diff --git a/source/source_cell/module_neighlist/neighbor_list.h b/source/source_cell/module_neighlist/neighbor_list.h index c14c80535f0..fe93f7da004 100644 --- a/source/source_cell/module_neighlist/neighbor_list.h +++ b/source/source_cell/module_neighlist/neighbor_list.h @@ -1,6 +1,8 @@ #ifndef NEIGHBOR_LIST_H #define NEIGHBOR_LIST_H +#include "source_cell/module_neighlist/neighbor_types.h" + #include #include "page_allocator.h" @@ -10,10 +12,10 @@ class NeighborList NeighborList() = default; ~NeighborList() = default; - void initialize(int nlocal, int pgsize) + void initialize(std::size_t nlocal, std::size_t pgsize) { - nlocal_ = nlocal; - allocator_ = PageAllocator(pgsize); + nlocal_ = ModuleNeighList::checked_int_size(nlocal, "NeighborList local atom count"); + allocator_ = PageAllocator(ModuleNeighList::checked_int_size(pgsize, "NeighborList page size")); numneigh_.assign(nlocal, 0); firstneigh_.assign(nlocal, nullptr); } @@ -39,4 +41,4 @@ class NeighborList friend class BinManager; }; -#endif // NEIGHBOR_LIST_H \ No newline at end of file +#endif // NEIGHBOR_LIST_H diff --git a/source/source_cell/module_neighlist/neighbor_search.cpp b/source/source_cell/module_neighlist/neighbor_search.cpp index 5461fb47bbb..74e21cfac69 100644 --- a/source/source_cell/module_neighlist/neighbor_search.cpp +++ b/source/source_cell/module_neighlist/neighbor_search.cpp @@ -4,186 +4,16 @@ #include #include #include +#include #include #include -namespace -{ -struct OriginalAtom -{ - OriginalAtom(const std::array& frac_in, int atom_type_in, int atom_index_in) - : frac(frac_in), atom_type(atom_type_in), atom_index(atom_index_in) - { - } - - std::array frac; - int atom_type = 0; - int atom_index = 0; -}; - -struct PeriodicInterval -{ - PeriodicInterval(double lo_in, double hi_in, int shift_in) : lo(lo_in), hi(hi_in), shift(shift_in) - { - } - - double lo = 0.0; - double hi = 0.0; - int shift = 0; -}; - -struct FractionalDomain -{ - FractionalDomain(const std::array& lo_in, const std::array& hi_in) - : lo(lo_in), hi(hi_in) - { - } - - std::array lo; - std::array hi; -}; - -double dot_product(const ModuleBase::Vector3& a, const ModuleBase::Vector3& b) -{ - return a.x * b.x + a.y * b.y + a.z * b.z; -} - -ModuleBase::Vector3 cross_product(const ModuleBase::Vector3& a, - const ModuleBase::Vector3& b) -{ - return ModuleBase::Vector3(a.y * b.z - a.z * b.y, - a.z * b.x - a.x * b.z, - a.x * b.y - a.y * b.x); -} - -double norm(const ModuleBase::Vector3& v) -{ - return std::sqrt(dot_product(v, v)); -} - -double wrap_fractional(double value) -{ - value -= std::floor(value); - if (value >= 1.0 - 1.0e-12) - { - return 0.0; - } - if (value < 1.0e-12) - { - return 0.0; - } - return value; -} - -int clamp_index(int value, int low, int high) -{ - return std::min(std::max(value, low), high); -} - -int fractional_domain_index(double frac, int n) -{ - return clamp_index(static_cast(std::floor(frac * n)), 0, n - 1); -} - -long long bin_key(int ix, int iy, int iz, const std::array& nbin) -{ - return (static_cast(ix) * nbin[1] + iy) * nbin[2] + iz; -} - -std::vector split_periodic_interval(double lo, double hi) -{ - std::vector intervals; - if (hi <= lo) - { - return intervals; - } - - const int first_shift = static_cast(std::floor(lo)); - const int last_shift = static_cast(std::ceil(hi)) - 1; - for (int shift = first_shift; shift <= last_shift; ++shift) - { - const double local_lo = std::max(0.0, lo - shift); - const double local_hi = std::min(1.0, hi - shift); - if (local_lo < local_hi) - { - intervals.push_back(PeriodicInterval(local_lo, local_hi, shift)); - } - } - return intervals; -} - -bool inside_interval(double value, double lo, double hi) -{ - return value >= lo && value < hi; -} - -bool inside_block(const OriginalAtom& atom, - const PeriodicInterval& bx, - const PeriodicInterval& by, - const PeriodicInterval& bz) -{ - return inside_interval(atom.frac[0], bx.lo, bx.hi) && - inside_interval(atom.frac[1], by.lo, by.hi) && - inside_interval(atom.frac[2], bz.lo, bz.hi); -} -} // namespace - -constexpr double NeighborSearch::coord_tolerance; - // ========== Getter methods ========== double NeighborSearch::get_search_radius() const { return search_radius_; } -int NeighborSearch::get_x() const { - return x_; -} - -int NeighborSearch::get_y() const { - return y_; -} - -int NeighborSearch::get_z() const { - return z_; -} - -double NeighborSearch::get_wide_x() const { - return wide_x_; -} - -double NeighborSearch::get_wide_y() const { - return wide_y_; -} - -double NeighborSearch::get_wide_z() const { - return wide_z_; -} - -int NeighborSearch::get_glayerX() const { - return glayerX_; -} - -int NeighborSearch::get_glayerY() const { - return glayerY_; -} - -int NeighborSearch::get_glayerZ() const { - return glayerZ_; -} - -int NeighborSearch::get_glayerX_minus() const { - return glayerX_minus_; -} - -int NeighborSearch::get_glayerY_minus() const { - return glayerY_minus_; -} - -int NeighborSearch::get_glayerZ_minus() const { - return glayerZ_minus_; -} - const std::vector& NeighborSearch::get_all_atoms() const { return all_atoms_; } @@ -204,385 +34,8 @@ const NeighborList& NeighborSearch::get_neighbor_list() const { return neighbor_list_; } -// ========== Setter methods ========== - -void NeighborSearch::set_search_radius(double sr) { - search_radius_ = sr; -} - -void NeighborSearch::set_position(int x, int y, int z) { - x_ = x; - y_ = y; - z_ = z; -} - -void NeighborSearch::set_width(double wx, double wy, double wz) { - wide_x_ = wx; - wide_y_ = wy; - wide_z_ = wz; -} - -// ========== Internal methods ========== - -double NeighborSearch::cross_product_norm(double a1, double a2, double a3, - double b1, double b2, double b3) -{ - double c1 = a2 * b3 - a3 * b2; - double c2 = a3 * b1 - a1 * b3; - double c3 = a1 * b2 - a2 * b1; - return sqrt(c1 * c1 + c2 * c2 + c3 * c3); -} - -InputAtoms NeighborSearch::ucell_to_input_atoms(const AtomProvider& ucell) -{ - InputAtoms input_atoms; - int atom_count = 0; - assert(ucell.get_natom() > 0); - - input_atoms.x_low = input_atoms.y_low = input_atoms.z_low = std::numeric_limits::max(); - input_atoms.x_high = input_atoms.y_high = input_atoms.z_high = std::numeric_limits::lowest(); - - for (int i = 0; i < ucell.get_ntype(); i++) - { - for (int j = 0; j < ucell.get_na(i); j++) - { - NeighborAtom atom( - ucell.get_tau(i,j).x, - ucell.get_tau(i,j).y, - ucell.get_tau(i,j).z, - i, - j, - atom_count - ); - input_atoms.InputAtom.push_back(atom); - - input_atoms.x_low = std::min(input_atoms.x_low, atom.position_x); - input_atoms.x_high = std::max(input_atoms.x_high, atom.position_x); - input_atoms.y_low = std::min(input_atoms.y_low, atom.position_y); - input_atoms.y_high = std::max(input_atoms.y_high, atom.position_y); - input_atoms.z_low = std::min(input_atoms.z_low, atom.position_z); - input_atoms.z_high = std::max(input_atoms.z_high, atom.position_z); - - atom_count++; - } - } - - input_atoms.n_atoms = atom_count; - return input_atoms; -} - -void NeighborSearch::check_expand_condition(const AtomProvider& ucell) -{ - const auto& lat = ucell.get_latvec(); - const double omega = ucell.get_omega(); - const double lat0 = ucell.get_lat0(); - const double lat0_cubed = lat0 * lat0 * lat0; - - double a23_norm = cross_product_norm(lat.e21, lat.e22, lat.e23, lat.e31, lat.e32, lat.e33); - int extend_d11 = std::ceil(a23_norm * search_radius_ / omega * lat0_cubed); - - double a31_norm = cross_product_norm(lat.e31, lat.e32, lat.e33, lat.e11, lat.e12, lat.e13); - int extend_d22 = std::ceil(a31_norm * search_radius_ / omega * lat0_cubed); - - double a12_norm = cross_product_norm(lat.e11, lat.e12, lat.e13, lat.e21, lat.e22, lat.e23); - int extend_d33 = std::ceil(a12_norm * search_radius_ / omega * lat0_cubed); - - glayerX_ = extend_d11 + positive_layer_offset; - glayerY_ = extend_d22 + positive_layer_offset; - glayerZ_ = extend_d33 + positive_layer_offset; - glayerX_minus_ = extend_d11; - glayerY_minus_ = extend_d22; - glayerZ_minus_ = extend_d33; -} - -void NeighborSearch::set_member_variables(const AtomProvider& ucell) -{ - all_atoms_.clear(); - - ModuleBase::Vector3 vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13); - ModuleBase::Vector3 vec2(ucell.get_latvec().e21, ucell.get_latvec().e22, ucell.get_latvec().e23); - ModuleBase::Vector3 vec3(ucell.get_latvec().e31, ucell.get_latvec().e32, ucell.get_latvec().e33); - - int atom_count = 0; - - for (int ix = -glayerX_minus_; ix < glayerX_; ix++) - { - for (int iy = -glayerY_minus_; iy < glayerY_; iy++) - { - for (int iz = -glayerZ_minus_; iz < glayerZ_; iz++) - { - for (int i = 0; i < ucell.get_ntype(); i++) - { - for (int j = 0; j < ucell.get_na(i); j++) - { - double atom_x = ucell.get_tau(i,j).x + vec1[0] * ix + vec2[0] * iy + vec3[0] * iz; - double atom_y = ucell.get_tau(i,j).y + vec1[1] * ix + vec2[1] * iy + vec3[1] * iz; - double atom_z = ucell.get_tau(i,j).z + vec1[2] * ix + vec2[2] * iy + vec3[2] * iz; - - NeighborAtom atom(atom_x, atom_y, atom_z, i, j, atom_count); - if(ix==0 && iy==0 && iz==0) - { - atom.is_inside = true; - } - else - { - atom.is_inside = false; - } - all_atoms_.push_back(atom); - atom_count++; - } - } - } - } - } -} - -void NeighborSearch::set_local_member_variables(const AtomProvider& ucell, - const InputAtoms& atoms, - int nx, - int ny, - int nz) -{ - all_atoms_.clear(); - inside_atoms_.clear(); - ghost_atoms_.clear(); - - ModuleBase::Vector3 vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13); - ModuleBase::Vector3 vec2(ucell.get_latvec().e21, ucell.get_latvec().e22, ucell.get_latvec().e23); - ModuleBase::Vector3 vec3(ucell.get_latvec().e31, ucell.get_latvec().e32, ucell.get_latvec().e33); - - const auto domain_index = [](double position, double low, double wide, int n) { - if (wide < coord_tolerance) - { - return std::abs(position - low) < coord_tolerance ? 0 : std::numeric_limits::max(); - } - return std::min(std::max(static_cast(std::floor((position - low) / wide)), 0), n - 1); - }; - - const double search_radius2 = search_radius_ * search_radius_; - for (int ix = -glayerX_minus_; ix < glayerX_; ix++) - { - for (int iy = -glayerY_minus_; iy < glayerY_; iy++) - { - for (int iz = -glayerZ_minus_; iz < glayerZ_; iz++) - { - const bool central_image = (ix == 0 && iy == 0 && iz == 0); - for (int it = 0; it < ucell.get_ntype(); it++) - { - for (int ia = 0; ia < ucell.get_na(it); ia++) - { - double atom_x = ucell.get_tau(it, ia).x + vec1[0] * ix + vec2[0] * iy + vec3[0] * iz; - double atom_y = ucell.get_tau(it, ia).y + vec1[1] * ix + vec2[1] * iy + vec3[1] * iz; - double atom_z = ucell.get_tau(it, ia).z + vec1[2] * ix + vec2[2] * iy + vec3[2] * iz; - - const int in_x = domain_index(atom_x, atoms.x_low, wide_x_, nx); - const int in_y = domain_index(atom_y, atoms.y_low, wide_y_, ny); - const int in_z = domain_index(atom_z, atoms.z_low, wide_z_, nz); - - const bool owned = central_image && - in_x == x_ && - in_y == y_ && - in_z == z_ && - atom_x <= atoms.x_high && - atom_y <= atoms.y_high && - atom_z <= atoms.z_high; - const bool ghost = !owned && - distance(atom_x, atom_y, atom_z, atoms.x_low, atoms.y_low, atoms.z_low) - <= search_radius2; - - if (!owned && !ghost) - { - continue; - } - - NeighborAtom atom(atom_x, atom_y, atom_z, it, ia, static_cast(all_atoms_.size())); - atom.is_inside = owned; - all_atoms_.push_back(atom); - if (owned) - { - inside_atoms_.push_back(atom); - } - else - { - ghost_atoms_.push_back(atom); - } - } - } - } - } - } -} - -void NeighborSearch::set_local_member_variables_by_halo(const AtomProvider& ucell, int nx, int ny, int nz) -{ - all_atoms_.clear(); - inside_atoms_.clear(); - ghost_atoms_.clear(); - - const ModuleBase::Matrix3& lat = ucell.get_latvec(); - const ModuleBase::Matrix3 inv_lat = lat.Inverse(); - - const ModuleBase::Vector3 a1(lat.e11, lat.e12, lat.e13); - const ModuleBase::Vector3 a2(lat.e21, lat.e22, lat.e23); - const ModuleBase::Vector3 a3(lat.e31, lat.e32, lat.e33); - - const ModuleBase::Vector3 a2xa3 = cross_product(a2, a3); - const ModuleBase::Vector3 a3xa1 = cross_product(a3, a1); - const ModuleBase::Vector3 a1xa2 = cross_product(a1, a2); - - const double volume = std::abs(dot_product(a1, a2xa3)); - assert(volume > coord_tolerance); - - const std::array heights = {{ - volume / norm(a2xa3), - volume / norm(a3xa1), - volume / norm(a1xa2) - }}; - - std::array margin = {{ - search_radius_ / heights[0] + coord_tolerance, - search_radius_ / heights[1] + coord_tolerance, - search_radius_ / heights[2] + coord_tolerance - }}; - - const std::array domain_lo = {{ - static_cast(x_) / nx, - static_cast(y_) / ny, - static_cast(z_) / nz - }}; - const std::array domain_hi = {{ - static_cast(x_ + 1) / nx, - static_cast(y_ + 1) / ny, - static_cast(z_ + 1) / nz - }}; - const FractionalDomain domain(domain_lo, domain_hi); - - const std::array halo_lo = {{ - domain.lo[0] - margin[0], - domain.lo[1] - margin[1], - domain.lo[2] - margin[2] - }}; - const std::array halo_hi = {{ - domain.hi[0] + margin[0], - domain.hi[1] + margin[1], - domain.hi[2] + margin[2] - }}; - const FractionalDomain halo(halo_lo, halo_hi); - - std::vector original_atoms; - original_atoms.reserve(ucell.get_natom()); - for (int it = 0; it < ucell.get_ntype(); ++it) - { - for (int ia = 0; ia < ucell.get_na(it); ++ia) - { - const ModuleBase::Vector3 cart = ucell.get_tau(it, ia); - const ModuleBase::Vector3 frac = cart * inv_lat; - const std::array wrapped_frac = {{ - wrap_fractional(frac.x), - wrap_fractional(frac.y), - wrap_fractional(frac.z) - }}; - original_atoms.push_back(OriginalAtom(wrapped_frac, it, ia)); - } - } - - std::array nbin; - for (int idim = 0; idim < 3; ++idim) - { - nbin[idim] = std::max(1, static_cast(std::ceil(1.0 / std::max(margin[idim], coord_tolerance)))); - } - - std::unordered_map> bins; - bins.reserve(original_atoms.size()); - for (int iat = 0; iat < static_cast(original_atoms.size()); ++iat) - { - const OriginalAtom& atom = original_atoms[iat]; - const int ix = clamp_index(static_cast(std::floor(atom.frac[0] * nbin[0])), 0, nbin[0] - 1); - const int iy = clamp_index(static_cast(std::floor(atom.frac[1] * nbin[1])), 0, nbin[1] - 1); - const int iz = clamp_index(static_cast(std::floor(atom.frac[2] * nbin[2])), 0, nbin[2] - 1); - bins[bin_key(ix, iy, iz, nbin)].push_back(iat); - } - - const std::vector intervals_x = split_periodic_interval(halo.lo[0], halo.hi[0]); - const std::vector intervals_y = split_periodic_interval(halo.lo[1], halo.hi[1]); - const std::vector intervals_z = split_periodic_interval(halo.lo[2], halo.hi[2]); - - for (const PeriodicInterval& bx : intervals_x) - { - const int ix_begin = clamp_index(static_cast(std::floor(bx.lo * nbin[0])), 0, nbin[0] - 1); - const int ix_end = clamp_index(static_cast(std::ceil(bx.hi * nbin[0])) - 1, 0, nbin[0] - 1); - for (const PeriodicInterval& by : intervals_y) - { - const int iy_begin = clamp_index(static_cast(std::floor(by.lo * nbin[1])), 0, nbin[1] - 1); - const int iy_end = clamp_index(static_cast(std::ceil(by.hi * nbin[1])) - 1, 0, nbin[1] - 1); - for (const PeriodicInterval& bz : intervals_z) - { - const int iz_begin = clamp_index(static_cast(std::floor(bz.lo * nbin[2])), 0, nbin[2] - 1); - const int iz_end = clamp_index(static_cast(std::ceil(bz.hi * nbin[2])) - 1, 0, nbin[2] - 1); - - for (int ix = ix_begin; ix <= ix_end; ++ix) - { - for (int iy = iy_begin; iy <= iy_end; ++iy) - { - for (int iz = iz_begin; iz <= iz_end; ++iz) - { - const auto bin_iter = bins.find(bin_key(ix, iy, iz, nbin)); - if (bin_iter == bins.end()) - { - continue; - } - - for (const int atom_id : bin_iter->second) - { - const OriginalAtom& original = original_atoms[atom_id]; - if (!inside_block(original, bx, by, bz)) - { - continue; - } - - const bool central_image = bx.shift == 0 && by.shift == 0 && bz.shift == 0; - const bool owned = central_image && - fractional_domain_index(original.frac[0], nx) == x_ && - fractional_domain_index(original.frac[1], ny) == y_ && - fractional_domain_index(original.frac[2], nz) == z_; - - const ModuleBase::Vector3 frac_image(original.frac[0] + bx.shift, - original.frac[1] + by.shift, - original.frac[2] + bz.shift); - const ModuleBase::Vector3 cart_image = frac_image * lat; - - NeighborAtom atom(cart_image.x, - cart_image.y, - cart_image.z, - original.atom_type, - original.atom_index, - static_cast(all_atoms_.size())); - atom.is_inside = owned; - all_atoms_.push_back(atom); - if (owned) - { - inside_atoms_.push_back(atom); - } - else - { - ghost_atoms_.push_back(atom); - } - } - } - } - } - } - } - } -} - // ========== Main public interface ========== -void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank) -{ - this->init(ucell, sr, mpi_rank, 1); -} - void NeighborSearch::init_distributed(const std::vector& owned_atoms, const std::vector& ghost_atoms, double sr, @@ -594,12 +47,16 @@ void NeighborSearch::init_distributed(const std::vector& owned_atoms, bin_manager_.clear(); search_radius_ = sr / lat0; - glayerX_ = glayerY_ = glayerZ_ = 0; - glayerX_minus_ = glayerY_minus_ = glayerZ_minus_ = 0; - x_ = y_ = z_ = 0; - wide_x_ = wide_y_ = wide_z_ = 0.0; - all_atoms_.reserve(owned_atoms.size() + ghost_atoms.size()); + const std::size_t total_atoms = ModuleNeighList::checked_size_sum(owned_atoms.size(), + ghost_atoms.size(), + "NeighborSearch distributed atom count"); + if (total_atoms > static_cast(std::numeric_limits::max())) + { + throw std::overflow_error("NeighborSearch distributed atom count exceeds local atom index range."); + } + + all_atoms_.reserve(total_atoms); inside_atoms_.reserve(owned_atoms.size()); ghost_atoms_.reserve(ghost_atoms.size()); @@ -611,10 +68,10 @@ void NeighborSearch::init_distributed(const std::vector& owned_atoms, local.cart.z, local.type, local.type_index, - static_cast(all_atoms_.size()), + ModuleNeighList::checked_local_atom_index(all_atoms_.size(), + "NeighborSearch owned atom id"), local.global_id, local.owner_rank); - atom.is_inside = true; all_atoms_.push_back(atom); inside_atoms_.push_back(atom); } @@ -627,223 +84,142 @@ void NeighborSearch::init_distributed(const std::vector& owned_atoms, local.cart.z, local.type, local.type_index, - static_cast(all_atoms_.size()), + ModuleNeighList::checked_local_atom_index(all_atoms_.size(), + "NeighborSearch ghost atom id"), local.global_id, local.owner_rank); - atom.is_inside = false; all_atoms_.push_back(atom); ghost_atoms_.push_back(atom); } - neighbor_list_.initialize(inside_atoms_.size(), - std::max(1, static_cast(all_atoms_.size()) * neighbor_reserve_factor)); + const std::size_t page_size = ModuleNeighList::checked_size_product(all_atoms_.size(), + neighbor_reserve_factor, + "NeighborSearch page size"); + neighbor_list_.initialize(inside_atoms_.size(), page_size); } -void NeighborSearch::init(const AtomProvider& ucell, double sr, int mpi_rank, int mpi_size) +void NeighborSearch::init(const AtomProvider& ucell, double sr) { + search_radius_ = sr / ucell.get_lat0(); + // clear possible residual data from previous runs inside_atoms_.clear(); ghost_atoms_.clear(); all_atoms_.clear(); - // clear any existing bin manager state bin_manager_.clear(); - search_radius_ = sr / ucell.get_lat0(); - check_expand_condition(ucell); - InputAtoms atoms = ucell_to_input_atoms(ucell); - - assert(mpi_size > 0); - assert(mpi_rank >= 0); - - const double span_x = atoms.x_high - atoms.x_low; - const double span_y = atoms.y_high - atoms.y_low; - const double span_z = atoms.z_high - atoms.z_low; - - int nx, ny, nz; - decompose(mpi_size, span_x, span_y, span_z, nx, ny, nz); - - const int active_size = nx * ny * nz; - assert(active_size > 0); - assert(active_size <= mpi_size); - - wide_x_ = span_x / nx; - wide_y_ = span_y / ny; - wide_z_ = span_z / nz; - assert(wide_x_ >= 0); - assert(wide_y_ >= 0); - assert(wide_z_ >= 0); - - if (mpi_rank >= active_size) - { - x_ = -1; - y_ = -1; - z_ = -1; - neighbor_list_.initialize(0, neighbor_reserve_factor); - return; - } - - z_ = mpi_rank / (nx * ny); - y_ = (mpi_rank % (nx * ny)) / nx; - x_ = mpi_rank % nx; - - if (mpi_size > 1) - { - set_local_member_variables_by_halo(ucell, nx, ny, nz); - } - else + for (int i = 0; i < ucell.get_ntype(); i++) { - set_member_variables(ucell); - - int in_x, in_y, in_z; - const auto domain_index = [](double position, double low, double wide, int n) { - if (wide < coord_tolerance) - { - return std::abs(position - low) < coord_tolerance ? 0 : std::numeric_limits::max(); - } - return std::min(std::max(static_cast(std::floor((position - low) / wide)), 0), n - 1); - }; - - for (size_t i = 0; i < all_atoms_.size(); i++) + for (int j = 0; j < ucell.get_na(i); j++) { - in_x = domain_index(all_atoms_[i].position_x, atoms.x_low, wide_x_, nx); - in_y = domain_index(all_atoms_[i].position_y, atoms.y_low, wide_y_, ny); - in_z = domain_index(all_atoms_[i].position_z, atoms.z_low, wide_z_, nz); - - if (in_x == x_ && in_y == y_ && in_z == z_ && - all_atoms_[i].position_x <= atoms.x_high && - all_atoms_[i].position_y <= atoms.y_high && - all_atoms_[i].position_z <= atoms.z_high && - all_atoms_[i].is_inside) - { - inside_atoms_.push_back(all_atoms_[i]); - } - else if (distance( - all_atoms_[i].position_x, - all_atoms_[i].position_y, - all_atoms_[i].position_z, - atoms.x_low, - atoms.y_low, - atoms.z_low) <= search_radius_ * search_radius_) - { - ghost_atoms_.push_back(all_atoms_[i]); - } + const ModuleNeighList::LocalAtomIndex atom_count + = ModuleNeighList::checked_local_atom_index(all_atoms_.size(), + "NeighborSearch atom id"); + NeighborAtom atom( + ucell.get_tau(i,j).x, + ucell.get_tau(i,j).y, + ucell.get_tau(i,j).z, + i, + j, + atom_count + ); + inside_atoms_.push_back(atom); + all_atoms_.push_back(atom); } } - neighbor_list_.initialize(inside_atoms_.size(), std::max(1, static_cast(all_atoms_.size()) * neighbor_reserve_factor)); + int glayerX ; + int glayerY ; + int glayerZ ; + + int glayerX_minus ; + int glayerY_minus ; + int glayerZ_minus ; + + check_expand_condition(ucell, glayerX_minus, glayerX, glayerY_minus, glayerY, glayerZ_minus, glayerZ); + set_member_variables(ucell, glayerX_minus, glayerX, glayerY_minus, glayerY, glayerZ_minus, glayerZ); + const std::size_t page_size = ModuleNeighList::checked_size_product(all_atoms_.size(), + neighbor_reserve_factor, + "NeighborSearch page size"); + neighbor_list_.initialize(inside_atoms_.size(), page_size); } void NeighborSearch::build_neighbors() { - bin_manager_.init_bins(search_radius_, inside_atoms_, ghost_atoms_); - bin_manager_.do_binning(inside_atoms_, ghost_atoms_); - bin_manager_.build_atom_neighbors(neighbor_list_, inside_atoms_); + bin_manager_.init_bins(search_radius_, all_atoms_); + bin_manager_.do_binning(all_atoms_); + bin_manager_.build_atom_neighbors(neighbor_list_, inside_atoms_, all_atoms_); } -// ========== Utility methods ========== -double NeighborSearch::distance( - double position_x, - double position_y, - double position_z, - double x_low, - double y_low, - double z_low) +// ========== Internal methods ========== + +double NeighborSearch::cross_product_norm(double a1, double a2, double a3, + double b1, double b2, double b3) { - double dx = std::max(0.0, std::max(x_low + x_ * wide_x_ - position_x, position_x - (x_low + (x_ + 1) * wide_x_))); - double dy = std::max(0.0, std::max(y_low + y_ * wide_y_ - position_y, position_y - (y_low + (y_ + 1) * wide_y_))); - double dz = std::max(0.0, std::max(z_low + z_ * wide_z_ - position_z, position_z - (z_low + (z_ + 1) * wide_z_))); - return dx * dx + dy * dy + dz * dz; + double c1 = a2 * b3 - a3 * b2; + double c2 = a3 * b1 - a1 * b3; + double c3 = a1 * b2 - a2 * b1; + return sqrt(c1 * c1 + c2 * c2 + c3 * c3); } -void NeighborSearch::decompose(int mpi_size, int &nx, int &ny, int &nz) +void NeighborSearch::check_expand_condition(const AtomProvider& ucell, int& glayerX_minus, int& glayerX, int& glayerY_minus, int& glayerY, int& glayerZ_minus, int& glayerZ) { - assert(mpi_size > 0); - nx = 1; - ny = 1; - nz = mpi_size; - - int cube = static_cast(cbrt(mpi_size)); - for (int i = cube; i >= 1; i--) - { - if (mpi_size % i == 0) - { - nx = i; - ny = mpi_size / i; - break; - } - } + const auto& lat = ucell.get_latvec(); + const double omega = ucell.get_omega(); + const double lat0 = ucell.get_lat0(); + const double lat0_cubed = lat0 * lat0 * lat0; - int sq = static_cast(sqrt(ny)); - for (int i = sq; i >= 1; i--) - { - if (ny % i == 0) - { - nz = ny / i; - ny = i; - break; - } - } -} + double a23_norm = cross_product_norm(lat.e21, lat.e22, lat.e23, lat.e31, lat.e32, lat.e33); + int extend_d11 = std::ceil(a23_norm * search_radius_ / omega * lat0_cubed); -void NeighborSearch::decompose(int mpi_size, double span_x, double span_y, double span_z, int& nx, int& ny, int& nz) -{ - assert(mpi_size > 0); + double a31_norm = cross_product_norm(lat.e31, lat.e32, lat.e33, lat.e11, lat.e12, lat.e13); + int extend_d22 = std::ceil(a31_norm * search_radius_ / omega * lat0_cubed); - nx = 1; - ny = 1; - nz = 1; + double a12_norm = cross_product_norm(lat.e11, lat.e12, lat.e13, lat.e21, lat.e22, lat.e23); + int extend_d33 = std::ceil(a12_norm * search_radius_ / omega * lat0_cubed); - span_x = std::max(0.0, span_x); - span_y = std::max(0.0, span_y); - span_z = std::max(0.0, span_z); + glayerX = extend_d11 + positive_layer_offset; + glayerY = extend_d22 + positive_layer_offset; + glayerZ = extend_d33 + positive_layer_offset; + glayerX_minus = extend_d11; + glayerY_minus = extend_d22; + glayerZ_minus = extend_d33; +} - const bool can_split_x = span_x > coord_tolerance; - const bool can_split_y = span_y > coord_tolerance; - const bool can_split_z = span_z > coord_tolerance; - if (!can_split_x && !can_split_y && !can_split_z) - { - return; - } +void NeighborSearch::set_member_variables(const AtomProvider& ucell, int glayerX_minus, int glayerX, int glayerY_minus, int glayerY, int glayerZ_minus, int glayerZ) +{ + ModuleBase::Vector3 vec1(ucell.get_latvec().e11, ucell.get_latvec().e12, ucell.get_latvec().e13); + ModuleBase::Vector3 vec2(ucell.get_latvec().e21, ucell.get_latvec().e22, ucell.get_latvec().e23); + ModuleBase::Vector3 vec3(ucell.get_latvec().e31, ucell.get_latvec().e32, ucell.get_latvec().e33); - std::vector factors; - int remaining = mpi_size; - for (int factor = 2; factor * factor <= remaining; ++factor) + for (int ix = -glayerX_minus; ix < glayerX; ix++) { - while (remaining % factor == 0) + for (int iy = -glayerY_minus; iy < glayerY; iy++) { - factors.push_back(factor); - remaining /= factor; - } - } - if (remaining > 1) - { - factors.push_back(remaining); - } - std::sort(factors.rbegin(), factors.rend()); - - for (const int factor : factors) - { - int* best_dim = nullptr; - double best_score = -1.0; - - const auto try_dimension = [&](bool can_split, double span, int& dim) { - if (!can_split) + for (int iz = -glayerZ_minus; iz < glayerZ; iz++) { - return; - } - const double score = span / dim; - if (score > best_score) - { - best_score = score; - best_dim = &dim; - } - }; + if(ix==0 && iy==0 && iz==0) + { + continue; + } + for (int i = 0; i < ucell.get_ntype(); i++) + { + for (int j = 0; j < ucell.get_na(i); j++) + { + double atom_x = ucell.get_tau(i,j).x + vec1[0] * ix + vec2[0] * iy + vec3[0] * iz; + double atom_y = ucell.get_tau(i,j).y + vec1[1] * ix + vec2[1] * iy + vec3[1] * iz; + double atom_z = ucell.get_tau(i,j).z + vec1[2] * ix + vec2[2] * iy + vec3[2] * iz; - try_dimension(can_split_x, span_x, nx); - try_dimension(can_split_y, span_y, ny); - try_dimension(can_split_z, span_z, nz); - assert(best_dim != nullptr); - *best_dim *= factor; + const ModuleNeighList::LocalAtomIndex atom_count + = ModuleNeighList::checked_local_atom_index(all_atoms_.size(), + "NeighborSearch atom id"); + NeighborAtom atom(atom_x, atom_y, atom_z, i, j, atom_count); + ghost_atoms_.push_back(atom); + all_atoms_.push_back(atom); + } + } + } + } } } diff --git a/source/source_cell/module_neighlist/neighbor_search.h b/source/source_cell/module_neighlist/neighbor_search.h index b752fa94e4a..b95ace5bc6d 100644 --- a/source/source_cell/module_neighlist/neighbor_search.h +++ b/source/source_cell/module_neighlist/neighbor_search.h @@ -42,22 +42,8 @@ class NeighborSearch * * @param ucell Unit cell providing atom positions and lattice info. * @param sr Search radius (cutoff distance) in Bohr. - * @param mpi_rank MPI rank of this process. */ - void init(const AtomProvider& ucell, double sr, int mpi_rank); - - /** - * @brief Initialize the neighbor search with explicit MPI rank and size. - * - * This overload keeps the single-rank interface intact while allowing - * callers to decompose central atoms across MPI ranks. - * - * @param ucell Unit cell providing atom positions and lattice info. - * @param sr Search radius (cutoff distance) in Bohr. - * @param mpi_rank MPI rank of this process. - * @param mpi_size Total number of MPI processes. - */ - void init(const AtomProvider& ucell, double sr, int mpi_rank, int mpi_size); + void init(const AtomProvider& ucell, double sr); /** * @brief Initialize from rank-local owned atoms and exchanged ghost atoms. @@ -83,6 +69,8 @@ class NeighborSearch */ void build_neighbors(); + + // ========== Getter methods ========== /** * @brief Get the constructed neighbor list. * @return Reference to the NeighborList object. @@ -95,137 +83,12 @@ class NeighborSearch */ const NeighborList& get_neighbor_list() const; - // ========== Utility methods (public for testing) ========== - - /** - * @brief Calculate squared distance from a point to the local domain box. - * - * Used to determine if an atom is within the search radius of the - * local MPI domain. - * - * @param position_x X coordinate of the point. - * @param position_y Y coordinate of the point. - * @param position_z Z coordinate of the point. - * @param x_low Lower bound of the global domain in X. - * @param y_low Lower bound of the global domain in Y. - * @param z_low Lower bound of the global domain in Z. - * @return Squared distance to the domain box. - */ - double distance(double position_x, - double position_y, - double position_z, - double x_low, - double y_low, - double z_low); - - /** - * @brief Decompose MPI size into a 3D grid. - * - * Finds a balanced decomposition of mpi_size into nx * ny * nz. - * - * @param mpi_size Total number of MPI processes. - * @param nx Output: number of divisions in X. - * @param ny Output: number of divisions in Y. - * @param nz Output: number of divisions in Z. - */ - void decompose(int mpi_size, int& nx, int& ny, int& nz); - - /** - * @brief Decompose MPI ranks only along directions with nonzero span. - * - * Directions whose atom-coordinate span is zero are assigned one domain - * layer so ownership is not duplicated across that direction. - * - * @param mpi_size Total number of MPI processes. - * @param span_x Atom-coordinate span in X. - * @param span_y Atom-coordinate span in Y. - * @param span_z Atom-coordinate span in Z. - * @param nx Output: number of divisions in X. - * @param ny Output: number of divisions in Y. - * @param nz Output: number of divisions in Z. - */ - void decompose(int mpi_size, double span_x, double span_y, double span_z, int& nx, int& ny, int& nz); - - // ========== Getter methods ========== - /** * @brief Get the search radius. * @return Search radius in lattice units. */ double get_search_radius() const; - /** - * @brief Get the X position of this MPI domain. - * @return Domain index in X. - */ - int get_x() const; - - /** - * @brief Get the Y position of this MPI domain. - * @return Domain index in Y. - */ - int get_y() const; - - /** - * @brief Get the Z position of this MPI domain. - * @return Domain index in Z. - */ - int get_z() const; - - /** - * @brief Get the width of this MPI domain in X. - * @return Domain width in X. - */ - double get_wide_x() const; - - /** - * @brief Get the width of this MPI domain in Y. - * @return Domain width in Y. - */ - double get_wide_y() const; - - /** - * @brief Get the width of this MPI domain in Z. - * @return Domain width in Z. - */ - double get_wide_z() const; - - /** - * @brief Get the number of expansion layers in +X direction. - * @return Number of layers. - */ - int get_glayerX() const; - - /** - * @brief Get the number of expansion layers in +Y direction. - * @return Number of layers. - */ - int get_glayerY() const; - - /** - * @brief Get the number of expansion layers in +Z direction. - * @return Number of layers. - */ - int get_glayerZ() const; - - /** - * @brief Get the number of expansion layers in -X direction. - * @return Number of layers. - */ - int get_glayerX_minus() const; - - /** - * @brief Get the number of expansion layers in -Y direction. - * @return Number of layers. - */ - int get_glayerY_minus() const; - - /** - * @brief Get the number of expansion layers in -Z direction. - * @return Number of layers. - */ - int get_glayerZ_minus() const; - /** * @brief Get all atoms (including periodic images). * @return Const reference to the vector of all atoms. @@ -244,39 +107,11 @@ class NeighborSearch */ const std::vector& get_ghost_atoms() const; - // ========== Setter methods ========== - - /** - * @brief Set the search radius. - * @param sr Search radius in lattice units. - */ - void set_search_radius(double sr); - - /** - * @brief Set the position of this MPI domain. - * @param x Domain index in X. - * @param y Domain index in Y. - * @param z Domain index in Z. - */ - void set_position(int x, int y, int z); - - /** - * @brief Set the width of this MPI domain. - * @param wx Domain width in X. - * @param wy Domain width in Y. - * @param wz Domain width in Z. - */ - void set_width(double wx, double wy, double wz); - private: // ========== Internal methods ========== - /** - * @brief Convert unit cell atoms to InputAtoms format. - * @param ucell Unit cell providing atom info. - * @return InputAtoms structure for processing. - */ - InputAtoms ucell_to_input_atoms(const AtomProvider& ucell); + double cross_product_norm(double a1, double a2, double a3, + double b1, double b2, double b3); /** * @brief Check and compute expansion layer counts. @@ -286,7 +121,7 @@ class NeighborSearch * * @param ucell Unit cell providing lattice vectors. */ - void check_expand_condition(const AtomProvider& ucell); + void check_expand_condition(const AtomProvider& ucell, int& glayerX_minus, int& glayerX, int& glayerY_minus, int& glayerY, int& glayerZ_minus, int& glayerZ); /** * @brief Set member variables by generating periodic images. @@ -296,71 +131,13 @@ class NeighborSearch * * @param ucell Unit cell providing atom positions. */ - void set_member_variables(const AtomProvider& ucell); - - /** - * @brief Generate only atoms needed by the local MPI domain. - * - * The resulting all_atoms_ is a rank-local index space containing local - * inside atoms and cutoff-relevant ghost/image atoms. - * - * @param ucell Unit cell providing atom positions. - * @param atoms Original unit-cell atom bounds. - * @param nx Number of MPI divisions in X. - * @param ny Number of MPI divisions in Y. - * @param nz Number of MPI divisions in Z. - */ - void set_local_member_variables(const AtomProvider& ucell, const InputAtoms& atoms, int nx, int ny, int nz); - - /** - * @brief Generate local atoms by querying fractional-coordinate halo bins. - * - * Ownership is still defined only for atoms in the primary unit cell. Periodic - * images are generated only when they overlap the local cutoff halo and are - * stored as ghost atoms. - * - * @param ucell Unit cell providing atom positions. - * @param nx Number of MPI divisions in fractional X. - * @param ny Number of MPI divisions in fractional Y. - * @param nz Number of MPI divisions in fractional Z. - */ - void set_local_member_variables_by_halo(const AtomProvider& ucell, int nx, int ny, int nz); - - /** - * @brief Compute the norm of the cross product of two 3D vectors. - * - * @param a1, a2, a3 Components of the first vector. - * @param b1, b2, b3 Components of the second vector. - * @return Norm of the cross product. - */ - static double cross_product_norm(double a1, double a2, double a3, - double b1, double b2, double b3); + void set_member_variables(const AtomProvider& ucell, int glayerX_minus, int glayerX, int glayerY_minus, int glayerY, int glayerZ_minus, int glayerZ); // ========== Data members ========== /// Search radius in lattice units double search_radius_ = 0.0; - /// Position of this MPI domain in the 3D grid - int x_ = 0; - int y_ = 0; - int z_ = 0; - - /// Width of this MPI domain - double wide_x_ = 0.0; - double wide_y_ = 0.0; - double wide_z_ = 0.0; - - /// Number of expansion layers in positive directions - int glayerX_ = 0; - int glayerY_ = 0; - int glayerZ_ = 0; - - /// Number of expansion layers in negative directions - int glayerX_minus_ = 0; - int glayerY_minus_ = 0; - int glayerZ_minus_ = 0; - /// All atoms including periodic images std::vector all_atoms_; @@ -375,10 +152,8 @@ class NeighborSearch /// Bin manager for efficient neighbor search BinManager bin_manager_; - // ========== Compile-time constants ========== - /// Tolerance for coordinate comparisons in lattice units - static constexpr double coord_tolerance = 1e-8; + // ========== Compile-time constants ========== /// Offset added to expansion layers in positive directions static constexpr int positive_layer_offset = 1; diff --git a/source/source_cell/module_neighlist/page_allocator.cpp b/source/source_cell/module_neighlist/page_allocator.cpp index 959ea79154f..43bb4d6bc68 100644 --- a/source/source_cell/module_neighlist/page_allocator.cpp +++ b/source/source_cell/module_neighlist/page_allocator.cpp @@ -1,6 +1,9 @@ #include "page_allocator.h" #include "source_base/tool_quit.h" +#include +#include + PageAllocator::PageAllocator() : pgsize_(default_pgsize) { new_page_(); @@ -8,6 +11,10 @@ PageAllocator::PageAllocator() : pgsize_(default_pgsize) PageAllocator::PageAllocator(int pgsize) : pgsize_(pgsize) { + if (pgsize_ <= 0) + { + throw std::invalid_argument("PageAllocator page size must be positive."); + } new_page_(); } @@ -60,6 +67,10 @@ int PageAllocator::get_pgsize() const void PageAllocator::new_page_() { + if (pgsize_ <= 0) + { + throw std::invalid_argument("PageAllocator page size must be positive."); + } Page p; p.capacity = pgsize_; p.offset = 0; diff --git a/source/source_cell/module_neighlist/test/bin_manager_test.cpp b/source/source_cell/module_neighlist/test/bin_manager_test.cpp index d28b386afb1..3853786b97d 100644 --- a/source/source_cell/module_neighlist/test/bin_manager_test.cpp +++ b/source/source_cell/module_neighlist/test/bin_manager_test.cpp @@ -11,14 +11,14 @@ TEST(BinManagerUnit, InitAndBinning) inside.emplace_back(0.5, 0.0, 0.0, 0, 1, 1); BinManager bm; - bm.init_bins(1.0, inside, ghost); + bm.init_bins(1.0, inside); EXPECT_EQ(bm.get_nbinx(), 1); EXPECT_EQ(bm.get_nbiny(), 1); EXPECT_EQ(bm.get_nbinz(), 1); EXPECT_EQ(bm.get_total_bins(), bm.get_nbinx() * bm.get_nbiny() * bm.get_nbinz()); - bm.do_binning(inside, ghost); + bm.do_binning(inside); int total_atoms_in_bins = 0; for (int i = 0; i < bm.get_total_bins(); ++i) { @@ -34,11 +34,8 @@ TEST(BinManagerUnit, InitBins) atoms.emplace_back(0.5, 0.0, 0.0, 0, 1, 1); atoms.emplace_back(4.9, 0.0, 0.0, 0, 2, 2); - std::vector inside = atoms; - std::vector ghost; - BinManager bm; - bm.init_bins(1.0, inside, ghost); + bm.init_bins(1.0, atoms); EXPECT_EQ(bm.get_nbinx(), 5); EXPECT_EQ(bm.get_nbiny(), 1); EXPECT_EQ(bm.get_nbinz(), 1); @@ -51,22 +48,19 @@ TEST(BinManagerUnit, BuildNeighborsAndClear) atoms.emplace_back(0.5, 0.0, 0.0, 0, 1, 1); atoms.emplace_back(5.0, 0.0, 0.0, 0, 2, 2); - std::vector inside = atoms; - std::vector ghost; - BinManager bm; - bm.init_bins(1.0, inside, ghost); + bm.init_bins(1.0, atoms); EXPECT_EQ(bm.get_nbinx(), 5); EXPECT_EQ(bm.get_nbiny(), 1); EXPECT_EQ(bm.get_nbinz(), 1); EXPECT_EQ(bm.get_total_bins(), bm.get_nbinx() * bm.get_nbiny() * bm.get_nbinz()); - bm.do_binning(inside, ghost); + bm.do_binning(atoms); NeighborList nl; nl.initialize(static_cast(atoms.size()), 1024); - bm.build_atom_neighbors(nl, atoms); + bm.build_atom_neighbors(nl, atoms, atoms); EXPECT_EQ(nl.get_numneigh(0), 1); EXPECT_EQ(nl.get_numneigh(1), 1); @@ -82,12 +76,12 @@ TEST(BinManagerUnit, EmptyAtomsBuildNeighbors) std::vector ghost; BinManager bm; - bm.init_bins(1.0, atoms, ghost); + bm.init_bins(1.0, atoms); NeighborList nl; nl.initialize(0, 16); - bm.build_atom_neighbors(nl, atoms); + bm.build_atom_neighbors(nl, atoms, atoms); EXPECT_EQ(nl.get_nlocal(), 0); } @@ -98,23 +92,20 @@ TEST(BinManagerUnit, BoundaryAndExactRadius) atoms.emplace_back(1.0, 0.0, 0.0, 0, 1, 1); atoms.emplace_back(0.9, 0.0, 0.0, 0, 2, 2); - std::vector inside = atoms; - std::vector ghost; - BinManager bm; - bm.init_bins(1.0, inside, ghost); - bm.do_binning(inside, ghost); + bm.init_bins(1.0, atoms); + bm.do_binning(atoms); NeighborList nl; - nl.initialize(static_cast(inside.size()), 64); + nl.initialize(atoms.size(), 64); - bm.build_atom_neighbors(nl, inside); + bm.build_atom_neighbors(nl, atoms, atoms); EXPECT_EQ(nl.get_numneigh(0), 2); - for (int i = 0; i < static_cast(inside.size()); ++i) { + for (int i = 0; i < static_cast(atoms.size()); ++i) { for (int j = 0; j < nl.get_numneigh(i); ++j) { int id = nl.get_firstneigh(i)[j]; - EXPECT_NE(id, inside[i].atom_id); + EXPECT_NE(id, atoms[i].atom_id); } } } @@ -128,7 +119,7 @@ TEST(BinManagerUnit, InitWithGhostOnly) ghost.emplace_back(2.0, 0.0, 0.0, 0, 1, 1); BinManager bm; - bm.init_bins(1.0, inside, ghost); + bm.init_bins(1.0, ghost); EXPECT_EQ(bm.get_nbinx(), 3); EXPECT_EQ(bm.get_nbiny(), 1); @@ -141,17 +132,14 @@ TEST(BinManagerUnit, BuildNeighborsNoNeighborsFirstneighNull) atoms.emplace_back(0.0, 0.0, 0.0, 0, 0, 0); atoms.emplace_back(100.0, 100.0, 100.0, 0, 1, 1); - std::vector inside = atoms; - std::vector ghost; - BinManager bm; - bm.init_bins(1.0, inside, ghost); - bm.do_binning(inside, ghost); + bm.init_bins(1.0, atoms); + bm.do_binning(atoms); NeighborList nl; - nl.initialize(static_cast(inside.size()), 8); + nl.initialize(atoms.size(), 8); - bm.build_atom_neighbors(nl, inside); + bm.build_atom_neighbors(nl, atoms, atoms); EXPECT_EQ(nl.get_numneigh(0), 0); EXPECT_EQ(nl.get_numneigh(1), 0); @@ -165,28 +153,53 @@ TEST(BinManagerUnit, GhostAtomsAreCounted) std::vector ghost; inside.emplace_back(0.0, 0.0, 0.0, 0, 0, 0); - ghost.emplace_back(0.4, 0.0, 0.0, 0, 1, 3); + ghost.emplace_back(0.4, 0.0, 0.0, 0, 1, 1, 3, 1); BinManager bm; - bm.init_bins(1.0, inside, ghost); - bm.do_binning(inside, ghost); + std::vector all_atoms = inside; + all_atoms.insert(all_atoms.end(), ghost.begin(), ghost.end()); + bm.init_bins(1.0, all_atoms); + bm.do_binning(all_atoms); NeighborList nl; nl.initialize(static_cast(inside.size()), 32); - bm.build_atom_neighbors(nl, inside); + bm.build_atom_neighbors(nl, inside, all_atoms); EXPECT_EQ(nl.get_nlocal(), 1); EXPECT_EQ(nl.get_numneigh(0), 1); bool found = false; if (nl.get_numneigh(0) > 0 && nl.get_firstneigh(0) != nullptr) { for (int k = 0; k < nl.get_numneigh(0); ++k) { - if (nl.get_firstneigh(0)[k] == 3) found = true; + if (nl.get_firstneigh(0)[k] == 1) found = true; } } EXPECT_TRUE(found); } +TEST(BinManagerUnit, SamePositionDifferentAtomsAreNeighbors) +{ + std::vector atoms; + atoms.emplace_back(0.0, 0.0, 0.0, 0, 0, 0); + atoms.emplace_back(0.0, 0.0, 0.0, 0, 1, 1); + + BinManager bm; + bm.init_bins(1.0, atoms); + bm.do_binning(atoms); + + NeighborList nl; + nl.initialize(atoms.size(), 16); + + bm.build_atom_neighbors(nl, atoms, atoms); + + EXPECT_EQ(nl.get_numneigh(0), 1); + EXPECT_EQ(nl.get_numneigh(1), 1); + ASSERT_NE(nl.get_firstneigh(0), nullptr); + ASSERT_NE(nl.get_firstneigh(1), nullptr); + EXPECT_EQ(nl.get_firstneigh(0)[0], 1); + EXPECT_EQ(nl.get_firstneigh(1)[0], 0); +} + TEST(BinManagerUnit, MultipleBinsNeighborSearch) { std::vector atoms; @@ -196,18 +209,15 @@ TEST(BinManagerUnit, MultipleBinsNeighborSearch) for (int z = 0; z < 3; ++z) atoms.emplace_back(x * 1.0, y * 1.0, z * 1.0, 0, 0, id++); - std::vector inside = atoms; - std::vector ghost; - BinManager bm; - bm.init_bins(1.0, inside, ghost); - bm.do_binning(inside, ghost); + bm.init_bins(1.0, atoms); + bm.do_binning(atoms); NeighborList nl; - nl.initialize(static_cast(inside.size()), 16); + nl.initialize(atoms.size(), 16); - bm.build_atom_neighbors(nl, inside); + bm.build_atom_neighbors(nl, atoms, atoms); int center_index = 13; EXPECT_EQ(nl.get_numneigh(center_index), 6); -} \ No newline at end of file +} diff --git a/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp b/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp index 917be5104fe..add4efcd529 100644 --- a/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp +++ b/source/source_cell/module_neighlist/test/neighbor_search_mpi_benchmark.cpp @@ -1,5 +1,6 @@ #include "source_cell/module_neighlist/neighbor_search.h" #include "source_cell/module_neighlist/domain_decomposition.h" +#include "source_cell/module_neighlist/neighbor_types.h" #include "source_cell/module_neighlist/unitcell_lite.h" #include @@ -142,7 +143,7 @@ int main(int argc, char** argv) { NeighborSearch serial; const double t0 = MPI_Wtime(); - serial.init(ucell, cutoff, 0); + serial.init(ucell, cutoff); const double t1 = MPI_Wtime(); serial.build_neighbors(); const double t2 = MPI_Wtime(); @@ -201,7 +202,8 @@ int main(int argc, char** argv) for (size_t atom_id = 0; atom_id < all_atoms.size(); ++atom_id) { - if (all_atoms[atom_id].atom_id != static_cast(atom_id)) + if (all_atoms[atom_id].atom_id != + ModuleNeighList::checked_local_atom_index(atom_id, "benchmark atom id")) { local_failure = 1; } @@ -219,7 +221,7 @@ int main(int argc, char** argv) for (int ad = 0; ad < list.get_numneigh(local_i); ++ad) { const int neighbor_id = list.get_firstneigh(local_i)[ad]; - if (neighbor_id < 0 || neighbor_id >= static_cast(all_atoms.size())) + if (neighbor_id < 0 || static_cast(neighbor_id) >= all_atoms.size()) { local_failure = 1; } diff --git a/source/source_cell/module_neighlist/test/neighbor_search_test.cpp b/source/source_cell/module_neighlist/test/neighbor_search_test.cpp index 266c0b7f2ad..2d984d1e54a 100644 --- a/source/source_cell/module_neighlist/test/neighbor_search_test.cpp +++ b/source/source_cell/module_neighlist/test/neighbor_search_test.cpp @@ -1,69 +1,115 @@ #include + +#include "../local_atom.h" #include "../neighbor_search.h" #include "../unitcell_lite.h" -#include -// Helper function to create a simple UnitCellLite for testing -static UnitCellLite make_test_ucell(double lat0, double omega, - const ModuleBase::Matrix3& latvec, - int ntype, const std::vector& na, - const std::vector>& tau) { +#include +#include + +namespace +{ +UnitCellLite make_test_ucell(double lat0, + double omega, + const ModuleBase::Matrix3& latvec, + int ntype, + const std::vector& na, + const std::vector>& tau) +{ UnitCellLite ucell; ucell.set_lattice(lat0, omega, latvec); ucell.set_atoms(ntype, na, tau); return ucell; } -TEST(NeighborSearchTest, TwoAtomsNeighbor) +ModuleBase::Matrix3 identity_lattice() { ModuleBase::Matrix3 latvec; - latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; + latvec.e11 = 1.0; + latvec.e12 = 0.0; + latvec.e13 = 0.0; + latvec.e21 = 0.0; + latvec.e22 = 1.0; + latvec.e23 = 0.0; + latvec.e31 = 0.0; + latvec.e32 = 0.0; + latvec.e33 = 1.0; + return latvec; +} + +std::size_t count_pairs(const NeighborList& list) +{ + std::size_t pairs = 0; + for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) + { + pairs += static_cast(list.get_numneigh(local_i)); + } + return pairs; +} +} // namespace - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {2}, - {{0.0, 0.0, 0.0}, {0.5, 0.0, 0.0}} - ); +TEST(NeighborSearchTest, TwoAtomsNeighbor) +{ + UnitCellLite ucell = make_test_ucell(1.0, + 1.0, + identity_lattice(), + 1, + {2}, + {{0.0, 0.0, 0.0}, {0.5, 0.0, 0.0}}); NeighborSearch ns; - double cutoff = 1.0; - - ns.init(ucell, cutoff, 0); + ns.init(ucell, 1.0); ns.build_neighbors(); - auto &list = ns.get_neighbor_list(); - + const NeighborList& list = ns.get_neighbor_list(); ASSERT_EQ(list.get_nlocal(), 2); - EXPECT_EQ(list.get_numneigh(0), 8); EXPECT_EQ(list.get_numneigh(1), 8); } TEST(NeighborSearchTest, NoNeighbor) { - ModuleBase::Matrix3 latvec; - latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; - - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {2}, - {{0.0, 0.0, 0.0}, {5.0, 0.0, 0.0}} - ); + UnitCellLite ucell = make_test_ucell(1.0, + 1.0, + identity_lattice(), + 1, + {2}, + {{0.0, 0.0, 0.0}, {5.0, 0.0, 0.0}}); NeighborSearch ns; - - // use a smaller search radius to avoid counting periodic-image neighbors - ns.init(ucell, 0.1, 0); + ns.init(ucell, 0.1); ns.build_neighbors(); - auto &list = ns.get_neighbor_list(); - + const NeighborList& list = ns.get_neighbor_list(); + ASSERT_EQ(list.get_nlocal(), 2); EXPECT_EQ(list.get_numneigh(0), 0); EXPECT_EQ(list.get_numneigh(1), 0); } +TEST(NeighborSearchTest, SerialInitOwnsCentralAtomsAndBuildsImages) +{ + UnitCellLite ucell = make_test_ucell(1.0, + 1.0, + identity_lattice(), + 1, + {2}, + {{0.0, 0.0, 0.0}, {0.5, 0.0, 0.0}}); + + NeighborSearch ns; + ns.init(ucell, 1.0); + + EXPECT_EQ(ns.get_inside_atoms().size(), 2U); + EXPECT_EQ(ns.get_neighbor_list().get_nlocal(), 2); + EXPECT_EQ(ns.get_all_atoms().size(), 54U); + + const std::vector& all_atoms = ns.get_all_atoms(); + for (std::size_t i = 0; i < all_atoms.size(); ++i) + { + EXPECT_EQ(all_atoms[i].atom_id, + ModuleNeighList::checked_local_atom_index(i, "test atom id")); + } +} + TEST(NeighborSearchTest, DistributedInputUsesOwnedCentersAndGhostNeighbors) { std::vector owned_atoms; @@ -90,324 +136,54 @@ TEST(NeighborSearchTest, DistributedInputUsesOwnedCentersAndGhostNeighbors) const NeighborList& list = ns.get_neighbor_list(); ASSERT_EQ(list.get_nlocal(), 1); ASSERT_EQ(list.get_numneigh(0), 1); + const int neighbor_id = list.get_firstneigh(0)[0]; ASSERT_GE(neighbor_id, 0); - ASSERT_LT(neighbor_id, static_cast(ns.get_all_atoms().size())); + ASSERT_LT(static_cast(neighbor_id), ns.get_all_atoms().size()); EXPECT_EQ(ns.get_all_atoms()[neighbor_id].global_id, 1); EXPECT_EQ(ns.get_all_atoms()[neighbor_id].owner_rank, 1); } -TEST(NeighborSearchUnit, DistanceBox) +TEST(NeighborSearchTest, DistributedNeighborIdsStayLocalToAllAtoms) { - NeighborSearch ns; - // set a single cell region at x=0..1,y=0..1,z=0..1 - ns.set_position(0, 0, 0); - ns.set_width(1.0, 1.0, 1.0); - - double inside = ns.distance(0.2, 0.5, 0.5, 0.0, 0.0, 0.0); - EXPECT_DOUBLE_EQ(inside, 0.0); - - double outside = ns.distance(2.0, 0.5, 0.5, 0.0, 0.0, 0.0); - // squared distance should be (2-1)^2 = 1 - EXPECT_DOUBLE_EQ(outside, 1.0); -} - -TEST(NeighborSearchUnit, DecomposeCases) -{ - NeighborSearch ns; - int nx, ny, nz; - - ns.decompose(8, nx, ny, nz); - EXPECT_EQ(nx * ny * nz, 8); - // expect somewhat balanced cube factors for 8 - EXPECT_EQ(nx, 2); - EXPECT_EQ(ny, 2); - EXPECT_EQ(nz, 2); - - ns.decompose(7, nx, ny, nz); - EXPECT_EQ(nx * ny * nz, 7); - EXPECT_EQ(nx, 1); - EXPECT_EQ(ny, 1); - EXPECT_EQ(nz, 7); -} - -TEST(NeighborSearchUnit, DecomposePrimeNumber) -{ - NeighborSearch ns; - int nx, ny, nz; - ns.decompose(13, nx, ny, nz); - EXPECT_EQ(nx * ny * nz, 13); - EXPECT_EQ(nx, 1); - EXPECT_EQ(ny, 1); - EXPECT_EQ(nz, 13); -} - -TEST(NeighborSearchUnit, DecomposeSkipsZeroSpanDirections) -{ - NeighborSearch ns; - int nx, ny, nz; - - ns.decompose(8, 1.0, 1.0, 0.0, nx, ny, nz); - EXPECT_EQ(nx * ny * nz, 8); - EXPECT_EQ(nz, 1); - - ns.decompose(4, 4.0, 0.0, 0.0, nx, ny, nz); - EXPECT_EQ(nx, 4); - EXPECT_EQ(ny, 1); - EXPECT_EQ(nz, 1); - - ns.decompose(4, 0.0, 0.0, 0.0, nx, ny, nz); - EXPECT_EQ(nx, 1); - EXPECT_EQ(ny, 1); - EXPECT_EQ(nz, 1); -} - -TEST(NeighborSearchUnit, NonOrthogonalLatticeExpand) -{ - ModuleBase::Matrix3 latvec; - // skewed lattice - latvec.e11 = 1; latvec.e12 = 0.3; latvec.e13 = 0.0; - latvec.e21 = 0.1; latvec.e22 = 1.0; latvec.e23 = 0.0; - latvec.e31 = 0.0; latvec.e32 = 0.0; latvec.e33 = 1.0; - - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {1}, - {{0.0, 0.0, 0.0}} - ); - - NeighborSearch ns; - ns.init(ucell, 2.5, 0); - // for skewed lattice, expansion layers should be >= 1 - EXPECT_GE(ns.get_glayerX(), 1); - EXPECT_GE(ns.get_glayerY(), 1); - EXPECT_GE(ns.get_glayerZ(), 1); -} - -TEST(NeighborSearchInit_WideZero_CentralInside, SingleAtomCell) -{ - ModuleBase::Matrix3 latvec; - latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; - - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {1}, - {{0.0, 0.0, 0.0}} - ); + std::vector owned_atoms; + std::vector ghost_atoms; + owned_atoms.push_back(LocalAtom(ModuleBase::Vector3(0.0, 0.0, 0.0), + ModuleBase::Vector3(0.0, 0.0, 0.0), + 0, + 10, + 0, + 0, + false)); + owned_atoms.push_back(LocalAtom(ModuleBase::Vector3(2.0, 0.0, 0.0), + ModuleBase::Vector3(2.0, 0.0, 0.0), + 0, + 11, + 1, + 0, + false)); + ghost_atoms.push_back(LocalAtom(ModuleBase::Vector3(0.5, 0.0, 0.0), + ModuleBase::Vector3(0.5, 0.0, 0.0), + 0, + 20, + 2, + 1, + true)); NeighborSearch ns; - // choose sr small enough; with mpi_size fixed to 1 in init, wide_* become 0 - ns.init(ucell, 0.1, 0); - // central cell atom should be counted as inside - EXPECT_EQ(ns.get_inside_atoms().size(), 1); - EXPECT_EQ(ns.get_neighbor_list().get_nlocal(), static_cast(ns.get_inside_atoms().size())); -} - -TEST(NeighborSearchInit_MpiRankIndexing, RankValues) -{ - ModuleBase::Matrix3 latvec; - latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; - - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {1}, - {{0.0, 0.0, 0.0}} - ); - - NeighborSearch ns0; - ns0.init(ucell, 0.5, 0); - // with mpi_size fixed to 1 in init, nx=ny=nz=1; for rank 0 expect x=y=0,z=0 - EXPECT_EQ(ns0.get_x(), 0); - EXPECT_EQ(ns0.get_y(), 0); - EXPECT_EQ(ns0.get_z(), 0); -} - -TEST(NeighborSearchInit_MpiOwnership, SingleAtomZeroSpanIsOwnedOnce) -{ - ModuleBase::Matrix3 latvec; - latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; - - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {1}, - {{0.0, 0.0, 0.0}} - ); - - size_t total_inside = 0; - for (int rank = 0; rank < 4; ++rank) - { - NeighborSearch ns; - ns.init(ucell, 0.1, rank, 4); - total_inside += ns.get_inside_atoms().size(); - EXPECT_EQ(ns.get_neighbor_list().get_nlocal(), static_cast(ns.get_inside_atoms().size())); - EXPECT_EQ(ns.get_inside_atoms().size(), rank == 0 ? 1U : 0U); - } - EXPECT_EQ(total_inside, 1U); -} - -TEST(NeighborSearchInit_MpiOwnership, SplitsOnlyNonzeroSpanDirection) -{ - ModuleBase::Matrix3 latvec; - latvec.e11 = 4; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; - - UnitCellLite ucell = make_test_ucell( - 1.0, 4.0, latvec, 1, {4}, - {{0.0, 0.0, 0.0}, - {1.0, 0.0, 0.0}, - {2.0, 0.0, 0.0}, - {3.0, 0.0, 0.0}} - ); - - size_t total_inside = 0; - for (int rank = 0; rank < 4; ++rank) - { - NeighborSearch ns; - ns.init(ucell, 0.1, rank, 4); - total_inside += ns.get_inside_atoms().size(); - EXPECT_EQ(ns.get_y(), 0); - EXPECT_EQ(ns.get_z(), 0); - EXPECT_EQ(ns.get_inside_atoms().size(), 1U); - } - EXPECT_EQ(total_inside, 4U); -} - -TEST(NeighborSearchInit_MpiLocalAtoms, LocalIdsAreValidAndAllAtomsShrink) -{ - ModuleBase::Matrix3 latvec; - latvec.e11 = 4; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 4; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 4; - - std::vector> tau; - for (int ix = 0; ix < 4; ++ix) - { - for (int iy = 0; iy < 4; ++iy) - { - for (int iz = 0; iz < 4; ++iz) - { - tau.emplace_back(ix, iy, iz); - } - } - } - - UnitCellLite ucell = make_test_ucell(1.0, 64.0, latvec, 1, {static_cast(tau.size())}, tau); - - NeighborSearch serial; - serial.init(ucell, 1.1, 0); - serial.build_neighbors(); - const size_t serial_all_atoms = serial.get_all_atoms().size(); - size_t serial_neighbor_pairs = 0; - for (int local_i = 0; local_i < serial.get_neighbor_list().get_nlocal(); ++local_i) - { - serial_neighbor_pairs += serial.get_neighbor_list().get_numneigh(local_i); - } + ns.init_distributed(owned_atoms, ghost_atoms, 0.75, 1.0); + ns.build_neighbors(); - size_t total_inside = 0; - size_t parallel_all_atoms_sum = 0; - size_t parallel_all_atoms_max = 0; - size_t parallel_neighbor_pairs = 0; - for (int rank = 0; rank < 4; ++rank) + const NeighborList& list = ns.get_neighbor_list(); + const std::vector& all_atoms = ns.get_all_atoms(); + EXPECT_EQ(count_pairs(list), 1U); + for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) { - NeighborSearch ns; - ns.init(ucell, 1.1, rank, 4); - ns.build_neighbors(); - - const auto& all_atoms = ns.get_all_atoms(); - const auto& list = ns.get_neighbor_list(); - total_inside += ns.get_inside_atoms().size(); - parallel_all_atoms_sum += all_atoms.size(); - parallel_all_atoms_max = std::max(parallel_all_atoms_max, all_atoms.size()); - for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) - { - parallel_neighbor_pairs += list.get_numneigh(local_i); - } - - for (size_t i = 0; i < all_atoms.size(); ++i) - { - EXPECT_EQ(all_atoms[i].atom_id, static_cast(i)); - } - - for (int local_i = 0; local_i < list.get_nlocal(); ++local_i) + for (int ad = 0; ad < list.get_numneigh(local_i); ++ad) { - for (int ad = 0; ad < list.get_numneigh(local_i); ++ad) - { - const int neighbor_id = list.get_firstneigh(local_i)[ad]; - EXPECT_GE(neighbor_id, 0); - EXPECT_LT(neighbor_id, static_cast(all_atoms.size())); - } + const int neighbor_id = list.get_firstneigh(local_i)[ad]; + EXPECT_GE(neighbor_id, 0); + EXPECT_LT(static_cast(neighbor_id), all_atoms.size()); } } - - EXPECT_EQ(total_inside, tau.size()); - EXPECT_EQ(parallel_neighbor_pairs, serial_neighbor_pairs); - EXPECT_LT(parallel_all_atoms_max, serial_all_atoms); - EXPECT_LT(parallel_all_atoms_sum, serial_all_atoms * 4); -} - -TEST(NeighborSearchDistance_OutsideCases, VariousAxes) -{ - NeighborSearch ns; - ns.set_position(0, 0, 0); - ns.set_width(2.0, 3.0, 4.0); - - // position inside box along x (no dx), but outside along y by above high bound - double d = ns.distance(0.5, 4.5, 1.0, 0.0, 0.0, 0.0); - // dy = position_y - (y_low + (y+1)*wide_y) = 4.5 - 3.0 = 1.5 -> squared 2.25 - // dx = 0, dz = 0 -> total 2.25 - EXPECT_DOUBLE_EQ(d, 2.25); - - // position left of low bound on x - double d2 = ns.distance(-1.0, 1.0, 1.0, 0.0, 0.0, 0.0); - // dx = x_low - position_x = 0 - (-1) = 1 -> squared 1 - EXPECT_DOUBLE_EQ(d2, 1.0); -} - -TEST(NeighborSearchDecompose_SmallSizes, TwoAndOne) -{ - NeighborSearch ns; - int nx, ny, nz; - ns.decompose(2, nx, ny, nz); - EXPECT_EQ(nx * ny * nz, 2); - // possible decomposition is nx=1, ny=1, nz=2 (or nx=1, ny=2, nz=1 depending on algorithm) - EXPECT_EQ(nx, 1); - - ns.decompose(1, nx, ny, nz); - EXPECT_EQ(nx, 1); - EXPECT_EQ(ny, 1); - EXPECT_EQ(nz, 1); -} - -TEST(NeighborSearchUnit, ExpansionLayersAndAtomCount) -{ - ModuleBase::Matrix3 latvec; - latvec.e11 = 1; latvec.e12 = 0; latvec.e13 = 0; - latvec.e21 = 0; latvec.e22 = 1; latvec.e23 = 0; - latvec.e31 = 0; latvec.e32 = 0; latvec.e33 = 1; - - UnitCellLite ucell = make_test_ucell( - 1.0, 1.0, latvec, 1, {2}, - {{0.0, 0.0, 0.0}, {0.5, 0.0, 0.0}} - ); - - NeighborSearch ns; - ns.init(ucell, 1.0, 0); - - // For identity lattice with search_radius=1 expected ceil produce values - EXPECT_EQ(ns.get_glayerX(), 2); - EXPECT_EQ(ns.get_glayerY(), 2); - EXPECT_EQ(ns.get_glayerZ(), 2); - EXPECT_EQ(ns.get_glayerX_minus(), 1); - - // Check atom count - int images_x = ns.get_glayerX() + ns.get_glayerX_minus(); - int images_y = ns.get_glayerY() + ns.get_glayerY_minus(); - int images_z = ns.get_glayerZ() + ns.get_glayerZ_minus(); - int expected = images_x * images_y * images_z * 2; // 2 atoms per cell - EXPECT_EQ(static_cast(ns.get_all_atoms().size()), expected); } - -// end of additional tests diff --git a/source/source_cell/module_neighlist/unitcell_lite.cpp b/source/source_cell/module_neighlist/unitcell_lite.cpp index 8475f81cf29..ddabbfdfa35 100644 --- a/source/source_cell/module_neighlist/unitcell_lite.cpp +++ b/source/source_cell/module_neighlist/unitcell_lite.cpp @@ -1,4 +1,5 @@ #include "unitcell_lite.h" +#include "source_cell/module_neighlist/neighbor_types.h" #include @@ -69,10 +70,12 @@ void UnitCellLite::set_atoms(int ntype, tau_ = tau; // compute total number of atoms - nat_ = 0; + std::size_t nat = 0; for (int i = 0; i < ntype_; ++i) { - nat_ += na_[i]; + assert(na_[i] >= 0); + nat += static_cast(na_[i]); } + nat_ = ModuleNeighList::checked_int_size(nat, "UnitCellLite atom count"); assert(tau_.size() == static_cast(nat_)); // compute cumulative counts @@ -89,4 +92,4 @@ void UnitCellLite::compute_naa_() { for (size_t i = 1; i < naa_.size(); ++i) { naa_[i] = naa_[i - 1] + na_[i]; } -} \ No newline at end of file +} diff --git a/source/source_esolver/esolver_lj.cpp b/source/source_esolver/esolver_lj.cpp index 6ee2c40b48a..7b212c027ed 100644 --- a/source/source_esolver/esolver_lj.cpp +++ b/source/source_esolver/esolver_lj.cpp @@ -4,6 +4,7 @@ #include "source_cell/module_neighbor/sltk_grid_driver.h" #include "source_io/module_output/output_log.h" #include "source_io/module_output/cif_io.h" +#include "source_cell/module_neighlist/neighbor_types.h" #include "source_cell/module_neighlist/neighbor_search.h" #include "source_base/global_variable.h" #include "source_base/timer.h" @@ -13,6 +14,7 @@ #endif #include +#include #include #include @@ -75,8 +77,7 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) double distance = 0.0; ModuleBase::Vector3 tau1, tau2, dtau; -#ifdef __MPI - if (GlobalV::NPROC > 1) + #ifdef __MPI { ModuleBase::timer::start("ESolverLJ", "mpi_total"); ModuleBase::timer::start("ESolverLJ", "neigh_init"); @@ -102,8 +103,10 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) atom_start[it + 1] = atom_start[it] + ucell.atoms[it].na; } - std::vector potential_by_atom(ucell.nat, 0.0); - std::vector virial_by_atom(ucell.nat * 9, 0.0); + const std::size_t local_virial_size + = ModuleNeighList::checked_size_product(inside_atoms.size(), 9, "ESolver_LJ local virial size"); + std::vector potential_by_local_atom(inside_atoms.size(), 0.0); + std::vector virial_by_local_atom(local_virial_size, 0.0); ModuleBase::timer::start("ESolverLJ", "force_loc"); for (int local_i = 0; local_i < neighbor_list.get_nlocal(); ++local_i) @@ -128,7 +131,7 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) distance = dtau.norm(); if (distance < lj_rcut(it, it2)) { - potential_by_atom[global_i] += LJ_energy(distance, it, it2) - en_shift(it, it2); + potential_by_local_atom[local_i] += LJ_energy(distance, it, it2) - en_shift(it, it2); ModuleBase::Vector3 f_ij = LJ_force(dtau, it, it2); lj_force(global_i, 0) += f_ij.x; lj_force(global_i, 1) += f_ij.y; @@ -137,7 +140,7 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) { for (int j = 0; j < 3; ++j) { - virial_by_atom[global_i * 9 + i * 3 + j] += dtau[i] * f_ij[j]; + virial_by_local_atom[local_i * 9 + i * 3 + j] += dtau[i] * f_ij[j]; } } } @@ -145,31 +148,41 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) } ModuleBase::timer::end("ESolverLJ", "force_loc"); + double local_potential = 0.0; + std::array local_virial{}; + for (std::size_t local_i = 0; local_i < potential_by_local_atom.size(); ++local_i) + { + local_potential += potential_by_local_atom[local_i]; + for (int component = 0; component < 9; ++component) + { + local_virial[component] += virial_by_local_atom[local_i * 9 + component]; + } + } + ModuleBase::timer::start("ESolverLJ", "reduce"); - Parallel_Reduce::reduce_all(potential_by_atom.data(), static_cast(potential_by_atom.size())); + Parallel_Reduce::reduce_all(&local_potential, 1); + Parallel_Reduce::reduce_all(local_virial.data(), static_cast(local_virial.size())); + // Existing MD code expects a full global force matrix on each rank. + // Keeping this reduction preserves current behavior; removing the global + // force layout requires a distributed MD data model. Parallel_Reduce::reduce_all(lj_force.c, lj_force.nr * lj_force.nc); - Parallel_Reduce::reduce_all(virial_by_atom.data(), static_cast(virial_by_atom.size())); ModuleBase::timer::end("ESolverLJ", "reduce"); - for (int iat = 0; iat < ucell.nat; ++iat) + lj_potential += local_potential; + for (int i = 0; i < 3; ++i) { - lj_potential += potential_by_atom[iat]; - for (int i = 0; i < 3; ++i) + for (int j = 0; j < 3; ++j) { - for (int j = 0; j < 3; ++j) - { - lj_virial(i, j) += virial_by_atom[iat * 9 + i * 3 + j]; - } + lj_virial(i, j) += local_virial[i * 3 + j]; } } ModuleBase::timer::end("ESolverLJ", "mpi_total"); } - else -#endif + #else { ModuleBase::timer::start("ESolverLJ", "serial_tot"); ModuleBase::timer::start("ESolverLJ", "ser_neigh"); - neighbor_search.init(ucell_lite, search_radius, 0); + neighbor_search.init(ucell_lite, search_radius); neighbor_search.build_neighbors(); ModuleBase::timer::end("ESolverLJ", "ser_neigh"); @@ -208,51 +221,7 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) ModuleBase::timer::end("ESolverLJ", "ser_force"); ModuleBase::timer::end("ESolverLJ", "serial_tot"); } - - - /*Grid_Driver grid_neigh(PARAM.inp.test_deconstructor, PARAM.inp.test_grid); - atom_arrange::search(PARAM.globalv.search_pbc, - GlobalV::ofs_running, - grid_neigh, - ucell, - search_radius, - PARAM.inp.test_atom_input); - - double distance = 0.0; - int index = 0; - - // Important! potential, force, virial must be zero per step - lj_potential = 0; - lj_force.zero_out(); - lj_virial.zero_out(); - - ModuleBase::Vector3 tau1, tau2, dtau; - for (int it = 0; it < ucell.ntype; ++it) - { - Atom* atom1 = &ucell.atoms[it]; - for (int ia = 0; ia < atom1->na; ++ia) - { - tau1 = atom1->tau[ia]; - grid_neigh.Find_atom(ucell, tau1, it, ia); - for (int ad = 0; ad < grid_neigh.getAdjacentNum(); ++ad) - { - tau2 = grid_neigh.getAdjacentTau(ad); - int it2 = grid_neigh.getType(ad); - dtau = (tau1 - tau2) * ucell.lat0; - distance = dtau.norm(); - if (distance < lj_rcut(it, it2)) - { - lj_potential += LJ_energy(distance, it, it2) - en_shift(it, it2); - ModuleBase::Vector3 f_ij = LJ_force(dtau, it, it2); - lj_force(index, 0) += f_ij.x; - lj_force(index, 1) += f_ij.y; - lj_force(index, 2) += f_ij.z; - LJ_virial(f_ij, dtau); - } - } - index++; - } - }*/ + #endif lj_potential /= 2.0; GlobalV::ofs_running << " #TOTAL ENERGY# " << std::setprecision(11) << lj_potential * ModuleBase::Ry_to_eV << " eV" @@ -266,7 +235,7 @@ void ESolver_LJ::runner(UnitCell& ucell, const int istep) lj_virial(i, j) /= (2.0 * ucell.omega); } } - } +} double ESolver_LJ::cal_energy() {