diff --git a/circuits/cpp/src/aztec3/circuits/kernel/private/.test.cpp b/circuits/cpp/src/aztec3/circuits/kernel/private/.test.cpp index ff441857bedc..534dde73a4bc 100644 --- a/circuits/cpp/src/aztec3/circuits/kernel/private/.test.cpp +++ b/circuits/cpp/src/aztec3/circuits/kernel/private/.test.cpp @@ -651,12 +651,10 @@ TEST(private_kernel_tests, circuit_create_proof_cbinds) uint8_t const* public_inputs_buf = nullptr; size_t public_inputs_size = 0; // info("Simulating to generate public inputs..."); - uint8_t* const circuit_failure_ptr = private_kernel__sim(signed_constructor_tx_request_vec.data(), - nullptr, // no previous kernel on first iteration - private_constructor_call_vec.data(), - true, // first iteration - &public_inputs_size, - &public_inputs_buf); + uint8_t* const circuit_failure_ptr = private_kernel__sim_init(signed_constructor_tx_request_vec.data(), + private_constructor_call_vec.data(), + &public_inputs_size, + &public_inputs_buf); ASSERT_TRUE(circuit_failure_ptr == nullptr); // TODO better equality check diff --git a/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.cpp b/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.cpp index 4d3876417d77..e6578c45be15 100644 --- a/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.cpp +++ b/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.cpp @@ -65,43 +65,55 @@ CBIND(private_kernel__dummy_previous_kernel, []() { return dummy_previous_kernel // TODO(dbanks12): comment about how public_inputs is a confusing name // returns size of public inputs -WASM_EXPORT uint8_t* private_kernel__sim(uint8_t const* signed_tx_request_buf, - uint8_t const* previous_kernel_buf, - uint8_t const* private_call_buf, - bool first_iteration, - size_t* private_kernel_public_inputs_size_out, - uint8_t const** private_kernel_public_inputs_buf) +WASM_EXPORT uint8_t* private_kernel__sim_init(uint8_t const* signed_tx_request_buf, + uint8_t const* private_call_buf, + size_t* private_kernel_public_inputs_size_out, + uint8_t const** private_kernel_public_inputs_buf) { - DummyComposer composer = DummyComposer("private_kernel__sim"); + DummyComposer composer = DummyComposer("private_kernel__sim_init"); + PrivateCallData private_call_data; read(private_call_buf, private_call_data); - KernelCircuitPublicInputs public_inputs = KernelCircuitPublicInputs{}; + SignedTxRequest signed_tx_request; + read(signed_tx_request_buf, signed_tx_request); - if (first_iteration) { - SignedTxRequest signed_tx_request; - read(signed_tx_request_buf, signed_tx_request); + PrivateKernelInputsInit const private_inputs = PrivateKernelInputsInit{ + .signed_tx_request = signed_tx_request, + .private_call = private_call_data, + }; - // Assert that previous_kernel_buf is empty (i.e. nullptr) - ASSERT(previous_kernel_buf == nullptr); + auto public_inputs = native_private_kernel_circuit_initial(composer, private_inputs); - PrivateKernelInputsInit const private_inputs = PrivateKernelInputsInit{ - .signed_tx_request = signed_tx_request, - .private_call = private_call_data, - }; + // serialize public inputs to bytes vec + std::vector public_inputs_vec; + write(public_inputs_vec, public_inputs); + // copy public inputs to output buffer + auto* raw_public_inputs_buf = (uint8_t*)malloc(public_inputs_vec.size()); + memcpy(raw_public_inputs_buf, (void*)public_inputs_vec.data(), public_inputs_vec.size()); + *private_kernel_public_inputs_buf = raw_public_inputs_buf; + *private_kernel_public_inputs_size_out = public_inputs_vec.size(); + return composer.alloc_and_serialize_first_failure(); +} - public_inputs = native_private_kernel_circuit_initial(composer, private_inputs); - } else { - PreviousKernelData previous_kernel; - read(previous_kernel_buf, previous_kernel); +WASM_EXPORT uint8_t* private_kernel__sim_inner(uint8_t const* previous_kernel_buf, + uint8_t const* private_call_buf, + size_t* private_kernel_public_inputs_size_out, + uint8_t const** private_kernel_public_inputs_buf) +{ + DummyComposer composer = DummyComposer("private_kernel__sim_inner"); + PrivateCallData private_call_data; + read(private_call_buf, private_call_data); - PrivateKernelInputsInner const private_inputs = PrivateKernelInputsInner{ - .previous_kernel = previous_kernel, - .private_call = private_call_data, - }; + PreviousKernelData previous_kernel; + read(previous_kernel_buf, previous_kernel); - public_inputs = native_private_kernel_circuit_inner(composer, private_inputs); - } + PrivateKernelInputsInner const private_inputs = PrivateKernelInputsInner{ + .previous_kernel = previous_kernel, + .private_call = private_call_data, + }; + + auto public_inputs = native_private_kernel_circuit_inner(composer, private_inputs); // serialize public inputs to bytes vec std::vector public_inputs_vec; diff --git a/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.h b/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.h index 9dc5a7f5841a..c0a095727f27 100644 --- a/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.h +++ b/circuits/cpp/src/aztec3/circuits/kernel/private/c_bind.h @@ -7,12 +7,14 @@ WASM_EXPORT size_t private_kernel__init_proving_key(uint8_t const** pk_buf); WASM_EXPORT size_t private_kernel__init_verification_key(uint8_t const* pk_buf, uint8_t const** vk_buf); CBIND_DECL(private_kernel__dummy_previous_kernel); -WASM_EXPORT uint8_t* private_kernel__sim(uint8_t const* signed_tx_request_buf, - uint8_t const* previous_kernel_buf, - uint8_t const* private_call_buf, - bool first_iteration, - size_t* private_kernel_public_inputs_size_out, - uint8_t const** private_kernel_public_inputs_buf); +WASM_EXPORT uint8_t* private_kernel__sim_init(uint8_t const* signed_tx_request_buf, + uint8_t const* private_call_buf, + size_t* private_kernel_public_inputs_size_out, + uint8_t const** private_kernel_public_inputs_buf); +WASM_EXPORT uint8_t* private_kernel__sim_inner(uint8_t const* previous_kernel_buf, + uint8_t const* private_call_buf, + size_t* private_kernel_public_inputs_size_out, + uint8_t const** private_kernel_public_inputs_buf); WASM_EXPORT size_t private_kernel__prove(uint8_t const* signed_tx_request_buf, uint8_t const* previous_kernel_buf, uint8_t const* private_call_buf, diff --git a/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.test.ts b/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.test.ts index 67dfab99c80e..aa5436de25f7 100644 --- a/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.test.ts +++ b/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.test.ts @@ -69,9 +69,16 @@ describe('Kernel Prover', () => { }; const expectExecution = (fns: string[]) => { - const callStackItems = proofCreator.createProof.mock.calls.map(args => args[2].callStackItem.functionData); - expect(callStackItems).toEqual(fns); - proofCreator.createProof.mockClear(); + const callStackItemsInit = proofCreator.createProofInit.mock.calls.map(args => args[1].callStackItem.functionData); + const callStackItemsInner = proofCreator.createProofInner.mock.calls.map( + args => args[1].callStackItem.functionData, + ); + + expect(proofCreator.createProofInit).toHaveBeenCalledTimes(Math.min(1, fns.length)); + expect(proofCreator.createProofInner).toHaveBeenCalledTimes(Math.max(0, fns.length - 1)); + expect(callStackItemsInit.concat(callStackItemsInner)).toEqual(fns); + proofCreator.createProofInner.mockClear(); + proofCreator.createProofInit.mockClear(); }; const expectOutputNotes = (outputNotes: OutputNoteData[], expectedNoteIndices: number[]) => { @@ -94,7 +101,8 @@ describe('Kernel Prover', () => { proofCreator.getSiloedCommitments.mockImplementation(publicInputs => Promise.resolve(publicInputs.newCommitments.map(createFakeSiloedCommitment)), ); - proofCreator.createProof.mockResolvedValue(createProofOutput([])); + proofCreator.createProofInit.mockResolvedValue(createProofOutput([])); + proofCreator.createProofInner.mockResolvedValue(createProofOutput([])); prover = new KernelProver(oracle, proofCreator); }); @@ -141,9 +149,9 @@ describe('Kernel Prover', () => { const resultA = createExecutionResult('a', [1, 2, 3]); const resultB = createExecutionResult('b', [4]); const resultC = createExecutionResult('c', [5, 6]); - proofCreator.createProof.mockResolvedValueOnce(createProofOutput([1, 2, 3])); - proofCreator.createProof.mockResolvedValueOnce(createProofOutput([1, 3, 4])); - proofCreator.createProof.mockResolvedValueOnce(createProofOutput([1, 3, 5, 6])); + proofCreator.createProofInit.mockResolvedValueOnce(createProofOutput([1, 2, 3])); + proofCreator.createProofInner.mockResolvedValueOnce(createProofOutput([1, 3, 4])); + proofCreator.createProofInner.mockResolvedValueOnce(createProofOutput([1, 3, 5, 6])); const executionResult = { ...resultA, diff --git a/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.ts b/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.ts index 9cfd08d440cc..df4f1965732d 100644 --- a/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.ts +++ b/yarn-project/aztec-rpc/src/kernel_prover/kernel_prover.ts @@ -87,17 +87,6 @@ export class KernelProver { proof: makeEmptyProof(), }; while (executionStack.length) { - const previousVkMembershipWitness = firstIteration - ? MembershipWitness.random(VK_TREE_HEIGHT) - : await this.oracle.getVkMembershipWitness(previousVerificationKey); - const previousKernelData = new PreviousKernelData( - output.publicInputs, - output.proof, - previousVerificationKey, - Number(previousVkMembershipWitness.leafIndex), - assertLength(previousVkMembershipWitness.siblingPath, VK_TREE_HEIGHT), - ); - const currentExecution = executionStack.pop()!; executionStack.push(...currentExecution.nestedExecutions); const privateCallStackPreimages = currentExecution.nestedExecutions.map(result => result.callStackItem); @@ -114,12 +103,19 @@ export class KernelProver { const privateCallData = await this.createPrivateCallData(currentExecution, privateCallStackPreimages); - output = await this.proofCreator.createProof( - signedTxRequest, - previousKernelData, - privateCallData, - firstIteration, - ); + if (firstIteration) { + output = await this.proofCreator.createProofInit(signedTxRequest, privateCallData); + } else { + const previousVkMembershipWitness = await this.oracle.getVkMembershipWitness(previousVerificationKey); + const previousKernelData = new PreviousKernelData( + output.publicInputs, + output.proof, + previousVerificationKey, + Number(previousVkMembershipWitness.leafIndex), + assertLength(previousVkMembershipWitness.siblingPath, VK_TREE_HEIGHT), + ); + output = await this.proofCreator.createProofInner(previousKernelData, privateCallData); + } (await this.getNewNotes(currentExecution)).forEach(n => { newNotes[n.commitment.toString()] = n; }); diff --git a/yarn-project/aztec-rpc/src/kernel_prover/proof_creator.ts b/yarn-project/aztec-rpc/src/kernel_prover/proof_creator.ts index 48c3bb729efd..d189fe46e1e4 100644 --- a/yarn-project/aztec-rpc/src/kernel_prover/proof_creator.ts +++ b/yarn-project/aztec-rpc/src/kernel_prover/proof_creator.ts @@ -9,7 +9,8 @@ import { Proof, SignedTxRequest, makeEmptyProof, - privateKernelSim, + privateKernelSimInner, + privateKernelSimInit, } from '@aztec/circuits.js'; import { Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; @@ -35,12 +36,8 @@ export interface ProofOutput { */ export interface ProofCreator { getSiloedCommitments(publicInputs: PrivateCircuitPublicInputs): Promise; - createProof( - signedTxRequest: SignedTxRequest, - previousKernelData: PreviousKernelData, - privateCallData: PrivateCallData, - firstIteration: boolean, - ): Promise; + createProofInit(signedTxRequest: SignedTxRequest, privateCallData: PrivateCallData): Promise; + createProofInner(previousKernelData: PreviousKernelData, privateCallData: PrivateCallData): Promise; } const OUTER_COMMITMENT = 3; @@ -73,29 +70,44 @@ export class KernelProofCreator { } /** - * Creates a proof output for a given signed transaction request, previous kernel data, private call data, and first iteration flag. + * Creates a proof output for a given signed transaction request and private call data for the first iteration. * * @param signedTxRequest - The signed transaction request object. - * @param previousKernelData - The previous kernel data object. * @param privateCallData - The private call data object. - * @param firstIteration - A boolean flag indicating if it's the first iteration of the kernel proof creation process. * @returns A Promise resolving to a ProofOutput object containing public inputs and the kernel proof. */ - public async createProof( + public async createProofInit( signedTxRequest: SignedTxRequest, + privateCallData: PrivateCallData, + ): Promise { + const wasm = await CircuitsWasm.get(); + this.log('Executing private kernel simulation init...'); + const publicInputs = await privateKernelSimInit(wasm, signedTxRequest, privateCallData); + this.log('Skipping private kernel proving...'); + // TODO + const proof = makeEmptyProof(); + this.log('Kernel Prover Completed!'); + + return { + publicInputs, + proof, + }; + } + + /** + * Creates a proof output for a given previous kernel data and private call data for an inner iteration. + * + * @param previousKernelData - The previous kernel data object. + * @param privateCallData - The private call data object. + * @returns A Promise resolving to a ProofOutput object containing public inputs and the kernel proof. + */ + public async createProofInner( previousKernelData: PreviousKernelData, privateCallData: PrivateCallData, - firstIteration: boolean, ): Promise { const wasm = await CircuitsWasm.get(); - this.log('Executing private kernel simulation...'); - const publicInputs = await privateKernelSim( - wasm, - signedTxRequest, - previousKernelData, - privateCallData, - firstIteration, - ); + this.log('Executing private kernel simulation inner...'); + const publicInputs = await privateKernelSimInner(wasm, previousKernelData, privateCallData); this.log('Skipping private kernel proving...'); // TODO const proof = makeEmptyProof(); diff --git a/yarn-project/circuits.js/src/kernel/private_kernel.ts b/yarn-project/circuits.js/src/kernel/private_kernel.ts index 2f4eec6826e6..a3b2ee28d432 100644 --- a/yarn-project/circuits.js/src/kernel/private_kernel.ts +++ b/yarn-project/circuits.js/src/kernel/private_kernel.ts @@ -92,39 +92,75 @@ export async function privateKernelProve( } /** - * Computes the public inputs of the private kernel without computing the proof. + * Computes the public inputs of the private kernel first iteration without computing the proof. * @param wasm - The circuits wasm instance. * @param signedTxRequest - The signed transaction request. - * @param previousKernel - The previous kernel data (dummy if this is the first kernel in the chain). * @param privateCallData - The private call data. - * @param firstIteration - Whether this is the first iteration of the private kernel. * @returns The public inputs of the private kernel. */ -export async function privateKernelSim( +export async function privateKernelSimInit( wasm: CircuitsWasm, signedTxRequest: SignedTxRequest, - previousKernel: PreviousKernelData, privateCallData: PrivateCallData, - firstIteration: boolean, ): Promise { wasm.call('pedersen__init'); const signedTxRequestBuffer = signedTxRequest.toBuffer(); - const previousKernelBuffer = previousKernel.toBuffer(); const privateCallDataBuffer = privateCallData.toBuffer(); - const previousKernelBufferOffset = signedTxRequestBuffer.length; - const privateCallDataOffset = previousKernelBufferOffset + previousKernelBuffer.length; + const privateCallDataOffset = signedTxRequestBuffer.length; wasm.writeMemory(0, signedTxRequestBuffer); - wasm.writeMemory(previousKernelBufferOffset, previousKernelBuffer); wasm.writeMemory(privateCallDataOffset, privateCallDataBuffer); const outputBufSizePtr = wasm.call('bbmalloc', 4); const outputBufPtrPtr = wasm.call('bbmalloc', 4); // Run and read outputs const circuitFailureBufPtr = await wasm.asyncCall( - 'private_kernel__sim', + 'private_kernel__sim_init', + 0, + privateCallDataOffset, + outputBufSizePtr, + outputBufPtrPtr, + ); + try { + // Try deserializing the output to `KernelCircuitPublicInputs` and throw if it fails + return handleCircuitOutput( + wasm, + outputBufSizePtr, + outputBufPtrPtr, + circuitFailureBufPtr, + KernelCircuitPublicInputs, + ); + } finally { + // Free memory + wasm.call('bbfree', outputBufSizePtr); + wasm.call('bbfree', outputBufPtrPtr); + wasm.call('bbfree', circuitFailureBufPtr); + } +} + +/** + * Computes the public inputs of a private kernel inner iteration without computing the proof. + * @param wasm - The circuits wasm instance. + * @param previousKernel - The previous kernel data (dummy if this is the first kernel in the chain). + * @param privateCallData - The private call data. + * @returns The public inputs of the private kernel. + */ +export async function privateKernelSimInner( + wasm: CircuitsWasm, + previousKernel: PreviousKernelData, + privateCallData: PrivateCallData, +): Promise { + wasm.call('pedersen__init'); + const previousKernelBuffer = previousKernel.toBuffer(); + const privateCallDataBuffer = privateCallData.toBuffer(); + const privateCallDataOffset = previousKernelBuffer.length; + wasm.writeMemory(0, previousKernelBuffer); + wasm.writeMemory(privateCallDataOffset, privateCallDataBuffer); + const outputBufSizePtr = wasm.call('bbmalloc', 4); + const outputBufPtrPtr = wasm.call('bbmalloc', 4); + // Run and read outputs + const circuitFailureBufPtr = await wasm.asyncCall( + 'private_kernel__sim_inner', 0, - previousKernelBufferOffset, privateCallDataOffset, - firstIteration, outputBufSizePtr, outputBufPtrPtr, );