diff --git a/barretenberg/cpp/scripts/merkle_tree_tests.sh b/barretenberg/cpp/scripts/merkle_tree_tests.sh index 9e5e0f0b3c91..2b7719b0cdb2 100755 --- a/barretenberg/cpp/scripts/merkle_tree_tests.sh +++ b/barretenberg/cpp/scripts/merkle_tree_tests.sh @@ -5,7 +5,7 @@ set -e # run commands relative to parent directory cd $(dirname $0)/.. -DEFAULT_TESTS=PersistedIndexedTreeTest.*:PersistedAppendOnlyTreeTest.*:LMDBTreeStoreTest.*:PersistedContentAddressedIndexedTreeTest.*:PersistedContentAddressedAppendOnlyTreeTest.* +DEFAULT_TESTS=PersistedIndexedTreeTest.*:PersistedAppendOnlyTreeTest.*:LMDBTreeStoreTest.*:PersistedContentAddressedIndexedTreeTest.*:PersistedContentAddressedAppendOnlyTreeTest.*:ContentAddressedCacheTest.* TEST=${1:-$DEFAULT_TESTS} PRESET=${PRESET:-clang16} diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp index 6b7bfb17a3b9..616a3d4324c8 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.hpp @@ -39,17 +39,21 @@ template class ContentAddressedAppendOn using StoreType = Store; // Asynchronous methods accept these callback function types as arguments + using EmptyResponseCallback = std::function; using AppendCompletionCallback = std::function&)>; using MetaDataCallback = std::function&)>; using HashPathCallback = std::function&)>; using FindLeafCallback = std::function&)>; using GetLeafCallback = std::function&)>; using CommitCallback = std::function&)>; - using RollbackCallback = std::function; + using RollbackCallback = EmptyResponseCallback; using RemoveHistoricBlockCallback = std::function&)>; using UnwindBlockCallback = std::function&)>; - using FinaliseBlockCallback = std::function; + using FinaliseBlockCallback = EmptyResponseCallback; using GetBlockForIndexCallback = std::function&)>; + using CheckpointCallback = EmptyResponseCallback; + using CheckpointCommitCallback = EmptyResponseCallback; + using CheckpointRevertCallback = EmptyResponseCallback; // Only construct from provided store and thread pool, no copies or moves ContentAddressedAppendOnlyTree(std::unique_ptr store, @@ -223,6 +227,10 @@ template class ContentAddressedAppendOn void finalise_block(const block_number_t& blockNumber, const FinaliseBlockCallback& on_completion); + void checkpoint(const CheckpointCallback& on_completion); + void commit_checkpoint(const CheckpointCommitCallback& on_completion); + void revert_checkpoint(const CheckpointRevertCallback& on_completion); + protected: using ReadTransaction = typename Store::ReadTransaction; using ReadTransactionPtr = typename Store::ReadTransactionPtr; @@ -843,6 +851,34 @@ void ContentAddressedAppendOnlyTree::rollback(const Rollba workers_->enqueue(job); } +// TODO(PhilWindle): One possible optimisation is for the following 3 functions +// checkpoint, commit_checkpoint and revert_checkpoint to not use the thread pool +// It is not stricly necessary for these operations to use it. The balance is whether +// the cost of using it outweighs the benefit or checkpointing/reverting all tree concurrently + +template +void ContentAddressedAppendOnlyTree::checkpoint(const CheckpointCallback& on_completion) +{ + auto job = [=, this]() { execute_and_report([=, this]() { store_->checkpoint(); }, on_completion); }; + workers_->enqueue(job); +} + +template +void ContentAddressedAppendOnlyTree::commit_checkpoint( + const CheckpointCommitCallback& on_completion) +{ + auto job = [=, this]() { execute_and_report([=, this]() { store_->commit_checkpoint(); }, on_completion); }; + workers_->enqueue(job); +} + +template +void ContentAddressedAppendOnlyTree::revert_checkpoint( + const CheckpointRevertCallback& on_completion) +{ + auto job = [=, this]() { execute_and_report([=, this]() { store_->revert_checkpoint(); }, on_completion); }; + workers_->enqueue(job); +} + template void ContentAddressedAppendOnlyTree::remove_historic_block( const block_number_t& blockNumber, const RemoveHistoricBlockCallback& on_completion) diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp index f484907329ff..4c01c53a4853 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/append_only_tree/content_addressed_append_only_tree.test.cpp @@ -20,8 +20,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -151,17 +151,6 @@ void commit_tree(TreeType& tree, bool expected_success = true) signal.wait_for_level(); } -void rollback_tree(TreeType& tree) -{ - Signal signal; - auto completion = [&](const Response& response) -> void { - EXPECT_EQ(response.success, true); - signal.signal_level(); - }; - tree.rollback(completion); - signal.wait_for_level(); -} - void remove_historic_block(TreeType& tree, const block_number_t& blockNumber, bool expected_success = true) { Signal signal; @@ -1949,3 +1938,112 @@ TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_not_historically_remove_ } remove_historic_block(tree, blockToFinalise, false); } + +TEST_F(PersistedContentAddressedAppendOnlyTreeTest, can_checkpoint_and_revert_forks) +{ + constexpr size_t depth = 10; + uint32_t blockSize = 16; + std::string name = random_string(); + ThreadPoolPtr pool = make_thread_pool(1); + LMDBTreeStore::SharedPtr db = std::make_shared(_directory, name, _mapSize, _maxReaders); + MemoryTree memdb(depth); + + { + std::unique_ptr store = std::make_unique(name, depth, db); + TreeType tree(std::move(store), pool); + + std::vector values = create_values(blockSize); + add_values(tree, values); + + commit_tree(tree); + } + + std::unique_ptr store = std::make_unique(name, depth, db); + TreeType tree(std::move(store), pool); + + // We apply a number of updates and checkpoint the tree each time + + uint32_t stackDepth = 20; + + std::vector paths(stackDepth); + uint32_t index = 0; + for (; index < stackDepth - 1; index++) { + std::vector values = create_values(blockSize); + add_values(tree, values); + + paths[index] = get_sibling_path(tree, 3); + + try { + checkpoint_tree(tree); + } catch (std::exception& e) { + std::cout << e.what() << std::endl; + } + } + + // Now add one more depth, this will be un-checkpointed + { + std::vector values = create_values(blockSize); + add_values(tree, values); + paths[index] = get_sibling_path(tree, 3); + } + + index_t checkpointIndex = index; + + // The tree is currently at the state of index 19 + EXPECT_EQ(get_sibling_path(tree, 3), paths[checkpointIndex]); + + // We now alternate committing and reverting the checkpoints half way up the stack + + for (; index > stackDepth / 2; index--) { + if (index % 2 == 0) { + revert_checkpoint_tree(tree, true); + checkpointIndex = index - 1; + } else { + commit_checkpoint_tree(tree, true); + } + + EXPECT_EQ(get_sibling_path(tree, 3), paths[checkpointIndex]); + } + + // Now apply another set of updates and checkpoints back to the original stack depth + for (; index < stackDepth - 1; index++) { + std::vector values = create_values(blockSize); + add_values(tree, values); + + paths[index] = get_sibling_path(tree, 3); + + try { + checkpoint_tree(tree); + } catch (std::exception& e) { + std::cout << e.what() << std::endl; + } + } + + // Now add one more depth, this will be un-checkpointed + { + std::vector values = create_values(blockSize); + add_values(tree, values); + paths[index] = get_sibling_path(tree, 3); + } + + // We now alternatively commit and revert all the way back to the start + checkpointIndex = index; + + // The tree is currently at the state of index 19 + EXPECT_EQ(get_sibling_path(tree, 3), paths[checkpointIndex]); + + for (; index > 0; index--) { + if (index % 2 == 0) { + revert_checkpoint_tree(tree, true); + checkpointIndex = index - 1; + } else { + commit_checkpoint_tree(tree, true); + } + + EXPECT_EQ(get_sibling_path(tree, 3), paths[checkpointIndex]); + } + + // Should not be able to commit or revert where there is no active checkpoint + revert_checkpoint_tree(tree, false); + commit_checkpoint_tree(tree, false); +} diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.hpp index 21c2acb532f3..229122050bfb 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.hpp @@ -965,8 +965,10 @@ void ContentAddressedIndexedTree::generate_insertions( // std::cout << "Failed to find low leaf" << std::endl; throw std::runtime_error(format("Unable to insert values into tree ", meta.name, - " failed to find low leaf at index ", - low_leaf_index)); + ", failed to find low leaf at index ", + low_leaf_index, + ", current size: ", + meta.size)); } // std::cout << "Low leaf hash " << low_leaf_hash.value() << std::endl; @@ -1454,7 +1456,7 @@ void ContentAddressedIndexedTree::generate_sequential_inse if (!low_leaf_hash.has_value()) { throw std::runtime_error(format("Unable to insert values into tree ", meta.name, - " failed to find low leaf at index ", + ", failed to find low leaf at index ", low_leaf_index)); } diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.test.cpp index b42d31755749..7fa27ecbfc85 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.test.cpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/indexed_tree/content_addressed_indexed_tree.test.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -125,26 +124,6 @@ fr_sibling_path get_historic_sibling_path(TypeOfTree& tree, return h; } -template -fr_sibling_path get_sibling_path(TypeOfTree& tree, - index_t index, - bool includeUncommitted = true, - bool expected_success = true) -{ - fr_sibling_path h; - Signal signal; - auto completion = [&](const TypedResponse& response) -> void { - EXPECT_EQ(response.success, expected_success); - if (response.success) { - h = response.inner.path; - } - signal.signal_level(); - }; - tree.get_sibling_path(index, completion, includeUncommitted); - signal.wait_for_level(); - return h; -} - template IndexedLeaf get_leaf(TypeOfTree& tree, index_t index, @@ -2795,3 +2774,343 @@ TEST_F(PersistedContentAddressedIndexedTreeTest, can_sync_and_unwind_empty_block _directory, ss.str(), _mapSize, _maxReaders, 20, actualSize, numBlocks, numBlocksToUnwind, values); } } + +TEST_F(PersistedContentAddressedIndexedTreeTest, test_can_commit_and_revert_checkpoints) +{ + index_t initial_size = 2; + index_t current_size = initial_size; + ThreadPoolPtr workers = make_thread_pool(8); + // Create a depth-3 indexed merkle tree + constexpr size_t depth = 3; + std::string name = random_string(); + LMDBTreeStore::SharedPtr db = std::make_shared(_directory, name, _mapSize, _maxReaders); + std::unique_ptr> store = + std::make_unique>(name, depth, db); + auto tree = ContentAddressedIndexedTree, Poseidon2HashPolicy>( + std::move(store), workers, current_size); + + /** + * Intial state: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 0 0 0 0 0 0 + * val 0 0 0 0 0 0 0 0 + * nextIdx 1 0 0 0 0 0 0 0 + * nextVal 1 0 0 0 0 0 0 0 + */ + + /** + * Add new slot:value 30:5: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 0 0 0 0 0 + * val 0 0 5 0 0 0 0 0 + * nextIdx 1 2 0 0 0 0 0 0 + * nextVal 1 30 0 0 0 0 0 0 + */ + add_value_sequentially(tree, PublicDataLeafValue(30, 5)); + check_size(tree, ++current_size); + + /** + * Add new slot:value 10:20: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 0 0 0 0 + * val 0 0 5 20 0 0 0 0 + * nextIdx 1 3 0 2 0 0 0 0 + * nextVal 1 10 0 30 0 0 0 0 + */ + add_value_sequentially(tree, PublicDataLeafValue(10, 20)); + check_size(tree, ++current_size); + + /** + * Update value at slot 30 to 6: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 0 0 0 0 + * val 0 0 6 20 0 0 0 0 + * nextIdx 1 3 0 2 0 0 0 0 + * nextVal 1 10 0 30 0 0 0 0 + */ + add_value_sequentially(tree, PublicDataLeafValue(30, 6)); + // The size does not increase since sequential insertion doesn't pad + check_size(tree, current_size); + commit_tree(tree); + + { + index_t fork_size = current_size; + std::unique_ptr> forkStore = + std::make_unique>(name, depth, db); + auto forkTree = + ContentAddressedIndexedTree, Poseidon2HashPolicy>( + std::move(forkStore), workers, initial_size); + + // Find the low leaf of slot 60 + auto predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be at index 2 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 2); + + // checkpoint the fork + checkpoint_tree(forkTree); + + /** + * Add new value slot:value 50:8: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 50 0 0 0 + * val 0 0 6 20 8 0 0 0 + * nextIdx 1 3 4 2 0 0 0 0 + * nextVal 1 10 50 30 0 0 0 0 + */ + add_value_sequentially(forkTree, PublicDataLeafValue(50, 8)); + check_size(forkTree, ++fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 6, 4, 50)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + EXPECT_EQ(get_leaf(forkTree, 4), create_indexed_public_data_leaf(50, 8, 0, 0)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 4); + + // Now revert the fork and see that it is rolled back to the checkpoint + revert_checkpoint_tree(forkTree); + check_size(forkTree, --fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 6, 0, 0)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be back at index 2 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 2); + + // checkpoint the fork again + checkpoint_tree(forkTree); + + // We now advance the fork again by a few checkpoints + + /** + * Add new value slot:value 50:8: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 50 0 0 0 + * val 0 0 6 20 8 0 0 0 + * nextIdx 1 3 4 2 0 0 0 0 + * nextVal 1 10 50 30 0 0 0 0 + */ + + // Make the same change again, commit the checkpoint and see that the changes remain + add_value_sequentially(forkTree, PublicDataLeafValue(50, 8)); + check_size(forkTree, ++fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 6, 4, 50)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + EXPECT_EQ(get_leaf(forkTree, 4), create_indexed_public_data_leaf(50, 8, 0, 0)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 4); + + // Checkpoint again + checkpoint_tree(forkTree); + + /** + * Update the value in slot 30 to 12: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 50 0 0 0 + * val 0 0 12 20 8 0 0 0 + * nextIdx 1 3 4 2 0 0 0 0 + * nextVal 1 10 50 30 0 0 0 0 + */ + add_value_sequentially(forkTree, PublicDataLeafValue(30, 12)); + check_size(forkTree, fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 12, 4, 50)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + EXPECT_EQ(get_leaf(forkTree, 4), create_indexed_public_data_leaf(50, 8, 0, 0)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 4); + + // Checkpoint again + checkpoint_tree(forkTree); + + /** + * Add a value at slot 45:15 + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 50 45 0 0 + * val 0 0 12 20 8 15 0 0 + * nextIdx 1 3 5 2 0 4 0 0 + * nextVal 1 10 45 30 0 50 0 0 + */ + add_value_sequentially(forkTree, PublicDataLeafValue(45, 15)); + + check_size(forkTree, ++fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 12, 5, 45)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + EXPECT_EQ(get_leaf(forkTree, 4), create_indexed_public_data_leaf(50, 8, 0, 0)); + EXPECT_EQ(get_leaf(forkTree, 5), create_indexed_public_data_leaf(45, 15, 4, 50)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 4); + + // Find the low leaf of slot 46 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(46, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 5); + + // Now commit the last checkpoint + commit_checkpoint_tree(forkTree); + + // The state should be identical + check_size(forkTree, fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 12, 5, 45)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + EXPECT_EQ(get_leaf(forkTree, 4), create_indexed_public_data_leaf(50, 8, 0, 0)); + EXPECT_EQ(get_leaf(forkTree, 5), create_indexed_public_data_leaf(45, 15, 4, 50)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 4); + + // Find the low leaf of slot 46 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(46, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 5); + + // Now revert the fork and we should remove both the new slot 45 and the update to slot 30 + + /** + * We should revert to this state: + * + * index 0 1 2 3 4 5 6 7 + * --------------------------------------------------------------------- + * slot 0 1 30 10 50 0 0 0 + * val 0 0 6 20 8 0 0 0 + * nextIdx 1 3 4 2 0 0 0 0 + * nextVal 1 10 50 30 0 0 0 0 + */ + + revert_checkpoint_tree(forkTree); + + check_size(forkTree, --fork_size); + EXPECT_EQ(get_leaf(forkTree, 0), create_indexed_public_data_leaf(0, 0, 1, 1)); + EXPECT_EQ(get_leaf(forkTree, 1), create_indexed_public_data_leaf(1, 0, 3, 10)); + EXPECT_EQ(get_leaf(forkTree, 2), create_indexed_public_data_leaf(30, 6, 4, 50)); + EXPECT_EQ(get_leaf(forkTree, 3), create_indexed_public_data_leaf(10, 20, 2, 30)); + EXPECT_EQ(get_leaf(forkTree, 4), create_indexed_public_data_leaf(50, 8, 0, 0)); + + // Find the low leaf of slot 60 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(60, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 4); + + // Find the low leaf of slot 46 + predecessor = get_low_leaf(forkTree, PublicDataLeafValue(46, 5)); + + // It should be back at index 4 + EXPECT_EQ(predecessor.is_already_present, false); + EXPECT_EQ(predecessor.index, 2); + } +} + +void advance_state(TreeType& fork, uint32_t size) +{ + std::vector values = create_values(size); + std::vector leaves; + for (uint32_t j = 0; j < size; j++) { + leaves.emplace_back(values[j]); + } + add_values(fork, leaves); +} + +TEST_F(PersistedContentAddressedIndexedTreeTest, nullifiers_can_be_inserted_after_revert) +{ + index_t current_size = 2; + ThreadPoolPtr workers = make_thread_pool(1); + constexpr size_t depth = 10; + std::string name = "Nullifier Tree"; + LMDBTreeStore::SharedPtr db = std::make_shared(_directory, name, _mapSize, _maxReaders); + std::unique_ptr store = std::make_unique(name, depth, db); + auto tree = TreeType(std::move(store), workers, current_size); + + { + std::unique_ptr forkStore = std::make_unique(name, depth, db); + auto forkTree = TreeType(std::move(forkStore), workers, current_size); + + check_size(tree, current_size); + + uint32_t size_to_insert = 8; + uint32_t num_insertions = 5; + + for (uint32_t i = 0; i < num_insertions - 1; i++) { + advance_state(forkTree, size_to_insert); + current_size += size_to_insert; + check_size(forkTree, current_size); + checkpoint_tree(forkTree); + } + + advance_state(forkTree, size_to_insert); + current_size += size_to_insert; + check_size(forkTree, current_size); + revert_checkpoint_tree(forkTree); + + current_size -= size_to_insert; + check_size(forkTree, current_size); + + commit_checkpoint_tree(forkTree); + + check_size(forkTree, current_size); + + advance_state(forkTree, size_to_insert); + + current_size += size_to_insert; + check_size(forkTree, current_size); + } +} diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp index da95d3857966..32e40faed882 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/cached_content_addressed_tree_store.hpp @@ -1,7 +1,9 @@ #pragma once #include "./tree_meta.hpp" +#include "barretenberg/common/log.hpp" #include "barretenberg/crypto/merkle_tree/indexed_tree/indexed_leaf.hpp" #include "barretenberg/crypto/merkle_tree/lmdb_store/lmdb_tree_store.hpp" +#include "barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp" #include "barretenberg/crypto/merkle_tree/types.hpp" #include "barretenberg/ecc/curves/bn254/fr.hpp" #include "barretenberg/lmdblib/lmdb_helpers.hpp" @@ -22,17 +24,6 @@ #include #include -template <> struct std::hash { - std::size_t operator()(const uint256_t& k) const { return k.data[0]; } -}; -template <> struct std::hash { - std::size_t operator()(const bb::fr& k) const - { - bb::numeric::uint256_t val(k); - return val.data[0]; - } -}; - namespace bb::crypto::merkle_tree { /** @@ -190,39 +181,28 @@ template class ContentAddressedCachedTreeStore { void unwind_block(const block_number_t& blockNumber, TreeMeta& finalMeta, TreeDBStats& dbStats); - std::optional get_fork_block() const; - void advance_finalised_block(const block_number_t& blockNumber); std::optional find_block_for_index(const index_t& index, ReadTransaction& tx) const; + void checkpoint(); + void revert_checkpoint(); + void commit_checkpoint(); + private: + using Cache = ContentAddressedCache; + struct ForkConstantData { std::string name_; uint32_t depth_; std::optional initialised_from_block_; }; ForkConstantData forkConstantData_; + mutable std::mutex mtx_; - // This is a mapping between the node hash and it's payload (children and ref count) for every node in the tree, - // including leaves. As indexed trees are updated, this will end up containing many nodes that are not part of the - // final tree so they need to be omitted from what is committed. - std::unordered_map nodes_; - - // This is a store mapping the leaf key (e.g. slot for public data or nullifier value for nullifier tree) to the - // index in the tree - std::map indices_; - - // This is a mapping from leaf hash to leaf pre-image. This will contain entries that need to be omitted when - // commiting updates - std::unordered_map leaves_; PersistedStoreType::SharedPtr dataStore_; - TreeMeta meta_; - mutable std::mutex mtx_; - // The following stores are not persisted, just cached until commit - std::vector> nodes_by_index_; - std::unordered_map leaf_pre_image_by_index_; + Cache cache_; void initialise(); @@ -234,10 +214,6 @@ template class ContentAddressedCachedTreeStore { void persist_meta(TreeMeta& m, WriteTransaction& tx); - void persist_leaf_indices(WriteTransaction& tx); - - void persist_leaf_pre_image(const fr& hash, WriteTransaction& tx); - void persist_node(const std::optional& optional_hash, uint32_t level, WriteTransaction& tx); void remove_node(const std::optional& optional_hash, @@ -253,6 +229,8 @@ template class ContentAddressedCachedTreeStore { void persist_block_for_index(const block_number_t& blockNumber, const index_t& index, WriteTransaction& tx); + void persist_leaf_indices(WriteTransaction& tx); + void delete_block_for_index(const block_number_t& blockNumber, const index_t& index, WriteTransaction& tx); index_t constrain_tree_size_to_only_committed(const RequestContext& requestContext, ReadTransaction& tx) const; @@ -266,7 +244,7 @@ ContentAddressedCachedTreeStore::ContentAddressedCachedTreeStore( PersistedStoreType::SharedPtr dataStore) : forkConstantData_{ .name_ = (std::move(name)), .depth_ = levels } , dataStore_(dataStore) - , nodes_by_index_(std::vector>(levels + 1, std::unordered_map())) + , cache_(levels) { initialise(); } @@ -278,11 +256,30 @@ ContentAddressedCachedTreeStore::ContentAddressedCachedTreeStore( PersistedStoreType::SharedPtr dataStore) : forkConstantData_{ .name_ = (std::move(name)), .depth_ = levels } , dataStore_(dataStore) - , nodes_by_index_(std::vector>(levels + 1, std::unordered_map())) + , cache_(levels) { initialise_from_block(referenceBlockNumber); } +// Much Like the commit/rollback/set finalised/remove historic blocks apis +// These 3 apis (checkpoint/revert_checkpoint/commit_checkpoint) all assume they are not called +// during the process of reading/writing uncommitted state +// This is reasonable, they intended for use by forks at the point of starting/ending a function call +template void ContentAddressedCachedTreeStore::checkpoint() +{ + cache_.checkpoint(); +} + +template void ContentAddressedCachedTreeStore::revert_checkpoint() +{ + cache_.revert(); +} + +template void ContentAddressedCachedTreeStore::commit_checkpoint() +{ + cache_.commit(); +} + template index_t ContentAddressedCachedTreeStore::constrain_tree_size_to_only_committed( const RequestContext& requestContext, ReadTransaction& tx) const @@ -353,38 +350,20 @@ std::pair ContentAddressedCachedTreeStore::find_lo index_t db_index = committed; uint256_t retrieved_value = found_key; - // Accessing indices_ from here under a lock - std::unique_lock lock(mtx_); - if (!requestContext.includeUncommitted || retrieved_value == new_value_as_number || indices_.empty()) { - return std::make_pair(new_value_as_number == retrieved_value, db_index); + // If we already found the leaf then return it. + bool already_present = retrieved_value == new_value_as_number; + if (already_present) { + return std::make_pair(true, db_index); } - // At this stage, we have been asked to include uncommitted and the value was not exactly found in the db - auto it = indices_.lower_bound(new_value_as_number); - if (it == indices_.end()) { - // there is no element >= the requested value. - // decrement the iterator to get the value preceeding the requested value - --it; - // we need to return the larger of the db value or the cached value - - return std::make_pair(false, it->first > retrieved_value ? it->second : db_index); - } - - if (it->first == uint256_t(new_value_as_number)) { - // the value is already present and the iterator points to it - return std::make_pair(true, it->second); - } - // the iterator points to the element immediately larger than the requested value - // We need to return the highest value from - // 1. The next lowest cached value, if there is one - // 2. The value retrieved from the db - if (it == indices_.begin()) { - // No cached lower value, return the db index + // If we were asked not to include uncommitted then return what we have + if (!requestContext.includeUncommitted) { return std::make_pair(false, db_index); } - --it; - // it now points to the value less than that requested - return std::make_pair(false, it->first > retrieved_value ? it->second : db_index); + + // Accessing the cache from here under a lock + std::unique_lock lock(mtx_); + return cache_.find_low_value(new_leaf_key, retrieved_value, db_index); } template @@ -393,53 +372,49 @@ ContentAddressedCachedTreeStore::get_leaf_by_hash(const fr& leaf_ ReadTransaction& tx, bool includeUncommitted) const { - std::optional::IndexedLeafValueType> leaf = std::nullopt; + IndexedLeafValueType leafData; if (includeUncommitted) { - // Accessing leaves_ here under a lock + // Accessing the cache here under a lock std::unique_lock lock(mtx_); - typename std::unordered_map::const_iterator it = leaves_.find(leaf_hash); - if (it != leaves_.end()) { - leaf = it->second; - return leaf; + if (cache_.get_leaf_preimage_by_hash(leaf_hash, leafData)) { + return leafData; } } - IndexedLeafValueType leafData; - bool success = dataStore_->read_leaf_by_hash(leaf_hash, leafData, tx); - if (success) { - leaf = leafData; + if (dataStore_->read_leaf_by_hash(leaf_hash, leafData, tx)) { + return leafData; } - return leaf; + return std::nullopt; } template void ContentAddressedCachedTreeStore::put_leaf_by_hash(const fr& leaf_hash, const IndexedLeafValueType& leafPreImage) { - // Accessing leaves_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - leaves_[leaf_hash] = leafPreImage; + cache_.put_leaf_preimage_by_hash(leaf_hash, leafPreImage); } template std::optional::IndexedLeafValueType> ContentAddressedCachedTreeStore::get_cached_leaf_by_index(const index_t& index) const { - // Accessing leaf_pre_image_by_index_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - auto it = leaf_pre_image_by_index_.find(index); - if (it == leaf_pre_image_by_index_.end()) { - return std::nullopt; + IndexedLeafValueType leafPreImage; + if (cache_.get_leaf_by_index(index, leafPreImage)) { + return leafPreImage; } - return it->second; + return std::nullopt; } template void ContentAddressedCachedTreeStore::put_cached_leaf_by_index(const index_t& index, const IndexedLeafValueType& leafPreImage) { - // Accessing leaf_pre_image_by_index_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - leaf_pre_image_by_index_[index] = leafPreImage; + cache_.put_leaf_by_index(index, leafPreImage); } template @@ -454,9 +429,9 @@ template void ContentAddressedCachedTreeStore::update_index(const index_t& index, const fr& leaf) { // std::cout << "update_index at index " << index << " leaf " << leaf << std::endl; - // Accessing indices_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - indices_.insert({ uint256_t(leaf), index }); + cache_.update_leaf_key_index(index, leaf); } template @@ -474,14 +449,14 @@ std::optional ContentAddressedCachedTreeStore::find_leaf ReadTransaction& tx) const { if (requestContext.includeUncommitted) { - // Accessing indices_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - auto it = indices_.find(uint256_t(leaf)); - if (it != indices_.end()) { - // we have an uncommitted value, we will return from here - if (it->second >= start_index) { - // we have a qualifying value - return std::make_optional(it->second); + std::optional cached = cache_.get_leaf_key_index(preimage_to_key(leaf)); + if (cached.has_value()) { + // The is a cached value for the leaf + // We will return from here regardless + if (cached.value() >= start_index) { + return cached; } return std::nullopt; } @@ -512,7 +487,7 @@ void ContentAddressedCachedTreeStore::put_node_by_hash(const fr& { // Accessing nodes_ under a lock std::unique_lock lock(mtx_); - nodes_[nodeHash] = payload; + cache_.put_node(nodeHash, payload); } template @@ -524,9 +499,7 @@ bool ContentAddressedCachedTreeStore::get_node_by_hash(const fr& if (includeUncommitted) { // Accessing nodes_ under a lock std::unique_lock lock(mtx_); - auto it = nodes_.find(nodeHash); - if (it != nodes_.end()) { - payload = it->second; + if (cache_.get_node(nodeHash, payload)) { return true; } } @@ -539,16 +512,15 @@ void ContentAddressedCachedTreeStore::put_cached_node_by_index(ui const fr& data, bool overwriteIfPresent) { - // Accessing nodes_by_index_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); if (!overwriteIfPresent) { - const auto& level_map = nodes_by_index_[level]; - auto it = level_map.find(index); - if (it != level_map.end()) { + std::optional cached = cache_.get_node_by_index(level, index); + if (cached.has_value()) { return; } } - nodes_by_index_[level][index] = data; + cache_.put_node_by_index(level, index, data); } template @@ -556,22 +528,21 @@ bool ContentAddressedCachedTreeStore::get_cached_node_by_index(ui const index_t& index, fr& data) const { - // Accessing nodes_by_index_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - const auto& level_map = nodes_by_index_[level]; - auto it = level_map.find(index); - if (it == level_map.end()) { - return false; + std::optional cached = cache_.get_node_by_index(level, index); + if (cached.has_value()) { + data = cached.value(); + return true; } - data = it->second; - return true; + return false; } template void ContentAddressedCachedTreeStore::put_meta(const TreeMeta& m) { - // Accessing meta_ under a lock + // Accessing the cache under a lock std::unique_lock lock(mtx_); - meta_ = m; + cache_.put_meta(m); } template @@ -590,7 +561,7 @@ template void ContentAddressedCachedTreeStore @@ -641,9 +612,19 @@ fr ContentAddressedCachedTreeStore::get_current_root(ReadTransact } // The following functions are related to either initialisation or committing data -// It is assumed that when these operations are being executed that no other state accessing operations +// It is assumed that when these operations are being executed, no other state accessing operations // are in progress, hence no data synchronisation is used. +template +void ContentAddressedCachedTreeStore::persist_leaf_indices(WriteTransaction& tx) +{ + const std::map& indices = cache_.get_indices(); + for (const auto& idx : indices) { + FrKeyType key = idx.first; + dataStore_->write_leaf_index(key, idx.second, tx); + } +} + template void ContentAddressedCachedTreeStore::commit_genesis_state() { // In this call, we will store any node/leaf data that has been created so far @@ -654,9 +635,8 @@ template void ContentAddressedCachedTreeStore void ContentAddressedCachedTreeStore::commit_block(TreeMeta& finalMeta, TreeDBStats& dbStats) { bool dataPresent = false; - TreeMeta uncommittedMeta; - TreeMeta committedMeta; + TreeMeta meta; + // We don't allow commits using images/forks if (forkConstantData_.initialised_from_block_.has_value()) { throw std::runtime_error("Committing a fork is forbidden"); } - { - ReadTransactionPtr tx = create_read_transaction(); - // read both committed and uncommitted meta data - get_meta(uncommittedMeta); - get_meta(committedMeta, *tx, false); - - auto currentRootIter = nodes_.find(uncommittedMeta.root); - dataPresent = currentRootIter != nodes_.end(); - } + get_meta(meta); + NodePayload rootPayload; + dataPresent = cache_.get_node(meta.root, rootPayload); { WriteTransactionPtr tx = create_write_transaction(); try { if (dataPresent) { // std::cout << "Persisting data for block " << uncommittedMeta.unfinalisedBlockHeight + 1 << std::endl; + // Persist the leaf indices persist_leaf_indices(*tx); } // If we are commiting a block, we need to persist the root, since the new block "references" this root @@ -710,22 +685,20 @@ void ContentAddressedCachedTreeStore::commit_block(TreeMeta& fina // absence of a real tree elsewhere. So, if the tree is completely empty we do not store any node data, the // only issue is this needs to be recognised when we unwind or remove historic blocks i.e. there will be no // node date to remove for these blocks - if (dataPresent || uncommittedMeta.size > 0) { - persist_node(std::optional(uncommittedMeta.root), 0, *tx); + if (dataPresent || meta.size > 0) { + persist_node(std::optional(meta.root), 0, *tx); } - ++uncommittedMeta.unfinalisedBlockHeight; - if (uncommittedMeta.oldestHistoricBlock == 0) { - uncommittedMeta.oldestHistoricBlock = 1; + ++meta.unfinalisedBlockHeight; + if (meta.oldestHistoricBlock == 0) { + meta.oldestHistoricBlock = 1; } // std::cout << "New root " << uncommittedMeta.root << std::endl; - BlockPayload block{ .size = uncommittedMeta.size, - .blockNumber = uncommittedMeta.unfinalisedBlockHeight, - .root = uncommittedMeta.root }; - dataStore_->write_block_data(uncommittedMeta.unfinalisedBlockHeight, block, *tx); + BlockPayload block{ .size = meta.size, .blockNumber = meta.unfinalisedBlockHeight, .root = meta.root }; + dataStore_->write_block_data(meta.unfinalisedBlockHeight, block, *tx); dataStore_->write_block_index_data(block.blockNumber, block.size, *tx); - uncommittedMeta.committedSize = uncommittedMeta.size; - persist_meta(uncommittedMeta, *tx); + meta.committedSize = meta.size; + persist_meta(meta, *tx); tx->commit(); } catch (std::exception& e) { tx->try_abort(); @@ -733,7 +706,7 @@ void ContentAddressedCachedTreeStore::commit_block(TreeMeta& fina format("Unable to commit data to tree: ", forkConstantData_.name_, " Error: ", e.what())); } } - finalMeta = uncommittedMeta; + finalMeta = meta; // rolling back destroys all cache stores and also refreshes the cached meta_ from persisted state rollback(); @@ -751,26 +724,6 @@ void ContentAddressedCachedTreeStore::extract_db_stats(TreeDBStat } } -template -void ContentAddressedCachedTreeStore::persist_leaf_indices(WriteTransaction& tx) -{ - for (auto& idx : indices_) { - FrKeyType key = idx.first; - dataStore_->write_leaf_index(key, idx.second, tx); - } -} - -template -void ContentAddressedCachedTreeStore::persist_leaf_pre_image(const fr& hash, WriteTransaction& tx) -{ - // Now persist the leaf pre-image - auto leafPreImageIter = leaves_.find(hash); - if (leafPreImageIter == leaves_.end()) { - return; - } - dataStore_->write_leaf_by_hash(hash, leafPreImageIter->second, tx); -} - template void ContentAddressedCachedTreeStore::persist_node(const std::optional& optional_hash, uint32_t level, @@ -796,43 +749,42 @@ void ContentAddressedCachedTreeStore::persist_node(const std::opt fr hash = so.opHash.value(); if (so.lvl == forkConstantData_.depth_) { - // this is a leaf - persist_leaf_pre_image(hash, tx); + // this is a leaf, we need to persist the pre-image + IndexedLeafValueType leafPreImage; + if (cache_.get_leaf_preimage_by_hash(hash, leafPreImage)) { + dataStore_->write_leaf_by_hash(hash, leafPreImage, tx); + } } // std::cout << "Persisting node hash " << hash << " at level " << so.lvl << std::endl; - auto nodePayloadIter = nodes_.find(hash); - if (nodePayloadIter == nodes_.end()) { + NodePayload nodePayload; + if (!cache_.get_node(hash, nodePayload)) { // need to increase the stored node's reference count here dataStore_->increment_node_reference_count(hash, tx); continue; } - NodePayload nodeData = nodePayloadIter->second; - dataStore_->set_or_increment_node_reference_count(hash, nodeData, tx); - if (nodeData.ref != 1) { + dataStore_->set_or_increment_node_reference_count(hash, nodePayload, tx); + if (nodePayload.ref != 1) { // If the node now has a ref count greater then 1, we don't continue. // It means that the entire sub-tree underneath already exists continue; } - stack.push_back({ .opHash = nodePayloadIter->second.left, .lvl = so.lvl + 1 }); - stack.push_back({ .opHash = nodePayloadIter->second.right, .lvl = so.lvl + 1 }); + stack.push_back({ .opHash = nodePayload.left, .lvl = so.lvl + 1 }); + stack.push_back({ .opHash = nodePayload.right, .lvl = so.lvl + 1 }); } } template void ContentAddressedCachedTreeStore::rollback() { // Extract the committed meta data and destroy the cache + cache_.reset(forkConstantData_.depth_); { ReadTransactionPtr tx = create_read_transaction(); - read_persisted_meta(meta_, *tx); + TreeMeta committedMeta; + read_persisted_meta(committedMeta, *tx); + cache_.put_meta(committedMeta); } - nodes_ = std::unordered_map(); - indices_ = std::map(); - leaves_ = std::unordered_map(); - nodes_by_index_ = - std::vector>(forkConstantData_.depth_ + 1, std::unordered_map()); - leaf_pre_image_by_index_ = std::unordered_map(); } template @@ -1179,12 +1131,13 @@ template void ContentAddressedCachedTreeStore data; + TreeMeta meta; { ReadTransactionPtr tx = create_read_transaction(); - bool success = read_persisted_meta(meta_, *tx); + bool success = read_persisted_meta(meta, *tx); if (success) { - if (forkConstantData_.name_ == meta_.name && forkConstantData_.depth_ == meta_.depth) { + if (forkConstantData_.name_ == meta.name && forkConstantData_.depth_ == meta.depth) { + cache_.put_meta(meta); return; } throw std::runtime_error( @@ -1193,24 +1146,25 @@ template void ContentAddressedCachedTreeStorecommit(); } catch (std::exception& e) { tx->try_abort(); throw e; } + cache_.put_meta(meta); } template @@ -1218,12 +1172,12 @@ void ContentAddressedCachedTreeStore::initialise_from_block(const { // Read the persisted meta data, if the name or depth of the tree is not consistent with what was provided during // construction then we throw - std::vector data; { ReadTransactionPtr tx = create_read_transaction(); - bool success = read_persisted_meta(meta_, *tx); + TreeMeta meta; + bool success = read_persisted_meta(meta, *tx); if (success) { - if (forkConstantData_.name_ != meta_.name || forkConstantData_.depth_ != meta_.depth) { + if (forkConstantData_.name_ != meta.name || forkConstantData_.depth_ != meta.depth) { throw std::runtime_error(format("Inconsistent tree meta data when initialising ", forkConstantData_.name_, " with depth ", @@ -1231,9 +1185,9 @@ void ContentAddressedCachedTreeStore::initialise_from_block(const " from block ", blockNumber, " stored name: ", - meta_.name, + meta.name, "stored depth: ", - meta_.depth)); + meta.depth)); } } else { @@ -1243,44 +1197,36 @@ void ContentAddressedCachedTreeStore::initialise_from_block(const blockNumber)); } - if (meta_.unfinalisedBlockHeight < blockNumber) { + if (meta.unfinalisedBlockHeight < blockNumber) { throw std::runtime_error(format("Unable to initialise from future block: ", blockNumber, " unfinalisedBlockHeight: ", - meta_.unfinalisedBlockHeight, + meta.unfinalisedBlockHeight, ". Tree name: ", forkConstantData_.name_)); } - if (meta_.oldestHistoricBlock > blockNumber && blockNumber != 0) { + if (meta.oldestHistoricBlock > blockNumber && blockNumber != 0) { throw std::runtime_error(format("Unable to fork from expired historical block: ", blockNumber, " unfinalisedBlockHeight: ", - meta_.oldestHistoricBlock, + meta.oldestHistoricBlock, ". Tree name: ", forkConstantData_.name_)); } BlockPayload blockData; if (blockNumber == 0) { blockData.blockNumber = 0; - blockData.root = meta_.initialRoot; - blockData.size = meta_.initialSize; + blockData.root = meta.initialRoot; + blockData.size = meta.initialSize; } else if (get_block_data(blockNumber, blockData, *tx) == false) { throw std::runtime_error( format("Failed to retrieve block data: ", blockNumber, ". Tree name: ", forkConstantData_.name_)); } forkConstantData_.initialised_from_block_ = blockData; // Ensure the meta reflects the fork constant data - enrich_meta_from_fork_constant_data(meta_); + enrich_meta_from_fork_constant_data(meta); + cache_.put_meta(meta); } } -template -std::optional ContentAddressedCachedTreeStore::get_fork_block() const -{ - if (forkConstantData_.initialised_from_block_.has_value()) { - return forkConstantData_.initialised_from_block_->blockNumber; - } - return std::nullopt; -} - } // namespace bb::crypto::merkle_tree diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp new file mode 100644 index 000000000000..7e2479208c8e --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp @@ -0,0 +1,460 @@ +#pragma once +#include "./tree_meta.hpp" +#include "barretenberg/crypto/merkle_tree/indexed_tree/indexed_leaf.hpp" +#include "barretenberg/crypto/merkle_tree/lmdb_store/lmdb_tree_store.hpp" +#include "barretenberg/crypto/merkle_tree/types.hpp" +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include "barretenberg/numeric/uint256/uint256.hpp" +#include "barretenberg/serialize/msgpack.hpp" +#include "barretenberg/stdlib/primitives/field/field.hpp" +#include "msgpack/assert.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template <> struct std::hash { + std::size_t operator()(const uint256_t& k) const { return k.data[0]; } +}; +template <> struct std::hash { + std::size_t operator()(const bb::fr& k) const + { + bb::numeric::uint256_t val(k); + return val.data[0]; + } +}; + +namespace bb::crypto::merkle_tree { + +// Stores all of the penidng updates to a mekle tree indexed for optimal retrieval +// Also stores a journal of inverse changes to the cache, enabling checkpoints and +// and subsequent commit/revert operations +template class ContentAddressedCache { + public: + using LeafType = LeafValueType; + using IndexedLeafValueType = IndexedLeaf; + using SharedPtr = std::shared_ptr; + using UniquePtr = std::unique_ptr; + + ContentAddressedCache() = delete; + ContentAddressedCache(uint32_t depth); + ~ContentAddressedCache() = default; + ContentAddressedCache(const ContentAddressedCache& other) = default; + ContentAddressedCache& operator=(const ContentAddressedCache& other) = default; + ContentAddressedCache(ContentAddressedCache&& other) noexcept = default; + ContentAddressedCache& operator=(ContentAddressedCache&& other) noexcept = default; + bool operator==(const ContentAddressedCache& other) const = default; + + void checkpoint(); + void revert(); + void commit(); + + void reset(uint32_t depth); + std::pair find_low_value(const uint256_t& new_leaf_key, + const uint256_t& retrieved_value, + const index_t& db_index) const; + + bool get_leaf_preimage_by_hash(const fr& leaf_hash, IndexedLeafValueType& leaf_pre_image) const; + void put_leaf_preimage_by_hash(const fr& leaf_hash, const IndexedLeafValueType& leaf_pre_image); + + bool get_leaf_by_index(const index_t& index, IndexedLeafValueType& leaf_pre_image) const; + void put_leaf_by_index(const index_t& index, const IndexedLeafValueType& leaf_pre_image); + + void update_leaf_key_index(const index_t& index, const fr& leaf_key); + std::optional get_leaf_key_index(const fr& leaf_key) const; + + void put_node(const fr& node_hash, const NodePayload& node); + bool get_node(const fr& node_hash, NodePayload& node) const; + + void put_meta(const TreeMeta& meta) { meta_ = meta; } + const TreeMeta& get_meta() const { return meta_; } + + std::optional get_node_by_index(uint32_t level, const index_t& index) const; + void put_node_by_index(uint32_t level, const index_t& index, const fr& node); + + const std::map& get_indices() const { return indices_; } + + bool is_equivalent_to(const ContentAddressedCache& other) const; + + private: + struct Journal { + // Captures the tree's metadata at the time of checkpoint + TreeMeta meta_; + // Captures the cache's node hashes at the time of checkpoint. If the node does not exist in the cache, the + // optional will == nullopt + // TODO (PhilWindle): Consider where a more optimal approach is a single unordered map, instead of 1 per level + std::vector>> nodes_by_index_; + // Captures the cache's leaf pre-images at the time of checkpoint. Again, if the leaf does not exist in the + // cache, the optional will == nullopt + std::unordered_map> leaf_pre_image_by_index_; + // Captures the addition of new leaf keys into the indices_ cache + std::vector new_leaf_keys_; + + Journal(TreeMeta meta) + : meta_(std::move(meta)) + , nodes_by_index_(meta_.depth + 1, std::unordered_map>()) + {} + }; + // This is a mapping between the node hash and it's payload (children and ref count) for every node in the tree, + // including leaves. As indexed trees are updated, this will end up containing many nodes that are not part of the + // final tree so they need to be omitted from what is committed. + std::unordered_map nodes_; + + // This is a store mapping the leaf key (e.g. slot for public data or nullifier value for nullifier tree) to the + // index in the tree + std::map indices_; + + // This is a mapping from leaf hash to leaf pre-image. This will contain entries that need to be omitted when + // commiting updates + std::unordered_map leaves_; + TreeMeta meta_; + + // The following stores are not persisted, just cached until commit + std::vector> nodes_by_index_; + std::unordered_map leaf_pre_image_by_index_; + + // The currently active journals + std::vector journals_; +}; + +template ContentAddressedCache::ContentAddressedCache(uint32_t depth) +{ + reset(depth); +} + +template void ContentAddressedCache::checkpoint() +{ + journals_.emplace_back(Journal(meta_)); +} + +template void ContentAddressedCache::revert() +{ + if (journals_.empty()) { + throw std::runtime_error("Cannot revert without a checkpoint"); + } + // We need to iterate over the nodes and leaves and + // 1. Remove any that were added since last checkpoint + // 2. Restore any that were updated since last checkpoint + // 3. Remove any new leaf keys that were added to the indices store + // 4. Restore the meta data + + Journal& journal = journals_.back(); + + for (uint32_t i = 0; i < journal.nodes_by_index_.size(); ++i) { + for (const auto& [index, optional_node_hash] : journal.nodes_by_index_[i]) { + // If the optional == nullopt then we remove it from the primary cache, it never existed before + if (!optional_node_hash.has_value()) { + nodes_by_index_[i].erase(index); + } else { + // The optional is not null, this means there is a vlue to be restored to the primary cache + nodes_by_index_[i][index] = optional_node_hash.value(); + } + } + } + + for (const auto& [index, optional_leaf] : journal.leaf_pre_image_by_index_) { + // If the option == nullopt then we remove it from the primary cache, it never existed before + // Also remove from the indices store + if (!optional_leaf.has_value()) { + leaf_pre_image_by_index_.erase(index); + } else { + // There was a leaf pre-image, restore it to the primary cache + // No need to update the indices store as the key has not changed + leaf_pre_image_by_index_[index] = optional_leaf.value(); + } + } + + // Remove any newly added leaf keys + for (const auto& key : journal.new_leaf_keys_) { + indices_.erase(key); + } + + // We need to restore the meta data + meta_ = std::move(journal.meta_); + journals_.pop_back(); +} + +template void ContentAddressedCache::commit() +{ + if (journals_.empty()) { + throw std::runtime_error("Cannot commit without a checkpoint"); + } + + // We need to iterate over the nodes and leaves and merge them into the previous checkpoint if there is one + // We also need to append any newly added leaf keys to the previous checkpoint + // If there is no previous checkpoint then we just destroy the journal as the cache will be correct + + if (journals_.size() == 1) { + journals_.clear(); + return; + } + + Journal& current_journal = journals_.back(); + Journal& previous_journal = journals_[journals_.size() - 2]; + + for (uint32_t i = 0; i < current_journal.nodes_by_index_.size(); ++i) { + for (const auto& [index, optional_node_hash] : current_journal.nodes_by_index_[i]) { + // There is an entry in the current journal, if it does not exist in the previous journal then we need to + // add it If it does exist in the previous journal then that journal already captured a value from the + // primary cache that existed no later + auto previousIter = previous_journal.nodes_by_index_[i].find(index); + if (previousIter == previous_journal.nodes_by_index_[i].end()) { + previous_journal.nodes_by_index_[i][index] = optional_node_hash; + } + } + } + + for (const auto& [index, optional_leaf] : current_journal.leaf_pre_image_by_index_) { + // There is an entry in the current journal, if it does not exist in the previous journal then we need to add it + // If it does exist in the previous journal then that journal already captured a value from the + // primary cache that existed no later + auto previousIter = previous_journal.leaf_pre_image_by_index_.find(index); + if (previousIter == previous_journal.leaf_pre_image_by_index_.end()) { + previous_journal.leaf_pre_image_by_index_[index] = optional_leaf; + } + } + + // Add our newly appended leaf keys to those of the previous journal + previous_journal.new_leaf_keys_.insert(previous_journal.new_leaf_keys_.end(), + current_journal.new_leaf_keys_.cbegin(), + current_journal.new_leaf_keys_.cend()); + + // We don't restore the meta here. We are committing, so the primary cached meta is correct + journals_.pop_back(); +} + +template void ContentAddressedCache::reset(uint32_t depth) +{ + nodes_ = std::unordered_map(); + indices_ = std::map(); + leaves_ = std::unordered_map(); + nodes_by_index_ = std::vector>(depth + 1, std::unordered_map()); + leaf_pre_image_by_index_ = std::unordered_map(); + journals_ = std::vector(); +} + +template +bool ContentAddressedCache::is_equivalent_to(const ContentAddressedCache& other) const +{ + // Meta should be identical + if (meta_ != other.meta_) { + return false; + } + + // Indices should be identical + if (indices_ != other.indices_) { + return false; + } + + // Nodes by index should be identical + if (nodes_by_index_ != other.nodes_by_index_) { + return false; + } + + // Leaf pre-images by index should be identical + if (leaf_pre_image_by_index_ != other.leaf_pre_image_by_index_) { + return false; + } + + // Our leaves should be a subset of the other leaves + for (const auto& [leaf_hash, leaf] : leaves_) { + auto it = other.leaves_.find(leaf_hash); + if (it == other.leaves_.end()) { + return false; + } + if (it->second != leaf) { + return false; + } + } + + // Our nodes should be a subset of the other nodes + for (const auto& [node_hash, node] : nodes_) { + auto it = other.nodes_.find(node_hash); + if (it == other.nodes_.end()) { + return false; + } + if (it->second != node) { + return false; + } + } + return true; +} + +template +std::pair ContentAddressedCache::find_low_value(const uint256_t& new_leaf_key, + const uint256_t& retrieved_value, + const index_t& db_index) const +{ + if (indices_.empty()) { + return std::make_pair(new_leaf_key == retrieved_value, db_index); + } + // At this stage, we have been asked to include uncommitted and the value was not exactly found in the db + auto it = indices_.lower_bound(new_leaf_key); + if (it == indices_.end()) { + // there is no element >= the requested value. + // decrement the iterator to get the value preceeding the requested value + --it; + // we need to return the larger of the db value or the cached value + + return std::make_pair(false, it->first > retrieved_value ? it->second : db_index); + } + + if (it->first == new_leaf_key) { + // the value is already present and the iterator points to it + return std::make_pair(true, it->second); + } + // the iterator points to the element immediately larger than the requested value + // We need to return the highest value from + // 1. The next lowest cached value, if there is one + // 2. The value retrieved from the db + if (it == indices_.begin()) { + // No cached lower value, return the db index + return std::make_pair(false, db_index); + } + --it; + // it now points to the value less than that requested + return std::make_pair(false, it->first > retrieved_value ? it->second : db_index); +} + +template +bool ContentAddressedCache::get_leaf_preimage_by_hash(const fr& leaf_hash, + IndexedLeafValueType& leaf_pre_image) const +{ + typename std::unordered_map::const_iterator it = leaves_.find(leaf_hash); + if (it != leaves_.end()) { + leaf_pre_image = it->second; + return true; + } + return false; +} + +template +void ContentAddressedCache::put_leaf_preimage_by_hash(const fr& leaf_hash, + const IndexedLeafValueType& leaf_pre_image) +{ + leaves_[leaf_hash] = leaf_pre_image; +} + +template +bool ContentAddressedCache::get_leaf_by_index(const index_t& index, + IndexedLeafValueType& leaf_pre_image) const +{ + typename std::unordered_map::const_iterator it = + leaf_pre_image_by_index_.find(index); + if (it != leaf_pre_image_by_index_.end()) { + leaf_pre_image = it->second; + return true; + } + return false; +} + +template +void ContentAddressedCache::put_leaf_by_index(const index_t& index, + const IndexedLeafValueType& leaf_pre_image) +{ + // If there is no current journal then we just update the cache and leave + if (journals_.empty()) { + leaf_pre_image_by_index_[index] = leaf_pre_image; + return; + } + + // There is a journal, grab it + Journal& journal = journals_.back(); + + // If there is no leaf pre-image at the given index then add the index location to the journal's collection of empty + // locations + auto cache_iter = leaf_pre_image_by_index_.find(index); + if (cache_iter == leaf_pre_image_by_index_.end()) { + journal.leaf_pre_image_by_index_[index] = std::nullopt; + } else { + // There is a leaf pre-image. If the journal does not have a pre-image at this index then add it to the journal + auto journalIter = journal.leaf_pre_image_by_index_.find(index); + if (journalIter == journal.leaf_pre_image_by_index_.end()) { + journal.leaf_pre_image_by_index_[index] = cache_iter->second; + } + } + leaf_pre_image_by_index_[index] = leaf_pre_image; +} + +template +void ContentAddressedCache::update_leaf_key_index(const index_t& index, const fr& leaf_key) +{ + uint256_t key = uint256_t(leaf_key); + auto result = indices_.insert({ key, index }); + if (result.second && !journals_.empty()) { + // The insertion took place, if we have a current journal then we need to add to the newly inserted leaf keys + Journal& journal = journals_.back(); + journal.new_leaf_keys_.emplace_back(key); + } +} + +template +std::optional ContentAddressedCache::get_leaf_key_index(const fr& leaf_key) const +{ + auto it = indices_.find(uint256_t(leaf_key)); + if (it == indices_.end()) { + return std::nullopt; + } + return it->second; +} + +template +void ContentAddressedCache::put_node(const fr& node_hash, const NodePayload& node) +{ + nodes_[node_hash] = node; +} + +template +bool ContentAddressedCache::get_node(const fr& node_hash, NodePayload& node) const +{ + auto it = nodes_.find(node_hash); + if (it == nodes_.end()) { + return false; + } + node = it->second; + return true; +} + +template +std::optional ContentAddressedCache::get_node_by_index(uint32_t level, const index_t& index) const +{ + auto it = nodes_by_index_[level].find(index); + if (it == nodes_by_index_[level].end()) { + return std::nullopt; + } + return it->second; +} + +template +void ContentAddressedCache::put_node_by_index(uint32_t level, const index_t& index, const fr& node) +{ + // If there is no current journal then we just update the cache and leave + if (journals_.empty()) { + nodes_by_index_[level][index] = node; + return; + } + + // There is a journal, grab it + Journal& journal = journals_.back(); + + // If there is no node at the given location then add a nullopt to the journal + auto cacheIter = nodes_by_index_[level].find(index); + if (cacheIter == nodes_by_index_[level].end()) { + journal.nodes_by_index_[level][index] = std::nullopt; + } else { + // There is a node. If the journal does not have a node at this index then add it to the journal + auto journalIter = journal.nodes_by_index_[level].find(index); + if (journalIter == journal.nodes_by_index_[level].end()) { + journal.nodes_by_index_[level][index] = cacheIter->second; + } + } + nodes_by_index_[level][index] = node; +} +} // namespace bb::crypto::merkle_tree \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp new file mode 100644 index 000000000000..76d8920b3b04 --- /dev/null +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.test.cpp @@ -0,0 +1,552 @@ +#include "barretenberg/crypto/merkle_tree/node_store/content_addressed_cache.hpp" +#include "barretenberg/common/test.hpp" +#include "barretenberg/crypto/merkle_tree/fixtures.hpp" +#include "barretenberg/crypto/merkle_tree/indexed_tree/indexed_leaf.hpp" +#include "barretenberg/crypto/merkle_tree/node_store/tree_meta.hpp" +#include "barretenberg/ecc/curves/bn254/fr.hpp" +#include +#include + +using namespace bb; +using namespace bb::crypto::merkle_tree; + +using LeafValueType = PublicDataLeafValue; +using IndexedLeafType = IndexedLeaf; +using CacheType = ContentAddressedCache; + +class ContentAddressedCacheTest : public testing::Test { + protected: + void SetUp() override {} + void TearDown() override{}; +}; + +uint64_t get_index(uint64_t max_index = 0) +{ + uint64_t result = random_engine.get_random_uint64(); + return max_index == 0 ? 0 : result % max_index; +} + +void add_to_cache( + CacheType& cache, index_t leaf_offset, uint64_t num_leaves, uint64_t num_nodes, uint64_t max_index = 0) +{ + for (uint64_t i = 0; i < num_leaves; i++) { + fr slot = fr::random_element(); + fr value = fr::random_element(); + index_t next_index = get_index(max_index); + fr next_value = fr::random_element(); + IndexedLeafType leaf = IndexedLeafType(LeafValueType(slot, value), next_index, next_value); + fr leaf_hash = fr::random_element(); + cache.put_leaf_by_index(i + leaf_offset, leaf); + cache.put_leaf_preimage_by_hash(leaf_hash, leaf); + cache.update_leaf_key_index(i + leaf_offset, leaf.value.get_key()); + } + + for (uint64_t i = 0; i < num_nodes; i++) { + fr node_hash = fr::random_element(); + NodePayload node = { fr::random_element(), fr::random_element(), 0 }; + cache.put_node(node_hash, node); + + uint32_t level = uint32_t(i % uint64_t(cache.get_meta().depth)); + index_t max_index_at_level = 1; + max_index_at_level <<= level; + max_index_at_level--; + index_t index = get_index(max_index_at_level); + cache.put_node_by_index(level, index, node_hash); + } + + TreeMeta meta = cache.get_meta(); + meta.size += num_leaves; + cache.put_meta(meta); +} + +CacheType create_cache(uint32_t depth) +{ + TreeMeta meta; + meta.depth = depth; + meta.size = 0; + CacheType cache(depth); + cache.put_meta(meta); + return cache; +} + +TEST_F(ContentAddressedCacheTest, can_create_cache) +{ + constexpr uint32_t depth = 10; + EXPECT_NO_THROW(CacheType cache(depth)); +} + +TEST_F(ContentAddressedCacheTest, can_checkpoint_cache) +{ + CacheType cache = create_cache(10); + add_to_cache(cache, 0, 10, 100); + EXPECT_NO_THROW(cache.checkpoint()); +} + +TEST_F(ContentAddressedCacheTest, can_not_revert_cache_without_checkpoint) +{ + CacheType cache = create_cache(10); + EXPECT_THROW(cache.revert(), std::runtime_error); +} + +TEST_F(ContentAddressedCacheTest, can_not_commit_cache_without_checkpoint) +{ + CacheType cache = create_cache(10); + EXPECT_THROW(cache.commit(), std::runtime_error); +} + +// Adds 4 node hashes by the given level and index +// Returns the 4 hashes +std::vector setup_nodes_test(uint32_t level, uint64_t index, CacheType& cache) +{ + // Now add a new node + fr node_hash_1 = fr::random_element(); + cache.put_node_by_index(level, index, node_hash_1); + + // Now add a new value at the same location + fr node_hash_2 = fr::random_element(); + cache.put_node_by_index(level, index, node_hash_2); + + // Checkpoint again + cache.checkpoint(); + fr node_hash_3 = fr::random_element(); + cache.put_node_by_index(level, index, node_hash_3); + + // Now add a new value at the same location + fr node_hash_4 = fr::random_element(); + cache.put_node_by_index(level, index, node_hash_4); + return { node_hash_1, node_hash_2, node_hash_3, node_hash_4 }; +} + +TEST_F(ContentAddressedCacheTest, commit_then_revert_nodes) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + uint32_t level = 5; + uint64_t index = 15; + + std::vector hashes = setup_nodes_test(level, index, cache); + fr node_hash_4 = hashes[3]; + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); + + // Commit the last checkpoint + cache.commit(); + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); + + // Revert the next checkpoint, there should be no node at this location + cache.revert(); + EXPECT_FALSE(cache.get_node_by_index(level, index).has_value()); +} + +TEST_F(ContentAddressedCacheTest, commit_then_commit_nodes) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + uint32_t level = 5; + uint64_t index = 15; + std::vector hashes = setup_nodes_test(level, index, cache); + fr node_hash_4 = hashes[3]; + + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); + + // Commit the last checkpoint + cache.commit(); + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); + + // Commit again and we should still have the same node + cache.commit(); + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); +} + +TEST_F(ContentAddressedCacheTest, revert_then_commit_nodes) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + uint32_t level = 5; + uint64_t index = 15; + std::vector hashes = setup_nodes_test(level, index, cache); + fr node_hash_4 = hashes[3]; + fr node_hash_2 = hashes[1]; + + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); + + // Revert the last checkpoint + cache.revert(); + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_2); + + // Commit the next checkpoint + cache.commit(); + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_2); +} + +TEST_F(ContentAddressedCacheTest, revert_then_revert_nodes) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + uint32_t level = 5; + uint64_t index = 15; + std::vector hashes = setup_nodes_test(level, index, cache); + fr node_hash_4 = hashes[3]; + fr node_hash_2 = hashes[1]; + + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_4); + + // Revert the last checkpoint + cache.revert(); + // Check current node value + EXPECT_EQ(cache.get_node_by_index(level, index).value(), node_hash_2); + + // Revert the next checkpoint, should be no node at this location + cache.revert(); + EXPECT_FALSE(cache.get_node_by_index(level, index).has_value()); +} + +std::optional get_leaf_by_index(CacheType& cache, index_t index) +{ + IndexedLeafType leaf; + if (cache.get_leaf_by_index(index, leaf)) { + return leaf; + } + return std::nullopt; +} + +// Adds 4 leaf values at the given index +// Return all 4 leaves +std::vector setup_leaves_tests(uint32_t index, CacheType& cache) +{ + fr slot = fr::random_element(); + fr value1 = fr::random_element(); + index_t next_index = 15; + fr next_value = fr::random_element(); + // Now add a new node + IndexedLeafType leaf1 = IndexedLeafType(LeafValueType(slot, value1), next_index, next_value); + cache.put_leaf_by_index(index, leaf1); + cache.update_leaf_key_index(index, leaf1.value.get_key()); + + // Now add a new value at the same location + fr value2 = fr::random_element(); + IndexedLeafType leaf2 = IndexedLeafType(LeafValueType(slot, value2), next_index, next_value); + cache.put_leaf_by_index(index, leaf2); + cache.update_leaf_key_index(index, leaf2.value.get_key()); + + // Checkpoint again + cache.checkpoint(); + fr value3 = fr::random_element(); + IndexedLeafType leaf3 = IndexedLeafType(LeafValueType(slot, value3), next_index, next_value); + cache.put_leaf_by_index(index, leaf3); + cache.update_leaf_key_index(index, leaf3.value.get_key()); + + // Now add a new value at the same location + fr value4 = fr::random_element(); + IndexedLeafType leaf4 = IndexedLeafType(LeafValueType(slot, value4), next_index, next_value); + cache.put_leaf_by_index(index, leaf4); + cache.update_leaf_key_index(index, leaf4.value.get_key()); + return { leaf1, leaf2, leaf3, leaf4 }; +} + +TEST_F(ContentAddressedCacheTest, commit_then_revert_leaves) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + + uint32_t index = 67; + std::vector leaves = setup_leaves_tests(index, cache); + IndexedLeafType leaf4 = leaves[3]; + + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Commit the last checkpoint + cache.commit(); + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Revert the next checkpoint, there should be no leaf at this location + cache.revert(); + EXPECT_FALSE(get_leaf_by_index(cache, index).has_value()); + EXPECT_FALSE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); +} + +TEST_F(ContentAddressedCacheTest, commit_then_commit_leaves) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + + uint32_t index = 67; + std::vector leaves = setup_leaves_tests(index, cache); + IndexedLeafType leaf4 = leaves[3]; + + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Commit the last checkpoint + cache.commit(); + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Commit the next checkpoint, should still have the same leaf + cache.commit(); + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); +} + +TEST_F(ContentAddressedCacheTest, revert_then_commit_leaves) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + + uint32_t index = 67; + std::vector leaves = setup_leaves_tests(index, cache); + IndexedLeafType leaf4 = leaves[3]; + IndexedLeafType leaf2 = leaves[1]; + + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Revert the last checkpoint + cache.revert(); + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf2); + + // Verify the indices store still has the key at the same index + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Commit the next checkpoint, should still have the same leaf + cache.commit(); + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf2); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); +} + +TEST_F(ContentAddressedCacheTest, revert_then_revert_leaves) +{ + CacheType cache = create_cache(10); + cache.checkpoint(); + + uint32_t index = 67; + std::vector leaves = setup_leaves_tests(index, cache); + IndexedLeafType leaf4 = leaves[3]; + IndexedLeafType leaf2 = leaves[1]; + + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf4); + + // Verify the indices store + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Revert the last checkpoint + cache.revert(); + // Check current leaf value + EXPECT_TRUE(get_leaf_by_index(cache, index).has_value()); + EXPECT_EQ(get_leaf_by_index(cache, index).value(), leaf2); + + // Verify the indices store still has the key at the same index + EXPECT_TRUE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); + EXPECT_EQ(cache.get_leaf_key_index(leaf4.value.get_key()).value(), index); + + // Revert the next checkpoint, there should be no leaf at this location + cache.revert(); + EXPECT_FALSE(get_leaf_by_index(cache, index).has_value()); + EXPECT_FALSE(cache.get_leaf_key_index(leaf4.value.get_key()).has_value()); +} + +TEST_F(ContentAddressedCacheTest, can_revert_cache) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 1000, 10000); + CacheType cache_copy = cache; + cache.checkpoint(); + add_to_cache(cache, 1000, 1000, 10000); + EXPECT_NO_THROW(cache.revert()); + // EXPECT_TRUE(cache_copy.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, can_commit_cache) +{ + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 1000, 10000); + CacheType cache_copy = cache; + cache.checkpoint(); + add_to_cache(cache, 1000, 1000, 10000); + CacheType cache_copy_2 = cache; + cache.checkpoint(); + add_to_cache(cache, 2000, 1000, 10000); + EXPECT_NO_THROW(cache.revert()); + EXPECT_TRUE(cache_copy_2.is_equivalent_to(cache)); + cache.commit(); + EXPECT_TRUE(cache_copy_2.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, can_revert_through_multiple_levels) +{ + uint64_t num_levels = 10; + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 1000, 10000); + + std::vector copies; + + for (uint64_t i = 0; i < num_levels; i++) { + copies.push_back(cache); + cache.checkpoint(); + add_to_cache(cache, (i + 1) * 1000, 1000, 10000); + } + + for (uint64_t i = 0; i < num_levels; i++) { + cache.revert(); + EXPECT_TRUE(copies[num_levels - i - 1].is_equivalent_to(cache)); + } +} + +TEST_F(ContentAddressedCacheTest, can_commit_through_multiple_levels) +{ + uint64_t num_levels = 10; + CacheType cache = create_cache(40); + add_to_cache(cache, 0, 1000, 10000); + + for (uint64_t i = 0; i < num_levels; i++) { + cache.checkpoint(); + add_to_cache(cache, (i + 1) * 1000, 1000, 10000); + } + + CacheType cache_copy = cache; + + for (uint64_t i = 0; i < num_levels; i++) { + cache.commit(); + } + + EXPECT_TRUE(cache_copy.is_equivalent_to(cache)); +} + +void test_reverts_remove_all_deeper_commits(uint64_t max_index, uint32_t depth, uint64_t num_levels) +{ + CacheType cache = create_cache(depth); + add_to_cache(cache, 0, 1000, 10000, max_index); + + CacheType base_cache = cache; + + cache.checkpoint(); + add_to_cache(cache, 1000, 1000, 10000, max_index); + CacheType first_commit_cache = cache; + + // make lots more checkpoints and changes + for (uint64_t i = 1; i < num_levels; i++) { + cache.checkpoint(); + add_to_cache(cache, (i + 1) * 1000, 1000, 10000, max_index); + } + + CacheType final_cache = cache; + + // commit everything except the the first checkpoint + for (uint64_t i = 1; i < num_levels; i++) { + cache.commit(); + } + + // we should still be equivalent to the final commit cache + EXPECT_TRUE(final_cache.is_equivalent_to(cache)); + + // reverting this final checkpoint reverts eveything else + cache.revert(); + EXPECT_TRUE(base_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, reverts_remove_all_deeper_commits) +{ + // We execute this test using 2 different values for max index to produce slightly different behaviour + // A lower value will encourage more updates to existing nodes + // A higher value will mean more new nodes are added + uint32_t depth = 40; + std::array max_indices = { 100, 1000000 }; + uint64_t num_levels = 10; + + for (uint64_t max_index : max_indices) { + test_reverts_remove_all_deeper_commits(max_index, depth, num_levels); + } +} + +void reverts_remove_all_deeper_commits_2(uint64_t max_index, uint32_t depth, uint64_t num_levels) +{ + CacheType cache = create_cache(depth); + add_to_cache(cache, 0, 1000, 10000, max_index); + + CacheType base_cache = cache; + + cache.checkpoint(); + add_to_cache(cache, 1000, 1000, 10000, max_index); + CacheType first_commit_cache = cache; + + // make lots more checkpoints and changes + for (uint64_t i = 1; i < num_levels; i++) { + cache.checkpoint(); + add_to_cache(cache, (i + 1) * 1000, 1000, 10000, max_index); + } + + for (uint64_t i = 1; i < num_levels; i++) { + if (i % 2 != 0) { + cache.revert(); + } else { + cache.commit(); + } + } + + // we should still be equivalent to the first commit cache + EXPECT_TRUE(first_commit_cache.is_equivalent_to(cache)); + + // reverting this final checkpoint reverts eveything else + cache.revert(); + EXPECT_TRUE(base_cache.is_equivalent_to(cache)); +} + +TEST_F(ContentAddressedCacheTest, reverts_remove_all_deeper_commits_2) +{ + // We execute this test using 2 different values for max index to produce slightly different behaviour + // A lower value will encourage more updates to existing nodes + // A higher value will mean more new nodes are added + uint64_t num_levels = 10; + uint32_t depth = 40; + std::array max_indices = { 100, 1000000 }; + for (uint64_t max_index : max_indices) { + reverts_remove_all_deeper_commits_2(max_index, depth, num_levels); + } +} diff --git a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp index db52d26b003e..eb82bd9bab6b 100644 --- a/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp +++ b/barretenberg/cpp/src/barretenberg/crypto/merkle_tree/test_fixtures.hpp @@ -13,7 +13,7 @@ namespace bb::crypto::merkle_tree { -void inline check_block_and_root_data(LMDBTreeStore::SharedPtr db, +inline void check_block_and_root_data(LMDBTreeStore::SharedPtr db, block_number_t blockNumber, fr root, bool expectedSuccess) @@ -30,7 +30,7 @@ void inline check_block_and_root_data(LMDBTreeStore::SharedPtr db, EXPECT_EQ(success, expectedSuccess); } -void inline check_block_and_root_data( +inline void check_block_and_root_data( LMDBTreeStore::SharedPtr db, block_number_t blockNumber, fr root, bool expectedSuccess, bool expectedRootSuccess) { BlockPayload blockData; @@ -45,7 +45,7 @@ void inline check_block_and_root_data( EXPECT_EQ(success, expectedRootSuccess); } -void inline check_block_and_size_data(LMDBTreeStore::SharedPtr db, +inline void check_block_and_size_data(LMDBTreeStore::SharedPtr db, block_number_t blockNumber, index_t expectedSize, bool expectedSuccess) @@ -59,7 +59,7 @@ void inline check_block_and_size_data(LMDBTreeStore::SharedPtr db, } } -void inline check_indices_data( +inline void check_indices_data( LMDBTreeStore::SharedPtr db, fr leaf, index_t index, bool entryShouldBePresent, bool indexShouldBePresent) { index_t retrieved = 0; @@ -71,6 +71,18 @@ void inline check_indices_data( } } +inline void call_operation(std::function)> operation, + bool expected_success = true) +{ + Signal signal; + auto completion = [&](const Response& response) -> void { + EXPECT_EQ(response.success, expected_success); + signal.signal_level(); + }; + operation(completion); + signal.wait_for_level(); +} + template void check_leaf_by_hash(LMDBTreeStore::SharedPtr db, IndexedLeaf leaf, bool shouldBePresent) { @@ -220,4 +232,48 @@ void check_historic_find_leaf_index_from(TypeOfTree& tree, includeUncommitted); } +template +fr_sibling_path get_sibling_path(TypeOfTree& tree, + index_t index, + bool includeUncommitted = true, + bool expected_success = true) +{ + fr_sibling_path h; + Signal signal; + auto completion = [&](const TypedResponse& response) -> void { + EXPECT_EQ(response.success, expected_success); + if (response.success) { + h = response.inner.path; + } + signal.signal_level(); + }; + tree.get_sibling_path(index, completion, includeUncommitted); + signal.wait_for_level(); + return h; +} + +template void rollback_tree(TreeType& tree) +{ + auto completion = [&](auto completion) { tree.rollback(completion); }; + call_operation(completion); +} + +template void checkpoint_tree(TreeType& tree) +{ + auto completion = [&](auto completion) { tree.checkpoint(completion); }; + call_operation(completion); +} + +template void commit_checkpoint_tree(TreeType& tree, bool expected_success = true) + +{ + auto completion = [&](auto completion) { tree.commit_checkpoint(completion); }; + call_operation(completion, expected_success); +} + +template void revert_checkpoint_tree(TreeType& tree, bool expected_success = true) +{ + auto completion = [&](auto completion) { tree.revert_checkpoint(completion); }; + call_operation(completion, expected_success); +} } // namespace bb::crypto::merkle_tree \ No newline at end of file diff --git a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp index 31f2c66d97b0..34fbf97fff55 100644 --- a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp +++ b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.cpp @@ -216,6 +216,18 @@ WorldStateWrapper::WorldStateWrapper(const Napi::CallbackInfo& info) _dispatcher.registerTarget(WorldStateMessageType::CLOSE, [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return close(obj, buffer); }); + + _dispatcher.registerTarget( + WorldStateMessageType::CREATE_CHECKPOINT, + [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return checkpoint(obj, buffer); }); + + _dispatcher.registerTarget( + WorldStateMessageType::COMMIT_CHECKPOINT, + [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return commit_checkpoint(obj, buffer); }); + + _dispatcher.registerTarget( + WorldStateMessageType::REVERT_CHECKPOINT, + [this](msgpack::object& obj, msgpack::sbuffer& buffer) { return revert_checkpoint(obj, buffer); }); } Napi::Value WorldStateWrapper::call(const Napi::CallbackInfo& info) @@ -726,6 +738,48 @@ bool WorldStateWrapper::remove_historical(msgpack::object& obj, msgpack::sbuffer return true; } +bool WorldStateWrapper::checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer) +{ + TypedMessage request; + obj.convert(request); + + _ws->checkpoint(request.value.forkId); + + MsgHeader header(request.header.messageId); + messaging::TypedMessage resp_msg(WorldStateMessageType::CREATE_CHECKPOINT, header, {}); + msgpack::pack(buffer, resp_msg); + + return true; +} + +bool WorldStateWrapper::commit_checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer) +{ + TypedMessage request; + obj.convert(request); + + _ws->commit_checkpoint(request.value.forkId); + + MsgHeader header(request.header.messageId); + messaging::TypedMessage resp_msg(WorldStateMessageType::COMMIT_CHECKPOINT, header, {}); + msgpack::pack(buffer, resp_msg); + + return true; +} + +bool WorldStateWrapper::revert_checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer) +{ + TypedMessage request; + obj.convert(request); + + _ws->revert_checkpoint(request.value.forkId); + + MsgHeader header(request.header.messageId); + messaging::TypedMessage resp_msg(WorldStateMessageType::REVERT_CHECKPOINT, header, {}); + msgpack::pack(buffer, resp_msg); + + return true; +} + bool WorldStateWrapper::get_status(msgpack::object& obj, msgpack::sbuffer& buf) const { HeaderOnlyMessage request; diff --git a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp index f6c070db92d2..0f6c8d8dfe1f 100644 --- a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp +++ b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state.hpp @@ -64,6 +64,10 @@ class WorldStateWrapper : public Napi::ObjectWrap { bool remove_historical(msgpack::object& obj, msgpack::sbuffer& buffer) const; bool get_status(msgpack::object& obj, msgpack::sbuffer& buffer) const; + + bool checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer); + bool commit_checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer); + bool revert_checkpoint(msgpack::object& obj, msgpack::sbuffer& buffer); }; } // namespace bb::nodejs diff --git a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp index a207a0fe2753..2547bc85a7b3 100644 --- a/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp +++ b/barretenberg/cpp/src/barretenberg/nodejs_module/world_state/world_state_message.hpp @@ -48,6 +48,10 @@ enum WorldStateMessageType { GET_STATUS, + CREATE_CHECKPOINT, + COMMIT_CHECKPOINT, + REVERT_CHECKPOINT, + CLOSE = 999, }; @@ -72,6 +76,11 @@ struct DeleteForkRequest { MSGPACK_FIELDS(forkId); }; +struct ForkIdOnlyRequest { + uint64_t forkId; + MSGPACK_FIELDS(forkId); +}; + struct TreeIdAndRevisionRequest { MerkleTreeId treeId; WorldStateRevision revision; diff --git a/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp b/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp index f7fba9cc6c60..d08c87b4224b 100644 --- a/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp +++ b/barretenberg/cpp/src/barretenberg/world_state/world_state.cpp @@ -987,4 +987,85 @@ bool WorldState::determine_if_synched(std::array& metaRespo return true; } +void WorldState::checkpoint(const uint64_t& forkId) +{ + Fork::SharedPtr fork = retrieve_fork(forkId); + Signal signal(static_cast(fork->_trees.size())); + std::array local; + std::mutex mtx; + for (auto& [id, tree] : fork->_trees) { + std::visit( + [&signal, &local, id, &mtx](auto&& wrapper) { + wrapper.tree->checkpoint([&signal, &local, &mtx, id](Response& resp) { + { + std::lock_guard lock(mtx); + local[id] = std::move(resp); + } + signal.signal_decrement(); + }); + }, + tree); + } + signal.wait_for_level(); + for (auto& m : local) { + if (!m.success) { + throw std::runtime_error(m.message); + } + } +} + +void WorldState::commit_checkpoint(const uint64_t& forkId) +{ + Fork::SharedPtr fork = retrieve_fork(forkId); + Signal signal(static_cast(fork->_trees.size())); + std::array local; + std::mutex mtx; + for (auto& [id, tree] : fork->_trees) { + std::visit( + [&signal, &local, id, &mtx](auto&& wrapper) { + wrapper.tree->commit_checkpoint([&signal, &local, &mtx, id](Response& resp) { + { + std::lock_guard lock(mtx); + local[id] = std::move(resp); + } + signal.signal_decrement(); + }); + }, + tree); + } + signal.wait_for_level(); + for (auto& m : local) { + if (!m.success) { + throw std::runtime_error(m.message); + } + } +} + +void WorldState::revert_checkpoint(const uint64_t& forkId) +{ + Fork::SharedPtr fork = retrieve_fork(forkId); + Signal signal(static_cast(fork->_trees.size())); + std::array local; + std::mutex mtx; + for (auto& [id, tree] : fork->_trees) { + std::visit( + [&signal, &local, id, &mtx](auto&& wrapper) { + wrapper.tree->revert_checkpoint([&signal, &local, &mtx, id](Response& resp) { + { + std::lock_guard lock(mtx); + local[id] = std::move(resp); + } + signal.signal_decrement(); + }); + }, + tree); + } + signal.wait_for_level(); + for (auto& m : local) { + if (!m.success) { + throw std::runtime_error(m.message); + } + } +} + } // namespace bb::world_state diff --git a/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp b/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp index 20aeaa2bcfab..efc72c388948 100644 --- a/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp +++ b/barretenberg/cpp/src/barretenberg/world_state/world_state.hpp @@ -249,6 +249,10 @@ class WorldState { const std::vector& nullifiers, const std::vector& public_writes); + void checkpoint(const uint64_t& forkId); + void commit_checkpoint(const uint64_t& forkId); + void revert_checkpoint(const uint64_t& forkId); + private: std::shared_ptr _workers; WorldStateStores::Ptr _persistentStores; diff --git a/yarn-project/circuit-types/src/interfaces/merkle_tree_operations.ts b/yarn-project/circuit-types/src/interfaces/merkle_tree_operations.ts index 288c16ff3ae9..334951fb1c33 100644 --- a/yarn-project/circuit-types/src/interfaces/merkle_tree_operations.ts +++ b/yarn-project/circuit-types/src/interfaces/merkle_tree_operations.ts @@ -258,6 +258,21 @@ export interface MerkleTreeWriteOperations extends MerkleTreeReadOperations { * Closes the database, discarding any uncommitted changes. */ close(): Promise; + + /** + * Checkpoints the current fork state + */ + createCheckpoint(): Promise; + + /** + * Commits the current checkpoint + */ + commitCheckpoint(): Promise; + + /** + * Reverts the current checkpoint + */ + revertCheckpoint(): Promise; } /** diff --git a/yarn-project/circuits.js/src/structs/public_data_write.ts b/yarn-project/circuits.js/src/structs/public_data_write.ts index 001d14a7878f..e052fd44b06b 100644 --- a/yarn-project/circuits.js/src/structs/public_data_write.ts +++ b/yarn-project/circuits.js/src/structs/public_data_write.ts @@ -66,6 +66,10 @@ export class PublicDataWrite { return new PublicDataWrite(Fr.ZERO, Fr.ZERO); } + static random() { + return new PublicDataWrite(Fr.random(), Fr.random()); + } + static isEmpty(data: PublicDataWrite): boolean { return data.isEmpty(); } diff --git a/yarn-project/world-state/src/native/merkle_trees_facade.ts b/yarn-project/world-state/src/native/merkle_trees_facade.ts index cf0f561d7511..8545382f27c2 100644 --- a/yarn-project/world-state/src/native/merkle_trees_facade.ts +++ b/yarn-project/world-state/src/native/merkle_trees_facade.ts @@ -189,7 +189,6 @@ export class MerkleTreesForkFacade extends MerkleTreesFacade implements MerkleTr assert.equal(revision.includeUncommitted, true, 'Fork must include uncommitted data'); super(instance, initialHeader, revision); } - async updateArchive(header: BlockHeader): Promise { await this.instance.call(WorldStateMessageType.UPDATE_ARCHIVE, { forkId: this.revision.forkId, @@ -266,6 +265,21 @@ export class MerkleTreesForkFacade extends MerkleTreesFacade implements MerkleTr assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); await this.instance.call(WorldStateMessageType.DELETE_FORK, { forkId: this.revision.forkId }); } + + public async createCheckpoint(): Promise { + assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); + await this.instance.call(WorldStateMessageType.CREATE_CHECKPOINT, { forkId: this.revision.forkId }); + } + + public async commitCheckpoint(): Promise { + assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); + await this.instance.call(WorldStateMessageType.COMMIT_CHECKPOINT, { forkId: this.revision.forkId }); + } + + public async revertCheckpoint(): Promise { + assert.notEqual(this.revision.forkId, 0, 'Fork ID must be set'); + await this.instance.call(WorldStateMessageType.REVERT_CHECKPOINT, { forkId: this.revision.forkId }); + } } function hydrateLeaf(treeId: ID, leaf: Fr | Buffer) { diff --git a/yarn-project/world-state/src/native/message.ts b/yarn-project/world-state/src/native/message.ts index 4b639f7b4b11..15af10b534b4 100644 --- a/yarn-project/world-state/src/native/message.ts +++ b/yarn-project/world-state/src/native/message.ts @@ -35,6 +35,10 @@ export enum WorldStateMessageType { GET_STATUS, + CREATE_CHECKPOINT, + COMMIT_CHECKPOINT, + REVERT_CHECKPOINT, + CLOSE = 999, } @@ -450,6 +454,10 @@ export type WorldStateRequest = { [WorldStateMessageType.GET_STATUS]: WithCanonicalForkId; + [WorldStateMessageType.CREATE_CHECKPOINT]: WithForkId; + [WorldStateMessageType.COMMIT_CHECKPOINT]: WithForkId; + [WorldStateMessageType.REVERT_CHECKPOINT]: WithForkId; + [WorldStateMessageType.CLOSE]: WithCanonicalForkId; }; @@ -486,6 +494,10 @@ export type WorldStateResponse = { [WorldStateMessageType.GET_STATUS]: WorldStateStatusSummary; + [WorldStateMessageType.CREATE_CHECKPOINT]: void; + [WorldStateMessageType.COMMIT_CHECKPOINT]: void; + [WorldStateMessageType.REVERT_CHECKPOINT]: void; + [WorldStateMessageType.CLOSE]: void; }; diff --git a/yarn-project/world-state/src/native/native_world_state.test.ts b/yarn-project/world-state/src/native/native_world_state.test.ts index ae099d3c661f..6387428cb272 100644 --- a/yarn-project/world-state/src/native/native_world_state.test.ts +++ b/yarn-project/world-state/src/native/native_world_state.test.ts @@ -1,4 +1,4 @@ -import { type L2Block, MerkleTreeId, type MerkleTreeWriteOperations } from '@aztec/circuit-types'; +import { type L2Block, MerkleTreeId, type MerkleTreeWriteOperations, type SiblingPath } from '@aztec/circuit-types'; import { ARCHIVE_HEIGHT, AppendOnlyTreeSnapshot, @@ -13,6 +13,7 @@ import { NOTE_HASH_TREE_HEIGHT, NULLIFIER_TREE_HEIGHT, PUBLIC_DATA_TREE_HEIGHT, + PublicDataWrite, } from '@aztec/circuits.js'; import { makeContentCommitment, makeGlobalVariables } from '@aztec/circuits.js/testing'; @@ -455,8 +456,8 @@ describe('NativeWorldState', () => { for (let i = 0; i < 16; i++) { const blockNumber = i + 1; const nonReorgSnapshot = nonReorgState.getSnapshot(blockNumber); - const reorgSnaphsot = ws.getSnapshot(blockNumber); - await compareChains(reorgSnaphsot, nonReorgSnapshot); + const reorgSnapshot = ws.getSnapshot(blockNumber); + await compareChains(reorgSnapshot, nonReorgSnapshot); } await compareChains(ws.getCommitted(), nonReorgState.getCommitted()); @@ -791,4 +792,364 @@ describe('NativeWorldState', () => { await Promise.all([setupFork.close(), testFork.close()]); }, 30_000); }); + + describe('Checkpoints', () => { + let ws: NativeWorldStateService; + + beforeEach(async () => { + ws = await NativeWorldStateService.tmp(); + const fork = await ws.fork(); + const { block, messages } = await mockBlock(1, 2, fork); + await fork.close(); + + await ws.handleL2BlockAndMessages(block, messages); + }); + + afterEach(async () => { + await ws.close(); + }); + + const getSiblingPaths = async (fork: MerkleTreeWriteOperations) => { + return await Promise.all( + [ + MerkleTreeId.L1_TO_L2_MESSAGE_TREE, + MerkleTreeId.NOTE_HASH_TREE, + MerkleTreeId.NULLIFIER_TREE, + MerkleTreeId.PUBLIC_DATA_TREE, + ].map(x => fork.getSiblingPath(x, 0n)), + ); + }; + + const advanceState = async (fork: MerkleTreeWriteOperations) => { + await Promise.all([ + fork.appendLeaves( + MerkleTreeId.L1_TO_L2_MESSAGE_TREE, + Array.from({ length: 8 }, () => Fr.random()), + ), + fork.appendLeaves( + MerkleTreeId.NOTE_HASH_TREE, + Array.from({ length: 8 }, () => Fr.random()), + ), + fork.sequentialInsert( + MerkleTreeId.PUBLIC_DATA_TREE, + Array.from({ length: 8 }, () => PublicDataWrite.random().toBuffer()), + ), + fork.batchInsert( + MerkleTreeId.NULLIFIER_TREE, + Array.from({ length: 8 }, () => Fr.random().toBuffer()), + 0, + ), + ]); + return getSiblingPaths(fork); + }; + + const compareState = async ( + fork: MerkleTreeWriteOperations, + pathsToCheck: SiblingPath[], + expectedEqual: boolean, + ) => { + const siblingPaths = await getSiblingPaths(fork); + + if (expectedEqual) { + expect(siblingPaths).toEqual(pathsToCheck); + } else { + expect(siblingPaths).not.toEqual(pathsToCheck); + } + return siblingPaths; + }; + + it('can checkpoint and revert', async () => { + const fork = await ws.fork(); + await fork.createCheckpoint(); + + const siblingPathsBefore = await getSiblingPaths(fork); + + await advanceState(fork); + + await compareState(fork, siblingPathsBefore, false); + + await fork.revertCheckpoint(); + + await compareState(fork, siblingPathsBefore, true); + + await fork.close(); + }); + + it('can checkpoint and commit', async () => { + const fork = await ws.fork(); + await fork.createCheckpoint(); + + const siblingPathsBefore = await getSiblingPaths(fork); + + const siblingPathsAfter = await advanceState(fork); + + await compareState(fork, siblingPathsBefore, false); + + await fork.commitCheckpoint(); + + await compareState(fork, siblingPathsAfter, true); + + await fork.close(); + }); + + it('can checkpoint from committed', async () => { + const fork = await ws.fork(); + await fork.createCheckpoint(); + + const siblingPathsBefore = await getSiblingPaths(fork); + + const siblingPathsAfter = await advanceState(fork); + + await compareState(fork, siblingPathsBefore, false); + + await fork.commitCheckpoint(); + + await compareState(fork, siblingPathsAfter, true); + + await fork.createCheckpoint(); + + await advanceState(fork); + + await fork.commitCheckpoint(); + + await compareState(fork, siblingPathsAfter, false); + + await fork.close(); + }); + + it('can checkpoint from reverted', async () => { + const fork = await ws.fork(); + await fork.createCheckpoint(); + + const siblingPathsBefore = await getSiblingPaths(fork); + + const siblingPathsAfter = await advanceState(fork); + + await compareState(fork, siblingPathsBefore, false); + + await fork.commitCheckpoint(); + + await compareState(fork, siblingPathsAfter, true); + + await fork.createCheckpoint(); + + await advanceState(fork); + + await fork.commitCheckpoint(); + + await compareState(fork, siblingPathsAfter, false); + + await fork.close(); + }); + + it('can revert all deeper commits', async () => { + const fork = await ws.fork(); + const siblingPathsBefore = await getSiblingPaths(fork); + + // This is the base checkpoint, this will revert all of the others + await fork.createCheckpoint(); + await advanceState(fork); + + const numCommits = 10; + + for (let i = 0; i < numCommits; i++) { + await fork.createCheckpoint(); + await advanceState(fork); + } + + // now commit all of these, and also advance each committed state further + for (let i = 0; i < numCommits; i++) { + await fork.commitCheckpoint(); + await advanceState(fork); + } + + // check we still have the same state + // now revert the base checkpoint + await fork.revertCheckpoint(); + + await compareState(fork, siblingPathsBefore, true); + + await fork.close(); + }); + + it('can checkpoint many levels', async () => { + const fork = await ws.fork(); + + const stackDepth = 20; + + const siblingsAtEachLevel = []; + + let index = 0; + + for (; index < stackDepth - 1; index++) { + siblingsAtEachLevel[index] = await advanceState(fork); + await fork.createCheckpoint(); + } + + // Add one more depth + siblingsAtEachLevel[index] = await advanceState(fork); + + await compareState(fork, siblingsAtEachLevel[stackDepth - 1], true); + + let checkpointIndex = index; + + // Alternate committing and reverting half the levels + for (; index > stackDepth / 2; index--) { + if (index % 2 == 0) { + // Here we change the checkpoint index + await fork.revertCheckpoint(); + checkpointIndex = index - 1; + } else { + // We don't change the checkpoint index + await fork.commitCheckpoint(); + } + await compareState(fork, siblingsAtEachLevel[checkpointIndex], true); + } + + // Now go down the stack again + for (; index < stackDepth - 1; index++) { + siblingsAtEachLevel[index] = await advanceState(fork); + await fork.createCheckpoint(); + } + + // Add one more depth + siblingsAtEachLevel[index] = await advanceState(fork); + + await compareState(fork, siblingsAtEachLevel[stackDepth - 1], true); + + checkpointIndex = index; + + // Alternate committing and reverting all the levels + for (; index > 0; index--) { + if (index % 2 == 0) { + // Here we change the checkpoint index + await fork.revertCheckpoint(); + checkpointIndex = index - 1; + } else { + // We don't change the checkpoint index + await fork.commitCheckpoint(); + } + await compareState(fork, siblingsAtEachLevel[checkpointIndex], true); + } + + await fork.close(); + }); + + it('can commit and revert', async () => { + const fork = await ws.fork(); + + const getLeaf = async (index: bigint) => { + const leaf = await fork.getLeafValue(MerkleTreeId.NULLIFIER_TREE, index); + return Fr.fromBuffer(leaf!); + }; + + const getPath = async (index: bigint) => { + return await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, index); + }; + + await fork.createCheckpoint(); + + const siblingPaths = []; + let size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + let index = 0; + const initialSize = size; + const initialLeaf = await getLeaf(size - 1n); + const initialPath = await getPath(size - 1n); + + const nullifiers: Fr[] = []; + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + await fork.createCheckpoint(); + index++; + + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + await fork.revertCheckpoint(); + index--; + + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + expect(await getPath(size - 1n)).toEqual(siblingPaths[index]); + + index++; + + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + await fork.createCheckpoint(); + index++; + + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + await fork.revertCheckpoint(); + index--; + + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + expect(await getPath(size - 1n)).toEqual(siblingPaths[index]); + + index++; + + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + index++; + + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + await fork.createCheckpoint(); + index++; + + nullifiers[index] = Fr.random(); + await fork.batchInsert(MerkleTreeId.NULLIFIER_TREE, [nullifiers[index].toBuffer()], 0); + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + + siblingPaths[index] = await fork.getSiblingPath(MerkleTreeId.NULLIFIER_TREE, size - 1n); + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + + await fork.commitCheckpoint(); + + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + expect(await getLeaf(size - 1n)).toEqual(nullifiers[index]); + expect(await getPath(size - 1n)).toEqual(siblingPaths[index]); + + await fork.revertCheckpoint(); + + index = 0; + size = (await fork.getTreeInfo(MerkleTreeId.NULLIFIER_TREE)).size; + expect(size).toBe(initialSize); + expect(await getLeaf(size - 1n)).toEqual(initialLeaf); + expect(await getPath(size - 1n)).toEqual(initialPath); + + await fork.close(); + }); + }); }); diff --git a/yarn-project/world-state/src/native/world_state_ops_queue.ts b/yarn-project/world-state/src/native/world_state_ops_queue.ts index ad786aa7ea46..73f0ca92615c 100644 --- a/yarn-project/world-state/src/native/world_state_ops_queue.ts +++ b/yarn-project/world-state/src/native/world_state_ops_queue.ts @@ -38,6 +38,9 @@ export const MUTATING_MSG_TYPES = new Set([ WorldStateMessageType.FINALISE_BLOCKS, WorldStateMessageType.UNWIND_BLOCKS, WorldStateMessageType.REMOVE_HISTORICAL_BLOCKS, + WorldStateMessageType.CREATE_CHECKPOINT, + WorldStateMessageType.COMMIT_CHECKPOINT, + WorldStateMessageType.REVERT_CHECKPOINT, ]); // This class implements the per-fork operation queue