From a42d75663f12d482bae25954f18c7ebdc506483b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Wed, 29 May 2024 21:18:58 -0400 Subject: [PATCH 1/4] refactor: Implement AsRef in chip accesses across modules This avoids relying on an instance of `MachineAir` on Chip. - Modified the way `chip` object is accessed across different files in the core and recursion directories by using the `as_ref()` method. - Implemented `AsRef` trait for `Chip` struct for returning an Air reference, and removed `MachineAir` for `Chip` implementation. - Changed the method of accessing functions like `chip.name()`, `generate_trace()`, `preprocessed_width()` through `as_ref()` on `chip`. - Updated error handling in the `Verifier` struct in `stark/verifier.rs` to use `as_ref()`. - Made changes in the recursion program to access `preprocessed_data` using the `as_ref()` method. - Updated the reference of `chip` object in several parts of the `prove_shard` function in `stark/prover.rs`. - Revised accessing `chip` methods with the use of `as_ref()` across different functions in `stark/machine.rs`. --- core/src/lookup/debug.rs | 10 ++++----- core/src/stark/chip.rs | 40 +++++----------------------------- core/src/stark/debug.rs | 2 +- core/src/stark/machine.rs | 28 ++++++++++++------------ core/src/stark/prover.rs | 12 +++++----- core/src/stark/verifier.rs | 12 +++++----- recursion/circuit/src/stark.rs | 8 +++---- recursion/circuit/src/types.rs | 2 +- recursion/program/src/stark.rs | 4 ++-- recursion/program/src/types.rs | 2 +- recursion/program/src/utils.rs | 2 +- 11 files changed, 47 insertions(+), 75 deletions(-) diff --git a/core/src/lookup/debug.rs b/core/src/lookup/debug.rs index 5f0b83821..fa6108437 100644 --- a/core/src/lookup/debug.rs +++ b/core/src/lookup/debug.rs @@ -56,11 +56,11 @@ pub fn debug_interactions>>( let mut key_to_vec_data = BTreeMap::new(); let mut key_to_count = BTreeMap::new(); - let trace = chip.generate_trace(record, &mut A::Record::default()); + let trace = chip.as_ref().generate_trace(record, &mut A::Record::default()); let mut pre_traces = pkey.traces.clone(); let mut preprocessed_trace = pkey .chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .map(|&index| pre_traces.get_mut(index).unwrap()); let mut main = trace.clone(); let height = trace.clone().height(); @@ -102,7 +102,7 @@ pub fn debug_interactions>>( .entry(key.clone()) .or_insert_with(Vec::new) .push(InteractionData { - chip_name: chip.name(), + chip_name: chip.as_ref().name(), kind: interaction.kind, row, interaction_number: m, @@ -150,10 +150,10 @@ where .or_insert((SC::Val::zero(), BTreeMap::new())); entry.0 += *value; total += *value; - *entry.1.entry(chip.name()).or_insert(SC::Val::zero()) += *value; + *entry.1.entry(chip.as_ref().name()).or_insert(SC::Val::zero()) += *value; } } - tracing::info!("{} chip has {} distinct events", chip.name(), total_events); + tracing::info!("{} chip has {} distinct events", chip.as_ref().name(), total_events); } tracing::info!("Final counts below."); diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index 39349b395..8805f5e5e 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -28,6 +28,12 @@ pub struct Chip { log_quotient_degree: usize, } +impl AsRef for Chip { + fn as_ref(&self) -> &A { + &self.air + } +} + impl Chip { /// The send interactions of the chip. pub fn sends(&self) -> &[Interaction] { @@ -151,40 +157,6 @@ where } } -impl MachineAir for Chip -where - F: Field, - A: MachineAir, -{ - type Record = A::Record; - - type Program = A::Program; - - fn name(&self) -> String { - self.air.name() - } - - fn preprocessed_width(&self) -> usize { - >::preprocessed_width(&self.air) - } - - fn generate_preprocessed_trace(&self, program: &A::Program) -> Option> { - >::generate_preprocessed_trace(&self.air, program) - } - - fn generate_trace(&self, input: &A::Record, output: &mut A::Record) -> RowMajorMatrix { - self.air.generate_trace(input, output) - } - - fn generate_dependencies(&self, input: &A::Record, output: &mut A::Record) { - self.air.generate_dependencies(input, output) - } - - fn included(&self, shard: &Self::Record) -> bool { - self.air.included(shard) - } -} - // Implement AIR directly on Chip, evaluating both execution and permutation constraints. impl Air for Chip where diff --git a/core/src/stark/debug.rs b/core/src/stark/debug.rs index e45cc48bf..e7aa51797 100644 --- a/core/src/stark/debug.rs +++ b/core/src/stark/debug.rs @@ -101,7 +101,7 @@ pub fn debug_constraints( if result.is_err() { eprintln!("local: {:?}", main_local); eprintln!("next: {:?}", main_next); - eprintln!("failed at row {} of chip {}", i, chip.name()); + eprintln!("failed at row {} of chip {}", i, chip.as_ref().name()); exit(1); } }); diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 38554a685..a9ae34d40 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -120,7 +120,7 @@ impl>> StarkMachine { self.chips .iter() .enumerate() - .filter(|(_, chip)| chip.preprocessed_width() > 0) + .filter(|(_, chip)| chip.as_ref().preprocessed_width() > 0) .map(|(i, _)| i) .collect() } @@ -132,7 +132,7 @@ impl>> StarkMachine { where 'a: 'b, { - self.chips.iter().filter(|chip| chip.included(shard)) + self.chips.iter().filter(|chip| chip.as_ref().included(shard)) } pub fn shard_chips_ordered<'a, 'b>( @@ -144,14 +144,14 @@ impl>> StarkMachine { { self.chips .iter() - .filter(|chip| chip_ordering.contains_key(&chip.name())) - .sorted_by_key(|chip| chip_ordering.get(&chip.name())) + .filter(|chip| chip_ordering.contains_key(&chip.as_ref().name())) + .sorted_by_key(|chip| chip_ordering.get(&chip.as_ref().name())) } pub fn chips_sorted_indices(&self, proof: &ShardProof) -> Vec> { self.chips() .iter() - .map(|chip| proof.chip_ordering.get(&chip.name()).cloned()) + .map(|chip| proof.chip_ordering.get(&chip.as_ref().name()).cloned()) .collect() } @@ -166,17 +166,17 @@ impl>> StarkMachine { self.chips() .iter() .map(|chip| { - let prep_trace = chip.generate_preprocessed_trace(program); + let prep_trace = chip.as_ref().generate_preprocessed_trace(program); // Assert that the chip width data is correct. let expected_width = prep_trace.as_ref().map_or(0, |t| t.width()); assert_eq!( expected_width, - chip.preprocessed_width(), + chip.as_ref().preprocessed_width(), "Incorrect number of preprocessed columns for chip {}", - chip.name() + chip.as_ref().name() ); - (chip.name(), prep_trace) + (chip.as_ref().name(), prep_trace) }) .filter(|(_, prep_trace)| prep_trace.is_some()) .map(|(name, prep_trace)| { @@ -251,7 +251,7 @@ impl>> StarkMachine { chips.iter().for_each(|chip| { let mut output = A::Record::default(); output.set_index(record.index()); - chip.generate_dependencies(&record, &mut output); + chip.as_ref().generate_dependencies(&record, &mut output); record.append(&mut output); }) }); @@ -375,13 +375,13 @@ impl>> StarkMachine { .iter() .map(|chip| { pk.chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .map(|index| &pk.traces[*index]) }) .collect::>(); let mut traces = chips .par_iter() - .map(|chip| chip.generate_trace(shard, &mut A::Record::default())) + .map(|chip| chip.as_ref().generate_trace(shard, &mut A::Record::default())) .zip(pre_traces) .collect::>(); @@ -424,7 +424,7 @@ impl>> StarkMachine { let total_width = trace_width + permutation_width; tracing::debug!( "{:<11} | Main Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<10} | Cells = {:<10}", - chips[i].name(), + chips[i].as_ref().name(), trace_width, permutation_width, traces[i].0.height(), @@ -436,7 +436,7 @@ impl>> StarkMachine { for i in 0..chips.len() { let permutation_trace = pk .chip_ordering - .get(&chips[i].name()) + .get(&chips[i].as_ref().name()) .map(|index| &pk.traces[*index]); debug_constraints::( chips[i], diff --git a/core/src/stark/prover.rs b/core/src/stark/prover.rs index 2252693eb..96b82a566 100644 --- a/core/src/stark/prover.rs +++ b/core/src/stark/prover.rs @@ -171,14 +171,14 @@ where shard_chips .par_iter() .map(|chip| { - let chip_name = chip.name(); + let chip_name = chip.as_ref().name(); // We need to create an outer span here because, for some reason, // the #[instrument] macro on the chip impl isn't attaching its span to `parent_span` // to avoid the unnecessary span, remove the #[instrument] macro. let trace = tracing::debug_span!(parent: &parent_span, "generate trace for chip", %chip_name) - .in_scope(|| chip.generate_trace(shard, &mut A::Record::default())); + .in_scope(|| chip.as_ref().generate_trace(shard, &mut A::Record::default())); (chip_name, trace) }) .collect::>() @@ -282,7 +282,7 @@ where .map(|(chip, main_trace)| { let preprocessed_trace = pk .chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .map(|&index| &pk.traces[index]); let perm_trace = chip.generate_permutation_trace( preprocessed_trace, @@ -307,7 +307,7 @@ where + permutation_width * >::D; tracing::debug!( "{:<15} | Main Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}", - chips[i].name(), + chips[i].as_ref().name(), trace_width, permutation_width * >::D, traces[i].height(), @@ -358,7 +358,7 @@ where .in_scope(|| { let preprocessed_trace_on_quotient_domains = pk .chip_ordering - .get(&chips[i].name()) + .get(&chips[i].as_ref().name()) .map_or_else(|| { RowMajorMatrix::new_col(vec![ SC::Val::zero(); @@ -506,7 +506,7 @@ where .enumerate() .map( |(i, ((((main, permutation), quotient), cumulative_sum), log_degree))| { - let preprocessed = pk.chip_ordering.get(&chips[i].name()).map_or( + let preprocessed = pk.chip_ordering.get(&chips[i].as_ref().name()).map_or( AirOpenedValues { local: vec![], next: vec![], diff --git a/core/src/stark/verifier.rs b/core/src/stark/verifier.rs index be766413d..0c90ef216 100644 --- a/core/src/stark/verifier.rs +++ b/core/src/stark/verifier.rs @@ -185,7 +185,7 @@ impl>> Verifier { ) { // Verify the shape of the opening arguments matches the expected values. Self::verify_opening_shape(chip, values) - .map_err(|e| VerificationError::OpeningShapeError(chip.name(), e))?; + .map_err(|e| VerificationError::OpeningShapeError(chip.as_ref().name(), e))?; // Verify the constraint evaluation. Self::verify_constraints( chip, @@ -197,7 +197,7 @@ impl>> Verifier { &permutation_challenges, public_values, ) - .map_err(|_e| VerificationError::OodEvaluationMismatch(chip.name()))?; + .map_err(|_e| VerificationError::OodEvaluationMismatch(chip.as_ref().name()))?; } Ok(()) } @@ -207,15 +207,15 @@ impl>> Verifier { opening: &ChipOpenedValues, ) -> Result<(), OpeningShapeError> { // Verify that the preprocessed width matches the expected value for the chip. - if opening.preprocessed.local.len() != chip.preprocessed_width() { + if opening.preprocessed.local.len() != chip.as_ref().preprocessed_width() { return Err(OpeningShapeError::PreprocessedWidthMismatch( - chip.preprocessed_width(), + chip.as_ref().preprocessed_width(), opening.preprocessed.local.len(), )); } - if opening.preprocessed.next.len() != chip.preprocessed_width() { + if opening.preprocessed.next.len() != chip.as_ref().preprocessed_width() { return Err(OpeningShapeError::PreprocessedWidthMismatch( - chip.preprocessed_width(), + chip.as_ref().preprocessed_width(), opening.preprocessed.next.len(), )); } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index b8467b8ce..e0781389e 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -101,7 +101,7 @@ where let chip_idx = machine .chips() .iter() - .rposition(|chip| &chip.name() == name) + .rposition(|chip| &chip.as_ref().name() == name) .unwrap(); let index = sorted_indices[chip_idx]; let opening = &opened_values.chips[index]; @@ -211,7 +211,7 @@ where for (i, sorted_chip) in sorted_chips.iter().enumerate() { for chip in machine.chips() { - if chip.name() == *sorted_chip { + if chip.as_ref().name() == *sorted_chip { let values = &opened_values.chips[i]; let trace_domain = &trace_domains[i]; let quotient_domain = "ient_domains[i]; @@ -294,7 +294,7 @@ pub fn build_wrap_circuit( let chips = outer_machine .shard_chips_ordered(&template_proof.chip_ordering) - .map(|chip| chip.name()) + .map(|chip| chip.as_ref().name()) .collect::>(); let sorted_indices = outer_machine @@ -303,7 +303,7 @@ pub fn build_wrap_circuit( .map(|chip| { template_proof .chip_ordering - .get(&chip.name()) + .get(&chip.as_ref().name()) .copied() .unwrap_or(usize::MAX) }) diff --git a/recursion/circuit/src/types.rs b/recursion/circuit/src/types.rs index 012a4bb05..d8ce02e7b 100644 --- a/recursion/circuit/src/types.rs +++ b/recursion/circuit/src/types.rs @@ -149,7 +149,7 @@ impl ChipOpening { local: vec![], next: vec![], }; - let preprocess_width = chip.preprocessed_width(); + let preprocess_width = chip.as_ref().preprocessed_width(); for i in 0..preprocess_width { preprocessed.local.push(opening.preprocessed.local[i]); preprocessed.next.push(opening.preprocessed.next[i]); diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index a33e60825..a87200a64 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -346,11 +346,11 @@ where // TODO CONSTRAIN: that the preprocessed chips get called with verify_constraints. builder.cycle_tracker("stage-e-verify-constraints"); for (i, chip) in machine.chips().iter().enumerate() { - let chip_name = chip.name(); + let chip_name = chip.as_ref().name(); tracing::debug!("verifying constraints for chip: {}", chip_name); let index = builder.get(&proof.sorted_idxs, i); - if chip.preprocessed_width() > 0 { + if chip.as_ref().preprocessed_width() > 0 { builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); } diff --git a/recursion/program/src/types.rs b/recursion/program/src/types.rs index 5380c1c74..5dcae7a0b 100644 --- a/recursion/program/src/types.rs +++ b/recursion/program/src/types.rs @@ -124,7 +124,7 @@ impl ChipOpening { local: vec![], next: vec![], }; - let preprocessed_width = chip.preprocessed_width(); + let preprocessed_width = chip.as_ref().preprocessed_width(); // Assert that the length of the dynamic arrays match the expected length of the vectors. builder.assert_usize_eq(preprocessed_width, opening.preprocessed.local.len()); builder.assert_usize_eq(preprocessed_width, opening.preprocessed.next.len()); diff --git a/recursion/program/src/utils.rs b/recursion/program/src/utils.rs index f7fa673f1..46f74b27d 100644 --- a/recursion/program/src/utils.rs +++ b/recursion/program/src/utils.rs @@ -225,7 +225,7 @@ pub(crate) fn get_preprocessed_data Date: Thu, 30 May 2024 19:59:53 -0400 Subject: [PATCH 2/4] refactor: Make Chips responsible for what they need from Records Currently the `MachineAir` trait requires every chip to define how it interacts with an `ExecutionRecord` associated type. This forces any chip that implements `MachineAir` to only interact with one specific implementation of an `ExecutionRecord` (as fixed at the moment of choosing that associated type). We would like those chips to be reused in more varied ways, and the following starts the changes towards accomplishing that. We set up the general framework with: ```rust /// A description of the events related to this AIR. pub trait WithEvents<'a>: Sized { /// output of a functional lens from the Record to /// refs of those events relative to the AIR. type Events: 'a; } pub trait EventLens WithEvents<'a>> { fn events(&self) -> ::Events; } pub trait MachineAir: BaseAir + for<'a> WithEvents<'a> { ... fn generate_trace>(&self, input: &EL, output: &mut ExecutionRecord) -> RowMajorMatrix; ... } ``` (the change to output is similar and pending) then in `AddSubChip`: ```rust impl<'a> WithEvents<'a> for AddSubChip { type Events = ( // add events &'a [AluEvent], // sub events &'a [AluEvent], ); } ``` In the `ExecutionRecord`: ```rust impl EventLens for ExecutionRecord { fn events(&self) -> ::Events { (&self.add_events, &self.sub_events) } } ``` In `generate_trace`: ```rust fn generate_trace>( &self, input: &EL, output: &mut EL, ) -> RowMajorMatrix { let (add_events, sub_events) = input.events(); // Generate the rows for the trace. let chunk_size = std::cmp::max( (add_events.len() + sub_events.len()) / num_cpus::get(), 1, ); let merged_events = add_events .iter() .chain(sub_events.iter()) .collect::>(); ... ``` --- core/src/air/machine.rs | 91 ++++++- core/src/alu/add_sub/mod.rs | 30 ++- core/src/alu/bitwise/mod.rs | 14 +- core/src/alu/divrem/mod.rs | 14 +- core/src/alu/lt/mod.rs | 14 +- core/src/alu/mul/mod.rs | 14 +- core/src/alu/sll/mod.rs | 16 +- core/src/alu/sr/mod.rs | 16 +- core/src/bytes/trace.rs | 25 +- core/src/cpu/trace.rs | 18 +- core/src/lookup/debug.rs | 24 +- core/src/memory/global.rs | 32 ++- core/src/memory/program.rs | 28 +- .../field/extensions/quadratic/mod.rs | 16 +- .../field/extensions/quadratic/sqrt.rs | 17 +- core/src/operations/field/field_den.rs | 16 +- .../operations/field/field_inner_product.rs | 16 +- core/src/operations/field/field_op.rs | 16 +- core/src/operations/field/field_sqrt.rs | 16 +- core/src/operations/field/params.rs | 32 --- core/src/program/mod.rs | 36 ++- core/src/runtime/record.rs | 241 +++++++++++++++++- core/src/stark/air.rs | 4 +- core/src/stark/machine.rs | 14 +- core/src/stark/prover.rs | 1 + core/src/stark/record.rs | 8 +- .../precompiles/blake3/compress/trace.rs | 14 +- .../precompiles/bls12_381/g1_decompress.rs | 12 +- .../syscall/precompiles/bls12_381/g2_add.rs | 14 +- .../precompiles/bls12_381/g2_double.rs | 14 +- .../src/syscall/precompiles/edwards/ed_add.rs | 19 +- .../precompiles/edwards/ed_decompress.rs | 18 +- core/src/syscall/precompiles/field/add.rs | 20 +- core/src/syscall/precompiles/field/mul.rs | 20 +- core/src/syscall/precompiles/field/sub.rs | 20 +- .../syscall/precompiles/keccak256/trace.rs | 27 +- .../src/syscall/precompiles/quad_field/add.rs | 19 +- .../src/syscall/precompiles/quad_field/mul.rs | 19 +- .../src/syscall/precompiles/quad_field/sub.rs | 19 +- .../precompiles/secp256k1/decompress.rs | 12 +- .../precompiles/sha256/compress/trace.rs | 16 +- .../precompiles/sha256/extend/trace.rs | 16 +- .../weierstrass/weierstrass_add.rs | 21 +- .../weierstrass/weierstrass_double.rs | 21 +- core/src/utils/ec/edwards/ed25519.rs | 11 +- core/src/utils/ec/mod.rs | 14 - core/src/utils/ec/weierstrass/bls12_381.rs | 75 ------ core/src/utils/ec/weierstrass/bn254.rs | 21 -- core/src/utils/ec/weierstrass/mod.rs | 23 +- core/src/utils/ec/weierstrass/secp256k1.rs | 20 -- core/src/utils/prove.rs | 2 +- derive/src/lib.rs | 174 ++++++++++++- recursion/core/src/cpu/trace.rs | 26 +- recursion/core/src/fri_fold/mod.rs | 32 ++- recursion/core/src/memory/air.rs | 43 ++-- recursion/core/src/memory/mod.rs | 5 +- recursion/core/src/multi/mod.rs | 71 ++++-- recursion/core/src/poseidon2/external.rs | 15 +- recursion/core/src/poseidon2/trace.rs | 24 +- recursion/core/src/poseidon2_wide/external.rs | 39 +-- recursion/core/src/program/mod.rs | 43 ++-- recursion/core/src/range_check/trace.rs | 23 +- recursion/core/src/runtime/record.rs | 71 +++++- recursion/core/src/stark/mod.rs | 32 ++- 64 files changed, 1267 insertions(+), 587 deletions(-) diff --git a/core/src/air/machine.rs b/core/src/air/machine.rs index 04d5042e4..939c63125 100644 --- a/core/src/air/machine.rs +++ b/core/src/air/machine.rs @@ -1,14 +1,91 @@ +use std::marker::PhantomData; + use p3_air::BaseAir; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; pub use sphinx_derive::MachineAir; -use crate::{runtime::Program, stark::MachineRecord}; +use crate::{ + runtime::Program, + stark::{Indexable, MachineRecord}, +}; + +/// A description of the events related to this AIR. +pub trait WithEvents<'a>: Sized { + /// output of a functional lens from the Record to + /// refs of those events relative to the AIR. + type Events: 'a; +} + +/// A trait intended for implementation on Records that may store events related to Chips, +/// The purpose of this trait is to provide a way to access the events relative to a specific +/// Chip, as specified by its `WithEvents` trait implementation. +/// +/// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 ) +pub trait EventLens WithEvents<'b>>: Indexable { + fn events<'a>(&'a self) -> >::Events; +} + +//////////////// Derive macro shaneanigans //////////////////////////////////////////////// +// This is *only* useful for the derive macros, you should *not* use this directly. +// +/// Hereafter, Lens composition explained pedantically: all this is saying is that +/// if I have an EventLens to T::Events, and a way (F) to deduce U::Events from that, +/// I can compose them to get an EventLens to U::Events. +pub struct Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens, +{ + record: &'a R, + projection: F, + _phantom: PhantomData, +} + +/// A constructor for the projection from T::Events to U::Events. +impl<'a, T, R, F> Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens, +{ + pub fn new(record: &'a R, projection: F) -> Self { + Self { + record, + projection, + _phantom: PhantomData, + } + } +} + +impl<'a, T, R, U, F> EventLens for Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens, + U: for<'b> WithEvents<'b>, + // see https://github.com/rust-lang/rust/issues/86702 for the empty parameter + F: for<'c> Fn(>::Events, &'c ()) -> >::Events, +{ + fn events<'c>(&'c self) -> >::Events { + let events: >::Events = self.record.events(); + (self.projection)(events, &()) + } +} + +impl<'a, T, R, F> Indexable for Proj<'a, T, R, F> +where + T: for<'b> WithEvents<'b>, + R: EventLens + Indexable, +{ + fn index(&self) -> u32 { + self.record.index() + } +} +//////////////// end of shenanigans destined for the derive macros. //////////////// /// An AIR that is part of a multi table AIR arithmetization. -pub trait MachineAir: BaseAir { +pub trait MachineAir: BaseAir + for<'a> WithEvents<'a> { /// The execution record containing events for producing the air trace. - type Record: MachineRecord; + type Record: MachineRecord + EventLens; type Program: MachineProgram; @@ -20,10 +97,14 @@ pub trait MachineAir: BaseAir { /// - `input` is the execution record containing the events to be written to the trace. /// - `output` is the execution record containing events that the `MachineAir` can add to /// the record such as byte lookup requests. - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix; + fn generate_trace>( + &self, + input: &EL, + output: &mut Self::Record, + ) -> RowMajorMatrix; /// Generate the dependencies for a given execution record. - fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) { + fn generate_dependencies>(&self, input: &EL, output: &mut Self::Record) { self.generate_trace(input, output); } diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 7ea030013..55e200a9c 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -12,13 +12,15 @@ use p3_maybe_rayon::prelude::ParallelSlice; use sphinx_derive::AlignedBorrow; use crate::{ - air::{AluAirBuilder, MachineAir, Word}, + air::{AluAirBuilder, EventLens, MachineAir, WithEvents, Word}, operations::AddOperation, runtime::{ExecutionRecord, Opcode, Program}, stark::MachineRecord, utils::pad_to_power_of_two, }; +use super::AluEvent; + /// The number of main trace columns for `AddSubChip`. pub const NUM_ADD_SUB_COLS: usize = size_of::>(); @@ -55,6 +57,15 @@ pub struct AddSubCols { pub is_sub: T, } +impl<'a> WithEvents<'a> for AddSubChip { + type Events = ( + // add events + &'a [AluEvent], + // sub events + &'a [AluEvent], + ); +} + impl MachineAir for AddSubChip { type Record = ExecutionRecord; @@ -64,20 +75,17 @@ impl MachineAir for AddSubChip { "AddSub".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, - output: &mut ExecutionRecord, + input: &EL, + output: &mut Self::Record, ) -> RowMajorMatrix { + let (add_events, sub_events) = input.events(); // Generate the rows for the trace. - let chunk_size = std::cmp::max( - (input.add_events.len() + input.sub_events.len()) / num_cpus::get(), - 1, - ); - let merged_events = input - .add_events + let chunk_size = std::cmp::max((add_events.len() + sub_events.len()) / num_cpus::get(), 1); + let merged_events = add_events .iter() - .chain(input.sub_events.iter()) + .chain(sub_events.iter()) .collect::>(); let rows_and_records = merged_events diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 8cb0b1958..2b3e067b2 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -9,13 +9,15 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir}; +use crate::air::{EventLens, WithEvents, Word}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `BitwiseChip`. pub const NUM_BITWISE_COLS: usize = size_of::>(); @@ -49,6 +51,10 @@ pub struct BitwiseCols { pub is_and: T, } +impl<'a> WithEvents<'a> for BitwiseChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for BitwiseChip { type Record = ExecutionRecord; @@ -58,14 +64,14 @@ impl MachineAir for BitwiseChip { "Bitwise".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let rows = input - .bitwise_events + .events() .iter() .map(|event| { let mut row = [F::zero(); NUM_BITWISE_COLS]; diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index ff212b086..e12472be7 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -75,8 +75,8 @@ use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; use self::utils::eval_abs_value; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation}; use crate::alu::AluEvent; use crate::bytes::event::ByteRecord; @@ -187,6 +187,10 @@ pub struct DivRemCols { pub is_real: T, } +impl<'a> WithEvents<'a> for DivRemChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for DivRemChip { type Record = ExecutionRecord; @@ -196,13 +200,13 @@ impl MachineAir for DivRemChip { "DivRem".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. - let divrem_events = &input.divrem_events; + let divrem_events = input.events(); let mut rows: Vec<[F; NUM_DIVREM_COLS]> = Vec::with_capacity(divrem_events.len()); for event in divrem_events { assert!( @@ -405,7 +409,7 @@ impl MachineAir for DivRemChip { row }; debug_assert!(padded_row_template.len() == NUM_DIVREM_COLS); - for i in input.divrem_events.len() * NUM_DIVREM_COLS..trace.values.len() { + for i in input.events().len() * NUM_DIVREM_COLS..trace.values.len() { trace.values[i] = padded_row_template[i % NUM_DIVREM_COLS]; } diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 8d965153b..c70b84f93 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -12,13 +12,15 @@ use p3_matrix::Matrix; use p3_maybe_rayon::prelude::*; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, BaseAirBuilder, ByteAirBuilder, MachineAir}; +use crate::air::{EventLens, WithEvents, Word}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `LtChip`. pub const NUM_LT_COLS: usize = size_of::>(); @@ -91,6 +93,10 @@ impl LtCols { } } +impl<'a> WithEvents<'a> for LtChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for LtChip { type Record = ExecutionRecord; @@ -100,14 +106,14 @@ impl MachineAir for LtChip { "Lt".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let (rows, new_byte_lookup_events): (Vec<_>, Vec<_>) = input - .lt_events + .events() .par_iter() .map(|event| { let mut row = [F::zero(); NUM_LT_COLS]; diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index bbdab712b..00ef5b946 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -44,8 +44,8 @@ use p3_maybe_rayon::prelude::ParallelIterator; use p3_maybe_rayon::prelude::ParallelSlice; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::alu::mul::utils::get_msb; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -54,6 +54,8 @@ use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::stark::MachineRecord; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `MulChip`. pub const NUM_MUL_COLS: usize = size_of::>(); @@ -121,6 +123,10 @@ pub struct MulCols { pub is_real: T, } +impl<'a> WithEvents<'a> for MulChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for MulChip { type Record = ExecutionRecord; @@ -130,12 +136,12 @@ impl MachineAir for MulChip { "Mul".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mul_events = input.mul_events.clone(); + let mul_events = input.events().clone(); // Compute the chunk size based on the number of events and the number of CPUs. let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1); diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index ecc84530d..812810f35 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -42,13 +42,15 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::bytes::event::ByteRecord; use crate::disassembler::WORD_SIZE; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `ShiftLeft`. pub const NUM_SHIFT_LEFT_COLS: usize = size_of::>(); @@ -96,6 +98,10 @@ pub struct ShiftLeftCols { pub is_real: T, } +impl<'a> WithEvents<'a> for ShiftLeft { + type Events = &'a [AluEvent]; +} + impl MachineAir for ShiftLeft { type Record = ExecutionRecord; @@ -105,14 +111,14 @@ impl MachineAir for ShiftLeft { "ShiftLeft".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let mut rows: Vec<[F; NUM_SHIFT_LEFT_COLS]> = vec![]; - let shift_left_events = input.shift_left_events.clone(); + let shift_left_events = input.events().clone(); for event in shift_left_events.iter() { let mut row = [F::zero(); NUM_SHIFT_LEFT_COLS]; let cols: &mut ShiftLeftCols = row.as_mut_slice().borrow_mut(); @@ -193,7 +199,7 @@ impl MachineAir for ShiftLeft { row }; debug_assert!(padded_row_template.len() == NUM_SHIFT_LEFT_COLS); - for i in input.shift_left_events.len() * NUM_SHIFT_LEFT_COLS..trace.values.len() { + for i in input.events().len() * NUM_SHIFT_LEFT_COLS..trace.values.len() { trace.values[i] = padded_row_template[i % NUM_SHIFT_LEFT_COLS]; } diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 2b52df3f3..331a49049 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -55,8 +55,8 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::Word; use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder}; +use crate::air::{EventLens, WithEvents, Word}; use crate::alu::sr::utils::{nb_bits_to_shift, nb_bytes_to_shift}; use crate::bytes::event::ByteRecord; use crate::bytes::utils::shr_carry; @@ -65,6 +65,8 @@ use crate::disassembler::WORD_SIZE; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::utils::pad_to_power_of_two; +use super::AluEvent; + /// The number of main trace columns for `ShiftRightChip`. pub const NUM_SHIFT_RIGHT_COLS: usize = size_of::>(); @@ -128,6 +130,10 @@ pub struct ShiftRightCols { pub is_real: T, } +impl<'a> WithEvents<'a> for ShiftRightChip { + type Events = &'a [AluEvent]; +} + impl MachineAir for ShiftRightChip { type Record = ExecutionRecord; @@ -137,14 +143,14 @@ impl MachineAir for ShiftRightChip { "ShiftRight".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. let mut rows: Vec<[F; NUM_SHIFT_RIGHT_COLS]> = Vec::new(); - let sr_events = input.shift_right_events.clone(); + let sr_events = input.events().clone(); for event in sr_events.iter() { assert!(event.opcode == Opcode::SRL || event.opcode == Opcode::SRA); let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS]; @@ -272,7 +278,7 @@ impl MachineAir for ShiftRightChip { row }; debug_assert!(padded_row_template.len() == NUM_SHIFT_RIGHT_COLS); - for i in input.shift_right_events.len() * NUM_SHIFT_RIGHT_COLS..trace.values.len() { + for i in input.events().len() * NUM_SHIFT_RIGHT_COLS..trace.values.len() { trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS]; } diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 22b820420..a792b0202 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -1,19 +1,24 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, collections::BTreeMap}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use super::{ columns::{ByteMultCols, NUM_BYTE_MULT_COLS, NUM_BYTE_PREPROCESSED_COLS}, - ByteChip, + ByteChip, ByteLookupEvent, }; use crate::{ - air::MachineAir, + air::{EventLens, MachineAir, WithEvents}, runtime::{ExecutionRecord, Program}, }; pub const NUM_ROWS: usize = 1 << 16; +impl<'a, F: Field> WithEvents<'a> for ByteChip { + // the byte lookups + type Events = &'a BTreeMap>; +} + impl MachineAir for ByteChip { type Record = ExecutionRecord; @@ -35,16 +40,20 @@ impl MachineAir for ByteChip { Some(trace) } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let shard = input.index; + let shard = input.index(); let (_, event_map) = Self::trace_and_map(shard); let mut trace = RowMajorMatrix::new( @@ -52,7 +61,7 @@ impl MachineAir for ByteChip { NUM_BYTE_MULT_COLS, ); - for (lookup, mult) in input.byte_lookups[&shard].iter() { + for (lookup, mult) in input.events()[&shard].iter() { let (row, index) = event_map[lookup]; let cols: &mut ByteMultCols = trace.row_mut(row).borrow_mut(); diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index 5d9955489..5aa4e02a1 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -7,7 +7,7 @@ use tracing::instrument; use super::columns::{CPU_COL_MAP, NUM_CPU_COLS}; use super::{CpuChip, CpuEvent}; -use crate::air::MachineAir; +use crate::air::{EventLens, MachineAir, WithEvents}; use crate::alu::AluEvent; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -18,6 +18,10 @@ use crate::memory::MemoryCols; use crate::runtime::{ExecutionRecord, Opcode, Program}; use crate::runtime::{MemoryRecordEnum, SyscallCode}; +impl<'a> WithEvents<'a> for CpuChip { + type Events = &'a [CpuEvent]; +} + impl MachineAir for CpuChip { type Record = ExecutionRecord; @@ -27,9 +31,9 @@ impl MachineAir for CpuChip { "CPU".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut new_alu_events = HashMap::new(); @@ -37,7 +41,7 @@ impl MachineAir for CpuChip { // Generate the trace rows for each event. let mut rows_with_events = input - .cpu_events + .events() .par_iter() .map(|op: &CpuEvent| self.event_to_row::(*op)) .collect::>(); @@ -78,11 +82,11 @@ impl MachineAir for CpuChip { } #[instrument(name = "generate cpu dependencies", level = "debug", skip_all)] - fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) { + fn generate_dependencies>(&self, input: &EL, output: &mut ExecutionRecord) { // Generate the trace rows for each event. - let chunk_size = std::cmp::max(input.cpu_events.len() / num_cpus::get(), 1); + let chunk_size = std::cmp::max(input.events().len() / num_cpus::get(), 1); let events = input - .cpu_events + .events() .par_chunks(chunk_size) .map(|ops: &[CpuEvent]| { let mut alu = HashMap::new(); diff --git a/core/src/lookup/debug.rs b/core/src/lookup/debug.rs index fa6108437..937f889c2 100644 --- a/core/src/lookup/debug.rs +++ b/core/src/lookup/debug.rs @@ -6,7 +6,9 @@ use p3_matrix::Matrix; use super::InteractionKind; use crate::air::MachineAir; -use crate::stark::{MachineChip, StarkGenericConfig, StarkMachine, StarkProvingKey, Val}; +use crate::stark::{ + MachineChip, StarkGenericConfig, StarkMachine, StarkProvingKey, Val, +}; #[derive(Debug)] pub struct InteractionData { @@ -44,7 +46,10 @@ fn field_to_int(x: F) -> i32 { } } -pub fn debug_interactions>>( +pub fn debug_interactions< + SC: StarkGenericConfig, + A: MachineAir>, +>( chip: &MachineChip, pkey: &StarkProvingKey, record: &A::Record, @@ -56,7 +61,9 @@ pub fn debug_interactions>>( let mut key_to_vec_data = BTreeMap::new(); let mut key_to_count = BTreeMap::new(); - let trace = chip.as_ref().generate_trace(record, &mut A::Record::default()); + let trace = chip + .as_ref() + .generate_trace(record, &mut A::Record::default()); let mut pre_traces = pkey.traces.clone(); let mut preprocessed_trace = pkey .chip_ordering @@ -150,10 +157,17 @@ where .or_insert((SC::Val::zero(), BTreeMap::new())); entry.0 += *value; total += *value; - *entry.1.entry(chip.as_ref().name()).or_insert(SC::Val::zero()) += *value; + *entry + .1 + .entry(chip.as_ref().name()) + .or_insert(SC::Val::zero()) += *value; } } - tracing::info!("{} chip has {} distinct events", chip.as_ref().name(), total_events); + tracing::info!( + "{} chip has {} distinct events", + chip.as_ref().name(), + total_events + ); } tracing::info!("Final counts below."); diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index 9a2ac38c3..0003d5d02 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -10,7 +10,9 @@ use sphinx_derive::AlignedBorrow; use super::MemoryInitializeFinalizeEvent; use crate::{ - air::{AirInteraction, BaseAirBuilder, MachineAir, Word, WordAirBuilder}, + air::{ + AirInteraction, BaseAirBuilder, EventLens, MachineAir, WithEvents, Word, WordAirBuilder, + }, runtime::{ExecutionRecord, Program}, utils::pad_to_power_of_two, }; @@ -40,6 +42,15 @@ impl BaseAir for MemoryChip { } } +impl<'a> WithEvents<'a> for MemoryChip { + type Events = ( + // initialize events + &'a [MemoryInitializeFinalizeEvent], + // finalize events + &'a [MemoryInitializeFinalizeEvent], + ); +} + impl MachineAir for MemoryChip { type Record = ExecutionRecord; @@ -52,15 +63,16 @@ impl MachineAir for MemoryChip { } } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mut memory_events = match self.kind { - MemoryChipType::Initialize => input.memory_initialize_events.clone(), - MemoryChipType::Finalize => input.memory_finalize_events.clone(), - }; + let mut memory_events: Vec = match self.kind { + MemoryChipType::Initialize => input.events().0, + MemoryChipType::Finalize => input.events().1, + } + .to_vec(); memory_events.sort_by_key(|event| event.addr); let rows: Vec<[F; 8]> = (0..memory_events.len()) // TODO: change this back to par_iter .map(|i| { @@ -93,7 +105,11 @@ impl MachineAir for MemoryChip { trace } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } diff --git a/core/src/memory/program.rs b/core/src/memory/program.rs index f311c2da9..074603b72 100644 --- a/core/src/memory/program.rs +++ b/core/src/memory/program.rs @@ -5,10 +5,11 @@ use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; +use std::collections::BTreeMap; use sphinx_derive::AlignedBorrow; -use crate::air::{AirInteraction, BaseAirBuilder, PublicValues}; +use crate::air::{AirInteraction, BaseAirBuilder, EventLens, PublicValues, WithEvents}; use crate::air::{MachineAir, Word}; use crate::operations::IsZeroOperation; use crate::runtime::{ExecutionRecord, Program}; @@ -49,6 +50,10 @@ impl MemoryProgramChip { } } +impl<'a> WithEvents<'a> for MemoryProgramChip { + type Events = &'a BTreeMap; +} + impl MachineAir for MemoryProgramChip { type Record = ExecutionRecord; @@ -91,23 +96,22 @@ impl MachineAir for MemoryProgramChip { Some(trace) } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let program_memory_addrs = input - .program - .memory_image - .keys() - .copied() - .collect::>(); + let program_memory_addrs = input.events().keys().copied().collect::>(); - let mult = if input.index == 1 { + let mult = if input.index() == 1 { F::one() } else { F::zero() @@ -120,7 +124,7 @@ impl MachineAir for MemoryProgramChip { let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS]; let cols: &mut MemoryProgramMultCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = mult; - IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1); + IsZeroOperation::populate(&mut cols.is_first_shard, input.index() - 1); row }) diff --git a/core/src/operations/field/extensions/quadratic/mod.rs b/core/src/operations/field/extensions/quadratic/mod.rs index a9b04765b..8dd54cd25 100644 --- a/core/src/operations/field/extensions/quadratic/mod.rs +++ b/core/src/operations/field/extensions/quadratic/mod.rs @@ -365,7 +365,7 @@ impl QuadFieldOpCols { #[cfg(test)] mod tests { - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; @@ -411,6 +411,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for QuadFieldOpChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &() + } + } + impl MachineAir for QuadFieldOpChip

{ type Record = ExecutionRecord; @@ -420,9 +430,9 @@ mod tests { format!("QuadFieldOp{:?}", self.operation) } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/extensions/quadratic/sqrt.rs b/core/src/operations/field/extensions/quadratic/sqrt.rs index 94cfd2eee..13071eb53 100644 --- a/core/src/operations/field/extensions/quadratic/sqrt.rs +++ b/core/src/operations/field/extensions/quadratic/sqrt.rs @@ -111,8 +111,7 @@ mod tests { use super::QuadFieldSqrtCols; use crate::air::MachineAir; - use crate::air::WordAirBuilder; - + use crate::air::{EventLens, WordAirBuilder}; use crate::bytes::event::ByteRecord; use crate::operations::field::params::{FieldParameters, Limbs}; use crate::runtime::ExecutionRecord; @@ -151,6 +150,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for QuadSqrtChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &() + } + } + impl MachineAir for QuadSqrtChip

{ type Record = ExecutionRecord; @@ -160,9 +169,9 @@ mod tests { "QuadSqrtChip".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let num_test_cols = size_of::>(); diff --git a/core/src/operations/field/field_den.rs b/core/src/operations/field/field_den.rs index 1a7580818..5ce71a573 100644 --- a/core/src/operations/field/field_den.rs +++ b/core/src/operations/field/field_den.rs @@ -152,7 +152,7 @@ mod tests { mem::size_of, }; - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use num::{bigint::RandBigInt, BigUint}; use p3_air::{Air, BaseAir}; use p3_baby_bear::BabyBear; @@ -197,6 +197,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for FieldDenChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &() + } + } + impl MachineAir for FieldDenChip

{ type Record = ExecutionRecord; @@ -206,9 +216,9 @@ mod tests { "FieldDen".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/field_inner_product.rs b/core/src/operations/field/field_inner_product.rs index aec17555e..45a12ef30 100644 --- a/core/src/operations/field/field_inner_product.rs +++ b/core/src/operations/field/field_inner_product.rs @@ -152,7 +152,7 @@ mod tests { use super::{FieldInnerProductCols, Limbs}; - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use crate::{ air::MachineAir, utils::ec::weierstrass::{bls12_381::Bls12381BaseField, secp256k1::Secp256k1BaseField}, @@ -185,6 +185,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for FieldIpChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &() + } + } + impl MachineAir for FieldIpChip

{ type Record = ExecutionRecord; @@ -194,9 +204,9 @@ mod tests { "FieldInnerProduct".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index cd9cb5fc5..41ac4209a 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -243,7 +243,7 @@ mod tests { use crate::{air::MachineAir, utils::ec::weierstrass::bls12_381::Bls12381BaseField}; - use crate::air::WordAirBuilder; + use crate::air::{EventLens, WordAirBuilder}; use crate::bytes::event::ByteRecord; use crate::operations::field::params::FieldParameters; use crate::runtime::ExecutionRecord; @@ -278,6 +278,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for FieldOpChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &() + } + } + impl MachineAir for FieldOpChip

{ type Record = ExecutionRecord; @@ -287,9 +297,9 @@ mod tests { format!("FieldOp{:?}", self.operation) } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index 47e9f78ec..324307246 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -119,8 +119,8 @@ mod tests { use super::{FieldSqrtCols, Limbs}; - use crate::air::MachineAir; use crate::air::WordAirBuilder; + use crate::air::{EventLens, MachineAir}; use crate::bytes::event::ByteRecord; use crate::operations::field::params::{FieldParameters, DEFAULT_NUM_LIMBS_T}; use crate::runtime::ExecutionRecord; @@ -151,6 +151,16 @@ mod tests { } } + impl<'a, P: FieldParameters> crate::air::WithEvents<'a> for EdSqrtChip

{ + type Events = &'a (); + } + + impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &() + } + } + impl> MachineAir for EdSqrtChip

{ @@ -162,9 +172,9 @@ mod tests { "EdSqrtChip".to_string() } - fn generate_trace( + fn generate_trace>( &self, - _: &ExecutionRecord, + _: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rng = thread_rng(); diff --git a/core/src/operations/field/params.rs b/core/src/operations/field/params.rs index 0d72d4222..a56da381b 100644 --- a/core/src/operations/field/params.rs +++ b/core/src/operations/field/params.rs @@ -14,8 +14,6 @@ use num::BigUint; use p3_field::Field; use crate::air::Polynomial; -use crate::runtime::ExecutionRecord; -use crate::syscall::precompiles; use crate::utils::ec::utils::biguint_from_limbs; pub const NB_BITS_PER_LIMB: usize = 8; @@ -179,36 +177,6 @@ pub trait FieldParameters: } } -pub trait WithFieldAddition: FieldParameters { - fn add_events(record: &ExecutionRecord) -> &[precompiles::field::add::FieldAddEvent]; -} - -pub trait WithFieldSubtraction: FieldParameters { - fn sub_events(record: &ExecutionRecord) -> &[precompiles::field::sub::FieldSubEvent]; -} - -pub trait WithFieldMultiplication: FieldParameters { - fn mul_events(record: &ExecutionRecord) -> &[precompiles::field::mul::FieldMulEvent]; -} - -pub trait WithQuadFieldAddition: FieldParameters { - fn add_events( - record: &ExecutionRecord, - ) -> &[precompiles::quad_field::add::QuadFieldAddEvent]; -} - -pub trait WithQuadFieldSubtraction: FieldParameters { - fn sub_events( - record: &ExecutionRecord, - ) -> &[precompiles::quad_field::sub::QuadFieldSubEvent]; -} - -pub trait WithQuadFieldMultiplication: FieldParameters { - fn mul_events( - record: &ExecutionRecord, - ) -> &[precompiles::quad_field::mul::QuadFieldMulEvent]; -} - #[cfg(test)] mod tests { use crate::operations::field::params::FieldParameters; diff --git a/core/src/program/mod.rs b/core/src/program/mod.rs index 64b16d6d2..0caad196b 100644 --- a/core/src/program/mod.rs +++ b/core/src/program/mod.rs @@ -10,8 +10,11 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_derive::AlignedBorrow; use crate::{ - air::{MachineAir, ProgramAirBuilder}, - cpu::columns::{InstructionCols, OpcodeSelectorCols}, + air::{EventLens, MachineAir, ProgramAirBuilder, WithEvents}, + cpu::{ + columns::{InstructionCols, OpcodeSelectorCols}, + CpuEvent, + }, runtime::{ExecutionRecord, Program}, utils::pad_to_power_of_two, }; @@ -49,6 +52,15 @@ impl ProgramChip { } } +impl<'a> WithEvents<'a> for ProgramChip { + type Events = ( + // CPU events + &'a [CpuEvent], + // the Program + &'a Program, + ); +} + impl MachineAir for ProgramChip { type Record = ExecutionRecord; @@ -92,21 +104,26 @@ impl MachineAir for ProgramChip { Some(trace) } - fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) { + fn generate_dependencies>( + &self, + _input: &EL, + _output: &mut ExecutionRecord, + ) { // Do nothing since this chip has no dependencies. } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Generate the trace rows for each event. + let (cpu_events, program) = input.events(); // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); - input.cpu_events.iter().for_each(|event| { + cpu_events.iter().for_each(|event| { let pc = event.pc; instruction_counts .entry(pc) @@ -114,17 +131,16 @@ impl MachineAir for ProgramChip { .or_insert(1); }); - let rows = input - .program + let rows = program .instructions .clone() .into_iter() .enumerate() .map(|(i, _)| { - let pc = input.program.pc_base + (i as u32 * 4); + let pc = program.pc_base + (i as u32 * 4); let mut row = [F::zero(); NUM_PROGRAM_MULT_COLS]; let cols: &mut ProgramMultiplicityCols = row.as_mut_slice().borrow_mut(); - cols.shard = F::from_canonical_u32(input.index); + cols.shard = F::from_canonical_u32(input.index()); cols.multiplicity = F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0)); row diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 2e95b6c42..66d308e79 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -5,12 +5,10 @@ use std::{ }; use itertools::Itertools; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use serde::{Deserialize, Serialize}; use super::{program::Program, Opcode}; -use crate::alu::AluEvent; -use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::cpu::CpuEvent; use crate::runtime::MemoryInitializeFinalizeEvent; @@ -24,6 +22,29 @@ use crate::syscall::precompiles::keccak256::KeccakPermuteEvent; use crate::syscall::precompiles::sha256::{ShaCompressEvent, ShaExtendEvent}; use crate::syscall::precompiles::{ECAddEvent, ECDoubleEvent}; use crate::utils::env; +use crate::{ + air::EventLens, + alu::AluEvent, + memory::MemoryProgramChip, + stark::{ + AddSubChip, BitwiseChip, Blake3CompressInnerChip, ByteChip, CpuChip, DivRemChip, + Ed25519Parameters, EdAddAssignChip, EdDecompressChip, FieldAddChip, FieldMulChip, + FieldSubChip, KeccakPermuteChip, LtChip, MemoryChip, MulChip, ProgramChip, + QuadFieldAddChip, QuadFieldMulChip, QuadFieldSubChip, ShaCompressChip, ShaExtendChip, + ShiftLeft, ShiftRightChip, WeierstrassAddAssignChip, WeierstrassDoubleAssignChip, + }, + syscall::precompiles::{ + bls12_381::{ + g1_decompress::Bls12381G1DecompressChip, g2_add::Bls12381G2AffineAddChip, + g2_double::Bls12381G2AffineDoubleChip, + }, + secp256k1::decompress::Secp256k1DecompressChip, + }, + utils::ec::{ + edwards::ed25519::Ed25519, + weierstrass::{bls12_381::Bls12381, bn254::Bn254, secp256k1::Secp256k1}, + }, +}; use crate::{ air::PublicValues, operations::field::params::FieldParameters, @@ -35,6 +56,7 @@ use crate::{ }, utils::ec::weierstrass::bls12_381::Bls12381BaseField, }; +use crate::{bytes::event::ByteRecord, stark::Indexable}; /// A record of the execution of a program. Contains event data for everything that happened during /// the execution of the shard. @@ -122,6 +144,211 @@ pub struct ExecutionRecord { pub public_values: PublicValues, } +// Event lenses connect the record to the events relative to a particular chip +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + (&self.add_events, &self.sub_events) + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.bitwise_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.divrem_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.lt_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.mul_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.shift_left_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.shift_right_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.byte_lookups + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.cpu_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + (&self.memory_initialize_events, &self.memory_finalize_events) + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.program.memory_image + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + (&self.cpu_events, &self.program) + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.sha_extend_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.sha_compress_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.blake3_compress_inner_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.keccak_permute_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.bls12381_g1_decompress_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.secp256k1_decompress_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.bls12381_g2_add_events + } +} + +impl EventLens for ExecutionRecord { + fn events(&self) -> ::Events { + &self.bls12381_g2_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_fp_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_fp_sub_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_fp_mul_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_fp2_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_fp2_sub_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_fp2_mul_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.secp256k1_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_g1_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bn254_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.secp256k1_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bls12381_g1_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.bn254_double_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.ed_add_events + } +} + +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents>::Events { + &self.ed_decompress_events + } +} + pub struct ShardingConfig { pub shard_size: usize, pub add_len: usize, @@ -185,12 +412,14 @@ impl Default for ShardingConfig { } } -impl MachineRecord for ExecutionRecord { - type Config = ShardingConfig; - +impl Indexable for ExecutionRecord { fn index(&self) -> u32 { self.index } +} + +impl MachineRecord for ExecutionRecord { + type Config = ShardingConfig; fn set_index(&mut self, index: u32) { self.index = index; diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index 7f7f7b93b..d4f0b55fb 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -12,6 +12,7 @@ use crate::StarkGenericConfig; use p3_field::PrimeField32; pub use riscv_chips::*; use tracing::instrument; +use sphinx_derive::{EventLens, WithEvents}; /// A module for importing all the different RISC-V chips. pub(crate) mod riscv_chips { @@ -42,7 +43,8 @@ pub(crate) mod riscv_chips { /// This enum contains all the different AIRs that are used in the Sp1 RISC-V IOP. Each variant is /// a different AIR that is used to encode a different part of the RISC-V execution, and the /// different AIR variants have a joint lookup argument. -#[derive(MachineAir)] +#[derive(WithEvents, EventLens, MachineAir)] +#[record_type = "crate::runtime::ExecutionRecord"] pub enum RiscvAir { /// An AIR that contains a preprocessed program table and a lookup for the instructions. Program(ProgramChip), diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index a9ae34d40..f265266d5 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -25,7 +25,7 @@ use crate::air::MachineProgram; use crate::lookup::debug_interactions_with_all_chips; use crate::lookup::InteractionBuilder; use crate::lookup::InteractionKind; -use crate::stark::record::MachineRecord; +use crate::stark::record::{Indexable, MachineRecord}; use crate::stark::DebugConstraintBuilder; use crate::stark::ProverConstraintFolder; use crate::stark::ShardProof; @@ -105,7 +105,8 @@ impl Debug for StarkVerifyingKey { } } -impl>> StarkMachine { +impl>> StarkMachine +{ /// Get an array containing a `ChipRef` for all the chips of this RISC-V STARK machine. pub fn chips(&self) -> &[MachineChip] { &self.chips @@ -132,7 +133,9 @@ impl>> StarkMachine { where 'a: 'b, { - self.chips.iter().filter(|chip| chip.as_ref().included(shard)) + self.chips + .iter() + .filter(|chip| chip.as_ref().included(shard)) } pub fn shard_chips_ordered<'a, 'b>( @@ -381,7 +384,10 @@ impl>> StarkMachine { .collect::>(); let mut traces = chips .par_iter() - .map(|chip| chip.as_ref().generate_trace(shard, &mut A::Record::default())) + .map(|chip| { + chip.as_ref() + .generate_trace(shard, &mut A::Record::default()) + }) .zip(pre_traces) .collect::>(); diff --git a/core/src/stark/prover.rs b/core/src/stark/prover.rs index 96b82a566..ff131e397 100644 --- a/core/src/stark/prover.rs +++ b/core/src/stark/prover.rs @@ -26,6 +26,7 @@ use super::{StarkProvingKey, VerifierConstraintFolder}; use crate::air::MachineAir; use crate::lookup::InteractionBuilder; use crate::stark::record::MachineRecord; +use crate::stark::Indexable; use crate::stark::MachineChip; use crate::stark::PackedChallenge; use crate::stark::ProverConstraintFolder; diff --git a/core/src/stark/record.rs b/core/src/stark/record.rs index 226188514..c4daa1b31 100644 --- a/core/src/stark/record.rs +++ b/core/src/stark/record.rs @@ -2,10 +2,12 @@ use std::collections::HashMap; use p3_field::AbstractField; -pub trait MachineRecord: Default + Sized + Send + Sync + Clone { - type Config: Default; - +pub trait Indexable { fn index(&self) -> u32; +} + +pub trait MachineRecord: Default + Sized + Send + Sync + Clone + Indexable { + type Config: Default; fn set_index(&mut self, index: u32); diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs index d151d9c20..166bcd79d 100644 --- a/core/src/syscall/precompiles/blake3/compress/trace.rs +++ b/core/src/syscall/precompiles/blake3/compress/trace.rs @@ -3,10 +3,12 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; +use super::Blake3CompressInnerEvent; use super::{ columns::Blake3CompressInnerCols, G_INDEX, G_INPUT_SIZE, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, }; +use crate::air::{EventLens, WithEvents}; use crate::bytes::event::ByteRecord; use crate::{ air::MachineAir, @@ -17,6 +19,10 @@ use crate::{ utils::pad_rows, }; +impl<'a> WithEvents<'a> for Blake3CompressInnerChip { + type Events = &'a [Blake3CompressInnerEvent]; +} + impl MachineAir for Blake3CompressInnerChip { type Record = ExecutionRecord; type Program = Program; @@ -25,17 +31,17 @@ impl MachineAir for Blake3CompressInnerChip { "Blake3CompressInner".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for i in 0..input.blake3_compress_inner_events.len() { - let event = input.blake3_compress_inner_events[i].clone(); + for i in 0..input.events().len() { + let event = input.events()[i].clone(); let shard = event.shard; let mut clk = event.clk; for round in 0..ROUND_COUNT { diff --git a/core/src/syscall/precompiles/bls12_381/g1_decompress.rs b/core/src/syscall/precompiles/bls12_381/g1_decompress.rs index a27d44059..b01c18c46 100644 --- a/core/src/syscall/precompiles/bls12_381/g1_decompress.rs +++ b/core/src/syscall/precompiles/bls12_381/g1_decompress.rs @@ -11,7 +11,7 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use serde::{Deserialize, Serialize}; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, ByteAirBuilder, MemoryAirBuilder}; +use crate::air::{AluAirBuilder, ByteAirBuilder, EventLens, MemoryAirBuilder, WithEvents}; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::operations::field::params::FieldParameters; use crate::operations::field::range::FieldRangeCols; @@ -229,6 +229,10 @@ impl Bls12381G1DecompressChip { } } +impl<'a> WithEvents<'a> for Bls12381G1DecompressChip { + type Events = &'a [Bls12381G1DecompressEvent]; +} + impl MachineAir for Bls12381G1DecompressChip { type Record = ExecutionRecord; type Program = Program; @@ -237,16 +241,16 @@ impl MachineAir for Bls12381G1DecompressChip { "Bls12381G1Decompress".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for event in input.bls12381_g1_decompress_events.iter() { + for event in input.events().iter() { let mut row = [F::zero(); size_of::>()]; let cols: &mut Bls12381G1DecompressCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/bls12_381/g2_add.rs b/core/src/syscall/precompiles/bls12_381/g2_add.rs index 0a1711058..a5e4ba019 100644 --- a/core/src/syscall/precompiles/bls12_381/g2_add.rs +++ b/core/src/syscall/precompiles/bls12_381/g2_add.rs @@ -1,4 +1,4 @@ -use crate::air::{AluAirBuilder, MachineAir, MemoryAirBuilder}; +use crate::air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}; use crate::bytes::event::ByteRecord; use crate::memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}; use crate::operations::field::extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}; @@ -276,6 +276,10 @@ impl BaseAir for Bls12381G2AffineAddChip { } } +impl<'a> WithEvents<'a> for Bls12381G2AffineAddChip { + type Events = &'a [Bls12381G2AffineAddEvent]; +} + impl MachineAir for Bls12381G2AffineAddChip { type Record = ExecutionRecord; type Program = Program; @@ -284,13 +288,17 @@ impl MachineAir for Bls12381G2AffineAddChip { "G2AffineAdd".to_string() } - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + fn generate_trace>( + &self, + input: &EL, + output: &mut Self::Record, + ) -> RowMajorMatrix { let mut rows = vec![]; let mut new_byte_lookup_events = Vec::new(); let width = >::width(self); - for event in &input.bls12381_g2_add_events { + for event in input.events() { let mut row = vec![F::zero(); width]; let cols: &mut Bls12381G2AffineAddCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/bls12_381/g2_double.rs b/core/src/syscall/precompiles/bls12_381/g2_double.rs index 29715c5dc..b43253add 100644 --- a/core/src/syscall/precompiles/bls12_381/g2_double.rs +++ b/core/src/syscall/precompiles/bls12_381/g2_double.rs @@ -1,4 +1,4 @@ -use crate::air::MachineAir; +use crate::air::{EventLens, MachineAir, WithEvents}; use crate::bytes::event::ByteRecord; use crate::memory::{MemoryCols, MemoryWriteCols}; use crate::operations::field::extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}; @@ -225,6 +225,10 @@ impl BaseAir for Bls12381G2AffineDoubleChip { } } +impl<'a> WithEvents<'a> for Bls12381G2AffineDoubleChip { + type Events = &'a [Bls12381G2AffineDoubleEvent]; +} + impl MachineAir for Bls12381G2AffineDoubleChip { type Record = ExecutionRecord; type Program = Program; @@ -233,14 +237,18 @@ impl MachineAir for Bls12381G2AffineDoubleChip { "Bls12381G2AffineDoubleChip".to_string() } - fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix { + fn generate_trace>( + &self, + input: &EL, + output: &mut Self::Record, + ) -> RowMajorMatrix { let mut rows: Vec> = vec![]; let width = >::width(self); let mut new_byte_lookup_events = Vec::new(); - for event in &input.bls12381_g2_double_events { + for event in input.events() { let mut row = vec![F::zero(); width]; let cols: &mut Bls12381G2AffineDoubleCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index cbba5d5ac..9d4739fc3 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -17,7 +17,6 @@ use p3_maybe_rayon::prelude::IntoParallelRefIterator; use p3_maybe_rayon::prelude::ParallelIterator; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::memory::MemoryCols; @@ -44,6 +43,10 @@ use crate::utils::ec::EllipticCurve; use crate::utils::limbs_from_prev_access; use crate::utils::pad_vec_rows; use crate::{air::MachineAir, utils::ec::EllipticCurveParameters}; +use crate::{ + air::{AluAirBuilder, EventLens, MemoryAirBuilder, WithEvents}, + syscall::precompiles::ECAddEvent, +}; pub const NUM_ED_ADD_COLS: usize = size_of::>(); @@ -140,7 +143,13 @@ impl< } } -impl MachineAir for EdAddAssignChip { +impl<'a, E: EllipticCurve + EdwardsParameters> WithEvents<'a> for EdAddAssignChip { + type Events = &'a [ECAddEvent]; +} + +impl MachineAir for EdAddAssignChip + where ExecutionRecord: EventLens>, +{ type Record = ExecutionRecord; type Program = Program; @@ -149,13 +158,13 @@ impl MachineAir for Ed "EdAddAssign".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let (mut rows, new_byte_lookup_events): (Vec>, Vec>) = input - .ed_add_events + .events() .par_iter() .map(|event| { let mut row = vec![F::zero(); size_of::>()]; diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index 2d65816fc..cc7d75a9e 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -13,8 +13,10 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use serde::{Deserialize, Serialize}; use sphinx_derive::AlignedBorrow; +use crate::air::EventLens; use crate::air::MachineAir; use crate::air::MachineAirBuilder; +use crate::air::WithEvents; use crate::air::{AluAirBuilder, BaseAirBuilder, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; @@ -331,7 +333,13 @@ impl EdDecompressChip { } } -impl MachineAir for EdDecompressChip { +impl<'a, E: EdwardsParameters> WithEvents<'a> for EdDecompressChip { + type Events = &'a [EdDecompressEvent]; +} + +impl MachineAir for EdDecompressChip + where ExecutionRecord: EventLens> +{ type Record = ExecutionRecord; type Program = Program; @@ -340,15 +348,15 @@ impl MachineAir for EdDecompressChip>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); - for i in 0..input.ed_decompress_events.len() { - let event = &input.ed_decompress_events[i]; + for i in 0..input.events().len() { + let event = &input.events()[i]; let mut row = vec![F::zero(); size_of::>()]; let cols: &mut EdDecompressCols = row.as_mut_slice().borrow_mut(); cols.populate::(event, output); diff --git a/core/src/syscall/precompiles/field/add.rs b/core/src/syscall/precompiles/field/add.rs index 82b25ac77..3e8d09d04 100644 --- a/core/src/syscall/precompiles/field/add.rs +++ b/core/src/syscall/precompiles/field/add.rs @@ -15,12 +15,12 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ field_op::{FieldOpCols, FieldOperation}, - params::{FieldParameters, FieldType, Limbs, WithFieldAddition, WORDS_FIELD_ELEMENT}, + params::{FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT}, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, syscall::precompiles::SyscallContext, @@ -113,25 +113,31 @@ pub fn create_fp_add_event( } } -impl MachineAir for FieldAddChip { +impl<'a, FP: FieldParameters> WithEvents<'a> for FieldAddChip { + type Events = &'a [FieldAddEvent]; +} + +impl MachineAir for FieldAddChip + where ExecutionRecord: EventLens> +{ type Record = ExecutionRecord; type Program = Program; fn name(&self) -> String { match FP::FIELD_TYPE { FieldType::Bls12381 => "Bls12381FieldAdd".to_string(), - _ => panic!("Unsupported field"), + _ => unreachable!("Unsupported field"), } } #[instrument(name = "generate field add trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::add_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/field/mul.rs b/core/src/syscall/precompiles/field/mul.rs index 32e13f59a..85912f3be 100644 --- a/core/src/syscall/precompiles/field/mul.rs +++ b/core/src/syscall/precompiles/field/mul.rs @@ -15,12 +15,12 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ field_op::{FieldOpCols, FieldOperation}, - params::{FieldParameters, FieldType, Limbs, WithFieldMultiplication, WORDS_FIELD_ELEMENT}, + params::{FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT}, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, syscall::precompiles::SyscallContext, @@ -113,8 +113,12 @@ pub fn create_fp_mul_event( } } -impl MachineAir - for FieldMulChip +impl<'a, FP: FieldParameters> WithEvents<'a> for FieldMulChip { + type Events = &'a [FieldMulEvent]; +} + +impl MachineAir for FieldMulChip +where ExecutionRecord: EventLens> { type Record = ExecutionRecord; type Program = Program; @@ -122,18 +126,18 @@ impl MachineAir< fn name(&self) -> String { match FP::FIELD_TYPE { FieldType::Bls12381 => "Bls12381FieldMul".to_string(), - _ => panic!("Unsupported field"), + _ => unreachable!("Unsupported field"), } } #[instrument(name = "generate field mul trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::mul_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/field/sub.rs b/core/src/syscall/precompiles/field/sub.rs index 15725343b..d4c804afc 100644 --- a/core/src/syscall/precompiles/field/sub.rs +++ b/core/src/syscall/precompiles/field/sub.rs @@ -15,12 +15,12 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ field_op::{FieldOpCols, FieldOperation}, - params::{FieldParameters, FieldType, Limbs, WithFieldSubtraction, WORDS_FIELD_ELEMENT}, + params::{FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT}, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, syscall::precompiles::SyscallContext, @@ -113,8 +113,12 @@ pub fn create_fp_sub_event( } } -impl MachineAir - for FieldSubChip +impl<'a, FP: FieldParameters> WithEvents<'a> for FieldSubChip { + type Events = &'a [FieldSubEvent]; +} + +impl MachineAir for FieldSubChip + where ExecutionRecord: EventLens> { type Record = ExecutionRecord; type Program = Program; @@ -122,18 +126,18 @@ impl MachineAir fn name(&self) -> String { match FP::FIELD_TYPE { FieldType::Bls12381 => "Bls12381FieldSub".to_string(), - _ => panic!("Unsupported field"), + _ => unreachable!("Unsupported field"), } } #[instrument(name = "generate field sub trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::sub_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index 5c2c4e8fa..dd814d1c0 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -5,16 +5,22 @@ use p3_keccak_air::{generate_trace_rows, NUM_KECCAK_COLS, NUM_ROUNDS}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}; +use crate::air::{EventLens, WithEvents}; use crate::bytes::event::ByteRecord; use crate::{runtime::Program, stark::MachineRecord}; use crate::{air::MachineAir, runtime::ExecutionRecord}; +use super::KeccakPermuteEvent; use super::{ columns::{KeccakMemCols, NUM_KECCAK_MEM_COLS}, KeccakPermuteChip, STATE_SIZE, }; +impl<'a> WithEvents<'a> for KeccakPermuteChip { + type Events = &'a [KeccakPermuteEvent]; +} + impl MachineAir for KeccakPermuteChip { type Record = ExecutionRecord; type Program = Program; @@ -23,36 +29,35 @@ impl MachineAir for KeccakPermuteChip { "KeccakPermute".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let num_events = input.keccak_permute_events.len(); + let num_events = input.events().len(); let chunk_size = std::cmp::max(num_events / num_cpus::get(), 1); // Use par_chunks to generate the trace in parallel. - let rows_and_records = (0..num_events) - .collect::>() + let rows_and_records = input + .events() .par_chunks(chunk_size) - .map(|chunk| { + .map(|chunk_events| { let mut record = ExecutionRecord::default(); let mut new_byte_lookup_events = Vec::new(); // First generate all the p3_keccak_air traces at once. - let perm_inputs = chunk + let perm_inputs = chunk_events .iter() - .map(|event_index| input.keccak_permute_events[*event_index].pre_state) + .map(|event| event.pre_state) .collect::>(); let p3_keccak_trace = generate_trace_rows::(perm_inputs); - let rows = chunk + let rows = chunk_events .iter() .enumerate() - .flat_map(|(index_in_chunk, event_index)| { + .flat_map(|(index_in_chunk, event)| { let mut rows = Vec::new(); - let event = &input.keccak_permute_events[*event_index]; let start_clk = event.clk; let shard = event.shard; diff --git a/core/src/syscall/precompiles/quad_field/add.rs b/core/src/syscall/precompiles/quad_field/add.rs index 3ee56f9b9..d909f045f 100644 --- a/core/src/syscall/precompiles/quad_field/add.rs +++ b/core/src/syscall/precompiles/quad_field/add.rs @@ -15,14 +15,13 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}, params::{ - FieldParameters, FieldType, Limbs, WithQuadFieldAddition, WORDS_FIELD_ELEMENT, - WORDS_QUAD_EXT_FIELD_ELEMENT, + FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT, WORDS_QUAD_EXT_FIELD_ELEMENT, }, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, @@ -140,8 +139,12 @@ pub fn create_fp2_add_event( } } -impl MachineAir - for QuadFieldAddChip +impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldAddChip { + type Events = &'a [QuadFieldAddEvent]; +} + +impl MachineAir for QuadFieldAddChip + where ExecutionRecord: EventLens> { type Record = ExecutionRecord; type Program = Program; @@ -154,13 +157,13 @@ impl MachineAir } #[instrument(name = "generate bls12381 fp2 add trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::add_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/quad_field/mul.rs b/core/src/syscall/precompiles/quad_field/mul.rs index 06a05b040..b4e683406 100644 --- a/core/src/syscall/precompiles/quad_field/mul.rs +++ b/core/src/syscall/precompiles/quad_field/mul.rs @@ -15,14 +15,13 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}, params::{ - FieldParameters, FieldType, Limbs, WithQuadFieldMultiplication, WORDS_FIELD_ELEMENT, - WORDS_QUAD_EXT_FIELD_ELEMENT, + FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT, WORDS_QUAD_EXT_FIELD_ELEMENT, }, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, @@ -149,8 +148,12 @@ pub fn create_fp2_mul_event( } } -impl MachineAir - for QuadFieldMulChip +impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldMulChip { + type Events = &'a [QuadFieldMulEvent]; +} + +impl MachineAir for QuadFieldMulChip + where ExecutionRecord: EventLens> { type Record = ExecutionRecord; type Program = Program; @@ -163,13 +166,13 @@ impl Machine } #[instrument(name = "generate bls12381 fp2 mul trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::mul_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/quad_field/sub.rs b/core/src/syscall/precompiles/quad_field/sub.rs index e21390c18..05bbd4943 100644 --- a/core/src/syscall/precompiles/quad_field/sub.rs +++ b/core/src/syscall/precompiles/quad_field/sub.rs @@ -15,14 +15,13 @@ use sphinx_derive::AlignedBorrow; use tracing::instrument; use crate::{ - air::{AluAirBuilder, MachineAir, MemoryAirBuilder}, + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, bytes::{event::ByteRecord, ByteLookupEvent}, memory::{MemoryCols, MemoryReadCols, MemoryWriteCols}, operations::field::{ extensions::quadratic::{QuadFieldOpCols, QuadFieldOperation}, params::{ - FieldParameters, FieldType, Limbs, WithQuadFieldSubtraction, WORDS_FIELD_ELEMENT, - WORDS_QUAD_EXT_FIELD_ELEMENT, + FieldParameters, FieldType, Limbs, WORDS_FIELD_ELEMENT, WORDS_QUAD_EXT_FIELD_ELEMENT, }, }, runtime::{ExecutionRecord, MemoryReadRecord, MemoryWriteRecord, Program, SyscallCode}, @@ -140,8 +139,12 @@ pub fn create_fp2_sub_event( } } -impl MachineAir - for QuadFieldSubChip +impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldSubChip { + type Events = &'a [QuadFieldSubEvent]; +} + +impl MachineAir for QuadFieldSubChip + where ExecutionRecord: EventLens> { type Record = ExecutionRecord; type Program = Program; @@ -154,13 +157,13 @@ impl MachineAir } #[instrument(name = "generate bls12381 fp2 sub trace", level = "debug", skip_all)] - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the field type. - let events = FP::sub_events(input); + let events = input.events(); let (mut rows, new_byte_lookup_events): (Vec<_>, Vec>) = events .par_iter() diff --git a/core/src/syscall/precompiles/secp256k1/decompress.rs b/core/src/syscall/precompiles/secp256k1/decompress.rs index 5eff9bafc..723a4047b 100644 --- a/core/src/syscall/precompiles/secp256k1/decompress.rs +++ b/core/src/syscall/precompiles/secp256k1/decompress.rs @@ -11,7 +11,7 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use serde::{Deserialize, Serialize}; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, ByteAirBuilder, MemoryAirBuilder}; +use crate::air::{AluAirBuilder, ByteAirBuilder, EventLens, MemoryAirBuilder, WithEvents}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::operations::field::range::FieldRangeCols; @@ -214,6 +214,10 @@ impl Secp256k1DecompressChip { } } +impl<'a> WithEvents<'a> for Secp256k1DecompressChip { + type Events = &'a [Secp256k1DecompressEvent]; +} + impl MachineAir for Secp256k1DecompressChip { type Record = ExecutionRecord; type Program = Program; @@ -222,16 +226,16 @@ impl MachineAir for Secp256k1DecompressChip { "Secp256k1Decompress".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for event in input.secp256k1_decompress_events.iter() { + for event in input.events().iter() { let mut row = [F::zero(); size_of::>()]; let cols: &mut Secp256k1DecompressCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index f0c2409db..059d14bf8 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -5,15 +5,19 @@ use p3_matrix::dense::RowMajorMatrix; use super::{ columns::{ShaCompressCols, NUM_SHA_COMPRESS_COLS}, - ShaCompressChip, SHA_COMPRESS_K, + ShaCompressChip, ShaCompressEvent, SHA_COMPRESS_K, }; use crate::{ - air::{MachineAir, Word}, + air::{EventLens, MachineAir, WithEvents, Word}, bytes::event::ByteRecord, runtime::{ExecutionRecord, Program}, utils::pad_rows, }; +impl<'a> WithEvents<'a> for ShaCompressChip { + type Events = &'a [ShaCompressEvent]; +} + impl MachineAir for ShaCompressChip { type Record = ExecutionRecord; @@ -23,16 +27,16 @@ impl MachineAir for ShaCompressChip { "ShaCompress".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for i in 0..input.sha_compress_events.len() { - let mut event = input.sha_compress_events[i].clone(); + for i in 0..input.events().len() { + let mut event = input.events()[i].clone(); let shard = event.shard; let og_h = event.h; diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index faadfa889..719ba0af6 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -3,13 +3,17 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; -use super::{ShaExtendChip, ShaExtendCols, NUM_SHA_EXTEND_COLS}; +use super::{ShaExtendChip, ShaExtendCols, ShaExtendEvent, NUM_SHA_EXTEND_COLS}; use crate::{ - air::MachineAir, + air::{EventLens, MachineAir, WithEvents}, bytes::event::ByteRecord, runtime::{ExecutionRecord, Program}, }; +impl<'a> WithEvents<'a> for ShaExtendChip { + type Events = &'a [ShaExtendEvent]; +} + impl MachineAir for ShaExtendChip { type Record = ExecutionRecord; @@ -19,16 +23,16 @@ impl MachineAir for ShaExtendChip { "ShaExtend".to_string() } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let mut new_byte_lookup_events = Vec::new(); - for i in 0..input.sha_extend_events.len() { - let event = input.sha_extend_events[i].clone(); + for i in 0..input.events().len() { + let event = input.events()[i].clone(); let shard = event.shard; for j in 0..48usize { let mut row = [F::zero(); NUM_SHA_EXTEND_COLS]; diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index 35b6158b5..f17e07f56 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -16,7 +16,6 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, MachineAir, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::memory::MemoryCols; @@ -34,9 +33,12 @@ use crate::utils::ec::AffinePoint; use crate::utils::ec::BaseLimbWidth; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurve; -use crate::utils::ec::WithAddition; use crate::utils::limbs_from_prev_access; use crate::utils::pad_vec_rows; +use crate::{ + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, + syscall::precompiles::ECAddEvent, +}; pub const fn num_weierstrass_add_cols() -> usize { size_of::>() @@ -147,8 +149,13 @@ impl WeierstrassAddAssignChip { } } -impl MachineAir +impl<'a, E: EllipticCurve + WeierstrassParameters> WithEvents<'a> for WeierstrassAddAssignChip { + type Events = &'a [ECAddEvent<::NB_LIMBS>]; +} + +impl MachineAir for WeierstrassAddAssignChip + where ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -158,17 +165,17 @@ impl M CurveType::Secp256k1 => "Secp256k1AddAssign".to_string(), CurveType::Bn254 => "Bn254AddAssign".to_string(), CurveType::Bls12381 => "Bls12381AddAssign".to_string(), - _ => panic!("Unsupported curve"), + _ => unreachable!("Unsupported curve"), } } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the curve type. - let events = E::add_events(input); + let events = input.events(); let mut rows = Vec::new(); diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index c12c5b2cd..4b7098843 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -17,7 +17,6 @@ use p3_maybe_rayon::prelude::ParallelIterator; use p3_maybe_rayon::prelude::ParallelSlice; use sphinx_derive::AlignedBorrow; -use crate::air::{AluAirBuilder, MachineAir, MemoryAirBuilder}; use crate::bytes::event::ByteRecord; use crate::bytes::ByteLookupEvent; use crate::memory::MemoryCols; @@ -37,9 +36,12 @@ use crate::utils::ec::AffinePoint; use crate::utils::ec::BaseLimbWidth; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurve; -use crate::utils::ec::WithDoubling; use crate::utils::limbs_from_prev_access; use crate::utils::pad_vec_rows; +use crate::{ + air::{AluAirBuilder, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, + syscall::precompiles::ECDoubleEvent, +}; pub const fn num_weierstrass_double_cols() -> usize { size_of::>() @@ -171,8 +173,15 @@ impl WeierstrassDoubleAssignChip { } } -impl MachineAir +impl<'a, E: EllipticCurve + WeierstrassParameters> WithEvents<'a> + for WeierstrassDoubleAssignChip +{ + type Events = &'a [ECDoubleEvent<::NB_LIMBS>]; +} + +impl MachineAir for WeierstrassDoubleAssignChip + where ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; @@ -186,13 +195,13 @@ impl M } } - fn generate_trace( + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { // collects the events based on the curve type. - let events = E::double_events(input); + let events = input.events(); let chunk_size = std::cmp::max(events.len() / num_cpus::get(), 1); diff --git a/core/src/utils/ec/edwards/ed25519.rs b/core/src/utils/ec/edwards/ed25519.rs index 3f11847a0..2fd5877bb 100644 --- a/core/src/utils/ec/edwards/ed25519.rs +++ b/core/src/utils/ec/edwards/ed25519.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use crate::operations::field::params::{FieldParameters, FieldType, DEFAULT_NUM_LIMBS_T}; use crate::utils::ec::edwards::{EdwardsCurve, EdwardsParameters}; -use crate::utils::ec::{AffinePoint, CurveType, EllipticCurveParameters, WithAddition}; +use crate::utils::ec::{AffinePoint, CurveType, EllipticCurveParameters}; pub type Ed25519 = EdwardsCurve; @@ -39,15 +39,6 @@ impl EllipticCurveParameters for Ed25519Parameters { const CURVE_TYPE: CurveType = CurveType::Ed25519; } -impl WithAddition for Ed25519Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.ed_add_events - } -} - impl EdwardsParameters for Ed25519Parameters { const D: Array::NB_LIMBS> = Array([ 30883, 4953, 19914, 30187, 55467, 16705, 2637, 112, 59544, 30585, 16505, 36039, 65139, diff --git a/core/src/utils/ec/mod.rs b/core/src/utils/ec/mod.rs index 0ead9b6e8..ceaf4a92a 100644 --- a/core/src/utils/ec/mod.rs +++ b/core/src/utils/ec/mod.rs @@ -14,8 +14,6 @@ use crate::air::WORD_SIZE; use crate::operations::field::params::FieldParameters; use crate::operations::field::params::WORDS_CURVEPOINT; use crate::operations::field::params::WORDS_FIELD_ELEMENT; -use crate::runtime::ExecutionRecord; -use crate::syscall::precompiles::{ECAddEvent, ECDoubleEvent}; pub const DEFAULT_NUM_WORDS_FIELD_ELEMENT: usize = 8; pub const DEFAULT_NUM_BYTES_FIELD_ELEMENT: usize = DEFAULT_NUM_WORDS_FIELD_ELEMENT * WORD_SIZE; @@ -122,18 +120,6 @@ pub trait EllipticCurveParameters: const CURVE_TYPE: CurveType; } -pub trait WithAddition: EllipticCurveParameters { - fn add_events( - record: &ExecutionRecord, - ) -> &[ECAddEvent<::NB_LIMBS>]; -} - -pub trait WithDoubling: EllipticCurveParameters { - fn double_events( - record: &ExecutionRecord, - ) -> &[ECDoubleEvent<::NB_LIMBS>]; -} - /// An interface for elliptic curve groups. pub trait EllipticCurve: EllipticCurveParameters { /// Adds two different points on the curve. diff --git a/core/src/utils/ec/weierstrass/bls12_381.rs b/core/src/utils/ec/weierstrass/bls12_381.rs index 30689f2fb..4606edb6e 100644 --- a/core/src/utils/ec/weierstrass/bls12_381.rs +++ b/core/src/utils/ec/weierstrass/bls12_381.rs @@ -10,12 +10,6 @@ use std::ops::Neg; use super::{SwCurve, WeierstrassParameters}; use crate::operations::field::params::FieldParameters; use crate::operations::field::params::FieldType; -use crate::operations::field::params::WithFieldAddition; -use crate::operations::field::params::WithFieldMultiplication; -use crate::operations::field::params::WithFieldSubtraction; -use crate::operations::field::params::WithQuadFieldAddition; -use crate::operations::field::params::WithQuadFieldMultiplication; -use crate::operations::field::params::WithQuadFieldSubtraction; use crate::runtime::Syscall; use crate::stark::WeierstrassAddAssignChip; use crate::stark::WeierstrassDoubleAssignChip; @@ -23,8 +17,6 @@ use crate::syscall::precompiles::create_ec_add_event; use crate::syscall::precompiles::create_ec_double_event; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurveParameters; -use crate::utils::ec::WithAddition; -use crate::utils::ec::WithDoubling; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] /// Bls12381 curve parameter @@ -188,78 +180,11 @@ pub fn bls12381_double(p: &[BigUint; 4]) -> [BigUint; 4] { ] } -impl WithFieldAddition for Bls12381BaseField { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::field::add::FieldAddEvent] { - &record.bls12381_fp_add_events - } -} - -impl WithFieldSubtraction for Bls12381BaseField { - fn sub_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::field::sub::FieldSubEvent] { - &record.bls12381_fp_sub_events - } -} - -impl WithFieldMultiplication for Bls12381BaseField { - fn mul_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::field::mul::FieldMulEvent] { - &record.bls12381_fp_mul_events - } -} - -impl WithQuadFieldAddition for Bls12381BaseField { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::quad_field::add::QuadFieldAddEvent] { - &record.bls12381_fp2_add_events - } -} - -impl WithQuadFieldSubtraction for Bls12381BaseField { - fn sub_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::quad_field::sub::QuadFieldSubEvent] { - &record.bls12381_fp2_sub_events - } -} - -impl WithQuadFieldMultiplication for Bls12381BaseField { - fn mul_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::quad_field::mul::QuadFieldMulEvent] { - &record.bls12381_fp2_mul_events - } -} - impl EllipticCurveParameters for Bls12381Parameters { type BaseField = Bls12381BaseField; const CURVE_TYPE: CurveType = CurveType::Bls12381; } -impl WithAddition for Bls12381Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.bls12381_g1_add_events - } -} - -impl WithDoubling for Bls12381Parameters { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - &record.bls12381_g1_double_events - } -} - /// The WeierstrassParameters for BLS12-381 G1 impl WeierstrassParameters for Bls12381Parameters { const A: Array::NB_LIMBS> = Array([ diff --git a/core/src/utils/ec/weierstrass/bn254.rs b/core/src/utils/ec/weierstrass/bn254.rs index 4c414da86..3550da138 100644 --- a/core/src/utils/ec/weierstrass/bn254.rs +++ b/core/src/utils/ec/weierstrass/bn254.rs @@ -13,8 +13,6 @@ use crate::syscall::precompiles::create_ec_add_event; use crate::syscall::precompiles::create_ec_double_event; use crate::utils::ec::CurveType; use crate::utils::ec::EllipticCurveParameters; -use crate::utils::ec::WithAddition; -use crate::utils::ec::WithDoubling; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] /// Bn254 curve parameter @@ -58,25 +56,6 @@ impl EllipticCurveParameters for Bn254Parameters { const CURVE_TYPE: CurveType = CurveType::Bn254; } -impl WithAddition for Bn254Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.bn254_add_events - } -} - -impl WithDoubling for Bn254Parameters { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - &record.bn254_double_events - } -} - impl Syscall for WeierstrassAddAssignChip { fn execute( &self, diff --git a/core/src/utils/ec/weierstrass/mod.rs b/core/src/utils/ec/weierstrass/mod.rs index c86ee396c..78a620486 100644 --- a/core/src/utils/ec/weierstrass/mod.rs +++ b/core/src/utils/ec/weierstrass/mod.rs @@ -5,9 +5,7 @@ use serde::{Deserialize, Serialize}; use super::CurveType; use crate::operations::field::params::FieldParameters; use crate::utils::ec::utils::biguint_to_bits_le; -use crate::utils::ec::{ - AffinePoint, EllipticCurve, EllipticCurveParameters, WithAddition, WithDoubling, -}; +use crate::utils::ec::{AffinePoint, EllipticCurve, EllipticCurveParameters}; pub mod bls12_381; pub mod bn254; @@ -77,25 +75,6 @@ impl EllipticCurveParameters for SwCurve { const CURVE_TYPE: CurveType = E::CURVE_TYPE; } -impl WithAddition for SwCurve { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - E::add_events(record) - } -} - -impl WithDoubling for SwCurve { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - E::double_events(record) - } -} - impl EllipticCurve for SwCurve { fn ec_add(p: &AffinePoint, q: &AffinePoint) -> AffinePoint { p.sw_add(q) diff --git a/core/src/utils/ec/weierstrass/secp256k1.rs b/core/src/utils/ec/weierstrass/secp256k1.rs index c12dc14d4..3f3d9087d 100644 --- a/core/src/utils/ec/weierstrass/secp256k1.rs +++ b/core/src/utils/ec/weierstrass/secp256k1.rs @@ -19,7 +19,6 @@ use crate::{ runtime::Syscall, stark::{WeierstrassAddAssignChip, WeierstrassDoubleAssignChip}, syscall::precompiles::{create_ec_add_event, create_ec_double_event}, - utils::ec::{WithAddition, WithDoubling}, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -56,25 +55,6 @@ impl EllipticCurveParameters for Secp256k1Parameters { const CURVE_TYPE: CurveType = CurveType::Secp256k1; } -impl WithAddition for Secp256k1Parameters { - fn add_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECAddEvent<::NB_LIMBS>] - { - &record.secp256k1_add_events - } -} - -impl WithDoubling for Secp256k1Parameters { - fn double_events( - record: &crate::runtime::ExecutionRecord, - ) -> &[crate::syscall::precompiles::ECDoubleEvent< - ::NB_LIMBS, - >] { - &record.secp256k1_double_events - } -} - impl WeierstrassParameters for Secp256k1Parameters { const A: Array::NB_LIMBS> = Array([ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index e11636af2..3ad430a7d 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -24,7 +24,7 @@ use crate::stark::StarkVerifyingKey; use crate::stark::Val; use crate::stark::VerifierConstraintFolder; use crate::stark::{Com, PcsProverData, RiscvAir, ShardProof, StarkProvingKey, UniConfig}; -use crate::stark::{MachineRecord, StarkMachine}; +use crate::stark::{Indexable, MachineRecord, StarkMachine}; use crate::utils::env; use crate::{ runtime::{Program, Runtime}, diff --git a/derive/src/lib.rs b/derive/src/lib.rs index a8f57b225..37bcf8c8d 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -85,6 +85,141 @@ pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { TokenStream::from(methods) } +/// Derives WithEvents for an enum which every variant has one field, +/// each of which implements WithEvents. +/// +/// The derived implementation is a tuple of the Events of each variant, +/// in the variant declaration order. That is, because the chip could be *any* variant, +/// it requires being able to provide for *all* event types consumable by each chip. +#[proc_macro_derive(WithEvents, attributes(sphinx_core_path))] +pub fn with_events_air_derive(input: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let type_params = generics.type_params(); + let const_params = generics.const_params(); + + let sphinx_core_path = find_sphinx_core_path(&ast.attrs); + let (_, ty_generics, where_clause) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(_) | Data::Union(_) => unimplemented!("Only Enums are supported yet"), + Data::Enum(e) => { + let fields = e + .variants + .iter() + .map(|variant| { + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!( + fields.next().is_none(), + "Only one field per variant is supported" + ); + field + }) + .collect::>(); + + let field_ty_events = fields.iter().map(|field| { + let field_ty = &field.ty; + quote! { + <#field_ty as #sphinx_core_path::air::WithEvents<'a>>::Events + } + }); + quote!{ + impl <'a, #(#type_params),*, #(#const_params),*> #sphinx_core_path::air::WithEvents<'a> for #name #ty_generics #where_clause { + type Events = (#(#field_ty_events,)*); + } + }.into() + } + } +} + +fn get_type_from_attrs(attrs: &[syn::Attribute], attr_name: &str) -> syn::Result { + attrs + .iter() + .find(|attr| attr.path.is_ident(attr_name)) + .map_or_else( + || { + Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!("Could not find attribute {}", attr_name), + )) + }, + |attr| match attr.parse_meta()? { + syn::Meta::NameValue(meta) => { + if let syn::Lit::Str(lit) = &meta.lit { + Ok(lit.clone()) + } else { + Err(syn::Error::new_spanned( + meta, + &format!("Could not parse {} attribute", attr_name)[..], + )) + } + } + bad => Err(syn::Error::new_spanned( + bad, + &format!("Could not parse {} attribute", attr_name)[..], + )), + }, + ) +} + +/// Derives EventLens for an enum which every variant has one field, +/// each of which the input record implements EventLens for. +/// +/// The derived implementation is a delegation to the events of each variant, +/// in the variant declaration order. +/// +/// This makes the strong assumption that the `WithEvents` trait is implemented on `Self` +/// as a tuple of the `WithEvents` of each variant. This is what the `WithEvents` derive macro does. +#[proc_macro_derive(EventLens, attributes(sphinx_core_path, record_type))] +pub fn event_lens_air_derive(input: TokenStream) -> TokenStream { + let ast: DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let record_type = get_type_from_attrs(&ast.attrs, "record_type").unwrap(); + let rec_ty: syn::Type = record_type.parse().unwrap(); + + let sphinx_core_path = find_sphinx_core_path(&ast.attrs); + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(_) | Data::Union(_) => unimplemented!("Only Enums are supported yet"), + Data::Enum(e) => { + let fields = e + .variants + .iter() + .map(|variant| { + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!( + fields.next().is_none(), + "Only one field per variant is supported" + ); + field + }) + .collect::>(); + + let field_events = fields.iter().map(|field| { + let field_ty = &field.ty; + quote! { + #sphinx_core_path::air::EventLens::<#field_ty>::events(self) + } + }); + let res = quote! { + impl #impl_generics #sphinx_core_path::air::EventLens<#name #ty_generics> for #rec_ty { + fn events(&self) -> <#name #ty_generics as #sphinx_core_path::air::WithEvents>::Events { + (#(#field_events,)*) + } + } + }; + res.into() + } + } +} + #[proc_macro_derive( MachineAir, attributes(sphinx_core_path, execution_record_path, program_path, builder_path) @@ -94,11 +229,17 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let name = &ast.ident; let generics = &ast.generics; + let type_params = generics.type_params(); + let ty_params = quote!{ #(#type_params),* }; + let const_params = generics.const_params(); + let co_params = quote! { #(#const_params),* }; + let sphinx_core_path = find_sphinx_core_path(&ast.attrs); let execution_record_path = find_execution_record_path(&ast.attrs); let program_path = find_program_path(&ast.attrs); let builder_path = find_builder_path(&ast.attrs); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let turbo_ty = ty_generics.as_turbofish(); match &ast.data { Data::Struct(_) => unimplemented!("Structs are not supported yet"), @@ -111,7 +252,10 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let mut fields = variant.fields.iter(); let field = fields.next().unwrap(); - assert!(fields.next().is_none(), "Only one field is supported"); + assert!( + fields.next().is_none(), + "Only one field per variant is supported" + ); (variant_name, field) }) .collect::>(); @@ -158,17 +302,29 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } }); - let generate_trace_arms = variants.iter().map(|(variant_name, field)| { + let generate_trace_arms = variants.iter().enumerate().map(|(i, (variant_name, field))| { let field_ty = &field.ty; + + let idx = syn::Index::from(i); quote! { - #name::#variant_name(x) => <#field_ty as #sphinx_core_path::air::MachineAir>::generate_trace(x, input, output) + #name::#variant_name(x) => { + fn f <'c, #ty_params, #co_params> (evs: <#name #ty_generics as #sphinx_core_path::air::WithEvents<'c>>::Events, _v: &'c ()) -> <#field_ty as #sphinx_core_path::air::WithEvents<'c>>::Events { evs.#idx } + + <#field_ty as #sphinx_core_path::air::MachineAir>::generate_trace(x, &#sphinx_core_path::air::Proj::new(input, f #turbo_ty), output) + } } }); - let generate_dependencies_arms = variants.iter().map(|(variant_name, field)| { + let generate_dependencies_arms = variants.iter().enumerate().map(|(i, (variant_name, field))| { let field_ty = &field.ty; + let idx = syn::Index::from(i); + quote! { - #name::#variant_name(x) => <#field_ty as #sphinx_core_path::air::MachineAir>::generate_dependencies(x, input, output) + #name::#variant_name(x) => { + fn f <'c, #ty_params, #co_params> (evs: <#name #ty_generics as #sphinx_core_path::air::WithEvents<'c>>::Events, _v: &'c ()) -> <#field_ty as #sphinx_core_path::air::WithEvents<'c>>::Events { evs.#idx } + + <#field_ty as #sphinx_core_path::air::MachineAir>::generate_dependencies(x, &#sphinx_core_path::air::Proj::new(input, f #turbo_ty), output) + } } }); @@ -206,9 +362,9 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } } - fn generate_trace( + fn generate_trace>( &self, - input: &#execution_record_path, + input: &EL, output: &mut #execution_record_path, ) -> p3_matrix::dense::RowMajorMatrix { match self { @@ -216,9 +372,9 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { } } - fn generate_dependencies( + fn generate_dependencies>( &self, - input: &#execution_record_path, + input: &EL, output: &mut #execution_record_path, ) { match self { diff --git a/recursion/core/src/cpu/trace.rs b/recursion/core/src/cpu/trace.rs index 8b80d8976..197bdff7b 100644 --- a/recursion/core/src/cpu/trace.rs +++ b/recursion/core/src/cpu/trace.rs @@ -1,9 +1,9 @@ use std::borrow::BorrowMut; -use p3_field::{extension::BinomiallyExtendable, PrimeField32}; +use p3_field::{extension::BinomiallyExtendable, Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_core::{ - air::{BinomialExtension, MachineAir}, + air::{BinomialExtension, EventLens, MachineAir, WithEvents}, utils::pad_rows_fixed, }; use tracing::instrument; @@ -14,9 +14,15 @@ use crate::{ runtime::{ExecutionRecord, Opcode, RecursionProgram, D}, }; -use super::{CpuChip, CpuCols, CPU_COL_MAP, NUM_CPU_COLS}; +use super::{CpuChip, CpuCols, CpuEvent, CPU_COL_MAP, NUM_CPU_COLS}; -impl> MachineAir for CpuChip { +impl<'a, F: Field> WithEvents<'a> for CpuChip { + type Events = &'a [CpuEvent]; +} + +impl> MachineAir for CpuChip + where ExecutionRecord: EventLens> +{ type Record = ExecutionRecord; type Program = RecursionProgram; @@ -24,18 +30,18 @@ impl> MachineAir for CpuChip { "CPU".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // There are no dependencies, since we do it all in the runtime. This is just a placeholder. } - #[instrument(name = "generate cpu trace", level = "debug", skip_all, fields(rows = input.cpu_events.len()))] - fn generate_trace( + #[instrument(name = "generate cpu trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( &self, - input: &ExecutionRecord, + input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = input - .cpu_events + .events() .iter() .map(|event| { let mut row = [F::zero(); NUM_CPU_COLS]; @@ -124,7 +130,7 @@ impl> MachineAir for CpuChip { let mut trace = RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_CPU_COLS); - for i in input.cpu_events.len()..trace.height() { + for i in input.events().len()..trace.height() { trace.values[i * NUM_CPU_COLS + CPU_COL_MAP.clk] = F::from_canonical_u32(4) * F::from_canonical_usize(i); trace.values[i * NUM_CPU_COLS + CPU_COL_MAP.instruction.imm_b] = diff --git a/recursion/core/src/fri_fold/mod.rs b/recursion/core/src/fri_fold/mod.rs index 943ba8d14..389280a08 100644 --- a/recursion/core/src/fri_fold/mod.rs +++ b/recursion/core/src/fri_fold/mod.rs @@ -4,13 +4,14 @@ use crate::air::RecursionMemoryAirBuilder; use crate::memory::{MemoryReadCols, MemoryReadSingleCols, MemoryReadWriteCols}; use crate::runtime::Opcode; use core::borrow::Borrow; +use std::marker::PhantomData; use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, BinomialExtension, ExtensionAirBuilder, MachineAir}; +use sphinx_core::air::{BaseAirBuilder, BinomialExtension, EventLens, ExtensionAirBuilder, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use sphinx_derive::AlignedBorrow; use std::borrow::BorrowMut; @@ -23,8 +24,9 @@ use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::>(); #[derive(Default)] -pub struct FriFoldChip { +pub struct FriFoldChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } #[derive(Debug, Clone)] @@ -83,13 +85,17 @@ pub struct FriFoldCols { pub is_real: T, } -impl BaseAir for FriFoldChip { +impl BaseAir for FriFoldChip { fn width(&self) -> usize { NUM_FRI_FOLD_COLS } } -impl MachineAir for FriFoldChip { +impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for FriFoldChip { + type Events = &'a [FriFoldEvent]; +} + +impl MachineAir for FriFoldChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -98,18 +104,16 @@ impl MachineAir for FriFoldChip "FriFold".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.fri_fold_events.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _: &mut ExecutionRecord, + #[instrument(name = "generate fri fold trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( + &self, input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = input - .fri_fold_events + .events() .iter() .map(|event| { let mut row = [F::zero(); NUM_FRI_FOLD_COLS]; @@ -167,7 +171,7 @@ impl MachineAir for FriFoldChip } } -impl FriFoldChip { +impl FriFoldChip { pub fn eval_fri_fold( &self, builder: &mut AB, @@ -368,7 +372,7 @@ impl FriFoldChip { } } -impl Air for FriFoldChip +impl Air for FriFoldChip where AB: SphinxRecursionAirBuilder, { diff --git a/recursion/core/src/memory/air.rs b/recursion/core/src/memory/air.rs index b5d968653..3fdd96fbd 100644 --- a/recursion/core/src/memory/air.rs +++ b/recursion/core/src/memory/air.rs @@ -1,32 +1,42 @@ use core::mem::size_of; -use std::borrow::{Borrow, BorrowMut}; +use std::{borrow::{Borrow, BorrowMut}, marker::PhantomData}; use p3_air::{Air, BaseAir}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_core::{ - air::{AirInteraction, MachineAir, MemoryAirBuilder}, + air::{AirInteraction, MachineAir, EventLens, WithEvents, MemoryAirBuilder}, lookup::InteractionKind, utils::pad_rows_fixed, }; use tracing::instrument; use super::columns::MemoryInitCols; -use crate::memory::MemoryGlobalChip; +use crate::{air::Block, memory::MemoryGlobalChip}; use crate::runtime::{ExecutionRecord, RecursionProgram}; pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::>(); #[allow(dead_code)] -impl MemoryGlobalChip { +impl MemoryGlobalChip { pub fn new() -> Self { Self { fixed_log2_rows: None, + _phantom: PhantomData, } } } -impl MachineAir for MemoryGlobalChip { +impl<'a, F: Field> WithEvents<'a> for MemoryGlobalChip { + type Events = ( + // first memory event + &'a [(F, Block)], + // last memory event + &'a [(F, F, Block)], + ); +} + +impl MachineAir for MemoryGlobalChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -34,22 +44,20 @@ impl MachineAir for MemoryGlobalChip { "MemoryGlobalChip".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate memory trace", level = "debug", skip_all, fields(first_rows = input.first_memory_record.len(), last_rows = input.last_memory_record.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _output: &mut ExecutionRecord, + #[instrument(name = "generate memory trace", level = "debug", skip_all, fields(first_rows = input.events().0.len(), last_rows = input.events().1.len()))] + fn generate_trace>( + &self, input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); + let (first_memory_events, last_memory_events) = input.events(); // Fill in the initial memory records. rows.extend( - input - .first_memory_record + first_memory_events .iter() .map(|(addr, value)| { let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; @@ -65,8 +73,7 @@ impl MachineAir for MemoryGlobalChip { // Fill in the finalize memory records. rows.extend( - input - .last_memory_record + last_memory_events .iter() .map(|(addr, timestamp, value)| { let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; @@ -98,13 +105,13 @@ impl MachineAir for MemoryGlobalChip { } } -impl BaseAir for MemoryGlobalChip { +impl BaseAir for MemoryGlobalChip { fn width(&self) -> usize { NUM_MEMORY_INIT_COLS } } -impl Air for MemoryGlobalChip +impl Air for MemoryGlobalChip where AB: MemoryAirBuilder, { diff --git a/recursion/core/src/memory/mod.rs b/recursion/core/src/memory/mod.rs index 834de2e7a..7d97cc912 100644 --- a/recursion/core/src/memory/mod.rs +++ b/recursion/core/src/memory/mod.rs @@ -1,6 +1,8 @@ mod air; mod columns; +use std::marker::PhantomData; + use p3_field::PrimeField32; use crate::air::Block; @@ -109,6 +111,7 @@ impl MemoryAccessCols { } #[derive(Default)] -pub struct MemoryGlobalChip { +pub struct MemoryGlobalChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 1b548ef4e..74e7b9f9b 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -1,11 +1,12 @@ use std::borrow::{Borrow, BorrowMut}; +use std::marker::PhantomData; use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, MachineAir}; +use sphinx_core::air::{BaseAirBuilder, EventLens, MachineAir, Proj, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use sphinx_derive::AlignedBorrow; @@ -17,8 +18,9 @@ use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_MULTI_COLS: usize = core::mem::size_of::>(); #[derive(Default)] -pub struct MultiChip { +pub struct MultiChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } #[derive(AlignedBorrow, Clone, Copy)] @@ -42,13 +44,17 @@ pub union InstructionSpecificCols { poseidon2: Poseidon2Cols, } -impl BaseAir for MultiChip { +impl BaseAir for MultiChip { fn width(&self) -> usize { NUM_MULTI_COLS } } -impl MachineAir for MultiChip { +impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for MultiChip { + type Events = ( as WithEvents<'a>>::Events, as WithEvents<'a>>::Events); +} + +impl MachineAir for MultiChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -57,19 +63,32 @@ impl MachineAir for MultiChip { "Multi".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - fn generate_trace( - &self, - input: &ExecutionRecord, - output: &mut ExecutionRecord, + fn generate_trace>( + &self, input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let fri_fold_chip = FriFoldChip::<3>::default(); + let fri_fold_chip = FriFoldChip::::default(); let poseidon2 = Poseidon2Chip::default(); - let fri_fold_trace = fri_fold_chip.generate_trace(input, output); - let mut poseidon2_trace = poseidon2.generate_trace(input, output); + + fn to_fri<'c, F: PrimeField32, const DEGREE: usize>( + evs: as WithEvents<'c>>::Events, + _v: &'c (), + ) -> as WithEvents<'c>>::Events { + evs.0 + } + + fn to_poseidon<'c, F: PrimeField32, const DEGREE: usize>( + evs: as WithEvents<'c>>::Events, + _v: &'c (), + ) -> as WithEvents<'c>>::Events { + evs.1 + } + + let fri_fold_trace = fri_fold_chip.generate_trace(&Proj::new(input, to_fri::), output); + let mut poseidon2_trace = poseidon2.generate_trace(&Proj::new(input, to_poseidon::), output); let mut rows = fri_fold_trace .clone() @@ -85,15 +104,15 @@ impl MachineAir for MultiChip { let fri_fold_cols = *cols.fri_fold(); cols.fri_fold_receive_table = - FriFoldChip::<3>::do_receive_table(&fri_fold_cols); + FriFoldChip::::do_receive_table(&fri_fold_cols); cols.fri_fold_memory_access = - FriFoldChip::<3>::do_memory_access(&fri_fold_cols); + FriFoldChip::::do_memory_access(&fri_fold_cols); } else { cols.is_poseidon2 = F::one(); let poseidon2_cols = *cols.poseidon2(); - cols.poseidon2_receive_table = Poseidon2Chip::do_receive_table(&poseidon2_cols); - cols.poseidon2_memory_access = Poseidon2Chip::do_memory_access(&poseidon2_cols); + cols.poseidon2_receive_table = Poseidon2Chip::::do_receive_table(&poseidon2_cols); + cols.poseidon2_memory_access = Poseidon2Chip::::do_memory_access(&poseidon2_cols); } row }) @@ -115,7 +134,7 @@ impl MachineAir for MultiChip { } } -impl Air for MultiChip +impl Air for MultiChip where AB: SphinxRecursionAirBuilder, { @@ -157,15 +176,15 @@ where let fri_columns_local = local.fri_fold(); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::<3>::do_memory_access::(fri_columns_local), + local.is_fri_fold * FriFoldChip::::do_memory_access::(fri_columns_local), local.fri_fold_memory_access, ); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::<3>::do_receive_table::(fri_columns_local), + local.is_fri_fold * FriFoldChip::::do_receive_table::(fri_columns_local), local.fri_fold_receive_table, ); - let fri_fold_chip = FriFoldChip::<3>::default(); + let fri_fold_chip = FriFoldChip::::default(); fri_fold_chip.eval_fri_fold( &mut sub_builder, local.fri_fold(), @@ -182,16 +201,16 @@ where let poseidon2_columns = local.poseidon2(); sub_builder.assert_eq( - local.is_poseidon2 * Poseidon2Chip::do_receive_table::(poseidon2_columns), + local.is_poseidon2 * Poseidon2Chip::::do_receive_table::(poseidon2_columns), local.poseidon2_receive_table, ); sub_builder.assert_eq( local.is_poseidon2 - * Poseidon2Chip::do_memory_access::(poseidon2_columns), + * Poseidon2Chip::::do_memory_access::(poseidon2_columns), local.poseidon2_memory_access, ); - let poseidon2_chip = Poseidon2Chip::default(); + let poseidon2_chip = Poseidon2Chip::::default(); poseidon2_chip.eval_poseidon2( &mut sub_builder, local.poseidon2(), @@ -215,6 +234,7 @@ impl MultiCols { #[cfg(test)] mod tests { use itertools::Itertools; + use std::marker::PhantomData; use std::time::Instant; use p3_baby_bear::BabyBear; @@ -239,8 +259,9 @@ mod tests { let config = BabyBearPoseidon2::compressed(); let mut challenger = config.challenger(); - let chip = MultiChip::<5> { + let chip = MultiChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }; let test_inputs = (0..16) diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index 98812828e..6e2abf508 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -1,8 +1,9 @@ use core::borrow::Borrow; use core::mem::size_of; +use std::marker::PhantomData; use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; -use p3_field::AbstractField; +use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; use sphinx_core::air::{BaseAirBuilder, ExtensionAirBuilder}; use sphinx_primitives::RC_16_30_U32; @@ -23,17 +24,18 @@ pub(crate) const WIDTH: usize = 16; /// A chip that implements addition for the opcode ADD. #[derive(Default)] -pub struct Poseidon2Chip { +pub struct Poseidon2Chip { pub fixed_log2_rows: Option, + _phantom: PhantomData, } -impl BaseAir for Poseidon2Chip { +impl BaseAir for Poseidon2Chip { fn width(&self) -> usize { NUM_POSEIDON2_COLS } } -impl Poseidon2Chip { +impl Poseidon2Chip { pub fn eval_poseidon2( &self, builder: &mut AB, @@ -318,7 +320,7 @@ impl Poseidon2Chip { } } -impl Air for Poseidon2Chip +impl Air for Poseidon2Chip where AB: BaseAirBuilder, { @@ -343,6 +345,7 @@ where mod tests { use itertools::Itertools; use std::borrow::Borrow; + use std::marker::PhantomData; use std::time::Instant; use zkhash::ark_ff::UniformRand; @@ -372,6 +375,7 @@ mod tests { fn generate_trace() { let chip = Poseidon2Chip { fixed_log2_rows: None, + _phantom: PhantomData, }; let rng = &mut rand::thread_rng(); @@ -421,6 +425,7 @@ mod tests { let chip = Poseidon2Chip { fixed_log2_rows: None, + _phantom: PhantomData, }; let trace: RowMajorMatrix = chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index 8030f6bbc..6ecc94fcb 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -1,8 +1,8 @@ use std::borrow::BorrowMut; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use sphinx_core::{air::MachineAir, utils::pad_rows_fixed}; +use sphinx_core::{air::{EventLens, MachineAir, WithEvents}, utils::pad_rows_fixed}; use sphinx_primitives::RC_16_30_U32; use tracing::instrument; @@ -13,10 +13,14 @@ use crate::{ use super::{ external::{NUM_POSEIDON2_COLS, WIDTH}, - Poseidon2Chip, Poseidon2Cols, + Poseidon2Chip, Poseidon2Cols, Poseidon2Event, }; -impl MachineAir for Poseidon2Chip { +impl<'a, F: Field> WithEvents<'a> for Poseidon2Chip { + type Events = &'a [Poseidon2Event]; +} + +impl MachineAir for Poseidon2Chip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -25,15 +29,13 @@ impl MachineAir for Poseidon2Chip { "Poseidon2".to_string() } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate poseidon2 trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _: &mut ExecutionRecord, + #[instrument(name = "generate poseidon2 trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( + &self, input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); @@ -44,7 +46,7 @@ impl MachineAir for Poseidon2Chip { let rounds_p_beginning = 2 + rounds_f / 2; let p_end = rounds_p_beginning + rounds_p; - for poseidon2_event in input.poseidon2_events.iter() { + for poseidon2_event in input.events().iter() { let mut round_input = Default::default(); for r in 0..rounds { let mut row = [F::zero(); NUM_POSEIDON2_COLS]; diff --git a/recursion/core/src/poseidon2_wide/external.rs b/recursion/core/src/poseidon2_wide/external.rs index cbcedb28f..694b8b226 100644 --- a/recursion/core/src/poseidon2_wide/external.rs +++ b/recursion/core/src/poseidon2_wide/external.rs @@ -1,14 +1,16 @@ +use crate::poseidon2::Poseidon2Event; use crate::poseidon2_wide::columns::{ Poseidon2ColType, Poseidon2ColTypeMut, Poseidon2Cols, Poseidon2SBoxCols, NUM_POSEIDON2_COLS, NUM_POSEIDON2_SBOX_COLS, }; use crate::runtime::Opcode; use core::borrow::Borrow; +use std::marker::PhantomData; use p3_air::{Air, BaseAir}; -use p3_field::{AbstractField, PrimeField32}; +use p3_field::{AbstractField, Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, MachineAir}; +use sphinx_core::air::{BaseAirBuilder, EventLens, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use sphinx_primitives::RC_16_30_U32; use std::borrow::BorrowMut; @@ -31,11 +33,17 @@ pub const NUM_ROUNDS: usize = NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS; /// A chip that implements addition for the opcode ADD. #[derive(Default)] -pub struct Poseidon2WideChip { +pub struct Poseidon2WideChip { pub fixed_log2_rows: Option, + pub _phantom: PhantomData, } -impl MachineAir for Poseidon2WideChip { +impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for Poseidon2WideChip { + type Events = &'a [Poseidon2Event]; +} + +impl MachineAir for Poseidon2WideChip +{ type Record = ExecutionRecord; type Program = RecursionProgram; @@ -44,15 +52,13 @@ impl MachineAir for Poseidon2WideChip>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.poseidon2_events.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _: &mut ExecutionRecord, + #[instrument(name = "generate poseidon2 wide trace", level = "debug", skip_all, fields(rows = input.events().len()))] + fn generate_trace>( + &self, input: &EL, _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); @@ -60,7 +66,7 @@ impl MachineAir for Poseidon2WideChip>::width(self); - for event in &input.poseidon2_events { + for event in input.events() { let mut row = Vec::new(); row.resize(num_columns, F::zero()); @@ -333,7 +339,7 @@ fn eval_internal_rounds( } } -impl BaseAir for Poseidon2WideChip { +impl BaseAir for Poseidon2WideChip { fn width(&self) -> usize { match DEGREE { d if d < 7 => NUM_POSEIDON2_SBOX_COLS, @@ -381,7 +387,7 @@ fn eval_mem(builder: &mut AB, local: &Poseidon2Me ); } -impl Air for Poseidon2WideChip +impl Air for Poseidon2WideChip where AB: SphinxRecursionAirBuilder, { @@ -445,6 +451,7 @@ where #[cfg(test)] mod tests { + use std::marker::PhantomData; use std::time::Instant; use crate::poseidon2::Poseidon2Event; @@ -463,8 +470,9 @@ mod tests { /// A test generating a trace for a single permutation that checks that the output is correct fn generate_trace_degree() { - let chip = Poseidon2WideChip:: { + let chip = Poseidon2WideChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }; let test_inputs = vec![ @@ -509,8 +517,9 @@ mod tests { inputs: Vec<[BabyBear; 16]>, outputs: Vec<[BabyBear; 16]>, ) { - let chip = Poseidon2WideChip:: { + let chip = Poseidon2WideChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }; let mut input_exec = ExecutionRecord::::default(); for (input, output) in inputs.into_iter().zip_eq(outputs) { diff --git a/recursion/core/src/program/mod.rs b/recursion/core/src/program/mod.rs index 4bfb43d2f..9376a61a2 100644 --- a/recursion/core/src/program/mod.rs +++ b/recursion/core/src/program/mod.rs @@ -1,11 +1,13 @@ use crate::air::SphinxRecursionAirBuilder; +use crate::cpu::{CpuEvent, Instruction}; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use std::marker::PhantomData; use p3_air::{Air, BaseAir, PairBuilder}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::MachineAir; +use sphinx_core::air::{EventLens, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use std::collections::HashMap; use tracing::instrument; @@ -38,15 +40,24 @@ pub struct ProgramMultiplicityCols { /// A chip that implements addition for the opcodes ADD and ADDI. #[derive(Default)] -pub struct ProgramChip; +pub struct ProgramChip(pub PhantomData); -impl ProgramChip { +impl ProgramChip { pub fn new() -> Self { - Self {} + Self (PhantomData) } } -impl MachineAir for ProgramChip { +impl<'a, F: Field> WithEvents<'a> for ProgramChip { + type Events = ( + // program.instructions + &'a [Instruction], + // cpu_events + &'a [CpuEvent], + ); +} + +impl MachineAir for ProgramChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -92,20 +103,18 @@ impl MachineAir for ProgramChip { )) } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - #[instrument(name = "generate program trace", level = "debug", skip_all, fields(rows = input.program.instructions.len()))] - fn generate_trace( - &self, - input: &ExecutionRecord, - _output: &mut ExecutionRecord, + #[instrument(name = "generate program trace", level = "debug", skip_all, fields(rows = input.events().0.len()))] + fn generate_trace>( + &self, input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); - input.cpu_events.iter().for_each(|event| { + input.events().1.iter().for_each(|event| { let pc = event.pc; instruction_counts .entry(pc.as_canonical_u32()) @@ -115,9 +124,9 @@ impl MachineAir for ProgramChip { let max_program_size = match std::env::var("MAX_RECURSION_PROGRAM_SIZE") { Ok(value) => value.parse().unwrap(), - Err(_) => std::cmp::min(1048576, input.program.instructions.len()), + Err(_) => std::cmp::min(1048576, input.events().0.len()), }; - let mut rows = input.program.instructions[0..max_program_size] + let mut rows = input.events().0[0..max_program_size] .iter() .enumerate() .map(|(i, _)| { @@ -145,13 +154,13 @@ impl MachineAir for ProgramChip { } } -impl BaseAir for ProgramChip { +impl BaseAir for ProgramChip { fn width(&self) -> usize { NUM_PROGRAM_MULT_COLS } } -impl Air for ProgramChip +impl Air for ProgramChip where AB: SphinxRecursionAirBuilder + PairBuilder, { diff --git a/recursion/core/src/range_check/trace.rs b/recursion/core/src/range_check/trace.rs index c5ee4c715..3044498a7 100644 --- a/recursion/core/src/range_check/trace.rs +++ b/recursion/core/src/range_check/trace.rs @@ -1,17 +1,22 @@ -use std::borrow::BorrowMut; +use std::{borrow::BorrowMut, collections::BTreeMap}; -use p3_field::PrimeField32; +use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use sphinx_core::air::MachineAir; +use sphinx_core::air::{EventLens, MachineAir, WithEvents}; use super::{ columns::{RangeCheckMultCols, NUM_RANGE_CHECK_MULT_COLS, NUM_RANGE_CHECK_PREPROCESSED_COLS}, - RangeCheckChip, + RangeCheckChip, RangeCheckEvent, }; use crate::runtime::{ExecutionRecord, RecursionProgram}; pub const NUM_ROWS: usize = 1 << 16; +impl<'a, F: Field> WithEvents<'a> for RangeCheckChip { + type Events = &'a BTreeMap; +} + + impl MachineAir for RangeCheckChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -30,14 +35,12 @@ impl MachineAir for RangeCheckChip { Some(trace) } - fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + fn generate_dependencies>(&self, _: &EL, _: &mut Self::Record) { // This is a no-op. } - fn generate_trace( - &self, - input: &ExecutionRecord, - _output: &mut ExecutionRecord, + fn generate_trace>( + &self, input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { let (_, event_map) = Self::trace_and_map(); @@ -46,7 +49,7 @@ impl MachineAir for RangeCheckChip { NUM_RANGE_CHECK_MULT_COLS, ); - for (lookup, mult) in input.range_check_events.iter() { + for (lookup, mult) in input.events().iter() { let (row, index) = event_map[lookup]; let cols: &mut RangeCheckMultCols = trace.row_mut(row).borrow_mut(); diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 7509a4073..09dae9cf0 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -2,14 +2,19 @@ use std::collections::{BTreeMap, HashMap}; use std::sync::Arc; use p3_field::{AbstractField, PrimeField32}; -use sphinx_core::stark::{MachineRecord, PROOF_MAX_NUM_PVS}; +use sphinx_core::air::EventLens; +use sphinx_core::stark::{Indexable, MachineRecord, PROOF_MAX_NUM_PVS}; use super::RecursionProgram; use crate::air::Block; -use crate::cpu::CpuEvent; -use crate::fri_fold::FriFoldEvent; -use crate::poseidon2::Poseidon2Event; -use crate::range_check::RangeCheckEvent; +use crate::cpu::{CpuChip, CpuEvent}; +use crate::fri_fold::{FriFoldChip, FriFoldEvent}; +use crate::memory::MemoryGlobalChip; +use crate::multi::MultiChip; +use crate::poseidon2::{Poseidon2Chip, Poseidon2Event}; +use crate::poseidon2_wide::Poseidon2WideChip; +use crate::program::ProgramChip; +use crate::range_check::{RangeCheckChip, RangeCheckEvent}; #[derive(Default, Debug, Clone)] pub struct ExecutionRecord { @@ -37,12 +42,14 @@ impl ExecutionRecord { } } -impl MachineRecord for ExecutionRecord { - type Config = (); - +impl Indexable for ExecutionRecord { fn index(&self) -> u32 { 0 } +} + +impl MachineRecord for ExecutionRecord { + type Config = (); fn set_index(&mut self, _: u32) {} @@ -92,3 +99,51 @@ impl MachineRecord for ExecutionRecord { ret } } + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + &self.cpu_events + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + &self.fri_fold_events + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + &self.poseidon2_events + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + &self.poseidon2_events + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + (&self.first_memory_record, &self.last_memory_record) + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + (&self.program.instructions, &self.cpu_events) + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + &self.range_check_events + } +} + +impl EventLens> for ExecutionRecord { + fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + (>>::events(self), >>::events(self)) + } +} \ No newline at end of file diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 4ab9b3de4..8857bf290 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -4,7 +4,7 @@ pub mod utils; use p3_field::{extension::BinomiallyExtendable, PrimeField32}; use sphinx_core::stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS}; -use sphinx_derive::MachineAir; +use sphinx_derive::{EventLens, MachineAir, WithEvents}; use crate::runtime::D; use crate::{ @@ -18,20 +18,21 @@ use std::marker::PhantomData; pub type RecursionAirWideDeg3 = RecursionAir; pub type RecursionAirSkinnyDeg7 = RecursionAir; -#[derive(MachineAir)] +#[derive(WithEvents, EventLens, MachineAir)] #[sphinx_core_path = "sphinx_core"] #[execution_record_path = "crate::runtime::ExecutionRecord"] #[program_path = "crate::runtime::RecursionProgram"] #[builder_path = "crate::air::SphinxRecursionAirBuilder"] +#[record_type = "crate::runtime::ExecutionRecord"] pub enum RecursionAir, const DEGREE: usize> { - Program(ProgramChip), + Program(ProgramChip), Cpu(CpuChip), - MemoryGlobal(MemoryGlobalChip), - Poseidon2Wide(Poseidon2WideChip), - Poseidon2Skinny(Poseidon2Chip), - FriFold(FriFoldChip), + MemoryGlobal(MemoryGlobalChip), + Poseidon2Wide(Poseidon2WideChip), + Poseidon2Skinny(Poseidon2Chip), + FriFold(FriFoldChip), RangeCheck(RangeCheckChip), - Multi(MultiChip), + Multi(MultiChip), } impl, const DEGREE: usize> RecursionAir { @@ -54,37 +55,44 @@ impl, const DEGREE: usize> RecursionAi } pub fn get_all() -> Vec { - once(RecursionAir::Program(ProgramChip)) + once(RecursionAir::Program(ProgramChip(PhantomData))) .chain(once(RecursionAir::Cpu(CpuChip { fixed_log2_rows: None, _phantom: PhantomData, }))) .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: None, + _phantom: PhantomData, }))) .chain(once(RecursionAir::Poseidon2Wide(Poseidon2WideChip::< - DEGREE, + F, DEGREE, > { fixed_log2_rows: None, + _phantom: PhantomData, }))) - .chain(once(RecursionAir::FriFold(FriFoldChip:: { + .chain(once(RecursionAir::FriFold(FriFoldChip:: { fixed_log2_rows: None, + _phantom: PhantomData, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .collect() } pub fn get_wrap_all() -> Vec { - once(RecursionAir::Program(ProgramChip)) + once(RecursionAir::Program(ProgramChip ( + PhantomData, + ))) .chain(once(RecursionAir::Cpu(CpuChip { fixed_log2_rows: Some(20), _phantom: PhantomData, }))) .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: Some(20), + _phantom: PhantomData, }))) .chain(once(RecursionAir::Multi(MultiChip { fixed_log2_rows: Some(20), + _phantom: PhantomData, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) .collect() From 0a7bb9fa86bc4d8377c69e6a93682571621c8f02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Fri, 31 May 2024 13:20:25 -0400 Subject: [PATCH 3/4] chore: clippy --- core/src/air/machine.rs | 6 +- core/src/alu/mul/mod.rs | 2 +- core/src/alu/sll/mod.rs | 2 +- core/src/alu/sr/mod.rs | 2 +- core/src/lookup/debug.rs | 9 +- .../field/extensions/quadratic/mod.rs | 2 +- .../field/extensions/quadratic/sqrt.rs | 2 +- core/src/operations/field/field_den.rs | 2 +- .../operations/field/field_inner_product.rs | 2 +- core/src/operations/field/field_op.rs | 2 +- core/src/operations/field/field_sqrt.rs | 2 +- core/src/program/mod.rs | 4 +- core/src/runtime/record.rs | 82 +++++++++++-------- core/src/stark/air.rs | 2 +- core/src/stark/machine.rs | 3 +- .../src/syscall/precompiles/edwards/ed_add.rs | 5 +- .../precompiles/edwards/ed_decompress.rs | 5 +- core/src/syscall/precompiles/field/add.rs | 5 +- core/src/syscall/precompiles/field/mul.rs | 5 +- core/src/syscall/precompiles/field/sub.rs | 5 +- .../src/syscall/precompiles/quad_field/add.rs | 5 +- .../src/syscall/precompiles/quad_field/mul.rs | 5 +- .../src/syscall/precompiles/quad_field/sub.rs | 5 +- .../weierstrass/weierstrass_add.rs | 3 +- .../weierstrass/weierstrass_double.rs | 3 +- derive/src/lib.rs | 4 +- examples/is-prime/program/Cargo.lock | 62 +++++++------- recursion/core/src/cpu/trace.rs | 5 +- recursion/core/src/fri_fold/mod.rs | 12 ++- recursion/core/src/memory/air.rs | 15 ++-- recursion/core/src/multi/mod.rs | 30 +++++-- recursion/core/src/poseidon2/external.rs | 2 +- recursion/core/src/poseidon2/trace.rs | 9 +- recursion/core/src/poseidon2_wide/external.rs | 9 +- recursion/core/src/program/mod.rs | 8 +- recursion/core/src/range_check/trace.rs | 5 +- recursion/core/src/runtime/record.rs | 31 ++++--- recursion/core/src/stark/mod.rs | 7 +- 38 files changed, 212 insertions(+), 157 deletions(-) diff --git a/core/src/air/machine.rs b/core/src/air/machine.rs index 939c63125..9f023a30d 100644 --- a/core/src/air/machine.rs +++ b/core/src/air/machine.rs @@ -23,11 +23,11 @@ pub trait WithEvents<'a>: Sized { /// /// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 ) pub trait EventLens WithEvents<'b>>: Indexable { - fn events<'a>(&'a self) -> >::Events; + fn events(&self) -> >::Events; } -//////////////// Derive macro shaneanigans //////////////////////////////////////////////// -// This is *only* useful for the derive macros, you should *not* use this directly. +//////////////// Derive macro shenanigans //////////////////////////////////////////////// +// The following is *only* useful for the derive macros, you should *not* use this directly. // /// Hereafter, Lens composition explained pedantically: all this is saying is that /// if I have an EventLens to T::Events, and a way (F) to deduce U::Events from that, diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index 00ef5b946..8681e5a4d 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -141,7 +141,7 @@ impl MachineAir for MulChip { input: &EL, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let mul_events = input.events().clone(); + let mul_events = input.events(); // Compute the chunk size based on the number of events and the number of CPUs. let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1); diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index 812810f35..39e3ad9f9 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -118,7 +118,7 @@ impl MachineAir for ShiftLeft { ) -> RowMajorMatrix { // Generate the trace rows for each event. let mut rows: Vec<[F; NUM_SHIFT_LEFT_COLS]> = vec![]; - let shift_left_events = input.events().clone(); + let shift_left_events = input.events(); for event in shift_left_events.iter() { let mut row = [F::zero(); NUM_SHIFT_LEFT_COLS]; let cols: &mut ShiftLeftCols = row.as_mut_slice().borrow_mut(); diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 331a49049..a9d85e9e9 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -150,7 +150,7 @@ impl MachineAir for ShiftRightChip { ) -> RowMajorMatrix { // Generate the trace rows for each event. let mut rows: Vec<[F; NUM_SHIFT_RIGHT_COLS]> = Vec::new(); - let sr_events = input.events().clone(); + let sr_events = input.events(); for event in sr_events.iter() { assert!(event.opcode == Opcode::SRL || event.opcode == Opcode::SRA); let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS]; diff --git a/core/src/lookup/debug.rs b/core/src/lookup/debug.rs index 937f889c2..8ed13d7d1 100644 --- a/core/src/lookup/debug.rs +++ b/core/src/lookup/debug.rs @@ -6,9 +6,7 @@ use p3_matrix::Matrix; use super::InteractionKind; use crate::air::MachineAir; -use crate::stark::{ - MachineChip, StarkGenericConfig, StarkMachine, StarkProvingKey, Val, -}; +use crate::stark::{MachineChip, StarkGenericConfig, StarkMachine, StarkProvingKey, Val}; #[derive(Debug)] pub struct InteractionData { @@ -46,10 +44,7 @@ fn field_to_int(x: F) -> i32 { } } -pub fn debug_interactions< - SC: StarkGenericConfig, - A: MachineAir>, ->( +pub fn debug_interactions>>( chip: &MachineChip, pkey: &StarkProvingKey, record: &A::Record, diff --git a/core/src/operations/field/extensions/quadratic/mod.rs b/core/src/operations/field/extensions/quadratic/mod.rs index 8dd54cd25..b91674b0f 100644 --- a/core/src/operations/field/extensions/quadratic/mod.rs +++ b/core/src/operations/field/extensions/quadratic/mod.rs @@ -416,7 +416,7 @@ mod tests { } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &() } } diff --git a/core/src/operations/field/extensions/quadratic/sqrt.rs b/core/src/operations/field/extensions/quadratic/sqrt.rs index 13071eb53..cfcd6c07b 100644 --- a/core/src/operations/field/extensions/quadratic/sqrt.rs +++ b/core/src/operations/field/extensions/quadratic/sqrt.rs @@ -155,7 +155,7 @@ mod tests { } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &() } } diff --git a/core/src/operations/field/field_den.rs b/core/src/operations/field/field_den.rs index 5ce71a573..ff96412f1 100644 --- a/core/src/operations/field/field_den.rs +++ b/core/src/operations/field/field_den.rs @@ -202,7 +202,7 @@ mod tests { } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &() } } diff --git a/core/src/operations/field/field_inner_product.rs b/core/src/operations/field/field_inner_product.rs index 45a12ef30..0d1ca42a4 100644 --- a/core/src/operations/field/field_inner_product.rs +++ b/core/src/operations/field/field_inner_product.rs @@ -190,7 +190,7 @@ mod tests { } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &() } } diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index 41ac4209a..654a8d10e 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -283,7 +283,7 @@ mod tests { } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &() } } diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index 324307246..cfd77c19e 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -156,7 +156,7 @@ mod tests { } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &() } } diff --git a/core/src/program/mod.rs b/core/src/program/mod.rs index 0caad196b..623bb2dde 100644 --- a/core/src/program/mod.rs +++ b/core/src/program/mod.rs @@ -123,13 +123,13 @@ impl MachineAir for ProgramChip { // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); - cpu_events.iter().for_each(|event| { + for event in cpu_events.iter() { let pc = event.pc; instruction_counts .entry(pc) .and_modify(|count| *count += 1) .or_insert(1); - }); + } let rows = program .instructions diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 66d308e79..9b3ddb593 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -146,205 +146,219 @@ pub struct ExecutionRecord { // Event lenses connect the record to the events relative to a particular chip impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { (&self.add_events, &self.sub_events) } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.bitwise_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.divrem_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.lt_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.mul_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.shift_left_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.shift_right_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.byte_lookups } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.cpu_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { (&self.memory_initialize_events, &self.memory_finalize_events) } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.program.memory_image } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { (&self.cpu_events, &self.program) } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.sha_extend_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.sha_compress_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.blake3_compress_inner_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.keccak_permute_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.bls12381_g1_decompress_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.secp256k1_decompress_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.bls12381_g2_add_events } } impl EventLens for ExecutionRecord { - fn events(&self) -> ::Events { + fn events(&self) -> >::Events { &self.bls12381_g2_double_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_fp_add_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_fp_sub_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_fp_mul_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_fp2_add_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_fp2_sub_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_fp2_mul_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.secp256k1_add_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_g1_add_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.bn254_add_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.secp256k1_double_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.bls12381_g1_double_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.bn254_double_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { &self.ed_add_events } } impl EventLens> for ExecutionRecord { - fn events(&self) -> as crate::air::WithEvents>::Events { + fn events( + &self, + ) -> as crate::air::WithEvents<'_>>::Events { &self.ed_decompress_events } } diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index d4f0b55fb..e536a9631 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -11,8 +11,8 @@ use crate::utils::ec::weierstrass::bls12_381::{Bls12381BaseField, Bls12381Parame use crate::StarkGenericConfig; use p3_field::PrimeField32; pub use riscv_chips::*; -use tracing::instrument; use sphinx_derive::{EventLens, WithEvents}; +use tracing::instrument; /// A module for importing all the different RISC-V chips. pub(crate) mod riscv_chips { diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index f265266d5..ec485d442 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -105,8 +105,7 @@ impl Debug for StarkVerifyingKey { } } -impl>> StarkMachine -{ +impl>> StarkMachine { /// Get an array containing a `ChipRef` for all the chips of this RISC-V STARK machine. pub fn chips(&self) -> &[MachineChip] { &self.chips diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index 9d4739fc3..e8cf48cf3 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -147,8 +147,9 @@ impl<'a, E: EllipticCurve + EdwardsParameters> WithEvents<'a> for EdAddAssignChi type Events = &'a [ECAddEvent]; } -impl MachineAir for EdAddAssignChip - where ExecutionRecord: EventLens>, +impl MachineAir for EdAddAssignChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index cc7d75a9e..5d6054226 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -337,8 +337,9 @@ impl<'a, E: EdwardsParameters> WithEvents<'a> for EdDecompressChip { type Events = &'a [EdDecompressEvent]; } -impl MachineAir for EdDecompressChip - where ExecutionRecord: EventLens> +impl MachineAir for EdDecompressChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; diff --git a/core/src/syscall/precompiles/field/add.rs b/core/src/syscall/precompiles/field/add.rs index 3e8d09d04..542e02739 100644 --- a/core/src/syscall/precompiles/field/add.rs +++ b/core/src/syscall/precompiles/field/add.rs @@ -117,8 +117,9 @@ impl<'a, FP: FieldParameters> WithEvents<'a> for FieldAddChip { type Events = &'a [FieldAddEvent]; } -impl MachineAir for FieldAddChip - where ExecutionRecord: EventLens> +impl MachineAir for FieldAddChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/field/mul.rs b/core/src/syscall/precompiles/field/mul.rs index 85912f3be..2d1fd8d48 100644 --- a/core/src/syscall/precompiles/field/mul.rs +++ b/core/src/syscall/precompiles/field/mul.rs @@ -117,8 +117,9 @@ impl<'a, FP: FieldParameters> WithEvents<'a> for FieldMulChip { type Events = &'a [FieldMulEvent]; } -impl MachineAir for FieldMulChip -where ExecutionRecord: EventLens> +impl MachineAir for FieldMulChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/field/sub.rs b/core/src/syscall/precompiles/field/sub.rs index d4c804afc..65e772d0b 100644 --- a/core/src/syscall/precompiles/field/sub.rs +++ b/core/src/syscall/precompiles/field/sub.rs @@ -117,8 +117,9 @@ impl<'a, FP: FieldParameters> WithEvents<'a> for FieldSubChip { type Events = &'a [FieldSubEvent]; } -impl MachineAir for FieldSubChip - where ExecutionRecord: EventLens> +impl MachineAir for FieldSubChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/quad_field/add.rs b/core/src/syscall/precompiles/quad_field/add.rs index d909f045f..c96547783 100644 --- a/core/src/syscall/precompiles/quad_field/add.rs +++ b/core/src/syscall/precompiles/quad_field/add.rs @@ -143,8 +143,9 @@ impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldAddChip { type Events = &'a [QuadFieldAddEvent]; } -impl MachineAir for QuadFieldAddChip - where ExecutionRecord: EventLens> +impl MachineAir for QuadFieldAddChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/quad_field/mul.rs b/core/src/syscall/precompiles/quad_field/mul.rs index b4e683406..265ba1dd0 100644 --- a/core/src/syscall/precompiles/quad_field/mul.rs +++ b/core/src/syscall/precompiles/quad_field/mul.rs @@ -152,8 +152,9 @@ impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldMulChip { type Events = &'a [QuadFieldMulEvent]; } -impl MachineAir for QuadFieldMulChip - where ExecutionRecord: EventLens> +impl MachineAir for QuadFieldMulChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/quad_field/sub.rs b/core/src/syscall/precompiles/quad_field/sub.rs index 05bbd4943..ae8780a3f 100644 --- a/core/src/syscall/precompiles/quad_field/sub.rs +++ b/core/src/syscall/precompiles/quad_field/sub.rs @@ -143,8 +143,9 @@ impl<'a, FP: FieldParameters> WithEvents<'a> for QuadFieldSubChip { type Events = &'a [QuadFieldSubEvent]; } -impl MachineAir for QuadFieldSubChip - where ExecutionRecord: EventLens> +impl MachineAir for QuadFieldSubChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index f17e07f56..d527d6e8e 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -155,7 +155,8 @@ impl<'a, E: EllipticCurve + WeierstrassParameters> WithEvents<'a> for Weierstras impl MachineAir for WeierstrassAddAssignChip - where ExecutionRecord: EventLens>, +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index 4b7098843..eff920ac7 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -181,7 +181,8 @@ impl<'a, E: EllipticCurve + WeierstrassParameters> WithEvents<'a> impl MachineAir for WeierstrassDoubleAssignChip - where ExecutionRecord: EventLens>, +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = Program; diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 37bcf8c8d..c1bff2797 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -230,7 +230,7 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let name = &ast.ident; let generics = &ast.generics; let type_params = generics.type_params(); - let ty_params = quote!{ #(#type_params),* }; + let ty_params = quote! { #(#type_params),* }; let const_params = generics.const_params(); let co_params = quote! { #(#const_params),* }; @@ -304,7 +304,7 @@ pub fn machine_air_derive(input: TokenStream) -> TokenStream { let generate_trace_arms = variants.iter().enumerate().map(|(i, (variant_name, field))| { let field_ty = &field.ty; - + let idx = syn::Index::from(i); quote! { #name::#variant_name(x) => { diff --git a/examples/is-prime/program/Cargo.lock b/examples/is-prime/program/Cargo.lock index 1974a3c1c..792100ff2 100644 --- a/examples/is-prime/program/Cargo.lock +++ b/examples/is-prime/program/Cargo.lock @@ -234,7 +234,7 @@ dependencies = [ name = "is-prime-program" version = "0.1.0" dependencies = [ - "wp1-zkvm", + "sphinx-zkvm", ] [[package]] @@ -407,6 +407,36 @@ dependencies = [ "rand_core", ] +[[package]] +name = "sphinx-precompiles" +version = "0.1.0" +source = "git+ssh://git@github.com/lurk-lab/sphinx.git#eeea8c7319aa065b198119ddc4aa3acbd921d143" +dependencies = [ + "anyhow", + "bincode", + "bls12_381", + "cfg-if", + "getrandom", + "hybrid-array", + "k256", + "serde", +] + +[[package]] +name = "sphinx-zkvm" +version = "0.1.0" +source = "git+ssh://git@github.com/lurk-lab/sphinx.git#eeea8c7319aa065b198119ddc4aa3acbd921d143" +dependencies = [ + "bincode", + "cfg-if", + "getrandom", + "k256", + "once_cell", + "rand", + "sha2", + "sphinx-precompiles", +] + [[package]] name = "spki" version = "0.7.3" @@ -464,36 +494,6 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" -[[package]] -name = "wp1-precompiles" -version = "0.1.0" -source = "git+ssh://git@github.com/wormhole-foundation/wp1.git#e6718735bba932d7e4820f37bea5028459948e18" -dependencies = [ - "anyhow", - "bincode", - "bls12_381", - "cfg-if", - "getrandom", - "hybrid-array", - "k256", - "serde", -] - -[[package]] -name = "wp1-zkvm" -version = "0.1.0" -source = "git+ssh://git@github.com/wormhole-foundation/wp1.git#e6718735bba932d7e4820f37bea5028459948e18" -dependencies = [ - "bincode", - "cfg-if", - "getrandom", - "k256", - "once_cell", - "rand", - "sha2", - "wp1-precompiles", -] - [[package]] name = "wyz" version = "0.5.1" diff --git a/recursion/core/src/cpu/trace.rs b/recursion/core/src/cpu/trace.rs index 197bdff7b..995969156 100644 --- a/recursion/core/src/cpu/trace.rs +++ b/recursion/core/src/cpu/trace.rs @@ -20,8 +20,9 @@ impl<'a, F: Field> WithEvents<'a> for CpuChip { type Events = &'a [CpuEvent]; } -impl> MachineAir for CpuChip - where ExecutionRecord: EventLens> +impl> MachineAir for CpuChip +where + ExecutionRecord: EventLens>, { type Record = ExecutionRecord; type Program = RecursionProgram; diff --git a/recursion/core/src/fri_fold/mod.rs b/recursion/core/src/fri_fold/mod.rs index 389280a08..3f163399c 100644 --- a/recursion/core/src/fri_fold/mod.rs +++ b/recursion/core/src/fri_fold/mod.rs @@ -4,17 +4,19 @@ use crate::air::RecursionMemoryAirBuilder; use crate::memory::{MemoryReadCols, MemoryReadSingleCols, MemoryReadWriteCols}; use crate::runtime::Opcode; use core::borrow::Borrow; -use std::marker::PhantomData; use itertools::Itertools; use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::{AbstractField, Field}; use p3_field::PrimeField32; +use p3_field::{AbstractField, Field}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; -use sphinx_core::air::{BaseAirBuilder, BinomialExtension, EventLens, ExtensionAirBuilder, MachineAir, WithEvents}; +use sphinx_core::air::{ + BaseAirBuilder, BinomialExtension, EventLens, ExtensionAirBuilder, MachineAir, WithEvents, +}; use sphinx_core::utils::pad_rows_fixed; use sphinx_derive::AlignedBorrow; use std::borrow::BorrowMut; +use std::marker::PhantomData; use tracing::instrument; use crate::air::SphinxRecursionAirBuilder; @@ -110,7 +112,9 @@ impl MachineAir for FriFoldChip>( - &self, input: &EL, _: &mut ExecutionRecord, + &self, + input: &EL, + _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = input .events() diff --git a/recursion/core/src/memory/air.rs b/recursion/core/src/memory/air.rs index 3fdd96fbd..bd39961bf 100644 --- a/recursion/core/src/memory/air.rs +++ b/recursion/core/src/memory/air.rs @@ -1,19 +1,22 @@ use core::mem::size_of; -use std::{borrow::{Borrow, BorrowMut}, marker::PhantomData}; +use std::{ + borrow::{Borrow, BorrowMut}, + marker::PhantomData, +}; use p3_air::{Air, BaseAir}; use p3_field::{Field, PrimeField32}; use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_core::{ - air::{AirInteraction, MachineAir, EventLens, WithEvents, MemoryAirBuilder}, + air::{AirInteraction, EventLens, MachineAir, MemoryAirBuilder, WithEvents}, lookup::InteractionKind, utils::pad_rows_fixed, }; use tracing::instrument; use super::columns::MemoryInitCols; -use crate::{air::Block, memory::MemoryGlobalChip}; use crate::runtime::{ExecutionRecord, RecursionProgram}; +use crate::{air::Block, memory::MemoryGlobalChip}; pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::>(); @@ -30,7 +33,7 @@ impl MemoryGlobalChip { impl<'a, F: Field> WithEvents<'a> for MemoryGlobalChip { type Events = ( // first memory event - &'a [(F, Block)], + &'a [(F, Block)], // last memory event &'a [(F, F, Block)], ); @@ -50,7 +53,9 @@ impl MachineAir for MemoryGlobalChip { #[instrument(name = "generate memory trace", level = "debug", skip_all, fields(first_rows = input.events().0.len(), last_rows = input.events().1.len()))] fn generate_trace>( - &self, input: &EL, _output: &mut ExecutionRecord, + &self, + input: &EL, + _output: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); let (first_memory_events, last_memory_events) = input.events(); diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 74e7b9f9b..132ba0229 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -51,7 +51,10 @@ impl BaseAir for MultiChip { } impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for MultiChip { - type Events = ( as WithEvents<'a>>::Events, as WithEvents<'a>>::Events); + type Events = ( + as WithEvents<'a>>::Events, + as WithEvents<'a>>::Events, + ); } impl MachineAir for MultiChip { @@ -68,7 +71,9 @@ impl MachineAir for MultiChip>( - &self, input: &EL, output: &mut ExecutionRecord, + &self, + input: &EL, + output: &mut ExecutionRecord, ) -> RowMajorMatrix { let fri_fold_chip = FriFoldChip::::default(); let poseidon2 = Poseidon2Chip::default(); @@ -87,8 +92,10 @@ impl MachineAir for MultiChip), output); - let mut poseidon2_trace = poseidon2.generate_trace(&Proj::new(input, to_poseidon::), output); + let fri_fold_trace = + fri_fold_chip.generate_trace(&Proj::new(input, to_fri::), output); + let mut poseidon2_trace = + poseidon2.generate_trace(&Proj::new(input, to_poseidon::), output); let mut rows = fri_fold_trace .clone() @@ -111,8 +118,10 @@ impl MachineAir for MultiChip::do_receive_table(&poseidon2_cols); - cols.poseidon2_memory_access = Poseidon2Chip::::do_memory_access(&poseidon2_cols); + cols.poseidon2_receive_table = + Poseidon2Chip::::do_receive_table(&poseidon2_cols); + cols.poseidon2_memory_access = + Poseidon2Chip::::do_memory_access(&poseidon2_cols); } row }) @@ -176,11 +185,13 @@ where let fri_columns_local = local.fri_fold(); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::::do_memory_access::(fri_columns_local), + local.is_fri_fold + * FriFoldChip::::do_memory_access::(fri_columns_local), local.fri_fold_memory_access, ); sub_builder.assert_eq( - local.is_fri_fold * FriFoldChip::::do_receive_table::(fri_columns_local), + local.is_fri_fold + * FriFoldChip::::do_receive_table::(fri_columns_local), local.fri_fold_receive_table, ); @@ -201,7 +212,8 @@ where let poseidon2_columns = local.poseidon2(); sub_builder.assert_eq( - local.is_poseidon2 * Poseidon2Chip::::do_receive_table::(poseidon2_columns), + local.is_poseidon2 + * Poseidon2Chip::::do_receive_table::(poseidon2_columns), local.poseidon2_receive_table, ); sub_builder.assert_eq( diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index 6e2abf508..e431c0115 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -1,12 +1,12 @@ use core::borrow::Borrow; use core::mem::size_of; -use std::marker::PhantomData; use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::{AbstractField, Field}; use p3_matrix::Matrix; use sphinx_core::air::{BaseAirBuilder, ExtensionAirBuilder}; use sphinx_primitives::RC_16_30_U32; +use std::marker::PhantomData; use std::ops::Add; use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index 6ecc94fcb..b872be2ae 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -2,7 +2,10 @@ use std::borrow::BorrowMut; use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; -use sphinx_core::{air::{EventLens, MachineAir, WithEvents}, utils::pad_rows_fixed}; +use sphinx_core::{ + air::{EventLens, MachineAir, WithEvents}, + utils::pad_rows_fixed, +}; use sphinx_primitives::RC_16_30_U32; use tracing::instrument; @@ -35,7 +38,9 @@ impl MachineAir for Poseidon2Chip { #[instrument(name = "generate poseidon2 trace", level = "debug", skip_all, fields(rows = input.events().len()))] fn generate_trace>( - &self, input: &EL, _: &mut ExecutionRecord, + &self, + input: &EL, + _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); diff --git a/recursion/core/src/poseidon2_wide/external.rs b/recursion/core/src/poseidon2_wide/external.rs index 694b8b226..d0ffc30b1 100644 --- a/recursion/core/src/poseidon2_wide/external.rs +++ b/recursion/core/src/poseidon2_wide/external.rs @@ -5,7 +5,6 @@ use crate::poseidon2_wide::columns::{ }; use crate::runtime::Opcode; use core::borrow::Borrow; -use std::marker::PhantomData; use p3_air::{Air, BaseAir}; use p3_field::{AbstractField, Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; @@ -14,6 +13,7 @@ use sphinx_core::air::{BaseAirBuilder, EventLens, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use sphinx_primitives::RC_16_30_U32; use std::borrow::BorrowMut; +use std::marker::PhantomData; use tracing::instrument; use crate::air::SphinxRecursionAirBuilder; @@ -42,8 +42,7 @@ impl<'a, F: Field, const DEGREE: usize> WithEvents<'a> for Poseidon2WideChip]; } -impl MachineAir for Poseidon2WideChip -{ +impl MachineAir for Poseidon2WideChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -58,7 +57,9 @@ impl MachineAir for Poseidon2WideChip>( - &self, input: &EL, _: &mut ExecutionRecord, + &self, + input: &EL, + _: &mut ExecutionRecord, ) -> RowMajorMatrix { let mut rows = Vec::new(); diff --git a/recursion/core/src/program/mod.rs b/recursion/core/src/program/mod.rs index 9376a61a2..c2ff61158 100644 --- a/recursion/core/src/program/mod.rs +++ b/recursion/core/src/program/mod.rs @@ -2,7 +2,6 @@ use crate::air::SphinxRecursionAirBuilder; use crate::cpu::{CpuEvent, Instruction}; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use std::marker::PhantomData; use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::{Field, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; @@ -10,6 +9,7 @@ use p3_matrix::Matrix; use sphinx_core::air::{EventLens, MachineAir, WithEvents}; use sphinx_core::utils::pad_rows_fixed; use std::collections::HashMap; +use std::marker::PhantomData; use tracing::instrument; use sphinx_derive::AlignedBorrow; @@ -44,7 +44,7 @@ pub struct ProgramChip(pub PhantomData); impl ProgramChip { pub fn new() -> Self { - Self (PhantomData) + Self(PhantomData) } } @@ -109,7 +109,9 @@ impl MachineAir for ProgramChip { #[instrument(name = "generate program trace", level = "debug", skip_all, fields(rows = input.events().0.len()))] fn generate_trace>( - &self, input: &EL, _output: &mut ExecutionRecord, + &self, + input: &EL, + _output: &mut ExecutionRecord, ) -> RowMajorMatrix { // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. diff --git a/recursion/core/src/range_check/trace.rs b/recursion/core/src/range_check/trace.rs index 3044498a7..21005d212 100644 --- a/recursion/core/src/range_check/trace.rs +++ b/recursion/core/src/range_check/trace.rs @@ -16,7 +16,6 @@ impl<'a, F: Field> WithEvents<'a> for RangeCheckChip { type Events = &'a BTreeMap; } - impl MachineAir for RangeCheckChip { type Record = ExecutionRecord; type Program = RecursionProgram; @@ -40,7 +39,9 @@ impl MachineAir for RangeCheckChip { } fn generate_trace>( - &self, input: &EL, _output: &mut ExecutionRecord, + &self, + input: &EL, + _output: &mut ExecutionRecord, ) -> RowMajorMatrix { let (_, event_map) = Self::trace_and_map(); diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 09dae9cf0..120ccc527 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -101,49 +101,56 @@ impl MachineRecord for ExecutionRecord { } impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { &self.cpu_events } } -impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { +impl EventLens> + for ExecutionRecord +{ + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { &self.fri_fold_events } } impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { &self.poseidon2_events } } -impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { +impl EventLens> + for ExecutionRecord +{ + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { &self.poseidon2_events } } impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { (&self.first_memory_record, &self.last_memory_record) } } impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { (&self.program.instructions, &self.cpu_events) } } impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { &self.range_check_events } } impl EventLens> for ExecutionRecord { - fn events<'a>(&'a self) -> as sphinx_core::air::WithEvents<'a>>::Events { - (>>::events(self), >>::events(self)) + fn events(&self) -> as sphinx_core::air::WithEvents<'_>>::Events { + ( + >>::events(self), + >>::events(self), + ) } -} \ No newline at end of file +} diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 8857bf290..b658c6a71 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -65,7 +65,8 @@ impl, const DEGREE: usize> RecursionAi _phantom: PhantomData, }))) .chain(once(RecursionAir::Poseidon2Wide(Poseidon2WideChip::< - F, DEGREE, + F, + DEGREE, > { fixed_log2_rows: None, _phantom: PhantomData, @@ -79,9 +80,7 @@ impl, const DEGREE: usize> RecursionAi } pub fn get_wrap_all() -> Vec { - once(RecursionAir::Program(ProgramChip ( - PhantomData, - ))) + once(RecursionAir::Program(ProgramChip(PhantomData))) .chain(once(RecursionAir::Cpu(CpuChip { fixed_log2_rows: Some(20), _phantom: PhantomData, From 434aec77f65bc4f1118a13a51d2f2c38713060e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Tue, 4 Jun 2024 14:09:29 -0400 Subject: [PATCH 4/4] chore: rename Indexable -> Indexed --- core/src/air/machine.rs | 8 ++++---- core/src/runtime/record.rs | 4 ++-- core/src/stark/machine.rs | 2 +- core/src/stark/prover.rs | 2 +- core/src/stark/record.rs | 4 ++-- core/src/utils/prove.rs | 2 +- recursion/core/src/runtime/record.rs | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/src/air/machine.rs b/core/src/air/machine.rs index 9f023a30d..da71e27db 100644 --- a/core/src/air/machine.rs +++ b/core/src/air/machine.rs @@ -7,7 +7,7 @@ pub use sphinx_derive::MachineAir; use crate::{ runtime::Program, - stark::{Indexable, MachineRecord}, + stark::{Indexed, MachineRecord}, }; /// A description of the events related to this AIR. @@ -22,7 +22,7 @@ pub trait WithEvents<'a>: Sized { /// Chip, as specified by its `WithEvents` trait implementation. /// /// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 ) -pub trait EventLens WithEvents<'b>>: Indexable { +pub trait EventLens WithEvents<'b>>: Indexed { fn events(&self) -> >::Events; } @@ -71,10 +71,10 @@ where } } -impl<'a, T, R, F> Indexable for Proj<'a, T, R, F> +impl<'a, T, R, F> Indexed for Proj<'a, T, R, F> where T: for<'b> WithEvents<'b>, - R: EventLens + Indexable, + R: EventLens + Indexed, { fn index(&self) -> u32 { self.record.index() diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 9b3ddb593..180b6e22f 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -56,7 +56,7 @@ use crate::{ }, utils::ec::weierstrass::bls12_381::Bls12381BaseField, }; -use crate::{bytes::event::ByteRecord, stark::Indexable}; +use crate::{bytes::event::ByteRecord, stark::Indexed}; /// A record of the execution of a program. Contains event data for everything that happened during /// the execution of the shard. @@ -426,7 +426,7 @@ impl Default for ShardingConfig { } } -impl Indexable for ExecutionRecord { +impl Indexed for ExecutionRecord { fn index(&self) -> u32 { self.index } diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index ec485d442..917221679 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -25,7 +25,7 @@ use crate::air::MachineProgram; use crate::lookup::debug_interactions_with_all_chips; use crate::lookup::InteractionBuilder; use crate::lookup::InteractionKind; -use crate::stark::record::{Indexable, MachineRecord}; +use crate::stark::record::{Indexed, MachineRecord}; use crate::stark::DebugConstraintBuilder; use crate::stark::ProverConstraintFolder; use crate::stark::ShardProof; diff --git a/core/src/stark/prover.rs b/core/src/stark/prover.rs index ff131e397..94bef3a86 100644 --- a/core/src/stark/prover.rs +++ b/core/src/stark/prover.rs @@ -26,7 +26,7 @@ use super::{StarkProvingKey, VerifierConstraintFolder}; use crate::air::MachineAir; use crate::lookup::InteractionBuilder; use crate::stark::record::MachineRecord; -use crate::stark::Indexable; +use crate::stark::Indexed; use crate::stark::MachineChip; use crate::stark::PackedChallenge; use crate::stark::ProverConstraintFolder; diff --git a/core/src/stark/record.rs b/core/src/stark/record.rs index c4daa1b31..10e035eaa 100644 --- a/core/src/stark/record.rs +++ b/core/src/stark/record.rs @@ -2,11 +2,11 @@ use std::collections::HashMap; use p3_field::AbstractField; -pub trait Indexable { +pub trait Indexed { fn index(&self) -> u32; } -pub trait MachineRecord: Default + Sized + Send + Sync + Clone + Indexable { +pub trait MachineRecord: Default + Sized + Send + Sync + Clone + Indexed { type Config: Default; fn set_index(&mut self, index: u32); diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index 3ad430a7d..d0a526adf 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -24,7 +24,7 @@ use crate::stark::StarkVerifyingKey; use crate::stark::Val; use crate::stark::VerifierConstraintFolder; use crate::stark::{Com, PcsProverData, RiscvAir, ShardProof, StarkProvingKey, UniConfig}; -use crate::stark::{Indexable, MachineRecord, StarkMachine}; +use crate::stark::{Indexed, MachineRecord, StarkMachine}; use crate::utils::env; use crate::{ runtime::{Program, Runtime}, diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 120ccc527..88149bbac 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use p3_field::{AbstractField, PrimeField32}; use sphinx_core::air::EventLens; -use sphinx_core::stark::{Indexable, MachineRecord, PROOF_MAX_NUM_PVS}; +use sphinx_core::stark::{Indexed, MachineRecord, PROOF_MAX_NUM_PVS}; use super::RecursionProgram; use crate::air::Block; @@ -42,7 +42,7 @@ impl ExecutionRecord { } } -impl Indexable for ExecutionRecord { +impl Indexed for ExecutionRecord { fn index(&self) -> u32 { 0 }