diff --git a/block/manager.go b/block/manager.go index 0820814074..9c7c9000da 100644 --- a/block/manager.go +++ b/block/manager.go @@ -17,6 +17,7 @@ import ( goheader "github.com/celestiaorg/go-header" ds "github.com/ipfs/go-datastore" "github.com/libp2p/go-libp2p/core/crypto" + "golang.org/x/sync/errgroup" coreda "github.com/rollkit/rollkit/core/da" coreexecutor "github.com/rollkit/rollkit/core/execution" @@ -46,9 +47,6 @@ const ( // This is temporary solution. It will be removed in future versions. maxSubmitAttempts = 30 - // Applies to most channels, 100 is a large enough buffer to avoid blocking - channelLength = 100 - // Applies to the headerInCh and dataInCh, 10000 is a large enough number for headers per DA block. eventInChLength = 10000 @@ -90,6 +88,10 @@ type BatchData struct { Data [][]byte } +type broadcaster[T any] interface { + WriteToStoreAndBroadcast(ctx context.Context, payload T) error +} + // Manager is responsible for aggregating transactions into blocks. type Manager struct { lastState types.State @@ -104,8 +106,8 @@ type Manager struct { daHeight *atomic.Uint64 - HeaderCh chan *types.SignedHeader - DataCh chan *types.Data + headerBroadcaster broadcaster[*types.SignedHeader] + dataBroadcaster broadcaster[*types.Data] headerInCh chan NewHeaderEvent headerStore goheader.Store[*types.SignedHeader] @@ -268,6 +270,8 @@ func NewManager( logger log.Logger, headerStore goheader.Store[*types.SignedHeader], dataStore goheader.Store[*types.Data], + headerBroadcaster broadcaster[*types.SignedHeader], + dataBroadcaster broadcaster[*types.Data], seqMetrics *Metrics, gasPrice float64, gasMultiplier float64, @@ -326,15 +330,15 @@ func NewManager( daH.Store(s.DAHeight) m := &Manager{ - signer: signer, - config: config, - genesis: genesis, - lastState: s, - store: store, - daHeight: &daH, + signer: signer, + config: config, + genesis: genesis, + lastState: s, + store: store, + daHeight: &daH, + headerBroadcaster: headerBroadcaster, + dataBroadcaster: dataBroadcaster, // channels are buffered to avoid blocking on input/output operations, buffer sizes are arbitrary - HeaderCh: make(chan *types.SignedHeader, channelLength), - DataCh: make(chan *types.Data, channelLength), headerInCh: make(chan NewHeaderEvent, eventInChLength), dataInCh: make(chan NewDataEvent, eventInChLength), headerStoreCh: make(chan struct{}, 1), @@ -654,24 +658,14 @@ func (m *Manager) publishBlockInternal(ctx context.Context) error { m.recordMetrics(data) - // Check for shut down event prior to sending the header and block to - // their respective channels. The reason for checking for the shutdown - // event separately is due to the inconsistent nature of the select - // statement when multiple cases are satisfied. - select { - case <-ctx.Done(): - return fmt.Errorf("unable to send header and block, context done: %w", ctx.Err()) - default: + g, ctx := errgroup.WithContext(ctx) + g.Go(func() error { return m.headerBroadcaster.WriteToStoreAndBroadcast(ctx, header) }) + g.Go(func() error { return m.dataBroadcaster.WriteToStoreAndBroadcast(ctx, data) }) + if err := g.Wait(); err != nil { + return err } - // Publish header to channel so that header exchange service can broadcast - m.HeaderCh <- header - - // Publish block to channel so that block exchange service can broadcast - m.DataCh <- data - m.logger.Debug("successfully proposed header", "proposer", hex.EncodeToString(header.ProposerAddress), "height", headerHeight) - return nil } diff --git a/block/publish_block2_test.go b/block/publish_block2_test.go new file mode 100644 index 0000000000..1b9ad02d9a --- /dev/null +++ b/block/publish_block2_test.go @@ -0,0 +1,245 @@ +package block + +import ( + "context" + cryptoRand "crypto/rand" + "errors" + "fmt" + "math/rand" + "path/filepath" + "sync" + "testing" + "time" + + "cosmossdk.io/log" + ds "github.com/ipfs/go-datastore" + ktds "github.com/ipfs/go-datastore/keytransform" + syncdb "github.com/ipfs/go-datastore/sync" + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + coresequencer "github.com/rollkit/rollkit/core/sequencer" + "github.com/rollkit/rollkit/pkg/config" + genesispkg "github.com/rollkit/rollkit/pkg/genesis" + "github.com/rollkit/rollkit/pkg/p2p" + "github.com/rollkit/rollkit/pkg/p2p/key" + "github.com/rollkit/rollkit/pkg/signer" + "github.com/rollkit/rollkit/pkg/signer/noop" + "github.com/rollkit/rollkit/pkg/store" + rollkitSync "github.com/rollkit/rollkit/pkg/sync" + "github.com/rollkit/rollkit/types" +) + +func TestSlowConsumers(t *testing.T) { + logging.SetDebugLogging() + blockTime := 100 * time.Millisecond + specs := map[string]struct { + headerConsumerDelay time.Duration + dataConsumerDelay time.Duration + }{ + "slow header consumer": { + headerConsumerDelay: blockTime * 2, + dataConsumerDelay: 0, + }, + "slow data consumer": { + headerConsumerDelay: 0, + dataConsumerDelay: blockTime * 2, + }, + "both slow": { + headerConsumerDelay: blockTime, + dataConsumerDelay: blockTime, + }, + "both fast": { + headerConsumerDelay: 0, + dataConsumerDelay: 0, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + workDir := t.TempDir() + dbm := syncdb.MutexWrap(ds.NewMapDatastore()) + ctx, cancel := context.WithCancel(t.Context()) + + pk, _, err := crypto.GenerateEd25519Key(cryptoRand.Reader) + require.NoError(t, err) + noopSigner, err := noop.NewNoopSigner(pk) + require.NoError(t, err) + + manager, headerSync, dataSync := setupBlockManager(t, ctx, workDir, dbm, blockTime, noopSigner) + var lastCapturedDataPayload *types.Data + var lastCapturedHeaderPayload *types.SignedHeader + manager.dataBroadcaster = capturingTailBroadcaster[*types.Data](spec.dataConsumerDelay, &lastCapturedDataPayload, dataSync) + manager.headerBroadcaster = capturingTailBroadcaster[*types.SignedHeader](spec.headerConsumerDelay, &lastCapturedHeaderPayload, headerSync) + + blockTime := manager.config.Node.BlockTime.Duration + aggCtx, aggCancel := context.WithCancel(ctx) + errChan := make(chan error, 1) + var wg sync.WaitGroup + wg.Add(1) + go func() { + manager.AggregationLoop(aggCtx, errChan) + wg.Done() + }() + + // wait for messages to pile up + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(spec.dataConsumerDelay + spec.headerConsumerDelay + 3*blockTime): + } + aggCancel() + wg.Wait() // await aggregation loop to finish + t.Log("shutting down block manager") + require.NoError(t, dataSync.Stop(ctx)) + require.NoError(t, headerSync.Stop(ctx)) + cancel() + require.NotNil(t, lastCapturedHeaderPayload) + require.NotNil(t, lastCapturedDataPayload) + + t.Log("restart with new block manager") + ctx, cancel = context.WithCancel(t.Context()) + manager, headerSync, dataSync = setupBlockManager(t, ctx, workDir, dbm, blockTime, noopSigner) + + var firstCapturedDataPayload *types.Data + var firstCapturedHeaderPayload *types.SignedHeader + manager.dataBroadcaster = capturingHeadBroadcaster[*types.Data](0, &firstCapturedDataPayload, dataSync) + manager.headerBroadcaster = capturingHeadBroadcaster[*types.SignedHeader](0, &firstCapturedHeaderPayload, headerSync) + go manager.AggregationLoop(ctx, errChan) + select { + case err := <-errChan: + require.NoError(t, err) + case <-time.After(spec.dataConsumerDelay + spec.headerConsumerDelay + 3*blockTime): + } + cancel() + require.NotNil(t, firstCapturedHeaderPayload) + assert.InDelta(t, lastCapturedDataPayload.Height(), firstCapturedDataPayload.Height(), 1) + require.NotNil(t, firstCapturedDataPayload) + assert.InDelta(t, lastCapturedHeaderPayload.Height(), firstCapturedHeaderPayload.Height(), 1) + }) + } +} + +func capturingTailBroadcaster[T interface{ Height() uint64 }](waitDuration time.Duration, target *T, next ...broadcaster[T]) broadcaster[T] { + var lastHeight uint64 + return broadcasterFn[T](func(ctx context.Context, payload T) error { + if payload.Height() <= lastHeight { + panic(fmt.Sprintf("got height %d, want %d", payload.Height(), lastHeight+1)) + } + + time.Sleep(waitDuration) + lastHeight = payload.Height() + *target = payload + var err error + for _, n := range next { + err = errors.Join(n.WriteToStoreAndBroadcast(ctx, payload)) + } + + return err + }) +} + +func capturingHeadBroadcaster[T interface{ Height() uint64 }](waitDuration time.Duration, target *T, next ...broadcaster[T]) broadcaster[T] { + var once sync.Once + return broadcasterFn[T](func(ctx context.Context, payload T) error { + once.Do(func() { + *target = payload + }) + var err error + for _, n := range next { + err = errors.Join(n.WriteToStoreAndBroadcast(ctx, payload)) + } + time.Sleep(waitDuration) + return err + }) +} + +type broadcasterFn[T any] func(ctx context.Context, payload T) error + +func (b broadcasterFn[T]) WriteToStoreAndBroadcast(ctx context.Context, payload T) error { + return b(ctx, payload) +} + +func setupBlockManager(t *testing.T, ctx context.Context, workDir string, mainKV ds.Batching, blockTime time.Duration, signer signer.Signer) (*Manager, *rollkitSync.HeaderSyncService, *rollkitSync.DataSyncService) { + t.Helper() + nodeConfig := config.DefaultConfig + nodeConfig.Node.BlockTime = config.DurationWrapper{Duration: blockTime} + nodeConfig.RootDir = workDir + nodeKey, err := key.LoadOrGenNodeKey(filepath.Dir(nodeConfig.ConfigPath())) + require.NoError(t, err) + + proposerAddr, err := signer.GetAddress() + require.NoError(t, err) + genesisDoc := genesispkg.Genesis{ + ChainID: "test-chain-id", + GenesisDAStartTime: time.Now(), + InitialHeight: 1, + ProposerAddress: proposerAddr, + } + + logger := log.NewTestLogger(t) + p2pClient, err := p2p.NewClient(nodeConfig, nodeKey, mainKV, logger, p2p.NopMetrics()) + require.NoError(t, err) + + // Start p2p client before creating sync service + err = p2pClient.Start(ctx) + require.NoError(t, err) + + const RollkitPrefix = "0" + ktds.Wrap(mainKV, ktds.PrefixTransform{Prefix: ds.NewKey(RollkitPrefix)}) + headerSyncService, err := rollkitSync.NewHeaderSyncService(mainKV, nodeConfig, genesisDoc, p2pClient, logger.With("module", "HeaderSyncService")) + require.NoError(t, err) + require.NoError(t, headerSyncService.Start(ctx)) + dataSyncService, err := rollkitSync.NewDataSyncService(mainKV, nodeConfig, genesisDoc, p2pClient, logger.With("module", "DataSyncService")) + require.NoError(t, err) + require.NoError(t, dataSyncService.Start(ctx)) + + result, err := NewManager( + ctx, + signer, + nodeConfig, + genesisDoc, + store.New(mainKV), + &mockExecutor{}, + coresequencer.NewDummySequencer(), + nil, + logger.With("module", "BlockManager"), + headerSyncService.Store(), + dataSyncService.Store(), + nil, + nil, + NopMetrics(), + 1., + 1., + ) + require.NoError(t, err) + return result, headerSyncService, dataSyncService +} + +type mockExecutor struct{} + +func (m mockExecutor) InitChain(ctx context.Context, genesisTime time.Time, initialHeight uint64, chainID string) (stateRoot []byte, maxBytes uint64, err error) { + return bytesN(32), 10_000, nil +} + +func (m mockExecutor) GetTxs(ctx context.Context) ([][]byte, error) { + panic("implement me") +} + +func (m mockExecutor) ExecuteTxs(ctx context.Context, txs [][]byte, blockHeight uint64, timestamp time.Time, prevStateRoot []byte) (updatedStateRoot []byte, maxBytes uint64, err error) { + return bytesN(32), 10_000, nil + +} + +func (m mockExecutor) SetFinal(ctx context.Context, blockHeight uint64) error { + return nil +} + +var rnd = rand.New(rand.NewSource(1)) //nolint:gosec // test code only + +func bytesN(n int) []byte { + data := make([]byte, n) + _, _ = rnd.Read(data) + return data +} diff --git a/block/publish_block_test.go b/block/publish_block_test.go index 84480dabba..67512c38eb 100644 --- a/block/publish_block_test.go +++ b/block/publish_block_test.go @@ -33,16 +33,7 @@ func setupManagerForPublishBlockTest( initialHeight uint64, lastSubmittedHeight uint64, logBuffer *bytes.Buffer, -) ( - *Manager, - *mocks.Store, - *mocks.Executor, - *mocks.Sequencer, - signer.Signer, - chan *types.SignedHeader, - chan *types.Data, - context.CancelFunc, -) { +) (*Manager, *mocks.Store, *mocks.Executor, *mocks.Sequencer, signer.Signer, context.CancelFunc) { require := require.New(t) mockStore := mocks.NewStore(t) @@ -60,8 +51,6 @@ func setupManagerForPublishBlockTest( cfg.Node.BlockTime.Duration = 1 * time.Second genesis := genesispkg.NewGenesis("testchain", initialHeight, time.Now(), proposerAddr) - headerCh := make(chan *types.SignedHeader, 1) - dataCh := make(chan *types.Data, 1) _, cancel := context.WithCancel(context.Background()) logger := log.NewLogger( logBuffer, @@ -74,18 +63,21 @@ func setupManagerForPublishBlockTest( var headerStore *goheaderstore.Store[*types.SignedHeader] var dataStore *goheaderstore.Store[*types.Data] - // Manager initialization (simplified, add fields as needed by tests) manager := &Manager{ - store: mockStore, - exec: mockExec, - sequencer: mockSeq, - signer: testSigner, - config: cfg, - genesis: genesis, - logger: logger, - HeaderCh: headerCh, - DataCh: dataCh, + store: mockStore, + exec: mockExec, + sequencer: mockSeq, + signer: testSigner, + config: cfg, + genesis: genesis, + logger: logger, + headerBroadcaster: broadcasterFn[*types.SignedHeader](func(ctx context.Context, payload *types.SignedHeader) error { + return nil + }), + dataBroadcaster: broadcasterFn[*types.Data](func(ctx context.Context, payload *types.Data) error { + return nil + }), headerStore: headerStore, daHeight: &atomic.Uint64{}, dataStore: dataStore, @@ -112,7 +104,7 @@ func setupManagerForPublishBlockTest( manager.lastState.LastBlockHeight = 0 } - return manager, mockStore, mockExec, mockSeq, testSigner, headerCh, dataCh, cancel + return manager, mockStore, mockExec, mockSeq, testSigner, cancel } // TestPublishBlockInternal_MaxPendingHeadersReached verifies that publishBlockInternal returns an error if the maximum number of pending headers is reached. @@ -125,7 +117,7 @@ func TestPublishBlockInternal_MaxPendingHeadersReached(t *testing.T) { maxPending := uint64(5) logBuffer := new(bytes.Buffer) - manager, mockStore, mockExec, mockSeq, _, _, _, cancel := setupManagerForPublishBlockTest(t, currentHeight+1, lastSubmitted, logBuffer) + manager, mockStore, mockExec, mockSeq, _, cancel := setupManagerForPublishBlockTest(t, currentHeight+1, lastSubmitted, logBuffer) defer cancel() manager.config.Node.MaxPendingHeaders = maxPending @@ -269,9 +261,13 @@ func Test_publishBlock_EmptyBatch(t *testing.T) { }, headerCache: cache.NewCache[types.SignedHeader](), dataCache: cache.NewCache[types.Data](), - HeaderCh: make(chan *types.SignedHeader, 1), - DataCh: make(chan *types.Data, 1), - daHeight: &daH, + headerBroadcaster: broadcasterFn[*types.SignedHeader](func(ctx context.Context, payload *types.SignedHeader) error { + return nil + }), + dataBroadcaster: broadcasterFn[*types.Data](func(ctx context.Context, payload *types.Data) error { + return nil + }), + daHeight: &daH, } m.publishBlock = m.publishBlockInternal @@ -345,7 +341,7 @@ func Test_publishBlock_Success(t *testing.T) { newHeight := initialHeight + 1 chainID := "testchain" - manager, mockStore, mockExec, mockSeq, _, headerCh, dataCh, _ := setupManagerForPublishBlockTest(t, initialHeight, 0, new(bytes.Buffer)) + manager, mockStore, mockExec, mockSeq, _, _ := setupManagerForPublishBlockTest(t, initialHeight, 0, new(bytes.Buffer)) manager.lastState.LastBlockHeight = initialHeight mockStore.On("Height", t.Context()).Return(initialHeight, nil).Once() @@ -361,6 +357,25 @@ func Test_publishBlock_Success(t *testing.T) { mockStore.On("UpdateState", t.Context(), mock.AnythingOfType("types.State")).Return(nil).Once() mockStore.On("SetMetadata", t.Context(), LastBatchDataKey, mock.AnythingOfType("[]uint8")).Return(nil).Once() + headerCh := make(chan *types.SignedHeader, 1) + manager.headerBroadcaster = broadcasterFn[*types.SignedHeader](func(ctx context.Context, payload *types.SignedHeader) error { + select { + case headerCh <- payload: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + dataCh := make(chan *types.Data, 1) + manager.dataBroadcaster = broadcasterFn[*types.Data](func(ctx context.Context, payload *types.Data) error { + select { + case dataCh <- payload: + return nil + case <-ctx.Done(): + return ctx.Err() + } + }) + // --- Mock Executor --- sampleTxs := [][]byte{[]byte("tx1"), []byte("tx2")} // No longer mocking GetTxs since it's handled by reaper.go @@ -383,7 +398,6 @@ func Test_publishBlock_Success(t *testing.T) { mockSeq.On("GetNextBatch", t.Context(), batchReqMatcher).Return(batchResponse, nil).Once() err := manager.publishBlock(t.Context()) require.NoError(err, "publishBlock should succeed") - select { case publishedHeader := <-headerCh: assert.Equal(t, newHeight, publishedHeader.Height(), "Published header height mismatch") diff --git a/block/store_test.go b/block/store_test.go index ee1c07fa1d..cd01e6770d 100644 --- a/block/store_test.go +++ b/block/store_test.go @@ -12,16 +12,16 @@ import ( "cosmossdk.io/log" ds "github.com/ipfs/go-datastore" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/rollkit/rollkit/pkg/config" "github.com/rollkit/rollkit/pkg/signer/noop" - // Use existing store mock if available, or define one mocksStore "github.com/rollkit/rollkit/test/mocks" extmocks "github.com/rollkit/rollkit/test/mocks/external" "github.com/rollkit/rollkit/types" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" ) func setupManagerForStoreRetrieveTest(t *testing.T) ( diff --git a/go.mod b/go.mod index 6dc32963b7..23de5397a2 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/goccy/go-yaml v1.17.1 github.com/ipfs/go-datastore v0.8.2 github.com/ipfs/go-ds-badger4 v0.1.8 + github.com/ipfs/go-log/v2 v2.5.1 github.com/libp2p/go-libp2p v0.41.1 github.com/libp2p/go-libp2p-kad-dht v0.29.1 github.com/libp2p/go-libp2p-pubsub v0.13.1 @@ -30,6 +31,7 @@ require ( github.com/stretchr/testify v1.10.0 golang.org/x/crypto v0.37.0 golang.org/x/net v0.38.0 + golang.org/x/sync v0.13.0 google.golang.org/protobuf v1.36.6 ) @@ -76,7 +78,6 @@ require ( github.com/ipfs/boxo v0.27.4 // indirect github.com/ipfs/go-cid v0.5.0 // indirect github.com/ipfs/go-log v1.0.5 // indirect - github.com/ipfs/go-log/v2 v2.5.1 // indirect github.com/ipld/go-ipld-prime v0.21.0 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect @@ -171,7 +172,6 @@ require ( golang.org/x/arch v0.15.0 // indirect golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/mod v0.24.0 // indirect - golang.org/x/sync v0.13.0 // indirect golang.org/x/sys v0.32.0 // indirect golang.org/x/text v0.24.0 // indirect golang.org/x/time v0.9.0 // indirect diff --git a/node/full.go b/node/full.go index dce321dd36..062e7a6599 100644 --- a/node/full.go +++ b/node/full.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "net/http/pprof" + "sync" "time" "cosmossdk.io/log" @@ -27,7 +28,7 @@ import ( "github.com/rollkit/rollkit/pkg/service" "github.com/rollkit/rollkit/pkg/signer" "github.com/rollkit/rollkit/pkg/store" - "github.com/rollkit/rollkit/pkg/sync" + rollkitsync "github.com/rollkit/rollkit/pkg/sync" ) // prefixes used in KV store to separate rollkit data from execution environment data (if the same data base is reused) @@ -55,8 +56,8 @@ type FullNode struct { da coreda.DA p2pClient *p2p.Client - hSyncService *sync.HeaderSyncService - dSyncService *sync.DataSyncService + hSyncService *rollkitsync.HeaderSyncService + dSyncService *rollkitsync.DataSyncService Store store.Store blockManager *block.Manager reaper *block.Reaper @@ -93,7 +94,7 @@ func newFullNode( return nil, err } - store := store.New(mainKV) + rktStore := store.New(mainKV) blockManager, err := initBlockManager( ctx, @@ -101,7 +102,7 @@ func newFullNode( exec, nodeConfig, genesis, - store, + rktStore, sequencer, da, logger, @@ -135,7 +136,7 @@ func newFullNode( blockManager: blockManager, reaper: reaper, da: da, - Store: store, + Store: rktStore, hSyncService: headerSyncService, dSyncService: dataSyncService, } @@ -151,8 +152,8 @@ func initHeaderSyncService( genesis genesispkg.Genesis, p2pClient *p2p.Client, logger log.Logger, -) (*sync.HeaderSyncService, error) { - headerSyncService, err := sync.NewHeaderSyncService(mainKV, nodeConfig, genesis, p2pClient, logger.With("module", "HeaderSyncService")) +) (*rollkitsync.HeaderSyncService, error) { + headerSyncService, err := rollkitsync.NewHeaderSyncService(mainKV, nodeConfig, genesis, p2pClient, logger.With("module", "HeaderSyncService")) if err != nil { return nil, fmt.Errorf("error while initializing HeaderSyncService: %w", err) } @@ -165,8 +166,8 @@ func initDataSyncService( genesis genesispkg.Genesis, p2pClient *p2p.Client, logger log.Logger, -) (*sync.DataSyncService, error) { - dataSyncService, err := sync.NewDataSyncService(mainKV, nodeConfig, genesis, p2pClient, logger.With("module", "DataSyncService")) +) (*rollkitsync.DataSyncService, error) { + dataSyncService, err := rollkitsync.NewDataSyncService(mainKV, nodeConfig, genesis, p2pClient, logger.With("module", "DataSyncService")) if err != nil { return nil, fmt.Errorf("error while initializing DataSyncService: %w", err) } @@ -191,8 +192,8 @@ func initBlockManager( sequencer coresequencer.Sequencer, da coreda.DA, logger log.Logger, - headerSyncService *sync.HeaderSyncService, - dataSyncService *sync.DataSyncService, + headerSyncService *rollkitsync.HeaderSyncService, + dataSyncService *rollkitsync.DataSyncService, seqMetrics *block.Metrics, gasPrice float64, gasMultiplier float64, @@ -211,6 +212,8 @@ func initBlockManager( logger.With("module", "BlockManager"), headerSyncService.Store(), dataSyncService.Store(), + headerSyncService, + dataSyncService, seqMetrics, gasPrice, gasMultiplier, @@ -242,38 +245,6 @@ func (n *FullNode) initGenesisChunks() error { return nil } -func (n *FullNode) headerPublishLoop(ctx context.Context) { - for { - select { - case signedHeader := <-n.blockManager.HeaderCh: - err := n.hSyncService.WriteToStoreAndBroadcast(ctx, signedHeader) - if err != nil { - // failed to init or start headerstore - n.Logger.Error(err.Error()) - return - } - case <-ctx.Done(): - return - } - } -} - -func (n *FullNode) dataPublishLoop(ctx context.Context) { - for { - select { - case data := <-n.blockManager.DataCh: - err := n.dSyncService.WriteToStoreAndBroadcast(ctx, data) - if err != nil { - // failed to init or start blockstore - n.Logger.Error(err.Error()) - return - } - case <-ctx.Done(): - return - } - } -} - // startInstrumentationServer starts HTTP servers for instrumentation (Prometheus metrics and pprof). // Returns the primary server (Prometheus if enabled, otherwise pprof) and optionally a secondary server. func (n *FullNode) startInstrumentationServer() (*http.Server, *http.Server) { @@ -298,8 +269,7 @@ func (n *FullNode) startInstrumentationServer() (*http.Server, *http.Server) { } go func() { - if err := prometheusServer.ListenAndServe(); err != http.ErrServerClosed { - // Error starting or closing listener: + if err := prometheusServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { n.Logger.Error("Prometheus HTTP server ListenAndServe", "err", err) } }() @@ -332,8 +302,7 @@ func (n *FullNode) startInstrumentationServer() (*http.Server, *http.Server) { } go func() { - if err := pprofServer.ListenAndServe(); err != http.ErrServerClosed { - // Error starting or closing listener: + if err := pprofServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { n.Logger.Error("pprof HTTP server ListenAndServe", "err", err) } }() @@ -375,11 +344,10 @@ func (n *FullNode) Run(parentCtx context.Context) error { } go func() { - err := n.rpcServer.ListenAndServe() - if err != nil && err != http.ErrServerClosed { + n.Logger.Info("started RPC server", "addr", n.nodeConfig.RPC.Address) + if err := n.rpcServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { n.Logger.Error("RPC server error", "err", err) } - n.Logger.Info("started RPC server", "addr", n.nodeConfig.RPC.Address) }() n.Logger.Info("starting P2P client") @@ -399,22 +367,28 @@ func (n *FullNode) Run(parentCtx context.Context) error { // only the first error is propagated // any error is an issue, so blocking is not a problem errCh := make(chan error, 1) - + // prepare to join the go routines later + var wg sync.WaitGroup + spawnWorker := func(f func()) { + wg.Add(1) + go func() { + defer wg.Done() + f() + }() + } if n.nodeConfig.Node.Aggregator { n.Logger.Info("working in aggregator mode", "block time", n.nodeConfig.Node.BlockTime) - go n.blockManager.AggregationLoop(ctx, errCh) - go n.reaper.Start(ctx) - go n.blockManager.HeaderSubmissionLoop(ctx) - go n.blockManager.BatchSubmissionLoop(ctx) - go n.headerPublishLoop(ctx) - go n.dataPublishLoop(ctx) - go n.blockManager.DAIncluderLoop(ctx, errCh) + spawnWorker(func() { n.blockManager.AggregationLoop(ctx, errCh) }) + spawnWorker(func() { n.reaper.Start(ctx) }) + spawnWorker(func() { n.blockManager.HeaderSubmissionLoop(ctx) }) + spawnWorker(func() { n.blockManager.BatchSubmissionLoop(ctx) }) + spawnWorker(func() { n.blockManager.DAIncluderLoop(ctx, errCh) }) } else { - go n.blockManager.RetrieveLoop(ctx) - go n.blockManager.HeaderStoreRetrieveLoop(ctx) - go n.blockManager.DataStoreRetrieveLoop(ctx) - go n.blockManager.SyncLoop(ctx, errCh) - go n.blockManager.DAIncluderLoop(ctx, errCh) + spawnWorker(func() { n.blockManager.RetrieveLoop(ctx) }) + spawnWorker(func() { n.blockManager.HeaderStoreRetrieveLoop(ctx) }) + spawnWorker(func() { n.blockManager.DataStoreRetrieveLoop(ctx) }) + spawnWorker(func() { n.blockManager.SyncLoop(ctx, errCh) }) + spawnWorker(func() { n.blockManager.DAIncluderLoop(ctx, errCh) }) } select { @@ -431,9 +405,12 @@ func (n *FullNode) Run(parentCtx context.Context) error { // Perform cleanup n.Logger.Info("halting full node and its sub services...") + // wait for all worker Go routines to finish so that we have + // no in-flight tasks while shutting down + wg.Wait() // Use a timeout context to ensure shutdown doesn't hang - shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + shutdownCtx, cancel := context.WithTimeout(context.Background(), 9*time.Second) defer cancel() var multiErr error // Use a multierror variable diff --git a/pkg/sync/sync_service.go b/pkg/sync/sync_service.go index df43cd4ebf..d46fc539ec 100644 --- a/pkg/sync/sync_service.go +++ b/pkg/sync/sync_service.go @@ -44,12 +44,13 @@ type SyncService[H header.Header[H]] struct { p2p *p2p.Client - ex *goheaderp2p.Exchange[H] - sub *goheaderp2p.Subscriber[H] - p2pServer *goheaderp2p.ExchangeServer[H] - store *goheaderstore.Store[H] - syncer *goheadersync.Syncer[H] - syncerStatus *SyncerStatus + ex *goheaderp2p.Exchange[H] + sub *goheaderp2p.Subscriber[H] + p2pServer *goheaderp2p.ExchangeServer[H] + store *goheaderstore.Store[H] + syncer *goheadersync.Syncer[H] + syncerStatus *SyncerStatus + topicSubscription header.Subscription[H] } // DataSyncService is the P2P Sync Service for blocks. @@ -204,7 +205,7 @@ func (syncService *SyncService[H]) setupP2P(ctx context.Context) ([]peer.ID, err if err := syncService.sub.Start(ctx); err != nil { return nil, fmt.Errorf("error while starting subscriber: %w", err) } - if _, err := syncService.sub.Subscribe(); err != nil { + if syncService.topicSubscription, err = syncService.sub.Subscribe(); err != nil { return nil, fmt.Errorf("error while subscribing: %w", err) } if err := syncService.store.Start(ctx); err != nil { @@ -294,6 +295,8 @@ func (syncService *SyncService[H]) setFirstAndStart(ctx context.Context, peerIDs // // `store` is closed last because it's used by other services. func (syncService *SyncService[H]) Stop(ctx context.Context) error { + // unsubscribe from topic first so that sub.Stop() does not fail + syncService.topicSubscription.Cancel() err := errors.Join( syncService.p2pServer.Stop(ctx), syncService.ex.Stop(ctx), diff --git a/pkg/sync/sync_service_test.go b/pkg/sync/sync_service_test.go new file mode 100644 index 0000000000..3dcadaba18 --- /dev/null +++ b/pkg/sync/sync_service_test.go @@ -0,0 +1,129 @@ +package sync + +import ( + "context" + sdklog "cosmossdk.io/log" + cryptoRand "crypto/rand" + "github.com/ipfs/go-datastore" + "github.com/ipfs/go-datastore/sync" + logging "github.com/ipfs/go-log/v2" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/rollkit/rollkit/pkg/config" + genesispkg "github.com/rollkit/rollkit/pkg/genesis" + "github.com/rollkit/rollkit/pkg/p2p" + "github.com/rollkit/rollkit/pkg/p2p/key" + "github.com/rollkit/rollkit/pkg/signer" + "github.com/rollkit/rollkit/pkg/signer/noop" + "github.com/rollkit/rollkit/types" + "github.com/stretchr/testify/require" + "math/rand" + "path/filepath" + "testing" + "time" +) + +func TestHeaderSyncServiceRestart(t *testing.T) { + logging.SetDebugLogging() + mainKV := sync.MutexWrap(datastore.NewMapDatastore()) + pk, _, err := crypto.GenerateEd25519Key(cryptoRand.Reader) + require.NoError(t, err) + noopSigner, err := noop.NewNoopSigner(pk) + require.NoError(t, err) + rnd := rand.New(rand.NewSource(1)) // nolint:gosec // test code only + + proposerAddr := []byte("test") + genesisDoc := genesispkg.Genesis{ + ChainID: "test-chain-id", + GenesisDAStartTime: time.Now(), + InitialHeight: 1, + ProposerAddress: proposerAddr, + } + conf := config.DefaultConfig + conf.RootDir = t.TempDir() + nodeKey, err := key.LoadOrGenNodeKey(filepath.Dir(conf.ConfigPath())) + require.NoError(t, err) + logger := sdklog.NewTestLogger(t) + p2pClient, err := p2p.NewClient(conf, nodeKey, mainKV, logger, p2p.NopMetrics()) + require.NoError(t, err) + + // Start p2p client before creating sync service + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + err = p2pClient.Start(ctx) + require.NoError(t, err) + + svc, err := NewHeaderSyncService(mainKV, conf, genesisDoc, p2pClient, logger) + require.NoError(t, err) + err = svc.Start(ctx) + require.NoError(t, err) + + // broadcast genesis block + headerConfig := types.HeaderConfig{ + Height: genesisDoc.InitialHeight, + DataHash: bytesN(rnd, 32), + AppHash: bytesN(rnd, 32), + Signer: noopSigner, + } + signedHeader, err := types.GetRandomSignedHeaderCustom(&headerConfig, genesisDoc.ChainID) + require.NoError(t, err) + require.NoError(t, signedHeader.Validate()) + require.NoError(t, svc.WriteToStoreAndBroadcast(ctx, signedHeader)) + + // broadcast another 10 example blocks + for i := genesisDoc.InitialHeight + 1; i < 10; i++ { + signedHeader = nextHeader(t, signedHeader, genesisDoc.ChainID, noopSigner) + t.Logf("signed header: %d", i) + require.NoError(t, svc.WriteToStoreAndBroadcast(ctx, signedHeader)) + } + + // then stop and restart service + _ = p2pClient.Close() + _ = svc.Stop(ctx) + cancel() + + p2pClient, err = p2p.NewClient(conf, nodeKey, mainKV, logger, p2p.NopMetrics()) + require.NoError(t, err) + + // Start p2p client again + ctx, cancel = context.WithCancel(t.Context()) + defer cancel() + err = p2pClient.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { _ = p2pClient.Close() }) + + svc, err = NewHeaderSyncService(mainKV, conf, genesisDoc, p2pClient, logger) + require.NoError(t, err) + err = svc.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { _ = svc.Stop(context.Background()) }) + // done with stop and restart service + + // broadcast another 10 example blocks + for i := signedHeader.Height() + 1; i < 10; i++ { + signedHeader = nextHeader(t, signedHeader, genesisDoc.ChainID, noopSigner) + t.Logf("signed header: %d", i) + require.NoError(t, svc.WriteToStoreAndBroadcast(ctx, signedHeader)) + } + cancel() +} + +func nextHeader(t *testing.T, previousHeader *types.SignedHeader, chainID string, noopSigner signer.Signer) *types.SignedHeader { + newSignedHeader := &types.SignedHeader{ + Header: types.GetRandomNextHeader(previousHeader.Header, chainID), + Signer: previousHeader.Signer, + } + b, err := newSignedHeader.Header.MarshalBinary() + require.NoError(t, err) + signature, err := noopSigner.Sign(b) + require.NoError(t, err) + newSignedHeader.Signature = signature + require.NoError(t, newSignedHeader.Validate()) + previousHeader = newSignedHeader + return previousHeader +} + +func bytesN(r *rand.Rand, n int) []byte { + data := make([]byte, n) + _, _ = r.Read(data) + return data +} diff --git a/test/e2e/base_test.go b/test/e2e/base_test.go index c63de94d85..eec159dd7e 100644 --- a/test/e2e/base_test.go +++ b/test/e2e/base_test.go @@ -41,7 +41,7 @@ func TestBasic(t *testing.T) { // start local da localDABinary := filepath.Join(filepath.Dir(binaryPath), "local-da") - sut.StartNode(localDABinary) + sut.ExecCmd(localDABinary) // Wait a moment for the local DA to initialize time.Sleep(500 * time.Millisecond) @@ -57,7 +57,7 @@ func TestBasic(t *testing.T) { require.NoError(t, err, "failed to init aggregator", output) // start aggregator - sut.StartNode(binaryPath, + sut.ExecCmd(binaryPath, "start", "--home="+node1Home, "--chain_id=testing", @@ -86,7 +86,7 @@ func TestBasic(t *testing.T) { // Start the full node node2RPC := "127.0.0.1:7332" node2P2P := "/ip4/0.0.0.0/tcp/7676" - sut.StartNode( + sut.ExecCmd( binaryPath, "start", "--home="+node2Home, @@ -124,14 +124,7 @@ func TestBasic(t *testing.T) { time.Sleep(2 * time.Second) // verify a block has been produced - c := nodeclient.NewClient("http://127.0.0.1:7331") - require.NoError(t, err) - - ctx, done = context.WithTimeout(context.Background(), time.Second) - defer done() - state, err := c.GetState(ctx) - require.NoError(t, err) - require.Greater(t, state.LastBlockHeight, uint64(1)) + sut.AwaitNBlocks(t, 1, "http://127.0.0.1:7331", 1*time.Second) } func TestNodeRestartPersistence(t *testing.T) { @@ -145,7 +138,7 @@ func TestNodeRestartPersistence(t *testing.T) { // Start local DA if needed localDABinary := filepath.Join(filepath.Dir(binaryPath), "local-da") - sut.StartNode(localDABinary) + sut.ExecCmd(localDABinary) time.Sleep(500 * time.Millisecond) // Init node @@ -159,7 +152,7 @@ func TestNodeRestartPersistence(t *testing.T) { require.NoError(t, err, "failed to init node", output) // Start node - sut.StartNode(binaryPath, + sut.ExecCmd(binaryPath, "start", "--home="+nodeHome, "--chain_id=testing", @@ -173,24 +166,27 @@ func TestNodeRestartPersistence(t *testing.T) { t.Log("Node started and is up.") c := nodeclient.NewClient("http://127.0.0.1:7331") - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + t.Cleanup(cancel) + state, err := c.GetState(ctx) require.NoError(t, err) require.Greater(t, state.LastBlockHeight, uint64(1)) - // Wait for a block to be produced - time.Sleep(1 * time.Second) + // Wait for some blocks to be produced + sut.AwaitNBlocks(t, 2, "http://127.0.0.1:7331", 2*time.Second) - // Shutdown node - sut.Shutdown() + // Shutdown all nodes but keep local da running + sut.ShutdownByCmd(binaryPath) t.Log("Node stopped.") // Wait a moment to ensure shutdown - time.Sleep(500 * time.Millisecond) + require.Eventually(t, func() bool { + return sut.HasProcess(binaryPath) + }, 500*time.Millisecond, 10*time.Millisecond) // Restart node - sut.StartNode(binaryPath, + sut.ExecCmd(binaryPath, "start", "--home="+nodeHome, "--chain_id=testing", @@ -202,13 +198,6 @@ func TestNodeRestartPersistence(t *testing.T) { ) sut.AwaitNodeUp(t, "http://127.0.0.1:7331", 2*time.Second) t.Log("Node restarted and is up.") - - // Wait for a block to be produced after restart - time.Sleep(1 * time.Second) - - ctx2, cancel2 := context.WithTimeout(context.Background(), time.Second) - defer cancel2() - state2, err := c.GetState(ctx2) - require.NoError(t, err) - require.Greater(t, state2.LastBlockHeight, state.LastBlockHeight) + // Wait for some blocks to be produced after restart + sut.AwaitNBlocks(t, 2, "http://127.0.0.1:7331", 2*time.Second) } diff --git a/test/e2e/sut_helper.go b/test/e2e/sut_helper.go index dd3afbfdbc..b5c0bd8049 100644 --- a/test/e2e/sut_helper.go +++ b/test/e2e/sut_helper.go @@ -6,10 +6,12 @@ import ( "context" "fmt" "io" + "iter" "maps" "os" "os/exec" "path/filepath" + "slices" "strings" "sync" "syscall" @@ -22,6 +24,7 @@ import ( "github.com/rollkit/rollkit/pkg/p2p/key" "github.com/rollkit/rollkit/pkg/rpc/client" + pb "github.com/rollkit/rollkit/types/pb/rollkit/v1" ) // WorkDir defines the default working directory for spawned processes. @@ -34,19 +37,22 @@ type SystemUnderTest struct { outBuff *ring.Ring errBuff *ring.Ring - pidsLock sync.RWMutex - pids map[int]struct{} + pidsLock sync.RWMutex + pids map[int]struct{} + cmdToPids map[string][]int + debug bool } // NewSystemUnderTest constructor func NewSystemUnderTest(t *testing.T) *SystemUnderTest { r := &SystemUnderTest{ - t: t, - pids: make(map[int]struct{}), - outBuff: ring.New(100), - errBuff: ring.New(100), + t: t, + pids: make(map[int]struct{}), + cmdToPids: make(map[string][]int), + outBuff: ring.New(100), + errBuff: ring.New(100), } - t.Cleanup(r.Shutdown) + t.Cleanup(r.ShutdownAll) return r } @@ -62,10 +68,11 @@ func (s *SystemUnderTest) RunCmd(cmd string, args ...string) (string, error) { return string(combinedOutput), err } -// StartNode starts a process for the given command and manages it cleanup on test end. -func (s *SystemUnderTest) StartNode(cmd string, args ...string) { +// ExecCmd starts a process for the given command and manages it cleanup on test end. +func (s *SystemUnderTest) ExecCmd(cmd string, args ...string) { + executable := locateExecutable(cmd) c := exec.Command( //nolint:gosec // used by tests only - locateExecutable(cmd), + executable, args..., ) c.Dir = WorkDir @@ -73,7 +80,9 @@ func (s *SystemUnderTest) StartNode(cmd string, args ...string) { err := c.Start() require.NoError(s.t, err) - + if s.debug { + s.logf("Exec cmd (pid: %d): %s %s", c.Process.Pid, executable, strings.Join(c.Args, " ")) + } // cleanup when stopped s.awaitProcessCleanup(c) } @@ -84,49 +93,54 @@ func (s *SystemUnderTest) AwaitNodeUp(t *testing.T, rpcAddr string, timeout time t.Logf("Await node is up: %s", rpcAddr) ctx, done := context.WithTimeout(context.Background(), timeout) defer done() - - started := make(chan struct{}, 1) - go func() { // query for a non empty block on status page - t.Logf("Checking node state: %s\n", rpcAddr) - for { - con := client.NewClient(rpcAddr) - if con == nil { - time.Sleep(100 * time.Millisecond) - continue - } - _, err := con.GetHealth(ctx) - if err != nil { - time.Sleep(100 * time.Millisecond) - continue - } - started <- struct{}{} - return - } - }() - select { - case <-started: - case <-ctx.Done(): - if !assert.NoError(t, ctx.Err()) { - s.PrintBuffer() - s.t.FailNow() - } - case <-time.NewTimer(timeout).C: - s.PrintBuffer() - t.Fatalf("timeout waiting for node start: %s", timeout) - } + require.EventuallyWithT(t, func(t *assert.CollectT) { + c := client.NewClient(rpcAddr) + require.NotNil(t, c) + _, err := c.GetHealth(ctx) + require.NoError(t, err) + }, timeout, timeout/10, "node is not up") +} +func (s *SystemUnderTest) AwaitNBlocks(t *testing.T, n uint64, rpcAddr string, timeout time.Duration) { + t.Helper() + ctx, done := context.WithTimeout(context.Background(), timeout) + defer done() + var c *client.Client + require.EventuallyWithT(t, func(t *assert.CollectT) { + c = client.NewClient(rpcAddr) + require.NotNil(t, c) + }, timeout, 50*time.Millisecond, "client is not setup") + var baseState *pb.State + require.EventuallyWithT(t, func(t *assert.CollectT) { + s, err := c.GetState(ctx) + require.NoError(t, err) + baseState = s + }, timeout, 50*time.Millisecond, "client is not setup") + require.EventuallyWithT(t, func(t *assert.CollectT) { + s, err := c.GetState(ctx) + require.NoError(t, err) + assert.GreaterOrEqual(t, s.LastBlockHeight, baseState.LastBlockHeight+n) + }, timeout, 50*time.Millisecond, "client is not setup") } func (s *SystemUnderTest) awaitProcessCleanup(cmd *exec.Cmd) { pid := cmd.Process.Pid s.pidsLock.Lock() s.pids[pid] = struct{}{} + cmdKey := filepath.Base(cmd.Path) + s.cmdToPids[cmdKey] = append(s.cmdToPids[cmdKey], pid) s.pidsLock.Unlock() go func() { _ = cmd.Wait() // blocks until shutdown - s.logf("Node stopped: %d\n", pid) + s.logf("Process stopped, pid: %d\n", pid) s.pidsLock.Lock() + defer s.pidsLock.Unlock() delete(s.pids, pid) - s.pidsLock.Unlock() + remainingPids := slices.DeleteFunc(s.cmdToPids[cmdKey], func(p int) bool { return p == pid }) + if len(remainingPids) == 0 { + delete(s.cmdToPids, cmdKey) + } else { + s.cmdToPids[cmdKey] = remainingPids + } }() } @@ -186,48 +200,92 @@ func (s *SystemUnderTest) logf(msg string, args ...any) { s.log(fmt.Sprintf(msg, args...)) } -func (s *SystemUnderTest) hashPids() bool { +func (s *SystemUnderTest) HasProcess(cmds ...string) bool { s.pidsLock.RLock() defer s.pidsLock.RUnlock() - return len(s.pids) != 0 + if len(cmds) == 0 { + return len(s.pids) != 0 + } + for _, cmd := range cmds { + if len(s.cmdToPids[filepath.Base(cmd)]) != 0 { + return true + } + } + return false } -func (s *SystemUnderTest) withEachPid(cb func(p *os.Process)) { - s.pidsLock.RLock() - pids := maps.Keys(s.pids) - s.pidsLock.RUnlock() +// ShutdownAll stops all processes managed by the SystemUnderTest by sending SIGTERM and SIGKILL signals if necessary. +func (s *SystemUnderTest) ShutdownAll() { + s.gracefulStopProcesses(s.iterAllProcesses) +} - for pid := range pids { - p, err := os.FindProcess(pid) - if err != nil { - continue - } - cb(p) - } +// ShutdownByCmd stops all processes associated with the specified command by sending SIGTERM and SIGKILL if needed. +func (s *SystemUnderTest) ShutdownByCmd(cmd string) { + s.gracefulStopProcesses(func() iter.Seq[*os.Process] { return s.iterProcessesByCmd(cmd) }) } -// Shutdown stops all processes managed by the SystemUnderTest by sending SIGTERM and SIGKILL signals if necessary. -func (s *SystemUnderTest) Shutdown() { - s.withEachPid(func(p *os.Process) { - go func() { +func (s *SystemUnderTest) gracefulStopProcesses(iterFn func() iter.Seq[*os.Process]) { + for p := range iterFn() { + go func(p *os.Process) { if err := p.Signal(syscall.SIGTERM); err != nil { s.logf("failed to stop node with pid %d: %s\n", p.Pid, err) } - }() - }) + }(p) + } + + // await graceful shutdown for range 5 { - if !s.hashPids() { + if !s.HasProcess() { break } time.Sleep(50 * time.Millisecond) } - - s.withEachPid(func(p *os.Process) { + // kill remaining processes if necessary + for p := range iterFn() { s.logf("killing node %d\n", p.Pid) if err := p.Kill(); err != nil { s.logf("failed to kill node with pid %d: %s\n", p.Pid, err) } - }) + } +} + +// iterAllProcesses returns an iterator over all processes currently managed by the SystemUnderTest instance. +func (s *SystemUnderTest) iterAllProcesses() iter.Seq[*os.Process] { + return func(yield func(*os.Process) bool) { + s.pidsLock.RLock() + pids := maps.Keys(s.pids) + s.pidsLock.RUnlock() + + for pid := range pids { + p, err := os.FindProcess(pid) + if err != nil { + continue + } + if !yield(p) { + break + } + } + } +} + +// iterProcessesByCmd returns an iterator over processes associated with the specified command. +func (s *SystemUnderTest) iterProcessesByCmd(cmd string) iter.Seq[*os.Process] { + cmdKey := filepath.Base(cmd) + return func(yield func(*os.Process) bool) { + s.pidsLock.RLock() + pids := slices.Clone(s.cmdToPids[cmdKey]) + s.pidsLock.RUnlock() + + for pid := range pids { + p, err := os.FindProcess(pid) + if err != nil { + continue + } + if !yield(p) { + break + } + } + } } // locateExecutable looks up the binary on the OS path.