-
-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add prim algorithm to find mst (#710)
- Loading branch information
Showing
3 changed files
with
208 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
// The Prim's algorithm computes the minimum spanning tree for a weighted undirected graph | ||
// Worst Case Time Complexity: O(E log V) using Binary heap, where V is the number of vertices and E is the number of edges | ||
// Space Complexity: O(V + E) | ||
// Implementation is based on the book 'Introduction to Algorithms' (CLRS) | ||
|
||
package graph | ||
|
||
import ( | ||
"container/heap" | ||
) | ||
|
||
type minEdge []Edge | ||
|
||
func (h minEdge) Len() int { return len(h) } | ||
func (h minEdge) Less(i, j int) bool { return h[i].Weight < h[j].Weight } | ||
func (h minEdge) Swap(i, j int) { h[i], h[j] = h[j], h[i] } | ||
|
||
func (h *minEdge) Push(x interface{}) { | ||
*h = append(*h, x.(Edge)) | ||
} | ||
|
||
func (h *minEdge) Pop() interface{} { | ||
old := *h | ||
n := len(old) | ||
x := old[n-1] | ||
*h = old[0 : n-1] | ||
return x | ||
} | ||
|
||
func (g *Graph) PrimMST(start Vertex) ([]Edge, int) { | ||
var mst []Edge | ||
marked := make([]bool, g.vertices) | ||
h := &minEdge{} | ||
// Pushing neighbors of the start node to the binary heap | ||
for neighbor, weight := range g.edges[int(start)] { | ||
heap.Push(h, Edge{start, Vertex(neighbor), weight}) | ||
} | ||
marked[start] = true | ||
cost := 0 | ||
for h.Len() > 0 { | ||
e := heap.Pop(h).(Edge) | ||
end := int(e.End) | ||
// To avoid cycles | ||
if marked[end] { | ||
continue | ||
} | ||
marked[end] = true | ||
cost += e.Weight | ||
mst = append(mst, e) | ||
// Check for neighbors of the newly added edge's End vertex | ||
for neighbor, weight := range g.edges[end] { | ||
if !marked[neighbor] { | ||
heap.Push(h, Edge{e.End, Vertex(neighbor), weight}) | ||
} | ||
} | ||
} | ||
return mst, cost | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
package graph | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"testing" | ||
) | ||
|
||
func TestPrimMST(t *testing.T) { | ||
|
||
var testCases = []struct { | ||
edges []Edge | ||
vertices int | ||
start int | ||
cost int | ||
mst []Edge | ||
}{ | ||
{ | ||
edges: []Edge{ | ||
{Start: 0, End: 1, Weight: 4}, | ||
{Start: 0, End: 2, Weight: 13}, | ||
{Start: 0, End: 3, Weight: 7}, | ||
{Start: 0, End: 4, Weight: 7}, | ||
{Start: 1, End: 2, Weight: 9}, | ||
{Start: 1, End: 3, Weight: 3}, | ||
{Start: 1, End: 4, Weight: 7}, | ||
{Start: 2, End: 3, Weight: 10}, | ||
{Start: 2, End: 4, Weight: 14}, | ||
{Start: 3, End: 4, Weight: 4}, | ||
}, | ||
vertices: 5, | ||
start: 0, | ||
cost: 20, | ||
mst: []Edge{ | ||
{Start: 0, End: 1, Weight: 4}, | ||
{Start: 1, End: 3, Weight: 3}, | ||
{Start: 3, End: 4, Weight: 4}, | ||
{Start: 1, End: 2, Weight: 9}, | ||
}, | ||
}, | ||
{ | ||
edges: []Edge{ | ||
{Start: 0, End: 1, Weight: 4}, | ||
{Start: 0, End: 7, Weight: 8}, | ||
{Start: 1, End: 2, Weight: 8}, | ||
{Start: 1, End: 7, Weight: 11}, | ||
{Start: 2, End: 3, Weight: 7}, | ||
{Start: 2, End: 5, Weight: 4}, | ||
{Start: 2, End: 8, Weight: 2}, | ||
{Start: 3, End: 4, Weight: 9}, | ||
{Start: 3, End: 5, Weight: 14}, | ||
{Start: 4, End: 5, Weight: 10}, | ||
{Start: 5, End: 6, Weight: 2}, | ||
{Start: 6, End: 7, Weight: 1}, | ||
{Start: 6, End: 8, Weight: 6}, | ||
{Start: 7, End: 8, Weight: 7}, | ||
}, | ||
vertices: 9, | ||
start: 3, | ||
cost: 37, | ||
mst: []Edge{ | ||
{Start: 3, End: 2, Weight: 7}, | ||
{Start: 2, End: 8, Weight: 2}, | ||
{Start: 2, End: 5, Weight: 4}, | ||
{Start: 5, End: 6, Weight: 2}, | ||
{Start: 6, End: 7, Weight: 1}, | ||
{Start: 2, End: 1, Weight: 8}, | ||
{Start: 1, End: 0, Weight: 4}, | ||
{Start: 3, End: 4, Weight: 9}, | ||
}, | ||
}, | ||
{ | ||
edges: []Edge{ | ||
{Start: 0, End: 1, Weight: 2}, | ||
{Start: 0, End: 3, Weight: 6}, | ||
{Start: 1, End: 2, Weight: 3}, | ||
{Start: 1, End: 3, Weight: 8}, | ||
{Start: 1, End: 4, Weight: 5}, | ||
{Start: 2, End: 4, Weight: 7}, | ||
{Start: 3, End: 4, Weight: 9}, | ||
}, | ||
vertices: 5, | ||
start: 2, | ||
cost: 16, | ||
mst: []Edge{ | ||
{Start: 2, End: 1, Weight: 3}, | ||
{Start: 1, End: 0, Weight: 2}, | ||
{Start: 1, End: 4, Weight: 5}, | ||
{Start: 0, End: 3, Weight: 6}, | ||
}, | ||
}, | ||
{ | ||
edges: []Edge{ | ||
{Start: 0, End: 0, Weight: 0}, | ||
}, | ||
vertices: 1, | ||
start: 0, | ||
cost: 0, | ||
mst: nil, | ||
}, | ||
{ | ||
edges: []Edge{ | ||
{Start: 0, End: 1, Weight: 1}, | ||
{Start: 0, End: 2, Weight: 6}, | ||
{Start: 0, End: 3, Weight: 5}, | ||
{Start: 1, End: 2, Weight: 2}, | ||
{Start: 1, End: 4, Weight: 4}, | ||
{Start: 2, End: 4, Weight: 9}, | ||
}, | ||
vertices: 5, | ||
start: 4, | ||
cost: 12, | ||
mst: []Edge{ | ||
{Start: 4, End: 1, Weight: 4}, | ||
{Start: 1, End: 0, Weight: 1}, | ||
{Start: 1, End: 2, Weight: 2}, | ||
{Start: 0, End: 3, Weight: 5}, | ||
}, | ||
}, | ||
} | ||
|
||
for i, testCase := range testCases { | ||
t.Run(fmt.Sprintf("Test Case %d", i), func(t *testing.T) { | ||
// Initializing graph, adding edges | ||
graph := New(testCase.vertices) | ||
graph.Directed = false | ||
for _, edge := range testCase.edges { | ||
graph.AddWeightedEdge(int(edge.Start), int(edge.End), edge.Weight) | ||
} | ||
|
||
computedMST, computedCost := graph.PrimMST(Vertex(testCase.start)) | ||
|
||
// Compare the computed result with the expected result | ||
if computedCost != testCase.cost { | ||
t.Errorf("Test Case %d, Expected Cost: %d, Computed: %d", i, testCase.cost, computedCost) | ||
} | ||
|
||
if !reflect.DeepEqual(testCase.mst, computedMST) { | ||
t.Errorf("Test Case %d, Expected MST: %v, Computed: %v", i, testCase.mst, computedMST) | ||
} | ||
}) | ||
} | ||
} |