Skip to content

Commit

Permalink
feat: add prim algorithm to find mst (#710)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafar75 authored Mar 15, 2024
1 parent 237d88f commit 0254892
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 6 deletions.
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,13 @@ Read our [Contribution Guidelines](CONTRIBUTING.md) before you contribute.
5. [`FloydWarshall`](./graph/floydwarshall.go#L15): FloydWarshall Returns all pair's shortest path using Floyd Warshall algorithm
6. [`GetIdx`](./graph/depthfirstsearch.go#L3): No description provided.
7. [`KruskalMST`](./graph/kruskal.go#L23): No description provided.
8. [`LowestCommonAncestor`](./graph/lowestcommonancestor.go#L111): For each node, we will precompute its ancestor above him, its ancestor two nodes above, its ancestor four nodes above, etc. Let's call `jump[j][u]` is the `2^j`-th ancestor above the node `u` with `u` in range `[0, numbersVertex)`, `j` in range `[0,MAXLOG)`. These information allow us to jump from any node to any ancestor above it in `O(MAXLOG)` time.
9. [`New`](./graph/graph.go#L16): Constructor functions for graphs (undirected by default)
10. [`NewTree`](./graph/lowestcommonancestor.go#L84): No description provided.
11. [`NewUnionFind`](./graph/unionfind.go#L24): Initialise a new union find data structure with s nodes
12. [`NotExist`](./graph/depthfirstsearch.go#L12): No description provided.
13. [`Topological`](./graph/topological.go#L7): Topological assumes that graph given is valid and that its possible to get a topological ordering. constraints are array of []int{a, b}, representing an edge going from a to b
8. [`PrimMST`](./graph/prim.go#30): Computes the minimum spanning tree of a weighted undirected graph
9. [`LowestCommonAncestor`](./graph/lowestcommonancestor.go#L111): For each node, we will precompute its ancestor above him, its ancestor two nodes above, its ancestor four nodes above, etc. Let's call `jump[j][u]` is the `2^j`-th ancestor above the node `u` with `u` in range `[0, numbersVertex)`, `j` in range `[0,MAXLOG)`. These information allow us to jump from any node to any ancestor above it in `O(MAXLOG)` time.
10. [`New`](./graph/graph.go#L16): Constructor functions for graphs (undirected by default)
11. [`NewTree`](./graph/lowestcommonancestor.go#L84): No description provided.
12. [`NewUnionFind`](./graph/unionfind.go#L24): Initialise a new union find data structure with s nodes
13. [`NotExist`](./graph/depthfirstsearch.go#L12): No description provided.
14. [`Topological`](./graph/topological.go#L7): Topological assumes that graph given is valid and that its possible to get a topological ordering. constraints are array of []int{a, b}, representing an edge going from a to b

---
##### Types
Expand Down
58 changes: 58 additions & 0 deletions graph/prim.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// The Prim's algorithm computes the minimum spanning tree for a weighted undirected graph
// Worst Case Time Complexity: O(E log V) using Binary heap, where V is the number of vertices and E is the number of edges
// Space Complexity: O(V + E)
// Implementation is based on the book 'Introduction to Algorithms' (CLRS)

package graph

import (
"container/heap"
)

type minEdge []Edge

func (h minEdge) Len() int { return len(h) }
func (h minEdge) Less(i, j int) bool { return h[i].Weight < h[j].Weight }
func (h minEdge) Swap(i, j int) { h[i], h[j] = h[j], h[i] }

func (h *minEdge) Push(x interface{}) {
*h = append(*h, x.(Edge))
}

func (h *minEdge) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

func (g *Graph) PrimMST(start Vertex) ([]Edge, int) {
var mst []Edge
marked := make([]bool, g.vertices)
h := &minEdge{}
// Pushing neighbors of the start node to the binary heap
for neighbor, weight := range g.edges[int(start)] {
heap.Push(h, Edge{start, Vertex(neighbor), weight})
}
marked[start] = true
cost := 0
for h.Len() > 0 {
e := heap.Pop(h).(Edge)
end := int(e.End)
// To avoid cycles
if marked[end] {
continue
}
marked[end] = true
cost += e.Weight
mst = append(mst, e)
// Check for neighbors of the newly added edge's End vertex
for neighbor, weight := range g.edges[end] {
if !marked[neighbor] {
heap.Push(h, Edge{e.End, Vertex(neighbor), weight})
}
}
}
return mst, cost
}
143 changes: 143 additions & 0 deletions graph/prim_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package graph

import (
"fmt"
"reflect"
"testing"
)

func TestPrimMST(t *testing.T) {

var testCases = []struct {
edges []Edge
vertices int
start int
cost int
mst []Edge
}{
{
edges: []Edge{
{Start: 0, End: 1, Weight: 4},
{Start: 0, End: 2, Weight: 13},
{Start: 0, End: 3, Weight: 7},
{Start: 0, End: 4, Weight: 7},
{Start: 1, End: 2, Weight: 9},
{Start: 1, End: 3, Weight: 3},
{Start: 1, End: 4, Weight: 7},
{Start: 2, End: 3, Weight: 10},
{Start: 2, End: 4, Weight: 14},
{Start: 3, End: 4, Weight: 4},
},
vertices: 5,
start: 0,
cost: 20,
mst: []Edge{
{Start: 0, End: 1, Weight: 4},
{Start: 1, End: 3, Weight: 3},
{Start: 3, End: 4, Weight: 4},
{Start: 1, End: 2, Weight: 9},
},
},
{
edges: []Edge{
{Start: 0, End: 1, Weight: 4},
{Start: 0, End: 7, Weight: 8},
{Start: 1, End: 2, Weight: 8},
{Start: 1, End: 7, Weight: 11},
{Start: 2, End: 3, Weight: 7},
{Start: 2, End: 5, Weight: 4},
{Start: 2, End: 8, Weight: 2},
{Start: 3, End: 4, Weight: 9},
{Start: 3, End: 5, Weight: 14},
{Start: 4, End: 5, Weight: 10},
{Start: 5, End: 6, Weight: 2},
{Start: 6, End: 7, Weight: 1},
{Start: 6, End: 8, Weight: 6},
{Start: 7, End: 8, Weight: 7},
},
vertices: 9,
start: 3,
cost: 37,
mst: []Edge{
{Start: 3, End: 2, Weight: 7},
{Start: 2, End: 8, Weight: 2},
{Start: 2, End: 5, Weight: 4},
{Start: 5, End: 6, Weight: 2},
{Start: 6, End: 7, Weight: 1},
{Start: 2, End: 1, Weight: 8},
{Start: 1, End: 0, Weight: 4},
{Start: 3, End: 4, Weight: 9},
},
},
{
edges: []Edge{
{Start: 0, End: 1, Weight: 2},
{Start: 0, End: 3, Weight: 6},
{Start: 1, End: 2, Weight: 3},
{Start: 1, End: 3, Weight: 8},
{Start: 1, End: 4, Weight: 5},
{Start: 2, End: 4, Weight: 7},
{Start: 3, End: 4, Weight: 9},
},
vertices: 5,
start: 2,
cost: 16,
mst: []Edge{
{Start: 2, End: 1, Weight: 3},
{Start: 1, End: 0, Weight: 2},
{Start: 1, End: 4, Weight: 5},
{Start: 0, End: 3, Weight: 6},
},
},
{
edges: []Edge{
{Start: 0, End: 0, Weight: 0},
},
vertices: 1,
start: 0,
cost: 0,
mst: nil,
},
{
edges: []Edge{
{Start: 0, End: 1, Weight: 1},
{Start: 0, End: 2, Weight: 6},
{Start: 0, End: 3, Weight: 5},
{Start: 1, End: 2, Weight: 2},
{Start: 1, End: 4, Weight: 4},
{Start: 2, End: 4, Weight: 9},
},
vertices: 5,
start: 4,
cost: 12,
mst: []Edge{
{Start: 4, End: 1, Weight: 4},
{Start: 1, End: 0, Weight: 1},
{Start: 1, End: 2, Weight: 2},
{Start: 0, End: 3, Weight: 5},
},
},
}

for i, testCase := range testCases {
t.Run(fmt.Sprintf("Test Case %d", i), func(t *testing.T) {
// Initializing graph, adding edges
graph := New(testCase.vertices)
graph.Directed = false
for _, edge := range testCase.edges {
graph.AddWeightedEdge(int(edge.Start), int(edge.End), edge.Weight)
}

computedMST, computedCost := graph.PrimMST(Vertex(testCase.start))

// Compare the computed result with the expected result
if computedCost != testCase.cost {
t.Errorf("Test Case %d, Expected Cost: %d, Computed: %d", i, testCase.cost, computedCost)
}

if !reflect.DeepEqual(testCase.mst, computedMST) {
t.Errorf("Test Case %d, Expected MST: %v, Computed: %v", i, testCase.mst, computedMST)
}
})
}
}

0 comments on commit 0254892

Please sign in to comment.