Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions src/TiledArray/conversions/foreach.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,9 +383,9 @@ inline std::
/// Policy::shape_type::value_type and used to construct the shape of the
/// result, whereas in the former case the shape of the result is computed from
/// the shapes of the DistArray arguments (e.g. assigned to the shape of the
/// first DistArray argument). \note \c foreach/foreach_inplace are collective,
/// with sparse variants synchronizing due to the need to compute and replicate
/// shapes.
/// first DistArray argument).
/// \note \c foreach/foreach_inplace are collective, with sparse variants
/// synchronizing due to the need to compute and replicate shapes.

/// @{

Expand Down Expand Up @@ -459,10 +459,12 @@ inline std::enable_if_t<is_dense_v<Policy>, DistArray<Tile, Policy>> foreach (
/// \param op The mutating tile function
/// \param fence A flag that indicates fencing behavior. If \c true this
/// function will fence before data is modified.
/// \warning This function fences by default to avoid data race conditions.
/// \warning
/// - This function fences by default to avoid data race conditions.
/// Only disable the fence if you can ensure, the data is not being read by
/// another thread.
/// \warning If there is a another copy of \c arg that was created via (or
/// \warning
/// - If there is a another copy of \c arg that was created via (or
/// arg was created by) the \c Array copy constructor or copy assignment
/// operator, this function will modify the data of that array since the data
/// of a tile is held in a \c std::shared_ptr. If you need to ensure other
Expand Down Expand Up @@ -509,12 +511,16 @@ inline std::enable_if_t<is_dense_v<Policy>, void> foreach_inplace(
/// const Tile& arg_tile);
/// \endcode
/// where in the case of standard Policy (i.e. SparsePolicy) the return value of
/// \c op is the 2-norm (Frobenius norm) of the result tile. \note This function
/// should not be used to initialize the tiles of an array object. \tparam
/// ResultTile The tile type of the result \tparam Tile The tile type of \c arg
/// \c op is the 2-norm (Frobenius norm) of the result tile.
/// \note This function should not be used to initialize the tiles of an array
/// object.
/// \tparam ResultTile The tile type of the result
/// \tparam Tile The tile type of \c arg
/// \tparam Policy The policy type of \c arg; \c is_dense_v<Policy> must be
/// false \tparam Op Tile operation \param arg The argument array \param op The
/// tile function
/// false
/// \tparam Op Tile operation
/// \param arg The argument array
/// \param op The tile function
template <typename ResultTile, typename ArgTile, typename Policy, typename Op,
typename = typename std::enable_if<
!std::is_same<ResultTile, ArgTile>::value>::type>
Expand Down Expand Up @@ -567,10 +573,12 @@ inline std::enable_if_t<!is_dense_v<Policy>, DistArray<Tile, Policy>> foreach (
/// \param op The mutating tile function
/// \param fence A flag that indicates fencing behavior. If \c true this
/// function will fence before data is modified.
/// \warning This function fences by default to avoid data race conditions.
/// \warning
/// - This function fences by default to avoid data race conditions.
/// Only disable the fence if you can ensure, the data is not being read by
/// another thread.
/// \warning If there is a another copy of \c arg that was created via (or
/// \warning
/// - If there is a another copy of \c arg that was created via (or
/// arg was created by) the \c Array copy constructor or copy assignment
/// operator, this function will modify the data of that array since the data
/// of a tile is held in a \c std::shared_ptr. If you need to ensure other
Expand Down
103 changes: 63 additions & 40 deletions src/TiledArray/einsum/eigen.h
Original file line number Diff line number Diff line change
@@ -1,48 +1,68 @@
#ifndef TILEDARRAY_EINSUM_EIGEN_H__INCLUDED
#define TILEDARRAY_EINSUM_EIGEN_H__INCLUDED

#include "TiledArray/fwd.h"
#include "TiledArray/external/eigen.h"
#include "TiledArray/einsum/index.h"
#include "TiledArray/einsum/range.h"
#include "TiledArray/einsum/string.h"
#include "TiledArray/external/eigen.h"
#include "TiledArray/fwd.h"

namespace Eigen {

template<typename Derived, int Options>
const Derived& derived(const TensorBase<Derived,Options> &t) {
return static_cast<const Derived&>(t);
template <typename Derived, int Options>
const Derived &derived(const TensorBase<Derived, Options> &t) {
return static_cast<const Derived &>(t);
}

template<typename TA, typename TB, typename TC>
void einsum(
std::string expr,
const Eigen::TensorBase<TA,Eigen::ReadOnlyAccessors> &A,
const Eigen::TensorBase<TB,Eigen::ReadOnlyAccessors> &B,
TC &C)
{

static_assert((TA::NumDimensions+TB::NumDimensions) >= TC::NumDimensions);
//static_assert((TA::NumDimensions+TB::NumDimensions)%2 == TC::NumDimensions%2);
/// Evaluates a binary tensor product specified by a string that follows
/// the
/// [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)
/// format
/// \tparam TA a class derived from Eigen::TensorBase, e.g. Eigen::Tensor
/// \tparam TB a class derived from Eigen::TensorBase, e.g. Eigen::Tensor
/// \tparam TC a class derived from Eigen::TensorBase, e.g. Eigen::Tensor
/// \param expr string specification of the tensor product;
/// follows the _explicit_
/// [numpy.einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html)
/// format, i.e. `"<AnnA>,<AnnB>-><AnnC>"`, where `<AnnX>` is an index
/// annotation for tensor `X`; note that slicing is not currently
/// supported, i.e. an index can only annotate a single mode of
/// a given tensor; for example, `"ii,k->ik"` is an invalid
/// value for \p expr .
/// \param A first argument
/// \param B second argument
/// \param C result
///
/// *Examples*:
/// - matrix multiplication: `einsum("ij,jk->ik",A,B,C)`
/// - matrix multiplication, with result transposed: `einsum("ij,jk->ki",A,B,C)`
/// - Hadamard product: `einsum("ij,ij->ij",A,B,C)`
/// - standard tensor contraction: `einsum("ijk,jkl->li",A,B,C)`
/// - standard tensor contraction: `einsum("ijk,jkl->li",A,B,C)`
template <typename TA, typename TB, typename TC>
void einsum(std::string expr,
const Eigen::TensorBase<TA, Eigen::ReadOnlyAccessors> &A,
const Eigen::TensorBase<TB, Eigen::ReadOnlyAccessors> &B, TC &C) {
static_assert((TA::NumDimensions + TB::NumDimensions) >= TC::NumDimensions);

using Index = TiledArray::Einsum::Index<char>;
using IndexDims = TiledArray::Einsum::IndexMap<char,size_t>;
using IndexDims = TiledArray::Einsum::IndexMap<char, size_t>;
using TiledArray::Einsum::string::split2;

auto permutation = [](auto src, auto dst) {
return TiledArray::Einsum::index::permutation(dst, src);
};

Index a,b,c;
Index a, b, c;
std::string ab;
std::tie(ab,c) = split2(expr, "->");
std::tie(a,b) = split2(ab, ",");
std::tie(ab, c) = split2(expr, "->");
std::tie(a, b) = split2(ab, ",");

// these are "Hadamard" (fused) indices
auto h = a & b & c;

auto e = (a ^ b);
auto he = h+e;
auto he = h + e;

// contracted indices
auto i = (a & b) - h;
Expand All @@ -52,46 +72,49 @@ void einsum(
eigen_assert(c.size() == C.NumDimensions);
eigen_assert(he.size() == C.NumDimensions);

IndexDims dimensions = (
IndexDims(a, derived(A).dimensions()) |
IndexDims(b, derived(B).dimensions())
);
IndexDims dimensions = (IndexDims(a, derived(A).dimensions()) |
IndexDims(b, derived(B).dimensions()));

auto product = [](auto &&dims) {
int64_t n = 1;
for (auto dim : dims) { n *= dim; }
for (auto dim : dims) {
n *= dim;
}
return n;
};

int64_t nh = product(dimensions[h]);
int64_t na = product(dimensions[a&e]);
int64_t nb = product(dimensions[b&e]);
int64_t na = product(dimensions[a & e]);
int64_t nb = product(dimensions[b & e]);
int64_t ni = product(dimensions[i]);

auto pA = A.shuffle(permutation(a, h+(e&a)+i)).reshape(std::array{nh,na,ni}).eval();
auto pB = B.shuffle(permutation(b, h+(e&b)+i)).reshape(std::array{nh,nb,ni}).eval();
Eigen::Tensor<typename TC::Scalar,3> C3(nh,na,nb);
auto pA = A.shuffle(permutation(a, h + (e & a) + i))
.reshape(std::array{nh, na, ni})
.eval();
auto pB = B.shuffle(permutation(b, h + (e & b) + i))
.reshape(std::array{nh, nb, ni})
.eval();
Eigen::Tensor<typename TC::Scalar, 3> C3(nh, na, nb);

for (int64_t h = 0; h < nh; ++h) {
//Eigen::array<Eigen::IndexPair<int>, 1> axis = { Eigen::IndexPair<int>(0, 0) };
C3.chip(h,0) = pA.chip(h,0).contract(pB.chip(h,0), std::array{ std::pair{1,1} });
// Eigen::array<Eigen::IndexPair<int>, 1> axis = { Eigen::IndexPair<int>(0,
// 0) };
C3.chip(h, 0) =
pA.chip(h, 0).contract(pB.chip(h, 0), std::array{std::pair{1, 1}});
}

std::array<int,TC::NumDimensions> permuted_shape;
std::array<int, TC::NumDimensions> permuted_shape;
for (int k = 0; k < permuted_shape.size(); ++k) {
permuted_shape[k] = dimensions[he[k]];
}

C = C3.reshape(permuted_shape).shuffle(permutation(he, c));

}

template<typename T, typename TA, typename TB>
T einsum(
std::string expr,
const Eigen::TensorBase<TA,Eigen::ReadOnlyAccessors> &A,
const Eigen::TensorBase<TB,Eigen::ReadOnlyAccessors> &B)
{
template <typename T, typename TA, typename TB>
T einsum(std::string expr,
const Eigen::TensorBase<TA, Eigen::ReadOnlyAccessors> &A,
const Eigen::TensorBase<TB, Eigen::ReadOnlyAccessors> &B) {
T AB;
einsum(expr, A, B, AB);
return AB;
Expand Down
14 changes: 7 additions & 7 deletions src/TiledArray/expressions/index_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ class IndexList {
/// \param str a string containing comma-separated index labels.
/// All whitespaces are discarded, i.e., "a c" will be converted to "ac"
/// and will be considered a single index.
explicit IndexList(const std::string& str) {
if (!str.empty()) init_(str);
}
/// \throw TiledArray::Exception if
/// `TiledArray::detail::is_valid_index(str)==false`
explicit IndexList(const std::string& str) { init_(str); }

/// constructs from a range of index labels

Expand Down Expand Up @@ -267,6 +267,8 @@ class IndexList {
private:
/// Initializes from a comma-separated sequence of indices
void init_(const std::string& str) {
if (!TiledArray::detail::is_valid_index(str))
TA_EXCEPTION_MESSAGE(__FILE__, __LINE__, "IndexList(str): invalid str");
std::string::const_iterator start = str.begin();
std::string::const_iterator finish = str.begin();
for (; finish != str.end(); ++finish) {
Expand Down Expand Up @@ -307,9 +309,7 @@ class IndexList {
}

static bool valid_char_(char c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') || (c == ' ') || (c == ',') || (c == '\0') ||
(c == '\'') || (c == '_');
return TiledArray::detail::is_valid_annotation_character(c);
}

friend void swap(IndexList&, IndexList&);
Expand Down Expand Up @@ -547,7 +547,7 @@ class BipartiteIndexList {
/// \return A read-only reference to the requested string index.
#ifdef BOOST_CONTAINER_USE_STD_EXCEPTIONS
/// \throw std::out_of_range
#else // BOOST_CONTAINER_USE_STD_EXCEPTIONS
#else // BOOST_CONTAINER_USE_STD_EXCEPTIONS
/// \throw boost::container::out_of_range
#endif
/// if \c n is not in the range [0, dim()). Strong throw guarantee.
Expand Down
4 changes: 3 additions & 1 deletion src/TiledArray/tile_op/tile_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,9 @@ inline Result& gemm(Result& result, const Left& left, const Right& right,
/// multiply op operation for tensor elements \return A tile whose element
/// <tt>result[i,j]</tt> obtained by executing
/// `foreach k: element_multiplyadd_op(result[i,j], left[i,k], right[k,j])`
/// \example For plain tensors GEMM can be implemented (very inefficiently)
///
/// _Example:_
/// For plain tensors GEMM can be implemented (very inefficiently)
/// using this method as follows:
/// \code
/// gemm(result, left, right, gemm_config,
Expand Down
62 changes: 30 additions & 32 deletions src/TiledArray/util/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
*
*/

#ifndef TILEDARRAY_ANNOTATION_H__INCLUDED
#define TILEDARRAY_ANNOTATION_H__INCLUDED
#ifndef TILEDARRAY_UTIL_ANNOTATION_H__INCLUDED
#define TILEDARRAY_UTIL_ANNOTATION_H__INCLUDED

#include "TiledArray/error.h"

Expand Down Expand Up @@ -95,48 +95,46 @@ inline auto tokenize_index(const std::string& s, char delim) {
return tokens;
}

/// Checks that the provided index is a valid TiledArray index
///
/// TiledArray defines a string as being a valid index if each character is one
/// of the following:
/// Checks that the provided character is a valid character in an
/// TiledArray index annotation
///
/// \param[in] ch a character
/// \return true if \p ch is any of the following
/// - Roman letters (`A..Z`, `a..z`)
/// - decimal digits (`0..9`)
/// - whitespace (` `)
/// - comma (`,`)
/// - semicolon (`;`)
/// - any of the following characters: `'`_~!@#$%^&*-+./?:|<>[]{}()`
///
/// Additionally the string can not:
///
/// - be only whitespace
/// - contain more than one semicolon
/// - have anonymous index name (i.e. can't have "i,,k" because the middle index
/// has no name).
///
/// \param[in] idx The index whose validity is being questioned.
/// \return True if the string corresponds to a valid index and false otherwise.
/// \note This function only tests that the characters making up the index are
/// valid. The index may still be invalid for a particular tensor. For
/// example if \c idx is an index for a matrix, but the actual tensor is
/// rank 3, then \c idx would be an invalid index for that tensor despite
/// being a valid index.
/// \throw std::bad_alloc if there is insufficient memory to copy \c idx. Strong
/// throw guarantee.
/// \throw std::bad_alloc if there is insufficient memory to split \c idx into
/// tokens. Strong throw guarantee.
inline bool is_valid_index(const std::string& idx) {
inline bool is_valid_annotation_character(char ch) {
const std::string valid_chars =
"abcdefghijklmnopqrstuvwxyz"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"1234567890"
",; '`_~!@#$%^&*-+./?:|<>[]{}()";
// Are valid characters
for (const auto& c : idx)
if (valid_chars.find(c) == std::string::npos) return false;
return valid_chars.find(ch) != std::string::npos;
}

// Is not only whitespace
auto no_ws = remove_whitespace(idx);
/// Checks that the provided string is a valid TiledArray index annotation.
///
/// Index annotations are used to annotate modes of tensors or tensor
/// expressions. This function only checks whether an annotation is
/// syntactically valid. A valid index annotation consists of one or more
/// sequences of one or more valid (\sa is_valid_annotation_character() )
/// non-separator nonwhitespace characters separated by separator characters
/// (`,` and `;`). Only one appearance of the `;` separator is permitted.
/// Whitespace characters are ignored and removed from \p str before the test.
///
/// \param[in] str string to be tested
/// \return true if \p str is a valid index annotation
inline bool is_valid_index(const std::string& str) {
// to be valid must contain only valid characters
for (const auto& c : str)
if (!is_valid_annotation_character(c)) return false;

auto no_ws = remove_whitespace(str);

// empty annotations are not permitted
if (no_ws.size() == 0) return false;

// At most one semicolon
Expand Down Expand Up @@ -208,4 +206,4 @@ inline auto split_index(const std::string& idx) {

} // namespace TiledArray::detail

#endif // TILEDARRAY_ANNOTATION_H__INCLUDED
#endif // TILEDARRAY_UTIL_ANNOTATION_H__INCLUDED
2 changes: 1 addition & 1 deletion tests/annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ BOOST_AUTO_TEST_CASE(multiple_semicolons) {
BOOST_CHECK(is_valid_index("i;j;k") == false);
}

BOOST_AUTO_TEST_CASE(only_whitespace) {
BOOST_AUTO_TEST_CASE(at_least_one_index) {
BOOST_CHECK(is_valid_index("") == false);
BOOST_CHECK(is_valid_index(" ") == false);
BOOST_CHECK(is_valid_index(" ") == false);
Expand Down
Loading