diff --git a/gg/gg.go b/gg/gg.go index 2ba705d..64db16d 100644 --- a/gg/gg.go +++ b/gg/gg.go @@ -437,3 +437,43 @@ func Equal(g1, g2 *Graph) bool { } return true } + +// Walk will traverse the Graph, calling the callback on every Vertex in the +// Graph once. If startWith is non-nil then that Vertex will be the first one +// passed to the callback and used as the starting point of the traversal. If +// the callback returns false the traversal is stopped. +func (g *Graph) Walk(startWith *Vertex, callback func(*Vertex) bool) { + g.makeView() + if len(g.view) == 0 { + return + } + + seen := make(map[*Vertex]bool, len(g.view)) + var innerWalk func(*Vertex) bool + innerWalk = func(v *Vertex) bool { + if seen[v] { + return true + } else if !callback(v) { + return false + } + seen[v] = true + for _, e := range v.In { + if !innerWalk(e.From) { + return false + } + } + return true + } + + if startWith != nil { + if !innerWalk(startWith) { + return + } + } + + for _, v := range g.view { + if !innerWalk(v) { + return + } + } +} diff --git a/gg/gg_test.go b/gg/gg_test.go index 8ca7012..d226bd1 100644 --- a/gg/gg_test.go +++ b/gg/gg_test.go @@ -68,14 +68,37 @@ func assertVertexEqual(t *T, exp, got *Vertex, msgAndArgs ...interface{}) bool { return assertInner(exp, got, map[*Vertex]bool{}) } -type graphTest struct { - name string - out func() *Graph - exp []*Vertex +func assertWalk(t *T, expVals, expJuncs int, g *Graph, msgAndArgs ...interface{}) { + seen := map[*Vertex]bool{} + var gotVals, gotJuncs int + g.Walk(nil, func(v *Vertex) bool { + assert.NotContains(t, seen, v, msgAndArgs...) + seen[v] = true + if v.VertexType == Value { + gotVals++ + } else { + gotJuncs++ + } + return true + }) + assert.Equal(t, expVals, gotVals, msgAndArgs...) + assert.Equal(t, expJuncs, gotJuncs, msgAndArgs...) } -func mkTest(name string, out func() *Graph, exp ...*Vertex) graphTest { - return graphTest{name: name, out: out, exp: exp} +type graphTest struct { + name string + out func() *Graph + exp []*Vertex + numVals, numJuncs int +} + +func mkTest(name string, out func() *Graph, numVals, numJuncs int, exp ...*Vertex) graphTest { + return graphTest{ + name: name, + out: out, + exp: exp, + numVals: numVals, numJuncs: numJuncs, + } } func TestGraph(t *T) { @@ -85,6 +108,7 @@ func TestGraph(t *T) { func() *Graph { return Null.AddValueIn(ValueOut(id("v0"), id("e0")), id("v1")) }, + 2, 0, value("v0"), value("v1", edge("e0", value("v0"))), ), @@ -95,6 +119,7 @@ func TestGraph(t *T) { g0 := Null.AddValueIn(ValueOut(id("v0"), id("e0")), id("v2")) return g0.AddValueIn(ValueOut(id("v1"), id("e1")), id("v2")) }, + 3, 0, value("v0"), value("v1"), value("v2", @@ -109,6 +134,7 @@ func TestGraph(t *T) { g0 := Null.AddValueIn(ValueOut(id("v0"), id("e0")), id("v1")) return g0.AddValueIn(ValueOut(id("v2"), id("e2")), id("v3")) }, + 4, 0, value("v0"), value("v1", edge("e0", value("v0"))), value("v2"), @@ -120,6 +146,7 @@ func TestGraph(t *T) { func() *Graph { return Null.AddValueIn(ValueOut(id("v0"), id("e")), id("v0")) }, + 1, 0, value("v0", edge("e", value("v0"))), ), @@ -129,6 +156,7 @@ func TestGraph(t *T) { g0 := Null.AddValueIn(ValueOut(id("v0"), id("e0")), id("v1")) return g0.AddValueIn(ValueOut(id("v1"), id("e1")), id("v0")) }, + 2, 0, value("v0", edge("e1", value("v1", edge("e0", value("v0"))))), value("v1", edge("e0", value("v0", edge("e1", value("v1"))))), ), @@ -140,6 +168,7 @@ func TestGraph(t *T) { g1 := g0.AddValueIn(ValueOut(id("v1"), id("e1")), id("v2")) return g1.AddValueIn(ValueOut(id("v2"), id("e2")), id("v1")) }, + 3, 0, value("v0"), value("v1", edge("e0", value("v0")), @@ -159,6 +188,7 @@ func TestGraph(t *T) { ej0 := JunctionOut([]OpenEdge{e0, e1}, id("ej0")) return Null.AddValueIn(ej0, id("v2")) }, + 3, 1, value("v0"), value("v1"), value("v2", junction("ej0", edge("e0", value("v0")), @@ -178,6 +208,7 @@ func TestGraph(t *T) { ej2 := JunctionOut([]OpenEdge{ej0, ej1}, id("ej2")) return Null.AddValueIn(ej2, id("v2")) }, + 3, 3, value("v0"), value("v1"), value("v2", junction("ej2", junction("ej0", @@ -203,6 +234,7 @@ func TestGraph(t *T) { e21 := ValueOut(id("v2"), id("e21")) return g1.AddValueIn(e21, id("v1")) }, + 3, 1, value("v0", edge("e20", value("v2", junction("ej0", edge("e0", value("v0")), edge("e1", value("v1", edge("e21", value("v2")))), @@ -219,6 +251,7 @@ func TestGraph(t *T) { } for i := range tests { + t.Logf("test[%d]:%q", i, tests[i].name) out := tests[i].out() for j, exp := range tests[i].exp { msgAndArgs := []interface{}{ @@ -232,8 +265,16 @@ func TestGraph(t *T) { assertVertexEqual(t, exp, v, msgAndArgs...) } + msgAndArgs := []interface{}{ + "tests[%d].name:%q", + i, tests[i].name, + } + // sanity check that graphs are equal to themselves - assert.True(t, Equal(out, out)) + assert.True(t, Equal(out, out), msgAndArgs...) + + // test the Walk method in here too + assertWalk(t, tests[i].numVals, tests[i].numJuncs, out, msgAndArgs...) } }