diff --git a/graph/graph.go b/graph/graph.go index b225c92..7f04060 100644 --- a/graph/graph.go +++ b/graph/graph.go @@ -43,6 +43,46 @@ func (e Edge) id() string { return fmt.Sprintf("%q-%q->%q", e.Tail, e.Val, e.Head) } +// an edgeIndex maps valueIDs to a set of edgeIDs. Graph keeps two edgeIndex's, +// one for input edges and one for output edges. +type edgeIndex map[string]map[string]struct{} + +func (ei edgeIndex) cp() edgeIndex { + if ei == nil { + return edgeIndex{} + } + ei2 := make(edgeIndex, len(ei)) + for valID, edgesM := range ei { + edgesM2 := make(map[string]struct{}, len(edgesM)) + for id := range edgesM { + edgesM2[id] = struct{}{} + } + ei2[valID] = edgesM2 + } + return ei2 +} + +func (ei edgeIndex) add(valID, edgeID string) { + edgesM, ok := ei[valID] + if !ok { + edgesM = map[string]struct{}{} + ei[valID] = edgesM + } + edgesM[edgeID] = struct{}{} +} + +func (ei edgeIndex) del(valID, edgeID string) { + edgesM, ok := ei[valID] + if !ok { + return + } + + delete(edgesM, edgeID) + if len(edgesM) == 0 { + delete(ei, valID) + } +} + // Graph implements an immutable, unidirectional graph which can hold generic // values. All methods are thread-safe as they don't modify the Graph in any // way. @@ -53,11 +93,17 @@ func (e Edge) id() string { // Edges are in random order. type Graph struct { m map[string]Edge + + // these are indices mapping Value IDs to all the in/out edges for that + // Value in the Graph. + vIns, vOuts edgeIndex } func (g Graph) cp() Graph { g2 := Graph{ - m: make(map[string]Edge, len(g.m)), + m: make(map[string]Edge, len(g.m)), + vIns: g.vIns.cp(), + vOuts: g.vOuts.cp(), } for id, e := range g.m { g2.m[id] = e @@ -75,6 +121,8 @@ func (g Graph) AddEdge(e Edge) Graph { g2 := g.cp() g2.m[id] = e + g2.vIns.add(e.Head.ID, id) + g2.vOuts.add(e.Tail.ID, id) return g2 } @@ -88,6 +136,8 @@ func (g Graph) DelEdge(e Edge) Graph { g2 := g.cp() delete(g2.m, id) + g2.vIns.del(e.Head.ID, id) + g2.vOuts.del(e.Tail.ID, id) return g2 } @@ -121,14 +171,14 @@ func (g Graph) Edges() []Edge { // ValueEdges returns all input (e.Head==v) and output (e.Tail==v) Edges // for the given Value in the Graph. func (g Graph) ValueEdges(v Value) ([]Edge, []Edge) { - var in, out []Edge - for _, e := range g.m { - if e.Tail.ID == v.ID { - out = append(out, e) - } - if e.Head.ID == v.ID { - in = append(in, e) - } + in := make([]Edge, 0, len(g.vIns[v.ID])) + for edgeID := range g.vIns[v.ID] { + in = append(in, g.m[edgeID]) + } + + out := make([]Edge, 0, len(g.vOuts[v.ID])) + for edgeID := range g.vOuts[v.ID] { + out = append(out, g.m[edgeID]) } return in, out }