diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp index e2448004deb9..38fc7faf0f99 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.cpp @@ -2,6 +2,490 @@ namespace smt_circuit { +/** + * @brief Construct a new Circuit::Circuit object + * + * @param circuit_info CircuitShema object + * @param solver pointer to the global solver + * @param tag tag of the circuit. Empty by default. + */ +Circuit::Circuit(CircuitSchema& circuit_info, Solver* solver, TermType type, const std::string& tag, bool optimizations) + : variables(circuit_info.variables) + , public_inps(circuit_info.public_inps) + , variable_names(circuit_info.vars_of_interest) + , selectors(circuit_info.selectors) + , wires_idxs(circuit_info.wires) + , real_variable_index(circuit_info.real_variable_index) + , optimizations(optimizations) + , solver(solver) + , type(type) + , tag(tag) +{ + if (!this->tag.empty()) { + if (this->tag[0] != '_') { + this->tag = "_" + this->tag; + } + } + + for (auto& x : variable_names) { + variable_names_inverse.insert({ x.second, x.first }); + } + + variable_names.insert({ 0, "zero" }); + variable_names.insert({ 1, "one" }); + variable_names_inverse.insert({ "zero", 0 }); + variable_names_inverse.insert({ "one", 1 }); + optimized.insert({ 0, false }); + optimized.insert({ 1, false }); + + this->init(); + + // Perform all relaxation for gates or + // add gate in its normal state to solver + size_t i = 0; + while (i < this->get_num_gates()) { + i = this->prepare_gates(i); + } + + for (const auto& i : this->public_inps) { + this->symbolic_vars[this->real_variable_index[i]] == this->variables[i]; + } +} + +/** + * Creates all the needed symbolic variables and constants + * which are used in circuit. + * + */ +void Circuit::init() +{ + size_t num_vars = variables.size(); + symbolic_vars.insert({ 0, STerm::Var("zero" + this->tag, this->solver, this->type) }); + symbolic_vars.insert({ 1, STerm::Var("one" + this->tag, this->solver, this->type) }); + + for (uint32_t i = 2; i < num_vars; i++) { + uint32_t real_idx = this->real_variable_index[i]; + if (this->symbolic_vars.contains(real_idx)) { + continue; + } + + std::string name = variable_names.contains(real_idx) ? variable_names[real_idx] : "var_" + std::to_string(i); + name += this->tag; + symbolic_vars.insert({ real_idx, STerm::Var(name, this->solver, this->type) }); + + optimized.insert({ real_idx, true }); + } + + symbolic_vars[0] == bb::fr(0); + symbolic_vars[1] == bb::fr(1); +} + +/** + * @brief Relaxes univariate polynomial constraints. + * TODO(alex): probably won't be necessary in the nearest future + * because of new solver + * + * @param q_m multiplication selector + * @param q_1 l selector + * @param q_2 r selector + * @param q_3 o selector + * @param q_c constant + * @param w witness index + */ +void Circuit::handle_univariate_constraint(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w) +{ + bb::fr b = q_1 + q_2 + q_3; + + if (q_m == 0) { + symbolic_vars[w] == -q_c / b; + return; + } + + std::pair d = (b * b - bb::fr(4) * q_m * q_c).sqrt(); + if (!d.first) { + throw std::invalid_argument("There're no roots of quadratic polynomial"); + } + bb::fr x1 = (-b + d.second) / (bb::fr(2) * q_m); + bb::fr x2 = (-b - d.second) / (bb::fr(2) * q_m); + + if (d.second == 0) { + symbolic_vars[w] == STerm(x1, this->solver, type); + } else { + ((Bool(symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) | + (Bool(symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type)))) + .assert_term(); + } +} + +/** + * @brief Relaxes logic constraints(AND/XOR). + * @details This function is needed when we use bitwise compatible + * symbolic terms. + * It compares the chunk of selectors of the current circuit + * with pure create_logic_constraint from circuit_builder. + * It uses binary search to find a bit length of the constraint, + * since we don't know it in general. + * After a match is found, it updates the cursor to skip all the + * redundant constraints and adds a pure a ^ b = c or a & b = c + * constraint to solver. + * If there's no match, it will return -1 + * + * @param cursor current position + * @return next position or -1 + */ +size_t Circuit::handle_logic_constraint(size_t cursor) +{ + // Initialize binary search. Logic gate can only accept even bit lengths + // So we need to find a match among [1, 127] and then multiply the result by 2 + size_t beg = 1; + size_t end = 127; + size_t mid = 0; + auto res = static_cast(-1); + + // Indicates that current bit length is a match for XOR + bool xor_flag = true; + // Indicates that current bit length is a match for AND + bool and_flag = true; + // Indicates the logic operation(true - XOR, false - AND) if the match is found. + bool logic_flag = true; + CircuitProps xor_props; + CircuitProps and_props; + + bool stop_flag = false; + + while (beg <= end) { + mid = (end + beg) / 2; + + // Take a pure logic circuit for the current bit length(2 * mid) + // and compare it's selectors to selectors of the global circuit + // at current position(cursor). + // If they are equal, we can apply an optimization + // However, if we have a match at bit length 2, it is possible + // to have a match at higher bit lengths. That's why we store + // the current match as `res` and proceed with ordinary binary search. + // `stop_flag` simply indicates that the first selector doesn't match + // and we can skip this whole section. + + if (!this->cached_subcircuits[SubcircuitType::XOR].contains(mid * 2)) { + this->cached_subcircuits[SubcircuitType::XOR].insert( + { mid * 2, get_standard_logic_circuit(mid * 2, true) }); + } + xor_props = this->cached_subcircuits[SubcircuitType::XOR][mid * 2]; + + if (!this->cached_subcircuits[SubcircuitType::AND].contains(mid * 2)) { + this->cached_subcircuits[SubcircuitType::AND].insert( + { mid * 2, get_standard_logic_circuit(mid * 2, false) }); + } + and_props = this->cached_subcircuits[SubcircuitType::AND][mid * 2]; + + CircuitSchema xor_circuit = xor_props.circuit; + CircuitSchema and_circuit = and_props.circuit; + + xor_flag = cursor + xor_props.num_gates <= this->selectors.size(); + and_flag = cursor + xor_props.num_gates <= this->selectors.size(); + if (xor_flag || and_flag) { + for (size_t j = 0; j < xor_props.num_gates; j++) { + // It is possible for gates to be equal but wires to be not, but I think it's very + // unlikely to happen + xor_flag &= xor_circuit.selectors[j + xor_props.start_gate] == this->selectors[cursor + j]; + and_flag &= and_circuit.selectors[j + and_props.start_gate] == this->selectors[cursor + j]; + + if (!xor_flag && !and_flag) { + // Won't match at any bit length + if (j == 0) { + stop_flag = true; + } + break; + } + } + } + if (stop_flag) { + break; + } + + if (!xor_flag && !and_flag) { + end = mid - 1; + } else { + res = 2 * mid; + logic_flag = xor_flag; + + beg = mid + 1; + } + } + + // TODO(alex): Figure out if I need to create range constraint here too or it'll be + // created anyway in any circuit + if (res != static_cast(-1)) { + xor_props = get_standard_logic_circuit(res, true); + and_props = get_standard_logic_circuit(res, false); + + info("Logic constraint optimization: ", std::to_string(res), " bits. is_xor: ", logic_flag); + size_t left_gate = xor_props.gate_idxs[0]; + uint32_t left_gate_idx = xor_props.idxs[0]; + size_t right_gate = xor_props.gate_idxs[1]; + uint32_t right_gate_idx = xor_props.idxs[1]; + size_t out_gate = xor_props.gate_idxs[2]; + uint32_t out_gate_idx = xor_props.idxs[2]; + + uint32_t left_idx = this->real_variable_index[this->wires_idxs[cursor + left_gate][left_gate_idx]]; + uint32_t right_idx = this->real_variable_index[this->wires_idxs[cursor + right_gate][right_gate_idx]]; + uint32_t out_idx = this->real_variable_index[this->wires_idxs[cursor + out_gate][out_gate_idx]]; + + STerm left = this->symbolic_vars[left_idx]; + STerm right = this->symbolic_vars[right_idx]; + STerm out = this->symbolic_vars[out_idx]; + + if (logic_flag) { + (left ^ right) == out; + } else { + (left & right) == out; + } + + // You have to mark these arguments so they won't be optimized out + optimized[left_idx] = false; + optimized[right_idx] = false; + optimized[out_idx] = false; + return cursor + xor_props.num_gates; + } + return res; +} + +/** + * @brief Relaxes range constraints. + * @details This function is needed when we use range compatible + * symbolic terms. + * It compares the chunk of selectors of the current circuit + * with pure create_range_constraint from circuit_builder. + * It uses binary search to find a bit length of the constraint, + * since we don't know it in general. + * After match is found, it updates the cursor to skip all the + * redundant constraints and adds a pure a < 2^bit_length + * constraint to solver. + * If there's no match, it will return -1 + * + * @param cursor current position + * @return next position or -1 + */ +size_t Circuit::handle_range_constraint(size_t cursor) +{ + // Indicates that current bit length is a match + bool range_flag = true; + size_t mid = 0; + auto res = static_cast(-1); + + CircuitProps range_props; + // Range constraints differ depending on oddness of bit_length + // That's why we need to handle these cases separately + for (size_t odd = 0; odd < 2; odd++) { + // Initialize binary search. + // We need to find a match among [1, 127] and then set the result to 2 * mid, or 2 * mid + 1 + size_t beg = 1; + size_t end = 127; + + bool stop_flag = false; + while (beg <= end) { + mid = (end + beg) / 2; + + // Take a pure logic circuit for the current bit length(2 * mid + odd) + // and compare it's selectors to selectors of the global circuit + // at current positin(cursor). + // If they are equal, we can apply an optimization + // However, if we have a match at bit length 2, it is possible + // to have a match at higher bit lengths. That's why we store + // the current match as `res` and proceed with ordinary binary search. + // `stop_flag` simply indicates that the first selector doesn't match + // and we can skip this whole section. + + if (!this->cached_subcircuits[SubcircuitType::RANGE].contains(2 * mid + odd)) { + this->cached_subcircuits[SubcircuitType::RANGE].insert( + { 2 * mid + odd, get_standard_range_constraint_circuit(2 * mid + odd) }); + } + range_props = this->cached_subcircuits[SubcircuitType::RANGE][2 * mid + odd]; + CircuitSchema range_circuit = range_props.circuit; + + range_flag = cursor + range_props.num_gates <= this->get_num_gates(); + if (range_flag) { + for (size_t j = 0; j < range_props.num_gates; j++) { + // It is possible for gates to be equal but wires to be not, but I think it's very + // unlikely to happen + range_flag &= range_circuit.selectors[j + range_props.start_gate] == this->selectors[cursor + j]; + + if (!range_flag) { + // Won't match at any bit length + if (j <= 2) { + stop_flag = true; + } + break; + } + } + } + if (stop_flag) { + break; + } + + if (!range_flag) { + end = mid - 1; + } else { + res = 2 * mid + odd; + beg = mid + 1; + } + } + + if (res != static_cast(-1)) { + range_flag = true; + break; + } + } + + if (range_flag) { + info("Range constraint optimization: ", std::to_string(res), " bits"); + range_props = get_standard_range_constraint_circuit(res); + + size_t left_gate = range_props.gate_idxs[0]; + uint32_t left_gate_idx = range_props.idxs[0]; + uint32_t left_idx = this->real_variable_index[this->wires_idxs[cursor + left_gate][left_gate_idx]]; + + STerm left = this->symbolic_vars[left_idx]; + left <= (bb::fr(2).pow(res) - 1); + + // You have to mark these arguments so they won't be optimized out + optimized[left_idx] = false; + return cursor + range_props.num_gates; + } + return res; +} + +/** + * @brief Adds all the gate constraints to the solver. + * Relaxes constraint system for non-ff solver engines + * via removing subcircuits that were already proved being correct. + * + */ +size_t Circuit::prepare_gates(size_t cursor) +{ + // TODO(alex): implement bitvector class and compute offsets + if (this->type == TermType::BVTerm && this->optimizations) { + size_t res = handle_logic_constraint(cursor); + if (res != static_cast(-1)) { + return res; + } + } + + if ((this->type == TermType::BVTerm || this->type == TermType::FFITerm) && this->optimizations) { + size_t res = handle_range_constraint(cursor); + if (res != static_cast(-1)) { + return res; + } + } + + bb::fr q_m = this->selectors[cursor][0]; + bb::fr q_1 = this->selectors[cursor][1]; + bb::fr q_2 = this->selectors[cursor][2]; + bb::fr q_3 = this->selectors[cursor][3]; + bb::fr q_c = this->selectors[cursor][4]; + + uint32_t w_l = this->wires_idxs[cursor][0]; + uint32_t w_r = this->wires_idxs[cursor][1]; + uint32_t w_o = this->wires_idxs[cursor][2]; + optimized[w_l] = false; + optimized[w_r] = false; + optimized[w_o] = false; + + // Handles the case when we have univariate polynomial as constraint + // by simply finding the roots via quadratic formula(or linear) + // There're 7 possibilities of that, which are present below + bool univariate_flag = false; + univariate_flag |= (w_l == w_r) && (w_r == w_o); + univariate_flag |= (w_l == w_r) && (q_3 == 0); + univariate_flag |= (w_l == w_o) && (q_2 == 0) && (q_m == 0); + univariate_flag |= (w_r == w_o) && (q_1 == 0) && (q_m == 0); + univariate_flag |= (q_m == 0) && (q_1 == 0) && (q_3 == 0); + univariate_flag |= (q_m == 0) && (q_2 == 0) && (q_3 == 0); + univariate_flag |= (q_m == 0) && (q_1 == 0) && (q_2 == 0); + + // Univariate gate. Relaxes the solver. Or is it? + // TODO(alex): Test the effect of this relaxation after the tests are merged. + if (univariate_flag) { + if ((q_m == 1) && (q_1 == 0) && (q_2 == 0) && (q_3 == -1) && (q_c == 0)) { + (Bool(symbolic_vars[w_l]) == Bool(symbolic_vars[0]) | Bool(symbolic_vars[w_l]) == Bool(symbolic_vars[1])) + .assert_term(); + } else { + this->handle_univariate_constraint(q_m, q_1, q_2, q_3, q_c, w_l); + } + } else { + STerm eq = symbolic_vars[0]; + + // mul selector + if (q_m != 0) { + eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m; + } + // left selector + if (q_1 != 0) { + eq += symbolic_vars[w_l] * q_1; + } + // right selector + if (q_2 != 0) { + eq += symbolic_vars[w_r] * q_2; + } + // out selector + if (q_3 != 0) { + eq += symbolic_vars[w_o] * q_3; + } + // constant selector + if (q_c != 0) { + eq += q_c; + } + eq == symbolic_vars[0]; + } + return cursor + 1; +} + +/** + * @brief Returns a previously named symbolic variable. + * + * @param name + * @return STerm + */ +STerm Circuit::operator[](const std::string& name) +{ + if (!this->variable_names_inverse.contains(name)) { + throw std::invalid_argument("No such an item `" + name + "` in vars or it vas not declared as interesting"); + } + uint32_t idx = this->variable_names_inverse[name]; + return this->symbolic_vars[idx]; +} + +/** + * @brief Similar functionality to old .check_circuit() method + * in standard circuit builder. + * + * @param witness + * @return true + * @return false + */ +bool Circuit::simulate_circuit_eval(std::vector& witness) const +{ + if (witness.size() != this->get_num_vars()) { + throw std::invalid_argument("Witness size should be " + std::to_string(this->get_num_vars()) + ", not " + + std::to_string(witness.size())); + } + for (size_t i = 0; i < this->selectors.size(); i++) { + bb::fr res = 0; + bb::fr x = witness[this->wires_idxs[i][0]]; + bb::fr y = witness[this->wires_idxs[i][1]]; + bb::fr o = witness[this->wires_idxs[i][2]]; + res += this->selectors[i][0] * x * y; + res += this->selectors[i][1] * x; + res += this->selectors[i][2] * y; + res += this->selectors[i][3] * o; + res += this->selectors[i][4]; + if (res != 0) { + return false; + } + } + return true; +} + /** * @brief Check your circuit for witness uniqueness * @@ -17,17 +501,17 @@ namespace smt_circuit { * @param not_equal_at_the_same_time The list of variables, where at least one pair has to be distinct * @return std::pair */ -template -std::pair, Circuit> unique_witness_ext(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal, - const std::vector& not_equal, - const std::vector& equal_at_the_same_time, - const std::vector& not_equal_at_the_same_time) +std::pair unique_witness_ext(CircuitSchema& circuit_info, + Solver* s, + TermType type, + const std::vector& equal, + const std::vector& not_equal, + const std::vector& equal_at_the_same_time, + const std::vector& not_equal_at_the_same_time) { // TODO(alex): set optimizations to be true once they are confirmed - Circuit c1(circuit_info, s, "circuit1", false); - Circuit c2(circuit_info, s, "circuit2", false); + Circuit c1(circuit_info, s, type, "circuit1", false); + Circuit c2(circuit_info, s, type, "circuit2", false); for (const auto& term : equal) { c1[term] == c2[term]; @@ -62,22 +546,6 @@ std::pair, Circuit> unique_witness_ext(CircuitSchema& circuit_in return { c1, c2 }; } -template std::pair, Circuit> unique_witness_ext( - CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}, - const std::vector& not_equal = {}, - const std::vector& equal_at_the_same_time = {}, - const std::vector& not_eqaul_at_the_same_time = {}); - -template std::pair, Circuit> unique_witness_ext( - CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}, - const std::vector& not_equal = {}, - const std::vector& equal_at_the_same_time = {}, - const std::vector& not_eqaul_at_the_same_time = {}); - /** * @brief Check your circuit for witness uniqueness * @@ -91,14 +559,14 @@ template std::pair, Circuit> unique_witness_ext( * @param equal The list of names of variables which should be equal in both circuits(each is equal) * @return std::pair */ -template -std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal) +std::pair unique_witness(CircuitSchema& circuit_info, + Solver* s, + TermType type, + const std::vector& equal) { // TODO(alex): set optimizations to be true once they are confirmed - Circuit c1(circuit_info, s, "circuit1", false); - Circuit c2(circuit_info, s, "circuit2", false); + Circuit c1(circuit_info, s, type, "circuit1", false); + Circuit c2(circuit_info, s, type, "circuit2", false); for (const auto& term : equal) { c1[term] == c2[term]; @@ -124,12 +592,4 @@ std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, } return { c1, c2 }; } - -template std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}); - -template std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}); }; // namespace smt_circuit \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp index 910310a6684f..0c956acd83dc 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.hpp @@ -5,8 +5,7 @@ #include #include "barretenberg/smt_verification/terms/bool.hpp" -#include "barretenberg/smt_verification/terms/ffiterm.hpp" -#include "barretenberg/smt_verification/terms/ffterm.hpp" +#include "barretenberg/smt_verification/terms/term.hpp" #include "subcircuits.hpp" @@ -23,14 +22,12 @@ enum class SubcircuitType { XOR, AND, RANGE }; * * @details Contains all the information about the circuit: gates, variables, * symbolic variables, specified names and global solver. - * - * @tparam FF FFTerm or FFITerm */ -template class Circuit { +class Circuit { private: void init(); - size_t prepare_gates(size_t cursor); + void handle_univariate_constraint(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w); size_t handle_logic_constraint(size_t cursor); size_t handle_range_constraint(size_t cursor); @@ -42,7 +39,7 @@ template class Circuit { std::unordered_map variable_names_inverse; // inverse map of the previous memeber std::vector> selectors; // selectors from the circuit std::vector> wires_idxs; // values of the gates' wires - std::unordered_map symbolic_vars; // all the symbolic variables from the circuit + std::unordered_map symbolic_vars; // all the symbolic variables from the circuit std::vector real_variable_index; // indexes for assert_equal'd wires std::unordered_map optimized; // keeps track of the variables that were excluded from symbolic // circuit during optimizations @@ -51,18 +48,21 @@ template class Circuit { cached_subcircuits; // caches subcircuits during optimization // No need to recompute them each time - Solver* solver; // pointer to the solver + Solver* solver; // pointer to the solver + TermType type; // Type of the underlying Symbolic Terms + std::string tag; // tag of the symbolic circuit. // If not empty, will be added to the names // of symbolic variables to prevent collisions. explicit Circuit(CircuitSchema& circuit_info, Solver* solver, + TermType type = TermType::FFTerm, const std::string& tag = "", bool optimizations = true); - FF operator[](const std::string& name); - FF operator[](const uint32_t& idx) { return symbolic_vars[this->real_variable_index[idx]]; }; + STerm operator[](const std::string& name); + STerm operator[](const uint32_t& idx) { return this->symbolic_vars[this->real_variable_index[idx]]; }; inline size_t get_num_gates() const { return selectors.size(); }; inline size_t get_num_real_vars() const { return symbolic_vars.size(); }; inline size_t get_num_vars() const { return variables.size(); }; @@ -70,530 +70,17 @@ template class Circuit { bool simulate_circuit_eval(std::vector& witness) const; }; -/** - * @brief Construct a new Circuit::Circuit object - * - * @param circuit_info CircuitShema object - * @param solver pointer to the global solver - * @param tag tag of the circuit. Empty by default. - */ -template -Circuit::Circuit(CircuitSchema& circuit_info, Solver* solver, const std::string& tag, bool optimizations) - : variables(circuit_info.variables) - , public_inps(circuit_info.public_inps) - , variable_names(circuit_info.vars_of_interest) - , selectors(circuit_info.selectors) - , wires_idxs(circuit_info.wires) - , real_variable_index(circuit_info.real_variable_index) - , optimizations(optimizations) - , solver(solver) - , tag(tag) -{ - if (!this->tag.empty()) { - if (this->tag[0] != '_') { - this->tag = "_" + this->tag; - } - } - - for (auto& x : variable_names) { - variable_names_inverse.insert({ x.second, x.first }); - } - - variable_names.insert({ 0, "zero" }); - variable_names.insert({ 1, "one" }); - variable_names_inverse.insert({ "zero", 0 }); - variable_names_inverse.insert({ "one", 1 }); - optimized.insert({ 0, false }); - optimized.insert({ 1, false }); - - this->init(); - - // Perform all relaxation for gates or - // add gate in its normal state to solver - size_t i = 0; - while (i < this->get_num_gates()) { - i = this->prepare_gates(i); - } - - for (auto& opt : optimized) { - if (opt.second) { - this->symbolic_vars[opt.first] == 0; - } - } -} - -/** - * Creates all the needed symbolic variables and constants - * which are used in circuit. - * - */ -template void Circuit::init() -{ - size_t num_vars = variables.size(); - - symbolic_vars.insert({ 0, FF::Var("zero" + this->tag, this->solver) }); - symbolic_vars.insert({ 1, FF::Var("one" + this->tag, this->solver) }); - - for (uint32_t i = 2; i < num_vars; i++) { - uint32_t real_idx = this->real_variable_index[i]; - if (this->symbolic_vars.contains(real_idx)) { - continue; - } - - if (variable_names.contains(real_idx)) { - std::string name = variable_names[real_idx]; - symbolic_vars.insert({ real_idx, FF::Var(name + this->tag, this->solver) }); - } else { - symbolic_vars.insert({ real_idx, FF::Var("var_" + std::to_string(i) + this->tag, this->solver) }); - } - optimized.insert({ real_idx, true }); - } - - symbolic_vars[0] == bb::fr(0); - symbolic_vars[1] == bb::fr(1); -} - -/** - * @brief Relaxes univariate polynomial constraints. - * TODO(alex): probably won't be necessary in the nearest future - * because of new solver - * - * @param q_m multiplication selector - * @param q_1 l selector - * @param q_2 r selector - * @param q_3 o selector - * @param q_c constant - * @param w witness index - */ -template -void Circuit::handle_univariate_constraint(bb::fr q_m, bb::fr q_1, bb::fr q_2, bb::fr q_3, bb::fr q_c, uint32_t w) -{ - bb::fr b = q_1 + q_2 + q_3; - - if (q_m == 0) { - symbolic_vars[w] == -q_c / b; - return; - } - - std::pair d = (b * b - bb::fr(4) * q_m * q_c).sqrt(); - if (!d.first) { - throw std::invalid_argument("There're no roots of quadratic polynomial"); - } - bb::fr x1 = (-b + d.second) / (bb::fr(2) * q_m); - bb::fr x2 = (-b - d.second) / (bb::fr(2) * q_m); - - if (d.second == 0) { - symbolic_vars[w] == FF(x1, this->solver); - } else { - ((Bool(symbolic_vars[w]) == Bool(FF(x1, this->solver))) | - (Bool(symbolic_vars[w]) == Bool(FF(x2, this->solver)))) - .assert_term(); - } -} - -/** - * @brief Relaxes logic constraints(AND/XOR). - * @details This function is needed when we use bitwise compatible - * symbolic terms. - * It compares the chunk of selectors of the current circuit - * with pure create_logic_constraint from circuit_builder. - * It uses binary search to find a bit length of the constraint, - * since we don't know it in general. - * After a match is found, it updates the cursor to skip all the - * redundant constraints and adds a pure a ^ b = c or a & b = c - * constraint to solver. - * If there's no match, it will return -1 - * - * @param cursor current position - * @return next position or -1 - */ -template size_t Circuit::handle_logic_constraint(size_t cursor) -{ - // Initialize binary search. Logic gate can only accept even bit lengths - // So we need to find a match among [1, 127] and then multiply the result by 2 - size_t beg = 1; - size_t end = 127; - size_t mid = 0; - auto res = static_cast(-1); - - // Indicates that current bit length is a match for XOR - bool xor_flag = true; - // Indicates that current bit length is a match for AND - bool and_flag = true; - // Indicates the logic operation(true - XOR, false - AND) if the match is found. - bool logic_flag = true; - CircuitProps xor_props; - CircuitProps and_props; - - bool stop_flag = false; - - while (beg <= end) { - mid = (end + beg) / 2; - - // Take a pure logic circuit for the current bit length(2 * mid) - // and compare it's selectors to selectors of the global circuit - // at current position(cursor). - // If they are equal, we can apply an optimization - // However, if we have a match at bit length 2, it is possible - // to have a match at higher bit lengths. That's why we store - // the current match as `res` and proceed with ordinary binary search. - // `stop_flag` simply indicates that the first selector doesn't match - // and we can skip this whole section. - - if (!this->cached_subcircuits[SubcircuitType::XOR].contains(mid * 2)) { - this->cached_subcircuits[SubcircuitType::XOR].insert( - { mid * 2, get_standard_logic_circuit(mid * 2, true) }); - } - xor_props = this->cached_subcircuits[SubcircuitType::XOR][mid * 2]; - - if (!this->cached_subcircuits[SubcircuitType::AND].contains(mid * 2)) { - this->cached_subcircuits[SubcircuitType::AND].insert( - { mid * 2, get_standard_logic_circuit(mid * 2, false) }); - } - and_props = this->cached_subcircuits[SubcircuitType::AND][mid * 2]; - - CircuitSchema xor_circuit = xor_props.circuit; - CircuitSchema and_circuit = and_props.circuit; - - xor_flag = cursor + xor_props.num_gates <= this->selectors.size(); - and_flag = cursor + xor_props.num_gates <= this->selectors.size(); - if (xor_flag || and_flag) { - for (size_t j = 0; j < xor_props.num_gates; j++) { - // It is possible for gates to be equal but wires to be not, but I think it's very - // unlikely to happen - xor_flag &= xor_circuit.selectors[j + xor_props.start_gate] == this->selectors[cursor + j]; - and_flag &= and_circuit.selectors[j + and_props.start_gate] == this->selectors[cursor + j]; - - if (!xor_flag && !and_flag) { - // Won't match at any bit length - if (j == 0) { - stop_flag = true; - } - break; - } - } - } - if (stop_flag) { - break; - } - - if (!xor_flag && !and_flag) { - end = mid - 1; - } else { - res = 2 * mid; - logic_flag = xor_flag; - - beg = mid + 1; - } - } - - // TODO(alex): Figure out if I need to create range constraint here too or it'll be - // created anyway in any circuit - if (res != static_cast(-1)) { - xor_props = get_standard_logic_circuit(res, true); - and_props = get_standard_logic_circuit(res, false); - - info("Logic constraint optimization: ", std::to_string(res), " bits. is_xor: ", xor_flag); - size_t left_gate = xor_props.gate_idxs[0]; - uint32_t left_gate_idx = xor_props.idxs[0]; - size_t right_gate = xor_props.gate_idxs[1]; - uint32_t right_gate_idx = xor_props.idxs[1]; - size_t out_gate = xor_props.gate_idxs[2]; - uint32_t out_gate_idx = xor_props.idxs[2]; - - uint32_t left_idx = this->real_variable_index[this->wires_idxs[cursor + left_gate][left_gate_idx]]; - uint32_t right_idx = this->real_variable_index[this->wires_idxs[cursor + right_gate][right_gate_idx]]; - uint32_t out_idx = this->real_variable_index[this->wires_idxs[cursor + out_gate][out_gate_idx]]; - - FF left = this->symbolic_vars[left_idx]; - FF right = this->symbolic_vars[right_idx]; - FF out = this->symbolic_vars[out_idx]; - - if (logic_flag) { - (left ^ right) == out; - } else { - (left ^ right) == out; // TODO(alex): implement & method - } - - // You have to mark these arguments so they won't be optimized out - optimized[left_idx] = false; - optimized[right_idx] = false; - optimized[out_idx] = false; - return cursor + xor_props.num_gates; - } - return res; -} - -/** - * @brief Relaxes range constraints. - * @details This function is needed when we use range compatible - * symbolic terms. - * It compares the chunk of selectors of the current circuit - * with pure create_range_constraint from circuit_builder. - * It uses binary search to find a bit length of the constraint, - * since we don't know it in general. - * After match is found, it updates the cursor to skip all the - * redundant constraints and adds a pure a < 2^bit_length - * constraint to solver. - * If there's no match, it will return -1 - * - * @param cursor current position - * @return next position or -1 - */ -template size_t Circuit::handle_range_constraint(size_t cursor) -{ - // Indicates that current bit length is a match - bool range_flag = true; - size_t mid = 0; - auto res = static_cast(-1); - - CircuitProps range_props; - // Range constraints differ depending on oddness of bit_length - // That's why we need to handle these cases separately - for (size_t odd = 0; odd < 2; odd++) { - // Initialize binary search. - // We need to find a match among [1, 127] and then set the result to 2 * mid, or 2 * mid + 1 - size_t beg = 1; - size_t end = 127; - - bool stop_flag = false; - while (beg <= end) { - mid = (end + beg) / 2; - - // Take a pure logic circuit for the current bit length(2 * mid + odd) - // and compare it's selectors to selectors of the global circuit - // at current positin(cursor). - // If they are equal, we can apply an optimization - // However, if we have a match at bit length 2, it is possible - // to have a match at higher bit lengths. That's why we store - // the current match as `res` and proceed with ordinary binary search. - // `stop_flag` simply indicates that the first selector doesn't match - // and we can skip this whole section. - - if (!this->cached_subcircuits[SubcircuitType::RANGE].contains(2 * mid + odd)) { - this->cached_subcircuits[SubcircuitType::RANGE].insert( - { 2 * mid + odd, get_standard_range_constraint_circuit(2 * mid + odd) }); - } - range_props = this->cached_subcircuits[SubcircuitType::RANGE][2 * mid + odd]; - CircuitSchema range_circuit = range_props.circuit; - - range_flag = cursor + range_props.num_gates <= this->get_num_gates(); - if (range_flag) { - for (size_t j = 0; j < range_props.num_gates; j++) { - // It is possible for gates to be equal but wires to be not, but I think it's very - // unlikely to happen - range_flag &= range_circuit.selectors[j + range_props.start_gate] == this->selectors[cursor + j]; - - if (!range_flag) { - // Won't match at any bit length - if (j <= 2) { - stop_flag = true; - } - break; - } - } - } - if (stop_flag) { - break; - } - - if (!range_flag) { - end = mid - 1; - } else { - res = 2 * mid + odd; - beg = mid + 1; - } - } - - if (res != static_cast(-1)) { - range_flag = true; - break; - } - } - - if (range_flag) { - info("Range constraint optimization: ", std::to_string(res), " bits"); - range_props = get_standard_range_constraint_circuit(res); - - size_t left_gate = range_props.gate_idxs[0]; - uint32_t left_gate_idx = range_props.idxs[0]; - uint32_t left_idx = this->real_variable_index[this->wires_idxs[cursor + left_gate][left_gate_idx]]; - - FF left = this->symbolic_vars[left_idx]; - left < bb::fr(2).pow(res); - - // You have to mark these arguments so they won't be optimized out - optimized[left_idx] = false; - return cursor + range_props.num_gates; - } - return res; -} - -/** - * @brief Adds all the gate constraints to the solver. - * Relaxes constraint system for non-ff solver engines - * via removing subcircuits that were already proved being correct. - * - */ -template size_t Circuit::prepare_gates(size_t cursor) -{ - // TODO(alex): implement bitvector class and compute offsets - if (FF::isBitVector() && this->optimizations) { - size_t res = handle_logic_constraint(cursor); - if (res != static_cast(-1)) { - return res; - } - } - - if ((FF::isBitVector() || FF::isInteger()) && this->optimizations) { - size_t res = handle_range_constraint(cursor); - if (res != static_cast(-1)) { - return res; - } - } - - bb::fr q_m = this->selectors[cursor][0]; - bb::fr q_1 = this->selectors[cursor][1]; - bb::fr q_2 = this->selectors[cursor][2]; - bb::fr q_3 = this->selectors[cursor][3]; - bb::fr q_c = this->selectors[cursor][4]; - - uint32_t w_l = this->wires_idxs[cursor][0]; - uint32_t w_r = this->wires_idxs[cursor][1]; - uint32_t w_o = this->wires_idxs[cursor][2]; - optimized[w_l] = false; - optimized[w_r] = false; - optimized[w_o] = false; - - // Handles the case when we have univariate polynomial as constraint - // by simply finding the roots via quadratic formula(or linear) - // There're 7 possibilities of that, which are present below - bool univariate_flag = false; - univariate_flag |= (w_l == w_r) && (w_r == w_o); - univariate_flag |= (w_l == w_r) && (q_3 == 0); - univariate_flag |= (w_l == w_o) && (q_2 == 0) && (q_m == 0); - univariate_flag |= (w_r == w_o) && (q_1 == 0) && (q_m == 0); - univariate_flag |= (q_m == 0) && (q_1 == 0) && (q_3 == 0); - univariate_flag |= (q_m == 0) && (q_2 == 0) && (q_3 == 0); - univariate_flag |= (q_m == 0) && (q_1 == 0) && (q_2 == 0); - - // Univariate gate. Relaxes the solver. Or is it? - // TODO(alex): Test the effect of this relaxation after the tests are merged. - if (univariate_flag) { - if ((q_m == 1) && (q_1 == 0) && (q_2 == 0) && (q_3 == -1) && (q_c == 0)) { - (Bool(symbolic_vars[w_l]) == Bool(symbolic_vars[0]) | Bool(symbolic_vars[w_l]) == Bool(symbolic_vars[1])) - .assert_term(); - } else { - this->handle_univariate_constraint(q_m, q_1, q_2, q_3, q_c, w_l); - } - } else { - FF eq = symbolic_vars[0]; - - // mul selector - if (q_m != 0) { - eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m; - } - // left selector - if (q_1 != 0) { - eq += symbolic_vars[w_l] * q_1; - } - // right selector - if (q_2 != 0) { - eq += symbolic_vars[w_r] * q_2; - } - // out selector - if (q_3 != 0) { - eq += symbolic_vars[w_o] * q_3; - } - // constant selector - if (q_c != 0) { - eq += q_c; - } - eq == symbolic_vars[0]; - } - return cursor + 1; -} - -/** - * @brief Returns a previously named symbolic variable. - * - * @param name - * @return FF - */ -template FF Circuit::operator[](const std::string& name) -{ - if (!this->variable_names_inverse.contains(name)) { - throw std::invalid_argument("No such an item `" + name + "` in vars or it vas not declared as interesting"); - } - uint32_t idx = this->variable_names_inverse[name]; - return this->symbolic_vars[idx]; -} - -/** - * @brief Similar functionality to old .check_circuit() method - * in standard circuit builder. - * - * @param witness - * @return true - * @return false - */ -template bool Circuit::simulate_circuit_eval(std::vector& witness) const -{ - if (witness.size() != this->get_num_vars()) { - throw std::invalid_argument("Witness size should be " + std::to_string(this->get_num_vars()) + ", not " + - std::to_string(witness.size())); - } - for (size_t i = 0; i < this->selectors.size(); i++) { - bb::fr res = 0; - bb::fr x = witness[this->wires_idxs[i][0]]; - bb::fr y = witness[this->wires_idxs[i][1]]; - bb::fr o = witness[this->wires_idxs[i][2]]; - res += this->selectors[i][0] * x * y; - res += this->selectors[i][1] * x; - res += this->selectors[i][2] * y; - res += this->selectors[i][3] * o; - res += this->selectors[i][4]; - if (res != 0) { - return false; - } - } - return true; -} - -template -std::pair, Circuit> unique_witness_ext(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}, - const std::vector& not_equal = {}, - const std::vector& equal_at_the_same_time = {}, - const std::vector& not_equal_at_the_same_time = {}); - -extern template std::pair, Circuit> unique_witness_ext( - CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}, - const std::vector& not_equal = {}, - const std::vector& equal_at_the_same_time = {}, - const std::vector& not_equal_at_the_same_time = {}); - -extern template std::pair, Circuit> unique_witness_ext( - CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}, - const std::vector& not_equal = {}, - const std::vector& equal_at_the_same_time = {}, - const std::vector& not_equal_at_the_same_time = {}); - -template -std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}); - -extern template std::pair, Circuit> unique_witness(CircuitSchema& circuit_info, - Solver* s, - const std::vector& equal = {}); - -extern template std::pair, Circuit> unique_witness( - CircuitSchema& circuit_info, Solver* s, const std::vector& equal = {}); +std::pair unique_witness_ext(CircuitSchema& circuit_info, + Solver* s, + TermType type, + const std::vector& equal = {}, + const std::vector& not_equal = {}, + const std::vector& equal_at_the_same_time = {}, + const std::vector& not_equal_at_the_same_time = {}); + +std::pair unique_witness(CircuitSchema& circuit_info, + Solver* s, + TermType type, + const std::vector& equal = {}); }; // namespace smt_circuit diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp index 2660f8628b6b..9530188a5970 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/circuit/circuit.test.cpp @@ -3,9 +3,11 @@ #include #include "barretenberg/proof_system/circuit_builder/standard_circuit_builder.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" +#include "barretenberg/stdlib/primitives/uint/uint.hpp" + #include "barretenberg/smt_verification/circuit/circuit.hpp" #include "barretenberg/smt_verification/util/smt_util.hpp" -#include "barretenberg/stdlib/primitives/field/field.hpp" #include @@ -19,6 +21,7 @@ auto& engine = numeric::get_debug_randomness(); using field_t = stdlib::field_t; using witness_t = stdlib::witness_t; using pub_witness_t = stdlib::public_witness_t; +using uint_ct = stdlib::uint32; TEST(circuit, assert_equal) { @@ -49,7 +52,7 @@ TEST(circuit, assert_equal) auto buf = builder.export_circuit(); CircuitSchema circuit_info = unpack_from_buffer(buf); Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); + Circuit circuit(circuit_info, &s, TermType::FFTerm); ASSERT_EQ(circuit[k.get_witness_index()].term, circuit["c"].term); ASSERT_EQ(circuit[d.get_witness_index()].term, circuit["a"].term); @@ -59,6 +62,23 @@ TEST(circuit, assert_equal) ASSERT_EQ(circuit[i.get_witness_index()].term, circuit[j.get_witness_index()].term); } +TEST(circuit, cached_subcircuits) +{ + StandardCircuitBuilder builder = StandardCircuitBuilder(); + field_t a(witness_t(&builder, fr::zero())); + builder.set_variable_name(a.get_witness_index(), "a"); + a.create_range_constraint(5); + field_t b(witness_t(&builder, fr::zero())); + b.create_range_constraint(5); + builder.set_variable_name(b.get_witness_index(), "b"); + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus); + Circuit circuit(circuit_info, &s, TermType::FFITerm); + s.print_assertions(); +} + TEST(circuit, range_relaxation_assertions) { StandardCircuitBuilder builder = StandardCircuitBuilder(); @@ -75,7 +95,7 @@ TEST(circuit, range_relaxation_assertions) auto buf = builder.export_circuit(); CircuitSchema circuit_info = unpack_from_buffer(buf); Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); + Circuit circuit(circuit_info, &s, TermType::FFITerm); s.print_assertions(); } @@ -90,25 +110,42 @@ TEST(circuit, range_relaxation) auto buf = builder.export_circuit(); CircuitSchema circuit_info = unpack_from_buffer(buf); Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); + Circuit circuit(circuit_info, &s, TermType::FFITerm); } } -TEST(circuit, cached_subcircuits) +TEST(circuit, xor_relaxation_assertions) { StandardCircuitBuilder builder = StandardCircuitBuilder(); - field_t a(witness_t(&builder, fr::zero())); + uint_ct a(witness_t(&builder, static_cast(fr(120)))); + uint_ct b(witness_t(&builder, static_cast(fr(120)))); + uint_ct c = a ^ b; builder.set_variable_name(a.get_witness_index(), "a"); - a.create_range_constraint(5); - field_t b(witness_t(&builder, fr::zero())); - b.create_range_constraint(5); builder.set_variable_name(b.get_witness_index(), "b"); + builder.set_variable_name(c.get_witness_index(), "c"); auto buf = builder.export_circuit(); CircuitSchema circuit_info = unpack_from_buffer(buf); - Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); + Solver s(circuit_info.modulus, default_solver_config, 16, 32); + Circuit circuit(circuit_info, &s, TermType::BVTerm); + s.print_assertions(); } -// TODO(alex): check xor relaxations after bivector is here \ No newline at end of file +TEST(circuit, and_relaxation_assertions) +{ + StandardCircuitBuilder builder = StandardCircuitBuilder(); + uint_ct a(witness_t(&builder, static_cast(fr(120)))); + uint_ct b(witness_t(&builder, static_cast(fr(120)))); + uint_ct c = a & b; + builder.set_variable_name(a.get_witness_index(), "a"); + builder.set_variable_name(b.get_witness_index(), "b"); + builder.set_variable_name(c.get_witness_index(), "c"); + + auto buf = builder.export_circuit(); + CircuitSchema circuit_info = unpack_from_buffer(buf); + Solver s(circuit_info.modulus, default_solver_config, 16, 32); + Circuit circuit(circuit_info, &s, TermType::BVTerm); + + s.print_assertions(); +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/smt_bigfield.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/smt_bigfield.test.cpp deleted file mode 100644 index 1578c2474b11..000000000000 --- a/barretenberg/cpp/src/barretenberg/smt_verification/smt_bigfield.test.cpp +++ /dev/null @@ -1,231 +0,0 @@ -#include "barretenberg/numeric/random/engine.hpp" - -#include "barretenberg/ecc/curves/bn254/fq.hpp" -#include "barretenberg/ecc/curves/bn254/fr.hpp" - -#include "barretenberg/stdlib/primitives/bigfield/bigfield.hpp" -#include "barretenberg/stdlib/primitives/bool/bool.hpp" -#include "barretenberg/stdlib/primitives/byte_array/byte_array.hpp" -#include "barretenberg/stdlib/primitives/field/field.hpp" - -#include "barretenberg/plonk/proof_system/constants.hpp" -#include "barretenberg/plonk/proof_system/prover/prover.hpp" -#include "barretenberg/plonk/proof_system/verifier/verifier.hpp" - -#include "barretenberg/stdlib/primitives/circuit_builders/circuit_builders.hpp" -#include "barretenberg/stdlib/primitives/curves/bn254.hpp" - -#include "barretenberg/polynomials/polynomial_arithmetic.hpp" -#include -#include -#include - -#include -#include -#include - -#include "barretenberg/smt_verification/circuit/circuit.hpp" - -using namespace smt_circuit; -using namespace bb; -using namespace bb::plonk; - -using field_ct = stdlib::field_t; -using witness_t = stdlib::witness_t; -using pub_witness_t = stdlib::public_witness_t; - -using bn254 = stdlib::bn254; - -using fr_ct = bn254::ScalarField; -using fq_ct = bn254::BaseField; -using public_witness_ct = bn254::public_witness_ct; -using witness_ct = bn254::witness_ct; - -msgpack::sbuffer create_circuit(bool pub_ab, bool ab) -{ - StandardCircuitBuilder builder = StandardCircuitBuilder(); - fq inputs[2]{ fq::random_element(), fq::random_element() }; - fq_ct a, b; - if (pub_ab) { - a = fq_ct(public_witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - public_witness_ct( - &builder, fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - b = fq_ct(public_witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - public_witness_ct( - &builder, fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - } else { - a = fq_ct( - witness_ct(&builder, fr(uint256_t(inputs[0]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - witness_ct(&builder, fr(uint256_t(inputs[0]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - b = fq_ct( - witness_ct(&builder, fr(uint256_t(inputs[1]).slice(0, fq_ct::NUM_LIMB_BITS * 2))), - witness_ct(&builder, fr(uint256_t(inputs[1]).slice(fq_ct::NUM_LIMB_BITS * 2, fq_ct::NUM_LIMB_BITS * 4)))); - } - builder.set_variable_name(a.binary_basis_limbs[0].element.witness_index, "a_limb_0"); - builder.set_variable_name(a.binary_basis_limbs[1].element.witness_index, "a_limb_1"); - builder.set_variable_name(a.binary_basis_limbs[2].element.witness_index, "a_limb_2"); - builder.set_variable_name(a.binary_basis_limbs[3].element.witness_index, "a_limb_3"); - - if (ab) { - builder.set_variable_name(b.binary_basis_limbs[0].element.witness_index, "b_limb_0"); - builder.set_variable_name(b.binary_basis_limbs[1].element.witness_index, "b_limb_1"); - builder.set_variable_name(b.binary_basis_limbs[2].element.witness_index, "b_limb_2"); - builder.set_variable_name(b.binary_basis_limbs[3].element.witness_index, "b_limb_3"); - } - - fq_ct c; - if (ab) { - c = a * b; - } else { - c = a * a; - } - builder.set_variable_name(c.binary_basis_limbs[0].element.witness_index, "c_limb_0"); - builder.set_variable_name(c.binary_basis_limbs[1].element.witness_index, "c_limb_1"); - builder.set_variable_name(c.binary_basis_limbs[2].element.witness_index, "c_limb_2"); - builder.set_variable_name(c.binary_basis_limbs[3].element.witness_index, "c_limb_3"); - return builder.export_circuit(); -} - -const std::string q = "21888242871839275222246405745257275088696311157297823662689037894645226208583"; - -std::vector correct_result(Circuit& c, Solver* s) -{ - FFTerm a_limb0 = c["a_limb_0"]; - FFTerm a_limb1 = c["a_limb_1"]; - FFTerm a_limb2 = c["a_limb_2"]; - FFTerm a_limb3 = c["a_limb_3"]; - - FFTerm b_limb0 = c["b_limb_0"]; - FFTerm b_limb1 = c["b_limb_1"]; - FFTerm b_limb2 = c["b_limb_2"]; - FFTerm b_limb3 = c["b_limb_3"]; - - FFTerm c_limb0 = c["c_limb_0"]; - FFTerm c_limb1 = c["c_limb_1"]; - FFTerm c_limb2 = c["c_limb_2"]; - FFTerm c_limb3 = c["c_limb_3"]; - - FFTerm two68 = FFTerm::Const("100000000000000000", s); - FFTerm two136 = two68 * two68; - FFTerm two204 = two136 * two68; - - FFTerm a = a_limb0 + two68 * a_limb1 + two136 * a_limb2 + two204 * a_limb3; - FFTerm b = b_limb0 + two68 * b_limb1 + two136 * b_limb2 + two204 * b_limb3; - FFTerm cr = c_limb0 + two68 * c_limb1 + two136 * c_limb2 + two204 * c_limb3; - FFTerm n = FFTerm::Var("n", s); - FFTerm q_ = FFTerm::Const(q, s, 10); // Const(q_hex, s) - a* b != cr + n* q_; - return { cr, n }; -} - -void model_variables(Circuit& c, Solver* s, std::vector& evaluation) -{ - std::unordered_map terms; - for (size_t i = 0; i < 4; i++) { - terms.insert({ "a_limb_" + std::to_string(i), c["a_limb_" + std::to_string(i)] }); - terms.insert({ "b_limb_" + std::to_string(i), c["b_limb_" + std::to_string(i)] }); - terms.insert({ "c_limb_" + std::to_string(i), c["c_limb_" + std::to_string(i)] }); - } - terms.insert({ "cr", evaluation[0] }); - terms.insert({ "n", evaluation[1] }); - - auto values = s->model(terms); - - for (size_t i = 0; i < 4; i++) { - std::string tmp = "a_limb_" + std::to_string(i); - info(tmp, " = ", values[tmp]); - } - for (size_t i = 0; i < 4; i++) { - std::string tmp = "b_limb_" + std::to_string(i); - info(tmp, " = ", values[tmp]); - } - for (size_t i = 0; i < 4; i++) { - std::string tmp = "c_limb_" + std::to_string(i); - info(tmp, " = ", values[tmp]); - } - info("cr = ", values["cr"]); - info("n = ", values["n"]); -} - -void model_variables1(Circuit& c1, Circuit& c2, Solver* s) -{ - std::unordered_map terms; - for (size_t i = 0; i < 4; i++) { - terms.insert({ "a_limb_" + std::to_string(i), c1["a_limb_" + std::to_string(i)] }); - terms.insert({ "c1_limb_" + std::to_string(i), c1["c_limb_" + std::to_string(i)] }); - terms.insert({ "c2_limb_" + std::to_string(i), c2["c_limb_" + std::to_string(i)] }); - } - auto values = s->model(terms); - - for (size_t i = 0; i < 4; i++) { - std::string tmp = "a_limb_" + std::to_string(i); - info(tmp, " = ", values[tmp]); - } - - for (size_t i = 0; i < 4; i++) { - std::string tmp = "c1_limb_" + std::to_string(i); - info(tmp, " = ", values[tmp]); - } - - for (size_t i = 0; i < 4; i++) { - std::string tmp = "c2_limb_" + std::to_string(i); - info(tmp, " = ", values[tmp]); - } -} - -TEST(bigfield, multiplication_equal) -{ - bool public_a_b = true; - bool a_neq_b = true; - auto buf = create_circuit(public_a_b, a_neq_b); - - CircuitSchema circuit_info = unpack_from_buffer(buf); - Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); - std::vector ev = correct_result(circuit, &s); - - auto start = std::chrono::high_resolution_clock::now(); - bool res = s.check(); - auto stop = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(stop - start); - - info(); - info("Gates: ", circuit.get_num_gates()); - info("Result: ", s.getResult()); - info("Time elapsed: ", static_cast(duration.count()) / 1e6, " sec"); - - if (res) { - model_variables(circuit, &s, ev); - } -} - -TEST(bigfield, unique_square) -{ - auto buf = create_circuit(true, false); - - CircuitSchema circuit_info = unpack_from_buffer(buf); - - Solver s(circuit_info.modulus); - - std::pair, Circuit> cs = - unique_witness_ext(circuit_info, - &s, - { "a_limb_0", "a_limb_1", "a_limb_2", "a_limb_3" }, - { "c_limb_0", "c_limb_1", "c_limb_2", "c_limb_3" }); - - auto start = std::chrono::high_resolution_clock::now(); - bool res = s.check(); - auto stop = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(stop - start); - - ASSERT_FALSE(res); - - info(); - info("Gates: ", cs.first.get_num_gates()); - info("Result: ", s.getResult()); - info("Time elapsed: ", static_cast(duration.count()) / 1e6, " sec"); - - if (res) { - model_variables1(cs.first, cs.second, &s); - } -} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp index 1ac73a89e90c..2a934d7fc8a3 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/smt_examples.test.cpp @@ -36,13 +36,13 @@ TEST(SMT_Example, multiplication_true) smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); smt_solver::Solver s(circuit_info.modulus); - smt_circuit::Circuit circuit(circuit_info, &s); - smt_terms::FFTerm a1 = circuit["a"]; - smt_terms::FFTerm b1 = circuit["b"]; - smt_terms::FFTerm c1 = circuit["c"]; - smt_terms::FFTerm two = smt_terms::FFTerm::Const("2", &s, 10); - smt_terms::FFTerm thr = smt_terms::FFTerm::Const("3", &s, 10); - smt_terms::FFTerm cr = smt_terms::FFTerm::Var("cr", &s); + smt_circuit::Circuit circuit(circuit_info, &s, smt_terms::TermType::FFTerm); + smt_terms::STerm a1 = circuit["a"]; + smt_terms::STerm b1 = circuit["b"]; + smt_terms::STerm c1 = circuit["c"]; + smt_terms::STerm two = smt_terms::FFConst("2", &s, 10); + smt_terms::STerm thr = smt_terms::FFConst("3", &s, 10); + smt_terms::STerm cr = smt_terms::FFVar("cr", &s); cr = (two * a1) / (thr * b1); c1 != cr; @@ -67,13 +67,13 @@ TEST(SMT_Example, multiplication_true_kind) smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); smt_solver::Solver s(circuit_info.modulus); - smt_circuit::Circuit circuit(circuit_info, &s); - smt_terms::FFTerm a1 = circuit["a"]; - smt_terms::FFTerm b1 = circuit["b"]; - smt_terms::FFTerm c1 = circuit["c"]; - smt_terms::FFTerm two = smt_terms::FFTerm::Const("2", &s, 10); - smt_terms::FFTerm thr = smt_terms::FFTerm::Const("3", &s, 10); - smt_terms::FFTerm cr = smt_terms::FFTerm::Var("cr", &s); + smt_circuit::Circuit circuit(circuit_info, &s, smt_terms::TermType::FFTerm); + smt_terms::STerm a1 = circuit["a"]; + smt_terms::STerm b1 = circuit["b"]; + smt_terms::STerm c1 = circuit["c"]; + smt_terms::STerm two = smt_terms::FFConst("2", &s, 10); + smt_terms::STerm thr = smt_terms::FFConst("3", &s, 10); + smt_terms::STerm cr = smt_terms::FFVar("cr", &s); cr* thr* b1 == two* a1; c1 != cr; @@ -98,15 +98,15 @@ TEST(SMT_Example, multiplication_false) smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); smt_solver::Solver s(circuit_info.modulus); - smt_circuit::Circuit circuit(circuit_info, &s); + smt_circuit::Circuit circuit(circuit_info, &s, smt_terms::TermType::FFTerm); - smt_terms::FFTerm a1 = circuit["a"]; - smt_terms::FFTerm b1 = circuit["b"]; - smt_terms::FFTerm c1 = circuit["c"]; + smt_terms::STerm a1 = circuit["a"]; + smt_terms::STerm b1 = circuit["b"]; + smt_terms::STerm c1 = circuit["c"]; - smt_terms::FFTerm two = smt_terms::FFTerm::Const("2", &s, 10); - smt_terms::FFTerm thr = smt_terms::FFTerm::Const("3", &s, 10); - smt_terms::FFTerm cr = smt_terms::FFTerm::Var("cr", &s); + smt_terms::STerm two = smt_terms::FFConst("2", &s, 10); + smt_terms::STerm thr = smt_terms::FFConst("3", &s, 10); + smt_terms::STerm cr = smt_terms::FFVar("cr", &s); cr = (two * a1) / (thr * b1); c1 != cr; @@ -123,8 +123,10 @@ TEST(SMT_Example, multiplication_false) info("c_res = ", vals["cr"]); } +// Make sure that quadratic polynomial evaluation doesn't have unique +// witness using unique_witness_ext function +// Find both roots of a quadratic equation x^2 + a * x + b = s TEST(SMT_Example, unique_witness_ext) -// two roots of a quadratic eq x^2 + a * x + b = s { StandardCircuitBuilder builder = StandardCircuitBuilder(); @@ -142,8 +144,8 @@ TEST(SMT_Example, unique_witness_ext) smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); smt_solver::Solver s(circuit_info.modulus); - std::pair, smt_circuit::Circuit> cirs = - smt_circuit::unique_witness_ext(circuit_info, &s, { "ev" }, { "z" }); + std::pair cirs = + smt_circuit::unique_witness_ext(circuit_info, &s, smt_terms::TermType::FFTerm, { "ev" }, { "z" }); bool res = s.check(); ASSERT_TRUE(res); @@ -154,7 +156,7 @@ TEST(SMT_Example, unique_witness_ext) } // Make sure that quadratic polynomial evaluation doesn't have unique -// witness. +// witness using unique_witness function // Finds both roots of a quadratic eq x^2 + a * x + b = s TEST(SMT_Example, unique_witness) { @@ -174,40 +176,8 @@ TEST(SMT_Example, unique_witness) smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); smt_solver::Solver s(circuit_info.modulus); - std::pair, smt_circuit::Circuit> cirs = - smt_circuit::unique_witness(circuit_info, &s, { "ev" }); - - bool res = s.check(); - ASSERT_TRUE(res); - - std::unordered_map terms = { { "z_c1", cirs.first["z"] }, { "z_c2", cirs.second["z"] } }; - std::unordered_map vals = s.model(terms); - ASSERT_NE(vals["z_c1"], vals["z_c2"]); -} - -// Make sure that quadratic polynomial evaluation doesn't have unique -// witness. Also coefficients are private. -// Finds both roots of a quadratic eq x^2 + a * x + b = s -TEST(SMT_Example, unique_witness_private_coefficients) -{ - StandardCircuitBuilder builder = StandardCircuitBuilder(); - - field_t a(witness_t(&builder, fr::random_element())); - field_t b(witness_t(&builder, fr::random_element())); - builder.set_variable_name(a.witness_index, "a"); - builder.set_variable_name(b.witness_index, "b"); - field_t z(witness_t(&builder, fr::random_element())); - field_t ev = z * z + a * z + b; - builder.set_variable_name(z.witness_index, "z"); - builder.set_variable_name(ev.witness_index, "ev"); - - auto buf = builder.export_circuit(); - - smt_circuit::CircuitSchema circuit_info = smt_circuit::unpack_from_buffer(buf); - smt_solver::Solver s(circuit_info.modulus); - - std::pair, smt_circuit::Circuit> cirs = - smt_circuit::unique_witness(circuit_info, &s, { "ev", "a", "b" }); + std::pair cirs = + smt_circuit::unique_witness(circuit_info, &s, smt_terms::TermType::FFTerm, { "ev" }); bool res = s.check(); ASSERT_TRUE(res); diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/smt_polynomials.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/smt_polynomials.test.cpp index b19e52089385..416c95e3f8e9 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/smt_polynomials.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/smt_polynomials.test.cpp @@ -1,76 +1,50 @@ -#include "barretenberg/crypto/generators/generator_data.hpp" -#include "barretenberg/crypto/pedersen_commitment/pedersen.hpp" -#include "barretenberg/proof_system/circuit_builder/standard_circuit_builder.hpp" - #include #include #include #include #include +#include "barretenberg/proof_system/circuit_builder/standard_circuit_builder.hpp" +#include "barretenberg/serialize/cbind.hpp" #include "barretenberg/stdlib/primitives/field/field.hpp" -#include "barretenberg/serialize/cbind.hpp" #include "barretenberg/smt_verification/circuit/circuit.hpp" +#include "barretenberg/smt_verification/util/smt_util.hpp" using namespace bb; using namespace smt_circuit; -using field_ct = stdlib::field_t; +using field_t = stdlib::field_t; using witness_t = stdlib::witness_t; using pub_witness_t = stdlib::public_witness_t; -// TODO(alex): z1 = z2, s1=s2, but coefficients are not public namespace { auto& engine = numeric::get_debug_randomness(); } -msgpack::sbuffer create_circuit(size_t n, bool pub_coeffs) +msgpack::sbuffer create_polynomial_evaluation_circuit(size_t n, bool pub_coeffs) { StandardCircuitBuilder builder = StandardCircuitBuilder(); - std::vector coeffs; - std::vector idxs; + + std::vector coeffs; for (size_t i = 0; i < n; i++) { - fr tmp_coeff = fr::random_element(); - uint32_t idx; if (pub_coeffs) { - idx = builder.add_public_variable(tmp_coeff); + coeffs.emplace_back(pub_witness_t(&builder, fr::random_element())); } else { - idx = builder.add_variable(tmp_coeff); + coeffs.emplace_back(witness_t(&builder, fr::random_element())); } - idxs.push_back(idx); - coeffs.push_back(tmp_coeff); - builder.set_variable_name(idx, "coeff_" + std::to_string(i)); + builder.set_variable_name(coeffs.back().get_witness_index(), "coeff_" + std::to_string(i)); } - fr z(10); - uint32_t z_idx = builder.add_variable(z); - builder.set_variable_name(z_idx, "point"); - fr res = fr::zero(); - uint32_t res_idx = builder.zero_idx; // i think assert_equal was needed for zero initialization - builder.assert_equal(res_idx, 0); + field_t z(witness_t(&builder, 10)); + builder.set_variable_name(z.get_witness_index(), "point"); + + field_t res = field_t::from_witness_index(&builder, 0); for (size_t i = 0; i < n; i++) { - res = res * z; - uint32_t mul_idx = builder.add_variable(res); - // builder.set_variable_name(mul_idx, "mul_" + std::to_string(i)); - builder.create_mul_gate({ res_idx, z_idx, mul_idx, fr::one(), fr::neg_one(), fr::zero() }); - - res = res + coeffs[i]; - uint32_t add_idx = builder.add_variable(res); - builder.create_add_gate({ - mul_idx, - idxs[i], - add_idx, - fr::one(), - fr::one(), - fr::neg_one(), - fr::zero(), - }); - - res_idx = add_idx; + res = res * z + coeffs[i]; } - builder.set_variable_name(res_idx, "result"); + builder.set_variable_name(res.get_witness_index(), "result"); info("evaluation at point ", z, ": ", res); info("gates: ", builder.num_gates); @@ -80,26 +54,18 @@ msgpack::sbuffer create_circuit(size_t n, bool pub_coeffs) return builder.export_circuit(); } -FFTerm polynomial_evaluation(Circuit& c, size_t n, bool is_correct = true) +STerm direct_polynomial_evaluation(Circuit& c, size_t n) { - std::vector coeffs(n); - for (size_t i = 0; i < n; i++) { - coeffs[i] = c["coeff_" + std::to_string(i)]; - } - - FFTerm point = c["point"]; - FFTerm result = c["result"]; - - FFTerm ev = is_correct ? c["zero"] : c["one"]; + STerm point = c["point"]; + STerm result = c["result"]; + STerm ev = c["zero"]; for (size_t i = 0; i < n; i++) { - ev = ev * point + coeffs[i]; + ev = ev * point + c["coeff_" + std::to_string(i)]; } - - result != ev; return ev; } -void model_variables(Circuit& c, Solver* s, FFTerm& evaluation) +void model_variables(Circuit& c, Solver* s, STerm& evaluation) { std::unordered_map terms; terms.insert({ "point", c["point"] }); @@ -113,54 +79,34 @@ void model_variables(Circuit& c, Solver* s, FFTerm& evaluatio info("function_evaluation = ", values["evaluation"]); } -TEST(polynomial_evaluation, correct) +TEST(polynomial_evaluation, public) { - size_t n = 30; - auto buf = create_circuit(n, true); + size_t n = 40; + auto buf = create_polynomial_evaluation_circuit(n, true); CircuitSchema circuit_info = unpack_from_buffer(buf); - Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); - FFTerm ev = polynomial_evaluation(circuit, n, true); - - auto start = std::chrono::high_resolution_clock::now(); - bool res = s.check(); - auto stop = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(stop - start); + Circuit circuit(circuit_info, &s, TermType::FFTerm); + STerm ev = direct_polynomial_evaluation(circuit, n); + ev != circuit["result"]; + bool res = smt_timer(&s, false); ASSERT_FALSE(res); - info(); - info("Gates: ", circuit.get_num_gates()); - info("Result: ", s.getResult()); - info("Time elapsed: ", static_cast(duration.count()) / 1e6, " sec"); } -TEST(polynomial_evaluation, incorrect) +TEST(polynomial_evaluation, private) { - size_t n = 30; - auto buf = create_circuit(n, true); + size_t n = 40; + auto buf = create_polynomial_evaluation_circuit(n, false); CircuitSchema circuit_info = unpack_from_buffer(buf); - Solver s(circuit_info.modulus); - Circuit circuit(circuit_info, &s); - FFTerm ev = polynomial_evaluation(circuit, n, false); - - auto start = std::chrono::high_resolution_clock::now(); - bool res = s.check(); - auto stop = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(stop - start); + Circuit circuit(circuit_info, &s, TermType::FFTerm); + STerm ev = direct_polynomial_evaluation(circuit, n); + ev != circuit["result"]; - ASSERT_TRUE(res); - info(); + bool res = smt_timer(&s, false); + ASSERT_FALSE(res); info("Gates: ", circuit.get_num_gates()); info("Result: ", s.getResult()); - info("Time elapsed: ", static_cast(duration.count()) / 1e6, " sec"); - - if (res) { - model_variables(circuit, &s, ev); - } -} - -// TODO(alex) try with arbitrary coefficients \ No newline at end of file +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp index 4090cce93524..e3924f0774e9 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.cpp @@ -112,6 +112,9 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis) if (term.getKind() == cvc5::Kind::CONST_INTEGER) { return term.getIntegerValue(); } + if (term.getKind() == cvc5::Kind::CONST_BITVECTOR) { + return term.getBitVectorValue(); + } if (term.getKind() == cvc5::Kind::CONST_BOOLEAN) { std::vector bool_res = { "false", "true" }; return bool_res[static_cast(term.getBooleanValue())]; @@ -120,22 +123,27 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis) std::string res; std::string op; bool child_parenthesis = true; + bool back = false; switch (term.getKind()) { case cvc5::Kind::ADD: case cvc5::Kind::FINITE_FIELD_ADD: + case cvc5::Kind::BITVECTOR_ADD: op = " + "; child_parenthesis = false; break; case cvc5::Kind::SUB: + case cvc5::Kind::BITVECTOR_SUB: op = " - "; child_parenthesis = false; break; case cvc5::Kind::NEG: case cvc5::Kind::FINITE_FIELD_NEG: + case cvc5::Kind::BITVECTOR_NEG: res = "-"; break; case cvc5::Kind::MULT: case cvc5::Kind::FINITE_FIELD_MULT: + case cvc5::Kind::BITVECTOR_MULT: op = " * "; break; case cvc5::Kind::EQUAL: @@ -143,23 +151,50 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis) child_parenthesis = false; break; case cvc5::Kind::LT: + case cvc5::Kind::BITVECTOR_ULT: op = " < "; break; case cvc5::Kind::GT: + case cvc5::Kind::BITVECTOR_UGT: op = " > "; break; case cvc5::Kind::LEQ: + case cvc5::Kind::BITVECTOR_ULE: op = " <= "; break; case cvc5::Kind::GEQ: + case cvc5::Kind::BITVECTOR_UGE: op = " >= "; break; case cvc5::Kind::XOR: + case cvc5::Kind::BITVECTOR_XOR: op = " ^ "; break; + case cvc5::Kind::BITVECTOR_OR: + op = " | "; + break; case cvc5::Kind::OR: op = " || "; break; + case cvc5::Kind::BITVECTOR_AND: + op = " & "; + break; + case cvc5::Kind::BITVECTOR_SHL: + back = true; + op = " << " + term.getOp()[0].toString(); + break; + case cvc5::Kind::BITVECTOR_LSHR: + back = true; + op = " >> " + term.getOp()[0].toString(); + break; + case cvc5::Kind::BITVECTOR_ROTATE_LEFT: + back = true; + op = " ><< " + term.getOp()[0].toString(); + break; + case cvc5::Kind::BITVECTOR_ROTATE_RIGHT: + back = true; + op = " >>< " + term.getOp()[0].toString(); + break; case cvc5::Kind::AND: op = " && "; break; @@ -187,6 +222,9 @@ std::string stringify_term(const cvc5::Term& term, bool parenthesis) } res = res + stringify_term(child, child_parenthesis); + if (back) { + res += op; + } if (parenthesis) { return "(" + res + ")"; } diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.hpp index 21e668266097..d7593663d048 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.hpp @@ -40,6 +40,7 @@ class Solver { cvc5::TermManager term_manager; cvc5::Solver solver; cvc5::Sort ff_sort; + cvc5::Sort bv_sort; std::string modulus; // modulus in base 10 bool res = false; cvc5::Result cvc_result; @@ -47,11 +48,13 @@ class Solver { explicit Solver(const std::string& modulus, const SolverConfiguration& config = default_solver_config, - uint32_t base = 16) + uint32_t base = 16, + uint32_t bvsize = 254) : solver(term_manager) { this->ff_sort = term_manager.mkFiniteFieldSort(modulus, base); this->modulus = ff_sort.getFiniteFieldSize(); + this->bv_sort = term_manager.mkBitVectorSort(bvsize); if (config.produce_models) { solver.setOption("produce-models", "true"); } diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.test.cpp index 78e5a8833af8..415047f2b441 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/solver/solver.test.cpp @@ -1,6 +1,5 @@ #include "solver.hpp" -#include "barretenberg/smt_verification/terms/ffiterm.hpp" -#include "barretenberg/smt_verification/terms/ffterm.hpp" +#include "barretenberg/smt_verification/terms/term.hpp" #include @@ -11,13 +10,13 @@ using namespace smt_terms; TEST(Solver, FFTerm_use_case) { Solver s("101", default_solver_config, 10); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); y* y == x* x* x + bb::fr(2); - FFTerm l = (3 * x * x) / (bb::fr(2) * y); - FFTerm xr = l * l - x - x; - FFTerm yr = l * (x - xr) - y; + STerm l = (3 * x * x) / (bb::fr(2) * y); + STerm xr = l * l - x - x; + STerm yr = l * (x - xr) - y; x == xr; y == -yr; bool res = s.check(); @@ -35,8 +34,8 @@ TEST(Solver, FFTerm_use_case) TEST(Solver, FFITerm_use_case) { Solver s("bce4e33b636e0cf38d13a55c3"); - FFITerm x = FFITerm::Var("x", &s); - FFITerm y = FFITerm::Var("y", &s); + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); bb::fr a = bb::fr::random_element(); x <= bb::fr(2).pow(32); @@ -51,18 +50,18 @@ TEST(Solver, FFITerm_use_case) info("+"); info(vvars["y"]); info("="); - info(s.getValue(FFITerm(a, &s).term)); + info(s.getValue(STerm(a, &s, TermType::FFITerm).term)); } TEST(Solver, human_readable_constraints_FFTerm) { Solver s("101", default_solver_config, 10); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); y* y == x* x* x + bb::fr(2); - FFTerm l = (3 * x * x) / (bb::fr(2) * y); - FFTerm xr = l * l - x - x; - FFTerm yr = l * (x - xr) - y; + STerm l = (3 * x * x) / (bb::fr(2) * y); + STerm xr = l * l - x - x; + STerm yr = l * (x - xr) - y; x == xr; y == -yr; s.print_assertions(); @@ -71,12 +70,12 @@ TEST(Solver, human_readable_constraints_FFTerm) TEST(Solver, human_readable_constraints_FFITerm) { Solver s("101", default_solver_config, 10); - FFITerm x = FFITerm::Var("x", &s); - FFITerm y = FFITerm::Var("y", &s); + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); y* y == x* x* x + bb::fr(2); - FFITerm l = (3 * x * x) / (bb::fr(2) * y); - FFITerm xr = l * l - x - x; - FFITerm yr = l * (x - xr) - y; + STerm l = (3 * x * x) / (bb::fr(2) * y); + STerm xr = l * l - x - x; + STerm yr = l * (x - xr) - y; x == xr; y == -yr; s.print_assertions(); diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bool.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bool.hpp index 306b4359fad7..44014b15ade9 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bool.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bool.hpp @@ -1,6 +1,5 @@ #pragma once -#include "ffiterm.hpp" -#include "ffterm.hpp" +#include "term.hpp" namespace smt_terms { using namespace smt_solver; @@ -21,11 +20,8 @@ class Bool { Bool(const cvc5::Term& t, Solver* slv) : solver(slv) , term(t){}; - explicit Bool(const FFTerm& t) - : solver(t.solver) - , term(t.term){}; - explicit Bool(const FFITerm& t) + explicit Bool(const STerm& t) : solver(t.solver) , term(t.term){}; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp new file mode 100644 index 000000000000..c393fd5297f2 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/bvterm.test.cpp @@ -0,0 +1,177 @@ +#include + +#include "barretenberg/stdlib/primitives/uint/uint.hpp" +#include "term.hpp" + +#include + +namespace { +auto& engine = bb::numeric::get_debug_randomness(); +} + +using namespace bb; +using witness_ct = stdlib::witness_t; +using uint_ct = stdlib::uint32; + +using namespace smt_terms; + +TEST(BVTerm, addition) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct c = a + b; + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = x + y; + + z == c.get_value(); + x == a.get_value(); + ASSERT_TRUE(s.check()); + + std::string yvals = s.getValue(y.term).getBitVectorValue(); + + STerm bval = STerm(b.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, yvals); +} + +TEST(BVTerm, subtraction) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct c = a - b; + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = x - y; + + z == c.get_value(); + x == a.get_value(); + ASSERT_TRUE(s.check()); + + std::string yvals = s.getValue(y.term).getBitVectorValue(); + + STerm bval = STerm(b.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, yvals); +} + +TEST(BVTerm, xor) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct b = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct c = a ^ b; + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + STerm z = x ^ y; + + z == c.get_value(); + x == a.get_value(); + ASSERT_TRUE(s.check()); + + std::string yvals = s.getValue(y).getBitVectorValue(); + + STerm bval = STerm(b.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, yvals); +} + +TEST(BVTerm, rotr) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct b = a.ror(10); + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = x.rotr(10); + + y == b.get_value(); + ASSERT_TRUE(s.check()); + + std::string xvals = s.getValue(x.term).getBitVectorValue(); + + STerm bval = STerm(a.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, xvals); +} + +TEST(BVTerm, rotl) +{ + StandardCircuitBuilder builder; + uint_ct a = witness_ct(&builder, static_cast(fr::random_element())); + uint_ct b = a.rol(10); + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = x.rotl(10); + + y == b.get_value(); + ASSERT_TRUE(s.check()); + + std::string xvals = s.getValue(x.term).getBitVectorValue(); + + STerm bval = STerm(a.get_value(), &s, TermType::BVTerm); + std::string bvals = s.getValue(bval.term).getBitVectorValue(); + ASSERT_EQ(bvals, xvals); +} + +// MUL, LSH, RSH, AND and OR are not tested, since they are not bijective + +// This test aims to check for the absence of unintended +// behavior. If an unsupported operator is called, an info message appears in stderr +// and the value is supposed to remain unchanged. +TEST(BVTerm, unsupported_operations) +{ + + uint32_t modulus_base = 16; + uint32_t bitvector_size = 32; + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001", + default_solver_config, + modulus_base, + bitvector_size); + + STerm x = BVVar("x", &s); + STerm y = BVVar("y", &s); + + STerm z = x / y; + ASSERT_EQ(z.term, x.term); +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.cpp deleted file mode 100644 index f70802397e4c..000000000000 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.cpp +++ /dev/null @@ -1,215 +0,0 @@ -#include "ffiterm.hpp" - -namespace smt_terms { - -/** - * Create an integer mod symbolic variable. - * - * @param name Name of the variable. Should be unique per variable. - * @param slv Pointer to the global solver. - * @return Finite field symbolic variable. - * */ -FFITerm FFITerm::Var(const std::string& name, Solver* slv) -{ - return FFITerm(name, slv); -}; - -/** - * Create an integer mod numeric member. - * - * @param val String representation of the value. - * @param slv Pointer to the global solver. - * @param base Base of the string representation. 16 by default. - * @return Finite field constant. - * */ -FFITerm FFITerm::Const(const std::string& val, Solver* slv, uint32_t base) -{ - return FFITerm(val, slv, true, base); -}; - -FFITerm::FFITerm(const std::string& t, Solver* slv, bool isconst, uint32_t base) - : solver(slv) - , modulus(slv->term_manager.mkInteger(slv->modulus)) -{ - if (!isconst) { - this->term = slv->term_manager.mkConst(slv->term_manager.getIntegerSort(), t); - cvc5::Term ge = slv->term_manager.mkTerm(cvc5::Kind::GEQ, { this->term, slv->term_manager.mkInteger(0) }); - cvc5::Term lt = slv->term_manager.mkTerm(cvc5::Kind::LT, { this->term, this->modulus }); - slv->assertFormula(ge); - slv->assertFormula(lt); - } else { - // TODO(alex): CVC5 doesn't provide integer initialization from hex. Yet. - std::string strvalue = slv->term_manager.mkFiniteFieldElem(t, slv->ff_sort, base).getFiniteFieldValue(); - this->term = slv->term_manager.mkInteger(strvalue); - this->mod(); - } -} - -void FFITerm::mod() -{ - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::INTS_MODULUS, { this->term, this->modulus }); -} - -FFITerm FFITerm::operator+(const FFITerm& other) const -{ - cvc5::Term res = this->solver->term_manager.mkTerm(cvc5::Kind::ADD, { this->term, other.term }); - return { res, this->solver }; -} - -void FFITerm::operator+=(const FFITerm& other) -{ - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::ADD, { this->term, other.term }); -} - -FFITerm FFITerm::operator-(const FFITerm& other) const -{ - cvc5::Term res = this->solver->term_manager.mkTerm(cvc5::Kind::SUB, { this->term, other.term }); - return { res, this->solver }; -} - -void FFITerm::operator-=(const FFITerm& other) -{ - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::SUB, { this->term, other.term }); -} - -FFITerm FFITerm::operator-() const -{ - cvc5::Term res = this->solver->term_manager.mkTerm(cvc5::Kind::NEG, { this->term }); - return { res, this->solver }; -} - -FFITerm FFITerm::operator*(const FFITerm& other) const -{ - cvc5::Term res = solver->term_manager.mkTerm(cvc5::Kind::MULT, { this->term, other.term }); - return { res, this->solver }; -} - -void FFITerm::operator*=(const FFITerm& other) -{ - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::MULT, { this->term, other.term }); -} - -/** - * @brief Division operation - * - * @details Returns a result of the division by - * creating a new symbolic variable and adding a new constraint - * to the solver. - * - * @param other - * @return FFITerm - */ -FFITerm FFITerm::operator/(const FFITerm& other) const -{ - other != bb::fr(0); - FFITerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + - static_cast(other), - this->solver); - res* other == *this; - return res; -} - -void FFITerm::operator/=(const FFITerm& other) -{ - other != bb::fr(0); - FFITerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + - static_cast(other), - this->solver); - res* other == *this; - this->term = res.term; -} - -/** - * Create an equality constraint between two integer mod elements. - * - */ -void FFITerm::operator==(const FFITerm& other) const -{ - FFITerm tmp1 = *this; - if (tmp1.term.getNumChildren() > 1) { - tmp1.mod(); - } - FFITerm tmp2 = other; - if (tmp2.term.getNumChildren() > 1) { - tmp2.mod(); - } - cvc5::Term eq = this->solver->term_manager.mkTerm(cvc5::Kind::EQUAL, { tmp1.term, tmp2.term }); - this->solver->assertFormula(eq); -} - -/** - * Create an inequality constraint between two integer mod elements. - * - */ -void FFITerm::operator!=(const FFITerm& other) const -{ - FFITerm tmp1 = *this; - if (tmp1.term.getNumChildren() > 1) { - tmp1.mod(); - } - FFITerm tmp2 = other; - if (tmp2.term.getNumChildren() > 1) { - tmp2.mod(); - } - cvc5::Term eq = this->solver->term_manager.mkTerm(cvc5::Kind::EQUAL, { tmp1.term, tmp2.term }); - eq = this->solver->term_manager.mkTerm(cvc5::Kind::NOT, { eq }); - this->solver->assertFormula(eq); -} - -FFITerm operator+(const bb::fr& lhs, const FFITerm& rhs) -{ - return rhs + lhs; -} - -FFITerm operator-(const bb::fr& lhs, const FFITerm& rhs) -{ - return (-rhs) + lhs; -} - -FFITerm operator*(const bb::fr& lhs, const FFITerm& rhs) -{ - return rhs * lhs; -} - -FFITerm operator/(const bb::fr& lhs, const FFITerm& rhs) -{ - return FFITerm(lhs, rhs.solver) / rhs; -} - -FFITerm operator^(__attribute__((unused)) const bb::fr& lhs, __attribute__((unused)) const FFITerm& rhs) -{ - info("Not compatible with Integers"); - return {}; -} -void operator==(const bb::fr& lhs, const FFITerm& rhs) -{ - rhs == lhs; -} - -void operator!=(const bb::fr& lhs, const FFITerm& rhs) -{ - rhs != lhs; -} - -void FFITerm::operator<(const bb::fr& other) const -{ - cvc5::Term lt = this->solver->term_manager.mkTerm(cvc5::Kind::LT, { this->term, FFITerm(other, this->solver) }); - this->solver->assertFormula(lt); -} -void FFITerm::operator<=(const bb::fr& other) const -{ - cvc5::Term le = this->solver->term_manager.mkTerm(cvc5::Kind::LEQ, { this->term, FFITerm(other, this->solver) }); - this->solver->assertFormula(le); -} -void FFITerm::operator>(const bb::fr& other) const -{ - cvc5::Term gt = this->solver->term_manager.mkTerm(cvc5::Kind::GT, { this->term, FFITerm(other, this->solver) }); - this->solver->assertFormula(gt); -} -void FFITerm::operator>=(const bb::fr& other) const -{ - cvc5::Term ge = this->solver->term_manager.mkTerm(cvc5::Kind::GEQ, { this->term, FFITerm(other, this->solver) }); - this->solver->assertFormula(ge); -} - -} // namespace smt_terms diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp deleted file mode 100644 index e7c6c1afc5c7..000000000000 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.hpp +++ /dev/null @@ -1,143 +0,0 @@ -#pragma once -#include "barretenberg/smt_verification/solver/solver.hpp" - -namespace smt_terms { -using namespace smt_solver; - -/** - * @brief Integer Modulo element class. - * - * @details Can be a symbolic variable or a constant. - * Both of them support basic arithmetic operations: +, -, *, /. - * Check the satisfability of a system and get it's model. - * - */ -class FFITerm { - public: - Solver* solver; - cvc5::Term term; - cvc5::Term modulus; - - static bool isFiniteField() { return false; }; - static bool isInteger() { return true; }; - static bool isBitVector() { return false; }; - - FFITerm() - : solver(nullptr) - , term(cvc5::Term()) - , modulus(cvc5::Term()){}; - - FFITerm(cvc5::Term& term, Solver* s) - : solver(s) - , term(term) - , modulus(s->term_manager.mkInteger(s->modulus)) - {} - - explicit FFITerm(const std::string& t, Solver* slv, bool isconst = false, uint32_t base = 16); - - FFITerm(const FFITerm& other) = default; - FFITerm(FFITerm&& other) = default; - - static FFITerm Var(const std::string& name, Solver* slv); - static FFITerm Const(const std::string& val, Solver* slv, uint32_t base = 16); - - explicit FFITerm(bb::fr value, Solver* s) - { - std::stringstream buf; // TODO(#893) - buf << value; - std::string tmp = buf.str(); - tmp[1] = '0'; // avoiding `x` in 0x prefix - - *this = Const(tmp, s); - } - - FFITerm& operator=(const FFITerm& right) = default; - FFITerm& operator=(FFITerm&& right) = default; - - FFITerm operator+(const FFITerm& other) const; - void operator+=(const FFITerm& other); - FFITerm operator-(const FFITerm& other) const; - void operator-=(const FFITerm& other); - FFITerm operator-() const; - - FFITerm operator*(const FFITerm& other) const; - void operator*=(const FFITerm& other); - FFITerm operator/(const FFITerm& other) const; - void operator/=(const FFITerm& other); - - void operator==(const FFITerm& other) const; - void operator!=(const FFITerm& other) const; - - FFITerm operator^(__attribute__((unused)) const FFITerm& other) const - { - info("Not compatible with Integers"); - return {}; - } - void operator^=(__attribute__((unused)) const FFITerm& other) { info("Not compatible with Integers"); }; - - void mod(); - - operator std::string() const { return smt_solver::stringify_term(term); }; - operator cvc5::Term() const { return term; }; - - ~FFITerm() = default; - - friend std::ostream& operator<<(std::ostream& out, const FFITerm& term) - { - return out << static_cast(term); - } - - friend FFITerm batch_add(const std::vector& children) - { - Solver* slv = children[0].solver; - std::vector terms(children.begin(), children.end()); - cvc5::Term res = slv->term_manager.mkTerm(cvc5::Kind::ADD, terms); - res = slv->term_manager.mkTerm(cvc5::Kind::INTS_MODULUS, { res, children[0].modulus }); - return { res, slv }; - } - - friend FFITerm batch_mul(const std::vector& children) - { - Solver* slv = children[0].solver; - std::vector terms(children.begin(), children.end()); - cvc5::Term res = slv->term_manager.mkTerm(cvc5::Kind::MULT, terms); - res = slv->term_manager.mkTerm(cvc5::Kind::INTS_MODULUS, { res, children[0].modulus }); - return { res, slv }; - } - - // arithmetic compatibility with Fr - - FFITerm operator+(const bb::fr& other) const { return *this + FFITerm(other, this->solver); } - void operator+=(const bb::fr& other) { *this += FFITerm(other, this->solver); } - FFITerm operator-(const bb::fr& other) const { return *this - FFITerm(other, this->solver); } - void operator-=(const bb::fr& other) { *this -= FFITerm(other, this->solver); } - FFITerm operator*(const bb::fr& other) const { return *this * FFITerm(other, this->solver); } - void operator*=(const bb::fr& other) { *this *= FFITerm(other, this->solver); } - FFITerm operator/(const bb::fr& other) const { return *this / FFITerm(other, this->solver); } - void operator/=(const bb::fr& other) { *this /= FFITerm(other, this->solver); } - - void operator==(const bb::fr& other) const { *this == FFITerm(other, this->solver); } - void operator!=(const bb::fr& other) const { *this != FFITerm(other, this->solver); } - - FFITerm operator^(__attribute__((unused)) const bb::fr& other) const - { - info("Not compatible with Integers"); - return {}; - } - void operator^=(__attribute__((unused)) const bb::fr& other) { info("Not compatible with Finite Field"); } - - void operator<(const bb::fr& other) const; - void operator<=(const bb::fr& other) const; - void operator>(const bb::fr& other) const; - void operator>=(const bb::fr& other) const; -}; - -FFITerm operator+(const bb::fr& lhs, const FFITerm& rhs); -FFITerm operator-(const bb::fr& lhs, const FFITerm& rhs); -FFITerm operator*(const bb::fr& lhs, const FFITerm& rhs); -FFITerm operator^(const bb::fr& lhs, const FFITerm& rhs); -FFITerm operator/(const bb::fr& lhs, const FFITerm& rhs); -void operator==(const bb::fr& lhs, const FFITerm& rhs); -void operator!=(const bb::fr& lhs, const FFITerm& rhs); - -} // namespace smt_terms diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.test.cpp index 7949ffde76ce..1c2efa1cf93e 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffiterm.test.cpp @@ -1,6 +1,6 @@ #include -#include "ffiterm.hpp" +#include "term.hpp" #include @@ -17,16 +17,17 @@ TEST(FFITerm, addition) bb::fr c = a + b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFITerm x = FFITerm::Var("x", &s); - FFITerm y = FFITerm::Var("y", &s); - FFITerm bval = FFITerm(b, &s); - FFITerm z = x + y; + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); + STerm z = x + y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getIntegerValue(); + + STerm bval = STerm(b, &s, TermType::FFITerm); std::string bvals = s.getValue(bval.term).getIntegerValue(); ASSERT_EQ(bvals, yvals); } @@ -38,10 +39,10 @@ TEST(FFITerm, subtraction) bb::fr c = a - b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFITerm x = FFITerm::Var("x", &s); - FFITerm y = FFITerm::Var("y", &s); - FFITerm bval = FFITerm(b, &s); - FFITerm z = x - y; + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); + STerm bval = STerm(b, &s, TermType::FFITerm); + STerm z = x - y; z == c; x == a; @@ -59,16 +60,17 @@ TEST(FFITerm, multiplication) bb::fr c = a * b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFITerm x = FFITerm::Var("x", &s); - FFITerm y = FFITerm::Var("y", &s); - FFITerm bval = FFITerm(b, &s); - FFITerm z = x * y; + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); + STerm z = x * y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getIntegerValue(); + + STerm bval = STerm(b, &s, TermType::FFITerm); std::string bvals = s.getValue(bval.term).getIntegerValue(); ASSERT_EQ(bvals, yvals); } @@ -80,16 +82,55 @@ TEST(FFITerm, division) bb::fr c = a / b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFITerm x = FFITerm::Var("x", &s); - FFITerm y = FFITerm::Var("y", &s); - FFITerm bval = FFITerm(b, &s); - FFITerm z = x / y; + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); + STerm z = x / y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getIntegerValue(); + + STerm bval = STerm(b, &s, TermType::FFITerm); std::string bvals = s.getValue(bval.term).getIntegerValue(); ASSERT_EQ(bvals, yvals); } + +// This test aims to check for the absence of unintended +// behavior. If an unsupported operator is called, an info message appears in stderr +// and the value is supposed to remain unchanged. +TEST(FFITerm, unsupported_operations) +{ + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); + + STerm x = FFIVar("x", &s); + STerm y = FFIVar("y", &s); + + STerm z = x ^ y; + ASSERT_EQ(z.term, x.term); + z = x & y; + ASSERT_EQ(z.term, x.term); + z = x | y; + ASSERT_EQ(z.term, x.term); + z = x >> 10; + ASSERT_EQ(z.term, x.term); + z = x << 10; + ASSERT_EQ(z.term, x.term); + z = x.rotr(10); + ASSERT_EQ(z.term, x.term); + z = x.rotl(10); + ASSERT_EQ(z.term, x.term); + + cvc5::Term before_term = x.term; + x ^= y; + ASSERT_EQ(x.term, before_term); + x &= y; + ASSERT_EQ(x.term, before_term); + x |= y; + ASSERT_EQ(x.term, before_term); + x >>= 10; + ASSERT_EQ(x.term, before_term); + x <<= 10; + ASSERT_EQ(x.term, before_term); +} \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.cpp deleted file mode 100644 index 0761954680f2..000000000000 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.cpp +++ /dev/null @@ -1,167 +0,0 @@ -#include "ffterm.hpp" - -namespace smt_terms { - -/** - * Create a finite field symbolic variable. - * - * @param name Name of the variable. Should be unique per variable. - * @param slv Pointer to the global solver. - * @return Finite field symbolic variable. - * */ -FFTerm FFTerm::Var(const std::string& name, Solver* slv) -{ - return FFTerm(name, slv); -}; - -/** - * Create a finite field numeric member. - * - * @param val String representation of the value. - * @param slv Pointer to the global solver. - * @param base Base of the string representation. 16 by default. - * @return Finite field constant. - * */ -FFTerm FFTerm::Const(const std::string& val, Solver* slv, uint32_t base) -{ - return FFTerm(val, slv, true, base); -}; - -FFTerm::FFTerm(const std::string& t, Solver* slv, bool isconst, uint32_t base) - : solver(slv) -{ - if (!isconst) { - this->term = slv->term_manager.mkConst(slv->ff_sort, t); - } else { - this->term = slv->term_manager.mkFiniteFieldElem(t, slv->ff_sort, base); - } -} - -FFTerm FFTerm::operator+(const FFTerm& other) const -{ - cvc5::Term res = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_ADD, { this->term, other.term }); - return { res, this->solver }; -} - -void FFTerm::operator+=(const FFTerm& other) -{ - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_ADD, { this->term, other.term }); -} - -FFTerm FFTerm::operator-(const FFTerm& other) const -{ - cvc5::Term res = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_NEG, { other.term }); - res = solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_ADD, { this->term, res }); - return { res, this->solver }; -} - -FFTerm FFTerm::operator-() const -{ - cvc5::Term res = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_NEG, { this->term }); - return { res, this->solver }; -} - -void FFTerm::operator-=(const FFTerm& other) -{ - cvc5::Term tmp_term = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_NEG, { other.term }); - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_ADD, { this->term, tmp_term }); -} - -FFTerm FFTerm::operator*(const FFTerm& other) const -{ - cvc5::Term res = solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_MULT, { this->term, other.term }); - return { res, this->solver }; -} - -void FFTerm::operator*=(const FFTerm& other) -{ - this->term = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_MULT, { this->term, other.term }); -} - -/** - * @brief Division operation - * - * @details Returns a result of the division by - * creating a new symbolic variable and adding a new constraint - * to the solver. - * - * @param other - * @return FFTerm - */ -FFTerm FFTerm::operator/(const FFTerm& other) const -{ - other != bb::fr(0); - FFTerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + - static_cast(other), - this->solver); - res* other == *this; - return res; -} - -void FFTerm::operator/=(const FFTerm& other) -{ - other != bb::fr(0); - FFTerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + - static_cast(other), - this->solver); - res* other == *this; - this->term = res.term; -} - -/** - * Create an equality constraint between two finite field elements. - * - */ -void FFTerm::operator==(const FFTerm& other) const -{ - cvc5::Term eq = this->solver->term_manager.mkTerm(cvc5::Kind::EQUAL, { this->term, other.term }); - this->solver->assertFormula(eq); -} - -/** - * Create an inequality constraint between two finite field elements. - * - */ -void FFTerm::operator!=(const FFTerm& other) const -{ - cvc5::Term eq = this->solver->term_manager.mkTerm(cvc5::Kind::EQUAL, { this->term, other.term }); - eq = this->solver->term_manager.mkTerm(cvc5::Kind::NOT, { eq }); - this->solver->assertFormula(eq); -} - -FFTerm operator+(const bb::fr& lhs, const FFTerm& rhs) -{ - return rhs + lhs; -} - -FFTerm operator-(const bb::fr& lhs, const FFTerm& rhs) -{ - return (-rhs) + lhs; -} - -FFTerm operator*(const bb::fr& lhs, const FFTerm& rhs) -{ - return rhs * lhs; -} - -FFTerm operator^(__attribute__((unused)) const bb::fr& lhs, __attribute__((unused)) const FFTerm& rhs) -{ - info("Not compatible with Finite Field"); - return {}; -} - -FFTerm operator/(const bb::fr& lhs, const FFTerm& rhs) -{ - return FFTerm(lhs, rhs.solver) / rhs; -} - -void operator==(const bb::fr& lhs, const FFTerm& rhs) -{ - rhs == lhs; -} - -void operator!=(const bb::fr& lhs, const FFTerm& rhs) -{ - rhs != lhs; -} -} // namespace smt_terms \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp deleted file mode 100644 index 4451c4befb52..000000000000 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.hpp +++ /dev/null @@ -1,136 +0,0 @@ -#pragma once -#include "barretenberg/smt_verification/solver/solver.hpp" - -namespace smt_terms { -using namespace smt_solver; - -/** - * @brief Finite Field element class. - * - * @details Can be a finite field symbolic variable or a constant. - * Both of them support basic arithmetic operations: +, -, *, /. - * Check the satisfability of a system and get it's model. - * - */ -class FFTerm { - public: - Solver* solver; - cvc5::Term term; - - static bool isFiniteField() { return true; }; - static bool isInteger() { return false; }; - static bool isBitVector() { return false; }; - - FFTerm() - : solver(nullptr) - , term(cvc5::Term()){}; - - FFTerm(cvc5::Term& term, Solver* s) - : solver(s) - , term(term){}; - - explicit FFTerm(const std::string& t, Solver* slv, bool isconst = false, uint32_t base = 16); - - FFTerm(const FFTerm& other) = default; - FFTerm(FFTerm&& other) = default; - - static FFTerm Var(const std::string& name, Solver* slv); - static FFTerm Const(const std::string& val, Solver* slv, uint32_t base = 16); - - FFTerm(bb::fr value, Solver* s) - { - std::stringstream buf; // TODO(#893) - buf << value; - std::string tmp = buf.str(); - tmp[1] = '0'; // avoiding `x` in 0x prefix - - *this = Const(tmp, s); - } - - FFTerm& operator=(const FFTerm& right) = default; - FFTerm& operator=(FFTerm&& right) = default; - - FFTerm operator+(const FFTerm& other) const; - void operator+=(const FFTerm& other); - FFTerm operator-(const FFTerm& other) const; - void operator-=(const FFTerm& other); - FFTerm operator-() const; - - FFTerm operator*(const FFTerm& other) const; - void operator*=(const FFTerm& other); - FFTerm operator/(const FFTerm& other) const; - void operator/=(const FFTerm& other); - - void operator==(const FFTerm& other) const; - void operator!=(const FFTerm& other) const; - - FFTerm operator^(__attribute__((unused)) const FFTerm& other) const - { - info("Not compatible with Finite Field"); - return {}; - } - void operator^=(__attribute__((unused)) const FFTerm& other) { info("Not compatible with Finite Field"); }; - - void mod(){}; - - operator std::string() const { return smt_solver::stringify_term(term); }; - operator cvc5::Term() const { return term; }; - - ~FFTerm() = default; - - friend std::ostream& operator<<(std::ostream& out, const FFTerm& term) - { - return out << static_cast(term); - }; - - friend FFTerm batch_add(const std::vector& children) - { - Solver* slv = children[0].solver; - std::vector terms(children.begin(), children.end()); - cvc5::Term res = slv->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_ADD, terms); - return { res, slv }; - } - - friend FFTerm batch_mul(const std::vector& children) - { - Solver* slv = children[0].solver; - std::vector terms(children.begin(), children.end()); - cvc5::Term res = slv->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_MULT, terms); - return { res, slv }; - } - - // arithmetic compatibility with Fr - - FFTerm operator+(const bb::fr& rhs) const { return *this + FFTerm(rhs, this->solver); } - void operator+=(const bb::fr& other) { *this += FFTerm(other, this->solver); } - FFTerm operator-(const bb::fr& other) const { return *this - FFTerm(other, this->solver); } - void operator-=(const bb::fr& other) { *this -= FFTerm(other, this->solver); } - FFTerm operator*(const bb::fr& other) const { return *this * FFTerm(other, this->solver); } - void operator*=(const bb::fr& other) { *this *= FFTerm(other, this->solver); } - FFTerm operator/(const bb::fr& other) const { return *this / FFTerm(other, this->solver); } - void operator/=(const bb::fr& other) { *this /= FFTerm(other, this->solver); } - - void operator==(const bb::fr& other) const { *this == FFTerm(other, this->solver); } - void operator!=(const bb::fr& other) const { *this != FFTerm(other, this->solver); } - - FFTerm operator^(__attribute__((unused)) const bb::fr& other) const - { - info("Not compatible with Finite Field"); - return {}; - } - void operator^=(__attribute__((unused)) const bb::fr& other) { info("Not compatible with Finite Field"); } - void operator<(__attribute__((unused)) const bb::fr& other) const { info("Not compatible with Finite Field"); } - void operator<=(__attribute__((unused)) const bb::fr& other) const { info("Not compatible with Finite Field"); } - void operator>(__attribute__((unused)) const bb::fr& other) const { info("Not compatible with Finite Field"); } - void operator>=(__attribute__((unused)) const bb::fr& other) const { info("Not compatible with Finite Field"); } -}; - -FFTerm operator+(const bb::fr& lhs, const FFTerm& rhs); -FFTerm operator-(const bb::fr& lhs, const FFTerm& rhs); -FFTerm operator*(const bb::fr& lhs, const FFTerm& rhs); -FFTerm operator^(__attribute__((unused)) const bb::fr& lhs, __attribute__((unused)) const FFTerm& rhs); -FFTerm operator/(const bb::fr& lhs, const FFTerm& rhs); -void operator==(const bb::fr& lhs, const FFTerm& rhs); -void operator!=(const bb::fr& lhs, const FFTerm& rhs); - -} // namespace smt_terms \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.test.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.test.cpp index 633f40a6465a..ee2b118f37e4 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.test.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/ffterm.test.cpp @@ -1,6 +1,6 @@ #include -#include "ffterm.hpp" +#include "term.hpp" #include @@ -17,16 +17,17 @@ TEST(FFTerm, addition) bb::fr c = a + b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); - FFTerm bval = FFTerm(b, &s); - FFTerm z = x + y; + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); + STerm z = x + y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getFiniteFieldValue(); + + STerm bval = STerm(b, &s, TermType::FFTerm); std::string bvals = s.getValue(bval.term).getFiniteFieldValue(); ASSERT_EQ(bvals, yvals); } @@ -38,16 +39,17 @@ TEST(FFTerm, subtraction) bb::fr c = a - b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); - FFTerm bval = FFTerm(b, &s); - FFTerm z = x - y; + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); + STerm z = x - y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getFiniteFieldValue(); + + STerm bval = STerm(b, &s, TermType::FFTerm); std::string bvals = s.getValue(bval.term).getFiniteFieldValue(); ASSERT_EQ(bvals, yvals); } @@ -59,16 +61,17 @@ TEST(FFTerm, multiplication) bb::fr c = a * b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); - FFTerm bval = FFTerm(b, &s); - FFTerm z = x * y; + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); + STerm z = x * y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getFiniteFieldValue(); + + STerm bval = STerm(b, &s, TermType::FFTerm); std::string bvals = s.getValue(bval.term).getFiniteFieldValue(); ASSERT_EQ(bvals, yvals); } @@ -80,16 +83,65 @@ TEST(FFTerm, division) bb::fr c = a / b; Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); - FFTerm x = FFTerm::Var("x", &s); - FFTerm y = FFTerm::Var("y", &s); - FFTerm bval = FFTerm(b, &s); - FFTerm z = x / y; + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); + STerm z = x / y; z == c; x == a; ASSERT_TRUE(s.check()); std::string yvals = s.getValue(y.term).getFiniteFieldValue(); + + STerm bval = STerm(b, &s, TermType::FFTerm); std::string bvals = s.getValue(bval.term).getFiniteFieldValue(); ASSERT_EQ(bvals, yvals); +} + +// This test aims to check for the absence of unintended +// behavior. If an unsupported operator is called, an info message appears in stderr +// and the value is supposed to remain unchanged. +TEST(FFTerm, unsupported_operations) +{ + Solver s("30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"); + + STerm x = FFVar("x", &s); + STerm y = FFVar("y", &s); + + STerm z = x ^ y; + ASSERT_EQ(z.term, x.term); + z = x & y; + ASSERT_EQ(z.term, x.term); + z = x | y; + ASSERT_EQ(z.term, x.term); + z = x >> 10; + ASSERT_EQ(z.term, x.term); + z = x << 10; + ASSERT_EQ(z.term, x.term); + z = x.rotr(10); + ASSERT_EQ(z.term, x.term); + z = x.rotl(10); + ASSERT_EQ(z.term, x.term); + + cvc5::Term before_term = x.term; + x ^= y; + ASSERT_EQ(x.term, before_term); + x &= y; + ASSERT_EQ(x.term, before_term); + x |= y; + ASSERT_EQ(x.term, before_term); + x >>= 10; + ASSERT_EQ(x.term, before_term); + x <<= 10; + ASSERT_EQ(x.term, before_term); + + size_t n = s.solver.getAssertions().size(); + z <= bb::fr(10); + ASSERT_EQ(n, s.solver.getAssertions().size()); + z < bb::fr(10); + ASSERT_EQ(n, s.solver.getAssertions().size()); + z > bb::fr(10); + ASSERT_EQ(n, s.solver.getAssertions().size()); + z >= bb::fr(10); + ASSERT_EQ(n, s.solver.getAssertions().size()); } \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp new file mode 100644 index 000000000000..ae16d144a9ba --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.cpp @@ -0,0 +1,468 @@ +#include "barretenberg/smt_verification/terms/term.hpp" + +namespace smt_terms { + +/** + * Create a symbolic variable. + * + * @param name Name of the variable. Should be unique per variable + * @param slv Pointer to the global solver + * @param type FFTerm, FFITerm or BVTerm + * @return symbolic variable + * */ +STerm STerm::Var(const std::string& name, Solver* slv, TermType type) +{ + return STerm(name, slv, false, 16, type); +}; + +/** + * Create constant symbolic variable. + * i.e. term with predefined constant value + * + * @param val String representation of the value. + * @param slv Pointer to the global solver. + * @param base Base of the string representation. 16 by default. + * @param type FFTerm, FFITerm or BVTerm + * @return numeric constant + * */ +STerm STerm::Const(const std::string& val, Solver* slv, uint32_t base, TermType type) +{ + return STerm(val, slv, true, base, type); +}; + +STerm::STerm(const std::string& t, Solver* slv, bool isconst, uint32_t base, TermType type) + : solver(slv) + , isFiniteField(type == TermType::FFTerm) + , isInteger(type == TermType::FFITerm) + , isBitVector(type == TermType::BVTerm) + , type(type) + , operations(typed_operations.at(type)) +{ + if (!isconst) { + cvc5::Term ge; + cvc5::Term lt; + cvc5::Term modulus; + switch (type) { + case TermType::FFTerm: + this->term = slv->term_manager.mkConst(slv->ff_sort, t); + break; + case TermType::FFITerm: + this->term = slv->term_manager.mkConst(slv->term_manager.getIntegerSort(), t); + ge = slv->term_manager.mkTerm(cvc5::Kind::GEQ, { this->term, slv->term_manager.mkInteger(0) }); + modulus = slv->term_manager.mkInteger(slv->modulus); + lt = slv->term_manager.mkTerm(cvc5::Kind::LT, { this->term, modulus }); + slv->assertFormula(ge); + slv->assertFormula(lt); + break; + case TermType::BVTerm: + this->term = slv->term_manager.mkConst(slv->bv_sort, t); + break; + } + } else { + std::string strvalue; + switch (type) { + case TermType::FFTerm: + this->term = slv->term_manager.mkFiniteFieldElem(t, slv->ff_sort, base); + break; + case TermType::FFITerm: + // TODO(alex): CVC5 doesn't provide integer initialization from hex. Yet. + strvalue = slv->term_manager.mkFiniteFieldElem(t, slv->ff_sort, base).getFiniteFieldValue(); + this->term = slv->term_manager.mkInteger(strvalue); + this->mod(); + break; + case TermType::BVTerm: + // it's better to have (-p/2, p/2) representation due to overflows + strvalue = slv->term_manager.mkFiniteFieldElem(t, slv->ff_sort, base).getFiniteFieldValue(); + this->term = slv->term_manager.mkBitVector(slv->bv_sort.getBitVectorSize(), strvalue, 10); + break; + } + } +} + +void STerm::mod() +{ + if (this->type == TermType::FFITerm) { + cvc5::Term modulus = this->solver->term_manager.mkInteger(solver->modulus); + this->term = this->solver->term_manager.mkTerm(this->operations.at(OpType::MOD), { this->term, modulus }); + } +} + +STerm STerm::operator+(const STerm& other) const +{ + cvc5::Term res = this->solver->term_manager.mkTerm(this->operations.at(OpType::ADD), { this->term, other.term }); + return { res, this->solver, this->type }; +} + +void STerm::operator+=(const STerm& other) +{ + this->term = this->solver->term_manager.mkTerm(this->operations.at(OpType::ADD), { this->term, other.term }); +} + +STerm STerm::operator-(const STerm& other) const +{ + cvc5::Term res = this->solver->term_manager.mkTerm(this->operations.at(OpType::NEG), { other.term }); + res = solver->term_manager.mkTerm(this->operations.at(OpType::ADD), { this->term, res }); + return { res, this->solver, this->type }; +} + +void STerm::operator-=(const STerm& other) +{ + cvc5::Term tmp_term = this->solver->term_manager.mkTerm(this->operations.at(OpType::NEG), { other.term }); + this->term = this->solver->term_manager.mkTerm(cvc5::Kind::FINITE_FIELD_ADD, { this->term, tmp_term }); +} + +STerm STerm::operator-() const +{ + cvc5::Term res = this->solver->term_manager.mkTerm(this->operations.at(OpType::NEG), { this->term }); + return { res, this->solver, this->type }; +} + +STerm STerm::operator*(const STerm& other) const +{ + cvc5::Term res = solver->term_manager.mkTerm(this->operations.at(OpType::MUL), { this->term, other.term }); + return { res, this->solver, this->type }; +} + +void STerm::operator*=(const STerm& other) +{ + this->term = this->solver->term_manager.mkTerm(this->operations.at(OpType::MUL), { this->term, other.term }); +} + +/** + * @brief Division operation + * + * @details Returns a result of the division by + * creating a new symbolic variable and adding a new constraint + * to the solver. + * + * @param other + * @return STerm + */ +STerm STerm::operator/(const STerm& other) const +{ + if (!this->operations.contains(OpType::DIV)) { + info("Division is not compatible with ", this->type); + return *this; + } + other != bb::fr(0); + STerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + + static_cast(other), + this->solver, + this->type); + res* other == *this; + return res; +} + +void STerm::operator/=(const STerm& other) +{ + if (!this->operations.contains(OpType::DIV)) { + info("Division is not compatible with ", this->type); + return; + } + other != bb::fr(0); + STerm res = Var("df8b586e3fa7a1224ec95a886e17a7da_div_" + static_cast(*this) + "_" + + static_cast(other), + this->solver, + this->type); + res* other == *this; + this->term = res.term; +} + +/** + * Create an equality constraint between two symbolic variables of the same type + * + */ +void STerm::operator==(const STerm& other) const +{ + STerm tmp1 = *this; + if (tmp1.term.getNumChildren() > 1) { + tmp1.mod(); + } + STerm tmp2 = other; + if (tmp2.term.getNumChildren() > 1) { + tmp2.mod(); + } + cvc5::Term eq = this->solver->term_manager.mkTerm(cvc5::Kind::EQUAL, { tmp1.term, tmp2.term }); + this->solver->assertFormula(eq); +} + +/** + * Create an inequality constraint between two symbolic variables of the same type + * + */ +void STerm::operator!=(const STerm& other) const +{ + STerm tmp1 = *this; + if (tmp1.term.getNumChildren() > 1) { + tmp1.mod(); + } + STerm tmp2 = other; + if (tmp2.term.getNumChildren() > 1) { + tmp2.mod(); + } + cvc5::Term eq = this->solver->term_manager.mkTerm(cvc5::Kind::EQUAL, { tmp1.term, tmp2.term }); + eq = this->solver->term_manager.mkTerm(cvc5::Kind::NOT, { eq }); + this->solver->assertFormula(eq); +} + +void STerm::operator<(const bb::fr& other) const +{ + if (!this->operations.contains(OpType::LT)) { + info("LT is not compatible with ", this->type); + return; + } + + cvc5::Term lt = this->solver->term_manager.mkTerm(this->operations.at(OpType::LT), + { this->term, STerm(other, this->solver, this->type) }); + this->solver->assertFormula(lt); +} + +void STerm::operator<=(const bb::fr& other) const +{ + if (!this->operations.contains(OpType::LE)) { + info("LE is not compatible with ", this->type); + return; + } + cvc5::Term le = this->solver->term_manager.mkTerm(this->operations.at(OpType::LE), + { this->term, STerm(other, this->solver, this->type) }); + this->solver->assertFormula(le); +} + +void STerm::operator>(const bb::fr& other) const +{ + if (!this->operations.contains(OpType::GT)) { + info("GT is not compatible with ", this->type); + return; + } + cvc5::Term gt = this->solver->term_manager.mkTerm(this->operations.at(OpType::GT), + { this->term, STerm(other, this->solver, this->type) }); + this->solver->assertFormula(gt); +} + +void STerm::operator>=(const bb::fr& other) const +{ + if (!this->operations.contains(OpType::GE)) { + info("GE is not compatible with ", this->type); + return; + } + cvc5::Term ge = this->solver->term_manager.mkTerm(this->operations.at(OpType::GE), + { this->term, STerm(other, this->solver, this->type) }); + this->solver->assertFormula(ge); +} + +STerm STerm::operator^(const STerm& other) const +{ + if (!this->operations.contains(OpType::XOR)) { + info("XOR is not compatible with ", this->type); + return *this; + } + cvc5::Term res = solver->term_manager.mkTerm(this->operations.at(OpType::XOR), { this->term, other.term }); + return { res, this->solver, this->type }; +} + +void STerm::operator^=(const STerm& other) +{ + if (!this->operations.contains(OpType::XOR)) { + info("XOR is not compatible with ", this->type); + return; + } + this->term = solver->term_manager.mkTerm(this->operations.at(OpType::XOR), { this->term, other.term }); +} + +STerm STerm::operator&(const STerm& other) const +{ + if (!this->operations.contains(OpType::AND)) { + info("AND is not compatible with ", this->type); + return *this; + } + cvc5::Term res = solver->term_manager.mkTerm(this->operations.at(OpType::AND), { this->term, other.term }); + return { res, this->solver, this->type }; +} + +void STerm::operator&=(const STerm& other) +{ + if (!this->operations.contains(OpType::AND)) { + info("AND is not compatible with ", this->type); + return; + } + this->term = solver->term_manager.mkTerm(this->operations.at(OpType::AND), { this->term, other.term }); +} + +STerm STerm::operator|(const STerm& other) const +{ + if (!this->operations.contains(OpType::OR)) { + info("OR is not compatible with ", this->type); + return *this; + } + cvc5::Term res = solver->term_manager.mkTerm(this->operations.at(OpType::OR), { this->term, other.term }); + return { res, this->solver, this->type }; +} + +void STerm::operator|=(const STerm& other) +{ + if (!this->operations.contains(OpType::OR)) { + info("OR is not compatible with ", this->type); + return; + } + this->term = solver->term_manager.mkTerm(this->operations.at(OpType::OR), { this->term, other.term }); +} + +STerm STerm::operator<<(const uint32_t& n) const +{ + if (!this->operations.contains(OpType::LSH)) { + info("SHIFT LEFT is not compatible with ", this->type); + return *this; + } + cvc5::Op lsh = solver->term_manager.mkOp(this->operations.at(OpType::LSH), { n }); + cvc5::Term res = solver->term_manager.mkTerm(lsh, { this->term }); + return { res, this->solver, this->type }; +} + +void STerm::operator<<=(const uint32_t& n) +{ + if (!this->operations.contains(OpType::LSH)) { + info("SHIFT LEFT is not compatible with ", this->type); + return; + } + cvc5::Op lsh = solver->term_manager.mkOp(this->operations.at(OpType::LSH), { n }); + this->term = solver->term_manager.mkTerm(lsh, { this->term }); +} + +STerm STerm::operator>>(const uint32_t& n) const +{ + if (!this->operations.contains(OpType::RSH)) { + info("RIGHT LEFT is not compatible with ", this->type); + return *this; + } + cvc5::Op rsh = solver->term_manager.mkOp(this->operations.at(OpType::RSH), { n }); + cvc5::Term res = solver->term_manager.mkTerm(rsh, { this->term }); + return { res, this->solver, this->type }; +} + +void STerm::operator>>=(const uint32_t& n) +{ + if (!this->operations.contains(OpType::RSH)) { + info("RIGHT LEFT is not compatible with ", this->type); + return; + } + cvc5::Op rsh = solver->term_manager.mkOp(this->operations.at(OpType::RSH), { n }); + this->term = solver->term_manager.mkTerm(rsh, { this->term }); +} + +STerm STerm::rotr(const uint32_t& n) const +{ + if (!this->operations.contains(OpType::ROTR)) { + info("ROTR is not compatible with ", this->type); + return *this; + } + cvc5::Op rotr = solver->term_manager.mkOp(this->operations.at(OpType::ROTR), { n }); + cvc5::Term res = solver->term_manager.mkTerm(rotr, { this->term }); + return { res, this->solver, this->type }; +} + +STerm STerm::rotl(const uint32_t& n) const +{ + if (!this->operations.contains(OpType::ROTL)) { + info("ROTL is not compatible with ", this->type); + return *this; + } + cvc5::Op rotl = solver->term_manager.mkOp(this->operations.at(OpType::ROTL), { n }); + cvc5::Term res = solver->term_manager.mkTerm(rotl, { this->term }); + return { res, this->solver, this->type }; +} + +STerm operator+(const bb::fr& lhs, const STerm& rhs) +{ + return rhs + lhs; +} + +STerm operator-(const bb::fr& lhs, const STerm& rhs) +{ + return (-rhs) + lhs; +} + +STerm operator*(const bb::fr& lhs, const STerm& rhs) +{ + return rhs * lhs; +} + +STerm operator^(const bb::fr& lhs, const STerm& rhs) +{ + return rhs ^ lhs; +} + +STerm operator&(const bb::fr& lhs, const STerm& rhs) +{ + return rhs & lhs; +} + +STerm operator|(const bb::fr& lhs, const STerm& rhs) +{ + return rhs | lhs; +} + +STerm operator/(const bb::fr& lhs, const STerm& rhs) +{ + return STerm(lhs, rhs.solver, rhs.type) / rhs; +} + +void operator==(const bb::fr& lhs, const STerm& rhs) +{ + rhs == lhs; +} + +void operator!=(const bb::fr& lhs, const STerm& rhs) +{ + rhs != lhs; +} + +std::ostream& operator<<(std::ostream& os, const TermType type) +{ + switch (type) { + case TermType::FFTerm: + os << "FFTerm"; + break; + case TermType::FFITerm: + os << "FFITerm"; + break; + case TermType::BVTerm: + os << "BVTerm"; + break; + }; + return os; +} + +// Parametrized calls to STerm constructor +// so you won't need to use TermType each time to define a new variable. + +STerm FFVar(const std::string& name, Solver* slv) +{ + return STerm::Var(name, slv, TermType::FFTerm); +} + +STerm FFConst(const std::string& val, Solver* slv, uint32_t base) +{ + return STerm::Const(val, slv, base, TermType::FFTerm); +} + +STerm FFIVar(const std::string& name, Solver* slv) +{ + return STerm::Var(name, slv, TermType::FFITerm); +} + +STerm FFIConst(const std::string& val, Solver* slv, uint32_t base) +{ + return STerm::Const(val, slv, base, TermType::FFITerm); +} + +STerm BVVar(const std::string& name, Solver* slv) +{ + return STerm::Var(name, slv, TermType::BVTerm); +} + +STerm BVConst(const std::string& val, Solver* slv, uint32_t base) +{ + return STerm::Const(val, slv, base, TermType::BVTerm); +} + +} // namespace smt_terms diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp new file mode 100644 index 000000000000..685bf446462a --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/smt_verification/terms/term.hpp @@ -0,0 +1,226 @@ +#pragma once +#include "barretenberg/smt_verification/solver/solver.hpp" + +namespace smt_terms { +using namespace smt_solver; + +/** + * @brief Allows to define three types of symbolic terms + * STerm - Symbolic Variables acting like a Finte Field elements + * FFITerm - Symbolic Variables acting like integers modulo prime + * BVTerm - Symbolic Variables acting like bitvectors modulo prime + * + */ +enum class TermType { FFTerm, FFITerm, BVTerm }; +std::ostream& operator<<(std::ostream& os, TermType type); + +enum class OpType : int32_t { ADD, SUB, MUL, DIV, NEG, XOR, AND, OR, GT, GE, LT, LE, MOD, RSH, LSH, ROTR, ROTL }; + +/** + * @brief precomputed map that contains allowed + * operations for each of three symbolic types + * + */ +const std::unordered_map> typed_operations = { + { TermType::FFTerm, + { { OpType::ADD, cvc5::Kind::FINITE_FIELD_ADD }, + { OpType::MUL, cvc5::Kind::FINITE_FIELD_MULT }, + { OpType::NEG, cvc5::Kind::FINITE_FIELD_NEG }, + // Just a placeholder that marks it supports division + { OpType::DIV, cvc5::Kind::FINITE_FIELD_MULT } } }, + { TermType::FFITerm, + { + + { OpType::ADD, cvc5::Kind::ADD }, + { OpType::SUB, cvc5::Kind::SUB }, + { OpType::MUL, cvc5::Kind::MULT }, + { OpType::NEG, cvc5::Kind::NEG }, + { OpType::GT, cvc5::Kind::GT }, + { OpType::GE, cvc5::Kind::GEQ }, + { OpType::LT, cvc5::Kind::LT }, + { OpType::LE, cvc5::Kind::LEQ }, + { OpType::MOD, cvc5::Kind::INTS_MODULUS }, + // Just a placeholder that marks it supports division + { OpType::DIV, cvc5::Kind::MULT } } }, + { TermType::BVTerm, + { + + { OpType::ADD, cvc5::Kind::BITVECTOR_ADD }, + { OpType::SUB, cvc5::Kind::BITVECTOR_SUB }, + { OpType::MUL, cvc5::Kind::BITVECTOR_MULT }, + { OpType::NEG, cvc5::Kind::BITVECTOR_NEG }, + { OpType::GT, cvc5::Kind::BITVECTOR_UGT }, + { OpType::GE, cvc5::Kind::BITVECTOR_UGE }, + { OpType::LT, cvc5::Kind::BITVECTOR_ULT }, + { OpType::LE, cvc5::Kind::BITVECTOR_ULE }, + { OpType::XOR, cvc5::Kind::BITVECTOR_XOR }, + { OpType::AND, cvc5::Kind::BITVECTOR_AND }, + { OpType::OR, cvc5::Kind::BITVECTOR_OR }, + { OpType::RSH, cvc5::Kind::BITVECTOR_LSHR }, + { OpType::LSH, cvc5::Kind::BITVECTOR_SHL }, + { OpType::ROTL, cvc5::Kind::BITVECTOR_ROTATE_LEFT }, + { OpType::ROTR, cvc5::Kind::BITVECTOR_ROTATE_RIGHT } } } +}; + +/** + * @brief Symbolic term element class. + * + * @details Can be a Finite Field/Integer Mod/BitVector symbolic variable or a constant. + * Supports basic arithmetic operations: +, -, *, /. + * Additionally + * FFITerm supports inequalities: <, <=, >, >= + * BVTerm supports inequalities and bitwise operations: ^, &, |, >>, <<, ror, rol + * + */ +class STerm { + private: + STerm(cvc5::Term& term, Solver* s, TermType type = TermType::FFTerm) + : solver(s) + , term(term) + , isFiniteField(type == TermType::FFTerm) + , isInteger(type == TermType::FFITerm) + , isBitVector(type == TermType::BVTerm) + , type(type) + , operations(typed_operations.at(type)){}; + void mod(); + + public: + Solver* solver; + cvc5::Term term; + + bool isFiniteField; + bool isInteger; + bool isBitVector; + + TermType type; + std::unordered_map operations; + + STerm() + : solver(nullptr) + , term(cvc5::Term()) + , isFiniteField(false) + , isInteger(false) + , isBitVector(false) + , type(TermType::FFTerm){}; + + explicit STerm( + const std::string& t, Solver* slv, bool isconst = false, uint32_t base = 16, TermType type = TermType::FFTerm); + + STerm(const STerm& other) = default; + STerm(STerm&& other) = default; + + static STerm Var(const std::string& name, Solver* slv, TermType type = TermType::FFTerm); + static STerm Const(const std::string& val, Solver* slv, uint32_t base = 16, TermType type = TermType::FFTerm); + + STerm(bb::fr value, Solver* s, TermType type = TermType::FFTerm) + { + std::stringstream buf; // TODO(#893) + buf << value; + std::string tmp = buf.str(); + tmp[1] = '0'; // avoiding `x` in 0x prefix + + *this = Const(tmp, s, 16, type); + } + + STerm& operator=(const STerm& right) = default; + STerm& operator=(STerm&& right) = default; + + STerm operator+(const STerm& other) const; + void operator+=(const STerm& other); + STerm operator-(const STerm& other) const; + void operator-=(const STerm& other); + STerm operator-() const; + + STerm operator*(const STerm& other) const; + void operator*=(const STerm& other); + STerm operator/(const STerm& other) const; + void operator/=(const STerm& other); + + void operator==(const STerm& other) const; + void operator!=(const STerm& other) const; + + STerm operator^(const STerm& other) const; + void operator^=(const STerm& other); + STerm operator&(const STerm& other) const; + void operator&=(const STerm& other); + STerm operator|(const STerm& other) const; + void operator|=(const STerm& other); + STerm operator<<(const uint32_t& n) const; + void operator<<=(const uint32_t& n); + STerm operator>>(const uint32_t& n) const; + void operator>>=(const uint32_t& n); + + STerm rotr(const uint32_t& n) const; + STerm rotl(const uint32_t& n) const; + + operator std::string() const { return smt_solver::stringify_term(term); }; + operator cvc5::Term() const { return term; }; + + ~STerm() = default; + + friend std::ostream& operator<<(std::ostream& out, const STerm& term) + { + return out << static_cast(term); + }; + + friend STerm batch_add(const std::vector& children) + { + Solver* slv = children[0].solver; + std::vector terms(children.begin(), children.end()); + cvc5::Term res = slv->term_manager.mkTerm(children[0].operations.at(OpType::ADD), terms); + return { res, slv }; + } + + friend STerm batch_mul(const std::vector& children) + { + Solver* slv = children[0].solver; + std::vector terms(children.begin(), children.end()); + cvc5::Term res = slv->term_manager.mkTerm(children[0].operations.at(OpType::MUL), terms); + return { res, slv }; + } + + // arithmetic compatibility with Fr + + STerm operator+(const bb::fr& other) const { return *this + STerm(other, this->solver, this->type); } + void operator+=(const bb::fr& other) { *this += STerm(other, this->solver, this->type); } + STerm operator-(const bb::fr& other) const { return *this - STerm(other, this->solver, this->type); } + void operator-=(const bb::fr& other) { *this -= STerm(other, this->solver, this->type); } + STerm operator*(const bb::fr& other) const { return *this * STerm(other, this->solver, this->type); } + void operator*=(const bb::fr& other) { *this *= STerm(other, this->solver, this->type); } + STerm operator/(const bb::fr& other) const { return *this / STerm(other, this->solver, this->type); } + void operator/=(const bb::fr& other) { *this /= STerm(other, this->solver, this->type); } + + void operator==(const bb::fr& other) const { *this == STerm(other, this->solver, this->type); } + void operator!=(const bb::fr& other) const { *this != STerm(other, this->solver, this->type); } + + STerm operator^(const bb::fr& other) const { return *this ^ STerm(other, this->solver, this->type); }; + void operator^=(const bb::fr& other) { *this ^= STerm(other, this->solver, this->type); }; + STerm operator&(const bb::fr& other) const { return *this & STerm(other, this->solver, this->type); }; + void operator&=(const bb::fr& other) { *this &= STerm(other, this->solver, this->type); }; + STerm operator|(const bb::fr& other) const { return *this | STerm(other, this->solver, this->type); }; + void operator|=(const bb::fr& other) { *this |= STerm(other, this->solver, this->type); }; + + void operator<(const bb::fr& other) const; + void operator<=(const bb::fr& other) const; + void operator>(const bb::fr& other) const; + void operator>=(const bb::fr& other) const; +}; + +STerm operator+(const bb::fr& lhs, const STerm& rhs); +STerm operator-(const bb::fr& lhs, const STerm& rhs); +STerm operator*(const bb::fr& lhs, const STerm& rhs); +STerm operator/(const bb::fr& lhs, const STerm& rhs); +void operator==(const bb::fr& lhs, const STerm& rhs); +void operator!=(const bb::fr& lhs, const STerm& rhs); +STerm operator^(const bb::fr& lhs, const STerm& rhs); +STerm operator&(const bb::fr& lhs, const STerm& rhs); +STerm operator|(const bb::fr& lhs, const STerm& rhs); + +STerm FFVar(const std::string& name, Solver* slv); +STerm FFConst(const std::string& val, Solver* slv, uint32_t base = 16); +STerm FFIVar(const std::string& name, Solver* slv); +STerm FFIConst(const std::string& val, Solver* slv, uint32_t base = 16); +STerm BVVar(const std::string& name, Solver* slv); +STerm BVConst(const std::string& val, Solver* slv, uint32_t base = 16); + +} // namespace smt_terms \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp index d67fb098d858..3e4b7eb1c661 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.cpp @@ -1,5 +1,119 @@ #include "smt_util.hpp" +/** + * @brief Get pretty formatted result of the solver work + * + * @details Having two circuits and defined constraint system + * inside the solver get the pretty formatted output. + * The whole witness will be saved in c-like array format. + * Special variables will be printed to stdout. e.g. `a_1, a_2 = val_a_1, val_a_2;` + * + * @param special The list of variables that you need to see in stdout + * @param c1 first circuit + * @param c2 the copy of the first circuit with changed tag + * @param s solver + * @param fname file to store the resulting witness if succeded + */ +void default_model(const std::vector& special, + smt_circuit::Circuit& c1, + smt_circuit::Circuit& c2, + smt_solver::Solver* s, + const std::string& fname) +{ + std::vector vterms1; + std::vector vterms2; + vterms1.reserve(c1.get_num_real_vars()); + vterms2.reserve(c1.get_num_real_vars()); + + for (uint32_t i = 0; i < c1.get_num_vars(); i++) { + vterms1.push_back(c1.symbolic_vars[c1.real_variable_index[i]]); + vterms2.push_back(c2.symbolic_vars[c2.real_variable_index[i]]); + } + + std::unordered_map mmap1 = s->model(vterms1); + std::unordered_map mmap2 = s->model(vterms2); + + std::fstream myfile; + myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); + myfile << "w12 = {" << std::endl; + for (uint32_t i = 0; i < c1.get_num_vars(); i++) { + std::string vname1 = vterms1[i].toString(); + std::string vname2 = vterms2[i].toString(); + if (c1.real_variable_index[i] == i) { + myfile << "{" << mmap1[vname1] << ", " << mmap2[vname2] << "}"; + myfile << ", // " << vname1 << ", " << vname2 << std::endl; + } else { + myfile << "{" << mmap1[vname1] << ", " + mmap2[vname2] << "}"; + myfile << ", // " << vname1 << " ," << vname2 << " -> " << c1.real_variable_index[i] << std::endl; + } + } + myfile << "};"; + myfile.close(); + + std::unordered_map vterms; + for (const auto& vname : special) { + vterms.insert({ vname + "_1", c1[vname] }); + vterms.insert({ vname + "_2", c2[vname] }); + } + + std::unordered_map mmap = s->model(vterms); + for (const auto& vname : special) { + info(vname, "_1, ", vname, "_2 = ", mmap[vname + "_1"], ", ", mmap[vname + "_2"]); + } +} + +/** + * @brief Get pretty formatted result of the solver work + * + * @details Having a circuit and defined constraint system + * inside the solver get the pretty formatted output. + * The whole witness will be saved in c-like array format. + * Special variables will be printed to stdout. e.g. `a = val_a;` + * + * @param special The list of variables that you need to see in stdout + * @param c first circuit + * @param s solver + * @param fname file to store the resulting witness if succeded + */ +void default_model_single(const std::vector& special, + smt_circuit::Circuit& c, + smt_solver::Solver* s, + const std::string& fname) +{ + std::vector vterms; + vterms.reserve(c.get_num_real_vars()); + + for (uint32_t i = 0; i < c.get_num_vars(); i++) { + vterms.push_back(c.symbolic_vars[i]); + } + + std::unordered_map mmap = s->model(vterms); + + std::fstream myfile; + myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); + myfile << "w = {" << std::endl; + for (size_t i = 0; i < c.get_num_vars(); i++) { + std::string vname = vterms[i].toString(); + if (c.real_variable_index[i] == i) { + myfile << mmap[vname] << ", // " << vname << std::endl; + } else { + myfile << mmap[vname] << ", // " << vname << " -> " << c.real_variable_index[i] << std::endl; + } + } + myfile << "};"; + myfile.close(); + + std::unordered_map vterms1; + for (const auto& vname : special) { + vterms1.insert({ vname, c[vname] }); + } + + std::unordered_map mmap1 = s->model(vterms1); + for (const auto& vname : special) { + info(vname, " = ", mmap1[vname]); + } +} + /** * @brief Get the solver result and amount of time * that it took to solve. @@ -7,13 +121,19 @@ * @param s * @return bool is system satisfiable? */ -bool smt_timer(smt_solver::Solver* s) +bool smt_timer(smt_solver::Solver* s, bool mins) { auto start = std::chrono::high_resolution_clock::now(); bool res = s->check(); auto stop = std::chrono::high_resolution_clock::now(); - double duration = static_cast(duration_cast(stop - start).count()); - info("Time passed: ", duration); + double duration = 0.0; + if (mins) { + duration = static_cast(duration_cast(stop - start).count()); + info("Time elapsed: ", duration, " min"); + } else { + duration = static_cast(duration_cast(stop - start).count()); + info("Time elapsed: ", duration, " sec"); + } info(s->cvc_result); return res; diff --git a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp index a0aceeeadf61..3b0211272c8d 100644 --- a/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp +++ b/barretenberg/cpp/src/barretenberg/smt_verification/util/smt_util.hpp @@ -3,121 +3,15 @@ #include "barretenberg/smt_verification/circuit/circuit.hpp" -/** - * @brief Get pretty formatted result of the solver work - * - * @details Having two circuits and defined constraint system - * inside the solver get the pretty formatted output. - * The whole witness will be saved in c-like array format. - * Special variables will be printed to stdout. e.g. `a_1, a_2 = val_a_1, val_a_2;` - * - * @param special The list of variables that you need to see in stdout - * @param c1 first circuit - * @param c2 the copy of the first circuit with changed tag - * @param s solver - * @param fname file to store the resulting witness if succeded - */ -template -void default_model(std::vector special, - smt_circuit::Circuit& c1, - smt_circuit::Circuit& c2, +void default_model(const std::vector& special, + smt_circuit::Circuit& c1, + smt_circuit::Circuit& c2, smt_solver::Solver* s, - const std::string& fname = "witness.out") -{ - std::vector vterms1; - std::vector vterms2; - vterms1.reserve(c1.get_num_real_vars()); - vterms2.reserve(c1.get_num_real_vars()); - - for (uint32_t i = 0; i < c1.get_num_vars(); i++) { - vterms1.push_back(c1.symbolic_vars[c1.real_variable_index[i]]); - vterms2.push_back(c2.symbolic_vars[c2.real_variable_index[i]]); - } - - std::unordered_map mmap1 = s->model(vterms1); - std::unordered_map mmap2 = s->model(vterms2); - - std::fstream myfile; - myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); - myfile << "w12 = {" << std::endl; - for (uint32_t i = 0; i < c1.get_num_vars(); i++) { - std::string vname1 = vterms1[i].toString(); - std::string vname2 = vterms2[i].toString(); - if (c1.real_variable_index[i] == i) { - myfile << "{" << mmap1[vname1] << ", " << mmap2[vname2] << "}"; - myfile << ", // " << vname1 << ", " << vname2 << std::endl; - } else { - myfile << "{" << mmap1[vname1] << ", " + mmap2[vname2] << "}"; - myfile << ", // " << vname1 << " ," << vname2 << " -> " << c1.real_variable_index[i] << std::endl; - } - } - myfile << "};"; - myfile.close(); - - std::unordered_map vterms; - for (auto& vname : special) { - vterms.insert({ vname + "_1", c1[vname] }); - vterms.insert({ vname + "_2", c2[vname] }); - } - - std::unordered_map mmap = s->model(vterms); - for (auto& vname : special) { - info(vname, "_1, ", vname, "_2 = ", mmap[vname + "_1"], ", ", mmap[vname + "_2"]); - } -} - -/** - * @brief Get pretty formatted result of the solver work - * - * @details Having a circuit and defined constraint system - * inside the solver get the pretty formatted output. - * The whole witness will be saved in c-like array format. - * Special variables will be printed to stdout. e.g. `a = val_a;` - * - * @param special The list of variables that you need to see in stdout - * @param c first circuit - * @param s solver - * @param fname file to store the resulting witness if succeded - */ -template -void default_model_single(std::vector special, - smt_circuit::Circuit& c, + const std::string& fname = "witness.out"); +void default_model_single(const std::vector& special, + smt_circuit::Circuit& c, smt_solver::Solver* s, - const std::string& fname = "witness.out") -{ - std::vector vterms; - vterms.reserve(c.get_num_real_vars()); - - for (uint32_t i = 0; i < c.get_num_vars(); i++) { - vterms.push_back(c.symbolic_vars[i]); - } - - std::unordered_map mmap = s->model(vterms); - - std::fstream myfile; - myfile.open(fname, std::ios::out | std::ios::trunc | std::ios::binary); - myfile << "w = {" << std::endl; - for (size_t i = 0; i < c.get_num_vars(); i++) { - std::string vname = vterms[i].toString(); - if (c.real_variable_index[i] == i) { - myfile << mmap[vname] << ", // " << vname << std::endl; - } else { - myfile << mmap[vname] << ", // " << vname << " -> " << c.real_variable_index[i] << std::endl; - } - } - myfile << "};"; - myfile.close(); - - std::unordered_map vterms1; - for (auto& vname : special) { - vterms1.insert({ vname, c[vname] }); - } - - std::unordered_map mmap1 = s->model(vterms1); - for (auto& vname : special) { - info(vname, " = ", mmap1[vname]); - } -} + const std::string& fname = "witness.out"); -bool smt_timer(smt_solver::Solver* s); +bool smt_timer(smt_solver::Solver* s, bool mins = true); std::pair, std::vector> base4(uint32_t el); \ No newline at end of file