Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ TiledArray/dist_eval/binary_eval.h
TiledArray/dist_eval/contraction_eval.h
TiledArray/dist_eval/dist_eval.h
TiledArray/dist_eval/unary_eval.h
TiledArray/einsum/index.h
TiledArray/einsum/index.cpp
TiledArray/einsum/range.h
TiledArray/einsum/string.h
TiledArray/expressions/add_engine.h
TiledArray/expressions/add_expr.h
TiledArray/expressions/binary_engine.h
Expand Down Expand Up @@ -179,14 +183,10 @@ TiledArray/util/annotation.h
TiledArray/util/backtrace.h
TiledArray/util/bug.h
TiledArray/util/function.h
TiledArray/util/index.h
TiledArray/util/index.cpp
TiledArray/util/initializer_list.h
TiledArray/util/logger.h
TiledArray/util/random.h
TiledArray/util/range.h
TiledArray/util/singleton.h
TiledArray/util/string.h
TiledArray/util/threads.h
TiledArray/util/threads.cpp
TiledArray/util/time.h
Expand Down
102 changes: 102 additions & 0 deletions src/TiledArray/einsum/eigen.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#ifndef TILEDARRAY_EINSUM_EIGEN_H__INCLUDED
#define TILEDARRAY_EINSUM_EIGEN_H__INCLUDED

#include "TiledArray/fwd.h"
#include "TiledArray/external/eigen.h"
#include "TiledArray/einsum/index.h"
#include "TiledArray/einsum/range.h"
#include "TiledArray/einsum/string.h"

namespace Eigen {

template<typename Derived, int Options>
const Derived& derived(const TensorBase<Derived,Options> &t) {
return static_cast<const Derived&>(t);
}

template<typename TA, typename TB, typename TC>
void einsum(
std::string expr,
const Eigen::TensorBase<TA,Eigen::ReadOnlyAccessors> &A,
const Eigen::TensorBase<TB,Eigen::ReadOnlyAccessors> &B,
TC &C)
{

static_assert((TA::NumDimensions+TB::NumDimensions) >= TC::NumDimensions);
//static_assert((TA::NumDimensions+TB::NumDimensions)%2 == TC::NumDimensions%2);

using Index = TiledArray::Einsum::Index<char>;
using IndexDims = TiledArray::Einsum::IndexMap<char,size_t>;
using TiledArray::Einsum::string::split2;

auto permutation = [](auto src, auto dst) {
return TiledArray::Einsum::index::permutation(dst, src);
};

Index a,b,c;
std::string ab;
std::tie(ab,c) = split2(expr, "->");
std::tie(a,b) = split2(ab, ",");

// these are "Hadamard" (fused) indices
auto h = a & b & c;

auto e = (a ^ b);
auto he = h+e;

// contracted indices
auto i = (a & b) - h;

eigen_assert(a.size() == A.NumDimensions);
eigen_assert(b.size() == B.NumDimensions);
eigen_assert(c.size() == C.NumDimensions);
eigen_assert(he.size() == C.NumDimensions);

IndexDims dimensions = (
IndexDims(a, derived(A).dimensions()) |
IndexDims(b, derived(B).dimensions())
);

auto product = [](auto &&dims) {
int64_t n = 1;
for (auto dim : dims) { n *= dim; }
return n;
};

int64_t nh = product(dimensions[h]);
int64_t na = product(dimensions[a&e]);
int64_t nb = product(dimensions[b&e]);
int64_t ni = product(dimensions[i]);

auto pA = A.shuffle(permutation(a, h+(e&a)+i)).reshape(std::array{nh,na,ni}).eval();
auto pB = B.shuffle(permutation(b, h+(e&b)+i)).reshape(std::array{nh,nb,ni}).eval();
Eigen::Tensor<typename TC::Scalar,3> C3(nh,na,nb);

for (int64_t h = 0; h < nh; ++h) {
//Eigen::array<Eigen::IndexPair<int>, 1> axis = { Eigen::IndexPair<int>(0, 0) };
C3.chip(h,0) = pA.chip(h,0).contract(pB.chip(h,0), std::array{ std::pair{1,1} });
}

std::array<int,TC::NumDimensions> permuted_shape;
for (int k = 0; k < permuted_shape.size(); ++k) {
permuted_shape[k] = dimensions[he[k]];
}

C = C3.reshape(permuted_shape).shuffle(permutation(he, c));

}

template<typename T, typename TA, typename TB>
T einsum(
std::string expr,
const Eigen::TensorBase<TA,Eigen::ReadOnlyAccessors> &A,
const Eigen::TensorBase<TB,Eigen::ReadOnlyAccessors> &B)
{
T AB;
einsum(expr, A, B, AB);
return AB;
}

} // namespace Eigen

#endif /* TILEDARRAY_EINSUM_EIGEN_H__INCLUDED */
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// Samuel R. Powell, 2021
#include "TiledArray/util/index.h"
#include "TiledArray/einsum/index.h"
#include "TiledArray/einsum/string.h"
#include "TiledArray/util/annotation.h"
#include "TiledArray/util/string.h"

namespace TiledArray::index {
namespace TiledArray::Einsum::index {

std::vector<std::string> validate(const std::vector<std::string> &v) {
return v;
Expand All @@ -22,4 +22,4 @@ std::string join(const small_vector<std::string> &v) {
return string::join(v, ",");
}

} // namespace TiledArray::index
} // namespace TiledArray::Einsum::index
33 changes: 19 additions & 14 deletions src/TiledArray/util/index.h → src/TiledArray/einsum/index.h
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
// Samuel R. Powell, 2021
#ifndef TILEDARRAY_EINSUM_INDEX_H__INCLUDED
#define TILEDARRAY_EINSUM_INDEX_H__INCLUDED

#include "TiledArray/expressions/fwd.h"

#include <TiledArray/error.h>
#include <TiledArray/permutation.h>
#include <TiledArray/util/vector.h>
#include <TiledArray/einsum/string.h>

#include <iosfwd>
#include <string>

namespace TiledArray::index {
namespace TiledArray::Einsum::index {

template <typename T>
using small_vector = container::svector<T>;
using small_vector = TiledArray::container::svector<T>;

small_vector<std::string> tokenize(const std::string &s);

small_vector<std::string> validate(const small_vector<std::string> &v);

std::string join(const small_vector<std::string> &v);

template <typename T, typename U>
using enable_if_string = std::enable_if_t<std::is_same_v<T, std::string>, U>;

Expand All @@ -43,7 +44,7 @@ class Index {

template <typename U = void>
operator std::string() const {
return index::join(data_);
return string::join(data_, ",");
}

explicit operator bool() const { return !data_.empty(); }
Expand Down Expand Up @@ -276,21 +277,23 @@ IndexMap<K, V> operator|(const IndexMap<K, V> &a, const IndexMap<K, V> &b) {
return IndexMap(d);
}

} // namespace TiledArray::index
} // namespace TiledArray::Einsum::index

namespace TiledArray {
namespace TiledArray::Einsum {

using Index = TiledArray::index::Index<std::string>;
using TiledArray::index::IndexMap;
using TiledArray::Einsum::index::Index;
using TiledArray::Einsum::index::IndexMap;

/// converts the annotation of an expression to an Index
template <typename Array>
auto idx(const std::string &s) {
using Index = Einsum::Index<std::string>;
if constexpr (detail::is_tensor_of_tensor_v<typename Array::value_type>) {
auto semi = std::find(s.begin(), s.end(), ';');
TA_ASSERT(semi != s.end());
auto first = std::string(s.begin(), semi);
auto second = std::string(semi + 1, s.end());
auto [first,second] = string::split2(s, ";");
TA_ASSERT(!first.empty());
TA_ASSERT(!second.empty());
return std::tuple<Index, Index>{first, second};
} else {
return std::tuple<Index>{s};
Expand All @@ -299,8 +302,10 @@ auto idx(const std::string &s) {

/// converts the annotation of an expression to an Index
template <typename A, bool Alias>
auto idx(const expressions::TsrExpr<A, Alias> &e) {
auto idx(const TiledArray::expressions::TsrExpr<A, Alias> &e) {
return idx<A>(e.annotation());
}

} // namespace TiledArray
} // namespace TiledArray::Einsum

#endif /* TILEDARRAY_EINSUM_INDEX_H__INCLUDED */
14 changes: 12 additions & 2 deletions src/TiledArray/util/range.h → src/TiledArray/einsum/range.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#ifndef TILEDARRAY_EINSUM_RANGE_H__INCLUDED
#define TILEDARRAY_EINSUM_RANGE_H__INCLUDED

#include <TiledArray/util/vector.h>

#include <vector>
#include <boost/iterator/counting_iterator.hpp>

namespace TiledArray::range {
namespace TiledArray::Einsum::range {

template<typename T>
using small_vector = container::svector<T>;
Expand Down Expand Up @@ -127,4 +130,11 @@ void cartesian_foreach(const std::vector<R>& rs, F f) {
}
}

} // namespace TiledArray::expressions
} // namespace TiledArray::Einsum::range

namespace TiledArray::Einsum {
using range::Range;
using range::RangeProduct;
}

#endif /* TILEDARRAY_EINSUM_RANGE_H__INCLUDED */
50 changes: 50 additions & 0 deletions src/TiledArray/einsum/string.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef TILEDARRAY_EINSUM_STRING_H
#define TILEDARRAY_EINSUM_STRING_H

#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/trim.hpp>
#include <boost/algorithm/string/join.hpp>
#include <string>
#include <vector>

namespace TiledArray::Einsum::string {
namespace {

// Split delimiter must match completely
template<typename T = std::string, typename U = T>
std::pair<T,U> split2(const std::string& s, const std::string &d) {
auto pos = s.find(d);
if (pos == s.npos) return { T(s), U("") };
return { T(s.substr(0,pos)), U(s.substr(pos+d.size())) };
}

// Split delimiter must match completely
std::vector<std::string> split(const std::string& s, char d) {
std::vector<std::string> res;
return boost::split(res, s, [&d](char c) { return c == d; } /*boost::is_any_of(d)*/);
}

std::string trim(const std::string& s) {
return boost::trim_copy(s);
}

template <typename T>
std::string str(const T& obj) {
std::stringstream ss;
ss << obj;
return ss.str();
}

template<typename T, typename U = std::string>
std::string join(const T &s, const U& j = U("")) {
std::vector<std::string> strings;
for (auto e : s) {
strings.push_back(str(e));
}
return boost::join(strings, j);
}

}
}

#endif //TILEDARRAY_EINSUM_STRING_H
Loading