ComputerScience/Algorithm

[Algorithm] Diameter Of Tree(트리의 지름)

kyungmin.yu 2019. 4. 10. 17:57

SW Expert Academy에서 트리의 지름에 관련된 문제를 풀게 되었다.

(https://swexpertacademy.com/main/code/problem/problemDetail.do?contestProbId=AV5PN4CqACYDFAUq&)

 

이 문제를 처음 봤을 때 트리의 지름, 반지름, 중심에 관련된 문제이구나 하는 생각이 들어서 문제를 푸는데 안풀린다..

아는 개념이라고 생각했었는데 잘 안풀려서 이번 기회에 복습해 보려고 한다.

 

 

이 문제의 풀이는 다음과 같다.

 

1. 모든 독립된 연결 요소들을 하나의 트리를 보고 각각의 트리에서 지름을 구한다.

 

2. 지름을 구하는 과정에서 구한 지름의 양끝점 u에서 v로 가는 경로를 탐색하면서 반지름을 구한다.

 

3. 반지름들을 이용해서 연결요소들을 연결 할 때 임의의 두 정점사이의 최대 거리가 최소가 되게 한다.

 

4. 생각해 보면 한 트리의 중심에 모두 연결 시키는 것이 최선일 수밖에 없다는 것을 알수 있고

   예외적으로 트리들을 연결하는데 드는 비용보다 반지름이 작은 경우에 대한 예외 처리를 해준다면 문제를 풀 수 있다. 

 

 

이 문제의 풀이에서 나온 용어를 정리해 보면 트리의 지름, 반지름, 중심이 있다.

그리고 트리의 중심에 대해서 알기 위해서는 이심률이라는 용어를 알아야 했다.

 

우선 이심률(Eccentricity)이란 트리의 한 정점에서 가장 먼 정점까지의 거리이다.

 

그리고 트리의 지름은 트리 상의 임의의 정점 u, v 사이의 거리들중 가장 긴 거리를 말한다.

즉, 트리에서 구할 수 있는 모든 이심률 중에서 가장 긴 것을 의미 한다.

 

Brute Force하게 트리의 지름을 구한다면 당연히 N개의 정점중에 중에 임의의 2개의 정점을 골라서 다 비교하면 된다.

이 경우에는  NC2 = N(N - 1)/2 = O(N^2)의 시간복잡도를 가지게 된다.

 

그런데 그런 방법이 아니라 BFS 또는 DFS를 2번 사용해서 트리의 지름을 구할 수도 있다.

그렇게 된다면 시간 복잡도는 O(N)이 될것이다.

 

BFS 또는 DFS를 이용해서 트리의 지름을 구하는 방법은 다음과 같다.

 

1. 임의의 한 점에서 BFS 또는 DFS를 통해서 그 점에서 가장 먼 한 점 u를 구한다.

 

2. 그 다음으로 u에서 다시 한 번  BFS 또는 DFS를 사용해서 u에서 가장 먼 점 v를 구한다.

 

3. 2에서 구해진 가장 먼 거리가 트리의 지름이 된다.

. . .
 
int bfs(int v) {
    _queue<int> q;
    q.push(v);
    path[v] = -1;
    dist[v] = 0;
    chk[v] = ++chkVal;
    mx = -1;
    int ret = 0;
    while (!q.empty()) {
        v = q.front(); q.pop();
        if (mx < dist[v]) {
            mx = dist[v];
            ret = v;
        }
        for (register int i = 0; i < G[v].size(); i++) {
            int nxt = G[v][i].v, cst = G[v][i].c;
            if (chk[nxt] == chkVal) continue;
            chk[nxt] = chkVal;
            dist[nxt] = dist[v] + cst;
            path[nxt] = v;
            q.push(nxt);
        }
    }
    return ret;
}
 
int main(){
 
   . . .
 
    for (register int i = 0; i < n; i++){
        if (chk[i]) continue;
        u[rind++] = bfs(i);
    }
 
    int res = 0;
    for (register int i = 0; i < rind; i++){
        int cur = bfs(u[i]);
        rad[i] = dist[cur];
        
        . . .
 
    }
 
. . .
 
    return 0;
}

i번째 연결 요소에서 만들어진 트리의 지름의 크기 rad[i]와 그 양 끝점 u[i]와 cur사이의 경로로 사이에서 반지름은 

경로상의 임의의 한 점에서 u[i]까지의 거리와 cur까지의 거리의 최대값들중에 최소값을 구하면 된다.

for (register int i = 0; i < rind; i++){
    int cur = bfs(u[i]);
    rad[i] = dist[cur];
    res = Max(res, rad[i]);
    while (cur != -1){
        rad[i] = Min(rad[i], Max(dist[cur], mx - dist[cur]));
        cur = path[cur];
    }
}

이 문제에서 요구하는 임의의 두 정점이 오가는데 최대시간이 최소시간이 되는 후보는 크게 3가지가 있다.

 

1. 한 연결 요소의 최대시간이 최소가 되는 경우

   이 경우에는 연결 요소에서 가장 먼 거리. 즉, rad[i]들의 최대값을 구하면된다.

 

2. 임의의 두 연결 요소를 연결했을 때 최대시간이 최소가 되는 경우

   이 경우는 rad배열을 내림차순으로 정렬해서 가장 긴 두 rad를 L의 비용으로 연결하면 된다.

 

3. 마지막으로 트리들을 연결하는데 드는 비용보다 반지름이 작은 경우

   가장 긴 rad 값이 아니라 두번째와 세번째 rad를 지나서 가장 긴 rad로 연결시키면 된다.

   (마지막 케이스때문에 계속 틀렸었는데 이것저것 계속 시도해 보다가 맞은거라

    아직도 이렇게 이해하는게 맞는지 잘 모르겠다. 누가 이 부분에 대한 정확한 풀이를 알려줬으면 좋겠다.)

 

 

아래의 코드가 풀이한 코드이다. STL을 안쓰고 풀어보려고 이것저것 잡다한게 많이 들어가서 코드가 조금 더럽다...

#include <stdio.h>
 
const int NMAX = 100005;
const int INF = 1e9;
 
template <class T>
class _vector{
private:
    T* ele;
    int cap, _size;
public:
    _vector(){
        cap = 32;
        _size = 0;
        ele = new T[cap];
    }
    ~_vector(){    delete[] ele;}
    int full(){ return cap == _size; }
    int size(){ return _size; }
    void clear(){ _size = 0; }
    void resize(int ncap){
        cap = ncap;
        T* tmp = new T[cap];
        for (register int i = 0; i < _size; i++)
            tmp[i] = ele[i];
        delete[] ele;
        ele = tmp;
    }
    void push_back(T e){
        if (full()) resize(cap * 2);
        ele[_size++] = e;
    }
    T& operator[] (int idx) const{ return ele[idx];}
};
 
template <class T>
class _queue{
private:
    T* ele;
    int cap, _size;
public:
    _queue(){
        cap = 32;
        _size = 0;
        ele = new T[cap];
    }
    ~_queue(){ delete[] ele; }
    int full(){ return cap == _size; }
    int empty(){ return _size == 0; }
    int size(){ return _size; }
    void clear(){ _size = 0; }
    void resize(int ncap){
        cap = ncap;
        T* tmp = new T[cap];
        for (register int i = 0; i < _size; i++)
            tmp[i] = ele[i];
        delete[] ele;
        ele = tmp;
    }
    void push(T e){
        if (full()) resize(cap * 2);
        ele[_size++] = e;
    }
    void pop(){
        if (empty()) return;
        _size--;
    }
    T front(){
        if (empty()) return T();
        return ele[_size - 1];
    }
};
 
template <typename T1, typename T2>
struct _pair{
    T1 v; T2 c;
    _pair(){}
    _pair(T1 v, T2 c):v(v), c(c){}
    int operator < (_pair p)const{
        return c > p.c;
    }
};
 
_vector<_pair<int, int> > G[NMAX];
int dist[NMAX], path[NMAX], chk[NMAX], u[NMAX], rad[NMAX];
int rind, chkVal, mx;
 
template <typename T>
T Max(T n1, T n2){ return n1 > n2 ? n1 : n2; }
template <typename T>
T Min(T n1, T n2){ return n1 < n2 ? n1 : n2; }
void merge(int l, int m, int r){
    int* tmp = new int[r - l + 1];
    int lp = l, p = 0, rp = m + 1;
 
    while (lp <= m && rp <= r){
        if (rad[lp] > rad[rp]) tmp[p++] = rad[lp++];
        else tmp[p++] = rad[rp++];
    }
    while (lp <= m) tmp[p++] = rad[lp++];
    while (rp <= r) tmp[p++] = rad[rp++];
 
    for (register int i = 0; i < p; i++)
        rad[l + i] = tmp[i];
    delete[] tmp;
}
void msort(int l, int r){
    if (r <= l) return;
    int m = (l + r) / 2;
    msort(l, m);
    msort(m + 1, r);
    merge(l, m, r);
}
void clear(){
    for (register int i = 0; i < NMAX; i++)
        rad[i] = dist[i] = chk[i] = path[i] = 0;
}
int bfs(int v) {
    _queue<int> q;
    q.push(v);
    path[v] = -1;
    dist[v] = 0;
    chk[v] = ++chkVal;
    mx = -1;
    int ret = 0;
    while (!q.empty()) {
        v = q.front(); q.pop();
        if (mx < dist[v]) {
            mx = dist[v];
            ret = v;
        }
        for (register int i = 0; i < G[v].size(); i++) {
            int nxt = G[v][i].v, cst = G[v][i].c;
            if (chk[nxt] == chkVal) continue;
            chk[nxt] = chkVal;
            dist[nxt] = dist[v] + cst;
            path[nxt] = v;
            q.push(nxt);
        }
    }
    return ret;
}
 
int main(){
    int T; scanf("%d", &T);
    for (int tc = 1; tc <= T; ++tc){
        rind = chkVal = 0;
        for (register int i = 0; i < NMAX; i++) G[i].clear();
        clear();
 
        int n, m, l; scanf("%d %d %d", &n, &m, &l);
        for (register int i = 0; i < m; i++){
            int from, to, cost;
            scanf("%d %d %d", &from, &to, &cost);
            G[to].push_back({ from, cost });
            G[from].push_back({ to, cost });
        }
 
        for (register int i = 0; i < n; i++){
            if (chk[i]) continue;
            u[rind++] = bfs(i);
        }
 
        int res = 0;
        for (register int i = 0; i < rind; i++){
            int cur = bfs(u[i]);
            rad[i] = dist[cur];
            res = Max(res, rad[i]);
            while (cur != -1){
                rad[i] = Min(rad[i], Max(dist[cur], mx - dist[cur]));
                cur = path[cur];
            }
        }
        
        msort(0, rind - 1);
 
        if (rind >= 2) res = Max(res, rad[0] + rad[1] + l);
        if (rind > 2)  res = Max(res, rad[1] + rad[2] + l * 2);
        printf("#%d %d\n", tc, res);
    }
    return 0;
}