Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 86 additions & 5 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,91 @@
use std::marker::PhantomData;

use p3_air::BaseAir;
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;
pub use sphinx_derive::MachineAir;

use crate::{runtime::Program, stark::MachineRecord};
use crate::{
runtime::Program,
stark::{Indexed, MachineRecord},
};

/// A description of the events related to this AIR.
pub trait WithEvents<'a>: Sized {
/// output of a functional lens from the Record to
/// refs of those events relative to the AIR.
type Events: 'a;
}

/// A trait intended for implementation on Records that may store events related to Chips,
/// The purpose of this trait is to provide a way to access the events relative to a specific
/// Chip, as specified by its `WithEvents` trait implementation.
///
/// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 )
pub trait EventLens<T: for<'b> WithEvents<'b>>: Indexed {
fn events(&self) -> <T as WithEvents<'_>>::Events;
}

//////////////// Derive macro shenanigans ////////////////////////////////////////////////
// The following is *only* useful for the derive macros, you should *not* use this directly.
//
/// Hereafter, Lens composition explained pedantically: all this is saying is that
/// if I have an EventLens to T::Events, and a way (F) to deduce U::Events from that,
/// I can compose them to get an EventLens to U::Events.
pub struct Proj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventLens<T>,
{
record: &'a R,
projection: F,
_phantom: PhantomData<T>,
}

/// A constructor for the projection from T::Events to U::Events.
impl<'a, T, R, F> Proj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventLens<T>,
{
pub fn new(record: &'a R, projection: F) -> Self {
Self {
record,
projection,
_phantom: PhantomData,
}
}
}

impl<'a, T, R, U, F> EventLens<U> for Proj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventLens<T>,
U: for<'b> WithEvents<'b>,
// see https://github.com/rust-lang/rust/issues/86702 for the empty parameter
F: for<'c> Fn(<T as WithEvents<'c>>::Events, &'c ()) -> <U as WithEvents<'c>>::Events,
{
fn events<'c>(&'c self) -> <U as WithEvents<'c>>::Events {
let events: <T as WithEvents<'c>>::Events = self.record.events();
(self.projection)(events, &())
}
}

impl<'a, T, R, F> Indexed for Proj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventLens<T> + Indexed,
{
fn index(&self) -> u32 {
self.record.index()
}
}
//////////////// end of shenanigans destined for the derive macros. ////////////////

/// An AIR that is part of a multi table AIR arithmetization.
pub trait MachineAir<F: Field>: BaseAir<F> {
pub trait MachineAir<F: Field>: BaseAir<F> + for<'a> WithEvents<'a> {
/// The execution record containing events for producing the air trace.
type Record: MachineRecord;
type Record: MachineRecord + EventLens<Self>;

type Program: MachineProgram<F>;

Expand All @@ -20,10 +97,14 @@ pub trait MachineAir<F: Field>: BaseAir<F> {
/// - `input` is the execution record containing the events to be written to the trace.
/// - `output` is the execution record containing events that the `MachineAir` can add to
/// the record such as byte lookup requests.
fn generate_trace(&self, input: &Self::Record, output: &mut Self::Record) -> RowMajorMatrix<F>;
fn generate_trace<EL: EventLens<Self>>(
&self,
input: &EL,
output: &mut Self::Record,
) -> RowMajorMatrix<F>;

/// Generate the dependencies for a given execution record.
fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
fn generate_dependencies<EL: EventLens<Self>>(&self, input: &EL, output: &mut Self::Record) {
self.generate_trace(input, output);
}

Expand Down
30 changes: 19 additions & 11 deletions core/src/alu/add_sub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ use p3_maybe_rayon::prelude::ParallelSlice;
use sphinx_derive::AlignedBorrow;

use crate::{
air::{AluAirBuilder, MachineAir, Word},
air::{AluAirBuilder, EventLens, MachineAir, WithEvents, Word},
operations::AddOperation,
runtime::{ExecutionRecord, Opcode, Program},
stark::MachineRecord,
utils::pad_to_power_of_two,
};

use super::AluEvent;

/// The number of main trace columns for `AddSubChip`.
pub const NUM_ADD_SUB_COLS: usize = size_of::<AddSubCols<u8>>();

Expand Down Expand Up @@ -55,6 +57,15 @@ pub struct AddSubCols<T> {
pub is_sub: T,
}

impl<'a> WithEvents<'a> for AddSubChip {
type Events = (
// add events
&'a [AluEvent],
// sub events
&'a [AluEvent],
);
}

impl<F: PrimeField> MachineAir<F> for AddSubChip {
type Record = ExecutionRecord;

Expand All @@ -64,20 +75,17 @@ impl<F: PrimeField> MachineAir<F> for AddSubChip {
"AddSub".to_string()
}

fn generate_trace(
fn generate_trace<EL: EventLens<Self>>(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
input: &EL,
output: &mut Self::Record,
) -> RowMajorMatrix<F> {
let (add_events, sub_events) = input.events();
// Generate the rows for the trace.
let chunk_size = std::cmp::max(
(input.add_events.len() + input.sub_events.len()) / num_cpus::get(),
1,
);
let merged_events = input
.add_events
let chunk_size = std::cmp::max((add_events.len() + sub_events.len()) / num_cpus::get(), 1);
let merged_events = add_events
.iter()
.chain(input.sub_events.iter())
.chain(sub_events.iter())
.collect::<Vec<_>>();

let rows_and_records = merged_events
Expand Down
14 changes: 10 additions & 4 deletions core/src/alu/bitwise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use sphinx_derive::AlignedBorrow;

use crate::air::Word;
use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir};
use crate::air::{EventLens, WithEvents, Word};
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::utils::pad_to_power_of_two;

use super::AluEvent;

/// The number of main trace columns for `BitwiseChip`.
pub const NUM_BITWISE_COLS: usize = size_of::<BitwiseCols<u8>>();

Expand Down Expand Up @@ -49,6 +51,10 @@ pub struct BitwiseCols<T> {
pub is_and: T,
}

impl<'a> WithEvents<'a> for BitwiseChip {
type Events = &'a [AluEvent];
}

impl<F: PrimeField> MachineAir<F> for BitwiseChip {
type Record = ExecutionRecord;

Expand All @@ -58,14 +64,14 @@ impl<F: PrimeField> MachineAir<F> for BitwiseChip {
"Bitwise".to_string()
}

fn generate_trace(
fn generate_trace<EL: EventLens<Self>>(
&self,
input: &ExecutionRecord,
input: &EL,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let rows = input
.bitwise_events
.events()
.iter()
.map(|event| {
let mut row = [F::zero(); NUM_BITWISE_COLS];
Expand Down
14 changes: 9 additions & 5 deletions core/src/alu/divrem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ use p3_matrix::Matrix;
use sphinx_derive::AlignedBorrow;

use self::utils::eval_abs_value;
use crate::air::Word;
use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder};
use crate::air::{EventLens, WithEvents, Word};
use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation};
use crate::alu::AluEvent;
use crate::bytes::event::ByteRecord;
Expand Down Expand Up @@ -187,6 +187,10 @@ pub struct DivRemCols<T> {
pub is_real: T,
}

impl<'a> WithEvents<'a> for DivRemChip {
type Events = &'a [AluEvent];
}

impl<F: PrimeField> MachineAir<F> for DivRemChip {
type Record = ExecutionRecord;

Expand All @@ -196,13 +200,13 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
"DivRem".to_string()
}

fn generate_trace(
fn generate_trace<EL: EventLens<Self>>(
&self,
input: &ExecutionRecord,
input: &EL,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let divrem_events = &input.divrem_events;
let divrem_events = input.events();
let mut rows: Vec<[F; NUM_DIVREM_COLS]> = Vec::with_capacity(divrem_events.len());
for event in divrem_events {
assert!(
Expand Down Expand Up @@ -405,7 +409,7 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
row
};
debug_assert!(padded_row_template.len() == NUM_DIVREM_COLS);
for i in input.divrem_events.len() * NUM_DIVREM_COLS..trace.values.len() {
for i in input.events().len() * NUM_DIVREM_COLS..trace.values.len() {
trace.values[i] = padded_row_template[i % NUM_DIVREM_COLS];
}

Expand Down
14 changes: 10 additions & 4 deletions core/src/alu/lt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::*;
use sphinx_derive::AlignedBorrow;

use crate::air::Word;
use crate::air::{AluAirBuilder, BaseAirBuilder, ByteAirBuilder, MachineAir};
use crate::air::{EventLens, WithEvents, Word};
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::utils::pad_to_power_of_two;

use super::AluEvent;

/// The number of main trace columns for `LtChip`.
pub const NUM_LT_COLS: usize = size_of::<LtCols<u8>>();

Expand Down Expand Up @@ -91,6 +93,10 @@ impl LtCols<u32> {
}
}

impl<'a> WithEvents<'a> for LtChip {
type Events = &'a [AluEvent];
}

impl<F: PrimeField32> MachineAir<F> for LtChip {
type Record = ExecutionRecord;

Expand All @@ -100,14 +106,14 @@ impl<F: PrimeField32> MachineAir<F> for LtChip {
"Lt".to_string()
}

fn generate_trace(
fn generate_trace<EL: EventLens<Self>>(
&self,
input: &ExecutionRecord,
input: &EL,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.
let (rows, new_byte_lookup_events): (Vec<_>, Vec<_>) = input
.lt_events
.events()
.par_iter()
.map(|event| {
let mut row = [F::zero(); NUM_LT_COLS];
Expand Down
14 changes: 10 additions & 4 deletions core/src/alu/mul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ use p3_maybe_rayon::prelude::ParallelIterator;
use p3_maybe_rayon::prelude::ParallelSlice;
use sphinx_derive::AlignedBorrow;

use crate::air::Word;
use crate::air::{AluAirBuilder, ByteAirBuilder, MachineAir, WordAirBuilder};
use crate::air::{EventLens, WithEvents, Word};
use crate::alu::mul::utils::get_msb;
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
Expand All @@ -54,6 +54,8 @@ use crate::runtime::{ExecutionRecord, Opcode, Program};
use crate::stark::MachineRecord;
use crate::utils::pad_to_power_of_two;

use super::AluEvent;

/// The number of main trace columns for `MulChip`.
pub const NUM_MUL_COLS: usize = size_of::<MulCols<u8>>();

Expand Down Expand Up @@ -121,6 +123,10 @@ pub struct MulCols<T> {
pub is_real: T,
}

impl<'a> WithEvents<'a> for MulChip {
type Events = &'a [AluEvent];
}

impl<F: PrimeField> MachineAir<F> for MulChip {
type Record = ExecutionRecord;

Expand All @@ -130,12 +136,12 @@ impl<F: PrimeField> MachineAir<F> for MulChip {
"Mul".to_string()
}

fn generate_trace(
fn generate_trace<EL: EventLens<Self>>(
&self,
input: &ExecutionRecord,
input: &EL,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mul_events = input.mul_events.clone();
let mul_events = input.events();
// Compute the chunk size based on the number of events and the number of CPUs.
let chunk_size = std::cmp::max(mul_events.len() / num_cpus::get(), 1);

Expand Down
Loading