hyelie
hyelie
Hyeil Jeong
       
글쓰기    관리    수식입력
  • 전체보기 (495)
    • PS (283)
      • Algorithm (28)
      • PS Log (244)
      • Contest (6)
      • Tips (5)
    • Development (52)
      • Java (14)
      • Spring (23)
      • SQL (2)
      • Node.js (2)
      • Socket.io (3)
      • Study (4)
      • Utils (4)
    • DevOps (36)
      • Git (5)
      • Docker (4)
      • Kubernetes (2)
      • GCP (3)
      • Environment Set Up (8)
      • Tutorial (12)
      • Figma (2)
    • CS (74)
      • OOP (7)
      • OS (24)
      • DB (2)
      • Network (24)
      • Architecture (0)
      • Security (2)
      • Software Design (0)
      • Parallel Computing (15)
    • Project (15)
      • Project N2T (5)
      • Project ASG (0)
      • Project Meerkat (1)
      • Model Checking (7)
      • Ideas (2)
    • 내가 하고싶은 것! (34)
      • Plan (16)
      • Software Maestro (10)
      • 취준 (8)
hELLO · Designed By 정상우.
hyelie

hyelie

PS/PS Log

23.03.25. 풀었던 문제들

Leetcode 2316. Count Unreachable Pairs of Nodes in an Undirected Graph

 이 문제를 처음 딱 봤을 때, disjoint set의 개수를 구하고, 이 배열을 arr이라고 했을 때, index i에 대해 [arr[i] * sum of arr[i+1] to end]를 구하면 되는 것을 알 수 있다.

 disjoint set의 개수를 구하는 것은 DFS, 또는 union-find를 이용해 쉽게 구할 수 있다. 그렇지만 index i에 대해 [arr[i] * sum of arr[i+1] to end]를 어떻게 쉽게 구할 수 있을까?

 

첫 번째 접근 : DFS + prefix sum

 나는 prefix sum을 떠올렸다. i > j일 때 psum[i] - psum[j] = arr[j+1] + arr[j+2] + ... + arr[i]이다. 이것을 O(1)에 구할 수 있기 때문에 prefix sum을 이용해 쉽게 구할 수 있었다.

// Runtime 564 ms Beats 82.93%
// Memory 185.2 MB Beats 59.15%

typedef long long ll;

class Solution {
public:
    vector<bool> visited;
    vector<vector<int>> edges;
    int DFS(int v){
        int num = 1;
        for(int w : edges[v]){
            if(!visited[w]){
                visited[w] = true;
                num += DFS(w);
            }
        }
        return num;
    }
    long long countPairs(int n, vector<vector<int>>& adjacents) {
        // init
        visited.resize(n);
        fill(visited.begin(), visited.end(), false);
        edges.resize(n);
        for(vector<int>& adjacent : adjacents){
            edges[adjacent[0]].push_back(adjacent[1]);
            edges[adjacent[1]].push_back(adjacent[0]);
        }

        vector<int> counts; // counts : disjoint set의 vertex 개수 vector
        for(int i = 0; i<n; i++){
            if(!visited[i]){
                visited[i] = true;
                counts.push_back(DFS(i));
            }
        }

        // psum 쓰면 될 것 같은데.
        int csize = counts.size();
        vector<ll> psum(csize, 0); // psum[i] : 0부터 i까지 sum
        psum[0] = counts[0];
        for(int i = 1; i<csize; i++){
            psum[i] = (ll)counts[i] + (ll)psum[i-1];
        }

        // psum[i] - psum[j] : j+1 ~ i까지 sum
        ll answer = 0;
        for(int i = 0; i<csize-1; i++){
            answer += (ll)(psum[csize-1] - psum[i]) * (ll)counts[i];
        }
        return answer;
    }
};

 

시간복잡도

 DFS에 O(V + E). 이 때 V = n이고 E = worst case 2n. 그리고 counts와 psum 배열은 worst case size n이므로 O(n)이다.

 

공간복잡도

 counts, psum 배열은 O(n), stack은 worst case O(V)만큼 쌓인다. 따라서 O(n)

 

 

두 번째 접근 : DFS

 solution을 보고 깨달은 것. 굳이 prefix sum을 사용할 필요가 없다.(공간 낭비이다.)

  • 전체 배열의 sum을 저장한다.
  • 모든 i에 대해, 현재 index를 sum에서 뺀다.
    • i = 0일 때는 i = 1부터 끝까지,
    • i = 1일 때는 i = 0이 이미 sum에서 빠진 상태, 따라서 i=1의 값을 sum에서 빼면 i=2부터 끝까지,
    • ...
  • 이 방법으로 prefix sum 없이 [index i 이후에 오는 모든 배열의 합]을 O(1)로 구할 수 있다.
// Runtime 532 ms Beats 92.53%
// Memory 184.1 MB Beats 61.89%

// Runtime 564 ms Beats 82.93%
// Memory 185.2 MB Beats 59.15%


class Solution {
public:
    long long countPairs(int n, vector<vector<int>>& adjacents) {
        // ...

        ll sum = 0, answer = 0;
        for(int count : counts){
            sum += (ll) count;
        }

        for(int i = 0; i<counts.size(); i++){
            sum -= (ll)counts[i];
            answer += (ll)counts[i] * sum;
        }
        return answer;
    }
};

 

시간복잡도 & 공간복잡도

 변하지 않는다. O(n).

 

 

세 번째 접근 : Union-Find + prefix sum

 이 문제를 처음 봤을 때 disjoint set의 개수를 알아야 하므로 union-find를 이용해 disjoint set의 개수를 구하면 되겠지 싶었다. 그러나 아래 *참고의 이유로 union 시 오류가 발생했다. 예시를 보자.

 

* 참고

 Union-find에서 union을 막 사용하면 안 된다. 예를 들어 union 시, parent[rx] = ry라고 작성했다고 해 보자.

 이 때 vertex 1, 2, 3이 있고, edge [1-2], [2-3], [1-3]을 차례로 union한다고 생각해 보자. 그러면

  • 초기 상태. parent 1, 2, 3은 차례대로 1, 2, 3이다.
  • union [1-2]. parent 1, 2, 3은 차례대로 2, 2, 3이다.
  • union [2-3]. parent 1, 2, 3은 차례대로 2, 3, 3이다.
  • union [1-3]. parent 1, 2, 3은 차례대로 2, 3, 1이다.
  • 그러면 find(1)을 하면, 1 - 2 - 3 - 1 - 2- 3 - ... cycle이 만들어진다.

따라서 union할 때, 무작정 parent[rx] = ry를 사용하면 안된다! rank를 이용해 확실하게 해야 한다.

 

 rank를 이용해 union-find를 한 코드이다. 

// Runtime 460 ms Beats 97.87%
// Memory 140.7 MB Beats 95.43%

typedef long long ll;

class Solution {
public:
    vector<int> parent, rank;
    int find(int v){
        if(v == parent[v]) return v;
        parent[v] = find(parent[v]);
        return parent[v];
    }
    void Union(int x, int y){
        int rx = find(x);
        int ry = find(y);
        if(rx == ry) return;
        
        if(rank[rx] < rank[ry]) // ry의 height가 더 크면 rx가 밑으로 글어가야 함
            parent[rx] = ry;
        else{ // 그렇지 않다면 ry가 밑으로 들어가야 함
            parent[ry] = rx;
            if(rank[rx] == rank[ry]){ // rank가 같으면 rank 조절
                rank[rx]++;
            }
        }
    }
    long long countPairs(int n, vector<vector<int>>& edges) {
        // init
        parent.resize(n);
        rank.resize(n);
        for(int i = 0; i<n; i++){
            parent[i] = i;
            rank[i] = 0;
        }

        // union all
        for(vector<int>& edge : edges){
            Union(edge[0], edge[1]);
        }

        vector<int> counts(n, 0); // counts[i] : root가 i인 disjoint set의 vertex 개수
        for(int i = 0; i<n; i++){
            counts[find(i)]++;
        }

        // psum 쓰면 될 것 같은데.
        vector<ll> psum(n, 0); // psum[i] : 0부터 i까지 sum
        psum[0] = counts[0];
        for(int i = 1; i<n; i++){
            psum[i] = (ll)counts[i] + (ll)psum[i-1];
        }

        // psum[i] - psum[j] : j+1 ~ i까지 sum
        ll answer = 0;
        for(int i = 0; i<n-1; i++){
            answer += (ll)(psum[n-1] - psum[i]) * (ll)counts[i];
        }
        return answer;
    }
};

 여기에 추가로 prefix sum을 사용하지 않고, 두 번째 접근의 풀이방식을 사용한다면 더 최적화될 것이다.

 

시간복잡도

 union-find에서 find와 union은 log*n, 따라서 모든 연산에 O(nlog*n)이 걸린다. 이후 prefix sum을 계산하는 데 O(n)이 걸린다. 따라서 O(nlog*n)이다. log*n의 경우, (16, 65536]에 4, (65537, 2^{65536}]에 5이다. 따라서 이 문제의 경우 단순 상수로 봐도 무방할 것이다.

 실제로 코드 수행 시간을 보면 수행시간이 더 적은데, 이는 함수 스택이 덜 쌓이기 때문이라 유추할 수 있다.

 

공간복잡도

 union-find는 size n인 parent, rank vector를 사용하고 prefix sum에 size n인 vector를 사용한다. 따라서 O(n)

 function stack이 덜 쌓이기 때문에 DFS를 이용한 풀이와 공간복잡도는 동일하지만 그 양이 매우 크게 차이나는 것을 확인할 수 있다. 추가로 두 번째 풀이처럼 prefix sum을 사용하지 않는다면 더 적은 메모리를 사용가능하다.

 

 

후기

 unioin-find를 조심히 사용하자.

 

 

 

 

 

 

저작자표시 (새창열림)

'PS > PS Log' 카테고리의 다른 글

23.03.27. 풀었던 문제들  (0) 2023.03.27
23.03.26. 풀었던 문제들  (1) 2023.03.26
23.03.24. 풀었던 문제들  (2) 2023.03.24
23.03.23. 풀었던 문제들  (0) 2023.03.23
23.03.22. 풀었던 문제들  (0) 2023.03.22
    hyelie
    hyelie

    티스토리툴바