Skip to content

Commit 0254892

Browse files
authored
feat: add prim algorithm to find mst (#710)
1 parent 237d88f commit 0254892

File tree

3 files changed

+208
-6
lines changed

3 files changed

+208
-6
lines changed

README.md

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,12 +455,13 @@ Read our [Contribution Guidelines](CONTRIBUTING.md) before you contribute.
455455
5. [`FloydWarshall`](./graph/floydwarshall.go#L15): FloydWarshall Returns all pair's shortest path using Floyd Warshall algorithm
456456
6. [`GetIdx`](./graph/depthfirstsearch.go#L3): No description provided.
457457
7. [`KruskalMST`](./graph/kruskal.go#L23): No description provided.
458-
8. [`LowestCommonAncestor`](./graph/lowestcommonancestor.go#L111): For each node, we will precompute its ancestor above him, its ancestor two nodes above, its ancestor four nodes above, etc. Let's call `jump[j][u]` is the `2^j`-th ancestor above the node `u` with `u` in range `[0, numbersVertex)`, `j` in range `[0,MAXLOG)`. These information allow us to jump from any node to any ancestor above it in `O(MAXLOG)` time.
459-
9. [`New`](./graph/graph.go#L16): Constructor functions for graphs (undirected by default)
460-
10. [`NewTree`](./graph/lowestcommonancestor.go#L84): No description provided.
461-
11. [`NewUnionFind`](./graph/unionfind.go#L24): Initialise a new union find data structure with s nodes
462-
12. [`NotExist`](./graph/depthfirstsearch.go#L12): No description provided.
463-
13. [`Topological`](./graph/topological.go#L7): Topological assumes that graph given is valid and that its possible to get a topological ordering. constraints are array of []int{a, b}, representing an edge going from a to b
458+
8. [`PrimMST`](./graph/prim.go#30): Computes the minimum spanning tree of a weighted undirected graph
459+
9. [`LowestCommonAncestor`](./graph/lowestcommonancestor.go#L111): For each node, we will precompute its ancestor above him, its ancestor two nodes above, its ancestor four nodes above, etc. Let's call `jump[j][u]` is the `2^j`-th ancestor above the node `u` with `u` in range `[0, numbersVertex)`, `j` in range `[0,MAXLOG)`. These information allow us to jump from any node to any ancestor above it in `O(MAXLOG)` time.
460+
10. [`New`](./graph/graph.go#L16): Constructor functions for graphs (undirected by default)
461+
11. [`NewTree`](./graph/lowestcommonancestor.go#L84): No description provided.
462+
12. [`NewUnionFind`](./graph/unionfind.go#L24): Initialise a new union find data structure with s nodes
463+
13. [`NotExist`](./graph/depthfirstsearch.go#L12): No description provided.
464+
14. [`Topological`](./graph/topological.go#L7): Topological assumes that graph given is valid and that its possible to get a topological ordering. constraints are array of []int{a, b}, representing an edge going from a to b
464465

465466
---
466467
##### Types

graph/prim.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// The Prim's algorithm computes the minimum spanning tree for a weighted undirected graph
2+
// Worst Case Time Complexity: O(E log V) using Binary heap, where V is the number of vertices and E is the number of edges
3+
// Space Complexity: O(V + E)
4+
// Implementation is based on the book 'Introduction to Algorithms' (CLRS)
5+
6+
package graph
7+
8+
import (
9+
"container/heap"
10+
)
11+
12+
type minEdge []Edge
13+
14+
func (h minEdge) Len() int { return len(h) }
15+
func (h minEdge) Less(i, j int) bool { return h[i].Weight < h[j].Weight }
16+
func (h minEdge) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
17+
18+
func (h *minEdge) Push(x interface{}) {
19+
*h = append(*h, x.(Edge))
20+
}
21+
22+
func (h *minEdge) Pop() interface{} {
23+
old := *h
24+
n := len(old)
25+
x := old[n-1]
26+
*h = old[0 : n-1]
27+
return x
28+
}
29+
30+
func (g *Graph) PrimMST(start Vertex) ([]Edge, int) {
31+
var mst []Edge
32+
marked := make([]bool, g.vertices)
33+
h := &minEdge{}
34+
// Pushing neighbors of the start node to the binary heap
35+
for neighbor, weight := range g.edges[int(start)] {
36+
heap.Push(h, Edge{start, Vertex(neighbor), weight})
37+
}
38+
marked[start] = true
39+
cost := 0
40+
for h.Len() > 0 {
41+
e := heap.Pop(h).(Edge)
42+
end := int(e.End)
43+
// To avoid cycles
44+
if marked[end] {
45+
continue
46+
}
47+
marked[end] = true
48+
cost += e.Weight
49+
mst = append(mst, e)
50+
// Check for neighbors of the newly added edge's End vertex
51+
for neighbor, weight := range g.edges[end] {
52+
if !marked[neighbor] {
53+
heap.Push(h, Edge{e.End, Vertex(neighbor), weight})
54+
}
55+
}
56+
}
57+
return mst, cost
58+
}

graph/prim_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package graph
2+
3+
import (
4+
"fmt"
5+
"reflect"
6+
"testing"
7+
)
8+
9+
func TestPrimMST(t *testing.T) {
10+
11+
var testCases = []struct {
12+
edges []Edge
13+
vertices int
14+
start int
15+
cost int
16+
mst []Edge
17+
}{
18+
{
19+
edges: []Edge{
20+
{Start: 0, End: 1, Weight: 4},
21+
{Start: 0, End: 2, Weight: 13},
22+
{Start: 0, End: 3, Weight: 7},
23+
{Start: 0, End: 4, Weight: 7},
24+
{Start: 1, End: 2, Weight: 9},
25+
{Start: 1, End: 3, Weight: 3},
26+
{Start: 1, End: 4, Weight: 7},
27+
{Start: 2, End: 3, Weight: 10},
28+
{Start: 2, End: 4, Weight: 14},
29+
{Start: 3, End: 4, Weight: 4},
30+
},
31+
vertices: 5,
32+
start: 0,
33+
cost: 20,
34+
mst: []Edge{
35+
{Start: 0, End: 1, Weight: 4},
36+
{Start: 1, End: 3, Weight: 3},
37+
{Start: 3, End: 4, Weight: 4},
38+
{Start: 1, End: 2, Weight: 9},
39+
},
40+
},
41+
{
42+
edges: []Edge{
43+
{Start: 0, End: 1, Weight: 4},
44+
{Start: 0, End: 7, Weight: 8},
45+
{Start: 1, End: 2, Weight: 8},
46+
{Start: 1, End: 7, Weight: 11},
47+
{Start: 2, End: 3, Weight: 7},
48+
{Start: 2, End: 5, Weight: 4},
49+
{Start: 2, End: 8, Weight: 2},
50+
{Start: 3, End: 4, Weight: 9},
51+
{Start: 3, End: 5, Weight: 14},
52+
{Start: 4, End: 5, Weight: 10},
53+
{Start: 5, End: 6, Weight: 2},
54+
{Start: 6, End: 7, Weight: 1},
55+
{Start: 6, End: 8, Weight: 6},
56+
{Start: 7, End: 8, Weight: 7},
57+
},
58+
vertices: 9,
59+
start: 3,
60+
cost: 37,
61+
mst: []Edge{
62+
{Start: 3, End: 2, Weight: 7},
63+
{Start: 2, End: 8, Weight: 2},
64+
{Start: 2, End: 5, Weight: 4},
65+
{Start: 5, End: 6, Weight: 2},
66+
{Start: 6, End: 7, Weight: 1},
67+
{Start: 2, End: 1, Weight: 8},
68+
{Start: 1, End: 0, Weight: 4},
69+
{Start: 3, End: 4, Weight: 9},
70+
},
71+
},
72+
{
73+
edges: []Edge{
74+
{Start: 0, End: 1, Weight: 2},
75+
{Start: 0, End: 3, Weight: 6},
76+
{Start: 1, End: 2, Weight: 3},
77+
{Start: 1, End: 3, Weight: 8},
78+
{Start: 1, End: 4, Weight: 5},
79+
{Start: 2, End: 4, Weight: 7},
80+
{Start: 3, End: 4, Weight: 9},
81+
},
82+
vertices: 5,
83+
start: 2,
84+
cost: 16,
85+
mst: []Edge{
86+
{Start: 2, End: 1, Weight: 3},
87+
{Start: 1, End: 0, Weight: 2},
88+
{Start: 1, End: 4, Weight: 5},
89+
{Start: 0, End: 3, Weight: 6},
90+
},
91+
},
92+
{
93+
edges: []Edge{
94+
{Start: 0, End: 0, Weight: 0},
95+
},
96+
vertices: 1,
97+
start: 0,
98+
cost: 0,
99+
mst: nil,
100+
},
101+
{
102+
edges: []Edge{
103+
{Start: 0, End: 1, Weight: 1},
104+
{Start: 0, End: 2, Weight: 6},
105+
{Start: 0, End: 3, Weight: 5},
106+
{Start: 1, End: 2, Weight: 2},
107+
{Start: 1, End: 4, Weight: 4},
108+
{Start: 2, End: 4, Weight: 9},
109+
},
110+
vertices: 5,
111+
start: 4,
112+
cost: 12,
113+
mst: []Edge{
114+
{Start: 4, End: 1, Weight: 4},
115+
{Start: 1, End: 0, Weight: 1},
116+
{Start: 1, End: 2, Weight: 2},
117+
{Start: 0, End: 3, Weight: 5},
118+
},
119+
},
120+
}
121+
122+
for i, testCase := range testCases {
123+
t.Run(fmt.Sprintf("Test Case %d", i), func(t *testing.T) {
124+
// Initializing graph, adding edges
125+
graph := New(testCase.vertices)
126+
graph.Directed = false
127+
for _, edge := range testCase.edges {
128+
graph.AddWeightedEdge(int(edge.Start), int(edge.End), edge.Weight)
129+
}
130+
131+
computedMST, computedCost := graph.PrimMST(Vertex(testCase.start))
132+
133+
// Compare the computed result with the expected result
134+
if computedCost != testCase.cost {
135+
t.Errorf("Test Case %d, Expected Cost: %d, Computed: %d", i, testCase.cost, computedCost)
136+
}
137+
138+
if !reflect.DeepEqual(testCase.mst, computedMST) {
139+
t.Errorf("Test Case %d, Expected MST: %v, Computed: %v", i, testCase.mst, computedMST)
140+
}
141+
})
142+
}
143+
}

0 commit comments

Comments
 (0)