Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,7 @@ benchmark.csv
# Build Artifacts
recursion/gnark-ffi/build
prover/build
prover/*.tar.gz
prover/*.tar.gz

# IDE Conf
.idea
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ ff = "0.13"
futures = "0.3.30"
futures-util = "0.3.14"
getrandom = "=0.2.14" # 0.2.15 depends on yanked libc 0.2.154
hashbrown = "0.14.5"
hashbrown = { version = "0.14.5", features = ["serde"] }
hex = "0.4.3"
home = "0.5.9"
hybrid-array = "0.2.0-rc"
Expand Down
1 change: 1 addition & 0 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ bls12_381 = { workspace = true }
cfg-if = { workspace = true }
curve25519-dalek = { workspace = true }
elliptic-curve = { workspace = true }
hashbrown = { workspace = true }
hex = { workspace = true }
hybrid-array = { workspace = true }
k256 = { workspace = true, features = ["expose-field"] }
Expand Down
2 changes: 1 addition & 1 deletion core/benches/fibonacci.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use hashbrown::HashMap;

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use sphinx_core::{
Expand Down
4 changes: 2 additions & 2 deletions core/src/alu/divrem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ mod utils;

use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use std::collections::HashMap;
use hashbrown::HashMap;

use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::AbstractField;
Expand Down Expand Up @@ -394,7 +394,7 @@ impl<F: PrimeField> MachineAir<F> for DivRemChip {
}
let mut alu_events = HashMap::new();
alu_events.insert(Opcode::ADD, add_events);
output.add_alu_events(&alu_events);
output.add_alu_events(&mut alu_events);
}

let mut lower_word = 0;
Expand Down
20 changes: 5 additions & 15 deletions core/src/bytes/event.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::collections::BTreeMap;

use p3_field::PrimeField32;
use serde::{Deserialize, Serialize};

use super::ByteOpcode;

/// A byte lookup event.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
Comment on lines -9 to +7
Copy link
Copy Markdown
Contributor Author

@wwared wwared Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not really necessary for any of the optimizations, but I removed this Copy derive to try to minimize accidental/implicit copies of byte lookups. There's only one place in core/src/runtime/record.rs that required adding a .clone() due to removing this bound

pub struct ByteLookupEvent {
/// The shard number, used for byte lookup table.
pub shard: u32,
Expand Down Expand Up @@ -36,9 +34,10 @@ pub trait ByteRecord {
fn add_byte_lookup_event(&mut self, blu_event: ByteLookupEvent);

/// Adds a list of `ByteLookupEvent`s to the record.
#[inline]
fn add_byte_lookup_events(&mut self, blu_events: Vec<ByteLookupEvent>) {
for blu_event in blu_events.iter() {
self.add_byte_lookup_event(*blu_event);
for blu_event in blu_events {
self.add_byte_lookup_event(blu_event);
}
}

Expand Down Expand Up @@ -121,6 +120,7 @@ pub trait ByteRecord {

impl ByteLookupEvent {
/// Creates a new `ByteLookupEvent`.
#[inline(always)]
pub fn new(
shard: u32,
channel: u32,
Expand All @@ -147,13 +147,3 @@ impl ByteRecord for Vec<ByteLookupEvent> {
self.push(blu_event);
}
}

impl ByteRecord for BTreeMap<u32, BTreeMap<ByteLookupEvent, usize>> {
fn add_byte_lookup_event(&mut self, blu_event: ByteLookupEvent) {
*self
.entry(blu_event.shard)
.or_default()
.entry(blu_event)
.or_insert(0) += 1
}
}
39 changes: 22 additions & 17 deletions core/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ pub mod opcode;
pub mod trace;
pub mod utils;

use alloc::collections::BTreeMap;
use core::borrow::BorrowMut;
use std::marker::PhantomData;

Expand Down Expand Up @@ -35,19 +34,10 @@ pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 16;
pub struct ByteChip<F>(PhantomData<F>);

impl<F: Field> ByteChip<F> {
/// Creates the preprocessed byte trace and event map.
/// Creates the preprocessed byte trace.
///
/// This function returns a pair `(trace, map)`, where:
/// - `trace` is a matrix containing all possible byte operations.
/// - `map` is a map from a byte lookup to the corresponding row it appears in the table and
/// the index of the result in the array of multiplicities.
pub fn trace_and_map(
shard: u32,
) -> (RowMajorMatrix<F>, BTreeMap<ByteLookupEvent, (usize, usize)>) {
// A map from a byte lookup to its corresponding row in the table and index in the array of
// multiplicities.
let mut event_map = BTreeMap::new();

/// This function returns a `trace` which is a matrix containing all possible byte operations.
pub fn trace() -> RowMajorMatrix<F> {
// The trace containing all values, with all multiplicities set to zero.
let mut initial_trace = RowMajorMatrix::new(
vec![F::zero(); NUM_ROWS * NUM_BYTE_PREPROCESSED_COLS],
Expand All @@ -65,10 +55,11 @@ impl<F: Field> ByteChip<F> {
col.b = F::from_canonical_u8(b);
col.c = F::from_canonical_u8(c);

let shard = 0;
// Iterate over all operations for results and updating the table map.
for channel in 0..NUM_BYTE_LOOKUP_CHANNELS {
for (i, opcode) in opcodes.iter().enumerate() {
let event = match opcode {
for opcode in opcodes.iter() {
match opcode {
ByteOpcode::AND => {
let and = b & c;
col.and = F::from_canonical_u8(and);
Expand Down Expand Up @@ -176,11 +167,25 @@ impl<F: Field> ByteChip<F> {
ByteLookupEvent::new(shard, channel, *opcode, v, 0, 0, 0)
}
};
event_map.insert(event, (row_index, i));
}
}
}

(initial_trace, event_map)
initial_trace
}
}

#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use std::time::Instant;

use super::*;

#[test]
fn test_trace_and_map() {
let start = Instant::now();
ByteChip::<BabyBear>::trace();
println!("trace and map: {:?}", start.elapsed());
}
Comment on lines +186 to 190
Copy link
Copy Markdown
Contributor Author

@wwared wwared Jun 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does not seem particularly too useful to me, I left it in since it was added upstream but I'd be okay removing it

}
27 changes: 13 additions & 14 deletions core/src/bytes/trace.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{borrow::BorrowMut, collections::BTreeMap};
use std::borrow::BorrowMut;

use hashbrown::HashMap;
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;

Expand All @@ -9,14 +10,15 @@ use super::{
};
use crate::{
air::{EventLens, MachineAir, WithEvents},
bytes::ByteOpcode,
runtime::{ExecutionRecord, Program},
};

pub const NUM_ROWS: usize = 1 << 16;

impl<'a, F: Field> WithEvents<'a> for ByteChip<F> {
// the byte lookups
type Events = &'a BTreeMap<u32, BTreeMap<ByteLookupEvent, usize>>;
type Events = &'a HashMap<u32, HashMap<ByteLookupEvent, usize>>;
}

impl<F: Field> MachineAir<F> for ByteChip<F> {
Expand All @@ -33,10 +35,7 @@ impl<F: Field> MachineAir<F> for ByteChip<F> {
}

fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
// OPT: We should be able to make this a constant. Also, trace / map should be separate.
// Since we only need the trace and not the map, we can just pass 0 as the shard.
let (trace, _) = Self::trace_and_map(0);

let trace = Self::trace();
Some(trace)
}

Expand All @@ -53,23 +52,23 @@ impl<F: Field> MachineAir<F> for ByteChip<F> {
input: &EL,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let shard = input.index();
let (_, event_map) = Self::trace_and_map(shard);

let mut trace = RowMajorMatrix::new(
vec![F::zero(); NUM_BYTE_MULT_COLS * NUM_ROWS],
NUM_BYTE_MULT_COLS,
);

let shard = input.index();
for (lookup, mult) in input.events()[&shard].iter() {
let (row, index) = event_map[lookup];
let row = if lookup.opcode != ByteOpcode::U16Range {
((lookup.b << 8) + lookup.c) as usize
} else {
lookup.a1 as usize
};
let index = lookup.opcode as usize;
let channel = lookup.channel as usize;
let cols: &mut ByteMultCols<F> = trace.row_mut(row).borrow_mut();

// Update the trace multiplicity
let cols: &mut ByteMultCols<F> = trace.row_mut(row).borrow_mut();
cols.mult_channels[channel].multiplicities[index] += F::from_canonical_usize(*mult);

// Set the shard column as the current shard.
cols.shard = F::from_canonical_u32(shard);
}

Expand Down
1 change: 1 addition & 0 deletions core/src/cpu/columns/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub struct ChannelSelectorCols<T> {
}

impl<F: Field> ChannelSelectorCols<F> {
#[inline(always)]
pub fn populate(&mut self, channel: u32) {
self.channel_selectors = [F::zero(); NUM_BYTE_LOOKUP_CHANNELS as usize];
self.channel_selectors[channel as usize] = F::one();
Expand Down
Loading