graph: refactor to use Node type

This commit is contained in:
Brian Picciano 2018-08-21 14:46:17 -04:00
parent 9534ff5c13
commit 20b2a80a3c
2 changed files with 118 additions and 79 deletions

View File

@ -111,9 +111,9 @@ func (g Graph) cp() Graph {
return g2 return g2
} }
// AddEdge returns a new Graph instance with the given Edge added to it. If the // Add returns a new Graph instance with the given Edge added to it. If the
// original Graph already had that Edge this returns the original Graph. // original Graph already had that Edge this returns the original Graph.
func (g Graph) AddEdge(e Edge) Graph { func (g Graph) Add(e Edge) Graph {
id := e.id() id := e.id()
if _, ok := g.m[id]; ok { if _, ok := g.m[id]; ok {
return g return g
@ -126,9 +126,9 @@ func (g Graph) AddEdge(e Edge) Graph {
return g2 return g2
} }
// DelEdge returns a new Graph instance without the given Edge in it. If the // Del returns a new Graph instance without the given Edge in it. If the
// original Graph didn't have that Edge this returns the original Graph. // original Graph didn't have that Edge this returns the original Graph.
func (g Graph) DelEdge(e Edge) Graph { func (g Graph) Del(e Edge) Graph {
id := e.id() id := e.id()
if _, ok := g.m[id]; !ok { if _, ok := g.m[id]; !ok {
return g return g
@ -141,24 +141,6 @@ func (g Graph) DelEdge(e Edge) Graph {
return g2 return g2
} }
// Values returns all Values which have incoming or outgoing Edges in the Graph.
func (g Graph) Values() []Value {
values := make([]Value, 0, len(g.m))
found := map[string]bool{}
tryAdd := func(v Value) {
if ok := found[v.ID]; !ok {
values = append(values, v)
found[v.ID] = true
}
}
for _, e := range g.m {
tryAdd(e.Head)
tryAdd(e.Tail)
}
return values
}
// Edges returns all Edges which are part of the Graph // Edges returns all Edges which are part of the Graph
func (g Graph) Edges() []Edge { func (g Graph) Edges() []Edge {
edges := make([]Edge, 0, len(g.m)) edges := make([]Edge, 0, len(g.m))
@ -168,38 +150,88 @@ func (g Graph) Edges() []Edge {
return edges return edges
} }
// ValueEdges returns all input (e.Head==v) and output (e.Tail==v) Edges // NOTE the Node type exists primarily for convenience. As far as Graph's
// for the given Value in the Graph. // internals are concerned it doesn't _really_ exist, and no Graph method should
func (g Graph) ValueEdges(v Value) ([]Edge, []Edge) { // ever take Node as a parameter (except the callback functions like in
in := make([]Edge, 0, len(g.vIns[v.ID])) // Traverse, where it's not really being taken in).
for edgeID := range g.vIns[v.ID] {
in = append(in, g.m[edgeID])
}
out := make([]Edge, 0, len(g.vOuts[v.ID])) // Node wraps a Value in a Graph to include that Node's input and output Edges
for edgeID := range g.vOuts[v.ID] { // in that Graph.
out = append(out, g.m[edgeID]) type Node struct {
Value
// All Edges in the Graph with this Node's Value as their Head and Tail,
// respectively.
Ins, Outs []Edge
}
// Node returns the Node for the given Value, or false if the Graph doesn't
// contain the Value.
func (g Graph) Node(v Value) (Node, bool) {
n := Node{Value: v}
for edgeID := range g.vIns[v.ID] {
n.Ins = append(n.Ins, g.m[edgeID])
} }
return in, out for edgeID := range g.vOuts[v.ID] {
n.Outs = append(n.Outs, g.m[edgeID])
}
return n, len(n.Ins) > 0 || len(n.Outs) > 0
}
// Nodes returns a Node for each Value which has at least one Edge in the Graph,
// with the Nodes mapped by their Value's ID.
func (g Graph) Nodes() map[string]Node {
nodesM := make(map[string]Node, len(g.m)*2)
for _, edge := range g.m {
// if head and tail are modified at the same time it messes up the case
// where they are the same node
{
head := nodesM[edge.Head.ID]
head.Value = edge.Head
head.Ins = append(head.Ins, edge)
nodesM[head.ID] = head
}
{
tail := nodesM[edge.Tail.ID]
tail.Value = edge.Tail
tail.Outs = append(tail.Outs, edge)
nodesM[tail.ID] = tail
}
}
return nodesM
}
// Has returns true if the Graph contains at least one Edge with a Head or Tail
// of Value.
func (g Graph) Has(v Value) bool {
if _, ok := g.vIns[v.ID]; ok {
return true
} else if _, ok := g.vOuts[v.ID]; ok {
return true
}
return false
} }
// Traverse is used to traverse the Graph until a stopping point is reached. // Traverse is used to traverse the Graph until a stopping point is reached.
// Traversal starts with the cursor at the given start value. Each hop is // Traversal starts with the cursor at the given start Value. Each hop is
// performed by passing the cursor value along with its input and output Edges // performed by passing the cursor Value's Node into the next function. The
// into the next function. The cursor moves to the returned Value and next is // cursor moves to the returned Value and next is called again, and so on.
// called again, and so on.
// //
// If the boolean returned from the next function is false traversal stops and // If the boolean returned from the next function is false traversal stops and
// this method returns. // this method returns.
// //
// If start has no Edges in the Graph, or a Value returned from next doesn't, // If start has no Edges in the Graph, or a Value returned from next doesn't,
// this will still call next, but the in/out params will both be empty. // this will still call next, but the Node will be the zero value.
func (g Graph) Traverse(start Value, next func(v Value, in, out []Edge) (Value, bool)) { func (g Graph) Traverse(start Value, next func(n Node) (Value, bool)) {
curr := start curr := start
var ok bool
for { for {
in, out := g.ValueEdges(curr) currNode, ok := g.Node(curr)
if curr, ok = next(curr, in, out); !ok { if ok {
curr, ok = next(currNode)
} else {
curr, ok = next(Node{})
}
if !ok {
return return
} }
} }

View File

@ -57,56 +57,63 @@ func TestGraph(t *T) {
Apply: func(ss mchk.State, a mchk.Action) (mchk.State, error) { Apply: func(ss mchk.State, a mchk.Action) (mchk.State, error) {
s, p := ss.(state), a.Params.(params) s, p := ss.(state), a.Params.(params)
if p.add != (Edge{}) { if p.add != (Edge{}) {
s.Graph = s.Graph.AddEdge(p.add) s.Graph = s.Graph.Add(p.add)
s.m[p.add.id()] = p.add s.m[p.add.id()] = p.add
} else { } else {
s.Graph = s.Graph.DelEdge(p.del) s.Graph = s.Graph.Del(p.del)
delete(s.m, p.del.id()) delete(s.m, p.del.id())
} }
{ // test Values and Edges methods { // test Nodes and Edges methods
vals := s.Graph.Values() nodes := s.Graph.Nodes()
edges := s.Graph.Edges() edges := s.Graph.Edges()
var aa []massert.Assertion var aa []massert.Assertion
found := map[string]bool{} vals := map[string]bool{}
tryAssert := func(v Value) { ins, outs := map[string]int{}, map[string]int{}
if ok := found[v.ID]; !ok {
found[v.ID] = true
aa = append(aa, massert.Has(vals, v))
}
}
for _, e := range s.m { for _, e := range s.m {
aa = append(aa, massert.Has(edges, e)) aa = append(aa, massert.Has(edges, e))
tryAssert(e.Head) aa = append(aa, massert.HasKey(nodes, e.Head.ID))
tryAssert(e.Tail) aa = append(aa, massert.Has(nodes[e.Head.ID].Ins, e))
aa = append(aa, massert.HasKey(nodes, e.Tail.ID))
aa = append(aa, massert.Has(nodes[e.Tail.ID].Outs, e))
vals[e.Head.ID] = true
vals[e.Tail.ID] = true
ins[e.Head.ID]++
outs[e.Tail.ID]++
} }
aa = append(aa, massert.Len(vals, len(found)))
aa = append(aa, massert.Len(edges, len(s.m))) aa = append(aa, massert.Len(edges, len(s.m)))
aa = append(aa, massert.Len(nodes, len(vals)))
for id, node := range nodes {
aa = append(aa, massert.Len(node.Ins, ins[id]))
aa = append(aa, massert.Len(node.Outs, outs[id]))
}
if err := massert.All(aa...).Assert(); err != nil { if err := massert.All(aa...).Assert(); err != nil {
return nil, err return nil, err
} }
} }
{ // test ValueEdges { // test Node and Has. Nodes has already been tested so we can use
for _, val := range s.Graph.Values() { // its returned Nodes as the expected ones
in, out := s.Graph.ValueEdges(val) var aa []massert.Assertion
var expIn, expOut []Edge for _, expNode := range s.Graph.Nodes() {
for _, e := range s.m { var naa []massert.Assertion
if e.Tail.ID == val.ID { node, ok := s.Graph.Node(expNode.Value)
expOut = append(expOut, e) naa = append(naa, massert.Equal(true, ok))
} naa = append(naa, massert.Equal(true, s.Graph.Has(expNode.Value)))
if e.Head.ID == val.ID { naa = append(naa, massert.Subset(expNode.Ins, node.Ins))
expIn = append(expIn, e) naa = append(naa, massert.Len(node.Ins, len(expNode.Ins)))
} naa = append(naa, massert.Subset(expNode.Outs, node.Outs))
} naa = append(naa, massert.Len(node.Outs, len(expNode.Outs)))
if err := massert.Comment(massert.All(
massert.Subset(expIn, in), aa = append(aa, massert.Comment(massert.All(naa...), "v:%q", expNode.ID))
massert.Len(in, len(expIn)), }
massert.Subset(expOut, out), _, ok := s.Graph.Node(strV("zz"))
massert.Len(out, len(expOut)), aa = append(aa, massert.Equal(false, ok))
), "val:%q", val.V).Assert(); err != nil { aa = append(aa, massert.Equal(false, s.Graph.Has(strV("zz"))))
return nil, err
} if err := massert.All(aa...).Assert(); err != nil {
return nil, err
} }
} }
@ -147,10 +154,10 @@ func TestSubGraphAndEqual(t *T) {
Apply: func(ss mchk.State, a mchk.Action) (mchk.State, error) { Apply: func(ss mchk.State, a mchk.Action) (mchk.State, error) {
s, p := ss.(state), a.Params.(params) s, p := ss.(state), a.Params.(params)
if p.add1 { if p.add1 {
s.g1 = s.g1.AddEdge(p.e) s.g1 = s.g1.Add(p.e)
} }
if p.add2 { if p.add2 {
s.g2 = s.g2.AddEdge(p.e) s.g2 = s.g2.Add(p.e)
} }
s.expSubGraph = s.expSubGraph && p.add1 s.expSubGraph = s.expSubGraph && p.add1
s.expEqual = s.expEqual && p.add1 && p.add2 s.expEqual = s.expEqual && p.add1 && p.add2