diff --git a/ipld/merkledag/merkledag.go b/ipld/merkledag/merkledag.go index 3153cf41e..c035dd424 100644 --- a/ipld/merkledag/merkledag.go +++ b/ipld/merkledag/merkledag.go @@ -172,7 +172,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s ng = &sesGetter{bserv.NewSession(ctx, ds.Blocks)} } - set := make(map[string]int) + set := make(map[cid.Cid]int) // Visit function returns true when: // * The element is not in the set and we're not over depthLim @@ -182,15 +182,14 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s // depthLim = -1 means we only return true if the element is not in the // set. visit := func(c cid.Cid, depth int) bool { - key := string(c.Bytes()) - oldDepth, ok := set[key] + oldDepth, ok := set[c] if (ok && depthLim < 0) || (depthLim >= 0 && depth > depthLim) { return false } if !ok || oldDepth > depth { - set[key] = depth + set[c] = depth return true } return false @@ -198,7 +197,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s v, _ := ctx.Value(progressContextKey).(*ProgressTracker) if v == nil { - return EnumerateChildrenAsyncDepth(ctx, GetLinksDirect(ng), root, 0, visit) + return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visit) } visitProgress := func(c cid.Cid, depth int) bool { @@ -208,7 +207,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s } return false } - return EnumerateChildrenAsyncDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress) + return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress) } // GetMany gets many nodes from the DAG at once. @@ -282,33 +281,31 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks { } } -// EnumerateChildren will walk the dag below the given root node and add all -// unseen children to the passed in set. -// TODO: parallelize to avoid disk latency perf hits? -func EnumerateChildren(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error { +// WalkGraph will walk the dag in order (depth first) starting at the given root. +func Walk(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error { visitDepth := func(c cid.Cid, depth int) bool { return visit(c) } - return EnumerateChildrenDepth(ctx, getLinks, root, 0, visitDepth) + return WalkDepth(ctx, getLinks, root, 0, visitDepth) } -// EnumerateChildrenDepth walks the dag below the given root and passes the -// current depth to a given visit function. The visit function can be used to -// limit DAG exploration. -func EnumerateChildrenDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error { +// WalkDepth walks the dag starting at the given root and passes the current +// depth to a given visit function. The visit function can be used to limit DAG +// exploration. +func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error { + if !visit(root, depth) { + return nil + } + links, err := getLinks(ctx, root) if err != nil { return err } for _, lnk := range links { - c := lnk.Cid - if visit(c, depth+1) { - err = EnumerateChildrenDepth(ctx, getLinks, c, depth+1, visit) - if err != nil { - return err - } + if err := WalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit); err != nil { + return err } } return nil @@ -344,23 +341,23 @@ func (p *ProgressTracker) Value() int { // 'fetchNodes' will start at a time var FetchGraphConcurrency = 32 -// EnumerateChildrenAsync is equivalent to EnumerateChildren *except* that it -// fetches children in parallel. +// WalkParallel is equivalent to Walk *except* that it explores multiple paths +// in parallel. // // NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. -func EnumerateChildrenAsync(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error { +func WalkParallel(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error { visitDepth := func(c cid.Cid, depth int) bool { return visit(c) } - return EnumerateChildrenAsyncDepth(ctx, getLinks, c, 0, visitDepth) + return WalkParallelDepth(ctx, getLinks, c, 0, visitDepth) } -// EnumerateChildrenAsyncDepth is equivalent to EnumerateChildrenDepth *except* -// that it fetches children in parallel (down to a maximum depth in the graph). +// WalkParallelDepth is equivalent to WalkDepth *except* that it fetches +// children in parallel. // // NOTE: It *does not* make multiple concurrent calls to the passed `visit` function. -func EnumerateChildrenAsyncDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error { +func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error { type cidDepth struct { cid cid.Cid depth int diff --git a/ipld/merkledag/merkledag_test.go b/ipld/merkledag/merkledag_test.go index bc87f3bb5..e56bb52d1 100644 --- a/ipld/merkledag/merkledag_test.go +++ b/ipld/merkledag/merkledag_test.go @@ -29,7 +29,7 @@ import ( // makeDepthTestingGraph makes a small DAG with two levels. The level-two // nodes are both children of the root and of one of the level 1 nodes. -// This is meant to test the EnumerateChildren*Depth functions. +// This is meant to test the Walk*Depth functions. func makeDepthTestingGraph(t *testing.T, ds ipld.DAGService) ipld.Node { root := NodeWithData(nil) l11 := NodeWithData([]byte("leve1_node1")) @@ -334,7 +334,7 @@ func TestFetchGraph(t *testing.T) { offlineDS := NewDAGService(bs) - err = EnumerateChildren(context.Background(), offlineDS.GetLinks, root.Cid(), func(_ cid.Cid) bool { return true }) + err = Walk(context.Background(), offlineDS.GetLinks, root.Cid(), func(_ cid.Cid) bool { return true }) if err != nil { t.Fatal(err) } @@ -347,11 +347,11 @@ func TestFetchGraphWithDepthLimit(t *testing.T) { } tests := []testcase{ - testcase{1, 3}, - testcase{0, 0}, - testcase{-1, 5}, - testcase{2, 5}, - testcase{3, 5}, + testcase{1, 4}, + testcase{0, 1}, + testcase{-1, 6}, + testcase{2, 6}, + testcase{3, 6}, } testF := func(t *testing.T, tc testcase) { @@ -383,7 +383,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) { } - err = EnumerateChildrenDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF) + err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF) if err != nil { t.Fatal(err) } @@ -400,7 +400,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) { } } -func TestEnumerateChildren(t *testing.T) { +func TestWalk(t *testing.T) { bsi := bstest.Mocks(1) ds := NewDAGService(bsi[0]) @@ -409,7 +409,7 @@ func TestEnumerateChildren(t *testing.T) { set := cid.NewSet() - err := EnumerateChildren(context.Background(), ds.GetLinks, root.Cid(), set.Visit) + err := Walk(context.Background(), ds.GetLinks, root.Cid(), set.Visit) if err != nil { t.Fatal(err) } @@ -736,7 +736,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) { } cset := cid.NewSet() - err = EnumerateChildrenAsync(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit) + err = WalkParallel(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit) if err == nil { t.Fatal("this should have failed") }