diff --git a/test/dummy.go b/test/dummy.go index 93be076..74532ac 100644 --- a/test/dummy.go +++ b/test/dummy.go @@ -1,9 +1,12 @@ package test import ( + "bytes" + "context" "crypto/sha512" "fmt" + "slices" "time" "github.com/rollkit/go-execution/types" @@ -38,7 +41,6 @@ func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, in // GetTxs returns the list of transactions (types.Tx) within the DummyExecutor instance and an error if any. func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) { txs := e.injectedTxs - e.injectedTxs = nil return txs, nil } @@ -56,6 +58,7 @@ func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHei } pending := hash.Sum(nil) e.pendingRoots[blockHeight] = pending + e.removeExecutedTxs(txs) return pending, e.maxBytes, nil } @@ -68,3 +71,9 @@ func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error } return fmt.Errorf("cannot set finalized block at height %d", blockHeight) } + +func (e *DummyExecutor) removeExecutedTxs(txs []types.Tx) { + e.injectedTxs = slices.DeleteFunc(e.injectedTxs, func(tx types.Tx) bool { + return slices.ContainsFunc(txs, func(t types.Tx) bool { return bytes.Equal(tx, t) }) + }) +} diff --git a/test/dummy_test.go b/test/dummy_test.go index 94af8b6..17330c7 100644 --- a/test/dummy_test.go +++ b/test/dummy_test.go @@ -1,9 +1,14 @@ package test import ( + "context" "testing" + "time" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/rollkit/go-execution/types" ) type DummyTestSuite struct { @@ -17,3 +22,37 @@ func (s *DummyTestSuite) SetupTest() { func TestDummySuite(t *testing.T) { suite.Run(t, new(DummyTestSuite)) } + +func TestTxRemoval(t *testing.T) { + exec := NewDummyExecutor() + tx1 := types.Tx([]byte{1, 2, 3}) + tx2 := types.Tx([]byte{3, 2, 1}) + + exec.InjectTx(tx1) + exec.InjectTx(tx2) + + // first execution of GetTxs - nothing special + txs, err := exec.GetTxs(context.Background()) + require.NoError(t, err) + require.Len(t, txs, 2) + require.Contains(t, txs, tx1) + require.Contains(t, txs, tx2) + + // ExecuteTxs was not called, so 2 txs should still be returned + txs, err = exec.GetTxs(context.Background()) + require.NoError(t, err) + require.Len(t, txs, 2) + require.Contains(t, txs, tx1) + require.Contains(t, txs, tx2) + + state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), nil) + require.NoError(t, err) + require.NotEmpty(t, state) + + // ExecuteTxs was called, 1 tx remaining in mempool + txs, err = exec.GetTxs(context.Background()) + require.NoError(t, err) + require.Len(t, txs, 1) + require.NotContains(t, txs, tx1) + require.Contains(t, txs, tx2) +}