PS/Algorithm

그래프 알고리즘 - (4) Union-find(disjoint set)

hyelie 2022. 6. 22. 12:59
이 글은 포스텍 오은진 교수님의 알고리즘(CSED331) 강의를 기반으로 재구성한 것입니다.

 

1. Disjoint Set과 Union-Find 

Disjoint Set

 disjoint set은 서로 중복되지 않은 부분 집합들로 이루어진 자료구조, "서로소 집합"이라고 이해하면 될 것 같다.

 

 

Union-Find

 union-find는 disjoint set의 표현 및 연산에 사용하는 알고리즘이다. 보통 Directed Tree를 이용해 구현하며, tree의 root는 해당 set의 대표값으로 생각하면 된다. 아래 예시에서는 {1, 2, 3, 4,}, {5, 6, 7, 8}, {9}가 각각의 set이며, 첫 번째 set의 대표값은 4, 두 번째 set의 대표값은 5, 세 번째 set의 대표값은 9이며 이 값들을 tree의 root로 생각하겠다는 것이다.

 

set을 directed tree로 생각하자.


 이렇게 설정했을 때 아래 3가지의 연산이 존재한다.

  • makeset(x) : x를 포함하는 새로운 set을 만든다.
  • find(x) : x가 속해있는 set의 root를 리턴한다.
  • union(x, y) : x가 속해있는 set과 y가 속해있는 set을 하나로 합친다.

 tree로 구현하는 이유를 생각해 보자. tree로 구현하면 find(x) 함수는 O(x가 속한 tree의 height), union 함수는 O(x가 속한 tree의 height + y가 속한 tree의 height)이다. 반면 배열로 구현할 경우 find(x)는 O(1)이지만 union 함수가 O(n)이 된다. union-find의 특성상 union(x, y)를 많이 사용하기 때문에 tree로 구현하는 것이 더 효율적이며, 추후 path compression을 사용하면 tree가 월등히 더 빨라진다.

 또한, tree를 이용할 경우 tree의 root를 해당 set의 대표값으로 이용한다는 생각에서 나온 것이 union-find를 tree로 이용한다는 접근이다.

 

 

Pseudo Code

procedure makeset(s)
    parent(x) = x
    rank(x) = 0

procedure find(x)
    while x != parent(x) :
        x = parent(x)
    end while
    return x

procedure union(x, y)
    rx = find(x)
    ry = find(y)

    if rx == ry :
        return
    end if

    if rank(rx) < rank(ry) :
        parent(rx) = ry
    end if
    else :
        parent(ry) = rx
        if rank(rx) == rank(ry) :
            rank(rx) = rank(rx) + 1
        end if
    end else

 위 pseudo code에는 앞에서 말하지 않은 rank라는 함수가 있다. rank(x)는 최적화를 위해 들어간 함수이며 x가 속해있는 set(subtree)의 height를 리턴한다. 그렇다면 rank(x)가 어떻게 최적화를 할까?

 union(x, y)에서 보면 rank가 더 작은 것을 rank가 더 큰 것에 합병시킨다. 만약 rank가 같다면 합치고 기존 rank에 1을 더해준다. 돌려 말하자면, height가 더 작은 tree를 height가 더 큰 tree에 merge시키는 것이다. 이렇게 하지 않으면 tree height가 worst O(n-1)이 될 수 있기 때문에 이런 방법을 추가한다.

 

 

Property of Union-Find

root가 아닌 x에 대해 rank(x) < rank(parent(x))이다.

  • 자명하다. parent의 height가 더 클 것이다.

x가 root가 아니면 rank(x)는 불변이다.

  • rank(x)가 바뀌는 경우는 union(x, y)를 할 때 뿐이며, 그마저도 root의 rank만 변경한다. 즉, child node의 rank는 변하지 않는다.

rank가 k인 모든 root node의 subtree에는 적어도 2^k개의 node가 있다.

  • induction으로 증명. k=0에서 1개의 node가 있다.
  • k-1에서 성립 가정, rank k인 node는 rank k-1개 2개를 합쳐야 만들 수 있다. rank k-1인 각각의 subtree는 2^(k-1)개의 node를 가지므로 2 * 2^(k-1) = 2^k이므로 성립이다.

총 n개의 node가 있을 때 rank k인 node는 최대 n/(2^k)개 존재한다.

 

즉 max rank는 logn이기 때문에 tree height <= logn이다.

  • 바로 위의 property로부터 유도된다.
  • 즉슨, tree height가 logn보다 작거나 같기 때문에 union(x, y)와 find(x)의 시간복잡도가 O(logn)이다.

 

 

 

2. Path Compression

 union(x, y)와 find(x)의 시간복잡도는 tree height에 비례하기 때문에 tree height를 짧게 유지하면 더 좋은 성능의 알고리즘을 만들 수 있다. 그 방법을 path compression이라 부르고 find(x)를 할 때 찾아지는 모든 node를 root 바로 밑에 붙일 것이다.

procedure find(x) :
    if x == parent(x) :
        return x
    end if
    else :
        parent(x) = find(parent(x)) // x의 parent를 x의 root로 변경
    end else
    return parent(x)

 

 

Property of Union-Find After Path Compression

root가 아닌 x에 대해 rank(x) < rank(parent(x))이다.

  • 자명하다. parent의 height가 더 클 것이다.

x가 root가 아니면 rank(x)는 불변이다.

  • rank(x)가 바뀌는 경우는 union(x, y)를 할 때 뿐이며, 그마저도 root의 rank만 변경한다. 즉, child node의 rank는 변하지 않는다.

rank가 k인 모든 root node의 subtree에는 적어도 2^k개의 node가 있다.

  • induction으로 증명. k=0에서 1개의 node가 있다.
  • k-1에서 성립 가정, rank k인 node는 rank k-1개 2개를 합쳐야 만들 수 있다. rank k-1인 각각의 subtree는 2^(k-1)개의 node를 가지므로 2 * 2^(k-1) = 2^k이므로 성립이다.

총 n개의 node가 있을 때 rank k인 node는 최대 n/(2^k)개 존재한다.

 

즉 max rank는 logn이기 때문에 tree height <= logn이다.

  • 바로 위의 property로부터 유도된다.
  • 즉슨, tree height가 logn보다 작거나 같기 때문에 union(x, y)와 find(x)의 시간복잡도가 O(logn)이다.

아래 2개의 property는 사라진다.

 

Time Complexity

 path compression 이후 시간복잡도를 계산해 보자. 이를 위해 log*n이라는 새 수를 정의하자. log*n은 n을 1까지 끌어내리는 데 필요한 log 연산의 개수이다.

 

$\log ^*n\begin{cases}0&if\ n\ \le \ 1\\1+\log ^*\left(\log _2n\right)&if\ n\ >\ 1\end{cases}$

 

0보다 크거나 같은 정수 k에 대해 $0 <= rank(x) <= log n$인 각각의 rank는 {k+1, k+2, ... , $2^{k}$}의 형태를 가진다. 

k = 0일 때 {1}

k = 1일 때 {2}

k = 2일 때 {3, 4}

k = 4일 때 {5, 6, ... , 16}

k = 16일 때 {17, 18, ... , $2^{16}$}

k = $2^{16}$ (= 65536)일 때 {65537, 65538, ... , $2^{65536}$}

, ...


 이제 edge를 아래와 같이 분류해 보자.
 type-1 edge는 다른 group 사이의 edge, type-2 edge는 같은 group의 edge이다. 그러면 union(x, y)와 find(x)를 m번 호출한다고 했을 때

전체 edge 수
 = type 1 edge(find m번으로 지워지는 type 1 edge 개수 == worst log*n)
    + sum of (x에 인접한 type-2 edge)
 = sum of(m번의 find call로 불러진 type-1 edge) + sum of (x에 인접한 type-2 edge)
 = O(m log*n) + sum of (x에 인접한 type-2 edge)
이다. 

 sum of (x에 인접한 type-2 edge)를 조금 자세히 구해보자. 모든 subtree의 element 개수의 상한은 2^k이다. rank k인 node 수는 property에 의해 n/(2^k)이므로 O(n log*n)이다.

 보통 vertex보다 edge 수가 더 많으므로 O(m log*n)이 된다.

 

∴ m번의 find(x), union(x, y) 연산에 걸리는 시간은 O(m log*n)이다.
1번의 연산에 걸리는 평균 시간이 O(log*n)이다.

 

소스코드 : O(log*n)

 제일 중요한 구현 방법이다. 다만 유의할 점, c++에서 find와 union, parent과 rank 이미 사용되고 있는 함수/변수이므로 앞의 것을 대문자로 바꿔 사용하자.

int Parent[1000001];
int Rank[1000001];

void Makeset(int s){
    Parent[s] = s;
    Rank[s] = 0;
    return;
}

int Find(int x){
    if(x == Parent[x]) return x;
    
    Parent[x] = Find(Parent[x]); // x의 parent를 x의 root로 설정
    return Parent[x];
}

void Union(int x, int y){
    int rx = Find(x);
    int ry = Find(y);
    if(rx == ry) return;
    
    if(Rank[rx] < Rank[ry]) // ry의 height가 더 크면 rx가 밑으로 글어가야 함
        Parent[rx] = ry;
    else{ // 그렇지 않다면 ry가 밑으로 들어가야 함
        Parent[ry] = rx;
        if(Rank[rx] == Rank[ry]){ // rank가 같으면 rank 조절
            Rank[rx]++;
        }
    }
}