Introduction to Union-Find
Union-Find (Disjoint Set Union) is a data structure that efficiently manages a partition of elements into disjoint sets. It supports two operations: finding which set an element belongs to, and uniting two sets.
Core Operations
Find
Determine which set an element belongs to.
Union
Merge two sets into one.
Basic Implementation
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
"""Find root of x's set"""
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Path compression
return self.parent[x]
def union(self, x, y):
"""Merge sets containing x and y"""
root_x = self.find(x)
root_y = self.find(y)
if root_x == root_y:
return False
# Union by rank
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
elif self.rank[root_x] > self.rank[root_y]:
self.parent[root_y] = root_x
else:
self.parent[root_y] = root_x
self.rank[root_x] += 1
return True
def connected(self, x, y):
"""Check if x and y are in same set"""
return self.find(x) == self.find(y)
Optimizations
Path Compression
Make every node point directly to root during find.
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # Compress path
return self.parent[x]
Union by Rank
Attach smaller tree under larger tree.
def union(self, x, y):
root_x, root_y = self.find(x), self.find(y)
if self.rank[root_x] < self.rank[root_y]:
self.parent[root_x] = root_y
else:
self.parent[root_y] = root_x
if self.rank[root_x] == self.rank[root_y]:
self.rank[root_x] += 1
Time Complexity
- Without optimizations: O(n) per operation
- With path compression + union by rank: O(α(n)) per operation
- α(n) is inverse Ackermann function, effectively constant
When to Use Union-Find
- Connected components in graphs
- Cycle detection in undirected graphs
- Minimum spanning trees (Kruskal's algorithm)
- Network connectivity
- Dynamic connectivity problems
Example: Number of Connected Components
def countComponents(n, edges):
"""Count connected components in undirected graph"""
uf = UnionFind(n)
for u, v in edges:
uf.union(u, v)
# Count unique roots
return len(set(uf.find(i) for i in range(n)))
# Example
print(countComponents(5, [[0,1], [1,2], [3,4]]))
# 2 components: {0,1,2} and {3,4}
Example: Graph Valid Tree
def validTree(n, edges):
"""Check if edges form a valid tree"""
# Tree has n-1 edges and no cycles
if len(edges) != n - 1:
return False
uf = UnionFind(n)
for u, v in edges:
if not uf.union(u, v):
return False # Cycle detected
return True
print(validTree(5, [[0,1], [0,2], [0,3], [1,4]]))
# True - forms a tree
Key Takeaway
Union-Find is the go-to data structure for managing disjoint sets and detecting connectivity. With path compression and union by rank, operations are nearly constant time, making it extremely efficient for large datasets.