diff --git a/graph.go b/graph.go index d33c7ec8..74baba04 100644 --- a/graph.go +++ b/graph.go @@ -120,3 +120,6 @@ type DStarGraph interface { Move(target Node) ChangedEdges() (newCostFunc func(Node, Node) float64, changedEdges []Edge) } + +// A function that returns the cost from one node to another +type CostFun func(Node, Node) float64 diff --git a/search/graphSearch.go b/search/graphSearch.go index ecea06d8..03aa8ef1 100644 --- a/search/graphSearch.go +++ b/search/graphSearch.go @@ -30,13 +30,14 @@ import ( // To run Uniform Cost Search, run A* with the NullHeuristic // // To run Breadth First Search, run A* with both the NullHeuristic and UniformCost (or any cost function that returns a uniform positive value) -func AStar(start, goal gr.Node, graph gr.Graph, Cost, HeuristicCost func(gr.Node, gr.Node) float64) (path []gr.Node, cost float64, nodesExpanded int) { - successors, _, _, _, _, _, Cost, HeuristicCost := setupFuncs(graph, Cost, HeuristicCost) +func AStar(start, goal gr.Node, graph gr.Graph, cost, heuristicCost gr.CostFun) (path []gr.Node, pathCost float64, nodesExpanded int) { + sf := setupFuncs(graph, cost, heuristicCost) + successors, cost, heuristicCost := sf.successors, sf.cost, sf.heuristicCost closedSet := make(map[int]internalNode) openSet := &aStarPriorityQueue{nodes: make([]internalNode, 0), indexList: make(map[int]int)} heap.Init(openSet) - node := internalNode{start, 0, HeuristicCost(start, goal)} + node := internalNode{start, 0, heuristicCost(start, goal)} heap.Push(openSet, node) predecessor := make(map[int]gr.Node) @@ -56,15 +57,15 @@ func AStar(start, goal gr.Node, graph gr.Graph, Cost, HeuristicCost func(gr.Node continue } - g := curr.gscore + Cost(curr.Node, neighbor) + g := curr.gscore + cost(curr.Node, neighbor) if existing, exists := openSet.Find(neighbor.ID()); !exists { predecessor[neighbor.ID()] = curr - node = internalNode{neighbor, g, g + HeuristicCost(neighbor, goal)} + node = internalNode{neighbor, g, g + heuristicCost(neighbor, goal)} heap.Push(openSet, node) } else if g < existing.gscore { predecessor[neighbor.ID()] = curr - openSet.Fix(neighbor.ID(), g, g+HeuristicCost(neighbor, goal)) + openSet.Fix(neighbor.ID(), g, g+heuristicCost(neighbor, goal)) } } } @@ -89,8 +90,10 @@ func BreadthFirstSearch(start, goal gr.Node, graph gr.Graph) ([]gr.Node, int) { // Like A*, Dijkstra's Algorithm likely won't run correctly with negative edge weights -- use Bellman-Ford for that instead // // Dijkstra's algorithm usually only returns a cost map, however, since the data is available this version will also reconstruct the path to every node -func Dijkstra(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) float64) (paths map[int][]gr.Node, costs map[int]float64) { - successors, _, _, _, _, _, Cost, _ := setupFuncs(graph, Cost, nil) +func Dijkstra(source gr.Node, graph gr.Graph, cost gr.CostFun) (paths map[int][]gr.Node, costs map[int]float64) { + + sf := setupFuncs(graph, cost, nil) + successors, cost := sf.successors, sf.cost nodes := graph.NodeList() openSet := &aStarPriorityQueue{nodes: make([]internalNode, 0), indexList: make(map[int]int)} @@ -111,7 +114,7 @@ func Dijkstra(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) float6 closedSet.Add(node.ID()) for _, neighbor := range successors(node) { - tmpCost := costs[node.ID()] + Cost(node, neighbor) + tmpCost := costs[node.ID()] + cost(node, neighbor) if cost, ok := costs[neighbor.ID()]; !ok { costs[neighbor.ID()] = tmpCost predecessor[neighbor.ID()] = node @@ -140,8 +143,9 @@ func Dijkstra(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) float6 // // Like Dijkstra's, along with the costs this implementation will also construct all the paths for you. In addition, it has a third return value which will be true if the algorithm was aborted // due to the presence of a negative edge weight cycle. -func BellmanFord(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) float64) (paths map[int][]gr.Node, costs map[int]float64, err error) { - successors, _, _, _, _, _, Cost, _ := setupFuncs(graph, Cost, nil) +func BellmanFord(source gr.Node, graph gr.Graph, cost gr.CostFun) (paths map[int][]gr.Node, costs map[int]float64, err error) { + sf := setupFuncs(graph, cost, nil) + successors, cost := sf.successors, sf.cost predecessor := make(map[int]gr.Node) costs = make(map[int]float64) @@ -155,7 +159,7 @@ func BellmanFord(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) flo nodeIDMap[node.ID()] = node succs := successors(node) for _, succ := range succs { - weight := Cost(node, succ) + weight := cost(node, succ) nodeIDMap[succ.ID()] = succ if dist := costs[node.ID()] + weight; dist < costs[succ.ID()] { @@ -169,7 +173,7 @@ func BellmanFord(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) flo for _, node := range nodes { for _, succ := range successors(node) { - weight := Cost(node, succ) + weight := cost(node, succ) if costs[node.ID()]+weight < costs[succ.ID()] { return nil, nil, errors.New("Negative edge cycle detected") } @@ -195,8 +199,10 @@ func BellmanFord(source gr.Node, graph gr.Graph, Cost func(gr.Node, gr.Node) flo // // Its return values are, in order: a map from the source node, to the destination node, to the path between them; a map from the source node, to the destination node, to the cost of the path between them; // and a bool that is true if Bellman-Ford detected a negative edge weight cycle -- thus causing it (and this algorithm) to abort (if aborted is true, both maps will be nil). -func Johnson(graph gr.Graph, Cost func(gr.Node, gr.Node) float64) (nodePaths map[int]map[int][]gr.Node, nodeCosts map[int]map[int]float64, err error) { - successors, _, _, _, _, _, Cost, _ := setupFuncs(graph, Cost, nil) +func Johnson(graph gr.Graph, cost gr.CostFun) (nodePaths map[int]map[int][]gr.Node, nodeCosts map[int]map[int]float64, err error) { + sf := setupFuncs(graph, cost, nil) + successors, cost := sf.successors, sf.cost + /* Copy graph into a mutable one since it has to be altered for this algorithm */ dummyGraph := concrete.NewGonumGraph(true) for _, node := range graph.NodeList() { @@ -204,12 +210,12 @@ func Johnson(graph gr.Graph, Cost func(gr.Node, gr.Node) float64) (nodePaths map if !dummyGraph.NodeExists(node) { dummyGraph.AddNode(node, neighbors) for _, neighbor := range neighbors { - dummyGraph.SetEdgeCost(concrete.GonumEdge{node, neighbor}, Cost(node, neighbor)) + dummyGraph.SetEdgeCost(concrete.GonumEdge{node, neighbor}, cost(node, neighbor)) } } else { for _, neighbor := range neighbors { dummyGraph.AddEdge(concrete.GonumEdge{node, neighbor}) - dummyGraph.SetEdgeCost(concrete.GonumEdge{node, neighbor}, Cost(node, neighbor)) + dummyGraph.SetEdgeCost(concrete.GonumEdge{node, neighbor}, cost(node, neighbor)) } } } @@ -229,7 +235,7 @@ func Johnson(graph gr.Graph, Cost func(gr.Node, gr.Node) float64) (nodePaths map /* Step 3: reweight the graph and remove the dummy node */ for _, node := range graph.NodeList() { for _, succ := range successors(node) { - dummyGraph.SetEdgeCost(concrete.GonumEdge{node, succ}, Cost(node, succ)+costs[node.ID()]-costs[succ.ID()]) + dummyGraph.SetEdgeCost(concrete.GonumEdge{node, succ}, cost(node, succ)+costs[node.ID()]-costs[succ.ID()]) } } @@ -249,7 +255,9 @@ func Johnson(graph gr.Graph, Cost func(gr.Node, gr.Node) float64) (nodePaths map // Expands the first node it sees trying to find the destination. Depth First Search is *not* guaranteed to find the shortest path, // however, if a path exists DFS is guaranteed to find it (provided you don't find a way to implement a Graph with an infinite depth) func DepthFirstSearch(start, goal gr.Node, graph gr.Graph) []gr.Node { - successors, _, _, _, _, _, _, _ := setupFuncs(graph, nil, nil) + sf := setupFuncs(graph, nil, nil) + successors := sf.successors + closedSet := set.NewSet() openSet := xifo.GonumStack([]interface{}{start}) predecessor := make(map[int]gr.Node) @@ -299,7 +307,8 @@ func CopyGraph(dst gr.MutableGraph, src gr.Graph) { dst.EmptyGraph() dst.SetDirected(false) - successors, _, _, _, _, _, cost, _ := setupFuncs(src, nil, nil) + sf := setupFuncs(src, nil, nil) + successors, cost := sf.successors, sf.cost for _, node := range src.NodeList() { succs := successors(node) @@ -335,7 +344,7 @@ func Tarjan(graph gr.Graph) (sccs [][]gr.Node) { lowlinks := make(map[int]int, len(nodes)) indices := make(map[int]int, len(nodes)) - successors, _, _, _, _, _, _, _ := setupFuncs(graph, nil, nil) + successors := setupFuncs(graph, nil, nil).successors var strongconnect func(gr.Node) []gr.Node @@ -384,7 +393,8 @@ func Tarjan(graph gr.Graph) (sccs [][]gr.Node) { // // Special case: a nil or zero length path is considered valid (true), a path of length 1 (only one node) is the trivial case, but only if the node listed in path exists. func IsPath(path []gr.Node, graph gr.Graph) bool { - _, _, _, isSuccessor, _, _, _, _ := setupFuncs(graph, nil, nil) + isSuccessor := setupFuncs(graph, nil, nil).isSuccessor + if path == nil || len(path) == 0 { return true } else if len(path) == 1 { @@ -405,14 +415,9 @@ func IsPath(path []gr.Node, graph gr.Graph) bool { // Generates a minimum spanning tree with sets. // // As with other algorithms that use Cost, the order of precedence is Argument > Interface > UniformCost -func Prim(dst gr.MutableGraph, graph gr.EdgeListGraph, Cost func(gr.Node, gr.Node) float64) { - if Cost == nil { - if cgraph, ok := graph.(gr.Coster); ok { - Cost = cgraph.Cost - } else { - Cost = UniformCost - } - } +func Prim(dst gr.MutableGraph, graph gr.EdgeListGraph, cost gr.CostFun) { + cost = setupFuncs(graph, cost, nil).cost + dst.EmptyGraph() dst.SetDirected(false) @@ -433,9 +438,9 @@ func Prim(dst gr.MutableGraph, graph gr.EdgeListGraph, Cost func(gr.Node, gr.Nod edgeWeights := make(edgeSorter, 0) for _, edge := range edgeList { if dst.NodeExists(edge.Head()) && remainingNodes.Contains(edge.Tail().ID()) { - edgeWeights = append(edgeWeights, WeightedEdge{Edge: edge, Weight: Cost(edge.Head(), edge.Tail())}) + edgeWeights = append(edgeWeights, WeightedEdge{Edge: edge, Weight: cost(edge.Head(), edge.Tail())}) } else if dst.NodeExists(edge.Tail()) && remainingNodes.Contains(edge.Head().ID()) { - edgeWeights = append(edgeWeights, WeightedEdge{Edge: edge, Weight: Cost(edge.Tail(), edge.Head())}) + edgeWeights = append(edgeWeights, WeightedEdge{Edge: edge, Weight: cost(edge.Tail(), edge.Head())}) } } @@ -459,7 +464,8 @@ func Prim(dst gr.MutableGraph, graph gr.EdgeListGraph, Cost func(gr.Node, gr.Nod // // As with other algorithms with Cost, the precedence goes Argument > Interface > UniformCost func Kruskal(dst gr.MutableGraph, graph gr.EdgeListGraph, cost func(gr.Node, gr.Node) float64) { - _, _, _, _, _, _, cost, _ = setupFuncs(graph, cost, nil) + cost = setupFuncs(graph, cost, nil).cost + dst.EmptyGraph() dst.SetDirected(false) @@ -506,7 +512,7 @@ func Dominators(start gr.Node, graph gr.Graph) map[int]*set.Set { allNodes.Add(node.ID()) } - _, predecessors, _, _, _, _, _, _ := setupFuncs(graph, nil, nil) + predecessors := setupFuncs(graph, nil, nil).predecessors for _, node := range nlist { dominators[node.ID()] = set.NewSet() @@ -550,7 +556,8 @@ func Dominators(start gr.Node, graph gr.Graph) map[int]*set.Set { // // This returns all possible post-dominators for all nodes, it does not prune for strict postdominators, immediate postdominators etc func PostDominators(end gr.Node, graph gr.Graph) map[int]*set.Set { - successors, _, _, _, _, _, _, _ := setupFuncs(graph, nil, nil) + successors := setupFuncs(graph, nil, nil).successors + allNodes := set.NewSet() nlist := graph.NodeList() dominators := make(map[int]*set.Set, len(nlist)) diff --git a/search/internals.go b/search/internals.go index 9fd80126..a14b8c3b 100644 --- a/search/internals.go +++ b/search/internals.go @@ -6,51 +6,57 @@ import ( gr "github.com/gonum/graph" ) +type searchFuncs struct { + successors, predecessors, neighbors func(gr.Node) []gr.Node + isSuccessor, isPredecessor, isNeighbor func(gr.Node, gr.Node) bool + cost, heuristicCost gr.CostFun +} + // Sets up the cost functions and successor functions so I don't have to do a type switch every time. // This almost always does more work than is necessary, but since it's only executed once per function, and graph functions are rather costly, the "extra work" // should be negligible. -func setupFuncs(graph gr.Graph, cost, heuristicCost func(gr.Node, gr.Node) float64) (successorsFunc, predecessorsFunc, neighborsFunc func(gr.Node) []gr.Node, isSuccessorFunc, isPredecessorFunc, - isNeighborFunc func(gr.Node, gr.Node) bool, - costFunc, heuristicCostFunc func(gr.Node, gr.Node) float64) { +func setupFuncs(graph gr.Graph, cost, heuristicCost gr.CostFun) searchFuncs { + + sf := searchFuncs{} switch g := graph.(type) { case gr.DirectedGraph: - successorsFunc = g.Successors - predecessorsFunc = g.Predecessors - neighborsFunc = g.Neighbors - isSuccessorFunc = g.IsSuccessor - isPredecessorFunc = g.IsPredecessor - isNeighborFunc = g.IsNeighbor + sf.successors = g.Successors + sf.predecessors = g.Predecessors + sf.neighbors = g.Neighbors + sf.isSuccessor = g.IsSuccessor + sf.isPredecessor = g.IsPredecessor + sf.isNeighbor = g.IsNeighbor default: - successorsFunc = g.Neighbors - predecessorsFunc = g.Neighbors - neighborsFunc = g.Neighbors - isSuccessorFunc = g.IsNeighbor - isPredecessorFunc = g.IsNeighbor - isNeighborFunc = g.IsNeighbor + sf.successors = g.Neighbors + sf.predecessors = g.Neighbors + sf.neighbors = g.Neighbors + sf.isSuccessor = g.IsNeighbor + sf.isPredecessor = g.IsNeighbor + sf.isNeighbor = g.IsNeighbor } if heuristicCost != nil { - heuristicCostFunc = heuristicCost + sf.heuristicCost = heuristicCost } else { if g, ok := graph.(gr.HeuristicCoster); ok { - heuristicCostFunc = g.HeuristicCost + sf.heuristicCost = g.HeuristicCost } else { - heuristicCostFunc = NullHeuristic + sf.heuristicCost = NullHeuristic } } if cost != nil { - costFunc = cost + sf.cost = cost } else { if g, ok := graph.(gr.Coster); ok { - costFunc = g.Cost + sf.cost = g.Cost } else { - costFunc = UniformCost + sf.cost = UniformCost } } - return + return sf } /* Purely internal data structures and functions (mostly for sorting) */