diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp index d8c2c666a6c3..e2448004deb9 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp @@ -25,8 +25,9 @@ std::pair, Circuit> unique_witness_ext(CircuitSchema& circuit_in const std::vector& equal_at_the_same_time, const std::vector& not_equal_at_the_same_time) { - Circuit c1(circuit_info, s, "circuit1"); - Circuit c2(circuit_info, s, "circuit2"); + // TODO(alex): set optimizations to be true once they are confirmed + Circuit c1(circuit_info, s, "circuit1", false); + Circuit c2(circuit_info, s, "circuit2", false); for (const auto& term : equal) { c1[term] == c2[term]; @@ -95,8 +96,9 @@ std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, Solver* s, const std::vector& equal) { - Circuit c1(circuit_info, s, "circuit1"); - Circuit c2(circuit_info, s, "circuit2"); + // TODO(alex): set optimizations to be true once they are confirmed + Circuit c1(circuit_info, s, "circuit1", false); + Circuit c2(circuit_info, s, "circuit2", false); for (const auto& term : equal) { c1[term] == c2[term]; @@ -108,6 +110,9 @@ std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, if (std::find(equal.begin(), equal.end(), std::string(c1.variable_names[i])) != equal.end()) { continue; } + if (c1.optimized[i]) { + continue; + } Bool tmp = Bool(c1[i]) != Bool(c2[i]); neqs.push_back(tmp); } diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp index d8f7f5f3777b..d2bce23aab64 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp @@ -16,6 +16,8 @@ using namespace smt_terms; using namespace smt_circuit_schema; using namespace smt_subcircuits; +enum class SubcircuitType { XOR, AND, RANGE }; + /** * @brief Symbolic Circuit class. * @@ -29,7 +31,9 @@ template class Circuit { void init(); size_t prepare_gates(size_t cursor); - void univariate_handler(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w); + void handle_univariate_constraint(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w); + size_t handle_logic_constraint(size_t cursor); + size_t handle_range_constraint(size_t cursor); public: std::vector variables; // circuit witness @@ -40,13 +44,22 @@ template class Circuit { std::vector> wires_idxs; // values of the gates' wires std::unordered_map symbolic_vars; // all the symbolic variables from the circuit std::vector real_variable_index; // indexes for assert_equal'd wires + std::unordered_map optimized; // keeps track of the variables that were excluded from symbolic + // circuit during optimizations + bool optimizations; // flags to turn on circuit optimizations + std::unordered_map> + cached_subcircuits; // caches subcircuits during optimization + // No need to recompute them each time Solver* solver; // pointer to the solver std::string tag; // tag of the symbolic circuit. // If not empty, will be added to the names // of symbolic variables to prevent collisions. - explicit Circuit(CircuitSchema& circuit_info, Solver* solver, const std::string& tag = ""); + explicit Circuit(CircuitSchema& circuit_info, + Solver* solver, + const std::string& tag = "", + bool optimizations = true); FF operator[](const std::string& name); FF operator[](const uint32_t& idx) { return symbolic_vars[this->real_variable_index[idx]]; }; @@ -65,13 +78,14 @@ template class Circuit { * @param tag tag of the circuit. Empty by default. */ template -Circuit::Circuit(CircuitSchema& circuit_info, Solver* solver, const std::string& tag) +Circuit::Circuit(CircuitSchema& circuit_info, Solver* solver, const std::string& tag, bool optimizations) : variables(circuit_info.variables) , public_inps(circuit_info.public_inps) , variable_names(circuit_info.vars_of_interest) , selectors(circuit_info.selectors) , wires_idxs(circuit_info.wires) , real_variable_index(circuit_info.real_variable_index) + , optimizations(optimizations) , solver(solver) , tag(tag) { @@ -89,6 +103,8 @@ Circuit::Circuit(CircuitSchema& circuit_info, Solver* solver, const std::str variable_names.insert({ 1, "one" }); variable_names_inverse.insert({ "zero", 0 }); variable_names_inverse.insert({ "one", 1 }); + optimized.insert({ 0, false }); + optimized.insert({ 1, false }); this->init(); @@ -98,6 +114,12 @@ Circuit::Circuit(CircuitSchema& circuit_info, Solver* solver, const std::str while (i < this->get_num_gates()) { i = this->prepare_gates(i); } + + for (auto& opt : optimized) { + if (opt.second) { + this->symbolic_vars[opt.first] == 0; + } + } } /** @@ -124,6 +146,7 @@ template void Circuit::init() } else { symbolic_vars.insert({ real_idx, FF::Var("var_" + std::to_string(i) + this->tag, this->solver) }); } + optimized.insert({ real_idx, true }); } symbolic_vars[0] == bb::fr(0); @@ -143,7 +166,7 @@ template void Circuit::init() * @param w witness index */ template -void Circuit::univariate_handler(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w) +void Circuit::handle_univariate_constraint(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w) { bb::fr b = q_1 + q_2 + q_3; @@ -168,6 +191,244 @@ void Circuit::univariate_handler(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr } } +/** + * @brief Relaxes logic constraints(AND/XOR). + * @details This function is needed when we use bitwise compatible + * symbolic terms. + * It compares the chunk of selectors of the current circuit + * with pure create_logic_constraint from circuit_builder. + * It uses binary search to find a bit length of the constraint, + * since we don't know it in general. + * After a match is found, it updates the cursor to skip all the + * redundant constraints and adds a pure a ^ b = c or a & b = c + * constraint to solver. + * If there's no match, it will return -1 + * + * @param cursor current position + * @return next position or -1 + */ +template size_t Circuit::handle_logic_constraint(size_t cursor) +{ + // Initialize binary search. Logic gate can only accept even bit lengths + // So we need to find a match among [1, 127] and then multiply the result by 2 + size_t beg = 1; + size_t end = 127; + size_t mid = 0; + auto res = static_cast(-1); + + // Indicates that current bit length is a match for XOR + bool xor_flag = true; + // Indicates that current bit length is a match for AND + bool and_flag = true; + // Indicates the logic operation(true - XOR, false - AND) if the match is found. + bool logic_flag = true; + CircuitProps xor_props; + CircuitProps and_props; + + bool stop_flag = false; + + while (beg <= end) { + mid = (end + beg) / 2; + + // Take a pure logic circuit for the current bit length(2 * mid) + // and compare it's selectors to selectors of the global circuit + // at current position(cursor). + // If they are equal, we can apply an optimization + // However, if we have a match at bit length 2, it is possible + // to have a match at higher bit lengths. That's why we store + // the current match as `res` and proceed with ordinary binary search. + // `stop_flag` simply indicates that the first selector doesn't match + // and we can skip this whole section. + + if (!this->cached_subcircuits[SubcircuitType::XOR].contains(mid * 2)) { + this->cached_subcircuits[SubcircuitType::XOR].insert( + { mid * 2, get_standard_logic_circuit(mid * 2, true) }); + } + xor_props = this->cached_subcircuits[SubcircuitType::XOR][mid * 2]; + + if (!this->cached_subcircuits[SubcircuitType::AND].contains(mid * 2)) { + this->cached_subcircuits[SubcircuitType::AND].insert( + { mid * 2, get_standard_logic_circuit(mid * 2, false) }); + } + and_props = this->cached_subcircuits[SubcircuitType::AND][mid * 2]; + + CircuitSchema xor_circuit = xor_props.circuit; + CircuitSchema and_circuit = and_props.circuit; + + xor_flag = cursor + xor_props.num_gates <= this->selectors.size(); + and_flag = cursor + xor_props.num_gates <= this->selectors.size(); + if (xor_flag || and_flag) { + for (size_t j = 0; j < xor_props.num_gates; j++) { + // It is possible for gates to be equal but wires to be not, but I think it's very + // unlikely to happen + xor_flag &= xor_circuit.selectors[j + xor_props.start_gate] == this->selectors[cursor + j]; + and_flag &= and_circuit.selectors[j + and_props.start_gate] == this->selectors[cursor + j]; + + if (!xor_flag && !and_flag) { + // Won't match at any bit length + if (j == 0) { + stop_flag = true; + } + break; + } + } + } + if (stop_flag) { + break; + } + + if (!xor_flag && !and_flag) { + end = mid - 1; + } else { + res = 2 * mid; + logic_flag = xor_flag; + + beg = mid + 1; + } + } + + // TODO(alex): Figure out if I need to create range constraint here too or it'll be + // created anyway in any circuit + if (res != static_cast(-1)) { + xor_props = get_standard_logic_circuit(res, true); + and_props = get_standard_logic_circuit(res, false); + + info("Logic constraint optimization: ", std::to_string(res), " bits. is_xor: ", xor_flag); + size_t left_gate = xor_props.gate_idxs[0]; + uint32_t left_gate_idx = xor_props.idxs[0]; + size_t right_gate = xor_props.gate_idxs[1]; + uint32_t right_gate_idx = xor_props.idxs[1]; + size_t out_gate = xor_props.gate_idxs[2]; + uint32_t out_gate_idx = xor_props.idxs[2]; + + uint32_t left_idx = this->real_variable_index[this->wires_idxs[cursor + left_gate][left_gate_idx]]; + uint32_t right_idx = this->real_variable_index[this->wires_idxs[cursor + right_gate][right_gate_idx]]; + uint32_t out_idx = this->real_variable_index[this->wires_idxs[cursor + out_gate][out_gate_idx]]; + + FF left = this->symbolic_vars[left_idx]; + FF right = this->symbolic_vars[right_idx]; + FF out = this->symbolic_vars[out_idx]; + + if (logic_flag) { + (left ^ right) == out; + } else { + (left ^ right) == out; // TODO(alex): implement & method + } + + // You have to mark these arguments so they won't be optimized out + optimized[left_idx] = false; + optimized[right_idx] = false; + optimized[out_idx] = false; + return cursor + xor_props.num_gates; + } + return res; +} + +/** + * @brief Relaxes range constraints. + * @details This function is needed when we use range compatible + * symbolic terms. + * It compares the chunk of selectors of the current circuit + * with pure create_range_constraint from circuit_builder. + * It uses binary search to find a bit length of the constraint, + * since we don't know it in general. + * After match is found, it updates the cursor to skip all the + * redundant constraints and adds a pure a < 2^bit_length + * constraint to solver. + * If there's no match, it will return -1 + * + * @param cursor current position + * @return next position or -1 + */ +template size_t Circuit::handle_range_constraint(size_t cursor) +{ + // Indicates that current bit length is a match + bool range_flag = true; + size_t mid = 0; + auto res = static_cast(-1); + + CircuitProps range_props; + // Range constraints differ depending on oddness of bit_length + // That's why we need to handle these cases separately + for (size_t odd = 0; odd < 2; odd++) { + // Initialize binary search. + // We need to find a match among [1, 127] and then set the result to 2 * mid, or 2 * mid + 1 + size_t beg = 1; + size_t end = 127; + + bool stop_flag = false; + while (beg <= end) { + mid = (end + beg) / 2; + + // Take a pure logic circuit for the current bit length(2 * mid + odd) + // and compare it's selectors to selectors of the global circuit + // at current positin(cursor). + // If they are equal, we can apply an optimization + // However, if we have a match at bit length 2, it is possible + // to have a match at higher bit lengths. That's why we store + // the current match as `res` and proceed with ordinary binary search. + // `stop_flag` simply indicates that the first selector doesn't match + // and we can skip this whole section. + + if (!this->cached_subcircuits[SubcircuitType::RANGE].contains(2 * mid + odd)) { + this->cached_subcircuits[SubcircuitType::RANGE].insert( + { 2 * mid + odd, get_standard_range_constraint_circuit(2 * mid + odd) }); + } + range_props = this->cached_subcircuits[SubcircuitType::RANGE][2 * mid + odd]; + CircuitSchema range_circuit = range_props.circuit; + + range_flag = cursor + range_props.num_gates <= this->get_num_gates(); + if (range_flag) { + for (size_t j = 0; j < range_props.num_gates; j++) { + // It is possible for gates to be equal but wires to be not, but I think it's very + // unlikely to happen + range_flag &= range_circuit.selectors[j + range_props.start_gate] == this->selectors[cursor + j]; + + if (!range_flag) { + // Won't match at any bit length + if (j <= 2) { + stop_flag = true; + } + break; + } + } + } + if (stop_flag) { + break; + } + + if (!range_flag) { + end = mid - 1; + } else { + res = 2 * mid + odd; + beg = mid + 1; + } + } + + if (res != static_cast(-1)) { + range_flag = true; + break; + } + } + + if (range_flag) { + info("Range constraint optimization: ", std::to_string(res), " bits"); + range_props = get_standard_range_constraint_circuit(res); + + size_t left_gate = range_props.gate_idxs[0]; + uint32_t left_gate_idx = range_props.idxs[0]; + uint32_t left_idx = this->real_variable_index[this->wires_idxs[cursor + left_gate][left_gate_idx]]; + + FF left = this->symbolic_vars[left_idx]; + left < bb::fr(2).pow(res); + + // You have to mark these arguments so they won't be optimized out + optimized[left_idx] = false; + return cursor + range_props.num_gates; + } + return res; +} + /** * @brief Adds all the gate constraints to the solver. * Relaxes constraint system for non-ff solver engines @@ -176,8 +437,20 @@ void Circuit::univariate_handler(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr */ template size_t Circuit::prepare_gates(size_t cursor) { - // TODO(alex): Here'll be the operator relaxation that is coming - // in the next pr + // TODO(alex): implement bitvector class and compute offsets + if (FF::isBitVector() && this->optimizations) { + size_t res = handle_logic_constraint(cursor); + if (res != static_cast(-1)) { + return res; + } + } + + if ((FF::isBitVector() || FF::isInteger()) && this->optimizations) { + size_t res = handle_range_constraint(cursor); + if (res != static_cast(-1)) { + return res; + } + } bb::fr q_m = this->selectors[cursor][0]; bb::fr q_1 = this->selectors[cursor][1]; @@ -188,30 +461,37 @@ template size_t Circuit::prepare_gates(size_t cursor) uint32_t w_l = this->wires_idxs[cursor][0]; uint32_t w_r = this->wires_idxs[cursor][1]; uint32_t w_o = this->wires_idxs[cursor][2]; - - bool univariate_flag = w_l == w_r && w_r == w_o; - univariate_flag |= w_l == w_r && q_3 == 0; - univariate_flag |= w_l == w_o && q_2 == 0 && q_m == 0; - univariate_flag |= w_r == w_o && q_1 == 0 && q_m == 0; - univariate_flag |= q_m == 0 && q_1 == 0 && q_3 == 0; - univariate_flag |= q_m == 0 && q_2 == 0 && q_3 == 0; - univariate_flag |= q_m == 0 && q_1 == 0 && q_2 == 0; + optimized[w_l] = false; + optimized[w_r] = false; + optimized[w_o] = false; + + // Handles the case when we have univariate polynomial as constraint + // by simply finding the roots via quadratic formula(or linear) + // There're 7 possibilities of that, which are present below + bool univariate_flag = true; + univariate_flag |= (w_l == w_r) && (w_r == w_o); + univariate_flag |= (w_l == w_r) && (q_3 == 0); + univariate_flag |= (w_l == w_o) && (q_2 == 0) && (q_m == 0); + univariate_flag |= (w_r == w_o) && (q_1 == 0) && (q_m == 0); + univariate_flag |= (q_m == 0) && (q_1 == 0) && (q_3 == 0); + univariate_flag |= (q_m == 0) && (q_2 == 0) && (q_3 == 0); + univariate_flag |= (q_m == 0) && (q_1 == 0) && (q_2 == 0); // Univariate gate. Relaxes the solver. Or is it? // TODO(alex): Test the effect of this relaxation after the tests are merged. if (univariate_flag) { - if (q_m == 1 && q_1 == 0 && q_2 == 0 && q_3 == -1 && q_c == 0) { + if ((q_m == 1) && (q_1 == 0) && (q_2 == 0) && (q_3 == -1) && (q_c == 0)) { (Bool(symbolic_vars[w_l]) == Bool(symbolic_vars[0]) | Bool(symbolic_vars[w_l]) == Bool(symbolic_vars[1])) .assert_term(); } else { - this->univariate_handler(q_m, q_1, q_2, q_3, q_c, w_l); + this->handle_univariate_constraint(q_m, q_1, q_2, q_3, q_c, w_l); } } else { FF eq = symbolic_vars[0]; // mul selector if (q_m != 0) { - eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m; // TODO(alex): Is there a way to do lmul? + eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m; } // left selector if (q_1 != 0) { diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp new file mode 100644 index 000000000000..2660f8628b6b --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp @@ -0,0 +1,114 @@ +#include +#include +#include + +#include "barretenberg/proof_system/circuit_builder/standard_circuit_builder.hpp" +#include "barretenberg/smt_verification/circuit/circuit.hpp" +#include "barretenberg/smt_verification/util/smt_util.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" + +#include + +using namespace bb; +using namespace smt_circuit; + +namespace { +auto& engine = numeric::get_debug_randomness(); +} + +using field_t = stdlib::field_t; +using witness_t = stdlib::witness_t; +using pub_witness_t = stdlib::public_witness_t; + +TEST(circuit, assert_equal) +{ + StandardCircuitBuilder builder = StandardCircuitBuilder(); + + field_t a(witness_t(&builder, fr::random_element())); + field_t b(witness_t(&builder, fr::random_element())); + builder.set_variable_name(a.witness_index, "a"); + builder.set_variable_name(b.witness_index, "b"); + field_t c = (a + a) / (b + b + b); + builder.set_variable_name(c.witness_index, "c"); + + field_t d(witness_t(&builder, a.get_value())); + field_t e(witness_t(&builder, b.get_value())); + field_t f(witness_t(&builder, c.get_value())); + builder.assert_equal(d.get_witness_index(), a.get_witness_index()); + builder.assert_equal(e.get_witness_index(), b.get_witness_index()); + + field_t g = d + d; + field_t h = e + e + e; + field_t i = g / h; + builder.assert_equal(i.get_witness_index(), c.get_witness_index()); + field_t j(witness_t(&builder, i.get_value())); + field_t k(witness_t(&builder, j.get_value())); + builder.assert_equal(i.get_witness_index(), j.get_witness_index()); + builder.assert_equal(i.get_witness_index(), k.get_witness_index()); + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus); + Circuit circuit(circuit_info, &s); + + ASSERT_EQ(circuit[k.get_witness_index()].term, circuit["c"].term); + ASSERT_EQ(circuit[d.get_witness_index()].term, circuit["a"].term); + ASSERT_EQ(circuit[e.get_witness_index()].term, circuit["b"].term); + + ASSERT_EQ(circuit[i.get_witness_index()].term, circuit[k.get_witness_index()].term); + ASSERT_EQ(circuit[i.get_witness_index()].term, circuit[j.get_witness_index()].term); +} + +TEST(circuit, range_relaxation_assertions) +{ + StandardCircuitBuilder builder = StandardCircuitBuilder(); + field_t a(witness_t(&builder, fr(120))); + a.create_range_constraint(10); + + field_t b(witness_t(&builder, fr(65567))); + field_t c = a * b; + c.create_range_constraint(27); + builder.set_variable_name(a.get_witness_index(), "a"); + builder.set_variable_name(b.get_witness_index(), "b"); + builder.set_variable_name(c.get_witness_index(), "c"); + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus); + Circuit circuit(circuit_info, &s); + + s.print_assertions(); +} + +TEST(circuit, range_relaxation) +{ + for (size_t i = 2; i < 256; i++) { + StandardCircuitBuilder builder = StandardCircuitBuilder(); + field_t a(witness_t(&builder, fr::zero())); + a.create_range_constraint(i); + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus); + Circuit circuit(circuit_info, &s); + } +} + +TEST(circuit, cached_subcircuits) +{ + StandardCircuitBuilder builder = StandardCircuitBuilder(); + field_t a(witness_t(&builder, fr::zero())); + builder.set_variable_name(a.get_witness_index(), "a"); + a.create_range_constraint(5); + field_t b(witness_t(&builder, fr::zero())); + b.create_range_constraint(5); + builder.set_variable_name(b.get_witness_index(), "b"); + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus); + Circuit circuit(circuit_info, &s); + s.print_assertions(); +} + +// TODO(alex): check xor relaxations after bivector is here \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp index 930ff36eaff8..0111614115a1 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit_schema.hpp @@ -16,10 +16,10 @@ namespace smt_circuit_schema { * * @param modulus Modulus of the field we are working with * @param public_inps Public inputs to the current circuit - * @param vars_of_interes Map wires indicies to their given names + * @param vars_of_interest Map wires indices to their given names * @param variables List of wires values in the current circuit * @param selectors List of selectors in the current circuit - * @param wires List of wires indicies for each selector + * @param wires List of wires indices for each selector * @param real_variable_index Encoded copy constraints */ struct CircuitSchema { @@ -36,4 +36,4 @@ struct CircuitSchema { CircuitSchema unpack_from_buffer(const msgpack::sbuffer& buf); CircuitSchema unpack_from_file(const std::string& filename); void print_schema_for_use_in_python(CircuitSchema& cir); -} // namespace smt_circuit_schema \ No newline at end of file +} // namespace smt_circuit_schema diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.cpp index 2870a623a536..792f2d2d8bbb 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.cpp @@ -2,23 +2,57 @@ namespace smt_subcircuits { -CircuitSchema get_standard_range_constraint_circuit(size_t n) +CircuitProps get_standard_range_constraint_circuit(size_t n) { bb::StandardCircuitBuilder builder = bb::StandardCircuitBuilder(); - uint32_t a_idx = builder.add_variable(bb::fr::random_element()); + uint32_t a_idx = builder.add_variable(bb::fr(0)); builder.set_variable_name(a_idx, "a"); - builder.create_range_constraint(a_idx, n); - return unpack_from_buffer(builder.export_circuit()); + + size_t start_gate = builder.get_num_gates(); + builder.decompose_into_base4_accumulators(a_idx, n); + size_t num_gates = builder.get_num_gates() - start_gate; + + CircuitSchema exported = unpack_from_buffer(builder.export_circuit()); + + // relative offstes in the circuit are calculated manually, according to decompose_into_base4_accumulators method + // lhs position in the gate + uint32_t lhs_position = 2; + // number of the gate that contains lhs + size_t gate_number = num_gates - 1; + + return { exported, start_gate, num_gates, { lhs_position }, { gate_number } }; } -CircuitSchema get_standard_logic_circuit(size_t n, bool is_xor) +CircuitProps get_standard_logic_circuit(size_t n, bool is_xor) { bb::StandardCircuitBuilder builder = bb::StandardCircuitBuilder(); - uint32_t a_idx = builder.add_variable(bb::fr::random_element()); - uint32_t b_idx = builder.add_variable(bb::fr::random_element()); + uint32_t a_idx = builder.add_variable(bb::fr(0)); + uint32_t b_idx = builder.add_variable(bb::fr(0)); builder.set_variable_name(a_idx, "a"); builder.set_variable_name(b_idx, "b"); - builder.create_logic_constraint(a_idx, b_idx, n, is_xor); - return unpack_from_buffer(builder.export_circuit()); + + size_t start_gate = builder.get_num_gates(); + auto acc = builder.create_logic_constraint(a_idx, b_idx, n, is_xor); + size_t num_gates = builder.get_num_gates() - start_gate; + + builder.set_variable_name(acc.out.back(), "c"); + + CircuitSchema exported = unpack_from_buffer(builder.export_circuit()); + + // relative offstes in the circuit are calculated manually, according to create_logic_constraint method + // lhs, rhs and out positions in the corresponding gates + uint32_t lhs_position = 2; + uint32_t rhs_position = 2; + uint32_t out_position = 2; + // numbers of the gates that contain lhs, rhs and out + size_t lhs_gate_number = num_gates - 3; + size_t rhs_gate_number = num_gates - 2; + size_t out_gate_number = num_gates - 1; + + return { exported, + start_gate, + num_gates, + { lhs_position, rhs_position, out_position }, + { lhs_gate_number, rhs_gate_number, out_gate_number } }; } } // namespace smt_subcircuits \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.hpp index 4e49113887a9..53d6a9aeb85b 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.hpp @@ -6,6 +6,23 @@ namespace smt_subcircuits { using namespace smt_circuit_schema; -CircuitSchema get_standard_range_constraint_circuit(size_t n); -CircuitSchema get_standard_logic_circuit(size_t n, bool is_xor); -} // namespace smt_subcircuits \ No newline at end of file +/** + * @brief Circuit stats to identify subcircuit + * + * @param circuit Schema of the whole subcircuit + * @param start_gate Start of the needed subcircuit + * @param num_gates The number of gates in the needed subcircuit + * @param idxs Indices of the needed variables to calculate offset + * @param gate_idxs Indices of the gates that use needed variables + */ +struct CircuitProps { + CircuitSchema circuit; + size_t start_gate; + size_t num_gates; + std::vector idxs; + std::vector gate_idxs; +}; + +CircuitProps get_standard_range_constraint_circuit(size_t n); +CircuitProps get_standard_logic_circuit(size_t n, bool is_xor); +} // namespace smt_subcircuits diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.test.cpp new file mode 100644 index 000000000000..e46800aa39b6 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/subcircuits.test.cpp @@ -0,0 +1,57 @@ +#include +#include + +#include "barretenberg/proof_system/circuit_builder/standard_circuit_builder.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" + +#include "barretenberg/smt_verification/circuit/subcircuits.hpp" + +#include + +using namespace bb; + +namespace { +auto& engine = numeric::get_debug_randomness(); +} + +// Check that all the relative offsets are calculated correctly. +// I.e. I can find an operand at the index, given by get_standard_range_constraint_circuit +TEST(subcircuits, range_circuit) +{ + for (size_t i = 1; i < 256; i++) { + smt_subcircuits::CircuitProps range_props = smt_subcircuits::get_standard_range_constraint_circuit(i); + smt_circuit_schema::CircuitSchema circuit = range_props.circuit; + + size_t a_gate = range_props.gate_idxs[0]; + uint32_t a_gate_idx = range_props.idxs[0]; + size_t start_gate = range_props.start_gate; + + ASSERT_EQ( + "a", circuit.vars_of_interest[circuit.real_variable_index[circuit.wires[start_gate + a_gate][a_gate_idx]]]); + } +} +// Check that all the relative offsets are calculated correctly. +// I.e. I can find all three operands at the indices, given by get_standard_logic_circuit +TEST(subcircuits, logic_circuit) +{ + for (size_t i = 2; i < 256; i += 2) { + smt_subcircuits::CircuitProps logic_props = smt_subcircuits::get_standard_logic_circuit(i, true); + smt_circuit_schema::CircuitSchema circuit = logic_props.circuit; + + size_t a_gate = logic_props.gate_idxs[0]; + uint32_t a_gate_idx = logic_props.idxs[0]; + size_t start_gate = logic_props.start_gate; + ASSERT_EQ( + "a", circuit.vars_of_interest[circuit.real_variable_index[circuit.wires[start_gate + a_gate][a_gate_idx]]]); + + size_t b_gate = logic_props.gate_idxs[1]; + uint32_t b_gate_idx = logic_props.idxs[1]; + ASSERT_EQ( + "b", circuit.vars_of_interest[circuit.real_variable_index[circuit.wires[start_gate + b_gate][b_gate_idx]]]); + + size_t c_gate = logic_props.gate_idxs[2]; + uint32_t c_gate_idx = logic_props.idxs[2]; + ASSERT_EQ( + "c", circuit.vars_of_interest[circuit.real_variable_index[circuit.wires[start_gate + c_gate][c_gate_idx]]]); + } +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp index b8a284180233..1ac73a89e90c 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp @@ -19,7 +19,7 @@ using field_t = stdlib::field_t; using witness_t = stdlib::witness_t; using pub_witness_t = stdlib::public_witness_t; -TEST(circuit_verification, multiplication_true) +TEST(SMT_Example, multiplication_true) { StandardCircuitBuilder builder = StandardCircuitBuilder(); @@ -50,7 +50,7 @@ TEST(circuit_verification, multiplication_true) ASSERT_FALSE(res); } -TEST(circuit_verification, multiplication_true_kind) +TEST(SMT_Example, multiplication_true_kind) { StandardCircuitBuilder builder = StandardCircuitBuilder(); @@ -81,7 +81,7 @@ TEST(circuit_verification, multiplication_true_kind) ASSERT_FALSE(res); } -TEST(circuit_verification, multiplication_false) +TEST(SMT_Example, multiplication_false) { StandardCircuitBuilder builder = StandardCircuitBuilder(); @@ -123,7 +123,7 @@ TEST(circuit_verification, multiplication_false) info("c_res = ", vals["cr"]); } -TEST(circuit_verifiaction, unique_witness) +TEST(SMT_Example, unique_witness_ext) // two roots of a quadratic eq x^2 + a * x + b = s { StandardCircuitBuilder builder = StandardCircuitBuilder(); @@ -153,29 +153,66 @@ TEST(circuit_verifiaction, unique_witness) ASSERT_NE(vals["z_c1"], vals["z_c2"]); } -using namespace smt_terms; - -TEST(solver_use_case, solver) +// Make sure that quadratic polynomial evaluation doesn't have unique +// witness. +// Finds both roots of a quadratic eq x^2 + a * x + b = s +TEST(SMT_Example, unique_witness) { - Solver s("11", default_solver_config, 10); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); + StandardCircuitBuilder builder = StandardCircuitBuilder(); + + field_t a(pub_witness_t(&builder, fr::random_element())); + field_t b(pub_witness_t(&builder, fr::random_element())); + builder.set_variable_name(a.witness_index, "a"); + builder.set_variable_name(b.witness_index, "b"); + field_t z(witness_t(&builder, fr::random_element())); + field_t ev = z * z + a * z + b; + builder.set_variable_name(z.witness_index, "z"); + builder.set_variable_name(ev.witness_index, "ev"); - FFTerm z = x * y + x * x; - z == FFTerm::Const("15", &s, 10); - x != y; - x != FFTerm::Const("0", &s); - y != FFTerm::Const("0", &s); + auto buf = builder.export_circuit(); + + smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); + smt_solver::Solver s(circuit_info.modulus); + + std::pair, smt_circuit::Circuit> cirs = + smt_circuit::unique_witness(circuit_info, &s, { "ev" }); bool res = s.check(); ASSERT_TRUE(res); - std::unordered_map vars = { { "x", x }, { "y", y } }; - std::unordered_map mvars = s.model(vars); - - info("x = ", mvars["x"]); - info("y = ", mvars["y"]); + std::unordered_map terms = { { "z_c1", cirs.first["z"] }, { "z_c2", cirs.second["z"] } }; + std::unordered_map vals = s.model(terms); + ASSERT_NE(vals["z_c1"], vals["z_c2"]); } -// TODO(alex): Try setting the whole witness to be not equal at the same time, while setting inputs and outputs to be -// equal \ No newline at end of file +// Make sure that quadratic polynomial evaluation doesn't have unique +// witness. Also coefficients are private. +// Finds both roots of a quadratic eq x^2 + a * x + b = s +TEST(SMT_Example, unique_witness_private_coefficients) +{ + StandardCircuitBuilder builder = StandardCircuitBuilder(); + + field_t a(witness_t(&builder, fr::random_element())); + field_t b(witness_t(&builder, fr::random_element())); + builder.set_variable_name(a.witness_index, "a"); + builder.set_variable_name(b.witness_index, "b"); + field_t z(witness_t(&builder, fr::random_element())); + field_t ev = z * z + a * z + b; + builder.set_variable_name(z.witness_index, "z"); + builder.set_variable_name(ev.witness_index, "ev"); + + auto buf = builder.export_circuit(); + + smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); + smt_solver::Solver s(circuit_info.modulus); + + std::pair, smt_circuit::Circuit> cirs = + smt_circuit::unique_witness(circuit_info, &s, { "ev", "a", "b" }); + + bool res = s.check(); + ASSERT_TRUE(res); + + std::unordered_map terms = { { "z_c1", cirs.first["z"] }, { "z_c2", cirs.second["z"] } }; + std::unordered_map vals = s.model(terms); + ASSERT_NE(vals["z_c1"], vals["z_c2"]); +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp index e37310e3ea67..6b7a53a4531c 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp @@ -20,6 +20,7 @@ class FFITerm { static bool isFiniteField() { return false; }; static bool isInteger() { return true; }; + static bool isBitVector() { return false; }; FFITerm() : solver(nullptr) diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp index c777fd6bbefa..2be142d49608 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp @@ -19,6 +19,7 @@ class FFTerm { static bool isFiniteField() { return true; }; static bool isInteger() { return false; }; + static bool isBitVector() { return false; }; FFTerm() : solver(nullptr) diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp new file mode 100644 index 000000000000..d67fb098d858 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp @@ -0,0 +1,47 @@ +#include "smt_util.hpp" + +/** + * @brief Get the solver result and amount of time + * that it took to solve. + * + * @param s + * @return bool is system satisfiable? + */ +bool smt_timer(smt_solver::Solver* s) +{ + auto start = std::chrono::high_resolution_clock::now(); + bool res = s->check(); + auto stop = std::chrono::high_resolution_clock::now(); + double duration = static_cast(duration_cast(stop - start).count()); + info("Time passed: ", duration); + + info(s->cvc_result); + return res; +} + +/** + * @brief base4 decomposition with accumulators + * + * @param el + * @return base decomposition, accumulators + */ +std::pair, std::vector> base4(uint32_t el) +{ + std::vector limbs; + limbs.reserve(16); + for (size_t i = 0; i < 16; i++) { + limbs.emplace_back(el % 4); + el /= 4; + } + std::reverse(limbs.begin(), limbs.end()); + std::vector accumulators; + accumulators.reserve(16); + bb::fr accumulator = 0; + for (size_t i = 0; i < 16; i++) { + accumulator = accumulator * 4 + limbs[i]; + accumulators.emplace_back(accumulator); + } + std::reverse(limbs.begin(), limbs.end()); + std::reverse(accumulators.begin(), accumulators.end()); + return { limbs, accumulators }; +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp new file mode 100644 index 000000000000..a0aceeeadf61 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp @@ -0,0 +1,123 @@ +#pragma once +#include + +#include "barretenberg/smt_verification/circuit/circuit.hpp" + +/** + * @brief Get pretty formatted result of the solver work + * + * @details Having two circuits and defined constraint system + * inside the solver get the pretty formatted output. + * The whole witness will be saved in c-like array format. + * Special variables will be printed to stdout. e.g. `a_1, a_2 = val_a_1, val_a_2;` + * + * @param special The list of variables that you need to see in stdout + * @param c1 first circuit + * @param c2 the copy of the first circuit with changed tag + * @param s solver + * @param fname file to store the resulting witness if succeded + */ +template +void default_model(std::vector special, + smt_circuit::Circuit& c1, + smt_circuit::Circuit& c2, + smt_solver::Solver* s, + const std::string& fname = "witness.out") +{ + std::vector vterms1; + std::vector vterms2; + vterms1.reserve(c1.get_num_real_vars()); + vterms2.reserve(c1.get_num_real_vars()); + + for (uint32_t i = 0; i < c1.get_num_vars(); i++) { + vterms1.push_back(c1.symbolic_vars[c1.real_variable_index[i]]); + vterms2.push_back(c2.symbolic_vars[c2.real_variable_index[i]]); + } + + std::unordered_map mmap1 = s->model(vterms1); + std::unordered_map mmap2 = s->model(vterms2); + + std::fstream myfile; + myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); + myfile << "w12 = {" << std::endl; + for (uint32_t i = 0; i < c1.get_num_vars(); i++) { + std::string vname1 = vterms1[i].toString(); + std::string vname2 = vterms2[i].toString(); + if (c1.real_variable_index[i] == i) { + myfile << "{" << mmap1[vname1] << ", " << mmap2[vname2] << "}"; + myfile << ", // " << vname1 << ", " << vname2 << std::endl; + } else { + myfile << "{" << mmap1[vname1] << ", " + mmap2[vname2] << "}"; + myfile << ", // " << vname1 << " ," << vname2 << " -> " << c1.real_variable_index[i] << std::endl; + } + } + myfile << "};"; + myfile.close(); + + std::unordered_map vterms; + for (auto& vname : special) { + vterms.insert({ vname + "_1", c1[vname] }); + vterms.insert({ vname + "_2", c2[vname] }); + } + + std::unordered_map mmap = s->model(vterms); + for (auto& vname : special) { + info(vname, "_1, ", vname, "_2 = ", mmap[vname + "_1"], ", ", mmap[vname + "_2"]); + } +} + +/** + * @brief Get pretty formatted result of the solver work + * + * @details Having a circuit and defined constraint system + * inside the solver get the pretty formatted output. + * The whole witness will be saved in c-like array format. + * Special variables will be printed to stdout. e.g. `a = val_a;` + * + * @param special The list of variables that you need to see in stdout + * @param c first circuit + * @param s solver + * @param fname file to store the resulting witness if succeded + */ +template +void default_model_single(std::vector special, + smt_circuit::Circuit& c, + smt_solver::Solver* s, + const std::string& fname = "witness.out") +{ + std::vector vterms; + vterms.reserve(c.get_num_real_vars()); + + for (uint32_t i = 0; i < c.get_num_vars(); i++) { + vterms.push_back(c.symbolic_vars[i]); + } + + std::unordered_map mmap = s->model(vterms); + + std::fstream myfile; + myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); + myfile << "w = {" << std::endl; + for (size_t i = 0; i < c.get_num_vars(); i++) { + std::string vname = vterms[i].toString(); + if (c.real_variable_index[i] == i) { + myfile << mmap[vname] << ", // " << vname << std::endl; + } else { + myfile << mmap[vname] << ", // " << vname << " -> " << c.real_variable_index[i] << std::endl; + } + } + myfile << "};"; + myfile.close(); + + std::unordered_map vterms1; + for (auto& vname : special) { + vterms1.insert({ vname, c[vname] }); + } + + std::unordered_map mmap1 = s->model(vterms1); + for (auto& vname : special) { + info(vname, " = ", mmap1[vname]); + } +} + +bool smt_timer(smt_solver::Solver* s); +std::pair, std::vector> base4(uint32_t el); \ No newline at end of file