Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions src/parallel/include/timpi/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,118 @@ class Communicator
void maxloc(std::vector<T,A1> & r,
std::vector<unsigned int,A2> & max_id) const;

/**
* Take a local variable and replace it with the product of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void product(T & r) const;

/**
* Non-blocking product of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void product(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the logical_and of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void logical_and(T & r) const;

/**
* Non-blocking logical_and of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void logical_and(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the bitwise_and of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void bitwise_and(T & r) const;

/**
* Non-blocking bitwise_and of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void bitwise_and(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the logical_or of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void logical_or(T & r) const;

/**
* Non-blocking logical_or of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void logical_or(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the bitwise_or of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void bitwise_or(T & r) const;

/**
* Non-blocking bitwise_or of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void bitwise_or(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the logical_xor of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void logical_xor(T & r) const;

/**
* Non-blocking logical_xor of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void logical_xor(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the bitwise_xor of its values
* on all processors. Containers are replaced element-wise.
*/
template <typename T>
inline
void bitwise_xor(T & r) const;

/**
* Non-blocking bitwise_xor of the local value \p r into \p o
* with the request \p req.
*/
template <typename T>
inline
void bitwise_xor(const T & r, T & o, Request & req) const;

/**
* Take a local variable and replace it with the sum of it's values
* on all processors. Containers are replaced element-wise.
Expand Down
28 changes: 28 additions & 0 deletions src/parallel/include/timpi/parallel_communicator_specializations
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,34 @@
inline
void sum(std::unordered_map<K,V,H,E,A> &r) const;

template <typename T, typename A>
inline
void product(std::vector<T,A> &r) const;

template <typename T, typename A>
inline
void logical_and(std::vector<T,A> &r) const;

template <typename T, typename A>
inline
void bitwise_and(std::vector<T,A> &r) const;

template <typename T, typename A>
inline
void logical_or(std::vector<T,A> &r) const;

template <typename T, typename A>
inline
void bitwise_or(std::vector<T,A> &r) const;

template <typename T, typename A>
inline
void logical_xor(std::vector<T,A> &r) const;

template <typename T, typename A>
inline
void bitwise_xor(std::vector<T,A> &r) const;

template <typename T, typename C, typename A>
inline
void set_union(std::set<T,C,A> &data,
Expand Down
213 changes: 46 additions & 167 deletions src/parallel/include/timpi/parallel_implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -2174,64 +2174,6 @@ inline bool Communicator::semiverify(const std::vector<T,A> * r) const




template <typename T>
inline void Communicator::min(const T & r,
T & o,
Request & req) const
{
if (this->size() > 1)
{
TIMPI_LOG_SCOPE("min()", "Parallel");

timpi_call_mpi
(TIMPI_IALLREDUCE(&r, &o, 1, StandardType<T>(&r),
OpFunction<T>::min(), this->get(),
req.get()));
}
else
{
o = r;
req = Request::null_request;
}
}



template <typename T>
inline void Communicator::min(T & timpi_mpi_var(r)) const
{
if (this->size() > 1)
{
TIMPI_LOG_SCOPE("min(scalar)", "Parallel");

timpi_call_mpi
(TIMPI_ALLREDUCE(MPI_IN_PLACE, &r, 1,
StandardType<T>(&r), OpFunction<T>::min(),
this->get()));
}
}



template <typename T, typename A>
inline void Communicator::min(std::vector<T,A> & r) const
{
if (this->size() > 1 && !r.empty())
{
TIMPI_LOG_SCOPE("min(vector)", "Parallel");

timpi_assert(this->verify(r.size()));

timpi_call_mpi
(TIMPI_ALLREDUCE
(MPI_IN_PLACE, r.data(), cast_int<CountType>(r.size()),
StandardType<T>(r.data()), OpFunction<T>::min(),
this->get()));
}
}


template <typename A>
inline void Communicator::min(std::vector<bool,A> & r) const
{
Expand Down Expand Up @@ -2352,59 +2294,6 @@ inline void Communicator::minloc(std::vector<bool,A1> & r,
}


template <typename T>
inline void Communicator::max(const T & r,
T & o,
Request & req) const
{
if (this->size() > 1)
{
TIMPI_LOG_SCOPE("max()", "Parallel");

timpi_call_mpi
(TIMPI_IALLREDUCE(&r, &o, 1, StandardType<T>(&r),
OpFunction<T>::max(), this->get(),
req.get()));
}
else
{
o = r;
req = Request::null_request;
}
}


template <typename T>
inline void Communicator::max(T & timpi_mpi_var(r)) const
{
if (this->size() > 1)
{
TIMPI_LOG_SCOPE("max(scalar)", "Parallel");

timpi_call_mpi
(TIMPI_ALLREDUCE (MPI_IN_PLACE, &r, 1, StandardType<T>(&r),
OpFunction<T>::max(), this->get()));
}
}


template <typename T, typename A>
inline void Communicator::max(std::vector<T,A> & r) const
{
if (this->size() > 1 && !r.empty())
{
TIMPI_LOG_SCOPE("max(vector)", "Parallel");

timpi_assert(this->verify(r.size()));

timpi_call_mpi
(TIMPI_ALLREDUCE (MPI_IN_PLACE, r.data(),
cast_int<CountType>(r.size()),
StandardType<T>(r.data()),
OpFunction<T>::max(), this->get()));
}
}


template <typename A>
inline void Communicator::max(std::vector<bool,A> & r) const
Expand Down Expand Up @@ -2643,64 +2532,54 @@ inline void Communicator::maxloc(std::vector<bool,A1> & r,
}
}

#define TIMPI_DEFINE_COMMUNICATOR_OPS(OPNAME) \
template <typename T> \
inline void Communicator::OPNAME(T &timpi_mpi_var(r)) const { \
if (this->size() > 1) { \
TIMPI_LOG_SCOPE(#OPNAME "(scalar, blocking)", "Parallel"); \
\
timpi_call_mpi(TIMPI_ALLREDUCE(MPI_IN_PLACE, &r, 1, StandardType<T>(&r), \
OpFunction<T>::OPNAME(), this->get())); \
} \
} \
\
template <typename T, typename A> \
inline void Communicator::OPNAME(std::vector<T, A> &r) const { \
if (this->size() > 1 && !r.empty()) { \
TIMPI_LOG_SCOPE(#OPNAME "(vector, blocking)", "Parallel"); \
\
timpi_assert(this->verify(r.size())); \
\
timpi_call_mpi(TIMPI_ALLREDUCE( \
MPI_IN_PLACE, r.data(), cast_int<CountType>(r.size()), \
StandardType<T>(r.data()), OpFunction<T>::OPNAME(), this->get())); \
} \
} \
template <typename T> \
inline void Communicator::OPNAME(const T &r, T &o, Request &req) const { \
if (this->size() > 1) { \
TIMPI_LOG_SCOPE(#OPNAME "(scalar, nonblocking)", "Parallel"); \
\
timpi_call_mpi(TIMPI_IALLREDUCE(&r, &o, 1, StandardType<T>(&r), \
OpFunction<T>::OPNAME(), this->get(), \
req.get())); \
} else { \
o = r; \
req = Request::null_request; \
} \
}

template <typename T>
inline void Communicator::sum(const T & r,
T & o,
Request & req) const
{
#ifdef TIMPI_HAVE_MPI
if (this->size() > 1)
{
TIMPI_LOG_SCOPE("sum()", "Parallel");

timpi_call_mpi
(TIMPI_IALLREDUCE(&r, &o, 1, StandardType<T>(&r),
OpFunction<T>::sum(), this->get(),
req.get()));
}
else
#endif
{
o = r;
req = Request::null_request;
}
}


template <typename T>
inline void Communicator::sum(T & timpi_mpi_var(r)) const
{
if (this->size() > 1)
{
TIMPI_LOG_SCOPE("sum()", "Parallel");

timpi_call_mpi
(TIMPI_ALLREDUCE(MPI_IN_PLACE, &r, 1,
StandardType<T>(&r),
OpFunction<T>::sum(),
this->get()));
}
}


template <typename T, typename A>
inline void Communicator::sum(std::vector<T,A> & r) const
{
if (this->size() > 1 && !r.empty())
{
TIMPI_LOG_SCOPE("sum()", "Parallel");

timpi_assert(this->verify(r.size()));
TIMPI_DEFINE_COMMUNICATOR_OPS(sum)
TIMPI_DEFINE_COMMUNICATOR_OPS(max)
TIMPI_DEFINE_COMMUNICATOR_OPS(min)
TIMPI_DEFINE_COMMUNICATOR_OPS(product)
TIMPI_DEFINE_COMMUNICATOR_OPS(logical_and)
TIMPI_DEFINE_COMMUNICATOR_OPS(bitwise_and)
TIMPI_DEFINE_COMMUNICATOR_OPS(logical_or)
TIMPI_DEFINE_COMMUNICATOR_OPS(bitwise_or)
TIMPI_DEFINE_COMMUNICATOR_OPS(logical_xor)
TIMPI_DEFINE_COMMUNICATOR_OPS(bitwise_xor)

timpi_call_mpi
(TIMPI_ALLREDUCE(MPI_IN_PLACE, r.data(),
cast_int<CountType>(r.size()),
StandardType<T>(r.data()),
OpFunction<T>::sum(),
this->get()));
}
}


// We still do function overloading for complex sums - in a perfect
Expand Down