From 495cff8b21625552eb173351fb99345ae49682a6 Mon Sep 17 00:00:00 2001 From: Y L <86551259+soondubu137@users.noreply.github.com> Date: Sun, 8 Sep 2024 15:41:20 -0500 Subject: [PATCH] fix: added path compression to UnionFind (#734) * fix: implemented path compression in Find & removed unnecessary return value of Union * feat: added a few test cases * fix: modified kruskal implementation to conform to the updated Union method * fix: changed to pointer receivers --------- Co-authored-by: Rak Laptudirm --- graph/kruskal.go | 2 +- graph/unionfind.go | 45 ++++++++++++++++++++++------------------- graph/unionfind_test.go | 23 ++++++++++++++------- 3 files changed, 41 insertions(+), 29 deletions(-) diff --git a/graph/kruskal.go b/graph/kruskal.go index 193e3f236..1b0328da3 100644 --- a/graph/kruskal.go +++ b/graph/kruskal.go @@ -42,7 +42,7 @@ func KruskalMST(n int, edges []Edge) ([]Edge, int) { // Add the weight of the edge to the total cost cost += edge.Weight // Merge the sets containing the start and end vertices of the current edge - u = u.Union(int(edge.Start), int(edge.End)) + u.Union(int(edge.Start), int(edge.End)) } } diff --git a/graph/unionfind.go b/graph/unionfind.go index 7a922f3cc..42714ab39 100644 --- a/graph/unionfind.go +++ b/graph/unionfind.go @@ -3,12 +3,13 @@ // is used to efficiently maintain connected components in a graph that undergoes dynamic changes, // such as edges being added or removed over time // Worst Case Time Complexity: The time complexity of find operation is nearly constant or -//O(α(n)), where where α(n) is the inverse Ackermann function +//O(α(n)), where α(n) is the inverse Ackermann function // practically, this is a very slowly growing function making the time complexity for find //operation nearly constant. // The time complexity of the union operation is also nearly constant or O(α(n)) // Worst Case Space Complexity: O(n), where n is the number of nodes or element in the structure // Reference: https://www.scaler.com/topics/data-structures/disjoint-set/ +// https://en.wikipedia.org/wiki/Disjoint-set_data_structure // Author: Mugdha Behere[https://github.com/MugdhaBehere] // see: unionfind.go, unionfind_test.go @@ -17,43 +18,45 @@ package graph // Defining the union-find data structure type UnionFind struct { parent []int - size []int + rank []int } // Initialise a new union find data structure with s nodes func NewUnionFind(s int) UnionFind { parent := make([]int, s) - size := make([]int, s) - for k := 0; k < s; k++ { - parent[k] = k - size[k] = 1 + rank := make([]int, s) + for i := 0; i < s; i++ { + parent[i] = i + rank[i] = 1 } - return UnionFind{parent, size} + return UnionFind{parent, rank} } -// to find the root of the set to which the given element belongs, the Find function serves the purpose -func (u UnionFind) Find(q int) int { - for q != u.parent[q] { - q = u.parent[q] +// Find finds the root of the set to which the given element belongs. +// It performs path compression to make future Find operations faster. +func (u *UnionFind) Find(q int) int { + if q != u.parent[q] { + u.parent[q] = u.Find(u.parent[q]) } - return q + return u.parent[q] } -// to merge two sets to which the given elements belong, the Union function serves the purpose -func (u UnionFind) Union(a, b int) UnionFind { - rootP := u.Find(a) - rootQ := u.Find(b) +// Union merges the sets, if not already merged, to which the given elements belong. +// It performs union by rank to keep the tree as flat as possible. +func (u *UnionFind) Union(p, q int) { + rootP := u.Find(p) + rootQ := u.Find(q) if rootP == rootQ { - return u + return } - if u.size[rootP] < u.size[rootQ] { + if u.rank[rootP] < u.rank[rootQ] { u.parent[rootP] = rootQ - u.size[rootQ] += u.size[rootP] + } else if u.rank[rootP] > u.rank[rootQ] { + u.parent[rootQ] = rootP } else { u.parent[rootQ] = rootP - u.size[rootP] += u.size[rootQ] + u.rank[rootP]++ } - return u } diff --git a/graph/unionfind_test.go b/graph/unionfind_test.go index b95547649..35eea59d4 100644 --- a/graph/unionfind_test.go +++ b/graph/unionfind_test.go @@ -8,10 +8,10 @@ func TestUnionFind(t *testing.T) { u := NewUnionFind(10) // Creating a Union-Find data structure with 10 elements //union operations - u = u.Union(0, 1) - u = u.Union(2, 3) - u = u.Union(4, 5) - u = u.Union(6, 7) + u.Union(0, 1) + u.Union(2, 3) + u.Union(4, 5) + u.Union(6, 7) // Testing the parent of specific elements t.Run("Test Find", func(t *testing.T) { @@ -20,12 +20,21 @@ func TestUnionFind(t *testing.T) { } }) - u = u.Union(1, 5) // Additional union operation - u = u.Union(3, 7) // Additional union operation + u.Union(1, 5) // Additional union operation + u.Union(3, 7) // Additional union operation // Testing the parent of specific elements after more union operations t.Run("Test Find after Union", func(t *testing.T) { - if u.Find(0) != u.Find(5) || u.Find(2) != u.Find(7) { + if u.Find(0) != u.Find(5) || u.Find(1) != u.Find(4) || u.Find(2) != u.Find(7) || u.Find(3) != u.Find(6) { + t.Error("Union operation not functioning correctly") + } + }) + + u.Union(3, 7) // Repeated union operation + + // Testing that repeated union operations are idempotent + t.Run("Test Find after repeated Union", func(t *testing.T) { + if u.Find(2) != u.Find(6) || u.Find(2) != u.Find(7) || u.Find(3) != u.Find(6) || u.Find(3) != u.Find(7) { t.Error("Union operation not functioning correctly") } })