From 79392250384f5e95200f87ae792e28de327b6f9e Mon Sep 17 00:00:00 2001 From: suyash67 Date: Wed, 2 Jul 2025 10:03:12 +0200 Subject: [PATCH 1/6] consistent arg names, fix basic warnings. --- .../stdlib/primitives/uint/uint.cpp | 32 +++++++++++-------- .../stdlib/primitives/uint/uint.hpp | 8 +++-- .../stdlib/primitives/witness/witness.hpp | 2 ++ 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp index 2fbd81f9d756..080168f3bfd2 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp @@ -13,35 +13,35 @@ template std::vector uint::constrain_accumulators(Builder* context, const uint32_t witness_index) const { - const auto res = context->decompose_into_default_range(witness_index, width, bits_per_limb); + std::vector res = context->decompose_into_default_range(witness_index, width, bits_per_limb); return res; } template -uint::uint(const witness_t& witness) - : context(witness.context) +uint::uint(const witness_t& other) + : context(other.context) , witness_status(WitnessStatus::OK) { - if (witness.witness_index == IS_CONSTANT) { - additive_constant = witness.witness; + if (other.is_constant()) { + additive_constant = other.witness; witness_index = IS_CONSTANT; } else { - accumulators = constrain_accumulators(context, witness.witness_index); - witness_index = witness.witness_index; + accumulators = constrain_accumulators(context, other.witness_index); + witness_index = other.witness_index; } } template -uint::uint(const field_t& value) - : context(value.context) +uint::uint(const field_t& other) + : context(other.context) , additive_constant(0) , witness_status(WitnessStatus::OK) { - if (value.witness_index == IS_CONSTANT) { - additive_constant = value.additive_constant; + if (other.is_constant()) { + additive_constant = other.additive_constant; witness_index = IS_CONSTANT; } else { - field_t norm = value.normalize(); + field_t norm = other.normalize(); accumulators = constrain_accumulators(context, norm.get_witness_index()); witness_index = norm.get_witness_index(); } @@ -129,7 +129,7 @@ uint::uint(const uint& other) {} template -uint::uint(uint&& other) +uint::uint(uint&& other) noexcept : context(other.context) , additive_constant(other.additive_constant) , witness_status(other.witness_status) @@ -139,6 +139,9 @@ uint::uint(uint&& other) template uint& uint::operator=(const uint& other) { + if (this == &other) { + return *this; + } context = other.context; additive_constant = other.additive_constant; witness_status = other.witness_status; @@ -147,7 +150,8 @@ template uint& uint uint& uint::operator=(uint&& other) +template +uint& uint::operator=(uint&& other) noexcept { context = other.context; additive_constant = other.additive_constant; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp index 7c68f3de436a..eb49961bd207 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp @@ -30,16 +30,18 @@ template class uint { : uint(static_cast(v)) {} - std::vector constrain_accumulators(Builder* ctx, const uint32_t witness_index) const; + std::vector constrain_accumulators(Builder* context, const uint32_t witness_index) const; static constexpr size_t bits_per_limb = 12; static constexpr size_t num_accumulators() { return (width + bits_per_limb - 1) / bits_per_limb; } uint(const uint& other); - uint(uint&& other); + uint(uint&& other) noexcept; + + ~uint() = default; uint& operator=(const uint& other); - uint& operator=(uint&& other); + uint& operator=(uint&& other) noexcept; explicit operator byte_array() const; explicit operator field_t() const; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/witness/witness.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/witness/witness.hpp index 3c42571feecd..902023130b93 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/witness/witness.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/witness/witness.hpp @@ -49,6 +49,8 @@ template class witness_t { return out; } + bool is_constant() const { return witness_index == IS_CONSTANT; } + bb::fr witness; uint32_t witness_index = IS_CONSTANT; Builder* context = nullptr; From b168a9ced3d1b28e370cd5483dc1f8825092849d Mon Sep 17 00:00:00 2001 From: suyash67 Date: Wed, 2 Jul 2025 11:02:55 +0200 Subject: [PATCH 2/6] make byte_array's constructor more efficient. --- .../stdlib/primitives/uint/uint.cpp | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp index 080168f3bfd2..3ab7c701acc1 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp @@ -75,15 +75,28 @@ uint::uint(const byte_array& other) { field_t accumulator(context, fr::zero()); field_t scaling_factor(context, fr::one()); - const auto bytes = other.bytes(); + const auto& bytes = other.bytes(); + const size_t num_bytes = bytes.size(); - // TODO JUMP IN STEPS OF TWO - for (size_t i = 0; i < bytes.size(); ++i) { - accumulator = accumulator + scaling_factor * bytes[bytes.size() - 1 - i]; - scaling_factor = scaling_factor * fr(256); + // Process pairs of bytes from the end of the byte array. + for (size_t i = 0; i < (num_bytes / 2); ++i) { + const field_t even_scaling_factor = scaling_factor; + const field_t odd_scaling_factor = scaling_factor * fr(256); + accumulator = accumulator.add_two(bytes[num_bytes - 1 - (2 * i)] * even_scaling_factor, + bytes[num_bytes - 1 - (2 * i + 1)] * odd_scaling_factor); + + scaling_factor = scaling_factor * fr(256 * 256); // Scale by (2^8 * 2^8). + } + + // Process the last byte if the byte array has an odd number of bytes. + if (num_bytes % 2 == 1) { + const field_t& last_byte = bytes[0]; + accumulator = accumulator + (scaling_factor * last_byte); } + + // Normalize the accumulator and set the witness index or additive constant. accumulator = accumulator.normalize(); - if (accumulator.witness_index == IS_CONSTANT) { + if (accumulator.is_constant()) { additive_constant = uint256_t(accumulator.additive_constant); } else { witness_index = accumulator.witness_index; From 1dde319fbdb639dd1c23ee121989ecf0a802c0da Mon Sep 17 00:00:00 2001 From: suyash67 Date: Wed, 2 Jul 2025 11:40:23 +0200 Subject: [PATCH 3/6] make bit_array's constructor more efficient. --- .../stdlib/primitives/uint/uint.cpp | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp index 3ab7c701acc1..88dce0319c5f 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp @@ -118,14 +118,27 @@ uint::uint(Builder* parent_context, const std::vector accumulator(context, fr::zero()); field_t scaling_factor(context, fr::one()); + const size_t num_wires = wires.size(); - // TODO JUMP IN STEPS OF TWO - for (size_t i = 0; i < wires.size(); ++i) { - accumulator = accumulator + scaling_factor * field_t(wires[i]); - scaling_factor = scaling_factor + scaling_factor; + for (size_t i = 0; i < (num_wires / 2); ++i) { + const field_t even_scaling_factor = scaling_factor; + const field_t odd_scaling_factor = scaling_factor * fr(2); + + accumulator = accumulator.add_two(field_t(wires[2 * i]) * even_scaling_factor, + field_t(wires[2 * i + 1]) * odd_scaling_factor); + + scaling_factor = scaling_factor * fr(4); // Scale by (2^1 * 2^1). + } + + // Process the last wire if the number of wires is odd. + if (num_wires % 2 == 1) { + const field_t& last_wire = field_t(wires[num_wires - 1]); + accumulator = accumulator + (scaling_factor * last_wire); } + + // Normalize the accumulator and set the witness index or additive constant. accumulator = accumulator.normalize(); - if (accumulator.witness_index == IS_CONSTANT) { + if (accumulator.is_constant()) { additive_constant = uint256_t(accumulator.additive_constant); } else { witness_index = accumulator.witness_index; From 70a4dab76f7cfb28ba77c3d7abffbb5b847f4498 Mon Sep 17 00:00:00 2001 From: suyash67 Date: Fri, 18 Jul 2025 13:45:57 +0000 Subject: [PATCH 4/6] add new function split_at in field_t. use that in bitwise ops in uint. next pr should remove slice function from field_t. --- .../stdlib/primitives/field/field.cpp | 47 +++++ .../stdlib/primitives/field/field.hpp | 3 + .../stdlib/primitives/field/field.test.cpp | 61 ++++++ .../stdlib/primitives/uint/logic.cpp | 196 ++++++++++-------- .../stdlib/primitives/uint/uint.hpp | 4 + 5 files changed, 221 insertions(+), 90 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp index 30047c9505b1..e9fc9ec9a516 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp @@ -1290,6 +1290,53 @@ std::array, 3> field_t::slice(const uint8_t msb, const return result; } +template +std::pair, field_t> field_t::split_at(const size_t lsb_index, + const size_t num_bits) const +{ + ASSERT(lsb_index < num_bits); + + const uint256_t value = get_value(); + const uint256_t hi = value >> lsb_index; + const uint256_t lo = value % (uint256_t(1) << lsb_index); + + if (is_constant()) { + // If `*this` is constant, we can return the split values directly + ASSERT(lo + (hi << lsb_index) == value); + return std::make_pair(field_t(lo), field_t(hi)); + } + + // Handle edge case when lsb_index == 0 + if (lsb_index == 0) { + ASSERT(hi == value); + ASSERT(lo == 0); + create_range_constraint(num_bits, "split_at: hi value too large."); + return std::make_pair(field_t(0), *this); + } + + Builder* ctx = get_context(); + ASSERT(ctx != nullptr); + + field_t lo_wit(witness_t(ctx, lo)); + field_t hi_wit(witness_t(ctx, hi)); + + // Ensure that `lo_wit` is in the range [0, 2^lsb_index - 1] + lo_wit.create_range_constraint(lsb_index, "split_at: lo value too large."); + + // Ensure that `hi_wit` is in the range [0, 2^(num_bits - lsb_index) - 1] + hi_wit.create_range_constraint(num_bits - lsb_index, "split_at: hi value too large."); + + // Check that *this = lo_wit + hi_wit * 2^{lsb_index} + const field_t decomposed = lo_wit + (hi_wit * field_t(uint256_t(1) << lsb_index)); + assert_equal(decomposed, "split_at: decomposition failed"); + + // Set the origin tag for both witnesses + lo_wit.set_origin_tag(tag); + hi_wit.set_origin_tag(tag); + + return std::make_pair(lo_wit, hi_wit); +} + /** * @brief Build constraints establishing the decomposition of `*this` into bits. * diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp index 295dfe6997cc..9d64a4ee228d 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.hpp @@ -339,6 +339,9 @@ template class field_t { std::array slice(uint8_t msb, uint8_t lsb) const; + std::pair, field_t> split_at( + const size_t lsb_index, const size_t num_bits = grumpkin::MAX_NO_WRAP_INTEGER_BIT_LENGTH) const; + bool_t is_zero() const; void create_range_constraint(size_t num_bits, std::string const& msg = "field_t::range_constraint") const; diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp index 4ab429965782..413d9e246823 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.test.cpp @@ -4,6 +4,7 @@ #include "barretenberg/circuit_checker/circuit_checker.hpp" #include "barretenberg/common/streams.hpp" #include "barretenberg/numeric/random/engine.hpp" +#include "barretenberg/numeric/uint256/uint256.hpp" #include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp" #include #include @@ -859,6 +860,56 @@ template class stdlib_field : public testing::Test { EXPECT_TRUE(builder.err() == "slice: hi value too large."); } + static void test_split_at() + { + Builder builder = Builder(); + + // Test different bit sizes + std::vector test_bit_sizes = { 8, 16, 32, 100, 252 }; + + // Lambda to check split_at functionality + auto check_split_at = [&](const field_ct& a, size_t start, size_t num_bits) { + const uint256_t a_native = a.get_value(); + auto split_data = a.split_at(start, num_bits); + EXPECT_EQ(split_data.first.get_value(), a_native & ((uint256_t(1) << start) - 1)); + EXPECT_EQ(split_data.second.get_value(), (a_native >> start) & ((uint256_t(1) << num_bits) - 1)); + + if (a.is_constant()) { + EXPECT_TRUE(split_data.first.is_constant()); + EXPECT_TRUE(split_data.second.is_constant()); + } + + if (start == 0) { + EXPECT_TRUE(split_data.first.is_constant()); + EXPECT_TRUE(split_data.first.get_value() == 0); + EXPECT_EQ(split_data.second.get_value(), a.get_value()); + } + }; + + for (size_t num_bits : test_bit_sizes) { + uint256_t a_native = engine.get_random_uint256() & ((uint256_t(1) << num_bits) - 1); + + // check split_at for a constant + field_ct a_constant(a_native); + check_split_at(a_constant, 0, num_bits); + check_split_at(a_constant, num_bits / 4, num_bits); + check_split_at(a_constant, num_bits / 3, num_bits); + check_split_at(a_constant, num_bits / 2, num_bits); + check_split_at(a_constant, num_bits - 1, num_bits); + + // check split_at for a witness + field_ct a_witness(witness_ct(&builder, a_native)); + check_split_at(a_witness, 0, num_bits); + check_split_at(a_witness, num_bits / 4, num_bits); + check_split_at(a_witness, num_bits / 3, num_bits); + check_split_at(a_witness, num_bits / 2, num_bits); + check_split_at(a_witness, num_bits - 1, num_bits); + } + + bool result = CircuitChecker::check(builder); + EXPECT_EQ(result, true); + } + static void test_three_bit_table() { Builder builder = Builder(); @@ -1362,6 +1413,12 @@ template class stdlib_field : public testing::Test { EXPECT_EQ(element.get_origin_tag(), submitted_value_origin_tag); } + // Split preserves tags + const size_t num_bits = uint256_t(a.get_value()).get_msb() + 1; + auto split_data = a.split_at(num_bits / 2, num_bits); + EXPECT_EQ(split_data.first.get_origin_tag(), submitted_value_origin_tag); + EXPECT_EQ(split_data.second.get_origin_tag(), submitted_value_origin_tag); + // Decomposition preserves tags auto decomposed_bits = a.decompose_into_bits(); @@ -1555,6 +1612,10 @@ TYPED_TEST(stdlib_field, test_slice_random) { TestFixture::test_slice_random(); } +TYPED_TEST(stdlib_field, test_split_at) +{ + TestFixture::test_split_at(); +} TYPED_TEST(stdlib_field, test_three_bit_table) { TestFixture::test_three_bit_table(); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp index 9d6ef527eb02..c1ca8945c150 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp @@ -55,40 +55,50 @@ uint uint::operator>>(const size_t shift) cons return *this; } - uint64_t bits_per_hi_limb; - // last limb will not likely bit `bits_per_limb`. Need to be careful with our range check - if (shift >= ((width / bits_per_limb) * bits_per_limb)) { - bits_per_hi_limb = width % bits_per_limb; - } else { - bits_per_hi_limb = bits_per_limb; - } - const uint64_t slice_bit_position = shift % bits_per_limb; + // Example for uint32_t: + // + // |<------ acc[2] ------>||<----------- acc[1] ----------->||<------- acc[0] ------->| + // val: [ 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 ] + // ↑ + // [<--------------- keep --------------->][<-------- discard -------->] + // + // out: [ 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 ] + // |<------ acc[2] ------>||<----- acc[1].hi ----->| + // + // Suppose the shift is 15, then we must discard the 15 least significant bits of the accumulator. + // The accumulator is split into three parts, so we clearly need to split acc[1]. On splitting, we must + // discard the lower slice of acc[1] and keep the upper slice. Thus, the updated uint value will be: + // + // acc[1].hi + (acc[2] << (24 - 15)) + // + // Let us first fetch the accumlator that needs to be sliced. const size_t accumulator_index = shift / bits_per_limb; - const uint32_t slice_index = accumulators[accumulator_index]; - const uint64_t slice_value = uint256_t(context->get_variable(slice_index)).data[0]; + const uint32_t slice_witness_index = accumulators[accumulator_index]; + const field_t acc_to_be_sliced = field_t::from_witness_index(context, slice_witness_index); - const uint64_t slice_lo = slice_value % (1ULL << slice_bit_position); - const uint64_t slice_hi = slice_value >> slice_bit_position; - const uint32_t slice_lo_idx = slice_bit_position ? context->add_variable(slice_lo) : context->zero_idx; - const uint32_t slice_hi_idx = - (slice_bit_position != bits_per_limb) ? context->add_variable(slice_hi) : context->zero_idx; + // Now, lets calculate: + // (i) bit position (from lsb) at which we need to slice. + // (ii) number of bits in the slice based on whether it is the highest slice or not. + const uint64_t slice_bit_position = shift % bits_per_limb; + const bool is_slice_hi = (accumulator_index == num_accumulators() - 1); + const uint8_t num_bits_per_limb = is_slice_hi ? bits_in_high_limb : bits_per_limb; - context->create_big_add_gate( - { slice_index, slice_lo_idx, context->zero_idx, slice_hi_idx, -1, 1, 0, (1 << slice_bit_position), 0 }); + // Finally, we can slice the accumulator. + // The slice_hi will be the upper slice of the accumulator, which we will keep. + // The slice_lo will be the lower slice of the accumulator, which we will discard. + // Its important to note that although slice_lo is not used here, it is still created and properly constrained + // in the split_at function. + field_t slice_hi = acc_to_be_sliced.split_at(slice_bit_position, num_bits_per_limb).second; - if (slice_bit_position != 0) { - context->create_new_range_constraint(slice_lo_idx, (1ULL << slice_bit_position) - 1); - } - context->create_new_range_constraint(slice_hi_idx, (1ULL << (bits_per_hi_limb - slice_bit_position)) - 1); + // Now we reconstruct the shifted uint value. std::vector> sublimbs; - sublimbs.emplace_back(field_t::from_witness_index(context, slice_hi_idx)); + sublimbs.emplace_back(slice_hi); const size_t start = accumulator_index + 1; field_t coefficient(context, uint64_t(1ULL << (start * bits_per_limb - shift))); field_t shifter(context, uint64_t(1ULL << bits_per_limb)); for (size_t i = accumulator_index + 1; i < num_accumulators(); ++i) { - sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * - field_t(coefficient)); + sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * coefficient); coefficient *= shifter; } @@ -117,49 +127,48 @@ uint uint::operator<<(const size_t shift) cons return *this; } - uint64_t slice_bit_position; - size_t accumulator_index; - size_t bits_per_hi_limb; - // most significant limb is only 2 bits long (for u32), need to be careful about which slice we index, - // and how large the range check is on our hi limb - if (shift < (width - ((width / bits_per_limb) * bits_per_limb))) { - bits_per_hi_limb = width % bits_per_limb; - slice_bit_position = bits_per_hi_limb - (shift % bits_per_hi_limb); - accumulator_index = num_accumulators() - 1; - } else { - const size_t offset = width % bits_per_limb; - slice_bit_position = bits_per_limb - ((shift - offset) % bits_per_limb); - accumulator_index = num_accumulators() - 2 - ((shift - offset) / bits_per_limb); - bits_per_hi_limb = bits_per_limb; - } - - const uint32_t slice_index = accumulators[accumulator_index]; - const uint64_t slice_value = uint256_t(context->get_variable(slice_index)).data[0]; - - const uint64_t slice_lo = slice_value % (1ULL << slice_bit_position); - const uint64_t slice_hi = slice_value >> slice_bit_position; - const uint32_t slice_lo_idx = slice_bit_position ? context->add_variable(slice_lo) : context->zero_idx; - const uint32_t slice_hi_idx = - (slice_bit_position != bits_per_hi_limb) ? context->add_variable(slice_hi) : context->zero_idx; - - context->create_big_add_gate( - { slice_index, slice_lo_idx, context->zero_idx, slice_hi_idx, -1, 1, 0, (1 << slice_bit_position), 0 }); - - context->create_new_range_constraint(slice_lo_idx, (1ULL << slice_bit_position) - 1); - - if (slice_bit_position != bits_per_limb) { - context->create_new_range_constraint(slice_hi_idx, (1ULL << (bits_per_hi_limb - slice_bit_position)) - 1); - } - + // Example for uint32_t: + // + // |<------ acc[2] ------->||<----------- acc[1] ----------->||<------- acc[0] ------->| + // val: [ 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 ] + // ↑ + // [<---- discard ---->][<---------------------- keep ----------------------->] + // + // out: [ 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 ] + // [<- acc[2].lo ->||<----------- acc[1] ----------->||<------- acc[0] ------->| + // + // Suppose the shift is 7, then we must discard the 7 most significant bits of the accumulator, and move + // the remaining bits to the left. The accumulator is split into three parts, so in this case we clearly + // need to split acc[2]. On splitting, we must discard the higher slice of acc[2] and keep the lower slice. + // Thus, the updated uint value will be: + // + // (acc[2].lo << (24 + 7)) + (acc[1] << (12 + 7)) + (acc[0] << 7) + // + // Let us first fetch the accumulator that needs to be sliced. + // We will do so by adjusting the shift from the most-significant bit. + size_t adjusted_shift = width - shift; + const size_t accumulator_index = adjusted_shift / bits_per_limb; + const uint32_t slice_witness_index = accumulators[accumulator_index]; + const field_t acc_to_be_sliced = field_t::from_witness_index(context, slice_witness_index); + + // Now, lets calculate: + // (i) bit position (from lsb) at which we need to slice. + // (ii) number of bits in the slice based on whether it is the highest slice or not. + const bool is_slice_hi = (accumulator_index == num_accumulators() - 1); + const uint64_t slice_bit_position = adjusted_shift % bits_per_limb; + const uint8_t num_bits_per_limb = is_slice_hi ? bits_in_high_limb : bits_per_limb; + + // We can now slice the accumulator. + field_t slice_lo = acc_to_be_sliced.split_at(slice_bit_position, num_bits_per_limb).first; + + // Now we reconstruct the shifted uint value. std::vector> sublimbs; - sublimbs.emplace_back(field_t::from_witness_index(context, slice_lo_idx) * - field_t(context, 1ULL << ((accumulator_index)*bits_per_limb + shift))); + sublimbs.emplace_back(slice_lo * field_t(context, 1ULL << ((accumulator_index * bits_per_limb) + shift))); field_t coefficient(context, uint64_t(1ULL << shift)); field_t shifter(context, uint64_t(1ULL << bits_per_limb)); for (size_t i = 0; i < accumulator_index; ++i) { - sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * - field_t(coefficient)); + sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * coefficient); coefficient *= shifter; } @@ -173,6 +182,7 @@ uint uint::operator<<(const size_t shift) cons template uint uint::ror(const size_t target_rotation) const { + // Note: width is always a power of two, so we can use bitwise AND. const size_t rotation = target_rotation & (width - 1); const auto rotate = [](const uint256_t input, const uint64_t rot) { @@ -193,51 +203,57 @@ uint uint::ror(const size_t target_rotation) c return *this; } - const size_t shift = rotation; - uint64_t bits_per_hi_limb; - // last limb will not likely bit `bits_per_limb`. Need to be careful with our range check - if (shift >= ((width / bits_per_limb) * bits_per_limb)) { - bits_per_hi_limb = width % bits_per_limb; - } else { - bits_per_hi_limb = bits_per_limb; - } - const uint64_t slice_bit_position = shift % bits_per_limb; + // Example for uint32_t: + // + // |<------ acc[2] ------>||<----------- acc[1] ----------->||<------- acc[0] ------->| + // val: [ 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 ] + // ↑ + // [<--------------- keep --------------->][<-------- right rotate -------->] + // + // out: [ 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 ] [ 31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 ] + // [<- acc[1].lo ->||<------- acc[0] ------->| |<------ acc[2] ------>||<----- acc[1].hi ------>] + // + // Suppose the right-rotation is 15 (i.e., rotate = 15), then we must right-rotate the 15 least + // significant bits of the accumulator. The accumulator is split into three parts, so in this case we need to split + // acc[1]. On splitting, we must "rotate" the lower slice of acc[1] and keep the upper slice. Thus, the updated uint + // value will be: + // + // acc[1].hi + (acc[2] << (24 - 15)) + (acc[0] >> (32 - 15)) + (acc[1].lo >> (32 - 15 + 12)) + // + // Let us first fetch the accumlator that needs to be sliced. + size_t shift = rotation; const size_t accumulator_index = shift / bits_per_limb; - const uint32_t slice_index = accumulators[accumulator_index]; - const uint64_t slice_value = uint256_t(context->get_variable(slice_index)).data[0]; + const uint32_t slice_witness_index = accumulators[accumulator_index]; + const field_t acc_to_be_sliced = field_t::from_witness_index(context, slice_witness_index); - const uint64_t slice_lo = slice_value % (1ULL << slice_bit_position); - const uint64_t slice_hi = slice_value >> slice_bit_position; - const uint32_t slice_lo_idx = slice_bit_position ? context->add_variable(slice_lo) : context->zero_idx; - const uint32_t slice_hi_idx = - (slice_bit_position != bits_per_limb) ? context->add_variable(slice_hi) : context->zero_idx; + // Now, lets calculate: + // (i) bit position (from lsb) at which we need to slice. + // (ii) number of bits in the slice based on whether it is the highest slice or not. + const uint64_t slice_bit_position = shift % bits_per_limb; + const bool is_slice_hi = (accumulator_index == num_accumulators() - 1); + const uint8_t num_bits_per_limb = is_slice_hi ? bits_in_high_limb : bits_per_limb; - context->create_big_add_gate( - { slice_index, slice_lo_idx, context->zero_idx, slice_hi_idx, -1, 1, 0, (1 << slice_bit_position), 0 }); + // Finally, we can slice the accumulator. + auto [slice_lo, slice_hi] = acc_to_be_sliced.split_at(slice_bit_position, num_bits_per_limb); - if (slice_bit_position != 0) { - context->create_new_range_constraint(slice_lo_idx, (1ULL << slice_bit_position) - 1); - } - context->create_new_range_constraint(slice_hi_idx, (1ULL << (bits_per_hi_limb - slice_bit_position)) - 1); + // Now we reconstruct the shifted uint value. std::vector> sublimbs; - sublimbs.emplace_back(field_t::from_witness_index(context, slice_hi_idx)); + sublimbs.emplace_back(slice_hi); const size_t start = accumulator_index + 1; field_t coefficient(context, uint64_t(1ULL << (start * bits_per_limb - shift))); field_t shifter(context, uint64_t(1ULL << bits_per_limb)); for (size_t i = accumulator_index + 1; i < num_accumulators(); ++i) { - sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * - field_t(coefficient)); + sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * coefficient); coefficient *= shifter; } coefficient = field_t(context, uint64_t(1ULL << (width - shift))); for (size_t i = 0; i < accumulator_index; ++i) { - sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * - field_t(coefficient)); + sublimbs.emplace_back(field_t::from_witness_index(context, accumulators[i]) * coefficient); coefficient *= shifter; } - sublimbs.emplace_back(field_t::from_witness_index(context, slice_lo_idx) * field_t(coefficient)); + sublimbs.emplace_back(slice_lo * field_t(coefficient)); uint32_t result_index = field_t::accumulate(sublimbs).get_witness_index(); uint result(context); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp index eb49961bd207..5e77b9fa4738 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.hpp @@ -18,6 +18,9 @@ template class uint { using FF = typename Builder::FF; static constexpr size_t width = sizeof(Native) * 8; + static_assert(width == 8 || width == 16 || width == 32 || width == 64, + "unsupported uint width, supported uint widths are: 8, 16, 32, and 64 bits."); + uint(const witness_t& other); uint(const field_t& other); uint(const uint256_t& value = 0); @@ -33,6 +36,7 @@ template class uint { std::vector constrain_accumulators(Builder* context, const uint32_t witness_index) const; static constexpr size_t bits_per_limb = 12; + static constexpr size_t bits_in_high_limb = width % bits_per_limb == 0 ? bits_per_limb : width % bits_per_limb; static constexpr size_t num_accumulators() { return (width + bits_per_limb - 1) / bits_per_limb; } uint(const uint& other); From 652615a11a4267a09d7ddf7c8a1ac5e4f0b01e45 Mon Sep 17 00:00:00 2001 From: suyash67 Date: Tue, 22 Jul 2025 09:42:15 +0000 Subject: [PATCH 5/6] resolve comments. --- .../stdlib/primitives/field/field.cpp | 4 +- .../stdlib/primitives/uint/uint.cpp | 54 +++++++------------ 2 files changed, 21 insertions(+), 37 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp index e9fc9ec9a516..e586d20ec31c 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/field/field.cpp @@ -1327,8 +1327,8 @@ std::pair, field_t> field_t::split_at(const s hi_wit.create_range_constraint(num_bits - lsb_index, "split_at: hi value too large."); // Check that *this = lo_wit + hi_wit * 2^{lsb_index} - const field_t decomposed = lo_wit + (hi_wit * field_t(uint256_t(1) << lsb_index)); - assert_equal(decomposed, "split_at: decomposition failed"); + const field_t reconstructed = lo_wit + (hi_wit * field_t(uint256_t(1) << lsb_index)); + assert_equal(reconstructed, "split_at: decomposition failed"); // Set the origin tag for both witnesses lo_wit.set_origin_tag(tag); diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp index 88dce0319c5f..bf52e37242ce 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/uint.cpp @@ -73,29 +73,21 @@ uint::uint(const byte_array& other) , accumulators() , witness_index(IS_CONSTANT) { - field_t accumulator(context, fr::zero()); - field_t scaling_factor(context, fr::one()); const auto& bytes = other.bytes(); const size_t num_bytes = bytes.size(); + field_t scaling_factor(context, fr::one()); - // Process pairs of bytes from the end of the byte array. - for (size_t i = 0; i < (num_bytes / 2); ++i) { - const field_t even_scaling_factor = scaling_factor; - const field_t odd_scaling_factor = scaling_factor * fr(256); - accumulator = accumulator.add_two(bytes[num_bytes - 1 - (2 * i)] * even_scaling_factor, - bytes[num_bytes - 1 - (2 * i + 1)] * odd_scaling_factor); - - scaling_factor = scaling_factor * fr(256 * 256); // Scale by (2^8 * 2^8). - } - - // Process the last byte if the byte array has an odd number of bytes. - if (num_bytes % 2 == 1) { - const field_t& last_byte = bytes[0]; - accumulator = accumulator + (scaling_factor * last_byte); + // Collect the bytes in reverse order and scale them appropriately. + std::vector> scaled_bytes; + scaled_bytes.reserve(num_bytes); + for (size_t i = 0; i < num_bytes; ++i) { + scaled_bytes.push_back(bytes[num_bytes - 1 - i] * scaling_factor); + scaling_factor = scaling_factor * fr(256); // Scale by 2^8. } + field_t accumulator = field_t::accumulate(scaled_bytes); - // Normalize the accumulator and set the witness index or additive constant. - accumulator = accumulator.normalize(); + // If the accumulator is constant, we set the additive constant. + // Otherwise, we set the witness index. if (accumulator.is_constant()) { additive_constant = uint256_t(accumulator.additive_constant); } else { @@ -116,28 +108,20 @@ uint::uint(Builder* parent_context, const std::vector accumulator(context, fr::zero()); field_t scaling_factor(context, fr::one()); const size_t num_wires = wires.size(); - for (size_t i = 0; i < (num_wires / 2); ++i) { - const field_t even_scaling_factor = scaling_factor; - const field_t odd_scaling_factor = scaling_factor * fr(2); - - accumulator = accumulator.add_two(field_t(wires[2 * i]) * even_scaling_factor, - field_t(wires[2 * i + 1]) * odd_scaling_factor); - - scaling_factor = scaling_factor * fr(4); // Scale by (2^1 * 2^1). - } - - // Process the last wire if the number of wires is odd. - if (num_wires % 2 == 1) { - const field_t& last_wire = field_t(wires[num_wires - 1]); - accumulator = accumulator + (scaling_factor * last_wire); + // Collect the bits and scale them appropriately. + std::vector> scaled_bits; + scaled_bits.reserve(num_wires); + for (size_t i = 0; i < num_wires; ++i) { + scaled_bits.push_back(field_t(wires[i]) * scaling_factor); + scaling_factor = scaling_factor * fr(2); // Scale by 2^1. } + field_t accumulator = field_t::accumulate(scaled_bits); - // Normalize the accumulator and set the witness index or additive constant. - accumulator = accumulator.normalize(); + // If the accumulator is constant, we set the additive constant. + // Otherwise, we set the witness index. if (accumulator.is_constant()) { additive_constant = uint256_t(accumulator.additive_constant); } else { From c5a2e0e134be47ccadfbcde48ad769a2e3c2278e Mon Sep 17 00:00:00 2001 From: suyash67 Date: Tue, 22 Jul 2025 10:30:31 +0000 Subject: [PATCH 6/6] fix type err. --- .../cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp index c1ca8945c150..534741ea3b28 100644 --- a/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp +++ b/barretenberg/cpp/src/barretenberg/stdlib/primitives/uint/logic.cpp @@ -79,7 +79,7 @@ uint uint::operator>>(const size_t shift) cons // Now, lets calculate: // (i) bit position (from lsb) at which we need to slice. // (ii) number of bits in the slice based on whether it is the highest slice or not. - const uint64_t slice_bit_position = shift % bits_per_limb; + const size_t slice_bit_position = shift % bits_per_limb; const bool is_slice_hi = (accumulator_index == num_accumulators() - 1); const uint8_t num_bits_per_limb = is_slice_hi ? bits_in_high_limb : bits_per_limb; @@ -155,7 +155,7 @@ uint uint::operator<<(const size_t shift) cons // (i) bit position (from lsb) at which we need to slice. // (ii) number of bits in the slice based on whether it is the highest slice or not. const bool is_slice_hi = (accumulator_index == num_accumulators() - 1); - const uint64_t slice_bit_position = adjusted_shift % bits_per_limb; + const size_t slice_bit_position = adjusted_shift % bits_per_limb; const uint8_t num_bits_per_limb = is_slice_hi ? bits_in_high_limb : bits_per_limb; // We can now slice the accumulator. @@ -229,7 +229,7 @@ uint uint::ror(const size_t target_rotation) c // Now, lets calculate: // (i) bit position (from lsb) at which we need to slice. // (ii) number of bits in the slice based on whether it is the highest slice or not. - const uint64_t slice_bit_position = shift % bits_per_limb; + const size_t slice_bit_position = shift % bits_per_limb; const bool is_slice_hi = (accumulator_index == num_accumulators() - 1); const uint8_t num_bits_per_limb = is_slice_hi ? bits_in_high_limb : bits_per_limb;