diff --git a/.gitignore b/.gitignore index 415e2c378f..119f7a628b 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,7 @@ coverage.out execution/evm/jwttoken target /.claude/settings.local.json - +apps/testapp/testapp docs/.vitepress/dist node_modules diff --git a/apps/testapp/cmd/rollback.go b/apps/testapp/cmd/rollback.go new file mode 100644 index 0000000000..e1b0eb519b --- /dev/null +++ b/apps/testapp/cmd/rollback.go @@ -0,0 +1,65 @@ +package cmd + +import ( + "context" + "fmt" + "strconv" + + kvexecutor "github.com/evstack/ev-node/apps/testapp/kv" + rollcmd "github.com/evstack/ev-node/pkg/cmd" + "github.com/evstack/ev-node/pkg/store" + "github.com/spf13/cobra" +) + +var RollbackCmd = &cobra.Command{ + Use: "rollback ", + Short: "Rollback the testapp node", + Args: cobra.RangeArgs(0, 1), + RunE: func(cmd *cobra.Command, args []string) error { + nodeConfig, err := rollcmd.ParseConfig(cmd) + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + datastore, err := store.NewDefaultKVStore(nodeConfig.RootDir, nodeConfig.DBPath, "testapp") + if err != nil { + return err + } + storeWrapper := store.New(datastore) + + executor, err := kvexecutor.NewKVExecutor(nodeConfig.RootDir, nodeConfig.DBPath) + if err != nil { + return err + } + + cmd.Println("Starting rollback operation") + currentHeight, err := storeWrapper.Height(ctx) + if err != nil { + return fmt.Errorf("failed to get current height: %w", err) + } + + var targetHeight uint64 = currentHeight - 1 + if len(args) > 0 { + targetHeight, err = strconv.ParseUint(args[0], 10, 64) + if err != nil { + return fmt.Errorf("failed to parse target height: %w", err) + } + } + + // rollback ev-node store + if err := storeWrapper.Rollback(ctx, targetHeight); err != nil { + return fmt.Errorf("rollback failed: %w", err) + } + + // rollback execution store + if err := executor.Rollback(ctx, targetHeight); err != nil { + return fmt.Errorf("rollback failed: %w", err) + } + + cmd.Println("Rollback completed successfully") + return nil + }, +} diff --git a/apps/testapp/cmd/root.go b/apps/testapp/cmd/root.go index d7001557f6..8e58cc794d 100644 --- a/apps/testapp/cmd/root.go +++ b/apps/testapp/cmd/root.go @@ -9,9 +9,8 @@ import ( const ( // AppName is the name of the application, the name of the command, and the name of the home directory. AppName = "testapp" -) -const ( + // flagKVEndpoint is the flag for the KV endpoint flagKVEndpoint = "kv-endpoint" ) diff --git a/apps/testapp/kv/http_server_test.go b/apps/testapp/kv/http_server_test.go index fd6077033a..7845204fbf 100644 --- a/apps/testapp/kv/http_server_test.go +++ b/apps/testapp/kv/http_server_test.go @@ -277,3 +277,40 @@ func TestHTTPServerContextCancellation(t *testing.T) { t.Fatal("Expected connection error after shutdown, but got none") } } + +func TestHTTPIntegration_GetKVWithMultipleHeights(t *testing.T) { + exec, err := NewKVExecutor(t.TempDir(), "testdb") + if err != nil { + t.Fatalf("Failed to create KVExecutor: %v", err) + } + ctx := context.Background() + + // Execute transactions at different heights for the same key + txsHeight1 := [][]byte{[]byte("testkey=original_value")} + _, _, err = exec.ExecuteTxs(ctx, txsHeight1, 1, time.Now(), []byte("")) + if err != nil { + t.Fatalf("ExecuteTxs failed for height 1: %v", err) + } + + txsHeight2 := [][]byte{[]byte("testkey=updated_value")} + _, _, err = exec.ExecuteTxs(ctx, txsHeight2, 2, time.Now(), []byte("")) + if err != nil { + t.Fatalf("ExecuteTxs failed for height 2: %v", err) + } + + server := NewHTTPServer(exec, ":0") + + // Test GET request - should return the latest value + req := httptest.NewRequest(http.MethodGet, "/kv?key=testkey", nil) + rr := httptest.NewRecorder() + + server.handleKV(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected status 200, got %d", rr.Code) + } + + if rr.Body.String() != "updated_value" { + t.Errorf("expected body 'updated_value', got %q", rr.Body.String()) + } +} diff --git a/apps/testapp/kv/kvexecutor.go b/apps/testapp/kv/kvexecutor.go index 8733665e7b..ad62e723fa 100644 --- a/apps/testapp/kv/kvexecutor.go +++ b/apps/testapp/kv/kvexecutor.go @@ -16,6 +16,7 @@ import ( var ( genesisInitializedKey = ds.NewKey("/genesis/initialized") genesisStateRootKey = ds.NewKey("/genesis/stateroot") + heightKeyPrefix = ds.NewKey("/height") finalizedHeightKey = ds.NewKey("/finalizedHeight") // Define a buffer size for the transaction channel txChannelBufferSize = 10000 @@ -49,18 +50,56 @@ func NewKVExecutor(rootdir, dbpath string) (*KVExecutor, error) { } // GetStoreValue is a helper for the HTTP interface to retrieve the value for a key from the database. +// It searches across all block heights to find the latest value for the given key. func (k *KVExecutor) GetStoreValue(ctx context.Context, key string) (string, bool) { - dsKey := ds.NewKey(key) - valueBytes, err := k.db.Get(ctx, dsKey) - if errors.Is(err, ds.ErrNotFound) { + // Query all keys to find height-prefixed versions of this key + q := query.Query{} + results, err := k.db.Query(ctx, q) + if err != nil { + fmt.Printf("Error querying DB for key '%s': %v\n", key, err) return "", false } - if err != nil { - // Log the error or handle it appropriately - fmt.Printf("Error getting value from DB: %v\n", err) + defer results.Close() + + heightPrefix := heightKeyPrefix.String() + var latestValue string + var latestHeight uint64 + found := false + + for result := range results.Next() { + if result.Error != nil { + fmt.Printf("Error iterating query results for key '%s': %v\n", key, result.Error) + return "", false + } + + resultKey := result.Key + // Check if this is a height-prefixed key that matches our target key + if strings.HasPrefix(resultKey, heightPrefix+"/") { + // Extract height and actual key: /height/{height}/{actual_key} + parts := strings.Split(strings.TrimPrefix(resultKey, heightPrefix+"/"), "/") + if len(parts) >= 2 { + var keyHeight uint64 + if _, err := fmt.Sscanf(parts[0], "%d", &keyHeight); err == nil { + // Reconstruct the actual key by joining all parts after the height + actualKey := strings.Join(parts[1:], "/") + if actualKey == key { + // This key matches - check if it's the latest height + if !found || keyHeight > latestHeight { + latestHeight = keyHeight + latestValue = string(result.Value) + found = true + } + } + } + } + } + } + + if !found { return "", false } - return string(valueBytes), true + + return latestValue, true } // computeStateRoot computes a deterministic state root by querying all keys, sorting them, @@ -206,11 +245,14 @@ func (k *KVExecutor) ExecuteTxs(ctx context.Context, txs [][]byte, blockHeight u if key == "" { return nil, 0, errors.New("empty key in transaction") } - dsKey := ds.NewKey(key) + + dsKey := getTxKey(blockHeight, key) + // Prevent writing reserved keys via transactions if reservedKeys[dsKey] { return nil, 0, fmt.Errorf("transaction attempts to modify reserved key: %s", key) } + err = batch.Put(ctx, dsKey, []byte(value)) if err != nil { // This error is unlikely for Put unless the context is cancelled. @@ -263,3 +305,96 @@ func (k *KVExecutor) InjectTx(tx []byte) { // Consider adding metrics here } } + +// Rollback reverts the state to the previous block height. +func (k *KVExecutor) Rollback(ctx context.Context, height uint64) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + // Validate height constraints + if height == 0 { + return fmt.Errorf("cannot rollback to height 0: invalid height") + } + + // Create a batch for atomic rollback operation + batch, err := k.db.Batch(ctx) + if err != nil { + return fmt.Errorf("failed to create batch for rollback: %w", err) + } + + // Query all keys to find those with height > target height + q := query.Query{} + results, err := k.db.Query(ctx, q) + if err != nil { + return fmt.Errorf("failed to query keys for rollback: %w", err) + } + defer results.Close() + + keysToDelete := make([]ds.Key, 0) + heightPrefix := heightKeyPrefix.String() + + for result := range results.Next() { + if result.Error != nil { + return fmt.Errorf("error iterating query results during rollback: %w", result.Error) + } + + key := result.Key + // Check if this is a height-prefixed key + if strings.HasPrefix(key, heightPrefix+"/") { + // Extract height from key: /height/{height}/{actual_key} (see getTxKey) + parts := strings.Split(strings.TrimPrefix(key, heightPrefix+"/"), "/") + if len(parts) > 0 { + var keyHeight uint64 + if _, err := fmt.Sscanf(parts[0], "%d", &keyHeight); err == nil { + // If this key's height is greater than target, mark for deletion + if keyHeight > height { + keysToDelete = append(keysToDelete, ds.NewKey(key)) + } + } + } + } + } + + // Delete all keys with height > target height + for _, key := range keysToDelete { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + err = batch.Delete(ctx, key) + if err != nil { + return fmt.Errorf("failed to stage delete operation for key '%s' during rollback: %w", key.String(), err) + } + } + + // Update finalized height if necessary - it should not exceed rollback height + finalizedHeightKey := ds.NewKey("/finalizedHeight") + if finalizedHeightBytes, err := k.db.Get(ctx, finalizedHeightKey); err == nil { + var finalizedHeight uint64 + if _, err := fmt.Sscanf(string(finalizedHeightBytes), "%d", &finalizedHeight); err == nil { + if finalizedHeight > height { + err = batch.Put(ctx, finalizedHeightKey, fmt.Appendf([]byte{}, "%d", height)) + if err != nil { + return fmt.Errorf("failed to update finalized height during rollback: %w", err) + } + } + } + } + + // Commit the batch atomically + err = batch.Commit(ctx) + if err != nil { + return fmt.Errorf("failed to commit rollback batch: %w", err) + } + + return nil +} + +func getTxKey(height uint64, txKey string) ds.Key { + return heightKeyPrefix.Child(ds.NewKey(fmt.Sprintf("%d/%s", height, txKey))) +} diff --git a/apps/testapp/main.go b/apps/testapp/main.go index f1ed6307b6..ea2bc094db 100644 --- a/apps/testapp/main.go +++ b/apps/testapp/main.go @@ -20,6 +20,7 @@ func main() { rollcmd.NetInfoCmd, rollcmd.StoreUnsafeCleanCmd, rollcmd.KeysCmd(), + cmds.RollbackCmd, initCmd, ) diff --git a/block/publish_block_p2p_test.go b/block/publish_block_p2p_test.go index 86cb15d766..779bb89a3a 100644 --- a/block/publish_block_p2p_test.go +++ b/block/publish_block_p2p_test.go @@ -17,6 +17,7 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/rs/zerolog" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" coresequencer "github.com/evstack/ev-node/core/sequencer" @@ -28,6 +29,7 @@ import ( "github.com/evstack/ev-node/pkg/signer/noop" "github.com/evstack/ev-node/pkg/store" evSync "github.com/evstack/ev-node/pkg/sync" + "github.com/evstack/ev-node/test/mocks" "github.com/evstack/ev-node/types" ) @@ -200,13 +202,18 @@ func setupBlockManager(t *testing.T, ctx context.Context, workDir string, mainKV require.NoError(t, err) require.NoError(t, dataSyncService.Start(ctx)) + mockExecutor := mocks.NewMockExecutor(t) + mockExecutor.On("InitChain", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bytesN(32), uint64(10_000), nil).Maybe() + mockExecutor.On("ExecuteTxs", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(bytesN(32), uint64(10_000), nil).Maybe() + mockExecutor.On("SetFinal", mock.Anything, mock.Anything).Return(nil).Maybe() + result, err := NewManager( ctx, signer, nodeConfig, genesisDoc, store.New(mainKV), - &mockExecutor{}, + mockExecutor, coresequencer.NewDummySequencer(), nil, blockManagerLogger, @@ -221,24 +228,6 @@ func setupBlockManager(t *testing.T, ctx context.Context, workDir string, mainKV return result, headerSyncService, dataSyncService } -type mockExecutor struct{} - -func (m mockExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (stateRoot []byte, maxBytes uint64, err error) { - return bytesN(32), 10_000, nil -} - -func (m mockExecutor) GetTxs(ctx context.Context) ([][]byte, error) { - panic("implement me") -} - -func (m mockExecutor) ExecuteTxs(ctx context.Context, txs [][]byte, blockHeight uint64, timestamp time.Time, prevStateRoot []byte) (updatedStateRoot []byte, maxBytes uint64, err error) { - return bytesN(32), 10_000, nil -} - -func (m mockExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { - return nil -} - var rnd = rand.New(rand.NewSource(1)) //nolint:gosec // test code only func bytesN(n int) []byte { diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index 79b3b34fa3..58d463d7f3 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -276,8 +276,6 @@ func NewServiceHandler(store store.Store, peerManager p2p.P2PRPC, logger zerolog mux := http.NewServeMux() - fmt.Println("Registering gRPC reflection service...") - compress1KB := connect.WithCompressMinBytes(1024) reflector := grpcreflect.NewStaticReflector( rpc.StoreServiceName, diff --git a/pkg/store/keys.go b/pkg/store/keys.go index f2a8b06686..dc465a9d03 100644 --- a/pkg/store/keys.go +++ b/pkg/store/keys.go @@ -42,8 +42,8 @@ func getSignatureKey(height uint64) string { return GenerateKey([]string{signaturePrefix, strconv.FormatUint(height, 10)}) } -func getStateKey() string { - return statePrefix +func getStateAtHeightKey(height uint64) string { + return GenerateKey([]string{statePrefix, strconv.FormatUint(height, 10)}) } func getMetaKey(key string) string { diff --git a/pkg/store/store.go b/pkg/store/store.go index 7b1cefd8fc..1d911f1e84 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -179,6 +179,11 @@ func (s *DefaultStore) GetSignature(ctx context.Context, height uint64) (*types. // UpdateState updates state saved in Store. Only one State is stored. // If there is no State in Store, state will be saved. func (s *DefaultStore) UpdateState(ctx context.Context, state types.State) error { + currentHeight, err := s.Height(ctx) + if err != nil { + return fmt.Errorf("failed to get current height: %w", err) + } + pbState, err := state.ToProto() if err != nil { return fmt.Errorf("failed to convert type state to protobuf type: %w", err) @@ -187,12 +192,17 @@ func (s *DefaultStore) UpdateState(ctx context.Context, state types.State) error if err != nil { return fmt.Errorf("failed to marshal state to protobuf: %w", err) } - return s.db.Put(ctx, ds.NewKey(getStateKey()), data) + return s.db.Put(ctx, ds.NewKey(getStateAtHeightKey(currentHeight)), data) } // GetState returns last state saved with UpdateState. func (s *DefaultStore) GetState(ctx context.Context) (types.State, error) { - blob, err := s.db.Get(ctx, ds.NewKey(getStateKey())) + currentHeight, err := s.Height(ctx) + if err != nil { + return types.State{}, fmt.Errorf("failed to get current height: %w", err) + } + + blob, err := s.db.Get(ctx, ds.NewKey(getStateAtHeightKey(currentHeight))) if err != nil { return types.State{}, fmt.Errorf("failed to retrieve state: %w", err) } @@ -207,6 +217,30 @@ func (s *DefaultStore) GetState(ctx context.Context) (types.State, error) { return state, err } +// GetStateAtHeight returns the state at the given height. +// If no state is stored at that height, it returns an error. +func (s *DefaultStore) GetStateAtHeight(ctx context.Context, height uint64) (types.State, error) { + blob, err := s.db.Get(ctx, ds.NewKey(getStateAtHeightKey(height))) + if err != nil { + if errors.Is(err, ds.ErrNotFound) { + return types.State{}, fmt.Errorf("no state found at height %d", height) + } + return types.State{}, fmt.Errorf("failed to retrieve state at height %d: %w", height, err) + } + + var pbState pb.State + if err := proto.Unmarshal(blob, &pbState); err != nil { + return types.State{}, fmt.Errorf("failed to unmarshal state from protobuf at height %d: %w", height, err) + } + + var state types.State + if err := state.FromProto(&pbState); err != nil { + return types.State{}, fmt.Errorf("failed to convert protobuf to state at height %d: %w", height, err) + } + + return state, nil +} + // SetMetadata saves arbitrary value in the store. // // Metadata is separated from other data by using prefix in KV. @@ -227,6 +261,93 @@ func (s *DefaultStore) GetMetadata(ctx context.Context, key string) ([]byte, err return data, nil } +// Rollback rolls back block data until the given height from the store. +// NOTE: this function does not rollback metadata. Those should be handled separately. +func (s *DefaultStore) Rollback(ctx context.Context, height uint64) error { + batch, err := s.db.Batch(ctx) + if err != nil { + return fmt.Errorf("failed to create a new batch: %w", err) + } + + currentHeight, err := s.Height(ctx) + if err != nil { + return fmt.Errorf("failed to get current height: %w", err) + } + + if currentHeight <= height { + return nil + } + + daIncludedHeightBz, err := s.GetMetadata(ctx, DAIncludedHeightKey) + if err != nil && !errors.Is(err, ds.ErrNotFound) { + return fmt.Errorf("failed to get DA included height: %w", err) + } else if len(daIncludedHeightBz) == 8 { // valid height stored, so able to check + daIncludedHeight := binary.LittleEndian.Uint64(daIncludedHeightBz) + if daIncludedHeight > height { + return fmt.Errorf("DA included height is greater than the rollback height: cannot rollback a finalized height") + } + } + + for currentHeight > height { + header, err := s.GetHeader(ctx, currentHeight) + if err != nil { + return fmt.Errorf("failed to get header at height %d: %w", currentHeight, err) + } + + if err := batch.Delete(ctx, ds.NewKey(getHeaderKey(currentHeight))); err != nil { + return fmt.Errorf("failed to delete header blob in batch: %w", err) + } + + if err := batch.Delete(ctx, ds.NewKey(getDataKey(currentHeight))); err != nil { + return fmt.Errorf("failed to delete data blob in batch: %w", err) + } + + if err := batch.Delete(ctx, ds.NewKey(getSignatureKey(currentHeight))); err != nil { + return fmt.Errorf("failed to delete signature of block blob in batch: %w", err) + } + + hash := header.Hash() + if err := batch.Delete(ctx, ds.NewKey(getIndexKey(hash))); err != nil { + return fmt.Errorf("failed to delete index key in batch: %w", err) + } + + currentHeight-- + } + + // set height -- using set height checks the current height + // so we cannot use that + heightBytes := encodeHeight(height) + if err := batch.Put(ctx, ds.NewKey(getHeightKey()), heightBytes); err != nil { + return fmt.Errorf("failed to set height: %w", err) + } + + targetState, err := s.GetStateAtHeight(ctx, height) + if err != nil { + return fmt.Errorf("failed to get state at height %d: %w", height, err) + } + + // update state manually to keep using the batch + pbState, err := targetState.ToProto() + if err != nil { + return fmt.Errorf("failed to convert type state to protobuf type: %w", err) + } + + data, err := proto.Marshal(pbState) + if err != nil { + return fmt.Errorf("failed to marshal state to protobuf: %w", err) + } + + if err := batch.Put(ctx, ds.NewKey(getStateAtHeightKey(height)), data); err != nil { + return fmt.Errorf("failed to set state at height %d: %w", height, err) + } + + if err := batch.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit batch: %w", err) + } + + return nil +} + const heightLength = 8 func encodeHeight(height uint64) []byte { diff --git a/pkg/store/store_test.go b/pkg/store/store_test.go index 22b6979484..ce2b4ad6de 100644 --- a/pkg/store/store_test.go +++ b/pkg/store/store_test.go @@ -2,6 +2,8 @@ package store import ( "context" + "encoding/binary" + "errors" "fmt" "testing" @@ -23,6 +25,7 @@ type mockBatchingDatastore struct { unmarshalErrorOnCall int // New field: 0 for no unmarshal error, 1 for first Get, 2 for second Get, etc. getCallCount int // Tracks number of Get calls getErrors []error // Specific errors for sequential Get calls + getMetadataError error // Specific error for GetMetadata calls } // mockBatch is a mock implementation of ds.Batch for testing error cases. @@ -40,6 +43,11 @@ func (m *mockBatchingDatastore) Put(ctx context.Context, key ds.Key, value []byt } func (m *mockBatchingDatastore) Get(ctx context.Context, key ds.Key) ([]byte, error) { + // Check for specific metadata error for DA included height key + if m.getMetadataError != nil && key.String() == "/m/d" { + return nil, m.getMetadataError + } + m.getCallCount++ if len(m.getErrors) >= m.getCallCount && m.getErrors[m.getCallCount-1] != nil { return nil, m.getErrors[m.getCallCount-1] @@ -511,15 +519,18 @@ func TestGetStateError(t *testing.T) { _, err := sGet.GetState(t.Context()) require.Error(err) require.Contains(err.Error(), mockErrGet.Error()) - require.Contains(err.Error(), "failed to retrieve state") // Simulate proto.Unmarshal error mockDsUnmarshal, _ := NewDefaultInMemoryKVStore() - mockBatchingDsUnmarshal := &mockBatchingDatastore{Batching: mockDsUnmarshal, unmarshalErrorOnCall: 1} + mockBatchingDsUnmarshal := &mockBatchingDatastore{Batching: mockDsUnmarshal, unmarshalErrorOnCall: 3} sUnmarshal := New(mockBatchingDsUnmarshal) // Put some data that will cause unmarshal error - err = mockBatchingDsUnmarshal.Put(t.Context(), ds.NewKey(getStateKey()), []byte("invalid state proto")) + height := uint64(1) + err = sUnmarshal.SetHeight(t.Context(), height) + require.NoError(err) + + err = mockBatchingDsUnmarshal.Put(t.Context(), ds.NewKey(getStateAtHeightKey(height)), []byte("invalid state proto")) require.NoError(err) _, err = sUnmarshal.GetState(t.Context()) @@ -612,3 +623,459 @@ func TestGetHeader(t *testing.T) { }) } } + +// TestRollback verifies that rollback successfully removes blocks and updates height +func TestRollback(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create and save multiple blocks + chainID := "test-rollback" + maxHeight := uint64(10) + + for h := uint64(1); h <= maxHeight; h++ { + header, data := types.GetRandomBlock(h, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, h) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: chainID, + InitialHeight: 1, + LastBlockHeight: h, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + } + + // Verify initial state + height, err := store.Height(ctx) + require.NoError(err) + require.Equal(maxHeight, height) + + // Verify all blocks exist + for h := uint64(1); h <= maxHeight; h++ { + _, _, err := store.GetBlockData(ctx, h) + require.NoError(err, "block at height %d should exist", h) + } + + // Execute rollback to height 7 + rollbackToHeight := uint64(7) + err = store.Rollback(ctx, rollbackToHeight) + require.NoError(err) + + // Verify new height + newHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(rollbackToHeight, newHeight) + + // Verify blocks exist only up to rollback height + for h := uint64(1); h <= rollbackToHeight; h++ { + _, _, err := store.GetBlockData(ctx, h) + require.NoError(err, "block at height %d should still exist after rollback", h) + } + + // Verify blocks above rollback height are removed + for h := rollbackToHeight + 1; h <= maxHeight; h++ { + _, _, err := store.GetBlockData(ctx, h) + require.Error(err, "block at height %d should be removed after rollback", h) + } + + // Verify state is rolled back + state, err := store.GetState(ctx) + require.NoError(err) + require.Equal(rollbackToHeight, state.LastBlockHeight) +} + +// TestRollbackToSameHeight verifies that rollback to same height is a no-op +func TestRollbackToSameHeight(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create one block + chainID := "test-rollback-same" + height := uint64(5) + header, data := types.GetRandomBlock(height, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, height) + require.NoError(err) + + // Execute rollback to same height + err = store.Rollback(ctx, height) + require.NoError(err) + + // Verify height unchanged + newHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(height, newHeight) + + // Verify block still exists + _, _, err = store.GetBlockData(ctx, height) + require.NoError(err) +} + +// TestRollbackToHigherHeight verifies that rollback to higher height is a no-op +func TestRollbackToHigherHeight(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create one block + chainID := "test-rollback-higher" + currentHeight := uint64(5) + header, data := types.GetRandomBlock(currentHeight, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, currentHeight) + require.NoError(err) + + // Execute rollback to higher height + rollbackToHeight := uint64(10) + err = store.Rollback(ctx, rollbackToHeight) + require.NoError(err) + + // Verify height unchanged + newHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(currentHeight, newHeight) + + // Verify block still exists + _, _, err = store.GetBlockData(ctx, currentHeight) + require.NoError(err) +} + +// TestRollbackBatchError verifies that rollback handles batch creation errors +func TestRollbackBatchError(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + mock := &mockBatchingDatastore{ + Batching: mustNewInMem(), + batchError: errors.New("batch creation failed"), + } + store := New(mock) + + err := store.Rollback(ctx, uint64(5)) + require.Error(err) + require.Contains(err.Error(), "failed to create a new batch") +} + +// TestRollbackHeightError verifies that rollback handles height retrieval errors +func TestRollbackHeightError(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + mock := &mockBatchingDatastore{ + Batching: mustNewInMem(), + getError: errors.New("height retrieval failed"), + } + store := New(mock) + + err := store.Rollback(ctx, uint64(5)) + require.Error(err) + require.Contains(err.Error(), "failed to get current height") +} + +// TestRollbackDAIncludedHeightValidation verifies DA included height validation during rollback +func TestRollbackDAIncludedHeightValidation(t *testing.T) { + t.Parallel() + require := require.New(t) + + // Test case 1: Rollback to height below DA included height should fail + t.Run("rollback below DA included height fails", func(t *testing.T) { + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create and save multiple blocks + chainID := "test-rollback-da-fail" + maxHeight := uint64(10) + + for h := uint64(1); h <= maxHeight; h++ { + header, data := types.GetRandomBlock(h, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, h) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: chainID, + InitialHeight: 1, + LastBlockHeight: h, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + } + + // Set DA included height to 8 + daIncludedHeight := uint64(8) + heightBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(heightBytes, daIncludedHeight) + err := store.SetMetadata(ctx, DAIncludedHeightKey, heightBytes) + require.NoError(err) + + // Rollback to height below DA included height should fail + err = store.Rollback(ctx, uint64(6)) + require.Error(err) + require.Contains(err.Error(), "DA included height is greater than the rollback height: cannot rollback a finalized height") + }) + + // Test case 2: Rollback to height equal to DA included height should succeed + t.Run("rollback to DA included height succeeds", func(t *testing.T) { + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create and save multiple blocks + chainID := "test-rollback-da-equal" + maxHeight := uint64(10) + + for h := uint64(1); h <= maxHeight; h++ { + header, data := types.GetRandomBlock(h, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, h) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: chainID, + InitialHeight: 1, + LastBlockHeight: h, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + } + + // Set DA included height to 8 + daIncludedHeight := uint64(8) + heightBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(heightBytes, daIncludedHeight) + err := store.SetMetadata(ctx, DAIncludedHeightKey, heightBytes) + require.NoError(err) + + // Rollback to height equal to DA included height should succeed + err = store.Rollback(ctx, uint64(8)) + require.NoError(err) + + // Verify height was rolled back to 8 + currentHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(uint64(8), currentHeight) + }) + + // Test case 3: Rollback to height above DA included height should succeed + t.Run("rollback above DA included height succeeds", func(t *testing.T) { + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create and save multiple blocks + chainID := "test-rollback-da-above" + maxHeight := uint64(10) + + for h := uint64(1); h <= maxHeight; h++ { + header, data := types.GetRandomBlock(h, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, h) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: chainID, + InitialHeight: 1, + LastBlockHeight: h, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + } + + // Set DA included height to 8 + daIncludedHeight := uint64(8) + heightBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(heightBytes, daIncludedHeight) + err := store.SetMetadata(ctx, DAIncludedHeightKey, heightBytes) + require.NoError(err) + + // Rollback to height above DA included height should succeed + err = store.Rollback(ctx, uint64(9)) + require.NoError(err) + + // Verify height was rolled back to 9 + currentHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(uint64(9), currentHeight) + }) +} + +// TestRollbackDAIncludedHeightNotSet verifies rollback works when DA included height is not set +func TestRollbackDAIncludedHeightNotSet(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create and save multiple blocks + chainID := "test-rollback-da-notset" + maxHeight := uint64(5) + + for h := uint64(1); h <= maxHeight; h++ { + header, data := types.GetRandomBlock(h, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, h) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: chainID, + InitialHeight: 1, + LastBlockHeight: h, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + } + + // Don't set DA included height - it should not exist + // Rollback should succeed since no DA included height is set + err := store.Rollback(ctx, uint64(3)) + require.NoError(err) + + // Verify height was rolled back to 3 + currentHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(uint64(3), currentHeight) +} + +// TestRollbackDAIncludedHeightInvalidLength verifies rollback works with invalid DA included height data +func TestRollbackDAIncludedHeightInvalidLength(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + store := New(mustNewInMem()) + + // Setup: create and save multiple blocks + chainID := "test-rollback-da-invalid" + maxHeight := uint64(5) + + for h := uint64(1); h <= maxHeight; h++ { + header, data := types.GetRandomBlock(h, 2, chainID) + sig := &header.Signature + + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + + err = store.SetHeight(ctx, h) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: chainID, + InitialHeight: 1, + LastBlockHeight: h, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + } + + // Set DA included height with invalid length (not 8 bytes) + invalidHeightData := []byte{1, 2, 3, 4} // only 4 bytes + err := store.SetMetadata(ctx, DAIncludedHeightKey, invalidHeightData) + require.NoError(err) + + // Rollback should succeed since invalid length data is ignored + err = store.Rollback(ctx, uint64(3)) + require.NoError(err) + + // Verify height was rolled back to 3 + currentHeight, err := store.Height(ctx) + require.NoError(err) + require.Equal(uint64(3), currentHeight) +} + +// TestRollbackDAIncludedHeightGetMetadataError verifies rollback handles GetMetadata errors for DA included height +func TestRollbackDAIncludedHeightGetMetadataError(t *testing.T) { + t.Parallel() + require := require.New(t) + + ctx := context.Background() + mock := &mockBatchingDatastore{ + Batching: mustNewInMem(), + } + store := New(mock) + + // Setup: create one block to ensure height > rollback target + header, data := types.GetRandomBlock(uint64(2), 2, "test-chain") + sig := &header.Signature + err := store.SaveBlockData(ctx, header, data, sig) + require.NoError(err) + err = store.SetHeight(ctx, uint64(2)) + require.NoError(err) + + // Create and update state for this height + state := types.State{ + ChainID: "test-chain", + InitialHeight: 1, + LastBlockHeight: 2, + LastBlockTime: header.Time(), + AppHash: header.AppHash, + } + err = store.UpdateState(ctx, state) + require.NoError(err) + + // Configure mock to return error when getting DA included height metadata + mock.getMetadataError = errors.New("metadata retrieval failed") + + // Rollback should fail due to GetMetadata error + err = store.Rollback(ctx, uint64(1)) + require.Error(err) + require.Contains(err.Error(), "failed to get DA included height") + require.Contains(err.Error(), "metadata retrieval failed") +} diff --git a/pkg/store/types.go b/pkg/store/types.go index 18f6b6e90c..ca935c08b4 100644 --- a/pkg/store/types.go +++ b/pkg/store/types.go @@ -35,6 +35,8 @@ type Store interface { UpdateState(ctx context.Context, state types.State) error // GetState returns last state saved with UpdateState. GetState(ctx context.Context) (types.State, error) + // GetStateAtHeight returns state saved at given height, or error if it's not found in Store. + GetStateAtHeight(ctx context.Context, height uint64) (types.State, error) // SetMetadata saves arbitrary value in the store. // @@ -44,6 +46,9 @@ type Store interface { // GetMetadata returns values stored for given key with SetMetadata. GetMetadata(ctx context.Context, key string) ([]byte, error) + // Rollback deletes x height from the ev-node store. + Rollback(ctx context.Context, height uint64) error + // Close safely closes underlying data storage, to ensure that data is actually saved. Close() error } diff --git a/test/mocks/Store.go b/test/mocks/store.go similarity index 87% rename from test/mocks/Store.go rename to test/mocks/store.go index 7bd6597bc6..45c10e684f 100644 --- a/test/mocks/Store.go +++ b/test/mocks/store.go @@ -566,6 +566,72 @@ func (_c *MockStore_GetState_Call) RunAndReturn(run func(ctx context.Context) (t return _c } +// GetStateAtHeight provides a mock function for the type MockStore +func (_mock *MockStore) GetStateAtHeight(ctx context.Context, height uint64) (types.State, error) { + ret := _mock.Called(ctx, height) + + if len(ret) == 0 { + panic("no return value specified for GetStateAtHeight") + } + + var r0 types.State + var r1 error + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64) (types.State, error)); ok { + return returnFunc(ctx, height) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64) types.State); ok { + r0 = returnFunc(ctx, height) + } else { + r0 = ret.Get(0).(types.State) + } + if returnFunc, ok := ret.Get(1).(func(context.Context, uint64) error); ok { + r1 = returnFunc(ctx, height) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// MockStore_GetStateAtHeight_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetStateAtHeight' +type MockStore_GetStateAtHeight_Call struct { + *mock.Call +} + +// GetStateAtHeight is a helper method to define mock.On call +// - ctx context.Context +// - height uint64 +func (_e *MockStore_Expecter) GetStateAtHeight(ctx interface{}, height interface{}) *MockStore_GetStateAtHeight_Call { + return &MockStore_GetStateAtHeight_Call{Call: _e.mock.On("GetStateAtHeight", ctx, height)} +} + +func (_c *MockStore_GetStateAtHeight_Call) Run(run func(ctx context.Context, height uint64)) *MockStore_GetStateAtHeight_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 uint64 + if args[1] != nil { + arg1 = args[1].(uint64) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_GetStateAtHeight_Call) Return(state types.State, err error) *MockStore_GetStateAtHeight_Call { + _c.Call.Return(state, err) + return _c +} + +func (_c *MockStore_GetStateAtHeight_Call) RunAndReturn(run func(ctx context.Context, height uint64) (types.State, error)) *MockStore_GetStateAtHeight_Call { + _c.Call.Return(run) + return _c +} + // Height provides a mock function for the type MockStore func (_mock *MockStore) Height(ctx context.Context) (uint64, error) { ret := _mock.Called(ctx) @@ -626,6 +692,63 @@ func (_c *MockStore_Height_Call) RunAndReturn(run func(ctx context.Context) (uin return _c } +// Rollback provides a mock function for the type MockStore +func (_mock *MockStore) Rollback(ctx context.Context, height uint64) error { + ret := _mock.Called(ctx, height) + + if len(ret) == 0 { + panic("no return value specified for Rollback") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64) error); ok { + r0 = returnFunc(ctx, height) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockStore_Rollback_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Rollback' +type MockStore_Rollback_Call struct { + *mock.Call +} + +// Rollback is a helper method to define mock.On call +// - ctx context.Context +// - height uint64 +func (_e *MockStore_Expecter) Rollback(ctx interface{}, height interface{}) *MockStore_Rollback_Call { + return &MockStore_Rollback_Call{Call: _e.mock.On("Rollback", ctx, height)} +} + +func (_c *MockStore_Rollback_Call) Run(run func(ctx context.Context, height uint64)) *MockStore_Rollback_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 uint64 + if args[1] != nil { + arg1 = args[1].(uint64) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockStore_Rollback_Call) Return(err error) *MockStore_Rollback_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockStore_Rollback_Call) RunAndReturn(run func(ctx context.Context, height uint64) error) *MockStore_Rollback_Call { + _c.Call.Return(run) + return _c +} + // SaveBlockData provides a mock function for the type MockStore func (_mock *MockStore) SaveBlockData(ctx context.Context, header *types.SignedHeader, data *types.Data, signature *types.Signature) error { ret := _mock.Called(ctx, header, data, signature)