diff --git a/src/parallel/include/timpi/communicator.h b/src/parallel/include/timpi/communicator.h index fa77314..9857fac 100644 --- a/src/parallel/include/timpi/communicator.h +++ b/src/parallel/include/timpi/communicator.h @@ -464,6 +464,118 @@ class Communicator void maxloc(std::vector & r, std::vector & max_id) const; + /** + * Take a local variable and replace it with the product of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void product(T & r) const; + + /** + * Non-blocking product of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void product(const T & r, T & o, Request & req) const; + + /** + * Take a local variable and replace it with the logical_and of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void logical_and(T & r) const; + + /** + * Non-blocking logical_and of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void logical_and(const T & r, T & o, Request & req) const; + + /** + * Take a local variable and replace it with the bitwise_and of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void bitwise_and(T & r) const; + + /** + * Non-blocking bitwise_and of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void bitwise_and(const T & r, T & o, Request & req) const; + + /** + * Take a local variable and replace it with the logical_or of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void logical_or(T & r) const; + + /** + * Non-blocking logical_or of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void logical_or(const T & r, T & o, Request & req) const; + + /** + * Take a local variable and replace it with the bitwise_or of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void bitwise_or(T & r) const; + + /** + * Non-blocking bitwise_or of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void bitwise_or(const T & r, T & o, Request & req) const; + + /** + * Take a local variable and replace it with the logical_xor of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void logical_xor(T & r) const; + + /** + * Non-blocking logical_xor of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void logical_xor(const T & r, T & o, Request & req) const; + + /** + * Take a local variable and replace it with the bitwise_xor of its values + * on all processors. Containers are replaced element-wise. + */ + template + inline + void bitwise_xor(T & r) const; + + /** + * Non-blocking bitwise_xor of the local value \p r into \p o + * with the request \p req. + */ + template + inline + void bitwise_xor(const T & r, T & o, Request & req) const; + /** * Take a local variable and replace it with the sum of it's values * on all processors. Containers are replaced element-wise. diff --git a/src/parallel/include/timpi/parallel_communicator_specializations b/src/parallel/include/timpi/parallel_communicator_specializations index fd59c43..56b3485 100644 --- a/src/parallel/include/timpi/parallel_communicator_specializations +++ b/src/parallel/include/timpi/parallel_communicator_specializations @@ -98,6 +98,34 @@ inline void sum(std::unordered_map &r) const; + template + inline + void product(std::vector &r) const; + + template + inline + void logical_and(std::vector &r) const; + + template + inline + void bitwise_and(std::vector &r) const; + + template + inline + void logical_or(std::vector &r) const; + + template + inline + void bitwise_or(std::vector &r) const; + + template + inline + void logical_xor(std::vector &r) const; + + template + inline + void bitwise_xor(std::vector &r) const; + template inline void set_union(std::set &data, diff --git a/src/parallel/include/timpi/parallel_implementation.h b/src/parallel/include/timpi/parallel_implementation.h index 26bdc53..1e28a09 100644 --- a/src/parallel/include/timpi/parallel_implementation.h +++ b/src/parallel/include/timpi/parallel_implementation.h @@ -2174,64 +2174,6 @@ inline bool Communicator::semiverify(const std::vector * r) const - -template -inline void Communicator::min(const T & r, - T & o, - Request & req) const -{ - if (this->size() > 1) - { - TIMPI_LOG_SCOPE("min()", "Parallel"); - - timpi_call_mpi - (TIMPI_IALLREDUCE(&r, &o, 1, StandardType(&r), - OpFunction::min(), this->get(), - req.get())); - } - else - { - o = r; - req = Request::null_request; - } -} - - - -template -inline void Communicator::min(T & timpi_mpi_var(r)) const -{ - if (this->size() > 1) - { - TIMPI_LOG_SCOPE("min(scalar)", "Parallel"); - - timpi_call_mpi - (TIMPI_ALLREDUCE(MPI_IN_PLACE, &r, 1, - StandardType(&r), OpFunction::min(), - this->get())); - } -} - - - -template -inline void Communicator::min(std::vector & r) const -{ - if (this->size() > 1 && !r.empty()) - { - TIMPI_LOG_SCOPE("min(vector)", "Parallel"); - - timpi_assert(this->verify(r.size())); - - timpi_call_mpi - (TIMPI_ALLREDUCE - (MPI_IN_PLACE, r.data(), cast_int(r.size()), - StandardType(r.data()), OpFunction::min(), - this->get())); - } -} - - template inline void Communicator::min(std::vector & r) const { @@ -2352,59 +2294,6 @@ inline void Communicator::minloc(std::vector & r, } -template -inline void Communicator::max(const T & r, - T & o, - Request & req) const -{ - if (this->size() > 1) - { - TIMPI_LOG_SCOPE("max()", "Parallel"); - - timpi_call_mpi - (TIMPI_IALLREDUCE(&r, &o, 1, StandardType(&r), - OpFunction::max(), this->get(), - req.get())); - } - else - { - o = r; - req = Request::null_request; - } -} - - -template -inline void Communicator::max(T & timpi_mpi_var(r)) const -{ - if (this->size() > 1) - { - TIMPI_LOG_SCOPE("max(scalar)", "Parallel"); - - timpi_call_mpi - (TIMPI_ALLREDUCE (MPI_IN_PLACE, &r, 1, StandardType(&r), - OpFunction::max(), this->get())); - } -} - - -template -inline void Communicator::max(std::vector & r) const -{ - if (this->size() > 1 && !r.empty()) - { - TIMPI_LOG_SCOPE("max(vector)", "Parallel"); - - timpi_assert(this->verify(r.size())); - - timpi_call_mpi - (TIMPI_ALLREDUCE (MPI_IN_PLACE, r.data(), - cast_int(r.size()), - StandardType(r.data()), - OpFunction::max(), this->get())); - } -} - template inline void Communicator::max(std::vector & r) const @@ -2643,64 +2532,54 @@ inline void Communicator::maxloc(std::vector & r, } } +#define TIMPI_DEFINE_COMMUNICATOR_OPS(OPNAME) \ + template \ + inline void Communicator::OPNAME(T &timpi_mpi_var(r)) const { \ + if (this->size() > 1) { \ + TIMPI_LOG_SCOPE(#OPNAME "(scalar, blocking)", "Parallel"); \ + \ + timpi_call_mpi(TIMPI_ALLREDUCE(MPI_IN_PLACE, &r, 1, StandardType(&r), \ + OpFunction::OPNAME(), this->get())); \ + } \ + } \ + \ + template \ + inline void Communicator::OPNAME(std::vector &r) const { \ + if (this->size() > 1 && !r.empty()) { \ + TIMPI_LOG_SCOPE(#OPNAME "(vector, blocking)", "Parallel"); \ + \ + timpi_assert(this->verify(r.size())); \ + \ + timpi_call_mpi(TIMPI_ALLREDUCE( \ + MPI_IN_PLACE, r.data(), cast_int(r.size()), \ + StandardType(r.data()), OpFunction::OPNAME(), this->get())); \ + } \ + } \ + template \ + inline void Communicator::OPNAME(const T &r, T &o, Request &req) const { \ + if (this->size() > 1) { \ + TIMPI_LOG_SCOPE(#OPNAME "(scalar, nonblocking)", "Parallel"); \ + \ + timpi_call_mpi(TIMPI_IALLREDUCE(&r, &o, 1, StandardType(&r), \ + OpFunction::OPNAME(), this->get(), \ + req.get())); \ + } else { \ + o = r; \ + req = Request::null_request; \ + } \ + } -template -inline void Communicator::sum(const T & r, - T & o, - Request & req) const -{ -#ifdef TIMPI_HAVE_MPI - if (this->size() > 1) - { - TIMPI_LOG_SCOPE("sum()", "Parallel"); - - timpi_call_mpi - (TIMPI_IALLREDUCE(&r, &o, 1, StandardType(&r), - OpFunction::sum(), this->get(), - req.get())); - } - else -#endif - { - o = r; - req = Request::null_request; - } -} - - -template -inline void Communicator::sum(T & timpi_mpi_var(r)) const -{ - if (this->size() > 1) - { - TIMPI_LOG_SCOPE("sum()", "Parallel"); - - timpi_call_mpi - (TIMPI_ALLREDUCE(MPI_IN_PLACE, &r, 1, - StandardType(&r), - OpFunction::sum(), - this->get())); - } -} - - -template -inline void Communicator::sum(std::vector & r) const -{ - if (this->size() > 1 && !r.empty()) - { - TIMPI_LOG_SCOPE("sum()", "Parallel"); - - timpi_assert(this->verify(r.size())); +TIMPI_DEFINE_COMMUNICATOR_OPS(sum) +TIMPI_DEFINE_COMMUNICATOR_OPS(max) +TIMPI_DEFINE_COMMUNICATOR_OPS(min) +TIMPI_DEFINE_COMMUNICATOR_OPS(product) +TIMPI_DEFINE_COMMUNICATOR_OPS(logical_and) +TIMPI_DEFINE_COMMUNICATOR_OPS(bitwise_and) +TIMPI_DEFINE_COMMUNICATOR_OPS(logical_or) +TIMPI_DEFINE_COMMUNICATOR_OPS(bitwise_or) +TIMPI_DEFINE_COMMUNICATOR_OPS(logical_xor) +TIMPI_DEFINE_COMMUNICATOR_OPS(bitwise_xor) - timpi_call_mpi - (TIMPI_ALLREDUCE(MPI_IN_PLACE, r.data(), - cast_int(r.size()), - StandardType(r.data()), - OpFunction::sum(), - this->get())); - } -} // We still do function overloading for complex sums - in a perfect