Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ TEST(ECDSASecp256r1, TestECDSAConstraintSucceed)
.blake2s_constraints = {},
.blake3_constraints = {},
.keccak_constraints = {},
.keccak_permutations = {},
.pedersen_constraints = {},
.pedersen_hash_constraints = {},
.poseidon2_constraints = {},
.multi_scalar_mul_constraints = {},
.ec_add_constraints = {},
.recursion_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ struct HashInput {
friend bool operator==(HashInput const& lhs, HashInput const& rhs) = default;
};

struct Keccakf1600 {
std::vector<uint32_t> state;
std::vector<uint32_t> result;

// For serialization, update with any new fields
MSGPACK_FIELDS(state, result);
friend bool operator==(Keccakf1600 const& lhs, Keccakf1600 const& rhs) = default;
};

struct Keccakf1600 {
std::array<uint32_t, 25> state;
std::array<uint32_t, 25> result;
Expand Down
63 changes: 62 additions & 1 deletion barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ struct BlackBoxFuncCall {
static Keccakf1600 bincodeDeserialize(std::vector<uint8_t>);
};

struct Keccakf1600 {
std::vector<Circuit::FunctionInput> inputs;
std::vector<Circuit::Witness> outputs;

friend bool operator==(const Keccakf1600&, const Keccakf1600&);
std::vector<uint8_t> bincodeSerialize() const;
static Keccakf1600 bincodeDeserialize(std::vector<uint8_t>);
};

struct RecursiveAggregation {
std::vector<Program::FunctionInput> verification_key;
std::vector<Program::FunctionInput> proof;
Expand Down Expand Up @@ -3359,6 +3368,58 @@ Program::BlackBoxFuncCall::Keccakf1600 serde::Deserializable<Program::BlackBoxFu

namespace Program {

inline bool operator==(const BlackBoxFuncCall::Keccakf1600& lhs, const BlackBoxFuncCall::Keccakf1600& rhs)
{
if (!(lhs.inputs == rhs.inputs)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxFuncCall::Keccakf1600::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxFuncCall::Keccakf1600>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxFuncCall::Keccakf1600 BlackBoxFuncCall::Keccakf1600::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxFuncCall::Keccakf1600>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxFuncCall::Keccakf1600>::serialize(
const Circuit::BlackBoxFuncCall::Keccakf1600& obj, Serializer& serializer)
{
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxFuncCall::Keccakf1600 serde::Deserializable<Circuit::BlackBoxFuncCall::Keccakf1600>::deserialize(
Deserializer& deserializer)
{
Circuit::BlackBoxFuncCall::Keccakf1600 obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxFuncCall::RecursiveAggregation& lhs,
const BlackBoxFuncCall::RecursiveAggregation& rhs)
{
Expand Down Expand Up @@ -3395,7 +3456,7 @@ inline BlackBoxFuncCall::RecursiveAggregation BlackBoxFuncCall::RecursiveAggrega
return value;
}

} // end of namespace Program
} // namespace Circuit

template <>
template <typename Serializer>
Expand Down
86 changes: 86 additions & 0 deletions barretenberg/cpp/src/barretenberg/stdlib/hash/keccak/keccak.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,92 @@ stdlib::byte_array<Builder> keccak<Builder>::hash_using_permutation_opcode(byte_
return result;
}

// Returns the keccak f1600 permutation of the input state
// We first convert the state into 'extended' representation, along with the 'twisted' state
// and then we call keccakf1600() with this keccak 'internal state'
// Finally, we convert back the state from the extented representation
template <typename Builder>
std::array<field_t<Builder>, keccak<Builder>::NUM_KECCAK_LANES> keccak<Builder>::permutation_opcode(
std::array<field_t<Builder>, NUM_KECCAK_LANES> state, Builder* ctx)
{
std::vector<field_t<Builder>> converted_buffer(NUM_KECCAK_LANES);
std::vector<field_t<Builder>> msb_buffer(NUM_KECCAK_LANES);
// populate keccak_state, convert our 64-bit lanes into an extended base-11 representation
keccak_state internal;
internal.context = ctx;
for (size_t i = 0; i < state.size(); ++i) {
const auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_FORMAT_INPUT, state[i]);
internal.state[i] = accumulators[ColumnIdx::C2][0];
internal.state_msb[i] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
}
compute_twisted_state(internal);
keccakf1600(internal);
// we convert back to the normal lanes
return extended_2_normal(internal);
}

// This function is similar to sponge_absorb()
// but it uses permutation_opcode() instead of calling directly keccakf1600().
// As a result, this function is less efficient and should only be used to test permutation_opcode()
template <typename Builder>
void keccak<Builder>::sponge_absorb_with_permutation_opcode(keccak_state& internal,
std::vector<field_t<Builder>>& input_buffer,
const size_t input_size)
{
// populate keccak_state
const size_t num_blocks = input_size / (BLOCK_SIZE / 8);
for (size_t i = 0; i < num_blocks; ++i) {
if (i == 0) {
for (size_t j = 0; j < LIMBS_PER_BLOCK; ++j) {
internal.state[j] = input_buffer[j];
}
for (size_t j = LIMBS_PER_BLOCK; j < NUM_KECCAK_LANES; ++j) {
internal.state[j] = witness_ct::create_constant_witness(internal.context, 0);
}
} else {
for (size_t j = 0; j < LIMBS_PER_BLOCK; ++j) {
internal.state[j] = stdlib::logic<Builder>::create_logic_constraint(
internal.state[j], input_buffer[i * LIMBS_PER_BLOCK + j], 64, true);
}
}
internal.state = permutation_opcode(internal.state, internal.context);
}
}

// This function computes the keccak hash, like the hash() function
// but it uses permutation_opcode() instead of calling directly keccakf1600().
// As a result, this function is less efficient and should only be used to test permutation_opcode()
template <typename Builder>
stdlib::byte_array<Builder> keccak<Builder>::hash_using_permutation_opcode(byte_array_ct& input,
const uint32_ct& num_bytes)
{
auto ctx = input.get_context();

ASSERT(uint256_t(num_bytes.get_value()) == input.size());

if (ctx == nullptr) {
// if buffer is constant compute hash and return w/o creating constraints
byte_array_ct output(nullptr, 32);
const std::vector<uint8_t> result = hash_native(input.get_value());
for (size_t i = 0; i < 32; ++i) {
output.set_byte(i, result[i]);
}
return output;
}

// convert the input byte array into 64-bit keccak lanes (+ apply padding)
auto formatted_slices = format_input_lanes(input, num_bytes);

keccak_state internal;
internal.context = ctx;
uint32_ct num_blocks_with_data = (num_bytes + BLOCK_SIZE) / BLOCK_SIZE;
sponge_absorb_with_permutation_opcode(internal, formatted_slices, formatted_slices.size());

auto result = sponge_squeeze_for_permutation_opcode(internal.state, ctx);

return result;
}

template <typename Builder>
stdlib::byte_array<Builder> keccak<Builder>::hash(byte_array_ct& input, const uint32_ct& num_bytes)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,46 @@ TEST(stdlib_keccak, test_permutation_opcode_double_block)
bool proof_result = CircuitChecker::check(builder);
EXPECT_EQ(proof_result, true);
}

TEST(stdlib_keccak, test_permutation_opcode_single_block)
{
Builder builder = Builder();
std::string input = "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz01";
std::vector<uint8_t> input_v(input.begin(), input.end());

byte_array input_arr(&builder, input_v);
byte_array output =
stdlib::keccak<Builder>::hash_using_permutation_opcode(input_arr, static_cast<uint32_t>(input.size()));

std::vector<uint8_t> expected = stdlib::keccak<Builder>::hash_native(input_v);

EXPECT_EQ(output.get_value(), expected);

builder.print_num_gates();

bool proof_result = builder.check_circuit();
EXPECT_EQ(proof_result, true);
}

TEST(stdlib_keccak, test_permutation_opcode_double_block)
{
Builder builder = Builder();
std::string input = "";
for (size_t i = 0; i < 200; ++i) {
input += "a";
}
std::vector<uint8_t> input_v(input.begin(), input.end());

byte_array input_arr(&builder, input_v);
byte_array output =
stdlib::keccak<Builder>::hash_using_permutation_opcode(input_arr, static_cast<uint32_t>(input.size()));

std::vector<uint8_t> expected = stdlib::keccak<Builder>::hash_native(input_v);

EXPECT_EQ(output.get_value(), expected);

builder.print_num_gates();

bool proof_result = builder.check_circuit();
EXPECT_EQ(proof_result, true);
}
50 changes: 50 additions & 0 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ namespace Program {
static Keccakf1600 bincodeDeserialize(std::vector<uint8_t>);
};

struct Keccakf1600 {
std::vector<Circuit::FunctionInput> inputs;
std::vector<Circuit::Witness> outputs;

friend bool operator==(const Keccakf1600&, const Keccakf1600&);
std::vector<uint8_t> bincodeSerialize() const;
static Keccakf1600 bincodeDeserialize(std::vector<uint8_t>);
};

struct RecursiveAggregation {
std::vector<Program::FunctionInput> verification_key;
std::vector<Program::FunctionInput> proof;
Expand Down Expand Up @@ -2928,6 +2937,47 @@ Program::BlackBoxFuncCall::Keccakf1600 serde::Deserializable<Program::BlackBoxFu

namespace Program {

inline bool operator==(const BlackBoxFuncCall::Keccakf1600 &lhs, const BlackBoxFuncCall::Keccakf1600 &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.outputs == rhs.outputs)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxFuncCall::Keccakf1600::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxFuncCall::Keccakf1600>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxFuncCall::Keccakf1600 BlackBoxFuncCall::Keccakf1600::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxFuncCall::Keccakf1600>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BlackBoxFuncCall::Keccakf1600>::serialize(const Circuit::BlackBoxFuncCall::Keccakf1600 &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.outputs)>::serialize(obj.outputs, serializer);
}

template <>
template <typename Deserializer>
Circuit::BlackBoxFuncCall::Keccakf1600 serde::Deserializable<Circuit::BlackBoxFuncCall::Keccakf1600>::deserialize(Deserializer &deserializer) {
Circuit::BlackBoxFuncCall::Keccakf1600 obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.outputs = serde::Deserializable<decltype(obj.outputs)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const BlackBoxFuncCall::RecursiveAggregation &lhs, const BlackBoxFuncCall::RecursiveAggregation &rhs) {
if (!(lhs.verification_key == rhs.verification_key)) { return false; }
if (!(lhs.proof == rhs.proof)) { return false; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ pub enum BlackBoxFuncCall {
inputs: Box<[FunctionInput; 25]>,
outputs: Box<[Witness; 25]>,
},
Keccakf1600 {
inputs: Vec<FunctionInput>,
outputs: Vec<Witness>,
},
RecursiveAggregation {
verification_key: Vec<FunctionInput>,
proof: Vec<FunctionInput>,
Expand Down
Loading