graph: clean up and fix implicit graph example

The From method incorrectly returned nil for the empty case and node neighbour
expansion did not mark visited nodes. Fix these and exercise the corrected path
in the test. Also clean up the code structure.
This commit is contained in:
Dan Kortschak
2024-04-18 05:28:56 +09:30
parent 1b7d9ca04a
commit 7bd265b283

View File

@@ -18,11 +18,18 @@ type GraphNode struct {
id int64
neighbors []graph.Node
roots []*GraphNode
name string
}
// NewGraphNode returns a new GraphNode.
func NewGraphNode(id int64) *GraphNode {
return &GraphNode{id: id}
func NewGraphNode(id int64, name string) *GraphNode {
return &GraphNode{name: name, id: id}
}
// String returns the node's name.
func (g *GraphNode) String() string {
return g.name
}
// Node allows GraphNode to satisfy the graph.Graph interface.
@@ -33,11 +40,7 @@ func (g *GraphNode) Node(id int64) graph.Node {
seen := map[int64]struct{}{g.id: {}}
for _, root := range g.roots {
if root.ID() == id {
return root
}
if root.has(seen, id) {
if root.ID() == id || root.has(seen, id) {
return root
}
}
@@ -62,28 +65,22 @@ func (g *GraphNode) has(seen map[int64]struct{}, id int64) bool {
if _, ok := seen[root.ID()]; ok {
continue
}
seen[root.ID()] = struct{}{}
if root.ID() == id {
if root.ID() == id || root.has(seen, id) {
return true
}
if root.has(seen, id) {
return true
}
}
for _, n := range g.neighbors {
if _, ok := seen[n.ID()]; ok {
continue
}
seen[n.ID()] = struct{}{}
if n.ID() == id {
return true
}
if gn, ok := n.(*GraphNode); ok {
if gn.has(seen, id) {
return true
@@ -100,16 +97,14 @@ func (g *GraphNode) Nodes() graph.Nodes {
seen := map[int64]struct{}{g.id: {}}
for _, root := range g.roots {
nodes = append(nodes, root)
seen[root.ID()] = struct{}{}
nodes = root.nodes(nodes, seen)
nodes = root.nodes(append(nodes, root), seen)
}
for _, n := range g.neighbors {
nodes = append(nodes, n)
seen[n.ID()] = struct{}{}
nodes = append(nodes, n)
if gn, ok := n.(*GraphNode); ok {
nodes = gn.nodes(nodes, seen)
}
@@ -124,15 +119,15 @@ func (g *GraphNode) nodes(dst []graph.Node, seen map[int64]struct{}) []graph.Nod
continue
}
seen[root.ID()] = struct{}{}
dst = append(dst, graph.Node(root))
dst = root.nodes(dst, seen)
dst = root.nodes(append(dst, graph.Node(root)), seen)
}
for _, n := range g.neighbors {
if _, ok := seen[n.ID()]; ok {
continue
}
seen[n.ID()] = struct{}{}
dst = append(dst, n)
if gn, ok := n.(*GraphNode); ok {
@@ -168,7 +163,7 @@ func (g *GraphNode) From(id int64) graph.Nodes {
}
}
return nil
return graph.Empty
}
func (g *GraphNode) findNeighbors(id int64, seen map[int64]struct{}) []graph.Node {
@@ -259,6 +254,7 @@ func (g *GraphNode) edgeBetween(uid, vid int64, seen map[int64]struct{}) graph.E
continue
}
seen[root.ID()] = struct{}{}
if result := root.edgeBetween(uid, vid, seen); result != nil {
return result
}
@@ -268,8 +264,8 @@ func (g *GraphNode) edgeBetween(uid, vid int64, seen map[int64]struct{}) graph.E
if _, ok := seen[n.ID()]; ok {
continue
}
seen[n.ID()] = struct{}{}
if gn, ok := n.(*GraphNode); ok {
if result := gn.edgeBetween(uid, vid, seen); result != nil {
return result
@@ -319,25 +315,25 @@ func Example_implicit() {
// }
// graph G {
G := NewGraphNode(0)
G := NewGraphNode(0, "G")
// e
e := NewGraphNode(1)
e := NewGraphNode(1, "e")
// subgraph clusterA {
clusterA := NewGraphNode(2)
clusterA := NewGraphNode(2, "clusterA")
// a -- b
a := NewGraphNode(3)
b := NewGraphNode(4)
a := NewGraphNode(3, "a")
b := NewGraphNode(4, "b")
a.AddNeighbor(b)
b.AddNeighbor(a)
clusterA.AddRoot(a)
clusterA.AddRoot(b)
// subgraph clusterC {
clusterC := NewGraphNode(5)
clusterC := NewGraphNode(5, "clusterC")
// C -- D
C := NewGraphNode(6)
D := NewGraphNode(7)
C := NewGraphNode(6, "C")
D := NewGraphNode(7, "D")
C.AddNeighbor(D)
D.AddNeighbor(C)
@@ -348,10 +344,10 @@ func Example_implicit() {
// }
// subgraph clusterB {
clusterB := NewGraphNode(8)
clusterB := NewGraphNode(8, "clusterB")
// d -- f
d := NewGraphNode(9)
f := NewGraphNode(10)
d := NewGraphNode(9, "d")
f := NewGraphNode(10, "f")
d.AddNeighbor(f)
f.AddNeighbor(d)
clusterB.AddRoot(d)
@@ -379,7 +375,19 @@ func Example_implicit() {
fmt.Println("C--D--d--f is a path in G.")
}
fmt.Println("\nConnected components:")
for _, c := range topo.ConnectedComponents(G) {
fmt.Printf(" %s\n", c)
}
// Output:
//
// C--D--d--f is a path in G.
//
// Connected components:
// [G]
// [e clusterB clusterC]
// [d D C f]
// [clusterA]
// [a b]
}