Leetcode 2316. Count Unreachable Pairs of Nodes in an Undirected Graph
이 문제를 처음 딱 봤을 때, disjoint set의 개수를 구하고, 이 배열을 arr이라고 했을 때, index i에 대해 [arr[i] * sum of arr[i+1] to end]를 구하면 되는 것을 알 수 있다.
disjoint set의 개수를 구하는 것은 DFS, 또는 union-find를 이용해 쉽게 구할 수 있다. 그렇지만 index i에 대해 [arr[i] * sum of arr[i+1] to end]를 어떻게 쉽게 구할 수 있을까?
첫 번째 접근 : DFS + prefix sum
나는 prefix sum을 떠올렸다. i > j일 때 psum[i] - psum[j] = arr[j+1] + arr[j+2] + ... + arr[i]이다. 이것을 O(1)에 구할 수 있기 때문에 prefix sum을 이용해 쉽게 구할 수 있었다.
// Runtime 564 ms Beats 82.93%
// Memory 185.2 MB Beats 59.15%
typedef long long ll;
class Solution {
public:
vector<bool> visited;
vector<vector<int>> edges;
int DFS(int v){
int num = 1;
for(int w : edges[v]){
if(!visited[w]){
visited[w] = true;
num += DFS(w);
}
}
return num;
}
long long countPairs(int n, vector<vector<int>>& adjacents) {
// init
visited.resize(n);
fill(visited.begin(), visited.end(), false);
edges.resize(n);
for(vector<int>& adjacent : adjacents){
edges[adjacent[0]].push_back(adjacent[1]);
edges[adjacent[1]].push_back(adjacent[0]);
}
vector<int> counts; // counts : disjoint set의 vertex 개수 vector
for(int i = 0; i<n; i++){
if(!visited[i]){
visited[i] = true;
counts.push_back(DFS(i));
}
}
// psum 쓰면 될 것 같은데.
int csize = counts.size();
vector<ll> psum(csize, 0); // psum[i] : 0부터 i까지 sum
psum[0] = counts[0];
for(int i = 1; i<csize; i++){
psum[i] = (ll)counts[i] + (ll)psum[i-1];
}
// psum[i] - psum[j] : j+1 ~ i까지 sum
ll answer = 0;
for(int i = 0; i<csize-1; i++){
answer += (ll)(psum[csize-1] - psum[i]) * (ll)counts[i];
}
return answer;
}
};
시간복잡도
DFS에 O(V + E). 이 때 V = n이고 E = worst case 2n. 그리고 counts와 psum 배열은 worst case size n이므로 O(n)이다.
공간복잡도
counts, psum 배열은 O(n), stack은 worst case O(V)만큼 쌓인다. 따라서 O(n)
두 번째 접근 : DFS
solution을 보고 깨달은 것. 굳이 prefix sum을 사용할 필요가 없다.(공간 낭비이다.)
- 전체 배열의 sum을 저장한다.
- 모든 i에 대해, 현재 index를 sum에서 뺀다.
- i = 0일 때는 i = 1부터 끝까지,
- i = 1일 때는 i = 0이 이미 sum에서 빠진 상태, 따라서 i=1의 값을 sum에서 빼면 i=2부터 끝까지,
- ...
- 이 방법으로 prefix sum 없이 [index i 이후에 오는 모든 배열의 합]을 O(1)로 구할 수 있다.
// Runtime 532 ms Beats 92.53%
// Memory 184.1 MB Beats 61.89%
// Runtime 564 ms Beats 82.93%
// Memory 185.2 MB Beats 59.15%
class Solution {
public:
long long countPairs(int n, vector<vector<int>>& adjacents) {
// ...
ll sum = 0, answer = 0;
for(int count : counts){
sum += (ll) count;
}
for(int i = 0; i<counts.size(); i++){
sum -= (ll)counts[i];
answer += (ll)counts[i] * sum;
}
return answer;
}
};
시간복잡도 & 공간복잡도
변하지 않는다. O(n).
세 번째 접근 : Union-Find + prefix sum
이 문제를 처음 봤을 때 disjoint set의 개수를 알아야 하므로 union-find를 이용해 disjoint set의 개수를 구하면 되겠지 싶었다. 그러나 아래 *참고의 이유로 union 시 오류가 발생했다. 예시를 보자.
* 참고
Union-find에서 union을 막 사용하면 안 된다. 예를 들어 union 시, parent[rx] = ry라고 작성했다고 해 보자.
이 때 vertex 1, 2, 3이 있고, edge [1-2], [2-3], [1-3]을 차례로 union한다고 생각해 보자. 그러면
- 초기 상태. parent 1, 2, 3은 차례대로 1, 2, 3이다.
- union [1-2]. parent 1, 2, 3은 차례대로 2, 2, 3이다.
- union [2-3]. parent 1, 2, 3은 차례대로 2, 3, 3이다.
- union [1-3]. parent 1, 2, 3은 차례대로 2, 3, 1이다.
- 그러면 find(1)을 하면, 1 - 2 - 3 - 1 - 2- 3 - ... cycle이 만들어진다.
따라서 union할 때, 무작정 parent[rx] = ry를 사용하면 안된다! rank를 이용해 확실하게 해야 한다.
rank를 이용해 union-find를 한 코드이다.
// Runtime 460 ms Beats 97.87%
// Memory 140.7 MB Beats 95.43%
typedef long long ll;
class Solution {
public:
vector<int> parent, rank;
int find(int v){
if(v == parent[v]) return v;
parent[v] = find(parent[v]);
return parent[v];
}
void Union(int x, int y){
int rx = find(x);
int ry = find(y);
if(rx == ry) return;
if(rank[rx] < rank[ry]) // ry의 height가 더 크면 rx가 밑으로 글어가야 함
parent[rx] = ry;
else{ // 그렇지 않다면 ry가 밑으로 들어가야 함
parent[ry] = rx;
if(rank[rx] == rank[ry]){ // rank가 같으면 rank 조절
rank[rx]++;
}
}
}
long long countPairs(int n, vector<vector<int>>& edges) {
// init
parent.resize(n);
rank.resize(n);
for(int i = 0; i<n; i++){
parent[i] = i;
rank[i] = 0;
}
// union all
for(vector<int>& edge : edges){
Union(edge[0], edge[1]);
}
vector<int> counts(n, 0); // counts[i] : root가 i인 disjoint set의 vertex 개수
for(int i = 0; i<n; i++){
counts[find(i)]++;
}
// psum 쓰면 될 것 같은데.
vector<ll> psum(n, 0); // psum[i] : 0부터 i까지 sum
psum[0] = counts[0];
for(int i = 1; i<n; i++){
psum[i] = (ll)counts[i] + (ll)psum[i-1];
}
// psum[i] - psum[j] : j+1 ~ i까지 sum
ll answer = 0;
for(int i = 0; i<n-1; i++){
answer += (ll)(psum[n-1] - psum[i]) * (ll)counts[i];
}
return answer;
}
};
여기에 추가로 prefix sum을 사용하지 않고, 두 번째 접근의 풀이방식을 사용한다면 더 최적화될 것이다.
시간복잡도
union-find에서 find와 union은 log*n, 따라서 모든 연산에 O(nlog*n)이 걸린다. 이후 prefix sum을 계산하는 데 O(n)이 걸린다. 따라서 O(nlog*n)이다. log*n의 경우, (16, 65536]에 4, (65537, 2^{65536}]에 5이다. 따라서 이 문제의 경우 단순 상수로 봐도 무방할 것이다.
실제로 코드 수행 시간을 보면 수행시간이 더 적은데, 이는 함수 스택이 덜 쌓이기 때문이라 유추할 수 있다.
공간복잡도
union-find는 size n인 parent, rank vector를 사용하고 prefix sum에 size n인 vector를 사용한다. 따라서 O(n)
function stack이 덜 쌓이기 때문에 DFS를 이용한 풀이와 공간복잡도는 동일하지만 그 양이 매우 크게 차이나는 것을 확인할 수 있다. 추가로 두 번째 풀이처럼 prefix sum을 사용하지 않는다면 더 적은 메모리를 사용가능하다.
후기
unioin-find를 조심히 사용하자.
'PS > PS Log' 카테고리의 다른 글
23.03.27. 풀었던 문제들 (0) | 2023.03.27 |
---|---|
23.03.26. 풀었던 문제들 (1) | 2023.03.26 |
23.03.24. 풀었던 문제들 (2) | 2023.03.24 |
23.03.23. 풀었던 문제들 (0) | 2023.03.23 |
23.03.22. 풀었던 문제들 (0) | 2023.03.22 |