diff --git a/CHANGELOG.md b/CHANGELOG.md index cfd924b0f..0b6576431 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The following emojis are used to highlight certain changes: - `routing/http`: `GET /routing/v1/dht/closest/peers/{key}` per [IPIP-476](https://github.com/ipfs/specs/pull/476) - upgrade to `go-libp2p-kad-dht` [v0.36.0](https://github.com/libp2p/go-libp2p-kad-dht/releases/tag/v0.36.0) +- `ipld/merkledag`: Added fetched node size reporting to the progress tracker. See [kubo#8915](https://github.com/ipfs/kubo/issues/8915) ### Changed diff --git a/ipld/merkledag/merkledag.go b/ipld/merkledag/merkledag.go index 7e21383df..590b4b1fd 100644 --- a/ipld/merkledag/merkledag.go +++ b/ipld/merkledag/merkledag.go @@ -132,6 +132,20 @@ func GetLinksDirect(serv format.NodeGetter) GetLinks { } } +// GetLinksDirectWithProgressTracker creates a function as GetLinksDirect, but +// updates the ProgressTracker with the raw block data size of the retrieved node. +func GetLinksDirectWithProgressTracker(serv format.NodeGetter, tracker *ProgressTracker) GetLinks { + return func(ctx context.Context, c cid.Cid) ([]*format.Link, error) { + nd, err := serv.Get(ctx, c) + if err != nil { + return nil, err + } + // We don't use Size() as it returns cumulative size including linked nodes. + tracker.Update(uint64(len(nd.RawData()))) + return nd.Links(), nil + } +} + type sesGetter struct { bs *bserv.Session decoder *legacy.Decoder @@ -208,20 +222,13 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s // We default to Concurrent() walk. opts = append([]WalkOption{Concurrent()}, opts...) - // If we have a ProgressTracker, we wrap the visit function to handle it. + // If we have a ProgressTracker, we wrap the get links function to handle it. v, _ := ctx.Value(progressContextKey).(*ProgressTracker) if v == nil { return WalkDepth(ctx, GetLinksDirect(ng), root, visit, opts...) } - visitProgress := func(c cid.Cid, depth int) bool { - if visit(c, depth) { - v.Increment() - return true - } - return false - } - return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, opts...) + return WalkDepth(ctx, GetLinksDirectWithProgressTracker(ng, v), root, visit, opts...) } // GetMany gets many nodes from the DAG at once. @@ -457,10 +464,18 @@ func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, d return nil } +// ProgressStat represents the progress of a fetch operation. +type ProgressStat struct { + // Nodes is the total number of nodes fetched. + Nodes int + // Bytes is the total bytes of raw block data. + Bytes uint64 +} + // ProgressTracker is used to show progress when fetching nodes. type ProgressTracker struct { - Total int - lk sync.Mutex + stat ProgressStat + lk sync.Mutex } // DeriveContext returns a new context with value "progress" derived from the @@ -469,18 +484,26 @@ func (p *ProgressTracker) DeriveContext(ctx context.Context) context.Context { return context.WithValue(ctx, progressContextKey, p) } -// Increment adds one to the total progress. -func (p *ProgressTracker) Increment() { +// Update adds one to the total nodes and updates the total bytes. +func (p *ProgressTracker) Update(bytes uint64) { p.lk.Lock() defer p.lk.Unlock() - p.Total++ + p.stat.Nodes++ + p.stat.Bytes += bytes } // Value returns the current progress. func (p *ProgressTracker) Value() int { p.lk.Lock() defer p.lk.Unlock() - return p.Total + return p.stat.Nodes +} + +// ProgressStat returns the current progress stat. +func (p *ProgressTracker) ProgressStat() ProgressStat { + p.lk.Lock() + defer p.lk.Unlock() + return p.stat } func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *walkOptions) error { diff --git a/ipld/merkledag/merkledag_test.go b/ipld/merkledag/merkledag_test.go index 08b240d47..6ac926a61 100644 --- a/ipld/merkledag/merkledag_test.go +++ b/ipld/merkledag/merkledag_test.go @@ -1161,7 +1161,7 @@ func TestProgressIndicatorNoChildren(t *testing.T) { func testProgressIndicator(t *testing.T, depth int) { ds := dstest.Mock() - top, numChildren := mkDag(ds, depth) + top, numChildren, totalSize := mkDag(ds, depth) v := new(ProgressTracker) ctx := v.DeriveContext(context.Background()) @@ -1175,9 +1175,19 @@ func testProgressIndicator(t *testing.T, depth int) { t.Errorf("wrong number of children reported in progress indicator, expected %d, got %d", numChildren+1, v.Value()) } + + if v.ProgressStat().Nodes != numChildren+1 { + t.Errorf("wrong number of children reported in progress stat indicator, expected %d, got %d", + numChildren+1, v.ProgressStat().Nodes) + } + + if v.ProgressStat().Bytes != totalSize { + t.Errorf("wrong bytes reported in progress stat indicator, expected %d, got %d", + totalSize, v.ProgressStat().Bytes) + } } -func mkDag(ds ipld.DAGService, depth int) (cid.Cid, int) { +func mkDag(ds ipld.DAGService, depth int) (cid.Cid, int, uint64) { ctx := context.Background() totalChildren := 0 @@ -1213,7 +1223,12 @@ func mkDag(ds ipld.DAGService, depth int) (cid.Cid, int) { panic(err) } - return nd.Cid(), totalChildren + totalSize, err := nd.Size() + if err != nil { + panic(err) + } + + return nd.Cid(), totalChildren, totalSize } func mkNodeWithChildren(getChild func() *ProtoNode, width int) *ProtoNode {