diff --git a/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs b/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs index bcf736cd926e..81e752d5656e 100644 --- a/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs +++ b/noir/noir-repo/acvm-repo/acvm/src/pwg/brillig.rs @@ -206,13 +206,13 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { for output in brillig.outputs.iter() { match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, memory[current_ret_data_idx].value, witness_map)?; + insert_value(witness, memory[current_ret_data_idx].to_field(), witness_map)?; current_ret_data_idx += 1; } BrilligOutputs::Array(witness_arr) => { for witness in witness_arr.iter() { - let value = memory[current_ret_data_idx]; - insert_value(witness, value.value, witness_map)?; + let value = &memory[current_ret_data_idx]; + insert_value(witness, value.to_field(), witness_map)?; current_ret_data_idx += 1; } } diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs index 3d77982ffb13..2107d10c093d 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs @@ -1,9 +1,10 @@ use acir::brillig::{BinaryFieldOp, BinaryIntOp}; use acir::FieldElement; use num_bigint::BigUint; -use num_traits::{One, ToPrimitive, Zero}; +use num_traits::ToPrimitive; +use num_traits::{One, Zero}; -use crate::memory::MemoryValue; +use crate::memory::{MemoryTypeError, MemoryValue}; #[derive(Debug, thiserror::Error)] pub(crate) enum BrilligArithmeticError { @@ -11,6 +12,8 @@ pub(crate) enum BrilligArithmeticError { MismatchedLhsBitSize { lhs_bit_size: u32, op_bit_size: u32 }, #[error("Bit size for rhs {rhs_bit_size} does not match op bit size {op_bit_size}")] MismatchedRhsBitSize { rhs_bit_size: u32, op_bit_size: u32 }, + #[error("Integer operation BinaryIntOp::{op:?} is not supported on FieldElement")] + IntegerOperationOnField { op: BinaryIntOp }, #[error("Shift with bit size {op_bit_size} is invalid")] InvalidShift { op_bit_size: u32 }, } @@ -21,21 +24,19 @@ pub(crate) fn evaluate_binary_field_op( lhs: MemoryValue, rhs: MemoryValue, ) -> Result { - if lhs.bit_size != FieldElement::max_num_bits() { + let MemoryValue::Field(a) = lhs else { return Err(BrilligArithmeticError::MismatchedLhsBitSize { - lhs_bit_size: lhs.bit_size, + lhs_bit_size: lhs.bit_size(), op_bit_size: FieldElement::max_num_bits(), }); - } - if rhs.bit_size != FieldElement::max_num_bits() { - return Err(BrilligArithmeticError::MismatchedRhsBitSize { - rhs_bit_size: rhs.bit_size, + }; + let MemoryValue::Field(b) = rhs else { + return Err(BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: rhs.bit_size(), op_bit_size: FieldElement::max_num_bits(), }); - } + }; - let a = lhs.value; - let b = rhs.value; Ok(match op { // Perform addition, subtraction, multiplication, and division based on the BinaryOp variant. BinaryFieldOp::Add => (a + b).into(), @@ -62,21 +63,26 @@ pub(crate) fn evaluate_binary_int_op( rhs: MemoryValue, bit_size: u32, ) -> Result { - if lhs.bit_size != bit_size { - return Err(BrilligArithmeticError::MismatchedLhsBitSize { - lhs_bit_size: lhs.bit_size, - op_bit_size: bit_size, - }); - } - if rhs.bit_size != bit_size { - return Err(BrilligArithmeticError::MismatchedRhsBitSize { - rhs_bit_size: rhs.bit_size, - op_bit_size: bit_size, - }); - } + let lhs = lhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err { + MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { + BrilligArithmeticError::MismatchedLhsBitSize { + lhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, + } + } + })?; + let rhs = rhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err { + MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { + BrilligArithmeticError::MismatchedRhsBitSize { + rhs_bit_size: value_bit_size, + op_bit_size: expected_bit_size, + } + } + })?; - let lhs = BigUint::from_bytes_be(&lhs.value.to_be_bytes()); - let rhs = BigUint::from_bytes_be(&rhs.value.to_be_bytes()); + if bit_size == FieldElement::max_num_bits() { + return Err(BrilligArithmeticError::IntegerOperationOnField { op: *op }); + } let bit_modulo = &(BigUint::one() << bit_size); let result = match op { @@ -136,13 +142,11 @@ pub(crate) fn evaluate_binary_int_op( } }; - let result_as_field = FieldElement::from_be_bytes_reduce(&result.to_bytes_be()); - Ok(match op { BinaryIntOp::Equals | BinaryIntOp::LessThan | BinaryIntOp::LessThanEquals => { - MemoryValue::new(result_as_field, 1) + MemoryValue::new_integer(result, 1) } - _ => MemoryValue::new(result_as_field, bit_size), + _ => MemoryValue::new_integer(result, bit_size), }) } @@ -159,13 +163,13 @@ mod tests { fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: u32) -> u128 { let result_value = evaluate_binary_int_op( op, - MemoryValue::new(a.into(), bit_size), - MemoryValue::new(b.into(), bit_size), + MemoryValue::new_integer(a.into(), bit_size), + MemoryValue::new_integer(b.into(), bit_size), bit_size, ) .unwrap(); // Convert back to u128 - result_value.value.to_u128() + result_value.to_field().to_u128() } fn to_negative(a: u128, bit_size: u32) -> u128 { diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs index bd33b5ee8fc0..73981fb06259 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs @@ -20,7 +20,7 @@ fn read_heap_array<'a>(memory: &'a Memory, array: &HeapArray) -> &'a [MemoryValu /// Extracts the last byte of every value fn to_u8_vec(inputs: &[MemoryValue]) -> Vec { let mut result = Vec::with_capacity(inputs.len()); - for &input in inputs { + for input in inputs { result.push(input.try_into().unwrap()); } result @@ -63,7 +63,7 @@ pub(crate) fn evaluate_black_box( BlackBoxOp::Keccakf1600 { message, output } => { let state_vec: Vec = read_heap_vector(memory, message) .iter() - .map(|&memory_value| memory_value.try_into().unwrap()) + .map(|memory_value| memory_value.try_into().unwrap()) .collect(); let state: [u64; 25] = state_vec.try_into().unwrap(); @@ -151,7 +151,7 @@ pub(crate) fn evaluate_black_box( } BlackBoxOp::PedersenCommitment { inputs, domain_separator, output } => { let inputs: Vec = - read_heap_vector(memory, inputs).iter().map(|&x| x.try_into().unwrap()).collect(); + read_heap_vector(memory, inputs).iter().map(|x| x.try_into().unwrap()).collect(); let domain_separator: u32 = memory.read(*domain_separator).try_into().map_err(|_| { BlackBoxResolutionError::Failed( @@ -165,7 +165,7 @@ pub(crate) fn evaluate_black_box( } BlackBoxOp::PedersenHash { inputs, domain_separator, output } => { let inputs: Vec = - read_heap_vector(memory, inputs).iter().map(|&x| x.try_into().unwrap()).collect(); + read_heap_vector(memory, inputs).iter().map(|x| x.try_into().unwrap()).collect(); let domain_separator: u32 = memory.read(*domain_separator).try_into().map_err(|_| { BlackBoxResolutionError::Failed( @@ -185,7 +185,7 @@ pub(crate) fn evaluate_black_box( BlackBoxOp::BigIntToLeBytes { .. } => todo!(), BlackBoxOp::Poseidon2Permutation { message, output, len } => { let input = read_heap_vector(memory, message); - let input: Vec = input.iter().map(|&x| x.try_into().unwrap()).collect(); + let input: Vec = input.iter().map(|x| x.try_into().unwrap()).collect(); let len = memory.read(*len).try_into().unwrap(); let result = solver.poseidon2_permutation(&input, len)?; let mut values = Vec::new(); @@ -204,7 +204,7 @@ pub(crate) fn evaluate_black_box( format!("Expected 16 inputs but encountered {}", &inputs.len()), )); } - for (i, &input) in inputs.iter().enumerate() { + for (i, input) in inputs.iter().enumerate() { message[i] = input.try_into().unwrap(); } let mut state = [0; 8]; @@ -215,7 +215,7 @@ pub(crate) fn evaluate_black_box( format!("Expected 8 values but encountered {}", &values.len()), )); } - for (i, &value) in values.iter().enumerate() { + for (i, value) in values.iter().enumerate() { state[i] = value.try_into().unwrap(); } diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs index 65654e24720c..26d5da675769 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/lib.rs @@ -289,8 +289,8 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { // Convert our source_pointer to an address let source = self.memory.read_ref(*source_pointer); // Use our usize source index to lookup the value in memory - let value = &self.memory.read(source); - self.memory.write(*destination_address, *value); + let value = self.memory.read(source); + self.memory.write(*destination_address, value); self.increment_program_counter() } Opcode::Store { destination_pointer, source: source_address } => { @@ -307,7 +307,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { } Opcode::Const { destination, value, bit_size } => { // Consts are not checked in runtime to fit in the bit size, since they can safely be checked statically. - self.memory.write(*destination, MemoryValue::new(*value, *bit_size)); + self.memory.write(*destination, MemoryValue::new_from_field(*value, *bit_size)); self.increment_program_counter() } Opcode::BlackBox(black_box_op) => { @@ -348,7 +348,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { ) -> ForeignCallParam { match (input, value_type) { (ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(_)) => { - self.memory.read(value_index).value.into() + self.memory.read(value_index).to_field().into() } ( ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }), @@ -357,7 +357,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { let start = self.memory.read_ref(pointer_index); self.read_slice_of_values_from_memory(start, size, value_types) .into_iter() - .map(|mem_value| mem_value.value) + .map(|mem_value| mem_value.to_field()) .collect::>() .into() } @@ -369,7 +369,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { let size = self.memory.read(size_index).to_usize(); self.read_slice_of_values_from_memory(start, size, value_types) .into_iter() - .map(|mem_value| mem_value.value) + .map(|mem_value| mem_value.to_field()) .collect::>() .into() } @@ -584,12 +584,9 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { /// Casts a value to a different bit size. fn cast(&self, bit_size: u32, source_value: MemoryValue) -> MemoryValue { - let lhs_big = BigUint::from_bytes_be(&source_value.value.to_be_bytes()); + let lhs_big = source_value.to_integer(); let mask = BigUint::from(2_u32).pow(bit_size) - 1_u32; - MemoryValue { - value: FieldElement::from_be_bytes_reduce(&(lhs_big & mask).to_bytes_be()), - bit_size, - } + MemoryValue::new_from_integer(lhs_big & mask, bit_size) } } @@ -627,7 +624,7 @@ mod tests { let VM { memory, .. } = vm; let output_value = memory.read(MemoryAddress::from(0)); - assert_eq!(output_value.value, FieldElement::from(27u128)); + assert_eq!(output_value.to_field(), FieldElement::from(27u128)); } #[test] @@ -666,7 +663,7 @@ mod tests { assert_eq!(status, VMStatus::InProgress); let output_cmp_value = vm.memory.read(destination); - assert_eq!(output_cmp_value.value, true.into()); + assert_eq!(output_cmp_value.to_field(), true.into()); let status = vm.process_opcode(); assert_eq!(status, VMStatus::InProgress); @@ -725,7 +722,7 @@ mod tests { assert_eq!(status, VMStatus::InProgress); let output_cmp_value = vm.memory.read(MemoryAddress::from(2)); - assert_eq!(output_cmp_value.value, false.into()); + assert_eq!(output_cmp_value.to_field(), false.into()); let status = vm.process_opcode(); assert_eq!(status, VMStatus::InProgress); @@ -742,7 +739,7 @@ mod tests { // The address at index `2` should have not changed as we jumped over the add opcode let VM { memory, .. } = vm; let output_value = memory.read(MemoryAddress::from(2)); - assert_eq!(output_value.value, false.into()); + assert_eq!(output_value.to_field(), false.into()); } #[test] @@ -776,7 +773,7 @@ mod tests { let VM { memory, .. } = vm; let casted_value = memory.read(MemoryAddress::from(1)); - assert_eq!(casted_value.value, (2_u128.pow(8) - 1).into()); + assert_eq!(casted_value.to_field(), (2_u128.pow(8) - 1).into()); } #[test] @@ -804,10 +801,10 @@ mod tests { let VM { memory, .. } = vm; let destination_value = memory.read(MemoryAddress::from(2)); - assert_eq!(destination_value.value, (1u128).into()); + assert_eq!(destination_value.to_field(), (1u128).into()); let source_value = memory.read(MemoryAddress::from(0)); - assert_eq!(source_value.value, (1u128).into()); + assert_eq!(source_value.to_field(), (1u128).into()); } #[test] @@ -869,10 +866,10 @@ mod tests { let VM { memory, .. } = vm; let destination_value = memory.read(MemoryAddress::from(4)); - assert_eq!(destination_value.value, (3_u128).into()); + assert_eq!(destination_value.to_field(), (3_u128).into()); let source_value = memory.read(MemoryAddress::from(5)); - assert_eq!(source_value.value, (2_u128).into()); + assert_eq!(source_value.to_field(), (2_u128).into()); } #[test] @@ -1120,7 +1117,7 @@ mod tests { let opcodes = [&start[..], &loop_body[..]].concat(); let vm = brillig_execute_and_get_vm(memory, &opcodes); - vm.memory.read(r_sum).value + vm.memory.read(r_sum).to_field() } assert_eq!( @@ -1359,7 +1356,7 @@ mod tests { // Check result in memory let result_values = vm.memory.read_slice(MemoryAddress(2), 4).to_vec(); assert_eq!( - result_values.into_iter().map(|mem_value| mem_value.value).collect::>(), + result_values.into_iter().map(|mem_value| mem_value.to_field()).collect::>(), expected_result ); @@ -1459,7 +1456,7 @@ mod tests { .memory .read_slice(MemoryAddress(4 + input_string.len()), output_string.len()) .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.clone().to_field()) .collect(); assert_eq!(result_values, output_string); @@ -1532,13 +1529,21 @@ mod tests { assert_eq!(vm.status, VMStatus::Finished { return_data_offset: 0, return_data_size: 0 }); // Check initial memory still in place - let initial_values: Vec<_> = - vm.memory.read_slice(MemoryAddress(2), 4).iter().map(|mem_val| mem_val.value).collect(); + let initial_values: Vec<_> = vm + .memory + .read_slice(MemoryAddress(2), 4) + .iter() + .map(|mem_val| mem_val.clone().to_field()) + .collect(); assert_eq!(initial_values, initial_matrix); // Check result in memory - let result_values: Vec<_> = - vm.memory.read_slice(MemoryAddress(6), 4).iter().map(|mem_val| mem_val.value).collect(); + let result_values: Vec<_> = vm + .memory + .read_slice(MemoryAddress(6), 4) + .iter() + .map(|mem_val| mem_val.clone().to_field()) + .collect(); assert_eq!(result_values, expected_result); // Ensure the foreign call counter has been incremented @@ -1622,8 +1627,12 @@ mod tests { assert_eq!(vm.status, VMStatus::Finished { return_data_offset: 0, return_data_size: 0 }); // Check result in memory - let result_values: Vec<_> = - vm.memory.read_slice(MemoryAddress(0), 4).iter().map(|mem_val| mem_val.value).collect(); + let result_values: Vec<_> = vm + .memory + .read_slice(MemoryAddress(0), 4) + .iter() + .map(|mem_val| mem_val.clone().to_field()) + .collect(); assert_eq!(result_values, expected_result); // Ensure the foreign call counter has been incremented @@ -1698,7 +1707,7 @@ mod tests { .chain(memory.iter().enumerate().map(|(index, mem_value)| Opcode::Cast { destination: MemoryAddress(index), source: MemoryAddress(index), - bit_size: mem_value.bit_size, + bit_size: mem_value.bit_size(), })) .chain(vec![ // input = 0 @@ -1721,7 +1730,7 @@ mod tests { .collect(); let mut vm = brillig_execute_and_get_vm( - memory.into_iter().map(|mem_value| mem_value.value).collect(), + memory.into_iter().map(|mem_value| mem_value.to_field()).collect(), &program, ); diff --git a/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs b/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs index d563e13be2e9..feeb3706bde3 100644 --- a/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs +++ b/noir/noir-repo/acvm-repo/brillig_vm/src/memory.rs @@ -1,11 +1,13 @@ use acir::{brillig::MemoryAddress, FieldElement}; +use num_bigint::BigUint; +use num_traits::{One, Zero}; pub const MEMORY_ADDRESSING_BIT_SIZE: u32 = 64; -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub struct MemoryValue { - pub value: FieldElement, - pub bit_size: u32, +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum MemoryValue { + Field(FieldElement), + Integer(BigUint, u32), } #[derive(Debug, thiserror::Error)] @@ -15,53 +17,147 @@ pub enum MemoryTypeError { } impl MemoryValue { - pub fn new(value: FieldElement, bit_size: u32) -> Self { - MemoryValue { value, bit_size } + /// Builds a memory value from a field element. + pub fn new_from_field(value: FieldElement, bit_size: u32) -> Self { + if bit_size == FieldElement::max_num_bits() { + MemoryValue::new_field(value) + } else { + MemoryValue::new_integer(BigUint::from_bytes_be(&value.to_be_bytes()), bit_size) + } + } + + /// Builds a memory value from an integer + pub fn new_from_integer(value: BigUint, bit_size: u32) -> Self { + if bit_size == FieldElement::max_num_bits() { + MemoryValue::new_field(FieldElement::from_be_bytes_reduce(&value.to_bytes_be())) + } else { + MemoryValue::new_integer(value, bit_size) + } } + /// Builds a memory value from a field element, checking that the value is within the bit size. pub fn new_checked(value: FieldElement, bit_size: u32) -> Option { - if value.num_bits() > bit_size { + if bit_size < FieldElement::max_num_bits() && value.num_bits() > bit_size { return None; } - Some(MemoryValue::new(value, bit_size)) + Some(MemoryValue::new_from_field(value, bit_size)) } + /// Builds a field-typed memory value. pub fn new_field(value: FieldElement) -> Self { - MemoryValue { value, bit_size: FieldElement::max_num_bits() } + MemoryValue::Field(value) + } + + /// Builds an integer-typed memory value. + pub fn new_integer(value: BigUint, bit_size: u32) -> Self { + assert!( + bit_size != FieldElement::max_num_bits(), + "Tried to build a field memory value via new_integer" + ); + MemoryValue::Integer(value, bit_size) + } + + /// Extracts the field element from the memory value, if it is typed as field element. + pub fn extract_field(&self) -> Option<&FieldElement> { + match self { + MemoryValue::Field(value) => Some(value), + _ => None, + } + } + + /// Extracts the integer from the memory value, if it is typed as integer. + pub fn extract_integer(&self) -> Option<(&BigUint, u32)> { + match self { + MemoryValue::Integer(value, bit_size) => Some((value, *bit_size)), + _ => None, + } + } + + /// Converts the memory value to a field element, independent of its type. + pub fn to_field(&self) -> FieldElement { + match self { + MemoryValue::Field(value) => *value, + MemoryValue::Integer(value, _) => { + FieldElement::from_be_bytes_reduce(&value.to_bytes_be()) + } + } + } + + /// Converts the memory value to an integer, independent of its type. + pub fn to_integer(self) -> BigUint { + match self { + MemoryValue::Field(value) => BigUint::from_bytes_be(&value.to_be_bytes()), + MemoryValue::Integer(value, _) => value, + } + } + + pub fn bit_size(&self) -> u32 { + match self { + MemoryValue::Field(_) => FieldElement::max_num_bits(), + MemoryValue::Integer(_, bit_size) => *bit_size, + } } pub fn to_usize(&self) -> usize { - assert!(self.bit_size == MEMORY_ADDRESSING_BIT_SIZE, "value is not typed as brillig usize"); - self.value.to_u128() as usize + assert!( + self.bit_size() == MEMORY_ADDRESSING_BIT_SIZE, + "value is not typed as brillig usize" + ); + self.extract_integer().unwrap().0.try_into().unwrap() } - pub fn expect_bit_size(&self, expected_bit_size: u32) -> Result<(), MemoryTypeError> { - if self.bit_size != expected_bit_size { - return Err(MemoryTypeError::MismatchedBitSize { - value_bit_size: self.bit_size, + pub fn expect_field(&self) -> Result<&FieldElement, MemoryTypeError> { + match self { + MemoryValue::Integer(_, bit_size) => Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: *bit_size, + expected_bit_size: FieldElement::max_num_bits(), + }), + MemoryValue::Field(field) => Ok(field), + } + } + + pub fn expect_integer_with_bit_size( + &self, + expected_bit_size: u32, + ) -> Result<&BigUint, MemoryTypeError> { + match self { + MemoryValue::Integer(value, bit_size) => { + if *bit_size != expected_bit_size { + return Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: *bit_size, + expected_bit_size, + }); + } + Ok(value) + } + MemoryValue::Field(_) => Err(MemoryTypeError::MismatchedBitSize { + value_bit_size: FieldElement::max_num_bits(), expected_bit_size, - }); + }), } - Ok(()) } } impl std::fmt::Display for MemoryValue { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> { - let typ = match self.bit_size { - 0 => "null".to_string(), - 1 => "bool".to_string(), - _ if self.bit_size == FieldElement::max_num_bits() => "field".to_string(), - _ => format!("u{}", self.bit_size), - }; - f.write_str(format!("{}: {}", self.value, typ).as_str()) + match self { + MemoryValue::Field(value) => write!(f, "{}: field", value), + MemoryValue::Integer(value, bit_size) => { + let typ = match bit_size { + 0 => "null".to_string(), + 1 => "bool".to_string(), + _ => format!("u{}", bit_size), + }; + write!(f, "{}: {}", value, typ) + } + } } } impl Default for MemoryValue { fn default() -> Self { - MemoryValue::new(FieldElement::zero(), 0) + MemoryValue::new_integer(BigUint::zero(), 0) } } @@ -73,31 +169,32 @@ impl From for MemoryValue { impl From for MemoryValue { fn from(value: usize) -> Self { - MemoryValue::new(value.into(), MEMORY_ADDRESSING_BIT_SIZE) + MemoryValue::new_integer(value.into(), MEMORY_ADDRESSING_BIT_SIZE) } } impl From for MemoryValue { fn from(value: u64) -> Self { - MemoryValue::new((value as u128).into(), 64) + MemoryValue::new_integer(value.into(), 64) } } impl From for MemoryValue { fn from(value: u32) -> Self { - MemoryValue::new((value as u128).into(), 32) + MemoryValue::new_integer(value.into(), 32) } } impl From for MemoryValue { fn from(value: u8) -> Self { - MemoryValue::new((value as u128).into(), 8) + MemoryValue::new_integer(value.into(), 8) } } impl From for MemoryValue { fn from(value: bool) -> Self { - MemoryValue::new(value.into(), 1) + let value = if value { BigUint::one() } else { BigUint::zero() }; + MemoryValue::new_integer(value, 1) } } @@ -105,8 +202,7 @@ impl TryFrom for FieldElement { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(FieldElement::max_num_bits())?; - Ok(memory_value.value) + memory_value.expect_field().copied() } } @@ -114,8 +210,7 @@ impl TryFrom for u64 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(64)?; - Ok(memory_value.value.to_u128() as u64) + memory_value.expect_integer_with_bit_size(64).map(|value| value.try_into().unwrap()) } } @@ -123,8 +218,7 @@ impl TryFrom for u32 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(32)?; - Ok(memory_value.value.to_u128() as u32) + memory_value.expect_integer_with_bit_size(32).map(|value| value.try_into().unwrap()) } } @@ -132,9 +226,7 @@ impl TryFrom for u8 { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(8)?; - - Ok(memory_value.value.to_u128() as u8) + memory_value.expect_integer_with_bit_size(8).map(|value| value.try_into().unwrap()) } } @@ -142,11 +234,65 @@ impl TryFrom for bool { type Error = MemoryTypeError; fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_bit_size(1)?; + let as_integer = memory_value.expect_integer_with_bit_size(1)?; + + if as_integer.is_zero() { + Ok(false) + } else if as_integer.is_one() { + Ok(true) + } else { + unreachable!("value typed as bool is greater than one") + } + } +} + +impl TryFrom<&MemoryValue> for FieldElement { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_field().copied() + } +} + +impl TryFrom<&MemoryValue> for u64 { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_integer_with_bit_size(64).map(|value| { + value.try_into().expect("memory_value has been asserted to contain a 64 bit integer") + }) + } +} + +impl TryFrom<&MemoryValue> for u32 { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_integer_with_bit_size(32).map(|value| { + value.try_into().expect("memory_value has been asserted to contain a 32 bit integer") + }) + } +} + +impl TryFrom<&MemoryValue> for u8 { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + memory_value.expect_integer_with_bit_size(8).map(|value| { + value.try_into().expect("memory_value has been asserted to contain an 8 bit integer") + }) + } +} + +impl TryFrom<&MemoryValue> for bool { + type Error = MemoryTypeError; + + fn try_from(memory_value: &MemoryValue) -> Result { + let as_integer = memory_value.expect_integer_with_bit_size(1)?; - if memory_value.value == FieldElement::zero() { + if as_integer.is_zero() { Ok(false) - } else if memory_value.value == FieldElement::one() { + } else if as_integer.is_one() { Ok(true) } else { unreachable!("value typed as bool is greater than one") @@ -164,7 +310,7 @@ pub struct Memory { impl Memory { /// Gets the value at pointer pub fn read(&self, ptr: MemoryAddress) -> MemoryValue { - self.inner.get(ptr.to_usize()).copied().unwrap_or_default() + self.inner.get(ptr.to_usize()).cloned().unwrap_or_default() } pub fn read_ref(&self, ptr: MemoryAddress) -> MemoryAddress { @@ -191,7 +337,7 @@ impl Memory { /// Sets the values after pointer `ptr` to `values` pub fn write_slice(&mut self, ptr: MemoryAddress, values: &[MemoryValue]) { self.resize_to_fit(ptr.to_usize() + values.len()); - self.inner[ptr.to_usize()..(ptr.to_usize() + values.len())].copy_from_slice(values); + self.inner[ptr.to_usize()..(ptr.to_usize() + values.len())].clone_from_slice(values); } /// Returns the values of the memory diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs index 15a2a531e786..4b97a61491d0 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs @@ -67,61 +67,105 @@ pub(crate) fn directive_invert() -> GeneratedBrillig { /// (a/b, a-a/b*b) /// } /// ``` -pub(crate) fn directive_quotient(mut bit_size: u32) -> GeneratedBrillig { +pub(crate) fn directive_quotient(bit_size: u32) -> GeneratedBrillig { // `a` is (0) (i.e register index 0) // `b` is (1) - if bit_size > FieldElement::max_num_bits() { - bit_size = FieldElement::max_num_bits(); - } - GeneratedBrillig { - byte_code: vec![ - BrilligOpcode::CalldataCopy { - destination_address: MemoryAddress::from(0), - size: 2, - offset: 0, - }, - BrilligOpcode::Cast { - destination: MemoryAddress(0), - source: MemoryAddress(0), - bit_size, - }, - BrilligOpcode::Cast { - destination: MemoryAddress(1), - source: MemoryAddress(1), - bit_size, - }, - //q = a/b is set into register (2) - BrilligOpcode::BinaryIntOp { - op: BinaryIntOp::Div, - lhs: MemoryAddress::from(0), - rhs: MemoryAddress::from(1), - destination: MemoryAddress::from(2), - bit_size, - }, - //(1)= q*b - BrilligOpcode::BinaryIntOp { - op: BinaryIntOp::Mul, - lhs: MemoryAddress::from(2), - rhs: MemoryAddress::from(1), - destination: MemoryAddress::from(1), - bit_size, - }, - //(1) = a-q*b - BrilligOpcode::BinaryIntOp { - op: BinaryIntOp::Sub, - lhs: MemoryAddress::from(0), - rhs: MemoryAddress::from(1), - destination: MemoryAddress::from(1), - bit_size, - }, - //(0) = q - BrilligOpcode::Mov { - destination: MemoryAddress::from(0), - source: MemoryAddress::from(2), - }, - BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 2 }, - ], - assert_messages: Default::default(), - locations: Default::default(), + + // TODO: The only difference between these implementations is the integer version will truncate the input to the `bit_size` via cast. + // Once we deduplicate brillig functions then we can modify this so that fields and integers share the same quotient function. + if bit_size >= FieldElement::max_num_bits() { + // Field version + GeneratedBrillig { + byte_code: vec![ + BrilligOpcode::CalldataCopy { + destination_address: MemoryAddress::from(0), + size: 2, + offset: 0, + }, + // No cast, since calldata is typed as field by default + //q = a/b is set into register (2) + BrilligOpcode::BinaryFieldOp { + op: BinaryFieldOp::IntegerDiv, // We want integer division, not field division! + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(2), + }, + //(1)= q*b + BrilligOpcode::BinaryFieldOp { + op: BinaryFieldOp::Mul, + lhs: MemoryAddress::from(2), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + }, + //(1) = a-q*b + BrilligOpcode::BinaryFieldOp { + op: BinaryFieldOp::Sub, + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + }, + //(0) = q + BrilligOpcode::Mov { + destination: MemoryAddress::from(0), + source: MemoryAddress::from(2), + }, + BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 2 }, + ], + assert_messages: Default::default(), + locations: Default::default(), + } + } else { + // Integer version + GeneratedBrillig { + byte_code: vec![ + BrilligOpcode::CalldataCopy { + destination_address: MemoryAddress::from(0), + size: 2, + offset: 0, + }, + BrilligOpcode::Cast { + destination: MemoryAddress(0), + source: MemoryAddress(0), + bit_size, + }, + BrilligOpcode::Cast { + destination: MemoryAddress(1), + source: MemoryAddress(1), + bit_size, + }, + //q = a/b is set into register (2) + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Div, + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(2), + bit_size, + }, + //(1)= q*b + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Mul, + lhs: MemoryAddress::from(2), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + bit_size, + }, + //(1) = a-q*b + BrilligOpcode::BinaryIntOp { + op: BinaryIntOp::Sub, + lhs: MemoryAddress::from(0), + rhs: MemoryAddress::from(1), + destination: MemoryAddress::from(1), + bit_size, + }, + //(0) = q + BrilligOpcode::Mov { + destination: MemoryAddress::from(0), + source: MemoryAddress::from(2), + }, + BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 2 }, + ], + assert_messages: Default::default(), + locations: Default::default(), + } } } diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs index b93693d9c797..3e1515b1eed4 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_slice_ops.rs @@ -465,7 +465,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); @@ -590,7 +590,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); @@ -686,7 +686,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); @@ -838,7 +838,7 @@ mod tests { assert_eq!( vm.get_memory()[return_data_offset..(return_data_offset + expected_return.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), expected_return ); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index db872487fcc4..fe3c5e0bb9c1 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -527,7 +527,7 @@ mod tests { let (vm, return_data_offset, return_data_size) = create_and_run_vm(calldata.clone(), &bytecode); assert_eq!(return_data_size, 1, "Return data size is incorrect"); - assert_eq!(vm.get_memory()[return_data_offset].value, FieldElement::from(1_usize)); + assert_eq!(vm.get_memory()[return_data_offset].to_field(), FieldElement::from(1_usize)); } #[test] @@ -569,7 +569,7 @@ mod tests { assert_eq!( memory[return_data_pointer..(return_data_pointer + flattened_array.len())] .iter() - .map(|mem_val| mem_val.value) + .map(|mem_val| mem_val.to_field()) .collect::>(), flattened_array ); diff --git a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 53d9e2530cc4..775571f4a413 100644 --- a/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/noir/noir-repo/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -1623,7 +1623,7 @@ impl AcirContext { let outputs_var = vecmap(outputs_types.iter(), |output| match output { AcirType::NumericType(_) => { let var = self.add_data(AcirVarData::Const( - memory.next().expect("Missing return data").value, + memory.next().expect("Missing return data").to_field(), )); AcirValue::Var(var, output.clone()) } @@ -1657,7 +1657,7 @@ impl AcirContext { AcirType::NumericType(_) => { let memory_value = memory_iter.next().expect("ICE: Unexpected end of memory"); - let var = self.add_data(AcirVarData::Const(memory_value.value)); + let var = self.add_data(AcirVarData::Const(memory_value.to_field())); array_values.push_back(AcirValue::Var(var, element_type.clone())); } } diff --git a/noir/noir-repo/tooling/debugger/src/context.rs b/noir/noir-repo/tooling/debugger/src/context.rs index 1acd581b2be1..b211832518d4 100644 --- a/noir/noir-repo/tooling/debugger/src/context.rs +++ b/noir/noir-repo/tooling/debugger/src/context.rs @@ -513,7 +513,11 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { pub(super) fn write_brillig_memory(&mut self, ptr: usize, value: FieldElement, bit_size: u32) { if let Some(solver) = self.brillig_solver.as_mut() { - solver.write_memory_at(ptr, MemoryValue::new(value, bit_size)); + solver.write_memory_at( + ptr, + MemoryValue::new_checked(value, bit_size) + .expect("Invalid value for the given bit size"), + ); } } diff --git a/noir/noir-repo/tooling/debugger/src/repl.rs b/noir/noir-repo/tooling/debugger/src/repl.rs index 1c077c6ee9bf..e30d519b62e8 100644 --- a/noir/noir-repo/tooling/debugger/src/repl.rs +++ b/noir/noir-repo/tooling/debugger/src/repl.rs @@ -319,7 +319,7 @@ impl<'a, B: BlackBoxFunctionSolver> ReplDebugger<'a, B> { return; }; - for (index, value) in memory.iter().enumerate().filter(|(_, value)| value.bit_size > 0) { + for (index, value) in memory.iter().enumerate().filter(|(_, value)| value.bit_size() > 0) { println!("{index} = {}", value); } }