diff --git a/yarn-project/foundation/src/collection/array.test.ts b/yarn-project/foundation/src/collection/array.test.ts index f71c739354d4..9348a90d2191 100644 --- a/yarn-project/foundation/src/collection/array.test.ts +++ b/yarn-project/foundation/src/collection/array.test.ts @@ -8,6 +8,7 @@ import { mean, median, partition, + partitionAsync, removeArrayPaddingEnd, stdDev, times, @@ -380,3 +381,40 @@ describe('partition', () => { expect(odd).toEqual([{ a: 1 }, { a: 3 }]); }); }); + +describe('partitionAsync', () => { + it('partitions an array into pass and fail arrays based on the predicate', async () => { + const input = [1, 2, 3, 4, 5]; + const [even, odd] = await partitionAsync(input, x => Promise.resolve(x % 2 === 0)); + expect(even).toEqual([2, 4]); + expect(odd).toEqual([1, 3, 5]); + }); + + it('returns all items in the first array if all pass the predicate', async () => { + const input = [2, 4, 6]; + const [pass, fail] = await partitionAsync(input, x => Promise.resolve(x % 2 === 0)); + expect(pass).toEqual([2, 4, 6]); + expect(fail).toEqual([]); + }); + + it('returns all items in the second array if none pass the predicate', async () => { + const input = [1, 3, 5]; + const [pass, fail] = await partitionAsync(input, x => Promise.resolve(x % 2 === 0)); + expect(pass).toEqual([]); + expect(fail).toEqual([1, 3, 5]); + }); + + it('handles an empty array', async () => { + const input: number[] = []; + const [pass, fail] = await partitionAsync(input, x => Promise.resolve(x > 0)); + expect(pass).toEqual([]); + expect(fail).toEqual([]); + }); + + it('works with objects and custom predicates', async () => { + const input = [{ a: 1 }, { a: 2 }, { a: 3 }]; + const [even, odd] = await partitionAsync(input, obj => Promise.resolve(obj.a % 2 === 0)); + expect(even).toEqual([{ a: 2 }]); + expect(odd).toEqual([{ a: 1 }, { a: 3 }]); + }); +}); diff --git a/yarn-project/foundation/src/collection/array.ts b/yarn-project/foundation/src/collection/array.ts index 69dde35053a5..e2636c27d1fc 100644 --- a/yarn-project/foundation/src/collection/array.ts +++ b/yarn-project/foundation/src/collection/array.ts @@ -315,3 +315,17 @@ export function partition(items: T[], predicate: (item: T) => boolean): [T[], } return [pass, fail]; } + +/** Partitions the given iterable into two arrays based on the predicate. */ +export async function partitionAsync(items: T[], predicate: (item: T) => Promise): Promise<[T[], T[]]> { + const pass: T[] = []; + const fail: T[] = []; + for (const item of items) { + if (await predicate(item)) { + pass.push(item); + } else { + fail.push(item); + } + } + return [pass, fail]; +} diff --git a/yarn-project/p2p/src/client/factory.ts b/yarn-project/p2p/src/client/factory.ts index eb277343c5a6..061181f3ab7e 100644 --- a/yarn-project/p2p/src/client/factory.ts +++ b/yarn-project/p2p/src/client/factory.ts @@ -17,7 +17,10 @@ import { AttestationPool, type AttestationPoolApi } from '../mem_pools/attestati import type { MemPools } from '../mem_pools/interface.js'; import type { TxPoolV2 } from '../mem_pools/tx_pool_v2/interfaces.js'; import { AztecKVTxPoolV2 } from '../mem_pools/tx_pool_v2/tx_pool_v2.js'; -import { createTxValidatorForTransactionsEnteringPendingTxPool } from '../msg_validators/index.js'; +import { + createTxValidatorForReqResponseReceivedTxs, + createTxValidatorForTransactionsEnteringPendingTxPool, +} from '../msg_validators/index.js'; import { DummyP2PService } from '../services/dummy_service.js'; import { LibP2PService } from '../services/index.js'; import { createFileStoreTxSources } from '../services/tx_collection/file_store_tx_source.js'; @@ -124,9 +127,12 @@ export async function createP2PClient( telemetry, ); + const txValidatorForTxCollection = createTxValidatorForReqResponseReceivedTxs(proofVerifier, config); const nodeSources = [ - ...createNodeRpcTxSources(config.txCollectionNodeRpcUrls, config), - ...(deps.rpcTxProviders ?? []).map((node, i) => new NodeRpcTxSource(node, `node-rpc-provider-${i}`)), + ...createNodeRpcTxSources(config.txCollectionNodeRpcUrls, txValidatorForTxCollection, config), + ...(deps.rpcTxProviders ?? []).map( + (node, i) => new NodeRpcTxSource(node, txValidatorForTxCollection, `node-rpc-provider-${i}`), + ), ...(deps.txCollectionNodeSources ?? []), ]; if (nodeSources.length > 0) { @@ -138,6 +144,7 @@ export async function createP2PClient( const fileStoreSources = await createFileStoreTxSources( config.txCollectionFileStoreUrls, txFileStoreBasePath, + txValidatorForTxCollection, logger.createChild('file-store-tx-source'), telemetry, ); diff --git a/yarn-project/p2p/src/services/tx_collection/file_store_tx_source.ts b/yarn-project/p2p/src/services/tx_collection/file_store_tx_source.ts index d682e303a0cf..b47927e8cb44 100644 --- a/yarn-project/p2p/src/services/tx_collection/file_store_tx_source.ts +++ b/yarn-project/p2p/src/services/tx_collection/file_store_tx_source.ts @@ -1,7 +1,8 @@ +import { partitionAsync } from '@aztec/foundation/collection'; import { type Logger, createLogger } from '@aztec/foundation/log'; import { Timer } from '@aztec/foundation/timer'; import { type ReadOnlyFileStore, createReadOnlyFileStore } from '@aztec/stdlib/file-store'; -import { Tx, type TxHash } from '@aztec/stdlib/tx'; +import { Tx, type TxHash, type TxValidator } from '@aztec/stdlib/tx'; import { type Histogram, Metrics, @@ -23,6 +24,7 @@ export class FileStoreTxSource implements TxSource { private readonly fileStore: ReadOnlyFileStore, private readonly baseUrl: string, private readonly basePath: string, + private readonly txValidator: TxValidator, private readonly log: Logger, telemetry: TelemetryClient, ) { @@ -44,6 +46,7 @@ export class FileStoreTxSource implements TxSource { public static async create( url: string, basePath: string, + txValidator: TxValidator, log: Logger = createLogger('p2p:file_store_tx_source'), telemetry: TelemetryClient = getTelemetryClient(), ): Promise { @@ -53,7 +56,7 @@ export class FileStoreTxSource implements TxSource { log.warn(`Failed to create file store for URL: ${url}`); return undefined; } - return new FileStoreTxSource(fileStore, url, basePath, log, telemetry); + return new FileStoreTxSource(fileStore, url, basePath, txValidator, log, telemetry); } catch (err) { log.warn(`Error creating file store for URL: ${url}`, { error: err }); return undefined; @@ -65,35 +68,41 @@ export class FileStoreTxSource implements TxSource { } public async getTxsByHash(txHashes: TxHash[]): Promise { - const invalidTxHashes: string[] = []; + const results = await Promise.all( + txHashes.map(async txHash => { + const path = `${this.basePath}/txs/${txHash.toString()}.bin`; + const timer = new Timer(); + try { + const buffer = await this.fileStore.read(path); + const tx = Tx.fromBuffer(buffer); + return { tx, downloadDuration: timer.ms(), downloadSize: buffer.length }; + } catch { + this.downloadsFailed.add(1); + return undefined; + } + }), + ); + + const txs = results.filter(tx => tx !== undefined); + const [validTxs, invalidTxs] = await partitionAsync( + txs, + async ({ tx, downloadDuration, downloadSize }): Promise => { + const valid = await this.txValidator.validateTx(tx); + if (valid.result === 'valid') { + this.downloadsSuccess.add(1); + this.downloadDuration.record(Math.ceil(downloadDuration)); + this.downloadSize.record(downloadSize); + return true; + } else { + this.downloadsFailed.add(1); + return false; + } + }, + ); + return { - validTxs: ( - await Promise.all( - txHashes.map(async txHash => { - const path = `${this.basePath}/txs/${txHash.toString()}.bin`; - const timer = new Timer(); - try { - const buffer = await this.fileStore.read(path); - const tx = Tx.fromBuffer(buffer); - if ((await tx.validateTxHash()) && txHash.equals(tx.txHash)) { - this.downloadsSuccess.add(1); - this.downloadDuration.record(Math.ceil(timer.ms())); - this.downloadSize.record(buffer.length); - return tx; - } else { - invalidTxHashes.push(tx.txHash.toString()); - this.downloadsFailed.add(1); - return undefined; - } - } catch { - // Tx not found or error reading - return undefined - this.downloadsFailed.add(1); - return undefined; - } - }), - ) - ).filter(tx => tx !== undefined), - invalidTxHashes: invalidTxHashes, + validTxs: validTxs.map(({ tx }) => tx), + invalidTxHashes: invalidTxs.map(({ tx }) => tx.getTxHash().toString()), }; } } @@ -109,9 +118,12 @@ export class FileStoreTxSource implements TxSource { export async function createFileStoreTxSources( urls: string[], basePath: string, + txValidator: TxValidator, log: Logger = createLogger('p2p:file_store_tx_source'), telemetry: TelemetryClient = getTelemetryClient(), ): Promise { - const sources = await Promise.all(urls.map(url => FileStoreTxSource.create(url, basePath, log, telemetry))); + const sources = await Promise.all( + urls.map(url => FileStoreTxSource.create(url, basePath, txValidator, log, telemetry)), + ); return sources.filter((s): s is FileStoreTxSource => s !== undefined); } diff --git a/yarn-project/p2p/src/services/tx_collection/tx_source.test.ts b/yarn-project/p2p/src/services/tx_collection/tx_source.test.ts new file mode 100644 index 000000000000..4767b9fc2481 --- /dev/null +++ b/yarn-project/p2p/src/services/tx_collection/tx_source.test.ts @@ -0,0 +1,62 @@ +import type { AztecNode } from '@aztec/stdlib/interfaces/client'; +import { Tx, type TxValidator } from '@aztec/stdlib/tx'; + +import { type MockProxy, mock } from 'jest-mock-extended'; + +import { NodeRpcTxSource } from './tx_source.js'; + +describe('NodeRpcTxSource', () => { + let mockClient: MockProxy>; + let mockValidator: MockProxy; + + const makeTx = async () => { + const tx = Tx.random(); + await tx.recomputeHash(); + return tx; + }; + + beforeEach(() => { + mockClient = mock>(); + mockValidator = mock(); + mockValidator.validateTx.mockResolvedValue({ result: 'valid' }); + }); + + const createSource = () => new NodeRpcTxSource(mockClient, mockValidator, 'test'); + + it('returns valid txs when validator accepts', async () => { + const tx1 = await makeTx(); + const tx2 = await makeTx(); + mockClient.getTxsByHash.mockResolvedValue([tx1, tx2]); + + const result = await createSource().getTxsByHash([tx1.getTxHash(), tx2.getTxHash()]); + + expect(result.validTxs).toHaveLength(2); + expect(result.invalidTxHashes).toHaveLength(0); + }); + + it('returns invalid tx hashes when validator rejects', async () => { + const tx1 = await makeTx(); + const tx2 = await makeTx(); + mockClient.getTxsByHash.mockResolvedValue([tx1, tx2]); + mockValidator.validateTx.mockResolvedValue({ result: 'invalid', reason: ['bad'] }); + + const result = await createSource().getTxsByHash([tx1.getTxHash(), tx2.getTxHash()]); + + expect(result.validTxs).toHaveLength(0); + expect(result.invalidTxHashes).toEqual([tx1.getTxHash().toString(), tx2.getTxHash().toString()]); + }); + + it('partitions txs based on validator result', async () => { + const tx1 = await makeTx(); + const tx2 = await makeTx(); + mockClient.getTxsByHash.mockResolvedValue([tx1, tx2]); + mockValidator.validateTx + .mockResolvedValueOnce({ result: 'valid' }) + .mockResolvedValueOnce({ result: 'invalid', reason: ['bad'] }); + + const result = await createSource().getTxsByHash([tx1.getTxHash(), tx2.getTxHash()]); + + expect(result.validTxs).toEqual([tx1]); + expect(result.invalidTxHashes).toEqual([tx2.getTxHash().toString()]); + }); +}); diff --git a/yarn-project/p2p/src/services/tx_collection/tx_source.ts b/yarn-project/p2p/src/services/tx_collection/tx_source.ts index 68e2ad8c1c43..20ef0dc923e4 100644 --- a/yarn-project/p2p/src/services/tx_collection/tx_source.ts +++ b/yarn-project/p2p/src/services/tx_collection/tx_source.ts @@ -2,7 +2,7 @@ import { getVKTreeRoot } from '@aztec/noir-protocol-circuits-types/vk-tree'; import { protocolContractsHash } from '@aztec/protocol-contracts'; import type { ChainConfig } from '@aztec/stdlib/config'; import { type AztecNode, createAztecNodeClient } from '@aztec/stdlib/interfaces/client'; -import type { Tx, TxHash } from '@aztec/stdlib/tx'; +import type { Tx, TxHash, TxValidator } from '@aztec/stdlib/tx'; import { type ComponentsVersions, getComponentsVersionsFromConfig } from '@aztec/stdlib/versioning'; import { makeTracedFetch } from '@aztec/telemetry-client'; @@ -16,12 +16,13 @@ export interface TxSource { export class NodeRpcTxSource implements TxSource { constructor( private readonly client: Pick, + private readonly txValidator: TxValidator, private readonly info: string, ) {} - public static fromUrl(nodeUrl: string, versions: ComponentsVersions): NodeRpcTxSource { + public static fromUrl(nodeUrl: string, txValidator: TxValidator, versions: ComponentsVersions): NodeRpcTxSource { const client = createAztecNodeClient(nodeUrl, versions, makeTracedFetch([1, 2, 3], false)); - return new NodeRpcTxSource(client, nodeUrl); + return new NodeRpcTxSource(client, txValidator, nodeUrl); } public getInfo() { @@ -38,8 +39,8 @@ export class NodeRpcTxSource implements TxSource { const invalidTxHashes: string[] = []; await Promise.all( txs.map(async tx => { - const isValid = await tx.validateTxHash(); - if (isValid) { + const validation = await this.txValidator.validateTx(tx); + if (validation.result === 'valid') { validTxs.push(tx); } else { invalidTxHashes.push(tx.getTxHash().toString()); @@ -50,7 +51,7 @@ export class NodeRpcTxSource implements TxSource { } } -export function createNodeRpcTxSources(urls: string[], chainConfig: ChainConfig) { +export function createNodeRpcTxSources(urls: string[], txValidator: TxValidator, chainConfig: ChainConfig) { const versions = getComponentsVersionsFromConfig(chainConfig, protocolContractsHash, getVKTreeRoot()); - return urls.map(url => NodeRpcTxSource.fromUrl(url, versions)); + return urls.map(url => NodeRpcTxSource.fromUrl(url, txValidator, versions)); } diff --git a/yarn-project/p2p/src/services/tx_file_store/tx_file_store.test.ts b/yarn-project/p2p/src/services/tx_file_store/tx_file_store.test.ts index 893a0b54aa9d..2fbefb183222 100644 --- a/yarn-project/p2p/src/services/tx_file_store/tx_file_store.test.ts +++ b/yarn-project/p2p/src/services/tx_file_store/tx_file_store.test.ts @@ -1,10 +1,11 @@ import { createLogger } from '@aztec/foundation/log'; import { sleep } from '@aztec/foundation/sleep'; import { type FileStore, createFileStore } from '@aztec/stdlib/file-store'; -import { Tx } from '@aztec/stdlib/tx'; +import { Tx, type TxValidator } from '@aztec/stdlib/tx'; import { jest } from '@jest/globals'; import { mkdtemp, readdir, rm } from 'fs/promises'; +import { type MockProxy, mock } from 'jest-mock-extended'; import { tmpdir } from 'os'; import { join } from 'path'; @@ -19,6 +20,7 @@ describe('TxFileStore', () => { let txPool: InMemoryTxPool; let config: TxFileStoreConfig; let txFileStore: TxFileStore | undefined; + let mockValidator: MockProxy; const log = createLogger('test:tx_file_store'); const basePath = 'aztec-1-1-0x1234'; @@ -52,6 +54,8 @@ describe('TxFileStore', () => { fileStore = await createFileStore(`file://${tmpDir}`); txPool = new InMemoryTxPool(); + mockValidator = mock(); + mockValidator.validateTx.mockResolvedValue({ result: 'valid' }); config = { txFileStoreEnabled: true, @@ -310,50 +314,44 @@ describe('TxFileStore', () => { }); describe('tx download validation', () => { - it('rejects tx with invalid hash when reading from file store', async () => { - // Write a tx with a mismatched hash directly to the file store - const invalidTx = Tx.random(); // random hash does not match computed hash - await fileStore.save(`${basePath}/txs/${invalidTx.txHash.toString()}.bin`, invalidTx.toBuffer(), { - compress: false, - }); + it('rejects tx when validator returns invalid', async () => { + const tx = await makeTx(); + await fileStore.save(`${basePath}/txs/${tx.txHash.toString()}.bin`, tx.toBuffer(), { compress: false }); - // Read it back via FileStoreTxSource - const source = (await FileStoreTxSource.create(`file://${tmpDir}`, basePath, log))!; - const result = await source.getTxsByHash([invalidTx.txHash]); + mockValidator.validateTx.mockResolvedValueOnce({ result: 'invalid', reason: ['invalid'] }); + const source = (await FileStoreTxSource.create(`file://${tmpDir}`, basePath, mockValidator, log))!; + const result = await source.getTxsByHash([tx.txHash]); expect(result.validTxs).toHaveLength(0); - expect(result.invalidTxHashes).toEqual([invalidTx.txHash.toString()]); + expect(result.invalidTxHashes).toEqual([tx.txHash.toString()]); }); - it('rejects tx when tx with wrong hash is returned', async () => { - // Write a tx with a mismatched hash directly to the file store - const invalidTx = Tx.random(); // random hash does not match computed hash - const validTx = await makeTx(); - await fileStore.save(`${basePath}/txs/${invalidTx.txHash.toString()}.bin`, validTx.toBuffer(), { - compress: false, - }); + it('accepts tx when validator returns valid', async () => { + const tx = await makeTx(); + await fileStore.save(`${basePath}/txs/${tx.txHash.toString()}.bin`, tx.toBuffer(), { compress: false }); - // Read it back via FileStoreTxSource - const source = (await FileStoreTxSource.create(`file://${tmpDir}`, basePath, log))!; - const result = await source.getTxsByHash([invalidTx.txHash]); + const source = (await FileStoreTxSource.create(`file://${tmpDir}`, basePath, mockValidator, log))!; + const result = await source.getTxsByHash([tx.txHash]); - expect(result.validTxs).toHaveLength(0); - expect(result.invalidTxHashes).toEqual([validTx.txHash.toString()]); + expect(result.validTxs).toHaveLength(1); + expect(result.invalidTxHashes).toHaveLength(0); }); - it('accepts correct tx', async () => { - // Write a tx with a correct hash directly to the file store - const validTx = await makeTx(); - await fileStore.save(`${basePath}/txs/${validTx.txHash.toString()}.bin`, validTx.toBuffer(), { - compress: false, - }); + it('partitions txs based on validator result', async () => { + const tx1 = await makeTx(); + const tx2 = await makeTx(); + await fileStore.save(`${basePath}/txs/${tx1.txHash.toString()}.bin`, tx1.toBuffer(), { compress: false }); + await fileStore.save(`${basePath}/txs/${tx2.txHash.toString()}.bin`, tx2.toBuffer(), { compress: false }); + + mockValidator.validateTx + .mockResolvedValueOnce({ result: 'valid' }) + .mockResolvedValueOnce({ result: 'invalid', reason: ['bad'] }); - // Read it back via FileStoreTxSource - const source = (await FileStoreTxSource.create(`file://${tmpDir}`, basePath, log))!; - const result = await source.getTxsByHash([validTx.txHash]); + const source = (await FileStoreTxSource.create(`file://${tmpDir}`, basePath, mockValidator, log))!; + const result = await source.getTxsByHash([tx1.txHash, tx2.txHash]); expect(result.validTxs).toHaveLength(1); - expect(result.invalidTxHashes).toHaveLength(0); + expect(result.invalidTxHashes).toHaveLength(1); }); }); @@ -388,7 +386,7 @@ describe('TxFileStore', () => { await txFileStore!.flush(); // Read back via FileStoreTxSource using the same local file store - const txSource = await FileStoreTxSource.create(`file://${tmpDir}`, basePath, log); + const txSource = await FileStoreTxSource.create(`file://${tmpDir}`, basePath, mockValidator, log); expect(txSource).toBeDefined(); const results = await txSource!.getTxsByHash([tx.getTxHash()]);