Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 64 additions & 43 deletions src/TiledArray/sparse_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class SparseShape {
size_vectors_; ///< Tile size information; size_vectors_.get()[d][i]
///< reports the size of i-th tile in dimension d
size_type zero_tile_count_; ///< Number of zero tiles
static value_type threshold_; ///< The zero threshold
static value_type threshold_; ///< The current default threshold
value_type my_threshold_ =
threshold_; ///< The threshold used to initialize this

template <typename Op>
static vector_type recursive_outer_product(
Expand Down Expand Up @@ -144,9 +146,9 @@ class SparseShape {
template <ScaleBy ScaleBy_, bool Screen = true>
static size_type scale_tile_norms(
Tensor<T>& tile_norms,
const vector_type* MADNESS_RESTRICT const size_vectors) {
const vector_type* MADNESS_RESTRICT const size_vectors,
const value_type threshold = threshold_) {
const unsigned int dim = tile_norms.range().rank();
const value_type threshold = threshold_;
madness::AtomicInt zero_tile_count;
zero_tile_count = 0;

Expand Down Expand Up @@ -251,12 +253,12 @@ class SparseShape {
}

/// @brief screens out zero tiles by zeroing out the norms of tiles below
/// the threshold
/// `my_threshold_`
/// @return the number of zero tiles
auto screen_out_zero_tiles() {
decltype(zero_tile_count_) zero_tile_count = 0;
for (auto& n : tile_norms_) {
if (n < threshold()) {
if (n < my_threshold_) {
n = 0;
++zero_tile_count;
}
Expand All @@ -266,10 +268,12 @@ class SparseShape {

SparseShape(const Tensor<T>& tile_norms,
const std::shared_ptr<vector_type>& size_vectors,
const size_type zero_tile_count)
const size_type zero_tile_count,
const value_type my_threshold = threshold_)
: tile_norms_(tile_norms),
size_vectors_(size_vectors),
zero_tile_count_(zero_tile_count) {}
zero_tile_count_(zero_tile_count),
my_threshold_(my_threshold) {}

public:
/// Default constructor
Expand Down Expand Up @@ -321,9 +325,10 @@ class SparseShape {
/// This constructor uses tile norms given as a sparse tensor,
/// represented as a sequence of {index,value_type} data.
/// The tile norms are scaled by the inverse of the corresponding tile's
/// volumes. \tparam SparseNormSequence the sequence of \c
/// std::pair<index,value_type> objects,
/// where \c index is a directly-addressable sequence indices.
/// volumes.
/// \tparam SparseNormSequence the sequence of
/// `std::pair<index,value_type>` objects,
/// where `index` is a directly-addressable sequence indices.
/// \param tile_norms The Frobenius norm of tiles
/// \param trange The tiled range of the tensor
/// \param do_not_scale if true, assume that the tile norms in \c tile_norms
Expand All @@ -350,7 +355,7 @@ class SparseShape {
auto norm_per_element =
do_not_scale ? pair_idx_norm.second
: (pair_idx_norm.second / compute_tile_volume());
if (norm_per_element >= threshold()) {
if (norm_per_element >= my_threshold_) {
tile_norms_[pair_idx_norm.first] = norm_per_element;
--zero_tile_count_;
}
Expand Down Expand Up @@ -426,7 +431,8 @@ class SparseShape {
other.tile_norms_unscaled_.get()->clone())
: nullptr),
size_vectors_(other.size_vectors_),
zero_tile_count_(other.zero_tile_count_) {}
zero_tile_count_(other.zero_tile_count_),
my_threshold_(other.my_threshold_) {}

/// Copy assignment operator

Expand All @@ -441,6 +447,7 @@ class SparseShape {
: nullptr;
size_vectors_ = other.size_vectors_;
zero_tile_count_ = other.zero_tile_count_;
my_threshold_ = other.my_threshold_;
return *this;
}

Expand All @@ -459,7 +466,7 @@ class SparseShape {
template <typename Index>
bool is_zero(const Index& i) const {
TA_ASSERT(!tile_norms_.empty());
return tile_norms_[i] < threshold_;
return tile_norms_[i] < my_threshold_;
}

/// Check density
Expand Down Expand Up @@ -541,8 +548,8 @@ class SparseShape {
tile_norms_unscaled_ =
std::make_unique<decltype(tile_norms_)>(tile_norms_.clone());
[[maybe_unused]] auto should_be_zero =
scale_tile_norms<ScaleBy::Volume, false>(*tile_norms_unscaled_,
size_vectors_.get());
scale_tile_norms<ScaleBy::Volume, false>(
*tile_norms_unscaled_, size_vectors_.get(), my_threshold_);
TA_ASSERT(should_be_zero == 0);
}
return *(tile_norms_unscaled_.get());
Expand Down Expand Up @@ -578,7 +585,8 @@ class SparseShape {
Tensor<value_type> result_tile_norms =
tile_norms_.binary(mask_shape.tile_norms_, op);

return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count);
return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count,
my_threshold_);
}

// clang-format off
Expand Down Expand Up @@ -617,7 +625,8 @@ class SparseShape {
l = r;
});

return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count);
return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count,
my_threshold_);
}

// clang-format off
Expand Down Expand Up @@ -674,7 +683,8 @@ class SparseShape {
l = r;
});

return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count);
return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count,
my_threshold_);
}

// clang-format off
Expand Down Expand Up @@ -799,9 +809,9 @@ class SparseShape {
template <typename Op>
static SparseShape_ make_block(
const std::shared_ptr<vector_type>& size_vectors,
const TensorConstView<value_type>& block_view, const Op& op) {
const TensorConstView<value_type>& block_view, const Op& op,
const value_type threshold = threshold_) {
// Copy the data from arg to result
const value_type threshold = threshold_;
madness::AtomicInt zero_tile_count;
zero_tile_count = 0;
auto copy_op = [threshold, &zero_tile_count, &op](
Expand All @@ -818,7 +828,7 @@ class SparseShape {
Tensor<value_type> result_norms(Range(block_view.range().extent()));
result_norms.inplace_binary(shift(block_view), copy_op);

return SparseShape(result_norms, size_vectors, zero_tile_count);
return SparseShape(result_norms, size_vectors, zero_tile_count, threshold);
}

public:
Expand All @@ -833,9 +843,10 @@ class SparseShape {
detail::is_integral_range_v<Index2>>>
SparseShape block(const Index1& lower_bound,
const Index2& upper_bound) const {
return make_block(block_range(lower_bound, upper_bound),
tile_norms_.block(lower_bound, upper_bound),
[](auto&& arg) { return arg; });
return make_block(
block_range(lower_bound, upper_bound),
tile_norms_.block(lower_bound, upper_bound),
[](auto&& arg) { return arg; }, my_threshold_);
}

/// Create a copy of a sub-block of the shape
Expand All @@ -862,8 +873,9 @@ class SparseShape {
template <typename PairRange,
typename = std::enable_if_t<detail::is_gpair_range_v<PairRange>>>
SparseShape block(const PairRange& bounds) const {
return make_block(block_range(bounds), tile_norms_.block(bounds),
[](auto&& arg) { return arg; });
return make_block(
block_range(bounds), tile_norms_.block(bounds),
[](auto&& arg) { return arg; }, my_threshold_);
}

/// Create a copy of a sub-block of the shape
Expand All @@ -874,8 +886,9 @@ class SparseShape {
typename = std::enable_if_t<std::is_integral_v<Index>>>
SparseShape block(
const std::initializer_list<std::initializer_list<Index>>& bounds) const {
return make_block(block_range(bounds), tile_norms_.block(bounds),
[](auto&& arg) { return arg; });
return make_block(
block_range(bounds), tile_norms_.block(bounds),
[](auto&& arg) { return arg; }, my_threshold_);
}

/// Create a scaled sub-block of the shape
Expand All @@ -895,9 +908,10 @@ class SparseShape {
SparseShape block(const Index1& lower_bound, const Index2& upper_bound,
const Scalar factor) const {
const value_type abs_factor = to_abs_factor(factor);
return make_block(block_range(lower_bound, upper_bound),
tile_norms_.block(lower_bound, upper_bound),
[&abs_factor](auto&& arg) { return abs_factor * arg; });
return make_block(
block_range(lower_bound, upper_bound),
tile_norms_.block(lower_bound, upper_bound),
[&abs_factor](auto&& arg) { return abs_factor * arg; }, my_threshold_);
}

/// Create a scaled sub-block of the shape
Expand Down Expand Up @@ -933,8 +947,9 @@ class SparseShape {
detail::is_gpair_range_v<PairRange>>>
SparseShape block(const PairRange& bounds, const Scalar factor) const {
const value_type abs_factor = to_abs_factor(factor);
return make_block(block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; });
return make_block(
block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; }, my_threshold_);
}

/// Create a scaled sub-block of the shape
Expand All @@ -952,8 +967,9 @@ class SparseShape {
const std::initializer_list<std::initializer_list<Index>>& bounds,
const Scalar factor) const {
const value_type abs_factor = to_abs_factor(factor);
return make_block(block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; });
return make_block(
block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; }, my_threshold_);
}

/// Create a permuted sub-block of the shape
Expand Down Expand Up @@ -1073,8 +1089,10 @@ class SparseShape {
SparseShape block(const PairRange& bounds, const Scalar factor,
const Permutation& perm) const {
const value_type abs_factor = to_abs_factor(factor);
return make_block(block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; })
return make_block(
block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; },
my_threshold_)
.perm(perm);
}

Expand All @@ -1095,18 +1113,20 @@ class SparseShape {
const std::initializer_list<std::initializer_list<Index>>& bounds,
const Scalar factor, const Permutation& perm) const {
const value_type abs_factor = to_abs_factor(factor);
return make_block(block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; })
return make_block(
block_range(bounds), tile_norms_.block(bounds),
[&abs_factor](auto&& arg) { return abs_factor * arg; },
my_threshold_)
.perm(perm);
}

/// Create a permuted shape of this shape

/// \param perm The permutation to be applied
/// \return A new, permuted shape
/// \return A new, permuted shape using the same threshold as this object
SparseShape_ perm(const Permutation& perm) const {
return SparseShape_(tile_norms_.permute(perm), perm_size_vectors(perm),
zero_tile_count_);
zero_tile_count_, my_threshold_);
}

/// Scale shape
Expand Down Expand Up @@ -1139,7 +1159,8 @@ class SparseShape {

Tensor<value_type> result_tile_norms = tile_norms_.unary(op);

return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count);
return SparseShape_(result_tile_norms, size_vectors_, zero_tile_count,
my_threshold_);
}

/// Scale and permute shape
Expand Down Expand Up @@ -1170,7 +1191,7 @@ class SparseShape {
Tensor<value_type> result_tile_norms = tile_norms_.unary(op, perm);

return SparseShape_(result_tile_norms, perm_size_vectors(perm),
zero_tile_count);
zero_tile_count, my_threshold_);
}

/// Add shapes
Expand Down