From 6d6be7cefd56dc61b04068f5c15753db008783de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Zdyba=C5=82?= Date: Tue, 5 Nov 2024 23:25:00 +0100 Subject: [PATCH 1/2] refactor: add context parameter to all executor methods This commit adds context.Context as the first parameter to all executor methods and updates relevant function calls and mocks accordingly. This change enhances context propagation and allows for better control over request lifecycles and timeouts. Resolves #25. --- execution.go | 28 ++----- mocks/mock_Executor.go | 119 +++++++++++++++------------- proxy/grpc/client.go | 20 +---- proxy/grpc/client_server_test.go | 9 ++- proxy/grpc/proxy_test.go | 2 +- proxy/grpc/server.go | 11 +-- proxy/jsonrpc/client.go | 20 ++--- proxy/jsonrpc/client_server_test.go | 18 +++-- proxy/jsonrpc/server.go | 18 ++--- test/dummy.go | 9 ++- test/suite.go | 9 ++- 11 files changed, 118 insertions(+), 145 deletions(-) diff --git a/execution.go b/execution.go index 88574eb..2eeba4c 100644 --- a/execution.go +++ b/execution.go @@ -1,6 +1,7 @@ package execution import ( + "context" "time" "github.com/rollkit/go-execution/types" @@ -9,33 +10,14 @@ import ( // Executor defines a common interface for interacting with the execution client. type Executor interface { // InitChain initializes the blockchain with genesis information. - InitChain( - genesisTime time.Time, - initialHeight uint64, - chainID string, - ) ( - stateRoot types.Hash, - maxBytes uint64, - err error, - ) + InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (stateRoot types.Hash, maxBytes uint64, err error) // GetTxs retrieves all available transactions from the execution client's mempool. - GetTxs() ([]types.Tx, error) + GetTxs(ctx context.Context) ([]types.Tx, error) // ExecuteTxs executes a set of transactions to produce a new block header. - ExecuteTxs( - txs []types.Tx, - blockHeight uint64, - timestamp time.Time, - prevStateRoot types.Hash, - ) ( - updatedStateRoot types.Hash, - maxBytes uint64, - err error, - ) + ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (updatedStateRoot types.Hash, maxBytes uint64, err error) // SetFinal marks a block at the given height as final. - SetFinal( - blockHeight uint64, - ) error + SetFinal(ctx context.Context, blockHeight uint64) error } diff --git a/mocks/mock_Executor.go b/mocks/mock_Executor.go index 7245228..2110bf6 100644 --- a/mocks/mock_Executor.go +++ b/mocks/mock_Executor.go @@ -3,7 +3,10 @@ package mocks import ( + context "context" + header "github.com/celestiaorg/go-header" + mock "github.com/stretchr/testify/mock" time "time" @@ -24,9 +27,9 @@ func (_m *MockExecutor) EXPECT() *MockExecutor_Expecter { return &MockExecutor_Expecter{mock: &_m.Mock} } -// ExecuteTxs provides a mock function with given fields: txs, blockHeight, timestamp, prevStateRoot -func (_m *MockExecutor) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash) (header.Hash, uint64, error) { - ret := _m.Called(txs, blockHeight, timestamp, prevStateRoot) +// ExecuteTxs provides a mock function with given fields: ctx, txs, blockHeight, timestamp, prevStateRoot +func (_m *MockExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash) (header.Hash, uint64, error) { + ret := _m.Called(ctx, txs, blockHeight, timestamp, prevStateRoot) if len(ret) == 0 { panic("no return value specified for ExecuteTxs") @@ -35,25 +38,25 @@ func (_m *MockExecutor) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp var r0 header.Hash var r1 uint64 var r2 error - if rf, ok := ret.Get(0).(func([]types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)); ok { - return rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(0).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)); ok { + return rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } - if rf, ok := ret.Get(0).(func([]types.Tx, uint64, time.Time, header.Hash) header.Hash); ok { - r0 = rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(0).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) header.Hash); ok { + r0 = rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(header.Hash) } } - if rf, ok := ret.Get(1).(func([]types.Tx, uint64, time.Time, header.Hash) uint64); ok { - r1 = rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(1).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) uint64); ok { + r1 = rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } else { r1 = ret.Get(1).(uint64) } - if rf, ok := ret.Get(2).(func([]types.Tx, uint64, time.Time, header.Hash) error); ok { - r2 = rf(txs, blockHeight, timestamp, prevStateRoot) + if rf, ok := ret.Get(2).(func(context.Context, []types.Tx, uint64, time.Time, header.Hash) error); ok { + r2 = rf(ctx, txs, blockHeight, timestamp, prevStateRoot) } else { r2 = ret.Error(2) } @@ -67,17 +70,18 @@ type MockExecutor_ExecuteTxs_Call struct { } // ExecuteTxs is a helper method to define mock.On call +// - ctx context.Context // - txs []types.Tx // - blockHeight uint64 // - timestamp time.Time // - prevStateRoot header.Hash -func (_e *MockExecutor_Expecter) ExecuteTxs(txs interface{}, blockHeight interface{}, timestamp interface{}, prevStateRoot interface{}) *MockExecutor_ExecuteTxs_Call { - return &MockExecutor_ExecuteTxs_Call{Call: _e.mock.On("ExecuteTxs", txs, blockHeight, timestamp, prevStateRoot)} +func (_e *MockExecutor_Expecter) ExecuteTxs(ctx interface{}, txs interface{}, blockHeight interface{}, timestamp interface{}, prevStateRoot interface{}) *MockExecutor_ExecuteTxs_Call { + return &MockExecutor_ExecuteTxs_Call{Call: _e.mock.On("ExecuteTxs", ctx, txs, blockHeight, timestamp, prevStateRoot)} } -func (_c *MockExecutor_ExecuteTxs_Call) Run(run func(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash)) *MockExecutor_ExecuteTxs_Call { +func (_c *MockExecutor_ExecuteTxs_Call) Run(run func(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot header.Hash)) *MockExecutor_ExecuteTxs_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([]types.Tx), args[1].(uint64), args[2].(time.Time), args[3].(header.Hash)) + run(args[0].(context.Context), args[1].([]types.Tx), args[2].(uint64), args[3].(time.Time), args[4].(header.Hash)) }) return _c } @@ -87,14 +91,14 @@ func (_c *MockExecutor_ExecuteTxs_Call) Return(updatedStateRoot header.Hash, max return _c } -func (_c *MockExecutor_ExecuteTxs_Call) RunAndReturn(run func([]types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)) *MockExecutor_ExecuteTxs_Call { +func (_c *MockExecutor_ExecuteTxs_Call) RunAndReturn(run func(context.Context, []types.Tx, uint64, time.Time, header.Hash) (header.Hash, uint64, error)) *MockExecutor_ExecuteTxs_Call { _c.Call.Return(run) return _c } -// GetTxs provides a mock function with given fields: -func (_m *MockExecutor) GetTxs() ([]types.Tx, error) { - ret := _m.Called() +// GetTxs provides a mock function with given fields: ctx +func (_m *MockExecutor) GetTxs(ctx context.Context) ([]types.Tx, error) { + ret := _m.Called(ctx) if len(ret) == 0 { panic("no return value specified for GetTxs") @@ -102,19 +106,19 @@ func (_m *MockExecutor) GetTxs() ([]types.Tx, error) { var r0 []types.Tx var r1 error - if rf, ok := ret.Get(0).(func() ([]types.Tx, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(context.Context) ([]types.Tx, error)); ok { + return rf(ctx) } - if rf, ok := ret.Get(0).(func() []types.Tx); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(context.Context) []types.Tx); ok { + r0 = rf(ctx) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]types.Tx) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) } else { r1 = ret.Error(1) } @@ -128,13 +132,14 @@ type MockExecutor_GetTxs_Call struct { } // GetTxs is a helper method to define mock.On call -func (_e *MockExecutor_Expecter) GetTxs() *MockExecutor_GetTxs_Call { - return &MockExecutor_GetTxs_Call{Call: _e.mock.On("GetTxs")} +// - ctx context.Context +func (_e *MockExecutor_Expecter) GetTxs(ctx interface{}) *MockExecutor_GetTxs_Call { + return &MockExecutor_GetTxs_Call{Call: _e.mock.On("GetTxs", ctx)} } -func (_c *MockExecutor_GetTxs_Call) Run(run func()) *MockExecutor_GetTxs_Call { +func (_c *MockExecutor_GetTxs_Call) Run(run func(ctx context.Context)) *MockExecutor_GetTxs_Call { _c.Call.Run(func(args mock.Arguments) { - run() + run(args[0].(context.Context)) }) return _c } @@ -144,14 +149,14 @@ func (_c *MockExecutor_GetTxs_Call) Return(_a0 []types.Tx, _a1 error) *MockExecu return _c } -func (_c *MockExecutor_GetTxs_Call) RunAndReturn(run func() ([]types.Tx, error)) *MockExecutor_GetTxs_Call { +func (_c *MockExecutor_GetTxs_Call) RunAndReturn(run func(context.Context) ([]types.Tx, error)) *MockExecutor_GetTxs_Call { _c.Call.Return(run) return _c } -// InitChain provides a mock function with given fields: genesisTime, initialHeight, chainID -func (_m *MockExecutor) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (header.Hash, uint64, error) { - ret := _m.Called(genesisTime, initialHeight, chainID) +// InitChain provides a mock function with given fields: ctx, genesisTime, initialHeight, chainID +func (_m *MockExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (header.Hash, uint64, error) { + ret := _m.Called(ctx, genesisTime, initialHeight, chainID) if len(ret) == 0 { panic("no return value specified for InitChain") @@ -160,25 +165,25 @@ func (_m *MockExecutor) InitChain(genesisTime time.Time, initialHeight uint64, c var r0 header.Hash var r1 uint64 var r2 error - if rf, ok := ret.Get(0).(func(time.Time, uint64, string) (header.Hash, uint64, error)); ok { - return rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint64, string) (header.Hash, uint64, error)); ok { + return rf(ctx, genesisTime, initialHeight, chainID) } - if rf, ok := ret.Get(0).(func(time.Time, uint64, string) header.Hash); ok { - r0 = rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(0).(func(context.Context, time.Time, uint64, string) header.Hash); ok { + r0 = rf(ctx, genesisTime, initialHeight, chainID) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(header.Hash) } } - if rf, ok := ret.Get(1).(func(time.Time, uint64, string) uint64); ok { - r1 = rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(1).(func(context.Context, time.Time, uint64, string) uint64); ok { + r1 = rf(ctx, genesisTime, initialHeight, chainID) } else { r1 = ret.Get(1).(uint64) } - if rf, ok := ret.Get(2).(func(time.Time, uint64, string) error); ok { - r2 = rf(genesisTime, initialHeight, chainID) + if rf, ok := ret.Get(2).(func(context.Context, time.Time, uint64, string) error); ok { + r2 = rf(ctx, genesisTime, initialHeight, chainID) } else { r2 = ret.Error(2) } @@ -192,16 +197,17 @@ type MockExecutor_InitChain_Call struct { } // InitChain is a helper method to define mock.On call +// - ctx context.Context // - genesisTime time.Time // - initialHeight uint64 // - chainID string -func (_e *MockExecutor_Expecter) InitChain(genesisTime interface{}, initialHeight interface{}, chainID interface{}) *MockExecutor_InitChain_Call { - return &MockExecutor_InitChain_Call{Call: _e.mock.On("InitChain", genesisTime, initialHeight, chainID)} +func (_e *MockExecutor_Expecter) InitChain(ctx interface{}, genesisTime interface{}, initialHeight interface{}, chainID interface{}) *MockExecutor_InitChain_Call { + return &MockExecutor_InitChain_Call{Call: _e.mock.On("InitChain", ctx, genesisTime, initialHeight, chainID)} } -func (_c *MockExecutor_InitChain_Call) Run(run func(genesisTime time.Time, initialHeight uint64, chainID string)) *MockExecutor_InitChain_Call { +func (_c *MockExecutor_InitChain_Call) Run(run func(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string)) *MockExecutor_InitChain_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(time.Time), args[1].(uint64), args[2].(string)) + run(args[0].(context.Context), args[1].(time.Time), args[2].(uint64), args[3].(string)) }) return _c } @@ -211,22 +217,22 @@ func (_c *MockExecutor_InitChain_Call) Return(stateRoot header.Hash, maxBytes ui return _c } -func (_c *MockExecutor_InitChain_Call) RunAndReturn(run func(time.Time, uint64, string) (header.Hash, uint64, error)) *MockExecutor_InitChain_Call { +func (_c *MockExecutor_InitChain_Call) RunAndReturn(run func(context.Context, time.Time, uint64, string) (header.Hash, uint64, error)) *MockExecutor_InitChain_Call { _c.Call.Return(run) return _c } -// SetFinal provides a mock function with given fields: blockHeight -func (_m *MockExecutor) SetFinal(blockHeight uint64) error { - ret := _m.Called(blockHeight) +// SetFinal provides a mock function with given fields: ctx, blockHeight +func (_m *MockExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { + ret := _m.Called(ctx, blockHeight) if len(ret) == 0 { panic("no return value specified for SetFinal") } var r0 error - if rf, ok := ret.Get(0).(func(uint64) error); ok { - r0 = rf(blockHeight) + if rf, ok := ret.Get(0).(func(context.Context, uint64) error); ok { + r0 = rf(ctx, blockHeight) } else { r0 = ret.Error(0) } @@ -240,14 +246,15 @@ type MockExecutor_SetFinal_Call struct { } // SetFinal is a helper method to define mock.On call +// - ctx context.Context // - blockHeight uint64 -func (_e *MockExecutor_Expecter) SetFinal(blockHeight interface{}) *MockExecutor_SetFinal_Call { - return &MockExecutor_SetFinal_Call{Call: _e.mock.On("SetFinal", blockHeight)} +func (_e *MockExecutor_Expecter) SetFinal(ctx interface{}, blockHeight interface{}) *MockExecutor_SetFinal_Call { + return &MockExecutor_SetFinal_Call{Call: _e.mock.On("SetFinal", ctx, blockHeight)} } -func (_c *MockExecutor_SetFinal_Call) Run(run func(blockHeight uint64)) *MockExecutor_SetFinal_Call { +func (_c *MockExecutor_SetFinal_Call) Run(run func(ctx context.Context, blockHeight uint64)) *MockExecutor_SetFinal_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(uint64)) + run(args[0].(context.Context), args[1].(uint64)) }) return _c } @@ -257,7 +264,7 @@ func (_c *MockExecutor_SetFinal_Call) Return(_a0 error) *MockExecutor_SetFinal_C return _c } -func (_c *MockExecutor_SetFinal_Call) RunAndReturn(run func(uint64) error) *MockExecutor_SetFinal_Call { +func (_c *MockExecutor_SetFinal_Call) RunAndReturn(run func(context.Context, uint64) error) *MockExecutor_SetFinal_Call { _c.Call.Return(run) return _c } diff --git a/proxy/grpc/client.go b/proxy/grpc/client.go index 3d1e1d3..7b14f73 100644 --- a/proxy/grpc/client.go +++ b/proxy/grpc/client.go @@ -51,10 +51,7 @@ func (c *Client) Stop() error { } // InitChain initializes the blockchain with genesis information. -func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { resp, err := c.client.InitChain(ctx, &pb.InitChainRequest{ GenesisTime: genesisTime.Unix(), InitialHeight: initialHeight, @@ -71,10 +68,7 @@ func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID } // GetTxs retrieves all available transactions from the execution client's mempool. -func (c *Client) GetTxs() ([]types.Tx, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) GetTxs(ctx context.Context) ([]types.Tx, error) { resp, err := c.client.GetTxs(ctx, &pb.GetTxsRequest{}) if err != nil { return nil, err @@ -89,10 +83,7 @@ func (c *Client) GetTxs() ([]types.Tx, error) { } // ExecuteTxs executes a set of transactions to produce a new block header. -func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { req := &pb.ExecuteTxsRequest{ Txs: make([][]byte, len(txs)), BlockHeight: blockHeight, @@ -115,10 +106,7 @@ func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.T } // SetFinal marks a block at the given height as final. -func (c *Client) SetFinal(blockHeight uint64) error { - ctx, cancel := context.WithTimeout(context.Background(), c.config.DefaultTimeout) - defer cancel() - +func (c *Client) SetFinal(ctx context.Context, blockHeight uint64) error { _, err := c.client.SetFinal(ctx, &pb.SetFinalRequest{ BlockHeight: blockHeight, }) diff --git a/proxy/grpc/client_server_test.go b/proxy/grpc/client_server_test.go index 14b22b5..af9bc34 100644 --- a/proxy/grpc/client_server_test.go +++ b/proxy/grpc/client_server_test.go @@ -1,11 +1,14 @@ package grpc_test import ( + "context" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -45,7 +48,7 @@ func TestClientServer(t *testing.T) { require.NoError(t, err) defer func() { _ = client.Stop() }() - mockExec.On("GetTxs").Return([]types.Tx{}, nil).Maybe() + mockExec.On("GetTxs", mock.Anything).Return([]types.Tx{}, nil).Maybe() t.Run("InitChain", func(t *testing.T) { genesisTime := time.Now().UTC().Truncate(time.Second) @@ -64,10 +67,10 @@ func TestClientServer(t *testing.T) { unixTime := genesisTime.Unix() expectedTime := time.Unix(unixTime, 0).UTC() - mockExec.On("InitChain", expectedTime, initialHeight, chainID). + mockExec.On("InitChain", mock.Anything, expectedTime, initialHeight, chainID). Return(stateRootHash, expectedMaxBytes, nil).Once() - stateRoot, maxBytes, err := client.InitChain(genesisTime, initialHeight, chainID) + stateRoot, maxBytes, err := client.InitChain(context.TODO(), genesisTime, initialHeight, chainID) require.NoError(t, err) assert.Equal(t, stateRootHash, stateRoot) diff --git a/proxy/grpc/proxy_test.go b/proxy/grpc/proxy_test.go index a88b871..ab45ae0 100644 --- a/proxy/grpc/proxy_test.go +++ b/proxy/grpc/proxy_test.go @@ -65,7 +65,7 @@ func (s *ProxyTestSuite) SetupTest() { require.NoError(s.T(), err) for i := 0; i < 10; i++ { - if _, err := client.GetTxs(); err == nil { + if _, err := client.GetTxs(context.TODO()); err == nil { break } time.Sleep(100 * time.Millisecond) diff --git a/proxy/grpc/server.go b/proxy/grpc/server.go index e285179..40c257f 100644 --- a/proxy/grpc/server.go +++ b/proxy/grpc/server.go @@ -48,11 +48,7 @@ func (s *Server) InitChain(ctx context.Context, req *pb.InitChainRequest) (*pb.I // Convert Unix timestamp to UTC time genesisTime := time.Unix(req.GenesisTime, 0).UTC() - stateRoot, maxBytes, err := s.exec.InitChain( - genesisTime, - req.InitialHeight, - req.ChainId, - ) + stateRoot, maxBytes, err := s.exec.InitChain(context.TODO(), genesisTime, req.InitialHeight, req.ChainId) if err != nil { return nil, err } @@ -65,7 +61,7 @@ func (s *Server) InitChain(ctx context.Context, req *pb.InitChainRequest) (*pb.I // GetTxs handles GetTxs method call from execution API. func (s *Server) GetTxs(ctx context.Context, req *pb.GetTxsRequest) (*pb.GetTxsResponse, error) { - txs, err := s.exec.GetTxs() + txs, err := s.exec.GetTxs(context.TODO()) if err != nil { return nil, err } @@ -91,6 +87,7 @@ func (s *Server) ExecuteTxs(ctx context.Context, req *pb.ExecuteTxsRequest) (*pb copy(prevStateRoot[:], req.PrevStateRoot) updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs( + context.TODO(), txs, req.BlockHeight, time.Unix(req.Timestamp, 0), @@ -108,7 +105,7 @@ func (s *Server) ExecuteTxs(ctx context.Context, req *pb.ExecuteTxsRequest) (*pb // SetFinal handles SetFinal method call from execution API. func (s *Server) SetFinal(ctx context.Context, req *pb.SetFinalRequest) (*pb.SetFinalResponse, error) { - err := s.exec.SetFinal(req.BlockHeight) + err := s.exec.SetFinal(context.TODO(), req.BlockHeight) if err != nil { return nil, err } diff --git a/proxy/jsonrpc/client.go b/proxy/jsonrpc/client.go index cdebf9f..b737467 100644 --- a/proxy/jsonrpc/client.go +++ b/proxy/jsonrpc/client.go @@ -47,7 +47,7 @@ func (c *Client) Stop() error { } // InitChain initializes the blockchain with genesis information. -func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { +func (c *Client) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { params := map[string]interface{}{ "genesis_time": genesisTime.Unix(), "initial_height": initialHeight, @@ -59,7 +59,7 @@ func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID MaxBytes uint64 `json:"max_bytes"` } - if err := c.call("init_chain", params, &result); err != nil { + if err := c.call(context.TODO(), "init_chain", params, &result); err != nil { return types.Hash{}, 0, err } @@ -75,12 +75,12 @@ func (c *Client) InitChain(genesisTime time.Time, initialHeight uint64, chainID } // GetTxs retrieves all available transactions from the execution client's mempool. -func (c *Client) GetTxs() ([]types.Tx, error) { +func (c *Client) GetTxs(context.Context) ([]types.Tx, error) { var result struct { Txs []string `json:"txs"` } - if err := c.call("get_txs", nil, &result); err != nil { + if err := c.call(context.TODO(), "get_txs", nil, &result); err != nil { return nil, err } @@ -97,7 +97,7 @@ func (c *Client) GetTxs() ([]types.Tx, error) { } // ExecuteTxs executes a set of transactions to produce a new block header. -func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { +func (c *Client) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { // Encode txs to base64 encodedTxs := make([]string, len(txs)) for i, tx := range txs { @@ -116,7 +116,7 @@ func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.T MaxBytes uint64 `json:"max_bytes"` } - if err := c.call("execute_txs", params, &result); err != nil { + if err := c.call(context.TODO(), "execute_txs", params, &result); err != nil { return types.Hash{}, 0, err } @@ -132,15 +132,15 @@ func (c *Client) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.T } // SetFinal marks a block at the given height as final. -func (c *Client) SetFinal(blockHeight uint64) error { +func (c *Client) SetFinal(ctx context.Context, blockHeight uint64) error { params := map[string]interface{}{ "block_height": blockHeight, } - return c.call("set_final", params, nil) + return c.call(context.TODO(), "set_final", params, nil) } -func (c *Client) call(method string, params interface{}, result interface{}) error { +func (c *Client) call(ctx context.Context, method string, params interface{}, result interface{}) error { request := struct { JSONRPC string `json:"jsonrpc"` Method string `json:"method"` @@ -158,7 +158,7 @@ func (c *Client) call(method string, params interface{}, result interface{}) err return fmt.Errorf("failed to marshal request: %w", err) } - req, err := http.NewRequestWithContext(context.Background(), "POST", c.endpoint, bytes.NewReader(reqBody)) + req, err := http.NewRequestWithContext(ctx, "POST", c.endpoint, bytes.NewReader(reqBody)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } diff --git a/proxy/jsonrpc/client_server_test.go b/proxy/jsonrpc/client_server_test.go index 2295663..ebb6a5d 100644 --- a/proxy/jsonrpc/client_server_test.go +++ b/proxy/jsonrpc/client_server_test.go @@ -1,11 +1,13 @@ package jsonrpc_test import ( + "context" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/rollkit/go-execution/mocks" @@ -47,10 +49,10 @@ func TestClientServer(t *testing.T) { unixTime := genesisTime.Unix() expectedTime := time.Unix(unixTime, 0).UTC() - mockExec.On("InitChain", expectedTime, initialHeight, chainID). + mockExec.On("InitChain", mock.Anything, expectedTime, initialHeight, chainID). Return(stateRootHash, expectedMaxBytes, nil).Once() - stateRoot, maxBytes, err := client.InitChain(genesisTime, initialHeight, chainID) + stateRoot, maxBytes, err := client.InitChain(context.TODO(), genesisTime, initialHeight, chainID) require.NoError(t, err) assert.Equal(t, stateRootHash, stateRoot) @@ -60,9 +62,9 @@ func TestClientServer(t *testing.T) { t.Run("GetTxs", func(t *testing.T) { expectedTxs := []types.Tx{[]byte("tx1"), []byte("tx2")} - mockExec.On("GetTxs").Return(expectedTxs, nil).Once() + mockExec.On("GetTxs", mock.Anything).Return(expectedTxs, nil).Once() - txs, err := client.GetTxs() + txs, err := client.GetTxs(context.TODO()) require.NoError(t, err) assert.Equal(t, expectedTxs, txs) mockExec.AssertExpectations(t) @@ -85,10 +87,10 @@ func TestClientServer(t *testing.T) { unixTime := timestamp.Unix() expectedTime := time.Unix(unixTime, 0).UTC() - mockExec.On("ExecuteTxs", txs, blockHeight, expectedTime, prevStateRoot). + mockExec.On("ExecuteTxs", mock.Anything, txs, blockHeight, expectedTime, prevStateRoot). Return(expectedStateRoot, expectedMaxBytes, nil).Once() - updatedStateRoot, maxBytes, err := client.ExecuteTxs(txs, blockHeight, timestamp, prevStateRoot) + updatedStateRoot, maxBytes, err := client.ExecuteTxs(context.TODO(), txs, blockHeight, timestamp, prevStateRoot) require.NoError(t, err) assert.Equal(t, expectedStateRoot, updatedStateRoot) @@ -98,9 +100,9 @@ func TestClientServer(t *testing.T) { t.Run("SetFinal", func(t *testing.T) { blockHeight := uint64(1) - mockExec.On("SetFinal", blockHeight).Return(nil).Once() + mockExec.On("SetFinal", mock.Anything, blockHeight).Return(nil).Once() - err := client.SetFinal(blockHeight) + err := client.SetFinal(context.TODO(), blockHeight) require.NoError(t, err) mockExec.AssertExpectations(t) }) diff --git a/proxy/jsonrpc/server.go b/proxy/jsonrpc/server.go index a5d94ad..b91b24b 100644 --- a/proxy/jsonrpc/server.go +++ b/proxy/jsonrpc/server.go @@ -1,6 +1,7 @@ package jsonrpc import ( + "context" "encoding/base64" "encoding/json" "net/http" @@ -90,11 +91,7 @@ func (s *Server) handleInitChain(params json.RawMessage) (interface{}, *jsonRPCE return nil, ErrInvalidParams } - stateRoot, maxBytes, err := s.exec.InitChain( - time.Unix(p.GenesisTime, 0).UTC(), - p.InitialHeight, - p.ChainID, - ) + stateRoot, maxBytes, err := s.exec.InitChain(context.TODO(), time.Unix(p.GenesisTime, 0).UTC(), p.InitialHeight, p.ChainID) if err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } @@ -106,7 +103,7 @@ func (s *Server) handleInitChain(params json.RawMessage) (interface{}, *jsonRPCE } func (s *Server) handleGetTxs() (interface{}, *jsonRPCError) { - txs, err := s.exec.GetTxs() + txs, err := s.exec.GetTxs(context.TODO()) if err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } @@ -152,12 +149,7 @@ func (s *Server) handleExecuteTxs(params json.RawMessage) (interface{}, *jsonRPC var prevStateRoot types.Hash copy(prevStateRoot[:], prevStateRootBytes) - updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs( - txs, - p.BlockHeight, - time.Unix(p.Timestamp, 0).UTC(), - prevStateRoot, - ) + updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs(context.TODO(), txs, p.BlockHeight, time.Unix(p.Timestamp, 0).UTC(), prevStateRoot) if err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } @@ -177,7 +169,7 @@ func (s *Server) handleSetFinal(params json.RawMessage) (interface{}, *jsonRPCEr return nil, ErrInvalidParams } - if err := s.exec.SetFinal(p.BlockHeight); err != nil { + if err := s.exec.SetFinal(context.TODO(), p.BlockHeight); err != nil { return nil, &jsonRPCError{Code: ErrCodeInternal, Message: err.Error()} } diff --git a/test/dummy.go b/test/dummy.go index 2214aea..bac37a4 100644 --- a/test/dummy.go +++ b/test/dummy.go @@ -1,6 +1,7 @@ package test import ( + "context" "time" "github.com/rollkit/go-execution/types" @@ -24,22 +25,22 @@ func NewDummyExecutor() *DummyExecutor { // InitChain initializes the chain state with the given genesis time, initial height, and chain ID. // It returns the state root hash, the maximum byte size, and an error if the initialization fails. -func (e *DummyExecutor) InitChain(genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { +func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (types.Hash, uint64, error) { return e.stateRoot, e.maxBytes, nil } // GetTxs returns the list of transactions (types.Tx) within the DummyExecutor instance and an error if any. -func (e *DummyExecutor) GetTxs() ([]types.Tx, error) { +func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) { return e.txs, nil } // ExecuteTxs simulate execution of transactions. -func (e *DummyExecutor) ExecuteTxs(txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { +func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uint64, timestamp time.Time, prevStateRoot types.Hash) (types.Hash, uint64, error) { e.txs = append(e.txs, txs...) return e.stateRoot, e.maxBytes, nil } // SetFinal marks block at given height as finalized. Currently not implemented. -func (e *DummyExecutor) SetFinal(blockHeight uint64) error { +func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { return nil } diff --git a/test/suite.go b/test/suite.go index 244dc90..c0536fb 100644 --- a/test/suite.go +++ b/test/suite.go @@ -1,6 +1,7 @@ package test import ( + "context" "time" "github.com/stretchr/testify/suite" @@ -21,7 +22,7 @@ func (s *ExecutorSuite) TestInitChain() { initialHeight := uint64(1) chainID := "test-chain" - stateRoot, maxBytes, err := s.Exec.InitChain(genesisTime, initialHeight, chainID) + stateRoot, maxBytes, err := s.Exec.InitChain(context.TODO(), genesisTime, initialHeight, chainID) s.Require().NoError(err) s.NotEqual(types.Hash{}, stateRoot) s.Greater(maxBytes, uint64(0)) @@ -29,7 +30,7 @@ func (s *ExecutorSuite) TestInitChain() { // TestGetTxs tests GetTxs method. func (s *ExecutorSuite) TestGetTxs() { - txs, err := s.Exec.GetTxs() + txs, err := s.Exec.GetTxs(context.TODO()) s.Require().NoError(err) s.NotNil(txs) } @@ -41,7 +42,7 @@ func (s *ExecutorSuite) TestExecuteTxs() { timestamp := time.Now().UTC() prevStateRoot := types.Hash{1, 2, 3} - stateRoot, maxBytes, err := s.Exec.ExecuteTxs(txs, blockHeight, timestamp, prevStateRoot) + stateRoot, maxBytes, err := s.Exec.ExecuteTxs(context.TODO(), txs, blockHeight, timestamp, prevStateRoot) s.Require().NoError(err) s.NotEqual(types.Hash{}, stateRoot) s.Greater(maxBytes, uint64(0)) @@ -49,6 +50,6 @@ func (s *ExecutorSuite) TestExecuteTxs() { // TestSetFinal tests SetFinal method. func (s *ExecutorSuite) TestSetFinal() { - err := s.Exec.SetFinal(1) + err := s.Exec.SetFinal(context.TODO(), 1) s.Require().NoError(err) } From b81ed4d8f3d5d9bc85e5bd6c06534f068f35200d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20Zdyba=C5=82?= Date: Thu, 7 Nov 2024 21:25:34 +0100 Subject: [PATCH 2/2] refactor: replace `context.TODO()` with `ctx` arg --- proxy/grpc/server.go | 8 ++++---- proxy/jsonrpc/client.go | 10 +++++----- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/proxy/grpc/server.go b/proxy/grpc/server.go index 40c257f..629360e 100644 --- a/proxy/grpc/server.go +++ b/proxy/grpc/server.go @@ -48,7 +48,7 @@ func (s *Server) InitChain(ctx context.Context, req *pb.InitChainRequest) (*pb.I // Convert Unix timestamp to UTC time genesisTime := time.Unix(req.GenesisTime, 0).UTC() - stateRoot, maxBytes, err := s.exec.InitChain(context.TODO(), genesisTime, req.InitialHeight, req.ChainId) + stateRoot, maxBytes, err := s.exec.InitChain(ctx, genesisTime, req.InitialHeight, req.ChainId) if err != nil { return nil, err } @@ -61,7 +61,7 @@ func (s *Server) InitChain(ctx context.Context, req *pb.InitChainRequest) (*pb.I // GetTxs handles GetTxs method call from execution API. func (s *Server) GetTxs(ctx context.Context, req *pb.GetTxsRequest) (*pb.GetTxsResponse, error) { - txs, err := s.exec.GetTxs(context.TODO()) + txs, err := s.exec.GetTxs(ctx) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (s *Server) ExecuteTxs(ctx context.Context, req *pb.ExecuteTxsRequest) (*pb copy(prevStateRoot[:], req.PrevStateRoot) updatedStateRoot, maxBytes, err := s.exec.ExecuteTxs( - context.TODO(), + ctx, txs, req.BlockHeight, time.Unix(req.Timestamp, 0), @@ -105,7 +105,7 @@ func (s *Server) ExecuteTxs(ctx context.Context, req *pb.ExecuteTxsRequest) (*pb // SetFinal handles SetFinal method call from execution API. func (s *Server) SetFinal(ctx context.Context, req *pb.SetFinalRequest) (*pb.SetFinalResponse, error) { - err := s.exec.SetFinal(context.TODO(), req.BlockHeight) + err := s.exec.SetFinal(ctx, req.BlockHeight) if err != nil { return nil, err } diff --git a/proxy/jsonrpc/client.go b/proxy/jsonrpc/client.go index b737467..294a88e 100644 --- a/proxy/jsonrpc/client.go +++ b/proxy/jsonrpc/client.go @@ -59,7 +59,7 @@ func (c *Client) InitChain(ctx context.Context, genesisTime time.Time, initialHe MaxBytes uint64 `json:"max_bytes"` } - if err := c.call(context.TODO(), "init_chain", params, &result); err != nil { + if err := c.call(ctx, "init_chain", params, &result); err != nil { return types.Hash{}, 0, err } @@ -75,12 +75,12 @@ func (c *Client) InitChain(ctx context.Context, genesisTime time.Time, initialHe } // GetTxs retrieves all available transactions from the execution client's mempool. -func (c *Client) GetTxs(context.Context) ([]types.Tx, error) { +func (c *Client) GetTxs(ctx context.Context) ([]types.Tx, error) { var result struct { Txs []string `json:"txs"` } - if err := c.call(context.TODO(), "get_txs", nil, &result); err != nil { + if err := c.call(ctx, "get_txs", nil, &result); err != nil { return nil, err } @@ -116,7 +116,7 @@ func (c *Client) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHeight uin MaxBytes uint64 `json:"max_bytes"` } - if err := c.call(context.TODO(), "execute_txs", params, &result); err != nil { + if err := c.call(ctx, "execute_txs", params, &result); err != nil { return types.Hash{}, 0, err } @@ -137,7 +137,7 @@ func (c *Client) SetFinal(ctx context.Context, blockHeight uint64) error { "block_height": blockHeight, } - return c.call(context.TODO(), "set_final", params, nil) + return c.call(ctx, "set_final", params, nil) } func (c *Client) call(ctx context.Context, method string, params interface{}, result interface{}) error {