From f4278d70b54c412748a8010fac54c3a5e159f612 Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 13:59:20 -0400 Subject: [PATCH 01/12] tensor.h --- src/TiledArray/tensor/tensor.h | 1028 ++++++++++++++------------------ 1 file changed, 460 insertions(+), 568 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index fe76b07bd0..892e6bb82e 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -34,6 +34,11 @@ namespace TiledArray { template class Tensor; +template +void gemm(Alpha alpha, const Tensor& A, const Tensor& B, + Beta beta, Tensor &C, const math::GemmHelper& gemm_helper); + namespace detail { /// Signals that we can take the trace of a Tensor (for numeric \c T) @@ -46,7 +51,7 @@ struct TraceIsDefined, enable_if_numeric_t> : std::true_type {}; /// \tparam T the value type of this tensor /// \tparam A The allocator type for the data -template +template class Tensor { // meaningful error if T& is not assignable, see // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=48101 @@ -55,13 +60,12 @@ class Tensor { "Tensor: T must be an assignable type (e.g. cannot be const)"); public: - typedef Tensor Tensor_; ///< This class type typedef Range range_type; ///< Tensor range type typedef typename range_type::index1_type index1_type; ///< 1-index type typedef typename range_type::ordinal_type ordinal_type; ///< Ordinal type typedef typename range_type::ordinal_type size_type; ///< Size type (to meet the container concept) - typedef A allocator_type; ///< Allocator type + typedef Allocator allocator_type; ///< Allocator type typedef typename allocator_type::value_type value_type; ///< Array element type typedef @@ -84,84 +88,54 @@ class Tensor { template using numeric_t = typename TiledArray::detail::numeric_type::type; - /// Evaluation tensor - - /// This tensor is used as an evaluated intermediate for other tensors. - class Impl : public allocator_type { - public: - /// Default constructor - - /// Construct an empty tensor that has no data or dimensions - Impl() : allocator_type(), range_(), data_(NULL) {} - - /// Construct with range - - /// \param range The N-dimensional range for this tensor - explicit Impl(const range_type& range) - : allocator_type(), range_(range), data_(NULL) { - data_ = allocator_type::allocate(range.volume()); - } - - /// Construct with rvalue range - - /// \param range The N-dimensional range for this tensor - explicit Impl(range_type&& range) - : allocator_type(), range_(range), data_(NULL) { - data_ = allocator_type::allocate(range.volume()); - } - - ~Impl() { - math::destroy_vector(range_.volume(), data_); - allocator_type::deallocate(data_, range_.volume()); - data_ = NULL; - } - - range_type range_; ///< Tensor size info - pointer data_; ///< Tensor data - }; // class Impl - template struct is_tensor { static constexpr bool value = detail::is_tensor::value || detail::is_tensor_of_tensor::value; }; - template >::type* = nullptr> - static void default_init(index1_type, U*) {} + using default_construct = bool; + + Tensor(const range_type& range, size_t batch_size, bool default_construct) + : range_(range), batch_size_(batch_size) + { + size_t size = range.volume()*batch_size; + allocator_type allocator; + auto *ptr = allocator.allocate(size); + if (default_construct) { + std::uninitialized_default_construct_n(ptr, size); + //std::uninitialized_value_construct_n(ptr, size); + } + auto deleter = [allocator=std::move(allocator),size](auto &&ptr) mutable { + std::destroy_n(ptr, size); + allocator.deallocate(ptr,size); + }; + this->data_ = std::shared_ptr(ptr, std::move(deleter)); + } - template >::type* = nullptr> - static void default_init(index1_type n, U* u) { - math::uninitialized_fill_vector(n, U(), u); + /// Construct a tensor with a range equal to \c range. The data is + /// uninitialized. + /// \param range The range of the tensor + Tensor(const range_type& range, size_t batch_size, std::shared_ptr data) + : range_(range), batch_size_(batch_size), data_(data) + { } - std::shared_ptr pimpl_; ///< Shared pointer to implementation object - static const range_type empty_range_; ///< Empty range + range_type range_; ///< range + size_t batch_size_ = 1; + std::shared_ptr data_; ///< Shared pointer to implementation object public: - // Compiler generated functions - Tensor() : pimpl_() {} - Tensor(const Tensor_& other) : pimpl_(other.pimpl_) {} - Tensor(Tensor_&& other) : pimpl_(std::move(other.pimpl_)) {} - ~Tensor() {} - Tensor_& operator=(const Tensor_& other) { - pimpl_ = other.pimpl_; - return *this; - } - Tensor_& operator=(Tensor_&& other) { - pimpl_ = std::move(other.pimpl_); - return *this; - } - /// Construct tensor + // Compiler generated functions + Tensor() = default; /// Construct a tensor with a range equal to \c range. The data is /// uninitialized. /// \param range The range of the tensor explicit Tensor(const range_type& range) - : pimpl_(std::make_shared(range)) { - default_init(range.volume(), pimpl_->data_); + : Tensor(range, 1, default_construct{true}) + { } /// Construct a tensor with a fill value @@ -173,9 +147,10 @@ class Tensor { typename std::enable_if::value && detail::is_tensor::value>::type* = nullptr> Tensor(const range_type& range, const Value& value) - : pimpl_(std::make_shared(range)) { - const auto n = pimpl_->range_.volume(); - pointer MADNESS_RESTRICT const data = pimpl_->data_; + : Tensor(range, 1, default_construct{false}) + { + const auto n = this->size(); + pointer MADNESS_RESTRICT const data = this->data(); Clone cloner; for (size_type i = 0ul; i < n; ++i) new (data + i) value_type(cloner(value)); @@ -188,7 +163,8 @@ class Tensor { template >::type* = nullptr> Tensor(const range_type& range, const Value& value) - : pimpl_(std::make_shared(range)) { + : Tensor(range, 1, default_construct{false}) + { detail::tensor_init([value]() -> Value { return value; }, *this); } @@ -198,16 +174,18 @@ class Tensor { TiledArray::detail::is_input_iterator::value && !std::is_pointer::value>::type* = nullptr> Tensor(const range_type& range, InIter it) - : pimpl_(std::make_shared(range)) { + : Tensor(range, 1, default_construct{false}) + { auto n = range.volume(); - pointer MADNESS_RESTRICT const data = pimpl_->data_; + pointer MADNESS_RESTRICT const data = this->data(); for (size_type i = 0ul; i < n; ++i, ++it) data[i] = *it; } template Tensor(const Range& range, const U* u) - : pimpl_(std::make_shared(range)) { - math::uninitialized_copy_vector(range.volume(), u, pimpl_->data_); + : Tensor(range, 1, default_construct{false}) + { + math::uninitialized_copy_vector(range.volume(), u, this->data()); } Tensor(const Range& range, std::initializer_list il) @@ -222,10 +200,11 @@ class Tensor { template < typename T1, typename std::enable_if< - is_tensor::value && !std::is_same::value && - !detail::has_conversion_operator_v>::type* = nullptr> + is_tensor::value && !std::is_same::value && + !detail::has_conversion_operator_v>::type* = nullptr> explicit Tensor(const T1& other) - : pimpl_(std::make_shared(detail::clone_range(other))) { + : Tensor(detail::clone_range(other), 1, default_construct{false}) + { auto op = [](const numeric_t arg) -> numeric_t { return arg; }; detail::tensor_init(op, *this, other); @@ -242,18 +221,19 @@ class Tensor { typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> Tensor(const T1& other, const Perm& perm) - : pimpl_(std::make_shared(outer(perm) * other.range())) { + : Tensor(outer(perm) * other.range(), 1, default_construct{false}) + { auto op = [](const numeric_t arg) -> numeric_t { return arg; }; detail::tensor_init(op, outer(perm), *this, other); // If we actually have a ToT the inner permutation was not applied above so // we do that now - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + constexpr bool is_tot = detail::is_tensor_of_tensor_v; constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor_"); + // not match Tensor"); if constexpr (is_tot && is_bperm) { if (inner_size(perm) != 0) { auto inner_perm = inner(perm); @@ -274,7 +254,8 @@ class Tensor { is_tensor::value && !detail::is_permutation_v>>* = nullptr> Tensor(const T1& other, Op&& op) - : pimpl_(std::make_shared(detail::clone_range(other))) { + : Tensor(detail::clone_range(other), 1, default_construct{false}) + { detail::tensor_init(op, *this, other); } @@ -290,15 +271,16 @@ class Tensor { typename std::enable_if_t::value && detail::is_permutation_v>* = nullptr> Tensor(const T1& other, Op&& op, const Perm& perm) - : pimpl_(std::make_shared(outer(perm) * other.range())) { + : Tensor(outer(perm) * other.range(), 1, default_construct{false}) + { detail::tensor_init(op, outer(perm), *this, other); // If we actually have a ToT the inner permutation was not applied above so // we do that now - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + constexpr bool is_tot = detail::is_tensor_of_tensor_v; constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor_"); + // not match Tensor"); if constexpr (is_tot && is_bperm) { if (inner_size(perm) != 0) { auto inner_perm = inner(perm); @@ -319,7 +301,8 @@ class Tensor { template ::value>::type* = nullptr> Tensor(const T1& left, const T2& right, Op&& op) - : pimpl_(std::make_shared(detail::clone_range(left))) { + : Tensor(detail::clone_range(left), 1, default_construct{false}) + { detail::tensor_init(op, *this, left, right); } @@ -338,15 +321,16 @@ class Tensor { typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> Tensor(const T1& left, const T2& right, Op&& op, const Perm& perm) - : pimpl_(std::make_shared(outer(perm) * left.range())) { + : Tensor(outer(perm) * left.range(), 1, default_construct{false}) + { detail::tensor_init(op, outer(perm), *this, left, right); // If we actually have a ToT the inner permutation was not applied above so // we do that now - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + constexpr bool is_tot = detail::is_tensor_of_tensor_v; constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor_"); + // not match Tensor"); if constexpr (is_tot && is_bperm) { if (inner_size(perm) != 0) { auto inner_perm = inner(perm); @@ -356,10 +340,23 @@ class Tensor { } } - Tensor_ clone() const { - Tensor_ result; - if (pimpl_) { - result = detail::tensor_op( + size_t batch_size() const { return this->batch_size_; } + + Tensor batch(size_t idx) const { + TA_ASSERT(idx < this->batch_size()); + std::shared_ptr data(this->data_, this->data_.get() + idx*this->size()); + return Tensor(this->range(), 1, data); + } + + auto reshape(const range_type& range, size_t batch_size = 1) const { + TA_ASSERT(this->range().volume()*this->batch_size() == range.volume()*batch_size); + return Tensor(range, batch_size, this->data_); + } + + Tensor clone() const { + Tensor result; + if (data_) { + result = detail::tensor_op( [](const numeric_type value) -> numeric_type { return value; }, *this); } @@ -368,8 +365,8 @@ class Tensor { template ::value>::type* = nullptr> - Tensor_& operator=(const T1& other) { - pimpl_ = std::make_shared(detail::clone_range(other)); + Tensor& operator=(const T1& other) { + *this = Tensor(detail::clone_range(other), 1, default_construct{false}); detail::inplace_tensor_op( [](reference MADNESS_RESTRICT tr, typename T1::const_reference MADNESS_RESTRICT t1) { tr = t1; }, @@ -382,22 +379,13 @@ class Tensor { /// \return The tensor range object const range_type& range() const { - return (pimpl_ ? pimpl_->range_ : empty_range_); - } - - /// Tensor range object mutable accessor - - /// \return The tensor range object - /// \note asserts that this object has been already initialized - range_type& range() { - TA_ASSERT(pimpl_); - return pimpl_->range_; + return range_; } /// Tensor dimension size accessor /// \return The number of elements in the tensor - ordinal_type size() const { return (pimpl_ ? pimpl_->range_.volume() : 0ul); } + ordinal_type size() const { return (this->range().volume()); } /// Const element accessor @@ -409,9 +397,9 @@ class Tensor { template ::value>* = nullptr> const_reference operator[](const Ordinal ord) const { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(ord)); - return pimpl_->data_[ord]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(ord)); + return this->data()[ord]; } /// Element accessor @@ -424,9 +412,9 @@ class Tensor { template ::value>* = nullptr> reference operator[](const Ordinal ord) { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(ord)); - return pimpl_->data_[ord]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(ord)); + return this->data()[ord]; } /// Const element accessor @@ -439,9 +427,9 @@ class Tensor { template >* = nullptr> const_reference operator[](const Index& i) const { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Element accessor @@ -454,9 +442,9 @@ class Tensor { template >* = nullptr> reference operator[](const Index& i) { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Const element accessor @@ -469,9 +457,9 @@ class Tensor { template >* = nullptr> const_reference operator[](const std::initializer_list& i) const { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Element accessor @@ -484,9 +472,9 @@ class Tensor { template >* = nullptr> reference operator[](const std::initializer_list& i) { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Const element accessor @@ -499,9 +487,9 @@ class Tensor { template >* = nullptr> const_reference operator()(const Index& i) const { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Element accessor @@ -514,9 +502,9 @@ class Tensor { template >* = nullptr> reference operator()(const Index& i) { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Const element accessor @@ -529,9 +517,9 @@ class Tensor { template >* = nullptr> const_reference operator()(const std::initializer_list& i) const { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Element accessor @@ -544,9 +532,9 @@ class Tensor { template >* = nullptr> reference operator()(const std::initializer_list& i) { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i)); - return pimpl_->data_[pimpl_->range_.ordinal(i)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i)); + return this->data()[this->range_.ordinal(i)]; } /// Const element accessor @@ -559,9 +547,9 @@ class Tensor { typename... Index, std::enable_if_t::value>* = nullptr> const_reference operator()(const Index&... i) const { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i...)); - return pimpl_->data_[pimpl_->range_.ordinal(i...)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i...)); + return this->data()[this->range_.ordinal(i...)]; } /// Element accessor @@ -574,108 +562,83 @@ class Tensor { typename... Index, std::enable_if_t::value>* = nullptr> reference operator()(const Index&... i) { - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.includes(i...)); - return pimpl_->data_[pimpl_->range_.ordinal(i...)]; + //TA_ASSERT(pimpl_); + TA_ASSERT(this->range_.includes(i...)); + return this->data()[this->range_.ordinal(i...)]; } /// Iterator factory /// \return An iterator to the first data element - const_iterator begin() const { return (pimpl_ ? pimpl_->data_ : NULL); } + const_iterator begin() const { return (this->data() ? this->data() : NULL); } /// Iterator factory /// \return An iterator to the first data element - iterator begin() { return (pimpl_ ? pimpl_->data_ : NULL); } + iterator begin() { return (this->data() ? this->data() : NULL); } /// Iterator factory /// \return An iterator to the last data element const_iterator end() const { - return (pimpl_ ? pimpl_->data_ + pimpl_->range_.volume() : NULL); + return (this->data() ? this->data() + this->size() : NULL); } /// Iterator factory /// \return An iterator to the last data element iterator end() { - return (pimpl_ ? pimpl_->data_ + pimpl_->range_.volume() : NULL); + return (this->data() ? this->data() + this->size() : NULL); } /// Data direct access /// \return A const pointer to the tensor data - const_pointer data() const { return (pimpl_ ? pimpl_->data_ : NULL); } + const_pointer data() const { return this->data_.get(); } /// Data direct access /// \return A const pointer to the tensor data - pointer data() { return (pimpl_ ? pimpl_->data_ : NULL); } + pointer data() { return this->data_.get(); } /// Test if the tensor is empty /// \return \c true if this tensor was default constructed (contains no /// data), otherwise \c false. - bool empty() const { return !pimpl_; } + bool empty() const { return !this->data_; } /// Output serialization function /// This function enables serialization within MADNESS /// \tparam Archive The output archive type /// \param[out] ar The output archive - template ::value>::type* = nullptr> + template void serialize(Archive& ar) { - if (pimpl_) { - ar & pimpl_->range_.volume(); - ar& madness::archive::wrap(pimpl_->data_, pimpl_->range_.volume()); - ar & pimpl_->range_; - } else { - ar& ordinal_type(0ul); - } - } - - /// Input serialization function - - /// This function implements serialization to/from MADNESS archive objects - /// \tparam Archive The input archive type - /// \param[out] ar The input archive - template ::value>::type* = nullptr> - void serialize(Archive& ar) { - ordinal_type n = 0ul; - ar& n; - if (n) { - std::shared_ptr temp = std::make_shared(); - temp->data_ = temp->allocate(n); - try { - // need to construct elements of data_ using placement new in case its - // default ctor is not trivial N.B. for fundamental types and standard - // alloc this incurs no overhead (Eigen::aligned_alloc OK also) - auto* data_ptr = temp->data_; - for (ordinal_type i = 0; i != n; ++i, ++data_ptr) - new (static_cast(data_ptr)) value_type; - - ar& madness::archive::wrap(temp->data_, n); - ar & temp->range_; - } catch (...) { - temp->deallocate(temp->data_, n); - throw; + bool empty = this->empty(); + auto range = this->range_; + auto batch_size = this->batch_size_; + ar & empty; + if (!empty) { + ar & range; + ar & batch_size; + if (madness::archive::is_input_archive::value) { + *this = Tensor(range, batch_size, default_construct{true}); } - - pimpl_ = temp; - } else { - pimpl_.reset(); + ar & madness::archive::wrap(this->data_.get(), range.volume()*batch_size); + } + else if (madness::archive::is_input_archive::value) { + *this = Tensor{}; } } /// Swap tensor data /// \param other The tensor to swap with this - void swap(Tensor_& other) { std::swap(pimpl_, other.pimpl_); } + void swap(Tensor& other) { + std::swap(data_, other.data_); + std::swap(range_, other.range_); + std::swap(batch_size_, other.batch_size_); + } // clang-format off /// Constructs a view of the block defined by \p lower_bound and \p upper_bound. @@ -703,9 +666,9 @@ class Tensor { detail::is_integral_range_v>> detail::TensorInterface block(const Index1& lower_bound, const Index2& upper_bound) { - TA_ASSERT(pimpl_); + //TA_ASSERT(pimpl_); return detail::TensorInterface( - BlockRange(pimpl_->range_, lower_bound, upper_bound), pimpl_->data_); + BlockRange(this->range_, lower_bound, upper_bound), this->data()); } template >> detail::TensorInterface block( const Index1& lower_bound, const Index2& upper_bound) const { - TA_ASSERT(pimpl_); + //TA_ASSERT(pimpl_); return detail::TensorInterface( - BlockRange(pimpl_->range_, lower_bound, upper_bound), pimpl_->data_); + BlockRange(this->range_, lower_bound, upper_bound), this->data()); } /// @} @@ -744,9 +707,9 @@ class Tensor { detail::TensorInterface block( const std::initializer_list& lower_bound, const std::initializer_list& upper_bound) { - TA_ASSERT(pimpl_); + //TA_ASSERT(pimpl_); return detail::TensorInterface( - BlockRange(pimpl_->range_, lower_bound, upper_bound), pimpl_->data_); + BlockRange(this->range_, lower_bound, upper_bound), this->data()); } template block( const std::initializer_list& lower_bound, const std::initializer_list& upper_bound) const { - TA_ASSERT(pimpl_); + //TA_ASSERT(pimpl_); return detail::TensorInterface( - BlockRange(pimpl_->range_, lower_bound, upper_bound), pimpl_->data_); + BlockRange(this->range_, lower_bound, upper_bound), this->data()); } /// @} @@ -800,14 +763,14 @@ class Tensor { detail::TensorInterface block( const PairRange& bounds) const { return detail::TensorInterface( - BlockRange(pimpl_->range_, bounds), pimpl_->data_); + BlockRange(this->range_, bounds), this->data()); } template >> detail::TensorInterface block(const PairRange& bounds) { return detail::TensorInterface( - BlockRange(pimpl_->range_, bounds), pimpl_->data_); + BlockRange(this->range_, bounds), this->data()); } /// @} @@ -831,7 +794,7 @@ class Tensor { detail::TensorInterface block( const std::initializer_list>& bounds) const { return detail::TensorInterface( - BlockRange(pimpl_->range_, bounds), pimpl_->data_); + BlockRange(this->range_, bounds), this->data()); } template block( const std::initializer_list>& bounds) { return detail::TensorInterface( - BlockRange(pimpl_->range_, bounds), pimpl_->data_); + BlockRange(this->range_, bounds), this->data()); } /// @} @@ -850,24 +813,24 @@ class Tensor { /// \return A permuted copy of this tensor template >> - Tensor_ permute(const Perm& perm) const { - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + Tensor permute(const Perm& perm) const { + constexpr bool is_tot = detail::is_tensor_of_tensor_v; [[maybe_unused]] constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor_"); + // not match Tensor"); if constexpr (!is_tot) { if constexpr (is_bperm) { TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation - return Tensor_(*this, outer(perm)); + return Tensor(*this, outer(perm)); } else - return Tensor_(*this, perm); + return Tensor(*this, perm); } else { // If we have a ToT we need to apply the permutation in two steps. The // first step is identical to the non-ToT case (permute the outer modes) // the second step does the inner modes - Tensor_ rv(*this, outer(perm)); + Tensor rv(*this, outer(perm)); if constexpr (is_bperm) { if (inner_size(perm) != 0) { auto inner_perm = inner(perm); @@ -887,9 +850,9 @@ class Tensor { /// \return A reference to this tensor template >* = nullptr> - Tensor_& shift_to(const Index& bound_shift) { - TA_ASSERT(pimpl_); - pimpl_->range_.inplace_shift(bound_shift); + Tensor& shift_to(const Index& bound_shift) { + //TA_ASSERT(pimpl_); + this->range_.inplace_shift(bound_shift); return *this; } @@ -900,9 +863,9 @@ class Tensor { /// \return A reference to this tensor template >* = nullptr> - Tensor_& shift_to(const std::initializer_list& bound_shift) { - TA_ASSERT(pimpl_); - pimpl_->range_.template inplace_shift>( + Tensor& shift_to(const std::initializer_list& bound_shift) { + //TA_ASSERT(pimpl_); + this->range_.template inplace_shift>( bound_shift); return *this; } @@ -914,9 +877,9 @@ class Tensor { /// \return A shifted copy of this tensor template >* = nullptr> - Tensor_ shift(const Index& bound_shift) const { - TA_ASSERT(pimpl_); - Tensor_ result = clone(); + Tensor shift(const Index& bound_shift) const { + //TA_ASSERT(pimpl_); + Tensor result = clone(); result.shift_to(bound_shift); return result; } @@ -928,9 +891,9 @@ class Tensor { /// \return A shifted copy of this tensor template >* = nullptr> - Tensor_ shift(const std::initializer_list& bound_shift) const { - TA_ASSERT(pimpl_); - Tensor_ result = clone(); + Tensor shift(const std::initializer_list& bound_shift) const { + //TA_ASSERT(pimpl_); + Tensor result = clone(); result.template shift_to>(bound_shift); return result; } @@ -947,8 +910,8 @@ class Tensor { /// \c op(*this[i],other[i]) template ::value>::type* = nullptr> - Tensor_ binary(const Right& right, Op&& op) const { - return Tensor_(*this, right, op); + Tensor binary(const Right& right, Op&& op) const { + return Tensor(*this, right, op); } /// Use a binary, element wise operation to construct a new, permuted tensor @@ -965,24 +928,24 @@ class Tensor { typename Right, typename Op, typename Perm, typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> - Tensor_ binary(const Right& right, Op&& op, const Perm& perm) const { - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + Tensor binary(const Right& right, Op&& op, const Perm& perm) const { + constexpr bool is_tot = detail::is_tensor_of_tensor_v; [[maybe_unused]] constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor_"); + // not match Tensor"); if constexpr (!is_tot) { if constexpr (is_bperm) { TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation - return Tensor_(*this, right, op, outer(perm)); + return Tensor(*this, right, op, outer(perm)); } else - return Tensor_(*this, right, op, perm); + return Tensor(*this, right, op, perm); } else { // AFAIK the other branch fundamentally relies on raw pointer arithmetic, // which won't work for ToTs. auto temp = binary(right, std::forward(op)); - Permute p; + Permute p; return p(temp, perm); } abort(); // unreachable @@ -1002,7 +965,7 @@ class Tensor { /// \throw TiledArray::Exception When this and \c other are the same. template ::value>::type* = nullptr> - Tensor_& inplace_binary(const Right& right, Op&& op) { + Tensor& inplace_binary(const Right& right, Op&& op) { detail::inplace_tensor_op(op, *this, right); return *this; } @@ -1015,8 +978,8 @@ class Tensor { /// \c op(*this[i]) /// \throw TiledArray::Exception When this tensor is empty. template - Tensor_ unary(Op&& op) const { - return Tensor_(*this, op); + Tensor unary(Op&& op) const { + return Tensor(*this, op); } /// Use a unary, element wise operation to construct a new, permuted tensor @@ -1031,22 +994,22 @@ class Tensor { /// that of this tensor. template >> - Tensor_ unary(Op&& op, const Perm& perm) const { - constexpr bool is_tot = detail::is_tensor_of_tensor_v; + Tensor unary(Op&& op, const Perm& perm) const { + constexpr bool is_tot = detail::is_tensor_of_tensor_v; [[maybe_unused]] constexpr bool is_bperm = detail::is_bipartite_permutation_v; // tile ops pass bipartite permutations here even if this is a plain tensor // static_assert(is_tot || (!is_tot && !is_bperm), "Permutation type does - // not match Tensor_"); + // not match Tensor"); if constexpr (!is_tot) { if constexpr (is_bperm) { TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation - return Tensor_(*this, op, outer(perm)); + return Tensor(*this, op, outer(perm)); } else - return Tensor_(*this, op, perm); + return Tensor(*this, op, perm); } else { auto temp = unary(std::forward(op)); - Permute p; + Permute p; return p(temp, perm); } abort(); // unreachable @@ -1059,7 +1022,7 @@ class Tensor { /// \return A reference to this object /// \throw TiledArray::Exception When this tensor is empty. template - Tensor_& inplace_unary(Op&& op) { + Tensor& inplace_unary(Op&& op) { detail::inplace_tensor_op(op, *this); return *this; } @@ -1074,7 +1037,7 @@ class Tensor { /// \c factor template >::type* = nullptr> - Tensor_ scale(const Scalar factor) const { + Tensor scale(const Scalar factor) const { return unary( [factor](const numeric_type a) -> numeric_type { return a * factor; }); } @@ -1090,7 +1053,7 @@ class Tensor { template && detail::is_permutation_v>> - Tensor_ scale(const Scalar factor, const Perm& perm) const { + Tensor scale(const Scalar factor, const Perm& perm) const { return unary( [factor](const numeric_type a) -> numeric_type { return a * factor; }, perm); @@ -1103,7 +1066,7 @@ class Tensor { /// \return A reference to this tensor template >::type* = nullptr> - Tensor_& scale_to(const Scalar factor) { + Tensor& scale_to(const Scalar factor) { return inplace_unary( [factor](numeric_type& MADNESS_RESTRICT res) { res *= factor; }); } @@ -1118,7 +1081,7 @@ class Tensor { /// \c this and \c other template ::value>::type* = nullptr> - Tensor_ add(const Right& right) const { + Tensor add(const Right& right) const { return binary( right, [](const numeric_type l, const numeric_t r) -> numeric_type { @@ -1138,7 +1101,7 @@ class Tensor { typename Right, typename Perm, typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> - Tensor_ add(const Right& right, const Perm& perm) const { + Tensor add(const Right& right, const Perm& perm) const { return binary( right, [](const numeric_type l, const numeric_t r) -> numeric_type { @@ -1159,7 +1122,7 @@ class Tensor { typename Right, typename Scalar, typename std::enable_if::value && detail::is_numeric_v>::type* = nullptr> - Tensor_ add(const Right& right, const Scalar factor) const { + Tensor add(const Right& right, const Scalar factor) const { return binary(right, [factor](const numeric_type l, const numeric_t r) -> numeric_type { return (l + r) * factor; }); @@ -1179,7 +1142,7 @@ class Tensor { typename std::enable_if< is_tensor::value && detail::is_numeric_v && detail::is_permutation_v>::type* = nullptr> - Tensor_ add(const Right& right, const Scalar factor, const Perm& perm) const { + Tensor add(const Right& right, const Scalar factor, const Perm& perm) const { return binary( right, [factor](const numeric_type l, const numeric_t r) @@ -1192,7 +1155,7 @@ class Tensor { /// \param value The constant to be added to this tensor /// \return A new tensor where the elements are the sum of the elements of /// \c this and \c value - Tensor_ add(const numeric_type value) const { + Tensor add(const numeric_type value) const { return unary( [value](const numeric_type a) -> numeric_type { return a + value; }); } @@ -1206,7 +1169,7 @@ class Tensor { /// \c this and \c value template >> - Tensor_ add(const numeric_type value, const Perm& perm) const { + Tensor add(const numeric_type value, const Perm& perm) const { return unary( [value](const numeric_type a) -> numeric_type { return a + value; }, perm); @@ -1219,7 +1182,7 @@ class Tensor { /// \return A reference to this tensor template ::value>::type* = nullptr> - Tensor_& add_to(const Right& right) { + Tensor& add_to(const Right& right) { return inplace_binary(right, [](numeric_type& MADNESS_RESTRICT l, const numeric_t r) { l += r; }); } @@ -1235,7 +1198,7 @@ class Tensor { typename Right, typename Scalar, typename std::enable_if::value && detail::is_numeric_v>::type* = nullptr> - Tensor_& add_to(const Right& right, const Scalar factor) { + Tensor& add_to(const Right& right, const Scalar factor) { return inplace_binary( right, [factor](numeric_type& MADNESS_RESTRICT l, const numeric_t r) { (l += r) *= factor; }); @@ -1245,7 +1208,7 @@ class Tensor { /// \param value The constant to be added /// \return A reference to this tensor - Tensor_& add_to(const numeric_type value) { + Tensor& add_to(const numeric_type value) { return inplace_unary( [value](numeric_type& MADNESS_RESTRICT res) { res += value; }); } @@ -1260,7 +1223,7 @@ class Tensor { /// elements of \c this and \c right template ::value>::type* = nullptr> - Tensor_ subt(const Right& right) const { + Tensor subt(const Right& right) const { return binary( right, [](const numeric_type l, const numeric_t r) -> numeric_type { @@ -1280,7 +1243,7 @@ class Tensor { typename Right, typename Perm, typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> - Tensor_ subt(const Right& right, const Perm& perm) const { + Tensor subt(const Right& right, const Perm& perm) const { return binary( right, [](const numeric_type l, const numeric_t r) -> numeric_type { @@ -1302,7 +1265,7 @@ class Tensor { typename Right, typename Scalar, typename std::enable_if::value && detail::is_numeric_v>::type* = nullptr> - Tensor_ subt(const Right& right, const Scalar factor) const { + Tensor subt(const Right& right, const Scalar factor) const { return binary(right, [factor](const numeric_type l, const numeric_t r) -> numeric_type { return (l - r) * factor; }); @@ -1323,7 +1286,7 @@ class Tensor { typename std::enable_if< is_tensor::value && detail::is_numeric_v && detail::is_permutation_v>::type* = nullptr> - Tensor_ subt(const Right& right, const Scalar factor, + Tensor subt(const Right& right, const Scalar factor, const Perm& perm) const { return binary( right, @@ -1336,7 +1299,7 @@ class Tensor { /// \return A new tensor where the elements are the different between the /// elements of \c this and \c value - Tensor_ subt(const numeric_type value) const { return add(-value); } + Tensor subt(const numeric_type value) const { return add(-value); } /// Subtract a constant from a permuted copy of this tensor @@ -1347,7 +1310,7 @@ class Tensor { /// elements of \c this and \c value template >> - Tensor_ subt(const numeric_type value, const Perm& perm) const { + Tensor subt(const numeric_type value, const Perm& perm) const { return add(-value, perm); } @@ -1358,7 +1321,7 @@ class Tensor { /// \return A reference to this tensor template ::value>::type* = nullptr> - Tensor_& subt_to(const Right& right) { + Tensor& subt_to(const Right& right) { return inplace_binary(right, [](numeric_type& MADNESS_RESTRICT l, const numeric_t r) { l -= r; }); } @@ -1374,7 +1337,7 @@ class Tensor { typename Right, typename Scalar, typename std::enable_if::value && detail::is_numeric_v>::type* = nullptr> - Tensor_& subt_to(const Right& right, const Scalar factor) { + Tensor& subt_to(const Right& right, const Scalar factor) { return inplace_binary( right, [factor](numeric_type& MADNESS_RESTRICT l, const numeric_t r) { (l -= r) *= factor; }); @@ -1383,7 +1346,7 @@ class Tensor { /// Subtract a constant from this tensor /// \return A reference to this tensor - Tensor_& subt_to(const numeric_type value) { return add_to(-value); } + Tensor& subt_to(const numeric_type value) { return add_to(-value); } // Multiplication operations @@ -1395,7 +1358,7 @@ class Tensor { /// of \c this and \c right template ::value>::type* = nullptr> - Tensor_ mult(const Right& right) const { + Tensor mult(const Right& right) const { return binary( right, [](const numeric_type l, const numeric_t r) -> numeric_type { @@ -1415,7 +1378,7 @@ class Tensor { typename Right, typename Perm, typename std::enable_if::value && detail::is_permutation_v>::type* = nullptr> - Tensor_ mult(const Right& right, const Perm& perm) const { + Tensor mult(const Right& right, const Perm& perm) const { return binary( right, [](const numeric_type l, const numeric_t r) -> numeric_type { @@ -1436,7 +1399,7 @@ class Tensor { typename Right, typename Scalar, typename std::enable_if::value && detail::is_numeric_v>::type* = nullptr> - Tensor_ mult(const Right& right, const Scalar factor) const { + Tensor mult(const Right& right, const Scalar factor) const { return binary(right, [factor](const numeric_type l, const numeric_t r) -> numeric_type { return (l * r) * factor; }); @@ -1456,7 +1419,7 @@ class Tensor { typename std::enable_if< is_tensor::value && detail::is_numeric_v && detail::is_permutation_v>::type* = nullptr> - Tensor_ mult(const Right& right, const Scalar factor, + Tensor mult(const Right& right, const Scalar factor, const Perm& perm) const { return binary( right, @@ -1472,7 +1435,7 @@ class Tensor { /// \return A reference to this tensor template ::value>::type* = nullptr> - Tensor_& mult_to(const Right& right) { + Tensor& mult_to(const Right& right) { return inplace_binary(right, [](numeric_type& MADNESS_RESTRICT l, const numeric_t r) { l *= r; }); } @@ -1488,7 +1451,7 @@ class Tensor { typename Right, typename Scalar, typename std::enable_if::value && detail::is_numeric_v>::type* = nullptr> - Tensor_& mult_to(const Right& right, const Scalar factor) { + Tensor& mult_to(const Right& right, const Scalar factor) { return inplace_binary( right, [factor](numeric_type& MADNESS_RESTRICT l, const numeric_t r) { (l *= r) *= factor; }); @@ -1499,7 +1462,7 @@ class Tensor { /// Create a negated copy of this tensor /// \return A new tensor that contains the negative values of this tensor - Tensor_ neg() const { + Tensor neg() const { return unary([](const numeric_type r) -> numeric_type { return -r; }); } @@ -1510,14 +1473,14 @@ class Tensor { /// \return A new tensor that contains the negative values of this tensor template >> - Tensor_ neg(const Perm& perm) const { + Tensor neg(const Perm& perm) const { return unary([](const numeric_type l) -> numeric_type { return -l; }, perm); } /// Negate elements of this tensor /// \return A reference to this tensor - Tensor_& neg_to() { + Tensor& neg_to() { return inplace_unary([](numeric_type& MADNESS_RESTRICT l) { l = -l; }); } @@ -1525,8 +1488,8 @@ class Tensor { /// \return A copy of this tensor that contains the complex conjugate the /// values - Tensor_ conj() const { - TA_ASSERT(pimpl_); + Tensor conj() const { + //TA_ASSERT(pimpl_); return scale(detail::conj_op()); } @@ -1538,8 +1501,8 @@ class Tensor { /// conjugate the values template >::type* = nullptr> - Tensor_ conj(const Scalar factor) const { - TA_ASSERT(pimpl_); + Tensor conj(const Scalar factor) const { + //TA_ASSERT(pimpl_); return scale(detail::conj_op(factor)); } @@ -1551,8 +1514,8 @@ class Tensor { /// conjugate values template >> - Tensor_ conj(const Perm& perm) const { - TA_ASSERT(pimpl_); + Tensor conj(const Perm& perm) const { + //TA_ASSERT(pimpl_); return scale(detail::conj_op(), perm); } @@ -1568,16 +1531,16 @@ class Tensor { typename Scalar, typename Perm, typename std::enable_if && detail::is_permutation_v>::type* = nullptr> - Tensor_ conj(const Scalar factor, const Perm& perm) const { - TA_ASSERT(pimpl_); + Tensor conj(const Scalar factor, const Perm& perm) const { + //TA_ASSERT(pimpl_); return scale(detail::conj_op(factor), perm); } /// Complex conjugate this tensor /// \return A reference to this tensor - Tensor_& conj_to() { - TA_ASSERT(pimpl_); + Tensor& conj_to() { + //TA_ASSERT(pimpl_); return scale_to(detail::conj_op()); } @@ -1588,303 +1551,47 @@ class Tensor { /// \return A reference to this tensor template >::type* = nullptr> - Tensor_& conj_to(const Scalar factor) { - TA_ASSERT(pimpl_); + Tensor& conj_to(const Scalar factor) { + //TA_ASSERT(pimpl_); return scale_to(detail::conj_op(factor)); } // GEMM operations - /// Contract this tensor with \c other - - /// \tparam U The other tensor element type - /// \tparam AU The other tensor allocator type - /// \tparam V The type of \c factor scalar - /// \param other The tensor that will be contracted with this tensor - /// \param factor Multiply the result by this constant - /// \param gemm_helper The *GEMM operation meta data - /// \return A new tensor which is the result of contracting this tensor with - /// \c other and scaled by \c factor - template - Tensor_ gemm(const Tensor& other, const V factor, - const math::GemmHelper& gemm_helper) const { - static_assert(!detail::is_tensor_of_tensor_v>, - "TA::Tensor::gemm without custom element op is only " - "applicable to plain tensors"); - // Check that this tensor is not empty and has the correct rank - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.rank() == gemm_helper.left_rank()); - - // Check that the arguments are not empty and have the correct ranks - TA_ASSERT(!other.empty()); - TA_ASSERT(other.range().rank() == gemm_helper.right_rank()); - - // Construct the result Tensor - Tensor_ result(gemm_helper.make_result_range(pimpl_->range_, - other.range())); - - // Check that the inner dimensions of left and right match - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_right_congruent(pimpl_->range_.lobound_data(), - other.range().lobound_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_right_congruent(pimpl_->range_.upbound_data(), - other.range().upbound_data())); - TA_ASSERT(gemm_helper.left_right_congruent(pimpl_->range_.extent_data(), - other.range().extent_data())); - - // Compute gemm dimensions - using integer = TiledArray::math::blas::integer; - integer m = 1, n = 1, k = 1; - gemm_helper.compute_matrix_sizes(m, n, k, pimpl_->range_, other.range()); - - // Get the leading dimension for left and right matrices. - const integer lda = - (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k : m); - const integer ldb = - (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n : k); - - math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, k, - factor, pimpl_->data_, lda, other.data(), ldb, - numeric_type(0), result.data(), n); - -#ifdef TA_ENABLE_TILE_OPS_LOGGING - if (TiledArray::TileOpsLogger::get_instance_ptr() != nullptr && - TiledArray::TileOpsLogger::get_instance().gemm) { - auto& logger = TiledArray::TileOpsLogger::get_instance(); - auto apply = [](auto& fnptr, const Range& arg) { - return fnptr ? fnptr(arg) : arg; - }; - auto tformed_left_range = - apply(logger.gemm_left_range_transform, pimpl_->range_); - auto tformed_right_range = - apply(logger.gemm_right_range_transform, other.range()); - auto tformed_result_range = - apply(logger.gemm_result_range_transform, result.range()); - if ((!logger.gemm_result_range_filter || - logger.gemm_result_range_filter(tformed_result_range)) && - (!logger.gemm_left_range_filter || - logger.gemm_left_range_filter(tformed_left_range)) && - (!logger.gemm_right_range_filter || - logger.gemm_right_range_filter(tformed_right_range))) { - logger << "TA::Tensor::gemm=: left=" << tformed_left_range - << " right=" << tformed_right_range - << " result=" << tformed_result_range << std::endl; - if (TiledArray::TileOpsLogger::get_instance() - .gemm_print_contributions) { - if (!TiledArray::TileOpsLogger::get_instance().gemm_printer) { - // must use custom printer if result's range transformed - if (!logger.gemm_result_range_transform) - logger << result << std::endl; - else - logger << make_map(result.data(), tformed_result_range) - << std::endl; - } else { - TiledArray::TileOpsLogger::get_instance().gemm_printer( - *logger.log, tformed_left_range, this->data(), - tformed_right_range, other.data(), tformed_right_range, - result.data()); - } - } - } - } -#endif // TA_ENABLE_TILE_OPS_LOGGING - + template + Tensor gemm(const Tensor& A, const V alpha, + const math::GemmHelper& gemm_helper) const { + Tensor result; + result.gemm(*this, A, alpha, gemm_helper); return result; } - /// Contract two tensors and accumulate the scaled result to this tensor - - /// GEMM is limited to matrix like contractions. For example, the following - /// contractions are supported: - /// \code - /// C[a,b] = A[a,i,j] * B[i,j,b] - /// C[a,b] = A[a,i,j] * B[b,i,j] - /// C[a,b] = A[i,j,a] * B[i,j,b] - /// C[a,b] = A[i,j,a] * B[b,i,j] - /// - /// C[a,b,c,d] = A[a,b,i,j] * B[i,j,c,d] - /// C[a,b,c,d] = A[a,b,i,j] * B[c,d,i,j] - /// C[a,b,c,d] = A[i,j,a,b] * B[i,j,c,d] - /// C[a,b,c,d] = A[i,j,a,b] * B[c,d,i,j] - /// \endcode - /// Notice that in the above contractions, the inner and outer indices of - /// the arguments for exactly two contiguous groups in each tensor and that - /// each group is in the same order in all tensors. That is, the indices of - /// the tensors must fit the one of the following patterns: - /// \code - /// C[M...,N...] = A[M...,K...] * B[K...,N...] - /// C[M...,N...] = A[M...,K...] * B[N...,K...] - /// C[M...,N...] = A[K...,M...] * B[K...,N...] - /// C[M...,N...] = A[K...,M...] * B[N...,K...] - /// \endcode - /// This allows use of optimized BLAS functions to evaluate tensor - /// contractions. Tensor contractions that do not fit this pattern require - /// one or more tensor permutation so that the tensors fit the required - /// pattern. - /// \tparam U The left-hand tensor element type - /// \tparam AU The left-hand tensor allocator type - /// \tparam V The right-hand tensor element type - /// \tparam AV The right-hand tensor allocator type - /// \tparam W The type of the scaling factor - /// \param left The left-hand tensor that will be contracted - /// \param right The right-hand tensor that will be contracted - /// \param factor The contraction result will be scaling by this value, then - /// accumulated into \c this \param gemm_helper The *GEMM operation meta data - /// \return A reference to \c this - /// \note if this is uninitialized, i.e., if \c this->empty()==true will - /// this is equivalent to - /// \code - /// return (*this = left.gemm(right, factor, gemm_helper)); - /// \endcode - template - Tensor_& gemm(const Tensor& left, const Tensor& right, - const W factor, const math::GemmHelper& gemm_helper) { - static_assert( - !detail::is_tensor_of_tensor_v, Tensor>, - "TA::Tensor::gemm without custom element op is only applicable to " - "plain tensors"); + template + Tensor& gemm(const Tensor& A, const Tensor& B, const W alpha, + const math::GemmHelper& gemm_helper) + { + numeric_type beta = 1; if (this->empty()) { - *this = left.gemm(right, factor, gemm_helper); - } else { - // Check that this tensor is not empty and has the correct rank - TA_ASSERT(pimpl_); - TA_ASSERT(pimpl_->range_.rank() == gemm_helper.result_rank()); - - // Check that the arguments are not empty and have the correct ranks - TA_ASSERT(!left.empty()); - TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); - TA_ASSERT(!right.empty()); - TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); - - // Check that the outer dimensions of left match the corresponding - // dimensions in result - TA_ASSERT(ignore_tile_position() || gemm_helper.left_result_congruent( - left.range().lobound_data(), - pimpl_->range_.lobound_data())); - TA_ASSERT(ignore_tile_position() || gemm_helper.left_result_congruent( - left.range().upbound_data(), - pimpl_->range_.upbound_data())); - TA_ASSERT(gemm_helper.left_result_congruent( - left.range().extent_data(), pimpl_->range_.extent_data())); - - // Check that the outer dimensions of right match the corresponding - // dimensions in result - TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( - right.range().lobound_data(), - pimpl_->range_.lobound_data())); - TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( - right.range().upbound_data(), - pimpl_->range_.upbound_data())); - TA_ASSERT(gemm_helper.right_result_congruent( - right.range().extent_data(), pimpl_->range_.extent_data())); - - // Check that the inner dimensions of left and right match - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_right_congruent(left.range().lobound_data(), - right.range().lobound_data())); - TA_ASSERT(ignore_tile_position() || - gemm_helper.left_right_congruent(left.range().upbound_data(), - right.range().upbound_data())); - TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), - right.range().extent_data())); - - // Compute gemm dimensions - using integer = TiledArray::math::blas::integer; - integer m, n, k; - gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); - - // Get the leading dimension for left and right matrices. - const integer lda = - (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k - : m); - const integer ldb = - (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n - : k); - - // may need to split gemm into multiply + accumulate for tracing purposes -#ifdef TA_ENABLE_TILE_OPS_LOGGING - { - const bool twostep = - TiledArray::TileOpsLogger::get_instance().gemm && - TiledArray::TileOpsLogger::get_instance() - .gemm_print_contributions; - std::unique_ptr data_copy; - size_t tile_volume; - if (twostep) { - tile_volume = range().volume(); - data_copy = std::make_unique(tile_volume); - std::copy(pimpl_->data_, pimpl_->data_ + tile_volume, - data_copy.get()); - } - non_distributed::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, - n, k, factor, left.data(), lda, right.data(), ldb, - twostep ? numeric_type(0) : numeric_type(1), - pimpl_->data_, n); - - if (TiledArray::TileOpsLogger::get_instance_ptr() != nullptr && - TiledArray::TileOpsLogger::get_instance().gemm) { - auto& logger = TiledArray::TileOpsLogger::get_instance(); - auto apply = [](auto& fnptr, const Range& arg) { - return fnptr ? fnptr(arg) : arg; - }; - auto tformed_left_range = - apply(logger.gemm_left_range_transform, left.range()); - auto tformed_right_range = - apply(logger.gemm_right_range_transform, right.range()); - auto tformed_result_range = - apply(logger.gemm_result_range_transform, pimpl_->range_); - if ((!logger.gemm_result_range_filter || - logger.gemm_result_range_filter(tformed_result_range)) && - (!logger.gemm_left_range_filter || - logger.gemm_left_range_filter(tformed_left_range)) && - (!logger.gemm_right_range_filter || - logger.gemm_right_range_filter(tformed_right_range))) { - logger << "TA::Tensor::gemm+: left=" << tformed_left_range - << " right=" << tformed_right_range - << " result=" << tformed_result_range << std::endl; - if (TiledArray::TileOpsLogger::get_instance() - .gemm_print_contributions) { - if (!TiledArray::TileOpsLogger::get_instance() - .gemm_printer) { // default printer - // must use custom printer if result's range transformed - if (!logger.gemm_result_range_transform) - logger << *this << std::endl; - else - logger << make_map(pimpl_->data_, tformed_result_range) - << std::endl; - } else { - TiledArray::TileOpsLogger::get_instance().gemm_printer( - *logger.log, tformed_left_range, left.data(), - tformed_right_range, right.data(), tformed_right_range, - pimpl_->data_); - } - } - } - } - - if (twostep) { - for (size_t v = 0; v != tile_volume; ++v) { - pimpl_->data_[v] += data_copy[v]; - } - } - } -#else // TA_ENABLE_TILE_OPS_LOGGING - math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, k, - factor, left.data(), lda, right.data(), ldb, - numeric_type(1), pimpl_->data_, n); -#endif // TA_ENABLE_TILE_OPS_LOGGING + range_type range = gemm_helper.make_result_range(A.range_, B.range()); + *this = Tensor(range, A.batch_size(), default_construct{true}); + beta = 0; + } + TA_ASSERT(this->batch_size() == A.batch_size()); + TA_ASSERT(this->batch_size() == B.batch_size()); + for (size_t i = 0; i < this->batch_size(); ++i) { + auto Ci = this->batch(i); + TiledArray::gemm(alpha, A.batch(i), B.batch(i), beta, Ci, gemm_helper); } - return *this; } + template , value_type&, const U&, const V&>>> - Tensor_& gemm(const Tensor& left, const Tensor& right, + Tensor& gemm(const Tensor& left, const Tensor& right, const math::GemmHelper& gemm_helper, ElementMultiplyAddOp&& elem_muladd_op) { // Check that the arguments are not empty and have the correct ranks @@ -1904,30 +1611,30 @@ class Tensor { right.range().extent_data())); if (this->empty()) { // initialize, if empty - *this = Tensor_(gemm_helper.make_result_range(left.range(), + *this = Tensor(gemm_helper.make_result_range(left.range(), right.range())); } else { // Check that the outer dimensions of left match the corresponding // dimensions in result TA_ASSERT(ignore_tile_position() || gemm_helper.left_result_congruent( left.range().lobound_data(), - pimpl_->range_.lobound_data())); + this->range_.lobound_data())); TA_ASSERT(ignore_tile_position() || gemm_helper.left_result_congruent( left.range().upbound_data(), - pimpl_->range_.upbound_data())); + this->range_.upbound_data())); TA_ASSERT(gemm_helper.left_result_congruent( - left.range().extent_data(), pimpl_->range_.extent_data())); + left.range().extent_data(), this->range_.extent_data())); // Check that the outer dimensions of right match the corresponding // dimensions in result TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( right.range().lobound_data(), - pimpl_->range_.lobound_data())); + this->range_.lobound_data())); TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( right.range().upbound_data(), - pimpl_->range_.upbound_data())); + this->range_.upbound_data())); TA_ASSERT(gemm_helper.right_result_congruent( - right.range().extent_data(), pimpl_->range_.extent_data())); + right.range().extent_data(), this->range_.extent_data())); } // Compute gemm dimensions @@ -1953,7 +1660,7 @@ class Tensor { gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? k * ldb + n : n * ldb + k; - elem_muladd_op(*(pimpl_->data_ + c_offset), *(left.data() + a_offset), + elem_muladd_op(*(this->data() + c_offset), *(left.data() + a_offset), *(right.data() + b_offset)); } } @@ -1970,7 +1677,7 @@ class Tensor { /// tensor. /// \return The trace of this tensor /// \throw TiledArray::Exception When this tensor is empty. - template > decltype(auto) trace() const { return TiledArray::trace(*this); @@ -2156,8 +1863,193 @@ class Tensor { }; // class Tensor -template -const typename Tensor::range_type Tensor::empty_range_; + + /// Contract two tensors and accumulate the scaled result to this tensor + + /// GEMM is limited to matrix like contractions. For example, the following + /// contractions are supported: + /// \code + /// C[a,b] = A[a,i,j] * B[i,j,b] + /// C[a,b] = A[a,i,j] * B[b,i,j] + /// C[a,b] = A[i,j,a] * B[i,j,b] + /// C[a,b] = A[i,j,a] * B[b,i,j] + /// + /// C[a,b,c,d] = A[a,b,i,j] * B[i,j,c,d] + /// C[a,b,c,d] = A[a,b,i,j] * B[c,d,i,j] + /// C[a,b,c,d] = A[i,j,a,b] * B[i,j,c,d] + /// C[a,b,c,d] = A[i,j,a,b] * B[c,d,i,j] + /// \endcode + /// Notice that in the above contractions, the inner and outer indices of + /// the arguments for exactly two contiguous groups in each tensor and that + /// each group is in the same order in all tensors. That is, the indices of + /// the tensors must fit the one of the following patterns: + /// \code + /// C[M...,N...] = A[M...,K...] * B[K...,N...] + /// C[M...,N...] = A[M...,K...] * B[N...,K...] + /// C[M...,N...] = A[K...,M...] * B[K...,N...] + /// C[M...,N...] = A[K...,M...] * B[N...,K...] + /// \endcode + /// This allows use of optimized BLAS functions to evaluate tensor + /// contractions. Tensor contractions that do not fit this pattern require + /// one or more tensor permutation so that the tensors fit the required + /// pattern. + /// \tparam U The left-hand tensor element type + /// \tparam AU The left-hand tensor allocator type + /// \tparam V The right-hand tensor element type + /// \tparam AV The right-hand tensor allocator type + /// \tparam W The type of the scaling factor + /// \param left The left-hand tensor that will be contracted + /// \param right The right-hand tensor that will be contracted + /// \param factor The contraction result will be scaling by this value, then + /// accumulated into \c this \param gemm_helper The *GEMM operation meta data + /// \return A reference to \c this + /// \note if this is uninitialized, i.e., if \c this->empty()==true will + /// this is equivalent to + /// \code + /// return (*this = left.gemm(right, factor, gemm_helper)); + /// \endcode + template + void gemm(Alpha alpha, const Tensor& A, const Tensor& B, + Beta beta, Tensor &C, const math::GemmHelper& gemm_helper) { + // static_assert( + // !detail::is_tensor_of_tensor_v, Tensor>, + // "TA::Tensor::gemm without custom element op is only applicable to " + // "plain tensors"); + { + // Check that this tensor is not empty and has the correct rank + //TA_ASSERT(pimpl_); + TA_ASSERT(C.range().rank() == gemm_helper.result_rank()); + + // Check that the arguments are not empty and have the correct ranks + TA_ASSERT(!A.empty()); + TA_ASSERT(A.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!B.empty()); + TA_ASSERT(B.range().rank() == gemm_helper.right_rank()); + + // Check that the outer dimensions of left match the corresponding + // dimensions in result + TA_ASSERT(ignore_tile_position() || gemm_helper.left_result_congruent( + A.range().lobound_data(), + C.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || gemm_helper.left_result_congruent( + A.range().upbound_data(), + C.range().upbound_data())); + TA_ASSERT(gemm_helper.left_result_congruent( + A.range().extent_data(), C.range().extent_data())); + + // Check that the outer dimensions of right match the corresponding + // dimensions in result + TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( + B.range().lobound_data(), + C.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || gemm_helper.right_result_congruent( + B.range().upbound_data(), + C.range().upbound_data())); + TA_ASSERT(gemm_helper.right_result_congruent( + B.range().extent_data(), C.range().extent_data())); + + // Check that the inner dimensions of left and right match + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(A.range().lobound_data(), + B.range().lobound_data())); + TA_ASSERT(ignore_tile_position() || + gemm_helper.left_right_congruent(A.range().upbound_data(), + B.range().upbound_data())); + TA_ASSERT(gemm_helper.left_right_congruent(A.range().extent_data(), + B.range().extent_data())); + + // Compute gemm dimensions + using integer = TiledArray::math::blas::integer; + integer m, n, k; + gemm_helper.compute_matrix_sizes(m, n, k, A.range(), B.range()); + + // Get the leading dimension for left and right matrices. + const integer lda = + (gemm_helper.left_op() == TiledArray::math::blas::NoTranspose ? k + : m); + const integer ldb = + (gemm_helper.right_op() == TiledArray::math::blas::NoTranspose ? n + : k); + + // may need to split gemm into multiply + accumulate for tracing purposes +#ifdef TA_ENABLE_TILE_OPS_LOGGING + { + const bool twostep = + TiledArray::TileOpsLogger::get_instance().gemm && + TiledArray::TileOpsLogger::get_instance() + .gemm_print_contributions; + std::unique_ptr data_copy; + size_t tile_volume; + if (twostep) { + tile_volume = range().volume(); + data_copy = std::make_unique(tile_volume); + std::copy(C.data(), C.data() + tile_volume, + data_copy.get()); + } + non_distributed::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, + n, k, alpha, A.data(), lda, B.data(), ldb, + twostep ? numeric_type(0) : beta, + C.data(), n); + + if (TiledArray::TileOpsLogger::get_instance_ptr() != nullptr && + TiledArray::TileOpsLogger::get_instance().gemm) { + auto& logger = TiledArray::TileOpsLogger::get_instance(); + auto apply = [](auto& fnptr, const Range& arg) { + return fnptr ? fnptr(arg) : arg; + }; + auto tformed_left_range = + apply(logger.gemm_left_range_transform, A.range()); + auto tformed_right_range = + apply(logger.gemm_right_range_transform, B.range()); + auto tformed_result_range = + apply(logger.gemm_result_range_transform, C.range()); + if ((!logger.gemm_result_range_filter || + logger.gemm_result_range_filter(tformed_result_range)) && + (!logger.gemm_left_range_filter || + logger.gemm_left_range_filter(tformed_left_range)) && + (!logger.gemm_right_range_filter || + logger.gemm_right_range_filter(tformed_right_range))) { + logger << "TA::Tensor::gemm+: left=" << tformed_left_range + << " right=" << tformed_right_range + << " result=" << tformed_result_range << std::endl; + if (TiledArray::TileOpsLogger::get_instance() + .gemm_print_contributions) { + if (!TiledArray::TileOpsLogger::get_instance() + .gemm_printer) { // default printer + // must use custom printer if result's range transformed + if (!logger.gemm_result_range_transform) + logger << *this << std::endl; + else + logger << make_map(C.data(), tformed_result_range) + << std::endl; + } else { + TiledArray::TileOpsLogger::get_instance().gemm_printer( + *logger.log, tformed_left_range, A.data(), + tformed_right_range, B.data(), tformed_right_range, + C.data()); + } + } + } + } + + if (twostep) { + for (size_t v = 0; v != tile_volume; ++v) { + C.data()[v] += data_copy[v]; + } + } + } +#else // TA_ENABLE_TILE_OPS_LOGGING + math::blas::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, k, + alpha, A.data(), lda, B.data(), ldb, + beta, C.data(), n); +#endif // TA_ENABLE_TILE_OPS_LOGGING + } + } + + +// template +// const typename Tensor::range_type Tensor::empty_range_; template bool operator==(const Tensor& a, const Tensor& b) { From 6f8fa7d2b618c4f4d9bf713e8ae39d9bddcc9b1e Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 18:26:47 -0400 Subject: [PATCH 02/12] perm --- src/TiledArray/permutation.h | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/TiledArray/permutation.h b/src/TiledArray/permutation.h index cb7610a6a0..806c885658 100644 --- a/src/TiledArray/permutation.h +++ b/src/TiledArray/permutation.h @@ -34,6 +34,7 @@ namespace TiledArray { // Forward declarations class Permutation; + bool operator==(const Permutation&, const Permutation&); std::ostream& operator<<(std::ostream&, const Permutation&); template @@ -74,6 +75,32 @@ inline void permute_array(const Perm& perm, const Arg& arg, Result& result) { result[pi] = arg[i]; } } + +template +void permute_n(size_t N, P p, In in, Out out, std::bool_constant) { + for (size_t k = 0; k < N; ++k) { + if constexpr (Inverse) { + out[*p++] = *in++; + } + else { + *out++ = in[*p++]; + } + } +} + +template +auto permute(const P &p, const S &s, std::bool_constant) { + // using std::size; + // using std::begin; + // size_t K = size(p); + // S r(K); + // detail::permute_n(K, begin(p), begin(s), begin(r), args...); + // return r; + if (!p) return s; + if constexpr (Inverse) return p.inv()*s; + else return p*s; +} + } // namespace detail /** @@ -604,6 +631,18 @@ inline std::vector operator*(const Permutation& perm, return result; } +template +S apply(const Permutation &p, const S &s) { + using detail::permute; + return permute(p, s, std::false_type{}); +} + +template +S apply_inverse(const Permutation &p, const S &s) { + using detail::permute; + return permute(p, s, std::true_type{}); +} + /////////////////////////////////// /// Permutation of a bipartite set From 11a8f898930e3fce967ab8b4ffd4e17cecd63b74 Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 19:55:56 -0400 Subject: [PATCH 03/12] tr --- src/TiledArray/tiled_range.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TiledArray/tiled_range.h b/src/TiledArray/tiled_range.h index 0aeb3ddf4f..80d1b9501f 100644 --- a/src/TiledArray/tiled_range.h +++ b/src/TiledArray/tiled_range.h @@ -90,7 +90,8 @@ class TiledRange { explicit TiledRange(const TRange1Range& range_of_trange1s) : range_(), elements_range_(), - ranges_(begin(range_of_trange1s), end(range_of_trange1s)) { + ranges_(std::begin(range_of_trange1s), std::end(range_of_trange1s)) + { init(); } @@ -287,6 +288,10 @@ class TiledRange { /// Tile dimension boundary array accessor + auto begin() const { return ranges_.begin(); } + auto end() const { return ranges_.end(); } + const auto& at(size_t idx) const { return ranges_.at(idx); } + /// \return A reference to the array of Range1 objects. /// \throw nothing const Ranges& data() const { return ranges_; } From d10c74a10008ebc15ff2c8f6738d95f35d8bef45 Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 14:00:36 -0400 Subject: [PATCH 04/12] index --- src/CMakeLists.txt | 2 + src/TiledArray/util/index.cpp | 25 +++ src/TiledArray/util/index.h | 302 ++++++++++++++++++++++++++++++++++ 3 files changed, 329 insertions(+) create mode 100644 src/TiledArray/util/index.cpp create mode 100644 src/TiledArray/util/index.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ca46d0a419..ec4ba3ac65 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -173,6 +173,8 @@ TiledArray/util/annotation.h TiledArray/util/backtrace.h TiledArray/util/bug.h TiledArray/util/function.h +TiledArray/util/index.h +TiledArray/util/index.cpp TiledArray/util/initializer_list.h TiledArray/util/logger.h TiledArray/util/random.h diff --git a/src/TiledArray/util/index.cpp b/src/TiledArray/util/index.cpp new file mode 100644 index 0000000000..28f74f491a --- /dev/null +++ b/src/TiledArray/util/index.cpp @@ -0,0 +1,25 @@ +// Samuel R. Powell, 2021 +#include "TiledArray/util/index.h" +#include "TiledArray/util/annotation.h" +#include "TiledArray/util/string.h" + +namespace TiledArray::index { + +std::vector validate(const std::vector &v) { + return v; +} + +small_vector tokenize(const std::string &s) { + // std::vector r; + // boost::split(r, s, boost::is_any_of(", \t")); + // return r; + auto r = detail::tokenize_index(s, ','); + if (r == std::vector{""}) return {}; + return small_vector (r.begin(), r.end()); // correct? +} + +std::string join(const small_vector &v) { + return string::join(v, ","); +} + +} // namespace TiledArray::index diff --git a/src/TiledArray/util/index.h b/src/TiledArray/util/index.h new file mode 100644 index 0000000000..e86560d970 --- /dev/null +++ b/src/TiledArray/util/index.h @@ -0,0 +1,302 @@ +// Samuel R. Powell, 2021 +#include "TiledArray/expressions/fwd.h" + +#include +#include +#include + +#include +#include + +namespace TiledArray::index { + +template +using small_vector = container::svector; + +small_vector tokenize(const std::string &s); + +small_vector validate(const small_vector &v); + +std::string join(const small_vector &v); + +template +using enable_if_string = std::enable_if_t< std::is_same_v, U>; + +/// an n-index, with n a runtime parameter +template +class Index { +public: + using container_type = small_vector; + using value_type = typename container_type::value_type; + + Index() = default; + Index(container_type &&s) : data_(std::move(s)) {} + + template + Index(const S &s) : data_(s.begin(), s.end()) {} + + template + Index(const std::string &s) : Index(index::tokenize(s)) {} + + template + Index(const char *s) : Index(std::string(s)) {} + + template + operator std::string() const { return index::join(data_); } + + explicit operator bool() const { return !data_.empty(); } + + bool operator==(const Index &other) const { + return (this->data_ == other.data_); + } + + bool operator!=(const Index &other) const { + return !(*this == other); + } + + size_t size() const { return data_.size(); } + + auto begin() const { return data_.begin(); } + auto end() const { return data_.end(); } + + auto find(const T& v) const { + return std::find(this->begin(), this->end(), v); + } + + const auto& operator[](size_t idx) const { return data_.at(idx); } + + size_t indexof(const T& v) const { + for (size_t i = 0; i < this->size(); ++i) { + if (this[i] == v) return i; + } + return -1; + } + + /// Returns true if argument exists in the Index object, else returns false + bool contains(const T& v) const { + return (this->find(v) != this->end()); + } + + private: + container_type data_; +}; + +template +std::ostream& operator<<(std::ostream& os, const Index &idx) { + os << std::string(idx); + return os; +} + +/// (stable) intersect of 2 Index objects +/// @param[in] a an Index object +/// @param[in] b an Index object +/// @pre a and b do not have duplicates +template +Index operator&(const Index &a, const Index &b) { + typename Index::container_type r; + for (const auto &s : a) { + if (!b.contains(s)) continue; + r.push_back(s); + } + return Index(r); +} +/// union of 2 Index objects +/// @param[in] a an Index object +/// @param[in] b an Index object +/// @pre a and b do not have duplicates +template +Index operator|(const Index &a, const Index &b) { + typename Index::container_type r; + r.assign(a.begin(), a.end()); + for (const auto &s : b) { + if (a.contains(s)) continue; + r.push_back(s); + } + return Index(r); +} + +/// concatenation of 2 Index objects +/// @param[in] a an Index object +/// @param[in] b an Index object +/// @note unline operator| @p a and @p b can have have duplicates +template +Index operator+(const Index &a, const Index &b) { + typename Index::container_type r; + r.assign(a.begin(), a.end()); + r.insert(r.end(), b.begin(), b.end()); + return Index(r); +} + +/// "difference" of 2 Index objects, i.e. elements of a that are not in b +/// @param[in] a an Index object +/// @param[in] b an Index object +/// @note unline operator& @p a and @p b can have have duplicates +template +Index operator-(const Index &a, const Index &b) { + typename Index::container_type r; + for (const auto &s : a) { + if (b.contains(s)) continue; + r.push_back(s); + } + return Index(r); +} + +/// elements that are exclusively in @p a or @p b +/// @param[in] a an Index object +/// @param[in] b an Index object +/// @pre a and b do not have duplicates +template +inline Index operator^(const Index &a, const Index &b) { + return (a | b) - (a & b); +} + +template +size_t rank(const Index &idx) { return idx.size(); } + +template +Index sorted(const Index& a) { + typename Index::container_type r(a.begin(), a.end()); + std::sort(r.begin(), r.end()); + return Index(r); +} + +template +Permutation permutation(const Index &s, const Index &p) { + assert(sorted(s) == sorted(p)); + small_vector m; + m.reserve(p.size()); + for (size_t i = 0; i != p.size(); ++i) { + m.push_back(s.indexof(p[i])); + } + return Permutation(m); +} + +template +auto permute(const Permutation &p, const Index &s, std::bool_constant) { + if (!p) return s; + using R = typename Index::container_type; + R r(p.size()); + detail::permute_n(p.size(), p.begin(), s.begin(), r.begin(), std::bool_constant{}); + return Index{r}; +} + +/// @brief Index-annotated collection of objects +/// @tparam Value +/// This is a map using Index::element_type as key +template +struct IndexMap { + + using key_type = K; + using value_type = V; + + IndexMap(const Index &keys, std::initializer_list s) + : IndexMap(keys, s.begin(), s.end()) {} + + template + IndexMap(const Index &keys, S &&s) + : IndexMap(keys, s.begin(), s.end()) {} + + template + IndexMap(const Index &keys, It begin, It end) { + auto it = begin; + data_.reserve(keys.size()); + for (auto &&key : keys) { + assert(it != end); + data_.emplace_back(std::pair{key, *it}); + ++it; + } + assert(it == end); + } + + IndexMap(const small_vector > &data) : data_(data) { } + + /// @return const iterator pointing to the element associated with @p key + auto find(const key_type &key) const { + return std::find_if( + data_.begin(), data_.end(), + [&key](const auto &v) { return key == v.first; } + ); + } + + /// @return reference to the element associated with @p key + /// @throw TA::Exception if @p key is not in this map + const auto& operator[](const key_type &key) const { + auto it = find(key); + if (it != data_.end()) return it->second; + throw TiledArray::Exception("IndexMap::at(key): key not found"); + } + + /// @param[in] idx an Index object + /// @return directly-addressable sequence of elements corresponding to the + /// keys in @p idx + auto operator[](const Index &idx) const { + small_vector result; + result.reserve(idx.size()); + for (auto &&key : idx) { + result.emplace_back(this->operator[](key)); + } + return result; + } + + auto begin() const { return data_.begin(); } + auto end() const { return data_.end(); } + + private: + small_vector< std::pair > data_; + +}; + +template +bool operator==(const IndexMap& lhs, const IndexMap& rhs) { + for (const auto& [k,v] : lhs) { + if (rhs.find(k) == rhs.end() || v != rhs[k]) return false; + } + for (const auto& [k,v] : rhs) { + if (lhs.find(k) == lhs.end()) return false; + } + return true; +} + +/// TODO to be filled by Sam +template +IndexMap operator|(const IndexMap &a, const IndexMap &b) { + small_vector< std::pair > d(a.begin(), a.end()); + for (const auto [k,v] : b) { + if (a.find(k) != a.end()) { + TA_ASSERT(a[k] == b[k]); + continue; + } + d.push_back(std::pair(k, v)); + } + return IndexMap(d); +} + +} // namespace TiledArray::index + +namespace TiledArray { + +using Index = TiledArray::index::Index; +using TiledArray::index::IndexMap; + +/// converts the annotation of an expression to an Index +template +auto idx(const std::string &s) { + if constexpr (detail::is_tensor_of_tensor_v) { + auto semi = std::find(s.begin(), s.end(), ';'); + assert(semi != s.end()); + auto first = std::string(s.begin(), semi); + auto second = std::string(semi+1, s.end()); + return std::tuple{ first, second }; + } + else { + return std::tuple{ s }; + } +} + +/// converts the annotation of an expression to an Index +template +auto idx(const expressions::TsrExpr &e) { + return idx(e.annotation()); +} + +} // namespace TiledArray From d21b59d4e79d447c1d0a0433473aae9f4650bc0a Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 13:16:20 -0400 Subject: [PATCH 05/12] range --- src/CMakeLists.txt | 1 + src/TiledArray/util/range.h | 130 ++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 src/TiledArray/util/range.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ec4ba3ac65..4bf0ec501d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -178,6 +178,7 @@ TiledArray/util/index.cpp TiledArray/util/initializer_list.h TiledArray/util/logger.h TiledArray/util/random.h +TiledArray/util/range.h TiledArray/util/singleton.h TiledArray/util/time.h TiledArray/util/vector.h diff --git a/src/TiledArray/util/range.h b/src/TiledArray/util/range.h new file mode 100644 index 0000000000..98c89324f7 --- /dev/null +++ b/src/TiledArray/util/range.h @@ -0,0 +1,130 @@ +#include + +#include +#include + +namespace TiledArray::range { + +template +using small_vector = container::svector; + +struct Range { + using value_type = int64_t; + using iterator = boost::counting_iterator; + template + explicit Range(Pair &&pair) : Range(pair.first, pair.second) {} + Range(value_type begin, value_type end) : begin_(begin), end_(end) {} + auto begin() const { return iterator(begin_); } + auto end() const { return iterator(end_); } + auto size() const { return end_ - begin_; } +protected: + const value_type begin_, end_; + +}; + +template > +struct RangeProduct { + + using ranges_type = std::vector; + +public: + + RangeProduct() = default; + RangeProduct(std::initializer_list ranges) : ranges_(ranges) {} + + RangeProduct& operator *= (R a) { + this->ranges_.push_back(a); + return *this; + } + + const auto& ranges() const { return ranges_; } + + struct iterator { + using iterator1 = decltype(std::begin(ranges_type{}[0])); + auto operator*() const { + T r; + for (auto &it : its_) { r.push_back(*it); } + return r; + } + bool operator!=(const iterator &other) const { + return !(this->p_ == other.p_ && this->its_ == other.its_); + } + iterator& operator++() { + size_t i = its_.size(); + auto &ranges = p_->ranges(); + while (i > 0) { + --i; + ++its_[i]; + if (i == 0) break; + if (its_[i] != std::end(ranges[i])) break; + its_[i] = std::begin(ranges[i]); + } + return *this; + } + private: + friend class RangeProduct; + explicit iterator(const RangeProduct *p, bool End = false) { + this->p_ = p; + using std::begin; + using std::end; + for (const auto& r : p->ranges()) { + auto it = (End ? end(r) : begin(r)); + its_.push_back(it); + End = false; + } + } + private: + const RangeProduct *p_; + small_vector its_; + }; + + auto begin() const { + return iterator(this); + } + + auto end() const { + return iterator(this, true); + } + +protected: + ranges_type ranges_; + +}; + +RangeProduct operator*(Range a, Range b){ + return RangeProduct({a, b}); +}; + +template +RangeProduct operator*(const RangeProduct& a, Range b) { + return RangeProduct(a) *= b; +}; + +template +void cartesian_foreach(const std::vector& rs, F f) { + using It = decltype(std::begin(rs[0])); + using T = typename R::value_type; + small_vector its, ends; + for (const auto& r : rs) { + its.push_back(std::begin(r)); + ends.push_back(std::end(r)); + } + while (its.front() != ends.front()) { + small_vector s; + s.reserve(its.size()); + for (auto& it : its) { + s.push_back(*it); + } + f(s); + size_t i = its.size(); + while (i > 0) { + --i; + ++its[i]; + if (i == 0) break; + if (its[i] != ends[i]) break; + its[i] = std::begin(rs[i]); + } + } +} + +} // namespace TiledArray::expressions From 330b4f0a00cd69ca42e8afdbaa27c695db09d94b Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 20:21:55 -0400 Subject: [PATCH 06/12] string --- src/CMakeLists.txt | 1 + src/TiledArray/util/string.h | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+) create mode 100644 src/TiledArray/util/string.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4bf0ec501d..a491a1ec12 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -180,6 +180,7 @@ TiledArray/util/logger.h TiledArray/util/random.h TiledArray/util/range.h TiledArray/util/singleton.h +TiledArray/util/string.h TiledArray/util/time.h TiledArray/util/vector.h ) diff --git a/src/TiledArray/util/string.h b/src/TiledArray/util/string.h new file mode 100644 index 0000000000..545da97259 --- /dev/null +++ b/src/TiledArray/util/string.h @@ -0,0 +1,41 @@ +// +// Created by Samuel R. Powell on 4/15/21. +// +#pragma once + +#ifndef TILEDARRAY_STRING_H +#define TILEDARRAY_STRING_H + +#include +#include +#include +#include +#include + +namespace TiledArray::string { + + // Split delimiter must match completely + std::vector split(const std::string& s, char d) { + std::vector res; + return boost::split(res, s, [&d](char c) { return c == d; } /*boost::is_any_of(d)*/); + } + + std::string trim(const std::string& s) { + return boost::trim_copy(s); + } + + template + std::string join(const T &s, const std::string& j = "") { + return boost::join(s, j); + } + + template + std::string str(const T& obj) { + std::stringstream ss; + ss << obj; + return ss.str(); + } + +} + +#endif //TILEDARRAY_STRING_H From 584e0162b40dafdf3d48a4d4aff1144f582ee641 Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 13:16:33 -0400 Subject: [PATCH 07/12] einsum --- src/TiledArray/expressions/einsum.h | 183 ++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 src/TiledArray/expressions/einsum.h diff --git a/src/TiledArray/expressions/einsum.h b/src/TiledArray/expressions/einsum.h new file mode 100644 index 0000000000..99a5176d59 --- /dev/null +++ b/src/TiledArray/expressions/einsum.h @@ -0,0 +1,183 @@ +#ifndef TILEDARRAY_EINSUM_H__INCLUDED +#define TILEDARRAY_EINSUM_H__INCLUDED + +#include "TiledArray/fwd.h" +#include "TiledArray/expressions/fwd.h" +#include "TiledArray/util/index.h" +#include "TiledArray/util/range.h" +#include "TiledArray/tiled_range1.h" +//#include "TiledArray/util/string.h" + +namespace TiledArray::expressions { + +/// einsum function without result indices assumes every index present +/// in both @p A and @p B is contracted, or, if there are no free indices, +/// pure Hadamard product is performed. +/// @param[in] A first argument to the product +/// @param[in] B second argument to the product +/// @warning just as in the plain expression code, reductions are a special +/// case; use Expr::reduce() +template +auto einsum(TsrExpr A, TsrExpr B) { + printf("einsum(A,B)\n"); + auto a = std::get<0>(idx(A)); + auto b = std::get<0>(idx(B)); + Array R; + R(a ^ b) = A * B; + return R; +} + +/// einsum function with result indices explicitly specified +/// @param[in] A first argument to the product +/// @param[in] B second argument to the product +/// @param[in] r result indices +/// @warning just as in the plain expression code, reductions are a special +/// case; use Expr::reduce() +template +auto einsum( + TsrExpr A, TsrExpr B, + const std::string &cs, + World &world = get_default_world()) +{ + return einsum(A, B, idx(cs), world); +} + +template +auto einsum( + TsrExpr A, TsrExpr B, + std::tuple cs, + World &world) +{ + + printf("einsum(A,B,c)\n"); + + auto a = std::get<0>(idx(A)); + auto b = std::get<0>(idx(B)); + Index c = std::get<0>(cs); + + struct { std::string a, b, c; } inner; + if constexpr (std::tuple_size::value == 2) { + inner.a = ";" + (std::string)std::get<1>(idx(A)); + inner.b = ";" + (std::string)std::get<1>(idx(B)); + inner.c = ";" + (std::string)std::get<1>(cs); + } + + // these are "Hadamard" (fused) indices + auto h = a & b & c; + + // no Hadamard indices => standard contraction (or even outer product) + // same a, b, and c => pure Hadamard + if (!h || (!(a ^ b) && !(b ^ c))) { + Array C; + C(std::string(c) + inner.c) = A*B; + return C; + } + + auto e = (a ^ b); + // contracted indices + auto i = (a & b) - h; + + TA_ASSERT(e); + TA_ASSERT(h); + + using range::Range; + using RangeProduct = range::RangeProduct >; + + using RangeMap = IndexMap; + auto range_map = ( + RangeMap(a, A.array().trange()) | + RangeMap(b, B.array().trange()) + ); + + using TiledArray::Permutation; + using TiledArray::index::permutation; + + struct Term { + Array array; + Index idx; + Permutation permutation; + RangeProduct tiles; + Array local; + std::string expr; + }; + + Term AB[2] = { { A.array(), a }, { B.array(), b } }; + + for (auto &term : AB) { + auto ei = (e+i & term.idx); + term.local = Array(world, TiledRange(range_map[ei])); + for (auto idx : ei) { + term.tiles *= Range(range_map[idx].tiles_range()); + } + if (term.idx != h+ei) { + term.permutation = permutation(term.idx, h+ei); + } + term.expr = ei; + } + + Term C = { Array(world, TiledRange(range_map[c])), c }; + for (auto idx : e) { + C.tiles *= Range(range_map[idx].tiles_range()); + } + if (C.idx != h+e) { + C.permutation = permutation(h+e, C.idx); + } + C.expr = e; + + AB[0].expr += inner.a; + AB[1].expr += inner.b; + C.expr += inner.c; + + struct { + RangeProduct tiles; + std::vector< std::vector > batch; + } H; + + for (auto idx : h) { + H.tiles *= Range(range_map[idx].tiles_range()); + H.batch.push_back({}); + for (auto r : range_map[idx]) { + H.batch.back().push_back(Range{r}.size()); + } + } + + // iterates over tiles of hadamard indices + using Index = index::Index; + for (Index h : H.tiles) { + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); + } + for (auto &term : AB) { + term.local = Array(term.local.world(), term.local.trange()); + const Permutation &P = term.permutation; + for (Index ei : term.tiles) { + auto tile = term.array.find(apply_inverse(P, h+ei)).get(); + if (P) tile = tile.permute(P); + auto shape = term.local.trange().tile(ei); + tile = tile.reshape(shape, batch); + term.local.set(ei, tile); + } + } + auto& [A,B] = AB; + C.local(C.expr) = A.local(A.expr) * B.local(B.expr); + const Permutation &P = C.permutation; + for (Index e : C.tiles) { + auto c = apply(P, h+e); + auto shape = C.array.trange().tile(c); + shape = apply_inverse(P, shape); + auto tile = C.local.find(e).get(); + assert(tile.batch_size() == batch); + tile = tile.reshape(shape); + if (P) tile = tile.permute(P); + C.array.set(c, tile); + } + } + + return C.array; + +} + +} // namespace TiledArray::expressions + +#endif /* TILEDARRAY_EINSUM_H__INCLUDED */ From 672095c464d1956d5d8a98ffeff7f64463cf1fb2 Mon Sep 17 00:00:00 2001 From: asadchev Date: Mon, 26 Apr 2021 14:00:08 -0400 Subject: [PATCH 08/12] einsum-test --- tests/einsum.cpp | 61 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 9 deletions(-) diff --git a/tests/einsum.cpp b/tests/einsum.cpp index ebd3637be4..5ea2f42151 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -16,11 +16,13 @@ * along with this program. If not, see . * */ -#include "TiledArray/expressions/contraction_helpers.h" +#include "TiledArray/expressions/einsum.h" #include "tiledarray.h" #include "unit_test_config.h" #include "tot_array_fixture.h" +#include "TiledArray/expressions/contraction_helpers.h" + using namespace TiledArray; using namespace TiledArray::expressions; @@ -83,8 +85,7 @@ BOOST_AUTO_TEST_CASE(ik_mn_eq_ij_mn_times_jk_mn){ matrix_il corr_il {{corr_elem_0_0, corr_elem_0_1, corr_elem_0_2, corr_elem_0_3, corr_elem_0_4, corr_elem_0_5},{corr_elem_1_0, corr_elem_1_1, corr_elem_1_2, corr_elem_1_3, corr_elem_1_4, corr_elem_1_5},{corr_elem_2_0, corr_elem_2_1, corr_elem_2_2, corr_elem_2_3, corr_elem_2_4, corr_elem_2_5},{corr_elem_3_0, corr_elem_3_1, corr_elem_3_2, corr_elem_3_3, corr_elem_3_4, corr_elem_3_5}}; TiledRange corr_trange{{0, 2, 4},{0, 2, 4, 6}}; dist_array_t corr(world, corr_trange, corr_il); - dist_array_t out; - einsum(out("i,k;m,n"), lhs("i,j;m,n"), rhs("j,k;m,n")); + dist_array_t out = einsum(lhs("i,j;m,n"), rhs("j,k;m,n"), "i,k;m,n"); const bool are_equal = ToTArrayFixture::are_equal(corr, out); BOOST_CHECK(are_equal); } @@ -146,8 +147,7 @@ BOOST_AUTO_TEST_CASE(ik_mn_eq_ij_mn_times_kj_mn){ matrix_il corr_il {{corr_elem_0_0, corr_elem_0_1, corr_elem_0_2, corr_elem_0_3, corr_elem_0_4, corr_elem_0_5},{corr_elem_1_0, corr_elem_1_1, corr_elem_1_2, corr_elem_1_3, corr_elem_1_4, corr_elem_1_5},{corr_elem_2_0, corr_elem_2_1, corr_elem_2_2, corr_elem_2_3, corr_elem_2_4, corr_elem_2_5},{corr_elem_3_0, corr_elem_3_1, corr_elem_3_2, corr_elem_3_3, corr_elem_3_4, corr_elem_3_5}}; TiledRange corr_trange{{0, 2, 4},{0, 2, 4, 6}}; dist_array_t corr(world, corr_trange, corr_il); - dist_array_t out; - einsum(out("i,k;m,n"), lhs("i,j;m,n"), rhs("k,j;m,n")); + dist_array_t out = einsum(lhs("i,j;m,n"), rhs("k,j;m,n"), "i,k;m,n"); const bool are_equal = ToTArrayFixture::are_equal(corr, out); BOOST_CHECK(are_equal); } @@ -189,8 +189,7 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mn_times_ij_mn){ matrix_il corr_il {{corr_elem_0_0, corr_elem_0_1},{corr_elem_1_0, corr_elem_1_1},{corr_elem_2_0, corr_elem_2_1},{corr_elem_3_0, corr_elem_3_1}}; TiledRange corr_trange{{0, 2, 4},{0, 2}}; dist_array_t corr(world, corr_trange, corr_il); - dist_array_t out; - einsum(out("i,j;m,n"), lhs("i,j;m,n"), rhs("i,j;m,n")); + dist_array_t out = einsum(lhs("i,j;m,n"), rhs("i,j;m,n"), "i,j;m,n"); const bool are_equal = ToTArrayFixture::are_equal(corr, out); BOOST_CHECK(are_equal); } @@ -232,11 +231,55 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mn_times_ji_mn){ matrix_il corr_il {{corr_elem_0_0, corr_elem_0_1},{corr_elem_1_0, corr_elem_1_1},{corr_elem_2_0, corr_elem_2_1},{corr_elem_3_0, corr_elem_3_1}}; TiledRange corr_trange{{0, 2, 4},{0, 2}}; dist_array_t corr(world, corr_trange, corr_il); - dist_array_t out; - einsum(out("i,j;m,n"), lhs("i,j;m,n"), rhs("j,i;m,n")); + dist_array_t out = einsum(lhs("i,j;m,n"), rhs("j,i;m,n"), "i,j;m,n"); const bool are_equal = ToTArrayFixture::are_equal(corr, out); BOOST_CHECK(are_equal); } +BOOST_AUTO_TEST_CASE(xxx) { + using dist_array_t = DistArray>, DensePolicy>; + using matrix_il = TiledArray::detail::matrix_il>; + auto& world = TiledArray::get_default_world(); + Tensor lhs_elem_0_0(Range{7, 2}, {15, 75, 54, 54, 72, 62, 97, 90, 17, 94, 19, 54, 13, 31}); + Tensor lhs_elem_0_1(Range{8, 3}, {82, 91, 60, 11, 47, 38, 87, 13, 72, 39, 59, 90, 26, 38, 2, 34, 30, 32, 46, 6, 26, 92, 47, 14}); + Tensor lhs_elem_1_0(Range{8, 3}, {53, 88, 72, 12, 58, 85, 55, 6, 50, 76, 51, 52, 77, 13, 4, 99, 30, 12, 16, 21, 60, 75, 55, 99}); + Tensor lhs_elem_1_1(Range{9, 4}, {16, 65, 6, 84, 85, 30, 97, 79, 2, 13, 4, 90, 32, 98, 88, 40, 25, 27, 8, 50, 56, 5, 42, 11, 20, 3, 51, 55, 32, 75, 8, 25, 4, 99, 75, 50}); + Tensor lhs_elem_2_0(Range{9, 4}, {39, 24, 23, 32, 10, 22, 94, 47, 85, 22, 77, 22, 92, 28, 61, 53, 21, 81, 57, 63, 37, 75, 93, 91, 24, 14, 56, 69, 42, 100, 17, 44, 78, 47, 33, 67}); + Tensor lhs_elem_2_1(Range{10, 5}, {93, 27, 38, 15, 87, 88, 48, 19, 54, 81, 6, 60, 70, 75, 1, 21, 34, 6, 74, 26, 5, 5, 75, 21, 31, 62, 53, 18, 17, 14, 19, 33, 96, 56, 94, 12, 30, 14, 94, 31, 25, 59, 72, 88, 66, 98, 56, 79, 11, 50}); + Tensor lhs_elem_3_0(Range{10, 5}, {49, 46, 13, 98, 77, 100, 23, 99, 77, 64, 10, 31, 10, 70, 30, 18, 89, 45, 81, 24, 45, 39, 83, 31, 3, 89, 35, 93, 70, 84, 43, 26, 96, 59, 57, 1, 3, 33, 27, 53, 33, 3, 53, 7, 80, 54, 47, 77, 62, 23}); + Tensor lhs_elem_3_1(Range{11, 6}, {27, 61, 27, 63, 45, 14, 80, 20, 73, 74, 74, 9, 59, 92, 5, 4, 78, 27, 53, 94, 70, 74, 1, 48, 30, 97, 51, 42, 93, 93, 81, 94, 73, 67, 23, 98, 58, 17, 75, 73, 92, 16, 59, 5, 82, 22, 43, 58, 68, 44, 27, 69, 79, 42, 99, 48, 78, 18, 9, 63, 1, 50, 9, 10, 82, 39}); + matrix_il lhs_il {{lhs_elem_0_0, lhs_elem_0_1},{lhs_elem_1_0, lhs_elem_1_1},{lhs_elem_2_0, lhs_elem_2_1},{lhs_elem_3_0, lhs_elem_3_1}}; + TiledRange lhs_trange{{0, 2, 4},{0, 2}}; + dist_array_t lhs(world, lhs_trange, lhs_il); + Tensor rhs_elem_0_0(Range{7, 2}, {55, 2, 99, 28, 98, 27, 80, 69, 1, 66, 5, 9, 1, 80}); + Tensor rhs_elem_0_1(Range{8, 3}, {19, 23, 52, 93, 6, 89, 68, 10, 4, 23, 24, 20, 99, 85, 81, 36, 82, 54, 36, 46, 26, 85, 15, 28}); + Tensor rhs_elem_0_2(Range{9, 4}, {57, 32, 86, 49, 55, 32, 100, 46, 2, 82, 84, 69, 63, 69, 12, 62, 21, 87, 1, 40, 61, 56, 90, 53, 74, 72, 5, 21, 49, 97, 69, 83, 48, 38, 88, 9}); + Tensor rhs_elem_0_3(Range{10, 5}, {28, 7, 4, 92, 30, 7, 3, 70, 16, 51, 71, 14, 37, 33, 92, 90, 75, 29, 52, 59, 15, 15, 96, 50, 39, 72, 22, 60, 56, 95, 45, 33, 25, 22, 23, 100, 26, 27, 38, 88, 89, 36, 48, 46, 6, 88, 16, 100, 54, 43}); + Tensor rhs_elem_1_0(Range{8, 3}, {55, 21, 79, 3, 77, 82, 65, 83, 66, 12, 100, 9, 40, 55, 8, 75, 82, 85, 100, 78, 39, 42, 65, 56}); + Tensor rhs_elem_1_1(Range{9, 4}, {45, 21, 58, 73, 57, 33, 27, 58, 56, 45, 88, 79, 78, 97, 23, 4, 87, 22, 9, 21, 54, 44, 81, 98, 53, 60, 29, 70, 83, 75, 30, 56, 61, 67, 18, 61}); + Tensor rhs_elem_1_2(Range{10, 5}, {11, 17, 75, 95, 66, 51, 95, 79, 10, 2, 43, 3, 85, 64, 67, 50, 32, 8, 48, 58, 35, 20, 87, 82, 40, 46, 70, 39, 46, 37, 38, 81, 87, 64, 31, 32, 7, 14, 94, 21, 33, 75, 67, 5, 80, 80, 36, 53, 99, 93}); + Tensor rhs_elem_1_3(Range{11, 6}, {64, 8, 79, 99, 13, 5, 64, 76, 2, 81, 78, 89, 88, 89, 83, 99, 71, 50, 18, 59, 91, 100, 91, 99, 20, 54, 72, 9, 43, 21, 61, 57, 18, 80, 12, 27, 95, 31, 92, 4, 6, 59, 27, 82, 98, 32, 82, 53, 52, 8, 31, 32, 38, 63, 32, 47, 24, 86, 64, 29, 86, 46, 96, 79, 48, 58}); + matrix_il rhs_il {{rhs_elem_0_0, rhs_elem_0_1, rhs_elem_0_2, rhs_elem_0_3},{rhs_elem_1_0, rhs_elem_1_1, rhs_elem_1_2, rhs_elem_1_3}}; + TiledRange rhs_trange{{0, 2},{0, 2, 4}}; + dist_array_t rhs(world, rhs_trange, rhs_il); + Tensor corr_elem_0_0(Range{7, 2}, {825, 150, 5346, 1512, 7056, 1674, 7760, 6210, 17, 6204, 95, 486, 13, 2480}); + Tensor corr_elem_0_1(Range{8, 3}, {4510, 1911, 4740, 33, 3619, 3116, 5655, 1079, 4752, 468, 5900, 810, 1040, 2090, 16, 2550, 2460, 2720, 4600, 468, 1014, 3864, 3055, 784}); + Tensor corr_elem_1_0(Range{8, 3}, {1007, 2024, 3744, 1116, 348, 7565, 3740, 60, 200, 1748, 1224, 1040, 7623, 1105, 324, 3564, 2460, 648, 576, 966, 1560, 6375, 825, 2772}); + Tensor corr_elem_1_1(Range{9, 4}, {720, 1365, 348, 6132, 4845, 990, 2619, 4582, 112, 585, 352, 7110, 2496, 9506, 2024, 160, 2175, 594, 72, 1050, 3024, 220, 3402, 1078, 1060, 180, 1479, 3850, 2656, 5625, 240, 1400, 244, 6633, 1350, 3050}); + Tensor corr_elem_2_0(Range{9, 4}, {2223, 768, 1978, 1568, 550, 704, 9400, 2162, 170, 1804, 6468, 1518, 5796, 1932, 732, 3286, 441, 7047, 57, 2520, 2257, 4200, 8370, 4823, 1776, 1008, 280, 1449, 2058, 9700, 1173, 3652, 3744, 1786, 2904, 603}); + Tensor corr_elem_2_1(Range{10, 5}, {1023, 459, 2850, 1425, 5742, 4488, 4560, 1501, 540, 162, 258, 180, 5950, 4800, 67, 1050, 1088, 48, 3552, 1508, 175, 100, 6525, 1722, 1240, 2852, 3710, 702, 782, 518, 722, 2673, 8352, 3584, 2914, 384, 210, 196, 8836, 651, 825, 4425, 4824, 440, 5280, 7840, 2016, 4187, 1089, 4650}); + Tensor corr_elem_3_0(Range{10, 5}, {1372, 322, 52, 9016, 2310, 700, 69, 6930, 1232, 3264, 710, 434, 370, 2310, 2760, 1620, 6675, 1305, 4212, 1416, 675, 585, 7968, 1550, 117, 6408, 770, 5580, 3920, 7980, 1935, 858, 2400, 1298, 1311, 100, 78, 891, 1026, 4664, 2937, 108, 2544, 322, 480, 4752, 752, 7700, 3348, 989}); + Tensor corr_elem_3_1(Range{11, 6}, {1728, 488, 2133, 6237, 585, 70, 5120, 1520, 146, 5994, 5772, 801, 5192, 8188, 415, 396, 5538, 1350, 954, 5546, 6370, 7400, 91, 4752, 600, 5238, 3672, 378, 3999, 1953, 4941, 5358, 1314, 5360, 276, 2646, 5510, 527, 6900, 292, 552, 944, 1593, 410, 8036, 704, 3526, 3074, 3536, 352, 837, 2208, 3002, 2646, 3168, 2256, 1872, 1548, 576, 1827, 86, 2300, 864, 790, 3936, 2262}); + matrix_il corr_il {{corr_elem_0_0, corr_elem_0_1},{corr_elem_1_0, corr_elem_1_1},{corr_elem_2_0, corr_elem_2_1},{corr_elem_3_0, corr_elem_3_1}}; + TiledRange corr_trange{{0, 2, 4},{0, 2}}; + dist_array_t corr(world, corr_trange, corr_il); + dist_array_t r0 = einsum(lhs("i,h;m,n"), rhs("h,i;m,n"), "i,h;m,n"); + dist_array_t r1; + einsum(r1("i,h;m,n"), lhs("i,h;m,n"), rhs("h,i;m,n")); + + const bool are_equal = ToTArrayFixture::are_equal(r0, r1); + BOOST_CHECK(are_equal); +} + BOOST_AUTO_TEST_SUITE_END() From 92281e4ec6f76f723bff4444e414e42d5b12102f Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Wed, 16 Jun 2021 16:25:15 -0400 Subject: [PATCH 09/12] added missing #include + cleanup --- src/TiledArray/expressions/einsum.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/TiledArray/expressions/einsum.h b/src/TiledArray/expressions/einsum.h index 99a5176d59..3ecdd9ccd7 100644 --- a/src/TiledArray/expressions/einsum.h +++ b/src/TiledArray/expressions/einsum.h @@ -6,6 +6,7 @@ #include "TiledArray/util/index.h" #include "TiledArray/util/range.h" #include "TiledArray/tiled_range1.h" +#include "TiledArray/tiled_range.h" //#include "TiledArray/util/string.h" namespace TiledArray::expressions { @@ -49,8 +50,6 @@ auto einsum( World &world) { - printf("einsum(A,B,c)\n"); - auto a = std::get<0>(idx(A)); auto b = std::get<0>(idx(B)); Index c = std::get<0>(cs); From 8e9bdef30515f3e3c45941af0dbe9eb78aa6044d Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Tue, 31 Aug 2021 17:48:28 -0400 Subject: [PATCH 10/12] revived gemm logging --- src/TiledArray/tensor/tensor.h | 63 +++++++++++++++++++--------------- src/TiledArray/util/logger.h | 2 +- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 1a2a8666c6..948e780978 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -153,7 +153,7 @@ class Tensor { std::uninitialized_default_construct_n(ptr, size); // std::uninitialized_value_construct_n(ptr, size); } - auto deleter = [allocator = std::move(allocator), + auto deleter = [this, allocator = std::move(allocator), size](auto&& ptr) mutable { std::destroy_n(ptr, size); allocator.deallocate(ptr, size); @@ -165,7 +165,7 @@ class Tensor { } Tensor(range_type&& range, size_t batch_size, bool default_construct) - : range_(strd::move(range)), batch_size_(batch_size) { + : range_(std::move(range)), batch_size_(batch_size) { size_t size = range_.volume() * batch_size; allocator_type allocator; auto* ptr = allocator.allocate(size); @@ -176,7 +176,7 @@ class Tensor { std::uninitialized_default_construct_n(ptr, size); // std::uninitialized_value_construct_n(ptr, size); } - auto deleter = [allocator = std::move(allocator), + auto deleter = [this, allocator = std::move(allocator), size](auto&& ptr) mutable { std::destroy_n(ptr, size); allocator.deallocate(ptr, size); @@ -1700,15 +1700,16 @@ class Tensor { std::unique_ptr data_copy; size_t tile_volume; if (twostep) { - tile_volume = range().volume(); + tile_volume = range().volume() * batch_size(); data_copy = std::make_unique(tile_volume); - std::copy(pimpl_->data_, pimpl_->data_ + tile_volume, data_copy.get()); + std::copy(data_.get(), data_.get() + tile_volume, data_copy.get()); + } + for (size_t i = 0; i < this->batch_size(); ++i) { + auto Ci = this->batch(i); + TiledArray::gemm(alpha, A.batch(i), B.batch(i), + twostep ? numeric_type(0) : numeric_type(1), Ci, + gemm_helper); } - non_distributed::gemm(gemm_helper.left_op(), gemm_helper.right_op(), m, n, - k, factor, left.data(), lda, right.data(), ldb, - twostep ? numeric_type(0) : numeric_type(1), - pimpl_->data_, n); - if (TiledArray::TileOpsLogger::get_instance_ptr() != nullptr && TiledArray::TileOpsLogger::get_instance().gemm) { auto& logger = TiledArray::TileOpsLogger::get_instance(); @@ -1716,11 +1717,11 @@ class Tensor { return fnptr ? fnptr(arg) : arg; }; auto tformed_left_range = - apply(logger.gemm_left_range_transform, left.range()); + apply(logger.gemm_left_range_transform, A.range()); auto tformed_right_range = - apply(logger.gemm_right_range_transform, right.range()); + apply(logger.gemm_right_range_transform, B.range()); auto tformed_result_range = - apply(logger.gemm_result_range_transform, pimpl_->range_); + apply(logger.gemm_result_range_transform, this->range_); if ((!logger.gemm_result_range_filter || logger.gemm_result_range_filter(tformed_result_range)) && (!logger.gemm_left_range_filter || @@ -1738,13 +1739,13 @@ class Tensor { if (!logger.gemm_result_range_transform) logger << *this << std::endl; else - logger << make_map(pimpl_->data_, tformed_result_range) + logger << make_map(this->data_.get(), tformed_result_range) << std::endl; } else { TiledArray::TileOpsLogger::get_instance().gemm_printer( - *logger.log, tformed_left_range, left.data(), - tformed_right_range, right.data(), tformed_right_range, - pimpl_->data_); + *logger.log, tformed_left_range, A.data(), + tformed_right_range, B.data(), tformed_right_range, + this->data(), this->batch_size()); } } } @@ -1752,7 +1753,7 @@ class Tensor { if (twostep) { for (size_t v = 0; v != tile_volume; ++v) { - pimpl_->data_[v] += data_copy[v]; + this->data_.get()[v] += data_copy[v]; } } } @@ -2091,13 +2092,14 @@ template void gemm(Alpha alpha, const Tensor& A, const Tensor& B, Beta beta, Tensor& C, const math::GemmHelper& gemm_helper) { - // static_assert( - // !detail::is_tensor_of_tensor_v, Tensor>, - // "TA::Tensor::gemm without custom element op is only applicable to " - // "plain tensors"); + static_assert( + !detail::is_tensor_of_tensor_v, Tensor, + Tensor>, + "TA::Tensor::gemm without custom element op is only applicable to " + "plain tensors"); { - // Check that this tensor is not empty and has the correct rank - // TA_ASSERT(pimpl_); + // Check that tensor C is not empty and has the correct rank + TA_ASSERT(!C.empty()); TA_ASSERT(C.range().rank() == gemm_helper.result_rank()); // Check that the arguments are not empty and have the correct ranks @@ -2106,6 +2108,10 @@ void gemm(Alpha alpha, const Tensor& A, const Tensor& B, TA_ASSERT(!B.empty()); TA_ASSERT(B.range().rank() == gemm_helper.right_rank()); + TA_ASSERT(A.batch_size() == 1); + TA_ASSERT(B.batch_size() == 1); + TA_ASSERT(C.batch_size() == 1); + // Check that the outer dimensions of left match the corresponding // dimensions in result TA_ASSERT(ignore_tile_position() || @@ -2152,13 +2158,15 @@ void gemm(Alpha alpha, const Tensor& A, const Tensor& B, // may need to split gemm into multiply + accumulate for tracing purposes #ifdef TA_ENABLE_TILE_OPS_LOGGING { + using numeric_type = typename Tensor::numeric_type; + using T = numeric_type; const bool twostep = TiledArray::TileOpsLogger::get_instance().gemm && TiledArray::TileOpsLogger::get_instance().gemm_print_contributions; std::unique_ptr data_copy; size_t tile_volume; if (twostep) { - tile_volume = range().volume(); + tile_volume = C.range().volume(); data_copy = std::make_unique(tile_volume); std::copy(C.data(), C.data() + tile_volume, data_copy.get()); } @@ -2193,13 +2201,14 @@ void gemm(Alpha alpha, const Tensor& A, const Tensor& B, .gemm_printer) { // default printer // must use custom printer if result's range transformed if (!logger.gemm_result_range_transform) - logger << *this << std::endl; + logger << C << std::endl; else logger << make_map(C.data(), tformed_result_range) << std::endl; } else { TiledArray::TileOpsLogger::get_instance().gemm_printer( *logger.log, tformed_left_range, A.data(), - tformed_right_range, B.data(), tformed_right_range, C.data()); + tformed_right_range, B.data(), tformed_right_range, C.data(), + C.batch_size()); } } } diff --git a/src/TiledArray/util/logger.h b/src/TiledArray/util/logger.h index abbc505172..f33f96a35f 100644 --- a/src/TiledArray/util/logger.h +++ b/src/TiledArray/util/logger.h @@ -41,7 +41,7 @@ struct TileOpsLogger : public Singleton> { using range_filter_t = std::function; using gemm_printer_t = std::function; + const T*, const Range&, const T*, std::size_t)>; // GEMM task logging bool gemm = false; From 76f8ee1c13fb686681b647a86d4be67560832382 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Wed, 1 Sep 2021 09:54:14 -0400 Subject: [PATCH 11/12] dox++ --- src/TiledArray/tensor/tensor.h | 26 ++++++++++++++++++++++---- src/TiledArray/tiled_range.h | 24 ++++++++++++------------ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 948e780978..bcb21812a5 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -55,6 +55,9 @@ struct TraceIsDefined, enable_if_numeric_t> : std::true_type {}; /// An N-dimensional tensor object +/// A contiguous row-major tensor with shallow-copy semantics. +/// As of TiledArray 1.1 Tensor represents a batch of tensors with same Range +/// (the default batch size = 1). /// \tparam T the value type of this tensor /// \tparam A The allocator type for the data template @@ -190,14 +193,15 @@ class Tensor { /// Construct a tensor with a range equal to \c range. The data is /// uninitialized. /// \param range The range of the tensor + /// \param batch_size The batch size + /// \param data shared pointer to the data Tensor(const range_type& range, size_t batch_size, std::shared_ptr data) : range_(range), batch_size_(batch_size), data_(data) {} range_type range_; ///< range size_t batch_size_ = 1; - std::shared_ptr - data_; ///< Shared pointer to implementation object + std::shared_ptr data_; ///< Shared pointer to the data public: // Compiler generated functions @@ -401,8 +405,15 @@ class Tensor { } } + /// The batch size accessor + + /// @return the size of tensor batch represented by `*this` size_t batch_size() const { return this->batch_size_; } + /// @param[in] idx the batch index + /// @pre `idx < this->batch_size()` + /// @return (plain, i.e. batch_size=1) Tensor representing element \p idx of + /// the batch Tensor batch(size_t idx) const { TA_ASSERT(idx < this->batch_size()); std::shared_ptr data(this->data_, @@ -410,12 +421,19 @@ class Tensor { return Tensor(this->range(), 1, data); } + /// Returns Tensor representing the data using another range and batch size + + /// @param[in] range the Range of the result + /// @param[in] batch_size the batch size of the result + /// @return Tensor object representing `this->data()` using @p range and @p + /// batch_size auto reshape(const range_type& range, size_t batch_size = 1) const { TA_ASSERT(this->range().volume() * this->batch_size() == range.volume() * batch_size); return Tensor(range, batch_size, this->data_); } + /// @return a deep copy of `*this` Tensor clone() const { Tensor result; if (data_) { @@ -650,12 +668,12 @@ class Tensor { /// \return An iterator to the last data element iterator end() { return (this->data() ? this->data() + this->size() : NULL); } - /// Data direct access + /// Read-only access to the data /// \return A const pointer to the tensor data const_pointer data() const { return this->data_.get(); } - /// Data direct access + /// Mutable access to the data /// \return A const pointer to the tensor data pointer data() { return this->data_.get(); } diff --git a/src/TiledArray/tiled_range.h b/src/TiledArray/tiled_range.h index 653c107ed2..7f2a944bd1 100644 --- a/src/TiledArray/tiled_range.h +++ b/src/TiledArray/tiled_range.h @@ -27,8 +27,9 @@ namespace TiledArray { /// Range data of a tiled array -/// TiledRange is a direct (Cartesian) product of 1-dimensional tiled ranges -/// (TiledRange1) +/// TiledRange is a direct (Cartesian) product of 1-dimensional tiled ranges, +/// represented as TiledRange1 objects. Thus TiledRange is a semantically +/// contiguous (C++) range of TiledRange1 objects. class TiledRange { private: /// Constructed with a set of ranges pointed to by [ first, last ). @@ -90,8 +91,7 @@ class TiledRange { explicit TiledRange(const TRange1Range& range_of_trange1s) : range_(), elements_range_(), - ranges_(std::begin(range_of_trange1s), std::end(range_of_trange1s)) - { + ranges_(std::begin(range_of_trange1s), std::end(range_of_trange1s)) { init(); } @@ -303,24 +303,24 @@ class TiledRange { } template >::type* = nullptr> + typename std::enable_if< + madness::is_input_archive_v>::type* = nullptr> void serialize(const Archive& ar) { ar& range_& elements_range_& ranges_; } template >::type* = nullptr> + typename std::enable_if< + madness::is_output_archive_v>::type* = nullptr> void serialize(const Archive& ar) const { ar& range_& elements_range_& ranges_; } private: - range_type range_; ///< Stores information on tile indexing for the range. - range_type elements_range_; ///< Stores information on element indexing for - ///< the range. - Ranges ranges_; ///< Stores tile boundaries for each dimension. + range_type range_; ///< Range of tile indices + range_type elements_range_; ///< Range of element indices + Ranges ranges_; ///< tiled (1d) range, aka TiledRange1, for each mode + ///< `*this` is a direct product of these tilings }; /// TiledRange permutation operator. From 373678e6c5c0185ec47dd88219924614f8371775 Mon Sep 17 00:00:00 2001 From: Eduard Valeyev Date: Wed, 1 Sep 2021 09:55:30 -0400 Subject: [PATCH 12/12] Define TiledRange::{const_iterator,value_type} for older boost to be able to compare it as a collection --- src/TiledArray/tiled_range.h | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/TiledArray/tiled_range.h b/src/TiledArray/tiled_range.h index 7f2a944bd1..51f3f3992c 100644 --- a/src/TiledArray/tiled_range.h +++ b/src/TiledArray/tiled_range.h @@ -70,6 +70,11 @@ class TiledRange { static_assert(std::is_same_v); typedef container::svector Ranges; + /// TiledRange is a contiguous C++ range of TiledRange1 objects + using const_iterator = typename Ranges::const_iterator; + /// TiledRange is a contiguous C++ range of TiledRange1 objects + using value_type = typename Ranges::value_type; + /// Default constructor TiledRange() : range_(), elements_range_(), ranges_() {} @@ -286,13 +291,18 @@ class TiledRange { return ranges_[d]; } - /// Tile dimension boundary array accessor + /// \return iterator pointing to the beginning of the range of TiledRange1 + /// objects + const_iterator begin() const { return ranges_.begin(); } + + /// \return iterator pointing to the end of the range of TiledRange1 objects + const_iterator end() const { return ranges_.end(); } - auto begin() const { return ranges_.begin(); } - auto end() const { return ranges_.end(); } - const auto& at(size_t idx) const { return ranges_.at(idx); } + /// \param[in] d mode index + /// \return const reference to the TiledRange1 object for mode \p d + const TiledRange1& at(size_t d) const { return ranges_.at(d); } - /// \return A reference to the array of Range1 objects. + /// \return A reference to the array of TiledRange1 objects. /// \throw nothing const Ranges& data() const { return ranges_; }