diff --git a/core/src/air/machine.rs b/core/src/air/machine.rs index 04d5042e4..da71e27db 100644 --- a/core/src/air/machine.rs +++ b/core/src/air/machine.rs @@ -1,14 +1,91 @@ +use std::marker::PhantomData; + use p3_air::BaseAir; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; pub use sphinx_derive::MachineAir; -use crate::{runtime::Program, stark::MachineRecord}; +use crate::{ + runtime::Program, + stark::{Indexed, MachineRecord}, +}; + +/// A description of the events related to this AIR. +pub trait WithEvents<'a>: Sized { + /// output of a functional lens from the Record to + /// refs of those events relative to the AIR. + type Events: 'a; +} + +/// A trait intended for implementation on Records that may store events related to Chips, +/// The purpose of this trait is to provide a way to access the events relative to a specific +/// Chip, as specified by its `WithEvents` trait implementation. +/// +/// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 ) +pub trait EventLens WithEvents<'b>>: Indexed { + fn events(&self) -> >::Events; +} + +//////////////// Derive macro shenanigans //////////////////////////////////////////////// +// The following is *only* useful for the derive macros, you should *not* use this directly. +// +/// Hereafter, Lens composition explained pedantically: all this is saying is that +/// if I have an EventLens to T::Events, and a way (F) to deduce U::Events from that, +/// I can compose them to get an EventLens to U::Events. +pub struct Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens, +{ + record: &'a R, + projection: F, + _phantom: PhantomData, +} + +/// A constructor for the projection from T::Events to U::Events. +impl<'a, T, R, F> Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens, +{ + pub fn new(record: &'a R, projection: F) -> Self { + Self { + record, + projection, + _phantom: PhantomData, + } + } +} + +impl<'a, T, R, U, F> EventLens for Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens, + U: for<'b> WithEvents<'b>, + // see https://github.com/rust-lang/rust/issues/86702 for the empty parameter + F: for<'c> Fn(>::Events, &'c ()) -> >::Events, +{ + fn events<'c>(&'c self) -> >::Events { + let events: >::Events = self.record.events(); + (self.projection)(events, &()) + } +} + +impl<'a, T, R, F> Indexed for Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens + Indexed, +{ + fn index(&self) -> u32 { + self.record.index() + } +} +//////////////// end of shenanigans destined for the derive macros. //////////////// /// An AIR that is part of a multi table AIR arithmetization. -pub trait MachineAir: BaseAir { +pub trait MachineAir: BaseAir + for<'a> WithEvents<'a> { /// The execution record containing events for producing the air trace. - type Record: MachineRecord; + type Record: MachineRecord + EventLens; type Program: MachineProgram; @@ -20,10 +97,14 @@ pub trait MachineAir: BaseAir { /// - `input` is the execution record containing the events to be written to the trace. /// - `output` is the execution record containing events that the `MachineAir` can add to /// the record such as byte lookup requests. - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix; + fn generate_trace>( + &self, + input: &EL, + output: &mut Self::Record, + ) -> RowMajorMatrix; /// Generate the dependencies for a given execution record. - fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { + fn generate_dependencies>(&self, input: &EL, output: &mut Self::Record) { self.generate_trace(input, output); } diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 7ea030013..55e200a9c 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -12,13 +12,15 @@ use p3_maybe_rayon::prelude::ParallelSlice; use sphinx_derive::AlignedBorrow; use crate::{ - air::{AluAirBuilder, MachineAir, Word}, + air::{AluAirBuilder, EventLens, MachineAir, WithEvents, Word}, operations::AddOperation, runtime::{ExecutionRecord, Opcode, Program}, stark::MachineRecord, utils::pad_to_power_of_two, }; +use super::AluEvent; + /// The number of main trace columns for `AddSubChip`. pub const NUM_ADD_SUB_COLS: usize = size_of::>(); @@ -55,6 +57,15 @@ pub struct AddSubCols { pub is_sub: T, } +impl<'a> WithEvents<'a> for AddSubChip { + type Events = ( + // add events + &'a [AluEvent], + // sub events + &'a [AluEvent], + ); +} + impl MachineAir for AddSubChip { type Record = ExecutionRecord; @@ -64,20 +75,17 @@ impl MachineAir for AddSubChip { "AddSub".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, - output: &mut ExecutionRecord, + input: &EL, + output: &mut Self::Record, ) -> RowMajorMatrix { + let (add_events, sub_events) = input.events(); // Generate the rows for the trace. - let chunk_size = std::cmp::max( - (input.add_events.len() + input.sub_events.len()) / num_cpus::get(), - 1, - ); - let merged_events = input - .add_events + let chunk_size = std::cmp::max((add_events.len() + sub_events.len()) / num_cpus::get(), 1); + let merged_events = add_events .iter() - .chain(input.sub_events.iter()) + .chain(sub_events.iter()) .collect::>(); let rows_and_records = merged_events diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 8cb0b1958..2b3e067b2 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -9,13 +9,15 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir}; +use crate::air::{EventLens, WithEvents, Word}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `BitwiseChip`. pub const NUM_BITWISE_COLS: usize = size_of::>(); @@ -49,6 +51,10 @@ pub struct BitwiseCols { pub is_and: T, } +impl<'a> WithEvents<'a> for BitwiseChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for BitwiseChip { type Record = ExecutionRecord; @@ -58,14 +64,14 @@ impl MachineAir for BitwiseChip { "Bitwise".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let rows = input - .bitwise_events + .events() .iter() .map(|event| { let mut row = [F::zero(); NUM_BITWISE_COLS]; diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index ff212b086..e12472be7 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -75,8 +75,8 @@ use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; use self::utils::eval_abs_value; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation}; use crate::alu::AluEvent; use crate::bytes::event::ByteRecord; @@ -187,6 +187,10 @@ pub struct DivRemCols { pub is_real: T, } +impl<'a> WithEvents<'a> for DivRemChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for DivRemChip { type Record = ExecutionRecord; @@ -196,13 +200,13 @@ impl MachineAir for DivRemChip { "DivRem".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. - let divrem_events = &input.divrem_events; + let divrem_events = input.events(); let mut rows: Vec<[F; NUM_DIVREM_COLS]> = Vec::with_capacity(divrem_events.len()); for event in divrem_events { assert!( @@ -405,7 +409,7 @@ impl MachineAir for DivRemChip { row }; debug_assert!(padded_row_template.len() == NUM_DIVREM_COLS); - for i in input.divrem_events.len() * NUM_DIVREM_COLS..trace.values.len() { + for i in input.events().len() * NUM_DIVREM_COLS..trace.values.len() { trace.values[i] = padded_row_template[i % NUM_DIVREM_COLS]; } diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 8d965153b..c70b84f93 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -12,13 +12,15 @@ use p3_matrix::Matrix; use p3_maybe_rayon::prelude::*; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, BaseAirBuilder, ByteAirBuilder, MachineAir}; +use crate::air::{EventLens, WithEvents, Word}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `LtChip`. pub const NUM_LT_COLS: usize = size_of::>(); @@ -91,6 +93,10 @@ impl LtCols { } } +impl<'a> WithEvents<'a> for LtChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for LtChip { type Record = ExecutionRecord; @@ -100,14 +106,14 @@ impl MachineAir for LtChip { "Lt".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let (rows, new_byte_lookup_events): (Vec<_>, Vec<_>) = input - .lt_events + .events() .par_iter() .map(|event| { let mut row = [F::zero(); NUM_LT_COLS]; diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index bbdab712b..8681e5a4d 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -44,8 +44,8 @@ use p3_maybe_rayon::prelude::ParallelIterator; use p3_maybe_rayon::prelude::ParallelSlice; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::alu::mul::utils::get_msb; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -54,6 +54,8 @@ use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::stark::MachineRecord; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `MulChip`. pub const NUM_MUL_COLS: usize = size_of::>(); @@ -121,6 +123,10 @@ pub struct MulCols { pub is_real: T, } +impl<'a> WithEvents<'a> for MulChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for MulChip { type Record = ExecutionRecord; @@ -130,12 +136,12 @@ impl MachineAir for MulChip { "Mul".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mul_events = input.mul_events.clone(); + let mul_events = input.events(); // Compute the chunk size based on the number of events and the number of CPUs. let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1); diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index ecc84530d..39e3ad9f9 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -42,13 +42,15 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::bytes::event::ByteRecord; use crate::disassembler::WORD_SIZE; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `ShiftLeft`. pub const NUM_SHIFT_LEFT_COLS: usize = size_of::>(); @@ -96,6 +98,10 @@ pub struct ShiftLeftCols { pub is_real: T, } +impl<'a> WithEvents<'a> for ShiftLeft { + type Events = &'a [AluEvent]; +} + impl MachineAir for ShiftLeft { type Record = ExecutionRecord; @@ -105,14 +111,14 @@ impl MachineAir for ShiftLeft { "ShiftLeft".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let mut rows: Vec<[F; NUM_SHIFT_LEFT_COLS]> = vec![]; - let shift_left_events = input.shift_left_events.clone(); + let shift_left_events = input.events(); for event in shift_left_events.iter() { let mut row = [F::zero(); NUM_SHIFT_LEFT_COLS]; let cols: &mut ShiftLeftCols = row.as_mut_slice().borrow_mut(); @@ -193,7 +199,7 @@ impl MachineAir for ShiftLeft { row }; debug_assert!(padded_row_template.len() == NUM_SHIFT_LEFT_COLS); - for i in input.shift_left_events.len() * NUM_SHIFT_LEFT_COLS..trace.values.len() { + for i in input.events().len() * NUM_SHIFT_LEFT_COLS..trace.values.len() { trace.values[i] = padded_row_template[i % NUM_SHIFT_LEFT_COLS]; } diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 2b52df3f3..a9d85e9e9 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -55,8 +55,8 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::alu::sr::utils::{nb_bits_to_shift, nb_bytes_to_shift}; use crate::bytes::event::ByteRecord; use crate::bytes::utils::shr_carry; @@ -65,6 +65,8 @@ use crate::disassembler::WORD_SIZE; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `ShiftRightChip`. pub const NUM_SHIFT_RIGHT_COLS: usize = size_of::>(); @@ -128,6 +130,10 @@ pub struct ShiftRightCols { pub is_real: T, } +impl<'a> WithEvents<'a> for ShiftRightChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for ShiftRightChip { type Record = ExecutionRecord; @@ -137,14 +143,14 @@ impl MachineAir for ShiftRightChip { "ShiftRight".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let mut rows: Vec<[F; NUM_SHIFT_RIGHT_COLS]> = Vec::new(); - let sr_events = input.shift_right_events.clone(); + let sr_events = input.events(); for event in sr_events.iter() { assert!(event.opcode == Opcode::SRL || event.opcode == Opcode::SRA); let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS]; @@ -272,7 +278,7 @@ impl MachineAir for ShiftRightChip { row }; debug_assert!(padded_row_template.len() == NUM_SHIFT_RIGHT_COLS); - for i in input.shift_right_events.len() * NUM_SHIFT_RIGHT_COLS..trace.values.len() { + for i in input.events().len() * NUM_SHIFT_RIGHT_COLS..trace.values.len() { trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS]; } diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 22b820420..a792b0202 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -1,19 +1,24 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, collections::BTreeMap}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use super::{ columns::{ByteMultCols, NUM_BYTE_MULT_COLS, NUM_BYTE_PREPROCESSED_COLS}, - ByteChip, + ByteChip, ByteLookupEvent, }; use crate::{ - air::MachineAir, + air::{EventLens, MachineAir, WithEvents}, runtime::{ExecutionRecord, Program}, }; pub const NUM_ROWS: usize = 1 << 16; +impl<'a, F: Field> WithEvents<'a> for ByteChip { + // the byte lookups + type Events = &'a BTreeMap>; +} + impl MachineAir for ByteChip { type Record = ExecutionRecord; @@ -35,16 +40,20 @@ impl MachineAir for ByteChip { Some(trace) } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let shard = input.index; + let shard = input.index(); let (_, event_map) = Self::trace_and_map(shard); let mut trace = RowMajorMatrix::new( @@ -52,7 +61,7 @@ impl MachineAir for ByteChip { NUM_BYTE_MULT_COLS, ); - for (lookup, mult) in input.byte_lookups[&shard].iter() { + for (lookup, mult) in input.events()[&shard].iter() { let (row, index) = event_map[lookup]; let cols: &mut ByteMultCols = trace.row_mut(row).borrow_mut(); diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index 5d9955489..5aa4e02a1 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -7,7 +7,7 @@ use tracing::instrument; use super::columns::{CPU_COL_MAP, NUM_CPU_COLS}; use super::{CpuChip, CpuEvent}; -use crate::air::MachineAir; +use crate::air::{EventLens, MachineAir, WithEvents}; use crate::alu::AluEvent; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -18,6 +18,10 @@ use crate::memory::MemoryCols; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::runtime::{MemoryRecordEnum, SyscallCode}; +impl<'a> WithEvents<'a> for CpuChip { + type Events = &'a [CpuEvent]; +} + impl MachineAir for CpuChip { type Record = ExecutionRecord; @@ -27,9 +31,9 @@ impl MachineAir for CpuChip { "CPU".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut new_alu_events = HashMap::new(); @@ -37,7 +41,7 @@ impl MachineAir for CpuChip { // Generate the trace rows for each event. let mut rows_with_events = input - .cpu_events + .events() .par_iter() .map(|op: &CpuEvent| self.event_to_row::(*op)) .collect::>(); @@ -78,11 +82,11 @@ impl MachineAir for CpuChip { } #[instrument(name = "generate cpu dependencies", level = "debug", skip_all)] - fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) { + fn generate_dependencies>(&self, input: &EL, output: &mut ExecutionRecord) { // Generate the trace rows for each event. - let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1); + let chunk_size = std::cmp::max(input.events().len() / num_cpus::get(), 1); let events = input - .cpu_events + .events() .par_chunks(chunk_size) .map(|ops: &[CpuEvent]| { let mut alu = HashMap::new(); diff --git a/core/src/lookup/debug.rs b/core/src/lookup/debug.rs index 5f0b83821..8ed13d7d1 100644 --- a/core/src/lookup/debug.rs +++ b/core/src/lookup/debug.rs @@ -56,11 +56,13 @@ pub fn debug_interactions>>( let mut key_to_vec_data = BTreeMap::new(); let mut key_to_count = BTreeMap::new(); - let trace = chip.generate_trace(record, &mut A::Record::default()); + let trace = chip + .as_ref() + .generate_trace(record, &mut A::Record::default()); let mut pre_traces = pkey.traces.clone(); let mut preprocessed_trace = pkey .chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .map(|&index| pre_traces.get_mut(index).unwrap()); let mut main = trace.clone(); let height = trace.clone().height(); @@ -102,7 +104,7 @@ pub fn debug_interactions>>( .entry(key.clone()) .or_insert_with(Vec::new) .push(InteractionData { - chip_name: chip.name(), + chip_name: chip.as_ref().name(), kind: interaction.kind, row, interaction_number: m, @@ -150,10 +152,17 @@ where .or_insert((SC::Val::zero(), BTreeMap::new())); entry.0 += *value; total += *value; - *entry.1.entry(chip.name()).or_insert(SC::Val::zero()) += *value; + *entry + .1 + .entry(chip.as_ref().name()) + .or_insert(SC::Val::zero()) += *value; } } - tracing::info!("{} chip has {} distinct events", chip.name(), total_events); + tracing::info!( + "{} chip has {} distinct events", + chip.as_ref().name(), + total_events + ); } tracing::info!("Final counts below."); diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index 9a2ac38c3..0003d5d02 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -10,7 +10,9 @@ use sphinx_derive::AlignedBorrow; use super::MemoryInitializeFinalizeEvent; use crate::{ - air::{AirInteraction, BaseAirBuilder, MachineAir, Word, WordAirBuilder}, + air::{ + AirInteraction, BaseAirBuilder, EventLens, MachineAir, WithEvents, Word, WordAirBuilder, + }, runtime::{ExecutionRecord, Program}, utils::pad_to_power_of_two, }; @@ -40,6 +42,15 @@ impl BaseAir for MemoryChip { } } +impl<'a> WithEvents<'a> for MemoryChip { + type Events = ( + // initialize events + &'a [MemoryInitializeFinalizeEvent], + // finalize events + &'a [MemoryInitializeFinalizeEvent], + ); +} + impl MachineAir for MemoryChip { type Record = ExecutionRecord; @@ -52,15 +63,16 @@ impl MachineAir for MemoryChip { } } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mut memory_events = match self.kind { - MemoryChipType::Initialize => input.memory_initialize_events.clone(), - MemoryChipType::Finalize => input.memory_finalize_events.clone(), - }; + let mut memory_events: Vec = match self.kind { + MemoryChipType::Initialize => input.events().0, + MemoryChipType::Finalize => input.events().1, + } + .to_vec(); memory_events.sort_by_key(|event| event.addr); let rows: Vec<[F; 8]> = (0..memory_events.len()) // TODO: change this back to par_iter .map(|i| { @@ -93,7 +105,11 @@ impl MachineAir for MemoryChip { trace } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } diff --git a/core/src/memory/program.rs b/core/src/memory/program.rs index f311c2da9..074603b72 100644 --- a/core/src/memory/program.rs +++ b/core/src/memory/program.rs @@ -5,10 +5,11 @@ use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; +use std::collections::BTreeMap; use sphinx_derive::AlignedBorrow; -use crate::air::{AirInteraction, BaseAirBuilder, PublicValues}; +use crate::air::{AirInteraction, BaseAirBuilder, EventLens, PublicValues, WithEvents}; use crate::air::{MachineAir, Word}; use crate::operations::IsZeroOperation; use crate::runtime::{ExecutionRecord, Program}; @@ -49,6 +50,10 @@ impl MemoryProgramChip { } } +impl<'a> WithEvents<'a> for MemoryProgramChip { + type Events = &'a BTreeMap; +} + impl MachineAir for MemoryProgramChip { type Record = ExecutionRecord; @@ -91,23 +96,22 @@ impl MachineAir for MemoryProgramChip { Some(trace) } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let program_memory_addrs = input - .program - .memory_image - .keys() - .copied() - .collect::>(); + let program_memory_addrs = input.events().keys().copied().collect::>(); - let mult = if input.index == 1 { + let mult = if input.index() == 1 { F::one() } else { F::zero() @@ -120,7 +124,7 @@ impl MachineAir for MemoryProgramChip { let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS]; let cols: &mut MemoryProgramMultCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = mult; - IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1); + IsZeroOperation::populate(&mut cols.is_first_shard, input.index() - 1); row }) diff --git a/core/src/operations/field/extensions/quadratic/mod.rs b/core/src/operations/field/extensions/quadratic/mod.rs index a9b04765b..b91674b0f 100644 --- a/core/src/operations/field/extensions/quadratic/mod.rs +++ b/core/src/operations/field/extensions/quadratic/mod.rs @@ -365,7 +365,7 @@ impl QuadFieldOpCols { #[cfg(test)] mod tests { - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; @@ -411,6 +411,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for QuadFieldOpChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &() + } + } + impl MachineAir for QuadFieldOpChip

{ type Record = ExecutionRecord; @@ -420,9 +430,9 @@ mod tests { format!("QuadFieldOp{:?}", self.operation) } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/extensions/quadratic/sqrt.rs b/core/src/operations/field/extensions/quadratic/sqrt.rs index 94cfd2eee..cfcd6c07b 100644 --- a/core/src/operations/field/extensions/quadratic/sqrt.rs +++ b/core/src/operations/field/extensions/quadratic/sqrt.rs @@ -111,8 +111,7 @@ mod tests { use super::QuadFieldSqrtCols; use crate::air::MachineAir; - use crate::air::WordAirBuilder; - + use crate::air::{EventLens, WordAirBuilder}; use crate::bytes::event::ByteRecord; use crate::operations::field::params::{FieldParameters, Limbs}; use crate::runtime::ExecutionRecord; @@ -151,6 +150,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for QuadSqrtChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &() + } + } + impl MachineAir for QuadSqrtChip

{ type Record = ExecutionRecord; @@ -160,9 +169,9 @@ mod tests { "QuadSqrtChip".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let num_test_cols = size_of::>(); diff --git a/core/src/operations/field/field_den.rs b/core/src/operations/field/field_den.rs index 1a7580818..ff96412f1 100644 --- a/core/src/operations/field/field_den.rs +++ b/core/src/operations/field/field_den.rs @@ -152,7 +152,7 @@ mod tests { mem::size_of, }; - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use num::{bigint::RandBigInt, BigUint}; use p3_air::{Air, BaseAir}; use p3_baby_bear::BabyBear; @@ -197,6 +197,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for FieldDenChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &() + } + } + impl MachineAir for FieldDenChip

{ type Record = ExecutionRecord; @@ -206,9 +216,9 @@ mod tests { "FieldDen".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/field_inner_product.rs b/core/src/operations/field/field_inner_product.rs index aec17555e..0d1ca42a4 100644 --- a/core/src/operations/field/field_inner_product.rs +++ b/core/src/operations/field/field_inner_product.rs @@ -152,7 +152,7 @@ mod tests { use super::{FieldInnerProductCols, Limbs}; - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use crate::{ air::MachineAir, utils::ec::weierstrass::{bls12_381::Bls12381BaseField, secp256k1::Secp256k1BaseField}, @@ -185,6 +185,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for FieldIpChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &() + } + } + impl MachineAir for FieldIpChip

{ type Record = ExecutionRecord; @@ -194,9 +204,9 @@ mod tests { "FieldInnerProduct".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index cd9cb5fc5..654a8d10e 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -243,7 +243,7 @@ mod tests { use crate::{air::MachineAir, utils::ec::weierstrass::bls12_381::Bls12381BaseField}; - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use crate::bytes::event::ByteRecord; use crate::operations::field::params::FieldParameters; use crate::runtime::ExecutionRecord; @@ -278,6 +278,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for FieldOpChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &() + } + } + impl MachineAir for FieldOpChip

{ type Record = ExecutionRecord; @@ -287,9 +297,9 @@ mod tests { format!("FieldOp{:?}", self.operation) } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index 47e9f78ec..cfd77c19e 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -119,8 +119,8 @@ mod tests { use super::{FieldSqrtCols, Limbs}; - use crate::air::MachineAir; use crate::air::WordAirBuilder; + use crate::air::{EventLens, MachineAir}; use crate::bytes::event::ByteRecord; use crate::operations::field::params::{FieldParameters, DEFAULT_NUM_LIMBS_T}; use crate::runtime::ExecutionRecord; @@ -151,6 +151,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for EdSqrtChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &() + } + } + impl> MachineAir for EdSqrtChip

{ @@ -162,9 +172,9 @@ mod tests { "EdSqrtChip".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/params.rs b/core/src/operations/field/params.rs index 0d72d4222..a56da381b 100644 --- a/core/src/operations/field/params.rs +++ b/core/src/operations/field/params.rs @@ -14,8 +14,6 @@ use num::BigUint; use p3_field::Field; use crate::air::Polynomial; -use crate::runtime::ExecutionRecord; -use crate::syscall::precompiles; use crate::utils::ec::utils::biguint_from_limbs; pub const NB_BITS_PER_LIMB: usize = 8; @@ -179,36 +177,6 @@ pub trait FieldParameters: } } -pub trait WithFieldAddition: FieldParameters { - fn add_events(record: &ExecutionRecord) -> &[precompiles::field::add::FieldAddEvent]; -} - -pub trait WithFieldSubtraction: FieldParameters { - fn sub_events(record: &ExecutionRecord) -> &[precompiles::field::sub::FieldSubEvent]; -} - -pub trait WithFieldMultiplication: FieldParameters { - fn mul_events(record: &ExecutionRecord) -> &[precompiles::field::mul::FieldMulEvent]; -} - -pub trait WithQuadFieldAddition: FieldParameters { - fn add_events( - record: &ExecutionRecord, - ) -> &[precompiles::quad_field::add::QuadFieldAddEvent]; -} - -pub trait WithQuadFieldSubtraction: FieldParameters { - fn sub_events( - record: &ExecutionRecord, - ) -> &[precompiles::quad_field::sub::QuadFieldSubEvent]; -} - -pub trait WithQuadFieldMultiplication: FieldParameters { - fn mul_events( - record: &ExecutionRecord, - ) -> &[precompiles::quad_field::mul::QuadFieldMulEvent]; -} - #[cfg(test)] mod tests { use crate::operations::field::params::FieldParameters; diff --git a/core/src/program/mod.rs b/core/src/program/mod.rs index 64b16d6d2..623bb2dde 100644 --- a/core/src/program/mod.rs +++ b/core/src/program/mod.rs @@ -10,8 +10,11 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_derive::AlignedBorrow; use crate::{ - air::{MachineAir, ProgramAirBuilder}, - cpu::columns::{InstructionCols, OpcodeSelectorCols}, + air::{EventLens, MachineAir, ProgramAirBuilder, WithEvents}, + cpu::{ + columns::{InstructionCols, OpcodeSelectorCols}, + CpuEvent, + }, runtime::{ExecutionRecord, Program}, utils::pad_to_power_of_two, }; @@ -49,6 +52,15 @@ impl ProgramChip { } } +impl<'a> WithEvents<'a> for ProgramChip { + type Events = ( + // CPU events + &'a [CpuEvent], + // the Program + &'a Program, + ); +} + impl MachineAir for ProgramChip { type Record = ExecutionRecord; @@ -92,39 +104,43 @@ impl MachineAir for ProgramChip { Some(trace) } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. + let (cpu_events, program) = input.events(); // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); - input.cpu_events.iter().for_each(|event| { + for event in cpu_events.iter() { let pc = event.pc; instruction_counts .entry(pc) .and_modify(|count| *count += 1) .or_insert(1); - }); + } - let rows = input - .program + let rows = program .instructions .clone() .into_iter() .enumerate() .map(|(i, _)| { - let pc = input.program.pc_base + (i as u32 * 4); + let pc = program.pc_base + (i as u32 * 4); let mut row = [F::zero(); NUM_PROGRAM_MULT_COLS]; let cols: &mut ProgramMultiplicityCols = row.as_mut_slice().borrow_mut(); - cols.shard = F::from_canonical_u32(input.index); + cols.shard = F::from_canonical_u32(input.index()); cols.multiplicity = F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0)); row diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 2e95b6c42..180b6e22f 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -5,12 +5,10 @@ use std::{ }; use itertools::Itertools; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use serde::{Deserialize, Serialize}; use super::{program::Program, Opcode}; -use crate::alu::AluEvent; -use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::cpu::CpuEvent; use crate::runtime::MemoryInitializeFinalizeEvent; @@ -24,6 +22,29 @@ use crate::syscall::precompiles::keccak256::KeccakPermuteEvent; use crate::syscall::precompiles::sha256::{ShaCompressEvent, ShaExtendEvent}; use crate::syscall::precompiles::{ECAddEvent, ECDoubleEvent}; use crate::utils::env; +use crate::{ + air::EventLens, + alu::AluEvent, + memory::MemoryProgramChip, + stark::{ + AddSubChip, BitwiseChip, Blake3CompressInnerChip, ByteChip, CpuChip, DivRemChip, + Ed25519Parameters, EdAddAssignChip, EdDecompressChip, FieldAddChip, FieldMulChip, + FieldSubChip, KeccakPermuteChip, LtChip, MemoryChip, MulChip, ProgramChip, + QuadFieldAddChip, QuadFieldMulChip, QuadFieldSubChip, ShaCompressChip, ShaExtendChip, + ShiftLeft, ShiftRightChip, WeierstrassAddAssignChip, WeierstrassDoubleAssignChip, + }, + syscall::precompiles::{ + bls12_381::{ + g1_decompress::Bls12381G1DecompressChip, g2_add::Bls12381G2AffineAddChip, + g2_double::Bls12381G2AffineDoubleChip, + }, + secp256k1::decompress::Secp256k1DecompressChip, + }, + utils::ec::{ + edwards::ed25519::Ed25519, + weierstrass::{bls12_381::Bls12381, bn254::Bn254, secp256k1::Secp256k1}, + }, +}; use crate::{ air::PublicValues, operations::field::params::FieldParameters, @@ -35,6 +56,7 @@ use crate::{ }, utils::ec::weierstrass::bls12_381::Bls12381BaseField, }; +use crate::{bytes::event::ByteRecord, stark::Indexed}; /// A record of the execution of a program. Contains event data for everything that happened during /// the execution of the shard. @@ -122,6 +144,225 @@ pub struct ExecutionRecord { pub public_values: PublicValues, } +// Event lenses connect the record to the events relative to a particular chip +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + (&self.add_events, &self.sub_events) + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.bitwise_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.divrem_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.lt_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.mul_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.shift_left_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.shift_right_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.byte_lookups + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.cpu_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + (&self.memory_initialize_events, &self.memory_finalize_events) + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.program.memory_image + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + (&self.cpu_events, &self.program) + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.sha_extend_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.sha_compress_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.blake3_compress_inner_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.keccak_permute_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.bls12381_g1_decompress_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.secp256k1_decompress_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.bls12381_g2_add_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> >::Events { + &self.bls12381_g2_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_fp_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_fp_sub_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_fp_mul_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_fp2_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_fp2_sub_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_fp2_mul_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.secp256k1_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_g1_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.bn254_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.secp256k1_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.bls12381_g1_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.bn254_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + &self.ed_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { + &self.ed_decompress_events + } +} + pub struct ShardingConfig { pub shard_size: usize, pub add_len: usize, @@ -185,12 +426,14 @@ impl Default for ShardingConfig { } } -impl MachineRecord for ExecutionRecord { - type Config = ShardingConfig; - +impl Indexed for ExecutionRecord { fn index(&self) -> u32 { self.index } +} + +impl MachineRecord for ExecutionRecord { + type Config = ShardingConfig; fn set_index(&mut self, index: u32) { self.index = index; diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index 7f7f7b93b..e536a9631 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -11,6 +11,7 @@ use crate::utils::ec::weierstrass::bls12_381::{Bls12381BaseField, Bls12381Parame use crate::StarkGenericConfig; use p3_field::PrimeField32; pub use riscv_chips::*; +use sphinx_derive::{EventLens, WithEvents}; use tracing::instrument; /// A module for importing all the different RISC-V chips. @@ -42,7 +43,8 @@ pub(crate) mod riscv_chips { /// This enum contains all the different AIRs that are used in the Sp1 RISC-V IOP. Each variant is /// a different AIR that is used to encode a different part of the RISC-V execution, and the /// different AIR variants have a joint lookup argument. -#[derive(MachineAir)] +#[derive(WithEvents, EventLens, MachineAir)] +#[record_type = "crate::runtime::ExecutionRecord"] pub enum RiscvAir { /// An AIR that contains a preprocessed program table and a lookup for the instructions. Program(ProgramChip), diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index 39349b395..8805f5e5e 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -28,6 +28,12 @@ pub struct Chip { log_quotient_degree: usize, } +impl AsRef for Chip { + fn as_ref(&self) -> &A { + &self.air + } +} + impl Chip { /// The send interactions of the chip. pub fn sends(&self) -> &[Interaction] { @@ -151,40 +157,6 @@ where } } -impl MachineAir for Chip -where - F: Field, - A: MachineAir, -{ - type Record = A::Record; - - type Program = A::Program; - - fn name(&self) -> String { - self.air.name() - } - - fn preprocessed_width(&self) -> usize { - >::preprocessed_width(&self.air) - } - - fn generate_preprocessed_trace(&self, program: &A::Program) -> Option> { - >::generate_preprocessed_trace(&self.air, program) - } - - fn generate_trace(&self, input: &A::Record, output: &mut A::Record) -> RowMajorMatrix { - self.air.generate_trace(input, output) - } - - fn generate_dependencies(&self, input: &A::Record, output: &mut A::Record) { - self.air.generate_dependencies(input, output) - } - - fn included(&self, shard: &Self::Record) -> bool { - self.air.included(shard) - } -} - // Implement AIR directly on Chip, evaluating both execution and permutation constraints. impl Air for Chip where diff --git a/core/src/stark/debug.rs b/core/src/stark/debug.rs index e45cc48bf..e7aa51797 100644 --- a/core/src/stark/debug.rs +++ b/core/src/stark/debug.rs @@ -101,7 +101,7 @@ pub fn debug_constraints( if result.is_err() { eprintln!("local: {:?}", main_local); eprintln!("next: {:?}", main_next); - eprintln!("failed at row {} of chip {}", i, chip.name()); + eprintln!("failed at row {} of chip {}", i, chip.as_ref().name()); exit(1); } }); diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 38554a685..917221679 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -25,7 +25,7 @@ use crate::air::MachineProgram; use crate::lookup::debug_interactions_with_all_chips; use crate::lookup::InteractionBuilder; use crate::lookup::InteractionKind; -use crate::stark::record::MachineRecord; +use crate::stark::record::{Indexed, MachineRecord}; use crate::stark::DebugConstraintBuilder; use crate::stark::ProverConstraintFolder; use crate::stark::ShardProof; @@ -120,7 +120,7 @@ impl>> StarkMachine { self.chips .iter() .enumerate() - .filter(|(_, chip)| chip.preprocessed_width() > 0) + .filter(|(_, chip)| chip.as_ref().preprocessed_width() > 0) .map(|(i, _)| i) .collect() } @@ -132,7 +132,9 @@ impl>> StarkMachine { where 'a: 'b, { - self.chips.iter().filter(|chip| chip.included(shard)) + self.chips + .iter() + .filter(|chip| chip.as_ref().included(shard)) } pub fn shard_chips_ordered<'a, 'b>( @@ -144,14 +146,14 @@ impl>> StarkMachine { { self.chips .iter() - .filter(|chip| chip_ordering.contains_key(&chip.name())) - .sorted_by_key(|chip| chip_ordering.get(&chip.name())) + .filter(|chip| chip_ordering.contains_key(&chip.as_ref().name())) + .sorted_by_key(|chip| chip_ordering.get(&chip.as_ref().name())) } pub fn chips_sorted_indices(&self, proof: &ShardProof) -> Vec> { self.chips() .iter() - .map(|chip| proof.chip_ordering.get(&chip.name()).cloned()) + .map(|chip| proof.chip_ordering.get(&chip.as_ref().name()).cloned()) .collect() } @@ -166,17 +168,17 @@ impl>> StarkMachine { self.chips() .iter() .map(|chip| { - let prep_trace = chip.generate_preprocessed_trace(program); + let prep_trace = chip.as_ref().generate_preprocessed_trace(program); // Assert that the chip width data is correct. let expected_width = prep_trace.as_ref().map_or(0, |t| t.width()); assert_eq!( expected_width, - chip.preprocessed_width(), + chip.as_ref().preprocessed_width(), "Incorrect number of preprocessed columns for chip {}", - chip.name() + chip.as_ref().name() ); - (chip.name(), prep_trace) + (chip.as_ref().name(), prep_trace) }) .filter(|(_, prep_trace)| prep_trace.is_some()) .map(|(name, prep_trace)| { @@ -251,7 +253,7 @@ impl>> StarkMachine { chips.iter().for_each(|chip| { let mut output = A::Record::default(); output.set_index(record.index()); - chip.generate_dependencies(&record, &mut output); + chip.as_ref().generate_dependencies(&record, &mut output); record.append(&mut output); }) }); @@ -375,13 +377,16 @@ impl>> StarkMachine { .iter() .map(|chip| { pk.chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .map(|index| &pk.traces[*index]) }) .collect::>(); let mut traces = chips .par_iter() - .map(|chip| chip.generate_trace(shard, &mut A::Record::default())) + .map(|chip| { + chip.as_ref() + .generate_trace(shard, &mut A::Record::default()) + }) .zip(pre_traces) .collect::>(); @@ -424,7 +429,7 @@ impl>> StarkMachine { let total_width = trace_width + permutation_width; tracing::debug!( "{:<11} | Main Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<10} | Cells = {:<10}", - chips[i].name(), + chips[i].as_ref().name(), trace_width, permutation_width, traces[i].0.height(), @@ -436,7 +441,7 @@ impl>> StarkMachine { for i in 0..chips.len() { let permutation_trace = pk .chip_ordering - .get(&chips[i].name()) + .get(&chips[i].as_ref().name()) .map(|index| &pk.traces[*index]); debug_constraints::( chips[i], diff --git a/core/src/stark/prover.rs b/core/src/stark/prover.rs index 2252693eb..94bef3a86 100644 --- a/core/src/stark/prover.rs +++ b/core/src/stark/prover.rs @@ -26,6 +26,7 @@ use super::{StarkProvingKey, VerifierConstraintFolder}; use crate::air::MachineAir; use crate::lookup::InteractionBuilder; use crate::stark::record::MachineRecord; +use crate::stark::Indexed; use crate::stark::MachineChip; use crate::stark::PackedChallenge; use crate::stark::ProverConstraintFolder; @@ -171,14 +172,14 @@ where shard_chips .par_iter() .map(|chip| { - let chip_name = chip.name(); + let chip_name = chip.as_ref().name(); // We need to create an outer span here because, for some reason, // the #[instrument] macro on the chip impl isn't attaching its span to `parent_span` // to avoid the unnecessary span, remove the #[instrument] macro. let trace = tracing::debug_span!(parent: &parent_span, "generate trace for chip", %chip_name) - .in_scope(|| chip.generate_trace(shard, &mut A::Record::default())); + .in_scope(|| chip.as_ref().generate_trace(shard, &mut A::Record::default())); (chip_name, trace) }) .collect::>() @@ -282,7 +283,7 @@ where .map(|(chip, main_trace)| { let preprocessed_trace = pk .chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .map(|&index| &pk.traces[index]); let perm_trace = chip.generate_permutation_trace( preprocessed_trace, @@ -307,7 +308,7 @@ where + permutation_width * >::D; tracing::debug!( "{:<15} | Main Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}", - chips[i].name(), + chips[i].as_ref().name(), trace_width, permutation_width * >::D, traces[i].height(), @@ -358,7 +359,7 @@ where .in_scope(|| { let preprocessed_trace_on_quotient_domains = pk .chip_ordering - .get(&chips[i].name()) + .get(&chips[i].as_ref().name()) .map_or_else(|| { RowMajorMatrix::new_col(vec![ SC::Val::zero(); @@ -506,7 +507,7 @@ where .enumerate() .map( |(i, ((((main, permutation), quotient), cumulative_sum), log_degree))| { - let preprocessed = pk.chip_ordering.get(&chips[i].name()).map_or( + let preprocessed = pk.chip_ordering.get(&chips[i].as_ref().name()).map_or( AirOpenedValues { local: vec![], next: vec![], diff --git a/core/src/stark/record.rs b/core/src/stark/record.rs index 226188514..10e035eaa 100644 --- a/core/src/stark/record.rs +++ b/core/src/stark/record.rs @@ -2,10 +2,12 @@ use std::collections::HashMap; use p3_field::AbstractField; -pub trait MachineRecord: Default + Sized + Send + Sync + Clone { - type Config: Default; - +pub trait Indexed { fn index(&self) -> u32; +} + +pub trait MachineRecord: Default + Sized + Send + Sync + Clone + Indexed { + type Config: Default; fn set_index(&mut self, index: u32); diff --git a/core/src/stark/verifier.rs b/core/src/stark/verifier.rs index be766413d..0c90ef216 100644 --- a/core/src/stark/verifier.rs +++ b/core/src/stark/verifier.rs @@ -185,7 +185,7 @@ impl>> Verifier { ) { // Verify the shape of the opening arguments matches the expected values. Self::verify_opening_shape(chip, values) - .map_err(|e| VerificationError::OpeningShapeError(chip.name(), e))?; + .map_err(|e| VerificationError::OpeningShapeError(chip.as_ref().name(), e))?; // Verify the constraint evaluation. Self::verify_constraints( chip, @@ -197,7 +197,7 @@ impl>> Verifier { &permutation_challenges, public_values, ) - .map_err(|_e| VerificationError::OodEvaluationMismatch(chip.name()))?; + .map_err(|_e| VerificationError::OodEvaluationMismatch(chip.as_ref().name()))?; } Ok(()) } @@ -207,15 +207,15 @@ impl>> Verifier { opening: &ChipOpenedValues, ) -> Result<(), OpeningShapeError> { // Verify that the preprocessed width matches the expected value for the chip. - if opening.preprocessed.local.len() != chip.preprocessed_width() { + if opening.preprocessed.local.len() != chip.as_ref().preprocessed_width() { return Err(OpeningShapeError::PreprocessedWidthMismatch( - chip.preprocessed_width(), + chip.as_ref().preprocessed_width(), opening.preprocessed.local.len(), )); } - if opening.preprocessed.next.len() != chip.preprocessed_width() { + if opening.preprocessed.next.len() != chip.as_ref().preprocessed_width() { return Err(OpeningShapeError::PreprocessedWidthMismatch( - chip.preprocessed_width(), + chip.as_ref().preprocessed_width(), opening.preprocessed.next.len(), )); } diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs index d151d9c20..166bcd79d 100644 --- a/core/src/syscall/precompiles/blake3/compress/trace.rs +++ b/core/src/syscall/precompiles/blake3/compress/trace.rs @@ -3,10 +3,12 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; +use super::Blake3CompressInnerEvent; use super::{ columns::Blake3CompressInnerCols, G_INDEX, G_INPUT_SIZE, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, }; +use crate::air::{EventLens, WithEvents}; use crate::bytes::event::ByteRecord; use crate::{ air::MachineAir, @@ -17,6 +19,10 @@ use crate::{ utils::pad_rows, }; +impl<'a> WithEvents<'a> for Blake3CompressInnerChip { + type Events = &'a [Blake3CompressInnerEvent]; +} + impl MachineAir for Blake3CompressInnerChip { type Record = ExecutionRecord; type Program = Program; @@ -25,17 +31,17 @@ impl MachineAir for Blake3CompressInnerChip { "Blake3CompressInner".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for i in 0..input.blake3_compress_inner_events.len() { - let event = input.blake3_compress_inner_events[i].clone(); + for i in 0..input.events().len() { + let event = input.events()[i].clone(); let shard = event.shard; let mut clk = event.clk; for round in 0..ROUND_COUNT { diff --git a/core/src/syscall/precompiles/bls12_381/g1_decompress.rs b/core/src/syscall/precompiles/bls12_381/g1_decompress.rs index a27d44059..b01c18c46 100644 --- a/core/src/syscall/precompiles/bls12_381/g1_decompress.rs +++ b/core/src/syscall/precompiles/bls12_381/g1_decompress.rs @@ -11,7 +11,7 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use serde::{Deserialize, Serialize}; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, ByteAirBuilder, MemoryAirBuilder}; +use crate::air::{AluAirBuilder, ByteAirBuilder, EventLens, MemoryAirBuilder, WithEvents}; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::operations::field::params::FieldParameters; use crate::operations::field::range::FieldRangeCols; @@ -229,6 +229,10 @@ impl Bls12381G1DecompressChip { } } +impl<'a> WithEvents<'a> for Bls12381G1DecompressChip { + type Events = &'a [Bls12381G1DecompressEvent]; +} + impl MachineAir for Bls12381G1DecompressChip { type Record = ExecutionRecord; type Program = Program; @@ -237,16 +241,16 @@ impl MachineAir for Bls12381G1DecompressChip { "Bls12381G1Decompress".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for event in input.bls12381_g1_decompress_events.iter() { + for event in input.events().iter() { let mut row = [F::zero(); size_of::>()]; let cols: &mut Bls12381G1DecompressCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/bls12_381/g2_add.rs b/core/src/syscall/precompiles/bls12_381/g2_add.rs index 0a1711058..a5e4ba019 100644 --- a/core/src/syscall/precompiles/bls12_381/g2_add.rs +++ b/core/src/syscall/precompiles/bls12_381/g2_add.rs @@ -1,4 +1,4 @@ -use crate::air::{AluAirBuilder, MachineAir, MemoryAirBuilder}; +use crate::air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}; use crate::bytes::event::ByteRecord; use crate::memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}; use crate::operations::field::extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}; @@ -276,6 +276,10 @@ impl BaseAir for Bls12381G2AffineAddChip { } } +impl<'a> WithEvents<'a> for Bls12381G2AffineAddChip { + type Events = &'a [Bls12381G2AffineAddEvent]; +} + impl MachineAir for Bls12381G2AffineAddChip { type Record = ExecutionRecord; type Program = Program; @@ -284,13 +288,17 @@ impl MachineAir for Bls12381G2AffineAddChip { "G2AffineAdd".to_string() } - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + fn generate_trace>( + &self, + input: &EL, + output: &mut Self::Record, + ) -> RowMajorMatrix { let mut rows = vec![]; let mut new_byte_lookup_events = Vec::new(); let width = >::width(self); - for event in &input.bls12381_g2_add_events { + for event in input.events() { let mut row = vec![F::zero(); width]; let cols: &mut Bls12381G2AffineAddCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/bls12_381/g2_double.rs b/core/src/syscall/precompiles/bls12_381/g2_double.rs index 29715c5dc..b43253add 100644 --- a/core/src/syscall/precompiles/bls12_381/g2_double.rs +++ b/core/src/syscall/precompiles/bls12_381/g2_double.rs @@ -1,4 +1,4 @@ -use crate::air::MachineAir; +use crate::air::{EventLens, MachineAir, WithEvents}; use crate::bytes::event::ByteRecord; use crate::memory::{MemoryCols, MemoryWriteCols}; use crate::operations::field::extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}; @@ -225,6 +225,10 @@ impl BaseAir for Bls12381G2AffineDoubleChip { } } +impl<'a> WithEvents<'a> for Bls12381G2AffineDoubleChip { + type Events = &'a [Bls12381G2AffineDoubleEvent]; +} + impl MachineAir for Bls12381G2AffineDoubleChip { type Record = ExecutionRecord; type Program = Program; @@ -233,14 +237,18 @@ impl MachineAir for Bls12381G2AffineDoubleChip { "Bls12381G2AffineDoubleChip".to_string() } - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + fn generate_trace>( + &self, + input: &EL, + output: &mut Self::Record, + ) -> RowMajorMatrix { let mut rows: Vec> = vec![]; let width = >::width(self); let mut new_byte_lookup_events = Vec::new(); - for event in &input.bls12381_g2_double_events { + for event in input.events() { let mut row = vec![F::zero(); width]; let cols: &mut Bls12381G2AffineDoubleCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index cbba5d5ac..e8cf48cf3 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -17,7 +17,6 @@ use p3_maybe_rayon::prelude::IntoParallelRefIterator; use p3_maybe_rayon::prelude::ParallelIterator; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::memory::MemoryCols; @@ -44,6 +43,10 @@ use crate::utils::ec::EllipticCurve; use crate::utils::limbs_from_prev_access; use crate::utils::pad_vec_rows; use crate::{air::MachineAir, utils::ec::EllipticCurveParameters}; +use crate::{ + air::{AluAirBuilder, EventLens, MemoryAirBuilder, WithEvents}, + syscall::precompiles::ECAddEvent, +}; pub const NUM_ED_ADD_COLS: usize = size_of::>(); @@ -140,7 +143,14 @@ impl< } } -impl MachineAir for EdAddAssignChip { +impl<'a, E: EllipticCurve + EdwardsParameters> WithEvents<'a> for EdAddAssignChip { + type Events = &'a [ECAddEvent]; +} + +impl MachineAir for EdAddAssignChip +where + ExecutionRecord: EventLens>, +{ type Record = ExecutionRecord; type Program = Program; @@ -149,13 +159,13 @@ impl MachineAir for Ed "EdAddAssign".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let (mut rows, new_byte_lookup_events): (Vec>, Vec>) = input - .ed_add_events + .events() .par_iter() .map(|event| { let mut row = vec![F::zero(); size_of::>()]; diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index 2d65816fc..5d6054226 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -13,8 +13,10 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use serde::{Deserialize, Serialize}; use sphinx_derive::AlignedBorrow; +use crate::air::EventLens; use crate::air::MachineAir; use crate::air::MachineAirBuilder; +use crate::air::WithEvents; use crate::air::{AluAirBuilder, BaseAirBuilder, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; @@ -331,7 +333,14 @@ impl EdDecompressChip { } } -impl MachineAir for EdDecompressChip { +impl<'a, E: EdwardsParameters> WithEvents<'a> for EdDecompressChip { + type Events = &'a [EdDecompressEvent]; +} + +impl MachineAir for EdDecompressChip +where + ExecutionRecord: EventLens>, +{ type Record = ExecutionRecord; type Program = Program; @@ -340,15 +349,15 @@ impl MachineAir for EdDecompressChip>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); - for i in 0..input.ed_decompress_events.len() { - let event = &input.ed_decompress_events[i]; + for i in 0..input.events().len() { + let event = &input.events()[i]; let mut row = vec![F::zero(); size_of::>()]; let cols: &mut EdDecompressCols = row.as_mut_slice().borrow_mut(); cols.populate::(event, output); diff --git a/core/src/syscall/precompiles/field/add.rs b/core/src/syscall/precompiles/field/add.rs index 82b25ac77..542e02739 100644 --- a/core/src/syscall/precompiles/field/add.rs +++ b/core/src/syscall/precompiles/field/add.rs @@ -15,12 +15,12 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ field_op::{FieldOpCols, FieldOperation}, - params::{FieldParameters, FieldType, Limbs, WithFieldAddition, WORDS_FIELD_ELEMENT}, + params::{FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT}, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, syscall::precompiles::SyscallContext, @@ -113,25 +113,32 @@ pub fn create_fp_add_event( } } -impl MachineAir for FieldAddChip { +impl<'a, FP: FieldParameters> WithEvents<'a> for FieldAddChip { + type Events = &'a [FieldAddEvent]; +} + +impl MachineAir for FieldAddChip +where + ExecutionRecord: EventLens>, +{ type Record = ExecutionRecord; type Program = Program; fn name(&self) -> String { match FP::FIELD_TYPE { FieldType::Bls12381 => "Bls12381FieldAdd".to_string(), - _ => panic!("Unsupported field"), + _ => unreachable!("Unsupported field"), } } #[instrument(name = "generate field add trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::add_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/field/mul.rs b/core/src/syscall/precompiles/field/mul.rs index 32e13f59a..2d1fd8d48 100644 --- a/core/src/syscall/precompiles/field/mul.rs +++ b/core/src/syscall/precompiles/field/mul.rs @@ -15,12 +15,12 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ field_op::{FieldOpCols, FieldOperation}, - params::{FieldParameters, FieldType, Limbs, WithFieldMultiplication, WORDS_FIELD_ELEMENT}, + params::{FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT}, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, syscall::precompiles::SyscallContext, @@ -113,8 +113,13 @@ pub fn create_fp_mul_event( } } -impl MachineAir - for FieldMulChip +impl<'a, FP: FieldParameters> WithEvents<'a> for FieldMulChip { + type Events = &'a [FieldMulEvent]; +} + +impl MachineAir for FieldMulChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -122,18 +127,18 @@ impl MachineAir< fn name(&self) -> String { match FP::FIELD_TYPE { FieldType::Bls12381 => "Bls12381FieldMul".to_string(), - _ => panic!("Unsupported field"), + _ => unreachable!("Unsupported field"), } } #[instrument(name = "generate field mul trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::mul_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/field/sub.rs b/core/src/syscall/precompiles/field/sub.rs index 15725343b..65e772d0b 100644 --- a/core/src/syscall/precompiles/field/sub.rs +++ b/core/src/syscall/precompiles/field/sub.rs @@ -15,12 +15,12 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ field_op::{FieldOpCols, FieldOperation}, - params::{FieldParameters, FieldType, Limbs, WithFieldSubtraction, WORDS_FIELD_ELEMENT}, + params::{FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT}, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, syscall::precompiles::SyscallContext, @@ -113,8 +113,13 @@ pub fn create_fp_sub_event( } } -impl MachineAir - for FieldSubChip +impl<'a, FP: FieldParameters> WithEvents<'a> for FieldSubChip { + type Events = &'a [FieldSubEvent]; +} + +impl MachineAir for FieldSubChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -122,18 +127,18 @@ impl MachineAir fn name(&self) -> String { match FP::FIELD_TYPE { FieldType::Bls12381 => "Bls12381FieldSub".to_string(), - _ => panic!("Unsupported field"), + _ => unreachable!("Unsupported field"), } } #[instrument(name = "generate field sub trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::sub_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index 5c2c4e8fa..dd814d1c0 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -5,16 +5,22 @@ use p3_keccak_air::{generate_trace_rows, NUM_KECCAK_COLS, NUM_ROUNDS}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; +use crate::air::{EventLens, WithEvents}; use crate::bytes::event::ByteRecord; use crate::{runtime::Program, stark::MachineRecord}; use crate::{air::MachineAir, runtime::ExecutionRecord}; +use super::KeccakPermuteEvent; use super::{ columns::{KeccakMemCols, NUM_KECCAK_MEM_COLS}, KeccakPermuteChip, STATE_SIZE, }; +impl<'a> WithEvents<'a> for KeccakPermuteChip { + type Events = &'a [KeccakPermuteEvent]; +} + impl MachineAir for KeccakPermuteChip { type Record = ExecutionRecord; type Program = Program; @@ -23,36 +29,35 @@ impl MachineAir for KeccakPermuteChip { "KeccakPermute".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let num_events = input.keccak_permute_events.len(); + let num_events = input.events().len(); let chunk_size = std::cmp::max(num_events / num_cpus::get(), 1); // Use par_chunks to generate the trace in parallel. - let rows_and_records = (0..num_events) - .collect::>() + let rows_and_records = input + .events() .par_chunks(chunk_size) - .map(|chunk| { + .map(|chunk_events| { let mut record = ExecutionRecord::default(); let mut new_byte_lookup_events = Vec::new(); // First generate all the p3_keccak_air traces at once. - let perm_inputs = chunk + let perm_inputs = chunk_events .iter() - .map(|event_index| input.keccak_permute_events[*event_index].pre_state) + .map(|event| event.pre_state) .collect::>(); let p3_keccak_trace = generate_trace_rows::(perm_inputs); - let rows = chunk + let rows = chunk_events .iter() .enumerate() - .flat_map(|(index_in_chunk, event_index)| { + .flat_map(|(index_in_chunk, event)| { let mut rows = Vec::new(); - let event = &input.keccak_permute_events[*event_index]; let start_clk = event.clk; let shard = event.shard; diff --git a/core/src/syscall/precompiles/quad_field/add.rs b/core/src/syscall/precompiles/quad_field/add.rs index 3ee56f9b9..c96547783 100644 --- a/core/src/syscall/precompiles/quad_field/add.rs +++ b/core/src/syscall/precompiles/quad_field/add.rs @@ -15,14 +15,13 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}, params::{ - FieldParameters, FieldType, Limbs, WithQuadFieldAddition, WORDS_FIELD_ELEMENT, - WORDS_QUAD_EXT_FIELD_ELEMENT, + FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT, WORDS_QUAD_EXT_FIELD_ELEMENT, }, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, @@ -140,8 +139,13 @@ pub fn create_fp2_add_event( } } -impl MachineAir - for QuadFieldAddChip +impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldAddChip { + type Events = &'a [QuadFieldAddEvent]; +} + +impl MachineAir for QuadFieldAddChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -154,13 +158,13 @@ impl MachineAir } #[instrument(name = "generate bls12381 fp2 add trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::add_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/quad_field/mul.rs b/core/src/syscall/precompiles/quad_field/mul.rs index 06a05b040..265ba1dd0 100644 --- a/core/src/syscall/precompiles/quad_field/mul.rs +++ b/core/src/syscall/precompiles/quad_field/mul.rs @@ -15,14 +15,13 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}, params::{ - FieldParameters, FieldType, Limbs, WithQuadFieldMultiplication, WORDS_FIELD_ELEMENT, - WORDS_QUAD_EXT_FIELD_ELEMENT, + FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT, WORDS_QUAD_EXT_FIELD_ELEMENT, }, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, @@ -149,8 +148,13 @@ pub fn create_fp2_mul_event( } } -impl MachineAir - for QuadFieldMulChip +impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldMulChip { + type Events = &'a [QuadFieldMulEvent]; +} + +impl MachineAir for QuadFieldMulChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -163,13 +167,13 @@ impl Machine } #[instrument(name = "generate bls12381 fp2 mul trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::mul_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/quad_field/sub.rs b/core/src/syscall/precompiles/quad_field/sub.rs index e21390c18..ae8780a3f 100644 --- a/core/src/syscall/precompiles/quad_field/sub.rs +++ b/core/src/syscall/precompiles/quad_field/sub.rs @@ -15,14 +15,13 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}, params::{ - FieldParameters, FieldType, Limbs, WithQuadFieldSubtraction, WORDS_FIELD_ELEMENT, - WORDS_QUAD_EXT_FIELD_ELEMENT, + FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT, WORDS_QUAD_EXT_FIELD_ELEMENT, }, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, @@ -140,8 +139,13 @@ pub fn create_fp2_sub_event( } } -impl MachineAir - for QuadFieldSubChip +impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldSubChip { + type Events = &'a [QuadFieldSubEvent]; +} + +impl MachineAir for QuadFieldSubChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -154,13 +158,13 @@ impl MachineAir } #[instrument(name = "generate bls12381 fp2 sub trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::sub_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/secp256k1/decompress.rs b/core/src/syscall/precompiles/secp256k1/decompress.rs index 5eff9bafc..723a4047b 100644 --- a/core/src/syscall/precompiles/secp256k1/decompress.rs +++ b/core/src/syscall/precompiles/secp256k1/decompress.rs @@ -11,7 +11,7 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use serde::{Deserialize, Serialize}; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, ByteAirBuilder, MemoryAirBuilder}; +use crate::air::{AluAirBuilder, ByteAirBuilder, EventLens, MemoryAirBuilder, WithEvents}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::operations::field::range::FieldRangeCols; @@ -214,6 +214,10 @@ impl Secp256k1DecompressChip { } } +impl<'a> WithEvents<'a> for Secp256k1DecompressChip { + type Events = &'a [Secp256k1DecompressEvent]; +} + impl MachineAir for Secp256k1DecompressChip { type Record = ExecutionRecord; type Program = Program; @@ -222,16 +226,16 @@ impl MachineAir for Secp256k1DecompressChip { "Secp256k1Decompress".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for event in input.secp256k1_decompress_events.iter() { + for event in input.events().iter() { let mut row = [F::zero(); size_of::>()]; let cols: &mut Secp256k1DecompressCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index f0c2409db..059d14bf8 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -5,15 +5,19 @@ use p3_matrix::dense::RowMajorMatrix; use super::{ columns::{ShaCompressCols, NUM_SHA_COMPRESS_COLS}, - ShaCompressChip, SHA_COMPRESS_K, + ShaCompressChip, ShaCompressEvent, SHA_COMPRESS_K, }; use crate::{ - air::{MachineAir, Word}, + air::{EventLens, MachineAir, WithEvents, Word}, bytes::event::ByteRecord, runtime::{ExecutionRecord, Program}, utils::pad_rows, }; +impl<'a> WithEvents<'a> for ShaCompressChip { + type Events = &'a [ShaCompressEvent]; +} + impl MachineAir for ShaCompressChip { type Record = ExecutionRecord; @@ -23,16 +27,16 @@ impl MachineAir for ShaCompressChip { "ShaCompress".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for i in 0..input.sha_compress_events.len() { - let mut event = input.sha_compress_events[i].clone(); + for i in 0..input.events().len() { + let mut event = input.events()[i].clone(); let shard = event.shard; let og_h = event.h; diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index faadfa889..719ba0af6 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -3,13 +3,17 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; -use super::{ShaExtendChip, ShaExtendCols, NUM_SHA_EXTEND_COLS}; +use super::{ShaExtendChip, ShaExtendCols, ShaExtendEvent, NUM_SHA_EXTEND_COLS}; use crate::{ - air::MachineAir, + air::{EventLens, MachineAir, WithEvents}, bytes::event::ByteRecord, runtime::{ExecutionRecord, Program}, }; +impl<'a> WithEvents<'a> for ShaExtendChip { + type Events = &'a [ShaExtendEvent]; +} + impl MachineAir for ShaExtendChip { type Record = ExecutionRecord; @@ -19,16 +23,16 @@ impl MachineAir for ShaExtendChip { "ShaExtend".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for i in 0..input.sha_extend_events.len() { - let event = input.sha_extend_events[i].clone(); + for i in 0..input.events().len() { + let event = input.events()[i].clone(); let shard = event.shard; for j in 0..48usize { let mut row = [F::zero(); NUM_SHA_EXTEND_COLS]; diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index 35b6158b5..d527d6e8e 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -16,7 +16,6 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, MachineAir, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::memory::MemoryCols; @@ -34,9 +33,12 @@ use crate::utils::ec::AffinePoint; use crate::utils::ec::BaseLimbWidth; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurve; -use crate::utils::ec::WithAddition; use crate::utils::limbs_from_prev_access; use crate::utils::pad_vec_rows; +use crate::{ + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, + syscall::precompiles::ECAddEvent, +}; pub const fn num_weierstrass_add_cols() -> usize { size_of::>() @@ -147,8 +149,14 @@ impl WeierstrassAddAssignChip { } } -impl MachineAir +impl<'a, E: EllipticCurve + WeierstrassParameters> WithEvents<'a> for WeierstrassAddAssignChip { + type Events = &'a [ECAddEvent<::NB_LIMBS>]; +} + +impl MachineAir for WeierstrassAddAssignChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -158,17 +166,17 @@ impl M CurveType::Secp256k1 => "Secp256k1AddAssign".to_string(), CurveType::Bn254 => "Bn254AddAssign".to_string(), CurveType::Bls12381 => "Bls12381AddAssign".to_string(), - _ => panic!("Unsupported curve"), + _ => unreachable!("Unsupported curve"), } } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the curve type. - let events = E::add_events(input); + let events = input.events(); let mut rows = Vec::new(); diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index c12c5b2cd..eff920ac7 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -17,7 +17,6 @@ use p3_maybe_rayon::prelude::ParallelIterator; use p3_maybe_rayon::prelude::ParallelSlice; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, MachineAir, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::memory::MemoryCols; @@ -37,9 +36,12 @@ use crate::utils::ec::AffinePoint; use crate::utils::ec::BaseLimbWidth; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurve; -use crate::utils::ec::WithDoubling; use crate::utils::limbs_from_prev_access; use crate::utils::pad_vec_rows; +use crate::{ + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, + syscall::precompiles::ECDoubleEvent, +}; pub const fn num_weierstrass_double_cols() -> usize { size_of::>() @@ -171,8 +173,16 @@ impl WeierstrassDoubleAssignChip { } } -impl MachineAir +impl<'a, E: EllipticCurve + WeierstrassParameters> WithEvents<'a> + for WeierstrassDoubleAssignChip +{ + type Events = &'a [ECDoubleEvent<::NB_LIMBS>]; +} + +impl MachineAir for WeierstrassDoubleAssignChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -186,13 +196,13 @@ impl M } } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the curve type. - let events = E::double_events(input); + let events = input.events(); let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1); diff --git a/core/src/utils/ec/edwards/ed25519.rs b/core/src/utils/ec/edwards/ed25519.rs index 3f11847a0..2fd5877bb 100644 --- a/core/src/utils/ec/edwards/ed25519.rs +++ b/core/src/utils/ec/edwards/ed25519.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::operations::field::params::{FieldParameters, FieldType, DEFAULT_NUM_LIMBS_T}; use crate::utils::ec::edwards::{EdwardsCurve, EdwardsParameters}; -use crate::utils::ec::{AffinePoint, CurveType, EllipticCurveParameters, WithAddition}; +use crate::utils::ec::{AffinePoint, CurveType, EllipticCurveParameters}; pub type Ed25519 = EdwardsCurve; @@ -39,15 +39,6 @@ impl EllipticCurveParameters for Ed25519Parameters { const CURVE_TYPE: CurveType = CurveType::Ed25519; } -impl WithAddition for Ed25519Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.ed_add_events - } -} - impl EdwardsParameters for Ed25519Parameters { const D: Array::NB_LIMBS> = Array([ 30883, 4953, 19914, 30187, 55467, 16705, 2637, 112, 59544, 30585, 16505, 36039, 65139, diff --git a/core/src/utils/ec/mod.rs b/core/src/utils/ec/mod.rs index 0ead9b6e8..ceaf4a92a 100644 --- a/core/src/utils/ec/mod.rs +++ b/core/src/utils/ec/mod.rs @@ -14,8 +14,6 @@ use crate::air::WORD_SIZE; use crate::operations::field::params::FieldParameters; use crate::operations::field::params::WORDS_CURVEPOINT; use crate::operations::field::params::WORDS_FIELD_ELEMENT; -use crate::runtime::ExecutionRecord; -use crate::syscall::precompiles::{ECAddEvent, ECDoubleEvent}; pub const DEFAULT_NUM_WORDS_FIELD_ELEMENT: usize = 8; pub const DEFAULT_NUM_BYTES_FIELD_ELEMENT: usize = DEFAULT_NUM_WORDS_FIELD_ELEMENT * WORD_SIZE; @@ -122,18 +120,6 @@ pub trait EllipticCurveParameters: const CURVE_TYPE: CurveType; } -pub trait WithAddition: EllipticCurveParameters { - fn add_events( - record: &ExecutionRecord, - ) -> &[ECAddEvent<::NB_LIMBS>]; -} - -pub trait WithDoubling: EllipticCurveParameters { - fn double_events( - record: &ExecutionRecord, - ) -> &[ECDoubleEvent<::NB_LIMBS>]; -} - /// An interface for elliptic curve groups. pub trait EllipticCurve: EllipticCurveParameters { /// Adds two different points on the curve. diff --git a/core/src/utils/ec/weierstrass/bls12_381.rs b/core/src/utils/ec/weierstrass/bls12_381.rs index 30689f2fb..4606edb6e 100644 --- a/core/src/utils/ec/weierstrass/bls12_381.rs +++ b/core/src/utils/ec/weierstrass/bls12_381.rs @@ -10,12 +10,6 @@ use std::ops::Neg; use super::{SwCurve, WeierstrassParameters}; use crate::operations::field::params::FieldParameters; use crate::operations::field::params::FieldType; -use crate::operations::field::params::WithFieldAddition; -use crate::operations::field::params::WithFieldMultiplication; -use crate::operations::field::params::WithFieldSubtraction; -use crate::operations::field::params::WithQuadFieldAddition; -use crate::operations::field::params::WithQuadFieldMultiplication; -use crate::operations::field::params::WithQuadFieldSubtraction; use crate::runtime::Syscall; use crate::stark::WeierstrassAddAssignChip; use crate::stark::WeierstrassDoubleAssignChip; @@ -23,8 +17,6 @@ use crate::syscall::precompiles::create_ec_add_event; use crate::syscall::precompiles::create_ec_double_event; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurveParameters; -use crate::utils::ec::WithAddition; -use crate::utils::ec::WithDoubling; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] /// Bls12381 curve parameter @@ -188,78 +180,11 @@ pub fn bls12381_double(p: &[BigUint; 4]) -> [BigUint; 4] { ] } -impl WithFieldAddition for Bls12381BaseField { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::field::add::FieldAddEvent] { - &record.bls12381_fp_add_events - } -} - -impl WithFieldSubtraction for Bls12381BaseField { - fn sub_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::field::sub::FieldSubEvent] { - &record.bls12381_fp_sub_events - } -} - -impl WithFieldMultiplication for Bls12381BaseField { - fn mul_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::field::mul::FieldMulEvent] { - &record.bls12381_fp_mul_events - } -} - -impl WithQuadFieldAddition for Bls12381BaseField { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::quad_field::add::QuadFieldAddEvent] { - &record.bls12381_fp2_add_events - } -} - -impl WithQuadFieldSubtraction for Bls12381BaseField { - fn sub_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::quad_field::sub::QuadFieldSubEvent] { - &record.bls12381_fp2_sub_events - } -} - -impl WithQuadFieldMultiplication for Bls12381BaseField { - fn mul_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::quad_field::mul::QuadFieldMulEvent] { - &record.bls12381_fp2_mul_events - } -} - impl EllipticCurveParameters for Bls12381Parameters { type BaseField = Bls12381BaseField; const CURVE_TYPE: CurveType = CurveType::Bls12381; } -impl WithAddition for Bls12381Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.bls12381_g1_add_events - } -} - -impl WithDoubling for Bls12381Parameters { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - &record.bls12381_g1_double_events - } -} - /// The WeierstrassParameters for BLS12-381 G1 impl WeierstrassParameters for Bls12381Parameters { const A: Array::NB_LIMBS> = Array([ diff --git a/core/src/utils/ec/weierstrass/bn254.rs b/core/src/utils/ec/weierstrass/bn254.rs index 4c414da86..3550da138 100644 --- a/core/src/utils/ec/weierstrass/bn254.rs +++ b/core/src/utils/ec/weierstrass/bn254.rs @@ -13,8 +13,6 @@ use crate::syscall::precompiles::create_ec_add_event; use crate::syscall::precompiles::create_ec_double_event; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurveParameters; -use crate::utils::ec::WithAddition; -use crate::utils::ec::WithDoubling; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] /// Bn254 curve parameter @@ -58,25 +56,6 @@ impl EllipticCurveParameters for Bn254Parameters { const CURVE_TYPE: CurveType = CurveType::Bn254; } -impl WithAddition for Bn254Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.bn254_add_events - } -} - -impl WithDoubling for Bn254Parameters { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - &record.bn254_double_events - } -} - impl Syscall for WeierstrassAddAssignChip { fn execute( &self, diff --git a/core/src/utils/ec/weierstrass/mod.rs b/core/src/utils/ec/weierstrass/mod.rs index c86ee396c..78a620486 100644 --- a/core/src/utils/ec/weierstrass/mod.rs +++ b/core/src/utils/ec/weierstrass/mod.rs @@ -5,9 +5,7 @@ use serde::{Deserialize, Serialize}; use super::CurveType; use crate::operations::field::params::FieldParameters; use crate::utils::ec::utils::biguint_to_bits_le; -use crate::utils::ec::{ - AffinePoint, EllipticCurve, EllipticCurveParameters, WithAddition, WithDoubling, -}; +use crate::utils::ec::{AffinePoint, EllipticCurve, EllipticCurveParameters}; pub mod bls12_381; pub mod bn254; @@ -77,25 +75,6 @@ impl EllipticCurveParameters for SwCurve { const CURVE_TYPE: CurveType = E::CURVE_TYPE; } -impl WithAddition for SwCurve { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - E::add_events(record) - } -} - -impl WithDoubling for SwCurve { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - E::double_events(record) - } -} - impl EllipticCurve for SwCurve { fn ec_add(p: &AffinePoint, q: &AffinePoint) -> AffinePoint { p.sw_add(q) diff --git a/core/src/utils/ec/weierstrass/secp256k1.rs b/core/src/utils/ec/weierstrass/secp256k1.rs index c12dc14d4..3f3d9087d 100644 --- a/core/src/utils/ec/weierstrass/secp256k1.rs +++ b/core/src/utils/ec/weierstrass/secp256k1.rs @@ -19,7 +19,6 @@ use crate::{ runtime::Syscall, stark::{WeierstrassAddAssignChip, WeierstrassDoubleAssignChip}, syscall::precompiles::{create_ec_add_event, create_ec_double_event}, - utils::ec::{WithAddition, WithDoubling}, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -56,25 +55,6 @@ impl EllipticCurveParameters for Secp256k1Parameters { const CURVE_TYPE: CurveType = CurveType::Secp256k1; } -impl WithAddition for Secp256k1Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.secp256k1_add_events - } -} - -impl WithDoubling for Secp256k1Parameters { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - &record.secp256k1_double_events - } -} - impl WeierstrassParameters for Secp256k1Parameters { const A: Array::NB_LIMBS> = Array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index e11636af2..d0a526adf 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -24,7 +24,7 @@ use crate::stark::StarkVerifyingKey; use crate::stark::Val; use crate::stark::VerifierConstraintFolder; use crate::stark::{Com, PcsProverData, RiscvAir, ShardProof, StarkProvingKey, UniConfig}; -use crate::stark::{MachineRecord, StarkMachine}; +use crate::stark::{Indexed, MachineRecord, StarkMachine}; use crate::utils::env; use crate::{ runtime::{Program, Runtime}, diff --git a/derive/src/lib.rs b/derive/src/lib.rs index a8f57b225..c1bff2797 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -85,6 +85,141 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { TokenStream::from(methods) } +/// Derives WithEvents for an enum which every variant has one field, +/// each of which implements WithEvents. +/// +/// The derived implementation is a tuple of the Events of each variant, +/// in the variant declaration order. That is, because the chip could be *any* variant, +/// it requires being able to provide for *all* event types consumable by each chip. +#[proc_macro_derive(WithEvents, attributes(sphinx_core_path))] +pub fn with_events_air_derive(input: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let type_params = generics.type_params(); + let const_params = generics.const_params(); + + let sphinx_core_path = find_sphinx_core_path(&ast.attrs); + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(_) | Data::Union(_) => unimplemented!("Only Enums are supported yet"), + Data::Enum(e) => { + let fields = e + .variants + .iter() + .map(|variant| { + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!( + fields.next().is_none(), + "Only one field per variant is supported" + ); + field + }) + .collect::>(); + + let field_ty_events = fields.iter().map(|field| { + let field_ty = &field.ty; + quote! { + <#field_ty as #sphinx_core_path::air::WithEvents<'a>>::Events + } + }); + quote!{ + impl <'a, #(#type_params),*, #(#const_params),*> #sphinx_core_path::air::WithEvents<'a> for #name #ty_generics #where_clause { + type Events = (#(#field_ty_events,)*); + } + }.into() + } + } +} + +fn get_type_from_attrs(attrs: &[syn::Attribute], attr_name: &str) -> syn::Result { + attrs + .iter() + .find(|attr| attr.path.is_ident(attr_name)) + .map_or_else( + || { + Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!("Could not find attribute {}", attr_name), + )) + }, + |attr| match attr.parse_meta()? { + syn::Meta::NameValue(meta) => { + if let syn::Lit::Str(lit) = &meta.lit { + Ok(lit.clone()) + } else { + Err(syn::Error::new_spanned( + meta, + &format!("Could not parse {} attribute", attr_name)[..], + )) + } + } + bad => Err(syn::Error::new_spanned( + bad, + &format!("Could not parse {} attribute", attr_name)[..], + )), + }, + ) +} + +/// Derives EventLens for an enum which every variant has one field, +/// each of which the input record implements EventLens for. +/// +/// The derived implementation is a delegation to the events of each variant, +/// in the variant declaration order. +/// +/// This makes the strong assumption that the `WithEvents` trait is implemented on `Self` +/// as a tuple of the `WithEvents` of each variant. This is what the `WithEvents` derive macro does. +#[proc_macro_derive(EventLens, attributes(sphinx_core_path, record_type))] +pub fn event_lens_air_derive(input: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let record_type = get_type_from_attrs(&ast.attrs, "record_type").unwrap(); + let rec_ty: syn::Type = record_type.parse().unwrap(); + + let sphinx_core_path = find_sphinx_core_path(&ast.attrs); + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(_) | Data::Union(_) => unimplemented!("Only Enums are supported yet"), + Data::Enum(e) => { + let fields = e + .variants + .iter() + .map(|variant| { + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!( + fields.next().is_none(), + "Only one field per variant is supported" + ); + field + }) + .collect::>(); + + let field_events = fields.iter().map(|field| { + let field_ty = &field.ty; + quote! { + #sphinx_core_path::air::EventLens::<#field_ty>::events(self) + } + }); + let res = quote! { + impl #impl_generics #sphinx_core_path::air::EventLens<#name #ty_generics> for #rec_ty { + fn events(&self) -> <#name #ty_generics as #sphinx_core_path::air::WithEvents>::Events { + (#(#field_events,)*) + } + } + }; + res.into() + } + } +} + #[proc_macro_derive( MachineAir, attributes(sphinx_core_path, execution_record_path, program_path, builder_path) @@ -94,11 +229,17 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let name = &ast.ident; let generics = &ast.generics; + let type_params = generics.type_params(); + let ty_params = quote! { #(#type_params),* }; + let const_params = generics.const_params(); + let co_params = quote! { #(#const_params),* }; + let sphinx_core_path = find_sphinx_core_path(&ast.attrs); let execution_record_path = find_execution_record_path(&ast.attrs); let program_path = find_program_path(&ast.attrs); let builder_path = find_builder_path(&ast.attrs); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let turbo_ty = ty_generics.as_turbofish(); match &ast.data { Data::Struct(_) => unimplemented!("Structs are not supported yet"), @@ -111,7 +252,10 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let mut fields = variant.fields.iter(); let field = fields.next().unwrap(); - assert!(fields.next().is_none(), "Only one field is supported"); + assert!( + fields.next().is_none(), + "Only one field per variant is supported" + ); (variant_name, field) }) .collect::>(); @@ -158,17 +302,29 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } }); - let generate_trace_arms = variants.iter().map(|(variant_name, field)| { + let generate_trace_arms = variants.iter().enumerate().map(|(i, (variant_name, field))| { let field_ty = &field.ty; + + let idx = syn::Index::from(i); quote! { - #name::#variant_name(x) => <#field_ty as #sphinx_core_path::air::MachineAir>::generate_trace(x, input, output) + #name::#variant_name(x) => { + fn f <'c, #ty_params, #co_params> (evs: <#name #ty_generics as #sphinx_core_path::air::WithEvents<'c>>::Events, _v: &'c ()) -> <#field_ty as #sphinx_core_path::air::WithEvents<'c>>::Events { evs.#idx } + + <#field_ty as #sphinx_core_path::air::MachineAir>::generate_trace(x, &#sphinx_core_path::air::Proj::new(input, f #turbo_ty), output) + } } }); - let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| { + let generate_dependencies_arms = variants.iter().enumerate().map(|(i, (variant_name, field))| { let field_ty = &field.ty; + let idx = syn::Index::from(i); + quote! { - #name::#variant_name(x) => <#field_ty as #sphinx_core_path::air::MachineAir>::generate_dependencies(x, input, output) + #name::#variant_name(x) => { + fn f <'c, #ty_params, #co_params> (evs: <#name #ty_generics as #sphinx_core_path::air::WithEvents<'c>>::Events, _v: &'c ()) -> <#field_ty as #sphinx_core_path::air::WithEvents<'c>>::Events { evs.#idx } + + <#field_ty as #sphinx_core_path::air::MachineAir>::generate_dependencies(x, &#sphinx_core_path::air::Proj::new(input, f #turbo_ty), output) + } } }); @@ -206,9 +362,9 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } } - fn generate_trace( + fn generate_trace>( &self, - input: &#execution_record_path, + input: &EL, output: &mut #execution_record_path, ) -> p3_matrix::dense::RowMajorMatrix { match self { @@ -216,9 +372,9 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } } - fn generate_dependencies( + fn generate_dependencies>( &self, - input: &#execution_record_path, + input: &EL, output: &mut #execution_record_path, ) { match self { diff --git a/examples/is-prime/program/Cargo.lock b/examples/is-prime/program/Cargo.lock index 1974a3c1c..792100ff2 100644 --- a/examples/is-prime/program/Cargo.lock +++ b/examples/is-prime/program/Cargo.lock @@ -234,7 +234,7 @@ dependencies = [ name = "is-prime-program" version = "0.1.0" dependencies = [ - "wp1-zkvm", + "sphinx-zkvm", ] [[package]] @@ -407,6 +407,36 @@ dependencies = [ "rand_core", ] +[[package]] +name = "sphinx-precompiles" +version = "0.1.0" +source = "git+ssh://git@github.com/lurk-lab/sphinx.git#eeea8c7319aa065b198119ddc4aa3acbd921d143" +dependencies = [ + "anyhow", + "bincode", + "bls12_381", + "cfg-if", + "getrandom", + "hybrid-array", + "k256", + "serde", +] + +[[package]] +name = "sphinx-zkvm" +version = "0.1.0" +source = "git+ssh://git@github.com/lurk-lab/sphinx.git#eeea8c7319aa065b198119ddc4aa3acbd921d143" +dependencies = [ + "bincode", + "cfg-if", + "getrandom", + "k256", + "once_cell", + "rand", + "sha2", + "sphinx-precompiles", +] + [[package]] name = "spki" version = "0.7.3" @@ -464,36 +494,6 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" -[[package]] -name = "wp1-precompiles" -version = "0.1.0" -source = "git+ssh://git@github.com/wormhole-foundation/wp1.git#e6718735bba932d7e4820f37bea5028459948e18" -dependencies = [ - "anyhow", - "bincode", - "bls12_381", - "cfg-if", - "getrandom", - "hybrid-array", - "k256", - "serde", -] - -[[package]] -name = "wp1-zkvm" -version = "0.1.0" -source = "git+ssh://git@github.com/wormhole-foundation/wp1.git#e6718735bba932d7e4820f37bea5028459948e18" -dependencies = [ - "bincode", - "cfg-if", - "getrandom", - "k256", - "once_cell", - "rand", - "sha2", - "wp1-precompiles", -] - [[package]] name = "wyz" version = "0.5.1" diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index b8467b8ce..e0781389e 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -101,7 +101,7 @@ where let chip_idx = machine .chips() .iter() - .rposition(|chip| &chip.name() == name) + .rposition(|chip| &chip.as_ref().name() == name) .unwrap(); let index = sorted_indices[chip_idx]; let opening = &opened_values.chips[index]; @@ -211,7 +211,7 @@ where for (i, sorted_chip) in sorted_chips.iter().enumerate() { for chip in machine.chips() { - if chip.name() == *sorted_chip { + if chip.as_ref().name() == *sorted_chip { let values = &opened_values.chips[i]; let trace_domain = &trace_domains[i]; let quotient_domain = "ient_domains[i]; @@ -294,7 +294,7 @@ pub fn build_wrap_circuit( let chips = outer_machine .shard_chips_ordered(&template_proof.chip_ordering) - .map(|chip| chip.name()) + .map(|chip| chip.as_ref().name()) .collect::>(); let sorted_indices = outer_machine @@ -303,7 +303,7 @@ pub fn build_wrap_circuit( .map(|chip| { template_proof .chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .copied() .unwrap_or(usize::MAX) }) diff --git a/recursion/circuit/src/types.rs b/recursion/circuit/src/types.rs index 012a4bb05..d8ce02e7b 100644 --- a/recursion/circuit/src/types.rs +++ b/recursion/circuit/src/types.rs @@ -149,7 +149,7 @@ impl ChipOpening { local: vec![], next: vec![], }; - let preprocess_width = chip.preprocessed_width(); + let preprocess_width = chip.as_ref().preprocessed_width(); for i in 0..preprocess_width { preprocessed.local.push(opening.preprocessed.local[i]); preprocessed.next.push(opening.preprocessed.next[i]); diff --git a/recursion/core/src/cpu/trace.rs b/recursion/core/src/cpu/trace.rs index 8b80d8976..995969156 100644 --- a/recursion/core/src/cpu/trace.rs +++ b/recursion/core/src/cpu/trace.rs @@ -1,9 +1,9 @@ use std::borrow::BorrowMut; -use p3_field::{extension::BinomiallyExtendable, PrimeField32}; +use p3_field::{extension::BinomiallyExtendable, Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_core::{ - air::{BinomialExtension, MachineAir}, + air::{BinomialExtension, EventLens, MachineAir, WithEvents}, utils::pad_rows_fixed, }; use tracing::instrument; @@ -14,9 +14,16 @@ use crate::{ runtime::{ExecutionRecord, Opcode, RecursionProgram, D}, }; -use super::{CpuChip, CpuCols, CPU_COL_MAP, NUM_CPU_COLS}; +use super::{CpuChip, CpuCols, CpuEvent, CPU_COL_MAP, NUM_CPU_COLS}; -impl> MachineAir for CpuChip { +impl<'a, F: Field> WithEvents<'a> for CpuChip { + type Events = &'a [CpuEvent]; +} + +impl> MachineAir for CpuChip +where + ExecutionRecord: EventLens>, +{ type Record = ExecutionRecord; type Program = RecursionProgram; @@ -24,18 +31,18 @@ impl> MachineAir for CpuChip { "CPU".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // There are no dependencies, since we do it all in the runtime. This is just a placeholder. } - #[instrument(name = "generate cpu trace", level = "debug", skip_all, fields(rows = input.cpu_events.len()))] - fn generate_trace( + #[instrument(name = "generate cpu trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = input - .cpu_events + .events() .iter() .map(|event| { let mut row = [F::zero(); NUM_CPU_COLS]; @@ -124,7 +131,7 @@ impl> MachineAir for CpuChip { let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_CPU_COLS); - for i in input.cpu_events.len()..trace.height() { + for i in input.events().len()..trace.height() { trace.values[i * NUM_CPU_COLS + CPU_COL_MAP.clk] = F::from_canonical_u32(4) * F::from_canonical_usize(i); trace.values[i * NUM_CPU_COLS + CPU_COL_MAP.instruction.imm_b] = diff --git a/recursion/core/src/fri_fold/mod.rs b/recursion/core/src/fri_fold/mod.rs index 943ba8d14..3f163399c 100644 --- a/recursion/core/src/fri_fold/mod.rs +++ b/recursion/core/src/fri_fold/mod.rs @@ -6,14 +6,17 @@ use crate::runtime::Opcode; use core::borrow::Borrow; use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::AbstractField; use p3_field::PrimeField32; +use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir}; +use sphinx_core::air::{ + BaseAirBuilder, BinomialExtension, EventLens, ExtensionAirBuilder, MachineAir, WithEvents, +}; use sphinx_core::utils::pad_rows_fixed; use sphinx_derive::AlignedBorrow; use std::borrow::BorrowMut; +use std::marker::PhantomData; use tracing::instrument; use crate::air::SphinxRecursionAirBuilder; @@ -23,8 +26,9 @@ use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::>(); #[derive(Default)] -pub struct FriFoldChip { +pub struct FriFoldChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } #[derive(Debug, Clone)] @@ -83,13 +87,17 @@ pub struct FriFoldCols { pub is_real: T, } -impl BaseAir for FriFoldChip { +impl BaseAir for FriFoldChip { fn width(&self) -> usize { NUM_FRI_FOLD_COLS } } -impl MachineAir for FriFoldChip { +impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for FriFoldChip { + type Events = &'a [FriFoldEvent]; +} + +impl MachineAir for FriFoldChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -98,18 +106,18 @@ impl MachineAir for FriFoldChip "FriFold".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.fri_fold_events.len()))] - fn generate_trace( + #[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = input - .fri_fold_events + .events() .iter() .map(|event| { let mut row = [F::zero(); NUM_FRI_FOLD_COLS]; @@ -167,7 +175,7 @@ impl MachineAir for FriFoldChip } } -impl FriFoldChip { +impl FriFoldChip { pub fn eval_fri_fold( &self, builder: &mut AB, @@ -368,7 +376,7 @@ impl FriFoldChip { } } -impl Air for FriFoldChip +impl Air for FriFoldChip where AB: SphinxRecursionAirBuilder, { diff --git a/recursion/core/src/memory/air.rs b/recursion/core/src/memory/air.rs index b5d968653..bd39961bf 100644 --- a/recursion/core/src/memory/air.rs +++ b/recursion/core/src/memory/air.rs @@ -1,32 +1,45 @@ use core::mem::size_of; -use std::borrow::{Borrow, BorrowMut}; +use std::{ + borrow::{Borrow, BorrowMut}, + marker::PhantomData, +}; use p3_air::{Air, BaseAir}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_core::{ - air::{AirInteraction, MachineAir, MemoryAirBuilder}, + air::{AirInteraction, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, lookup::InteractionKind, utils::pad_rows_fixed, }; use tracing::instrument; use super::columns::MemoryInitCols; -use crate::memory::MemoryGlobalChip; use crate::runtime::{ExecutionRecord, RecursionProgram}; +use crate::{air::Block, memory::MemoryGlobalChip}; pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::>(); #[allow(dead_code)] -impl MemoryGlobalChip { +impl MemoryGlobalChip { pub fn new() -> Self { Self { fixed_log2_rows: None, + _phantom: PhantomData, } } } -impl MachineAir for MemoryGlobalChip { +impl<'a, F: Field> WithEvents<'a> for MemoryGlobalChip { + type Events = ( + // first memory event + &'a [(F, Block)], + // last memory event + &'a [(F, F, Block)], + ); +} + +impl MachineAir for MemoryGlobalChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -34,22 +47,22 @@ impl MachineAir for MemoryGlobalChip { "MemoryGlobalChip".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate memory trace", level = "debug", skip_all, fields(first_rows = input.first_memory_record.len(), last_rows = input.last_memory_record.len()))] - fn generate_trace( + #[instrument(name = "generate memory trace", level = "debug", skip_all, fields(first_rows = input.events().0.len(), last_rows = input.events().1.len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); + let (first_memory_events, last_memory_events) = input.events(); // Fill in the initial memory records. rows.extend( - input - .first_memory_record + first_memory_events .iter() .map(|(addr, value)| { let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; @@ -65,8 +78,7 @@ impl MachineAir for MemoryGlobalChip { // Fill in the finalize memory records. rows.extend( - input - .last_memory_record + last_memory_events .iter() .map(|(addr, timestamp, value)| { let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; @@ -98,13 +110,13 @@ impl MachineAir for MemoryGlobalChip { } } -impl BaseAir for MemoryGlobalChip { +impl BaseAir for MemoryGlobalChip { fn width(&self) -> usize { NUM_MEMORY_INIT_COLS } } -impl Air for MemoryGlobalChip +impl Air for MemoryGlobalChip where AB: MemoryAirBuilder, { diff --git a/recursion/core/src/memory/mod.rs b/recursion/core/src/memory/mod.rs index 834de2e7a..7d97cc912 100644 --- a/recursion/core/src/memory/mod.rs +++ b/recursion/core/src/memory/mod.rs @@ -1,6 +1,8 @@ mod air; mod columns; +use std::marker::PhantomData; + use p3_field::PrimeField32; use crate::air::Block; @@ -109,6 +111,7 @@ impl MemoryAccessCols { } #[derive(Default)] -pub struct MemoryGlobalChip { +pub struct MemoryGlobalChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 1b548ef4e..132ba0229 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -1,11 +1,12 @@ use std::borrow::{Borrow, BorrowMut}; +use std::marker::PhantomData; use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, MachineAir}; +use sphinx_core::air::{BaseAirBuilder, EventLens, MachineAir, Proj, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use sphinx_derive::AlignedBorrow; @@ -17,8 +18,9 @@ use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_MULTI_COLS: usize = core::mem::size_of::>(); #[derive(Default)] -pub struct MultiChip { +pub struct MultiChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } #[derive(AlignedBorrow, Clone, Copy)] @@ -42,13 +44,20 @@ pub union InstructionSpecificCols { poseidon2: Poseidon2Cols, } -impl BaseAir for MultiChip { +impl BaseAir for MultiChip { fn width(&self) -> usize { NUM_MULTI_COLS } } -impl MachineAir for MultiChip { +impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for MultiChip { + type Events = ( + as WithEvents<'a>>::Events, + as WithEvents<'a>>::Events, + ); +} + +impl MachineAir for MultiChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -57,19 +66,36 @@ impl MachineAir for MultiChip { "Multi".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let fri_fold_chip = FriFoldChip::<3>::default(); + let fri_fold_chip = FriFoldChip::::default(); let poseidon2 = Poseidon2Chip::default(); - let fri_fold_trace = fri_fold_chip.generate_trace(input, output); - let mut poseidon2_trace = poseidon2.generate_trace(input, output); + + fn to_fri<'c, F: PrimeField32, const DEGREE: usize>( + evs: as WithEvents<'c>>::Events, + _v: &'c (), + ) -> as WithEvents<'c>>::Events { + evs.0 + } + + fn to_poseidon<'c, F: PrimeField32, const DEGREE: usize>( + evs: as WithEvents<'c>>::Events, + _v: &'c (), + ) -> as WithEvents<'c>>::Events { + evs.1 + } + + let fri_fold_trace = + fri_fold_chip.generate_trace(&Proj::new(input, to_fri::), output); + let mut poseidon2_trace = + poseidon2.generate_trace(&Proj::new(input, to_poseidon::), output); let mut rows = fri_fold_trace .clone() @@ -85,15 +111,17 @@ impl MachineAir for MultiChip { let fri_fold_cols = *cols.fri_fold(); cols.fri_fold_receive_table = - FriFoldChip::<3>::do_receive_table(&fri_fold_cols); + FriFoldChip::::do_receive_table(&fri_fold_cols); cols.fri_fold_memory_access = - FriFoldChip::<3>::do_memory_access(&fri_fold_cols); + FriFoldChip::::do_memory_access(&fri_fold_cols); } else { cols.is_poseidon2 = F::one(); let poseidon2_cols = *cols.poseidon2(); - cols.poseidon2_receive_table = Poseidon2Chip::do_receive_table(&poseidon2_cols); - cols.poseidon2_memory_access = Poseidon2Chip::do_memory_access(&poseidon2_cols); + cols.poseidon2_receive_table = + Poseidon2Chip::::do_receive_table(&poseidon2_cols); + cols.poseidon2_memory_access = + Poseidon2Chip::::do_memory_access(&poseidon2_cols); } row }) @@ -115,7 +143,7 @@ impl MachineAir for MultiChip { } } -impl Air for MultiChip +impl Air for MultiChip where AB: SphinxRecursionAirBuilder, { @@ -157,15 +185,17 @@ where let fri_columns_local = local.fri_fold(); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::<3>::do_memory_access::(fri_columns_local), + local.is_fri_fold + * FriFoldChip::::do_memory_access::(fri_columns_local), local.fri_fold_memory_access, ); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::<3>::do_receive_table::(fri_columns_local), + local.is_fri_fold + * FriFoldChip::::do_receive_table::(fri_columns_local), local.fri_fold_receive_table, ); - let fri_fold_chip = FriFoldChip::<3>::default(); + let fri_fold_chip = FriFoldChip::::default(); fri_fold_chip.eval_fri_fold( &mut sub_builder, local.fri_fold(), @@ -182,16 +212,17 @@ where let poseidon2_columns = local.poseidon2(); sub_builder.assert_eq( - local.is_poseidon2 * Poseidon2Chip::do_receive_table::(poseidon2_columns), + local.is_poseidon2 + * Poseidon2Chip::::do_receive_table::(poseidon2_columns), local.poseidon2_receive_table, ); sub_builder.assert_eq( local.is_poseidon2 - * Poseidon2Chip::do_memory_access::(poseidon2_columns), + * Poseidon2Chip::::do_memory_access::(poseidon2_columns), local.poseidon2_memory_access, ); - let poseidon2_chip = Poseidon2Chip::default(); + let poseidon2_chip = Poseidon2Chip::::default(); poseidon2_chip.eval_poseidon2( &mut sub_builder, local.poseidon2(), @@ -215,6 +246,7 @@ impl MultiCols { #[cfg(test)] mod tests { use itertools::Itertools; + use std::marker::PhantomData; use std::time::Instant; use p3_baby_bear::BabyBear; @@ -239,8 +271,9 @@ mod tests { let config = BabyBearPoseidon2::compressed(); let mut challenger = config.challenger(); - let chip = MultiChip::<5> { + let chip = MultiChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }; let test_inputs = (0..16) diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index 98812828e..e431c0115 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -2,10 +2,11 @@ use core::borrow::Borrow; use core::mem::size_of; use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; use sphinx_core::air::{BaseAirBuilder, ExtensionAirBuilder}; use sphinx_primitives::RC_16_30_U32; +use std::marker::PhantomData; use std::ops::Add; use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; @@ -23,17 +24,18 @@ pub(crate) const WIDTH: usize = 16; /// A chip that implements addition for the opcode ADD. #[derive(Default)] -pub struct Poseidon2Chip { +pub struct Poseidon2Chip { pub fixed_log2_rows: Option, + _phantom: PhantomData, } -impl BaseAir for Poseidon2Chip { +impl BaseAir for Poseidon2Chip { fn width(&self) -> usize { NUM_POSEIDON2_COLS } } -impl Poseidon2Chip { +impl Poseidon2Chip { pub fn eval_poseidon2( &self, builder: &mut AB, @@ -318,7 +320,7 @@ impl Poseidon2Chip { } } -impl Air for Poseidon2Chip +impl Air for Poseidon2Chip where AB: BaseAirBuilder, { @@ -343,6 +345,7 @@ where mod tests { use itertools::Itertools; use std::borrow::Borrow; + use std::marker::PhantomData; use std::time::Instant; use zkhash::ark_ff::UniformRand; @@ -372,6 +375,7 @@ mod tests { fn generate_trace() { let chip = Poseidon2Chip { fixed_log2_rows: None, + _phantom: PhantomData, }; let rng = &mut rand::thread_rng(); @@ -421,6 +425,7 @@ mod tests { let chip = Poseidon2Chip { fixed_log2_rows: None, + _phantom: PhantomData, }; let trace: RowMajorMatrix = chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index 8030f6bbc..b872be2ae 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -1,8 +1,11 @@ use std::borrow::BorrowMut; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use sphinx_core::{air::MachineAir, utils::pad_rows_fixed}; +use sphinx_core::{ + air::{EventLens, MachineAir, WithEvents}, + utils::pad_rows_fixed, +}; use sphinx_primitives::RC_16_30_U32; use tracing::instrument; @@ -13,10 +16,14 @@ use crate::{ use super::{ external::{NUM_POSEIDON2_COLS, WIDTH}, - Poseidon2Chip, Poseidon2Cols, + Poseidon2Chip, Poseidon2Cols, Poseidon2Event, }; -impl MachineAir for Poseidon2Chip { +impl<'a, F: Field> WithEvents<'a> for Poseidon2Chip { + type Events = &'a [Poseidon2Event]; +} + +impl MachineAir for Poseidon2Chip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -25,14 +32,14 @@ impl MachineAir for Poseidon2Chip { "Poseidon2".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate poseidon2 trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] - fn generate_trace( + #[instrument(name = "generate poseidon2 trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); @@ -44,7 +51,7 @@ impl MachineAir for Poseidon2Chip { let rounds_p_beginning = 2 + rounds_f / 2; let p_end = rounds_p_beginning + rounds_p; - for poseidon2_event in input.poseidon2_events.iter() { + for poseidon2_event in input.events().iter() { let mut round_input = Default::default(); for r in 0..rounds { let mut row = [F::zero(); NUM_POSEIDON2_COLS]; diff --git a/recursion/core/src/poseidon2_wide/external.rs b/recursion/core/src/poseidon2_wide/external.rs index cbcedb28f..d0ffc30b1 100644 --- a/recursion/core/src/poseidon2_wide/external.rs +++ b/recursion/core/src/poseidon2_wide/external.rs @@ -1,3 +1,4 @@ +use crate::poseidon2::Poseidon2Event; use crate::poseidon2_wide::columns::{ Poseidon2ColType, Poseidon2ColTypeMut, Poseidon2Cols, Poseidon2SBoxCols, NUM_POSEIDON2_COLS, NUM_POSEIDON2_SBOX_COLS, @@ -5,13 +6,14 @@ use crate::poseidon2_wide::columns::{ use crate::runtime::Opcode; use core::borrow::Borrow; use p3_air::{Air, BaseAir}; -use p3_field::{AbstractField, PrimeField32}; +use p3_field::{AbstractField, Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, MachineAir}; +use sphinx_core::air::{BaseAirBuilder, EventLens, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use sphinx_primitives::RC_16_30_U32; use std::borrow::BorrowMut; +use std::marker::PhantomData; use tracing::instrument; use crate::air::SphinxRecursionAirBuilder; @@ -31,11 +33,16 @@ pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; /// A chip that implements addition for the opcode ADD. #[derive(Default)] -pub struct Poseidon2WideChip { +pub struct Poseidon2WideChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } -impl MachineAir for Poseidon2WideChip { +impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for Poseidon2WideChip { + type Events = &'a [Poseidon2Event]; +} + +impl MachineAir for Poseidon2WideChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -44,14 +51,14 @@ impl MachineAir for Poseidon2WideChip>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] - fn generate_trace( + #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); @@ -60,7 +67,7 @@ impl MachineAir for Poseidon2WideChip>::width(self); - for event in &input.poseidon2_events { + for event in input.events() { let mut row = Vec::new(); row.resize(num_columns, F::zero()); @@ -333,7 +340,7 @@ fn eval_internal_rounds( } } -impl BaseAir for Poseidon2WideChip { +impl BaseAir for Poseidon2WideChip { fn width(&self) -> usize { match DEGREE { d if d < 7 => NUM_POSEIDON2_SBOX_COLS, @@ -381,7 +388,7 @@ fn eval_mem(builder: &mut AB, local: &Poseidon2Me ); } -impl Air for Poseidon2WideChip +impl Air for Poseidon2WideChip where AB: SphinxRecursionAirBuilder, { @@ -445,6 +452,7 @@ where #[cfg(test)] mod tests { + use std::marker::PhantomData; use std::time::Instant; use crate::poseidon2::Poseidon2Event; @@ -463,8 +471,9 @@ mod tests { /// A test generating a trace for a single permutation that checks that the output is correct fn generate_trace_degree() { - let chip = Poseidon2WideChip:: { + let chip = Poseidon2WideChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }; let test_inputs = vec![ @@ -509,8 +518,9 @@ mod tests { inputs: Vec<[BabyBear; 16]>, outputs: Vec<[BabyBear; 16]>, ) { - let chip = Poseidon2WideChip:: { + let chip = Poseidon2WideChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }; let mut input_exec = ExecutionRecord::::default(); for (input, output) in inputs.into_iter().zip_eq(outputs) { diff --git a/recursion/core/src/program/mod.rs b/recursion/core/src/program/mod.rs index 4bfb43d2f..c2ff61158 100644 --- a/recursion/core/src/program/mod.rs +++ b/recursion/core/src/program/mod.rs @@ -1,13 +1,15 @@ use crate::air::SphinxRecursionAirBuilder; +use crate::cpu::{CpuEvent, Instruction}; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; use p3_air::{Air, BaseAir, PairBuilder}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::MachineAir; +use sphinx_core::air::{EventLens, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use std::collections::HashMap; +use std::marker::PhantomData; use tracing::instrument; use sphinx_derive::AlignedBorrow; @@ -38,15 +40,24 @@ pub struct ProgramMultiplicityCols { /// A chip that implements addition for the opcodes ADD and ADDI. #[derive(Default)] -pub struct ProgramChip; +pub struct ProgramChip(pub PhantomData); -impl ProgramChip { +impl ProgramChip { pub fn new() -> Self { - Self {} + Self(PhantomData) } } -impl MachineAir for ProgramChip { +impl<'a, F: Field> WithEvents<'a> for ProgramChip { + type Events = ( + // program.instructions + &'a [Instruction], + // cpu_events + &'a [CpuEvent], + ); +} + +impl MachineAir for ProgramChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -92,20 +103,20 @@ impl MachineAir for ProgramChip { )) } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate program trace", level = "debug", skip_all, fields(rows = input.program.instructions.len()))] - fn generate_trace( + #[instrument(name = "generate program trace", level = "debug", skip_all, fields(rows = input.events().0.len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); - input.cpu_events.iter().for_each(|event| { + input.events().1.iter().for_each(|event| { let pc = event.pc; instruction_counts .entry(pc.as_canonical_u32()) @@ -115,9 +126,9 @@ impl MachineAir for ProgramChip { let max_program_size = match std::env::var("MAX_RECURSION_PROGRAM_SIZE") { Ok(value) => value.parse().unwrap(), - Err(_) => std::cmp::min(1048576, input.program.instructions.len()), + Err(_) => std::cmp::min(1048576, input.events().0.len()), }; - let mut rows = input.program.instructions[0..max_program_size] + let mut rows = input.events().0[0..max_program_size] .iter() .enumerate() .map(|(i, _)| { @@ -145,13 +156,13 @@ impl MachineAir for ProgramChip { } } -impl BaseAir for ProgramChip { +impl BaseAir for ProgramChip { fn width(&self) -> usize { NUM_PROGRAM_MULT_COLS } } -impl Air for ProgramChip +impl Air for ProgramChip where AB: SphinxRecursionAirBuilder + PairBuilder, { diff --git a/recursion/core/src/range_check/trace.rs b/recursion/core/src/range_check/trace.rs index c5ee4c715..21005d212 100644 --- a/recursion/core/src/range_check/trace.rs +++ b/recursion/core/src/range_check/trace.rs @@ -1,17 +1,21 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, collections::BTreeMap}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use sphinx_core::air::MachineAir; +use sphinx_core::air::{EventLens, MachineAir, WithEvents}; use super::{ columns::{RangeCheckMultCols, NUM_RANGE_CHECK_MULT_COLS, NUM_RANGE_CHECK_PREPROCESSED_COLS}, - RangeCheckChip, + RangeCheckChip, RangeCheckEvent, }; use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_ROWS: usize = 1 << 16; +impl<'a, F: Field> WithEvents<'a> for RangeCheckChip { + type Events = &'a BTreeMap; +} + impl MachineAir for RangeCheckChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -30,13 +34,13 @@ impl MachineAir for RangeCheckChip { Some(trace) } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { let (_, event_map) = Self::trace_and_map(); @@ -46,7 +50,7 @@ impl MachineAir for RangeCheckChip { NUM_RANGE_CHECK_MULT_COLS, ); - for (lookup, mult) in input.range_check_events.iter() { + for (lookup, mult) in input.events().iter() { let (row, index) = event_map[lookup]; let cols: &mut RangeCheckMultCols = trace.row_mut(row).borrow_mut(); diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 7509a4073..88149bbac 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -2,14 +2,19 @@ use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use p3_field::{AbstractField, PrimeField32}; -use sphinx_core::stark::{MachineRecord, PROOF_MAX_NUM_PVS}; +use sphinx_core::air::EventLens; +use sphinx_core::stark::{Indexed, MachineRecord, PROOF_MAX_NUM_PVS}; use super::RecursionProgram; use crate::air::Block; -use crate::cpu::CpuEvent; -use crate::fri_fold::FriFoldEvent; -use crate::poseidon2::Poseidon2Event; -use crate::range_check::RangeCheckEvent; +use crate::cpu::{CpuChip, CpuEvent}; +use crate::fri_fold::{FriFoldChip, FriFoldEvent}; +use crate::memory::MemoryGlobalChip; +use crate::multi::MultiChip; +use crate::poseidon2::{Poseidon2Chip, Poseidon2Event}; +use crate::poseidon2_wide::Poseidon2WideChip; +use crate::program::ProgramChip; +use crate::range_check::{RangeCheckChip, RangeCheckEvent}; #[derive(Default, Debug, Clone)] pub struct ExecutionRecord { @@ -37,12 +42,14 @@ impl ExecutionRecord { } } -impl MachineRecord for ExecutionRecord { - type Config = (); - +impl Indexed for ExecutionRecord { fn index(&self) -> u32 { 0 } +} + +impl MachineRecord for ExecutionRecord { + type Config = (); fn set_index(&mut self, _: u32) {} @@ -92,3 +99,58 @@ impl MachineRecord for ExecutionRecord { ret } } + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + &self.cpu_events + } +} + +impl EventLens> + for ExecutionRecord +{ + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + &self.fri_fold_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + &self.poseidon2_events + } +} + +impl EventLens> + for ExecutionRecord +{ + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + &self.poseidon2_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + (&self.first_memory_record, &self.last_memory_record) + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + (&self.program.instructions, &self.cpu_events) + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + &self.range_check_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + ( + >>::events(self), + >>::events(self), + ) + } +} diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 4ab9b3de4..b658c6a71 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -4,7 +4,7 @@ pub mod utils; use p3_field::{extension::BinomiallyExtendable, PrimeField32}; use sphinx_core::stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS}; -use sphinx_derive::MachineAir; +use sphinx_derive::{EventLens, MachineAir, WithEvents}; use crate::runtime::D; use crate::{ @@ -18,20 +18,21 @@ use std::marker::PhantomData; pub type RecursionAirWideDeg3 = RecursionAir; pub type RecursionAirSkinnyDeg7 = RecursionAir; -#[derive(MachineAir)] +#[derive(WithEvents, EventLens, MachineAir)] #[sphinx_core_path = "sphinx_core"] #[execution_record_path = "crate::runtime::ExecutionRecord"] #[program_path = "crate::runtime::RecursionProgram"] #[builder_path = "crate::air::SphinxRecursionAirBuilder"] +#[record_type = "crate::runtime::ExecutionRecord"] pub enum RecursionAir, const DEGREE: usize> { - Program(ProgramChip), + Program(ProgramChip), Cpu(CpuChip), - MemoryGlobal(MemoryGlobalChip), - Poseidon2Wide(Poseidon2WideChip), - Poseidon2Skinny(Poseidon2Chip), - FriFold(FriFoldChip), + MemoryGlobal(MemoryGlobalChip), + Poseidon2Wide(Poseidon2WideChip), + Poseidon2Skinny(Poseidon2Chip), + FriFold(FriFoldChip), RangeCheck(RangeCheckChip), - Multi(MultiChip), + Multi(MultiChip), } impl, const DEGREE: usize> RecursionAir { @@ -54,37 +55,43 @@ impl, const DEGREE: usize> RecursionAi } pub fn get_all() -> Vec { - once(RecursionAir::Program(ProgramChip)) + once(RecursionAir::Program(ProgramChip(PhantomData))) .chain(once(RecursionAir::Cpu(CpuChip { fixed_log2_rows: None, _phantom: PhantomData, }))) .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: None, + _phantom: PhantomData, }))) .chain(once(RecursionAir::Poseidon2Wide(Poseidon2WideChip::< + F, DEGREE, > { fixed_log2_rows: None, + _phantom: PhantomData, }))) - .chain(once(RecursionAir::FriFold(FriFoldChip:: { + .chain(once(RecursionAir::FriFold(FriFoldChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .collect() } pub fn get_wrap_all() -> Vec { - once(RecursionAir::Program(ProgramChip)) + once(RecursionAir::Program(ProgramChip(PhantomData))) .chain(once(RecursionAir::Cpu(CpuChip { fixed_log2_rows: Some(20), _phantom: PhantomData, }))) .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: Some(20), + _phantom: PhantomData, }))) .chain(once(RecursionAir::Multi(MultiChip { fixed_log2_rows: Some(20), + _phantom: PhantomData, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .collect() diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index a33e60825..a87200a64 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -346,11 +346,11 @@ where // TODO CONSTRAIN: that the preprocessed chips get called with verify_constraints. builder.cycle_tracker("stage-e-verify-constraints"); for (i, chip) in machine.chips().iter().enumerate() { - let chip_name = chip.name(); + let chip_name = chip.as_ref().name(); tracing::debug!("verifying constraints for chip: {}", chip_name); let index = builder.get(&proof.sorted_idxs, i); - if chip.preprocessed_width() > 0 { + if chip.as_ref().preprocessed_width() > 0 { builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); } diff --git a/recursion/program/src/types.rs b/recursion/program/src/types.rs index 5380c1c74..5dcae7a0b 100644 --- a/recursion/program/src/types.rs +++ b/recursion/program/src/types.rs @@ -124,7 +124,7 @@ impl ChipOpening { local: vec![], next: vec![], }; - let preprocessed_width = chip.preprocessed_width(); + let preprocessed_width = chip.as_ref().preprocessed_width(); // Assert that the length of the dynamic arrays match the expected length of the vectors. builder.assert_usize_eq(preprocessed_width, opening.preprocessed.local.len()); builder.assert_usize_eq(preprocessed_width, opening.preprocessed.next.len()); diff --git a/recursion/program/src/utils.rs b/recursion/program/src/utils.rs index f7fa673f1..46f74b27d 100644 --- a/recursion/program/src/utils.rs +++ b/recursion/program/src/utils.rs @@ -225,7 +225,7 @@ pub(crate) fn get_preprocessed_data