diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bc24d9494a..f5ed90793b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -89,6 +89,10 @@ TiledArray/dist_eval/binary_eval.h TiledArray/dist_eval/contraction_eval.h TiledArray/dist_eval/dist_eval.h TiledArray/dist_eval/unary_eval.h +TiledArray/einsum/index.h +TiledArray/einsum/index.cpp +TiledArray/einsum/range.h +TiledArray/einsum/string.h TiledArray/expressions/add_engine.h TiledArray/expressions/add_expr.h TiledArray/expressions/binary_engine.h @@ -179,14 +183,10 @@ TiledArray/util/annotation.h TiledArray/util/backtrace.h TiledArray/util/bug.h TiledArray/util/function.h -TiledArray/util/index.h -TiledArray/util/index.cpp TiledArray/util/initializer_list.h TiledArray/util/logger.h TiledArray/util/random.h -TiledArray/util/range.h TiledArray/util/singleton.h -TiledArray/util/string.h TiledArray/util/threads.h TiledArray/util/threads.cpp TiledArray/util/time.h diff --git a/src/TiledArray/einsum/eigen.h b/src/TiledArray/einsum/eigen.h new file mode 100644 index 0000000000..2a47ba1b29 --- /dev/null +++ b/src/TiledArray/einsum/eigen.h @@ -0,0 +1,102 @@ +#ifndef TILEDARRAY_EINSUM_EIGEN_H__INCLUDED +#define TILEDARRAY_EINSUM_EIGEN_H__INCLUDED + +#include "TiledArray/fwd.h" +#include "TiledArray/external/eigen.h" +#include "TiledArray/einsum/index.h" +#include "TiledArray/einsum/range.h" +#include "TiledArray/einsum/string.h" + +namespace Eigen { + +template +const Derived& derived(const TensorBase &t) { + return static_cast(t); +} + +template +void einsum( + std::string expr, + const Eigen::TensorBase &A, + const Eigen::TensorBase &B, + TC &C) +{ + + static_assert((TA::NumDimensions+TB::NumDimensions) >= TC::NumDimensions); + //static_assert((TA::NumDimensions+TB::NumDimensions)%2 == TC::NumDimensions%2); + + using Index = TiledArray::Einsum::Index; + using IndexDims = TiledArray::Einsum::IndexMap; + using TiledArray::Einsum::string::split2; + + auto permutation = [](auto src, auto dst) { + return TiledArray::Einsum::index::permutation(dst, src); + }; + + Index a,b,c; + std::string ab; + std::tie(ab,c) = split2(expr, "->"); + std::tie(a,b) = split2(ab, ","); + + // these are "Hadamard" (fused) indices + auto h = a & b & c; + + auto e = (a ^ b); + auto he = h+e; + + // contracted indices + auto i = (a & b) - h; + + eigen_assert(a.size() == A.NumDimensions); + eigen_assert(b.size() == B.NumDimensions); + eigen_assert(c.size() == C.NumDimensions); + eigen_assert(he.size() == C.NumDimensions); + + IndexDims dimensions = ( + IndexDims(a, derived(A).dimensions()) | + IndexDims(b, derived(B).dimensions()) + ); + + auto product = [](auto &&dims) { + int64_t n = 1; + for (auto dim : dims) { n *= dim; } + return n; + }; + + int64_t nh = product(dimensions[h]); + int64_t na = product(dimensions[a&e]); + int64_t nb = product(dimensions[b&e]); + int64_t ni = product(dimensions[i]); + + auto pA = A.shuffle(permutation(a, h+(e&a)+i)).reshape(std::array{nh,na,ni}).eval(); + auto pB = B.shuffle(permutation(b, h+(e&b)+i)).reshape(std::array{nh,nb,ni}).eval(); + Eigen::Tensor C3(nh,na,nb); + + for (int64_t h = 0; h < nh; ++h) { + //Eigen::array, 1> axis = { Eigen::IndexPair(0, 0) }; + C3.chip(h,0) = pA.chip(h,0).contract(pB.chip(h,0), std::array{ std::pair{1,1} }); + } + + std::array permuted_shape; + for (int k = 0; k < permuted_shape.size(); ++k) { + permuted_shape[k] = dimensions[he[k]]; + } + + C = C3.reshape(permuted_shape).shuffle(permutation(he, c)); + +} + +template +T einsum( + std::string expr, + const Eigen::TensorBase &A, + const Eigen::TensorBase &B) +{ + T AB; + einsum(expr, A, B, AB); + return AB; +} + +} // namespace Eigen + +#endif /* TILEDARRAY_EINSUM_EIGEN_H__INCLUDED */ diff --git a/src/TiledArray/util/index.cpp b/src/TiledArray/einsum/index.cpp similarity index 78% rename from src/TiledArray/util/index.cpp rename to src/TiledArray/einsum/index.cpp index 28f74f491a..e8bb636f54 100644 --- a/src/TiledArray/util/index.cpp +++ b/src/TiledArray/einsum/index.cpp @@ -1,9 +1,9 @@ // Samuel R. Powell, 2021 -#include "TiledArray/util/index.h" +#include "TiledArray/einsum/index.h" +#include "TiledArray/einsum/string.h" #include "TiledArray/util/annotation.h" -#include "TiledArray/util/string.h" -namespace TiledArray::index { +namespace TiledArray::Einsum::index { std::vector validate(const std::vector &v) { return v; @@ -22,4 +22,4 @@ std::string join(const small_vector &v) { return string::join(v, ","); } -} // namespace TiledArray::index +} // namespace TiledArray::Einsum::index diff --git a/src/TiledArray/util/index.h b/src/TiledArray/einsum/index.h similarity index 91% rename from src/TiledArray/util/index.h rename to src/TiledArray/einsum/index.h index 3652da6e66..ee00425964 100644 --- a/src/TiledArray/util/index.h +++ b/src/TiledArray/einsum/index.h @@ -1,24 +1,25 @@ -// Samuel R. Powell, 2021 +#ifndef TILEDARRAY_EINSUM_INDEX_H__INCLUDED +#define TILEDARRAY_EINSUM_INDEX_H__INCLUDED + #include "TiledArray/expressions/fwd.h" #include #include #include +#include #include #include -namespace TiledArray::index { +namespace TiledArray::Einsum::index { template -using small_vector = container::svector; +using small_vector = TiledArray::container::svector; small_vector tokenize(const std::string &s); small_vector validate(const small_vector &v); -std::string join(const small_vector &v); - template using enable_if_string = std::enable_if_t, U>; @@ -43,7 +44,7 @@ class Index { template operator std::string() const { - return index::join(data_); + return string::join(data_, ","); } explicit operator bool() const { return !data_.empty(); } @@ -276,21 +277,23 @@ IndexMap operator|(const IndexMap &a, const IndexMap &b) { return IndexMap(d); } -} // namespace TiledArray::index +} // namespace TiledArray::Einsum::index -namespace TiledArray { +namespace TiledArray::Einsum { -using Index = TiledArray::index::Index; -using TiledArray::index::IndexMap; +using TiledArray::Einsum::index::Index; +using TiledArray::Einsum::index::IndexMap; /// converts the annotation of an expression to an Index template auto idx(const std::string &s) { + using Index = Einsum::Index; if constexpr (detail::is_tensor_of_tensor_v) { auto semi = std::find(s.begin(), s.end(), ';'); TA_ASSERT(semi != s.end()); - auto first = std::string(s.begin(), semi); - auto second = std::string(semi + 1, s.end()); + auto [first,second] = string::split2(s, ";"); + TA_ASSERT(!first.empty()); + TA_ASSERT(!second.empty()); return std::tuple{first, second}; } else { return std::tuple{s}; @@ -299,8 +302,10 @@ auto idx(const std::string &s) { /// converts the annotation of an expression to an Index template -auto idx(const expressions::TsrExpr &e) { +auto idx(const TiledArray::expressions::TsrExpr &e) { return idx(e.annotation()); } -} // namespace TiledArray +} // namespace TiledArray::Einsum + +#endif /* TILEDARRAY_EINSUM_INDEX_H__INCLUDED */ diff --git a/src/TiledArray/util/range.h b/src/TiledArray/einsum/range.h similarity index 90% rename from src/TiledArray/util/range.h rename to src/TiledArray/einsum/range.h index df2f8377b3..dc5accf0cb 100644 --- a/src/TiledArray/util/range.h +++ b/src/TiledArray/einsum/range.h @@ -1,9 +1,12 @@ +#ifndef TILEDARRAY_EINSUM_RANGE_H__INCLUDED +#define TILEDARRAY_EINSUM_RANGE_H__INCLUDED + #include #include #include -namespace TiledArray::range { +namespace TiledArray::Einsum::range { template using small_vector = container::svector; @@ -127,4 +130,11 @@ void cartesian_foreach(const std::vector& rs, F f) { } } -} // namespace TiledArray::expressions +} // namespace TiledArray::Einsum::range + +namespace TiledArray::Einsum { + using range::Range; + using range::RangeProduct; +} + +#endif /* TILEDARRAY_EINSUM_RANGE_H__INCLUDED */ diff --git a/src/TiledArray/einsum/string.h b/src/TiledArray/einsum/string.h new file mode 100644 index 0000000000..b2998a4f7e --- /dev/null +++ b/src/TiledArray/einsum/string.h @@ -0,0 +1,50 @@ +#ifndef TILEDARRAY_EINSUM_STRING_H +#define TILEDARRAY_EINSUM_STRING_H + +#include +#include +#include +#include +#include + +namespace TiledArray::Einsum::string { +namespace { + + // Split delimiter must match completely + template + std::pair split2(const std::string& s, const std::string &d) { + auto pos = s.find(d); + if (pos == s.npos) return { T(s), U("") }; + return { T(s.substr(0,pos)), U(s.substr(pos+d.size())) }; + } + + // Split delimiter must match completely + std::vector split(const std::string& s, char d) { + std::vector res; + return boost::split(res, s, [&d](char c) { return c == d; } /*boost::is_any_of(d)*/); + } + + std::string trim(const std::string& s) { + return boost::trim_copy(s); + } + + template + std::string str(const T& obj) { + std::stringstream ss; + ss << obj; + return ss.str(); + } + + template + std::string join(const T &s, const U& j = U("")) { + std::vector strings; + for (auto e : s) { + strings.push_back(str(e)); + } + return boost::join(strings, j); + } + +} +} + +#endif //TILEDARRAY_EINSUM_STRING_H diff --git a/src/TiledArray/einsum/tiledarray.h b/src/TiledArray/einsum/tiledarray.h new file mode 100644 index 0000000000..50afa01152 --- /dev/null +++ b/src/TiledArray/einsum/tiledarray.h @@ -0,0 +1,214 @@ +#ifndef TILEDARRAY_EINSUM_H__INCLUDED +#define TILEDARRAY_EINSUM_H__INCLUDED + +#include "TiledArray/fwd.h" +#include "TiledArray/expressions/fwd.h" +#include "TiledArray/einsum/index.h" +#include "TiledArray/einsum/range.h" +#include "TiledArray/tiled_range1.h" +#include "TiledArray/tiled_range.h" +//#include "TiledArray/util/string.h" + +namespace TiledArray::expressions { + +/// einsum function without result indices assumes every index present +/// in both @p A and @p B is contracted, or, if there are no free indices, +/// pure Hadamard product is performed. +/// @param[in] A first argument to the product +/// @param[in] B second argument to the product +/// @warning just as in the plain expression code, reductions are a special +/// case; use Expr::reduce() +template +auto einsum(expressions::TsrExpr A, expressions::TsrExpr B) { + //printf("einsum(A,B)\n"); + auto a = std::get<0>(idx(A)); + auto b = std::get<0>(idx(B)); + return einsum(A, B, std::string(a^b)); +} + +/// einsum function with result indices explicitly specified +/// @param[in] A first argument to the product +/// @param[in] B second argument to the product +/// @param[in] r result indices +/// @warning just as in the plain expression code, reductions are a special +/// case; use Expr::reduce() +template +auto einsum( + expressions::TsrExpr A, + expressions::TsrExpr B, + const std::string &cs, + World &world = get_default_world()) +{ + static_assert(std::is_same::value); + using E = expressions::TsrExpr; + return einsum(E(A), E(B), Einsum::idx(cs), world); +} + +template +auto einsum( + expressions::TsrExpr A, + expressions::TsrExpr B, + std::tuple,Indices...> cs, + World &world) +{ + + using Array = std::remove_cv_t; + + auto a = std::get<0>(Einsum::idx(A)); + auto b = std::get<0>(Einsum::idx(B)); + Einsum::Index c = std::get<0>(cs); + + struct { std::string a, b, c; } inner; + if constexpr (std::tuple_size::value == 2) { + inner.a = ";" + (std::string)std::get<1>(Einsum::idx(A)); + inner.b = ";" + (std::string)std::get<1>(Einsum::idx(B)); + inner.c = ";" + (std::string)std::get<1>(cs); + } + + // these are "Hadamard" (fused) indices + auto h = a & b & c; + + // no Hadamard indices => standard contraction (or even outer product) + // same a, b, and c => pure Hadamard + if (!h || (!(a ^ b) && !(b ^ c))) { + Array C; + C(std::string(c) + inner.c) = A*B; + return C; + } + + auto e = (a ^ b); + // contracted indices + auto i = (a & b) - h; + + TA_ASSERT(e); + TA_ASSERT(h); + + using Einsum::index::small_vector; + using Range = Einsum::Range; + using RangeProduct = Einsum::RangeProduct >; + + using RangeMap = Einsum::IndexMap; + auto range_map = ( + RangeMap(a, A.array().trange()) | + RangeMap(b, B.array().trange()) + ); + + using TiledArray::Permutation; + using Einsum::index::permutation; + + struct Term { + Array array; + Einsum::Index idx; + Permutation permutation; + RangeProduct tiles; + Array local; + std::string expr; + }; + + Term AB[2] = { { A.array(), a }, { B.array(), b } }; + + for (auto &term : AB) { + auto ei = (e+i & term.idx); + if (term.idx != h+ei) { + term.permutation = permutation(term.idx, h+ei); + } + term.expr = ei; + } + + Term C = { Array(world, TiledRange(range_map[c])), c }; + for (auto idx : e) { + C.tiles *= Range(range_map[idx].tiles_range()); + } + if (C.idx != h+e) { + C.permutation = permutation(h+e, C.idx); + } + C.expr = e; + + AB[0].expr += inner.a; + AB[1].expr += inner.b; + C.expr += inner.c; + + struct { + RangeProduct tiles; + std::vector< std::vector > batch; + } H; + + for (auto idx : h) { + H.tiles *= Range(range_map[idx].tiles_range()); + H.batch.push_back({}); + for (auto r : range_map[idx]) { + H.batch.back().push_back(Range{r}.size()); + } + } + + for (auto &term : AB) { + auto ei = (e+i & term.idx); + term.local = Array(world, TiledRange(range_map[ei])); + for (auto idx : ei) { + term.tiles *= Range(range_map[idx].tiles_range()); + } + } + + // iterates over tiles of hadamard indices + using Index = Einsum::Index; + for (Index h : H.tiles) { + size_t batch = 1; + for (size_t i = 0; i < h.size(); ++i) { + batch *= H.batch[i].at(h[i]); + } + for (auto &term : AB) { + term.local = Array(term.local.world(), term.local.trange()); + const Permutation &P = term.permutation; + for (Index ei : term.tiles) { + auto tile = term.array.find(apply_inverse(P, h+ei)).get(); + if (P) tile = tile.permute(P); + auto shape = term.local.trange().tile(ei); + tile = tile.reshape(shape, batch); + term.local.set(ei, tile); + } + } + auto& [A,B] = AB; + C.local(C.expr) = A.local(A.expr) * B.local(B.expr); + const Permutation &P = C.permutation; + for (Index e : C.tiles) { + auto c = apply(P, h+e); + auto shape = C.array.trange().tile(c); + shape = apply_inverse(P, shape); + auto tile = C.local.find(e).get(); + assert(tile.batch_size() == batch); + tile = tile.reshape(shape); + if (P) tile = tile.permute(P); + C.array.set(c, tile); + } + } + + return C.array; + +} + +} // namespace TiledArray::expressions + +namespace TiledArray { + +using expressions::einsum; + +template +auto einsum( + const std::string &expr, + const DistArray &A, + const DistArray &B, + World &world = get_default_world()) +{ + namespace string = Einsum::string; + auto [lhs,rhs] = string::split2(expr, "->"); + auto [a,b] = string::split2(lhs,","); + return einsum( + A(string::join(a,",")), + B(string::join(b,",")), + string::join(rhs,",") + ); +} + +} + +#endif /* TILEDARRAY_EINSUM_H__INCLUDED */ diff --git a/src/TiledArray/expressions/einsum.h b/src/TiledArray/expressions/einsum.h index 1ef190d169..247a5769cc 100644 --- a/src/TiledArray/expressions/einsum.h +++ b/src/TiledArray/expressions/einsum.h @@ -1,183 +1,12 @@ -#ifndef TILEDARRAY_EINSUM_H__INCLUDED -#define TILEDARRAY_EINSUM_H__INCLUDED +#ifndef TILEDARRAY_EXPRESSIONS_EINSUM_H__INCLUDED +#define TILEDARRAY_EXPRESSIONS_EINSUM_H__INCLUDED -#include "TiledArray/fwd.h" -#include "TiledArray/expressions/fwd.h" -#include "TiledArray/util/index.h" -#include "TiledArray/util/range.h" -#include "TiledArray/tiled_range1.h" -#include "TiledArray/tiled_range.h" -//#include "TiledArray/util/string.h" +#include "TiledArray/einsum/tiledarray.h" namespace TiledArray::expressions { -/// einsum function without result indices assumes every index present -/// in both @p A and @p B is contracted, or, if there are no free indices, -/// pure Hadamard product is performed. -/// @param[in] A first argument to the product -/// @param[in] B second argument to the product -/// @warning just as in the plain expression code, reductions are a special -/// case; use Expr::reduce() -template -auto einsum(TsrExpr A, TsrExpr B) { - //printf("einsum(A,B)\n"); - auto a = std::get<0>(idx(A)); - auto b = std::get<0>(idx(B)); - return einsum(A, B, std::string(a^b)); -} - -/// einsum function with result indices explicitly specified -/// @param[in] A first argument to the product -/// @param[in] B second argument to the product -/// @param[in] r result indices -/// @warning just as in the plain expression code, reductions are a special -/// case; use Expr::reduce() -template -auto einsum( - TsrExpr A, TsrExpr B, - const std::string &cs, - World &world = get_default_world()) -{ - static_assert(std::is_same::value); - using E = TsrExpr; - return einsum(E(A), E(B), idx(cs), world); -} - -template -auto einsum( - TsrExpr A, TsrExpr B, - std::tuple cs, - World &world) -{ - using Array = std::remove_cv_t; - - auto a = std::get<0>(idx(A)); - auto b = std::get<0>(idx(B)); - Index c = std::get<0>(cs); - - struct { std::string a, b, c; } inner; - if constexpr (std::tuple_size::value == 2) { - inner.a = ";" + (std::string)std::get<1>(idx(A)); - inner.b = ";" + (std::string)std::get<1>(idx(B)); - inner.c = ";" + (std::string)std::get<1>(cs); - } - - // these are "Hadamard" (fused) indices - auto h = a & b & c; - - // no Hadamard indices => standard contraction (or even outer product) - // same a, b, and c => pure Hadamard - if (!h || (!(a ^ b) && !(b ^ c))) { - Array C; - C(std::string(c) + inner.c) = A*B; - return C; - } - - auto e = (a ^ b); - // contracted indices - auto i = (a & b) - h; - - TA_ASSERT(e); - TA_ASSERT(h); - - using range::Range; - using RangeProduct = range::RangeProduct >; - - using RangeMap = IndexMap; - auto range_map = ( - RangeMap(a, A.array().trange()) | - RangeMap(b, B.array().trange()) - ); - - using TiledArray::Permutation; - using TiledArray::index::permutation; - - struct Term { - Array array; - Index idx; - Permutation permutation; - RangeProduct tiles; - Array local; - std::string expr; - }; - - Term AB[2] = { { A.array(), a }, { B.array(), b } }; - - for (auto &term : AB) { - auto ei = (e+i & term.idx); - term.local = Array(world, TiledRange(range_map[ei])); - for (auto idx : ei) { - term.tiles *= Range(range_map[idx].tiles_range()); - } - if (term.idx != h+ei) { - term.permutation = permutation(term.idx, h+ei); - } - term.expr = ei; - } - - Term C = { Array(world, TiledRange(range_map[c])), c }; - for (auto idx : e) { - C.tiles *= Range(range_map[idx].tiles_range()); - } - if (C.idx != h+e) { - C.permutation = permutation(h+e, C.idx); - } - C.expr = e; - - AB[0].expr += inner.a; - AB[1].expr += inner.b; - C.expr += inner.c; - - struct { - RangeProduct tiles; - std::vector< std::vector > batch; - } H; - - for (auto idx : h) { - H.tiles *= Range(range_map[idx].tiles_range()); - H.batch.push_back({}); - for (auto r : range_map[idx]) { - H.batch.back().push_back(Range{r}.size()); - } - } - - // iterates over tiles of hadamard indices - using Index = index::Index; - for (Index h : H.tiles) { - size_t batch = 1; - for (size_t i = 0; i < h.size(); ++i) { - batch *= H.batch[i].at(h[i]); - } - for (auto &term : AB) { - term.local = Array(term.local.world(), term.local.trange()); - const Permutation &P = term.permutation; - for (Index ei : term.tiles) { - auto tile = term.array.find(apply_inverse(P, h+ei)).get(); - if (P) tile = tile.permute(P); - auto shape = term.local.trange().tile(ei); - tile = tile.reshape(shape, batch); - term.local.set(ei, tile); - } - } - auto& [A,B] = AB; - C.local(C.expr) = A.local(A.expr) * B.local(B.expr); - const Permutation &P = C.permutation; - for (Index e : C.tiles) { - auto c = apply(P, h+e); - auto shape = C.array.trange().tile(c); - shape = apply_inverse(P, shape); - auto tile = C.local.find(e).get(); - assert(tile.batch_size() == batch); - tile = tile.reshape(shape); - if (P) tile = tile.permute(P); - C.array.set(c, tile); - } - } - - return C.array; - -} + using TiledArray::einsum; } // namespace TiledArray::expressions -#endif /* TILEDARRAY_EINSUM_H__INCLUDED */ +#endif /* TILEDARRAY_EXPRESSIONS_EINSUM_H__INCLUDED */ diff --git a/src/TiledArray/util/string.h b/src/TiledArray/util/string.h deleted file mode 100644 index 545da97259..0000000000 --- a/src/TiledArray/util/string.h +++ /dev/null @@ -1,41 +0,0 @@ -// -// Created by Samuel R. Powell on 4/15/21. -// -#pragma once - -#ifndef TILEDARRAY_STRING_H -#define TILEDARRAY_STRING_H - -#include -#include -#include -#include -#include - -namespace TiledArray::string { - - // Split delimiter must match completely - std::vector split(const std::string& s, char d) { - std::vector res; - return boost::split(res, s, [&d](char c) { return c == d; } /*boost::is_any_of(d)*/); - } - - std::string trim(const std::string& s) { - return boost::trim_copy(s); - } - - template - std::string join(const T &s, const std::string& j = "") { - return boost::join(s, j); - } - - template - std::string str(const T& obj) { - std::stringstream ss; - ss << obj; - return ss.str(); - } - -} - -#endif //TILEDARRAY_STRING_H diff --git a/tests/einsum.cpp b/tests/einsum.cpp index 5ea2f42151..d3cd7fe86e 100644 --- a/tests/einsum.cpp +++ b/tests/einsum.cpp @@ -16,9 +16,11 @@ * along with this program. If not, see . * */ + +#include "unit_test_config.h" + #include "TiledArray/expressions/einsum.h" #include "tiledarray.h" -#include "unit_test_config.h" #include "tot_array_fixture.h" #include "TiledArray/expressions/contraction_helpers.h" @@ -282,4 +284,317 @@ BOOST_AUTO_TEST_CASE(xxx) { } +BOOST_AUTO_TEST_SUITE_END() + +#include "TiledArray/einsum/eigen.h" + +template +bool isApprox( + const Eigen::TensorBase &A, + const Eigen::TensorBase &B) +{ + Eigen::Tensor r = (derived(A) == derived(B)).all(); + return r.coeffRef(); +} + +// Eigen einsum expressions +BOOST_AUTO_TEST_SUITE(einsum_eigen, TA_UT_LABEL_SERIAL) + +template +auto random(Args ... args) { + Eigen::Tensor t(args...); + t.setRandom(); + return t; +} + +template +void einsum_eigen_contract_check( + Eigen::Tensor A, + Eigen::Tensor B, + std::string expr, + std::array< std::pair, NI> i, + std::array p) +{ + using Eigen::einsum; + using std::tuple; + static_assert(NC == NA+NB-2*NI); + auto C = Eigen::einsum< Eigen::Tensor >(expr, A, B); + BOOST_CHECK(isApprox(C, A.contract(B,i).shuffle(p))); +} + +template +void einsum_eigen_hadamard_check( + Eigen::Tensor A, + Eigen::Tensor B, + std::string expr, + std::array< std::pair, NI> i, // contracted pairs + std::array< std::pair, 1> h, // hadamard pairs + std::array p) +{ + static_assert(NC == NA+NB-2*NI-1); + auto [ha,hb] = h[0]; + size_t nh = A.dimension(ha); + size_t hc = ha; + // decrement internal indices above h-index + for (auto& [ia,ib] : i) { + if (ia < ha) --hc; + if (ia > ha) ia -= 1; + if (ib > hb) ib -= 1; + } + // shuffle C to A*B order + std::vector p_ab(p.size()); + for (size_t i = 0; i < p.size(); ++i) { + p_ab.at(p[i]) = i; + } + auto C = Eigen::einsum< Eigen::Tensor >(expr, A, B); + // validate h-dims + eigen_assert(nh == A.dimension(ha)); + eigen_assert(nh == B.dimension(hb)); + eigen_assert(nh == C.dimension(p_ab.at(hc))); + // validate + for (size_t h = 0; h < nh; ++h) { + auto ah = A.chip(h,ha); + auto bh = B.chip(h,hb); + auto result = C.shuffle(p_ab).chip(h,hc); + auto expected = ah.contract(bh,i); + BOOST_CHECK(isApprox(result, expected)); + } +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_ak_bk_ab) { + using std::array; + using std::pair; + einsum_eigen_contract_check( + random(11,7), + random(13,7), + "ak,bk->ab", + array{ pair{1,1} }, + array{ 0, 1 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_ka_bk_ba) { + using std::array; + using std::pair; + einsum_eigen_contract_check( + random(7,11), + random(13,7), + "ka,bk->ba", + array{ pair{0,1} }, + array{ 1, 0 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_abi_cdi_cdab) { + using std::array; + using std::pair; + einsum_eigen_contract_check( + random(21,22,3), + random(24,25,3), + "abi,cdi->cdab", + array{ pair{2,2} }, + array{ 2, 3, 0, 1 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_icd_ai_abcd) { + using std::array; + using std::pair; + einsum_eigen_contract_check( + random(3,12,13), + random(14,15,3), + "icd,bai->abcd", + array{ pair{0,2} }, + array{ 3, 2, 0, 1 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_cdji_ibja_abcd) { + using std::array; + using std::pair; + einsum_eigen_contract_check( + random(14,15,3,5), + random(5,12,3,13), + "cdji,ibja->abcd", + array{ pair{3,0}, pair{2,2} }, + array{ 3, 2, 0, 1 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_hai_hbi_hab) { + using std::array; + using std::pair; + einsum_eigen_hadamard_check( + random(7,14,3), + random(7,15,3), + "hai,hbi->hab", + array{ pair{2,2} }, + array{ pair{ 0, 0 } }, + array{ 0, 1, 2 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_iah_hib_bha) { + using std::array; + using std::pair; + einsum_eigen_hadamard_check( + random(7,14,3), + random(3,7,15), + "iah,hib->bha", + array{ pair{0,1} }, + array{ pair{ 2, 0 } }, + array{ 2, 1, 0 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_iah_hib_abh) { + using std::array; + using std::pair; + einsum_eigen_hadamard_check( + random(7,14,3), + random(3,7,15), + "iah,hib->abh", + array{ pair{0,1} }, + array{ pair{ 2, 0 } }, + array{ 0, 2, 1 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_hi_hi_h) { + using std::array; + using std::pair; + einsum_eigen_hadamard_check( + random(7,14), + random(7,14), + "hi,hi->h", + array{ pair{1,1} }, + array{ pair{ 0, 0 } }, + array{ 0 } + ); +} + +BOOST_AUTO_TEST_CASE(einsum_eigen_hji_jih_hj) { + using std::array; + using T = int; + Eigen::Index nh = 3, nj = 5, ni = 4; + auto A = random(nh,nj,ni); + auto B = random(nj,ni,nh); + using R = Eigen::Tensor; + R result = einsum("hji,jih->hj", A, B); + R reference = einsum< Eigen::Tensor >( + "hi,hi->h", + A.shuffle(array{0,1,2}).reshape(array{nh*nj,ni}), + B.shuffle(array{2,0,1}).reshape(array{nh*nj,ni}) + ).reshape(array{nh,nj}); + BOOST_CHECK(isApprox(reference, result)); +} + +BOOST_AUTO_TEST_SUITE_END() + +// TiledArray einsum expressions +BOOST_AUTO_TEST_SUITE(einsum_tiledarray, TA_UT_LABEL_SERIAL) + +template, typename ... Args> +auto random(Args ... args) { + TiledArray::TiledRange tr{ {0, args}... }; + auto& world = TiledArray::get_default_world(); + TiledArray::DistArray t(world,tr); + t.fill_random(); + return t; +} + +template +void einsum_tiledarray_check( + const TiledArray::DistArray &A, + const TiledArray::DistArray &B, + std::string expr) +{ + using Eigen::Tensor; + using U = typename T::value_type; + using TC = Tensor; + auto result = einsum(expr, A, B); + BOOST_CHECK( + isApprox( + array_to_eigen_tensor(result), + einsum( + expr, + array_to_eigen_tensor>(A), + array_to_eigen_tensor>(B) + ) + ) + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_ak_bk_ab) { + einsum_tiledarray_check<2,2,2>( + random(11,7), + random(13,7), + "ak,bk->ab" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_ka_bk_ba) { + einsum_tiledarray_check<2,2,2>( + random(7,11), + random(13,7), + "ka,bk->ba" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_abi_cdi_cdab) { + einsum_tiledarray_check<3,3,4>( + random(21,22,3), + random(24,25,3), + "abi,cdi->cdab" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_icd_ai_abcd) { + einsum_tiledarray_check<3,3,4>( + random(3,12,13), + random(14,15,3), + "icd,bai->abcd" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_cdji_ibja_abcd) { + einsum_tiledarray_check<4,4,4>( + random(14,15,3,5), + random(5,12,3,13), + "cdji,ibja->abcd" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_hai_hbi_hab) { + einsum_tiledarray_check<3,3,3>( + random(7,14,3), + random(7,15,3), + "hai,hbi->hab" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_bha) { + einsum_tiledarray_check<3,3,3>( + random(7,14,3), + random(3,7,15), + "iah,hib->bha" + ); +} + +BOOST_AUTO_TEST_CASE(einsum_tiledarray_iah_hib_abh) { + einsum_tiledarray_check<3,3,3>( + random(7,14,3), + random(3,7,15), + "iah,hib->abh" + ); +} + +// BOOST_AUTO_TEST_CASE(einsum_tiledarray_hi_hi_h) { +// einsum_tiledarray_check<2,2,1>( +// random(7,14), +// random(7,14), +// "hi,hi->h" +// ); +// } + BOOST_AUTO_TEST_SUITE_END()