Skip to main content Link Menu Expand (external link) Document Search Copy Copied

1855. 영준이의 진짜 BFS

Code Battle (`23 동계 대학생 S/W 알고리즘 특강 기초학습 문제)


본 게시글은 개인 학습 용도로 작성한 게시글입니다.

(문제 출처: SW Expert Academy)


노드의 최대 개수 $n$이 최대 $100,000$이므로, $O(n^2)$ 알고리즘으로는 풀이 할 수 없는 문제이다.

이 문제의 핵심은 크게 세 가지이다.

  • 입력에서 번호가 작은 노드부터 순서대로 부모노드의 번호가 주어진다. 그리고, 부모의 노드번호가 항상 자식노드보다 작다는 제약사항이 있다.
  • Breadth-First-Search (BFS, 너비 우선 탐색) 알고리즘으로 $O(n)$시간에 탐색을 수행한다.
  • Lowest Common Ancestor (LCA, 가장 가까운 공통 조상) 알고리즘으로 이동 경로를 $O(\log{n})$ 시간에 구한다.

트리를 구성할 때, 문제의 제약사항을 활용하여 $O(n)$시간에 트리를 구성할 수 있다. 이는 입력 받고있는 노드의 부모노드는 이미 입력 받아져 있음이 보장된다는 점에 주목하면 쉽게 구현할 수 있다.

BFS를 수행하면서, 각 노드를 방문 할 때, 직전에 방문한 노드와의 LCA를 구하면 최소 이동거리를 계산할 수 있다. 이 경우 답을 구하는데 $O(n\log{n})$의 시간이 소요될 것을 기대할 수 있다.

단, 트리의 높이가 $n$에 가까운 Edge-case에서는 LCA를 찾는데에만 $O(n)$시간이 소요될 수 있다. 이 경우 풀이의 시간복잡도는 $O(n^2)$이 되어 Time Limit Exceed를 피할 수 없다.

따라서, 조금 더 개선된 형태의 LCA 알고리즘을 사용하여야 한다. Sparse Matrix를 적용한 LCA를 사용하면 안정적으로 쿼리를 $O(\log{n})$ 시간에 수행할 수 있다. 이 알고리즘은 전처리에 $O(n \log{n})$ 시간이 소요되나, 최적의 경우 $O(\log{\log{n}})$ 시간에 수행 될 것을 기대할 수 있다. 무엇보다, 최악의 경우에도 $O(\log{n})$ 시간에 안정적으로 쿼리를 수행할 수 있다는 장점이 있다.

사소한 주의할 점이 있는데, 바로 극단적인 테스트케이스에서 정답의 값이 $n^2$에 근사하게 될 수도 있다는 것이다. int 자료형을 사용하고 있을 경우에는 정답의 값이 자료형의 표현범위를 넘어 오버플로우가 발생할 수 있음에 유의해야 한다.


#include <iostream>

#define MAX_NODES 100000
#define MAX_NODES_LOG2 17
#define ROOT 1
#define UNDEFINED 0

using namespace std;

int parent[MAX_NODES+1][MAX_NODES_LOG2+1]; // parent[node][i] means node's 2^i-th parent.
int firstChild[MAX_NODES+1];
int lastChild[MAX_NODES+1];
int nextSibling[MAX_NODES+1];
int depth[MAX_NODES+1];

int treeSize;

int queueHead;
int queueTail;
int nextQueueItem[MAX_NODES+1];

void init() {
    cin >> treeSize;

    // clear tree
    for (int node = 1; node <= treeSize; node++) {
        firstChild[node] = UNDEFINED;
        lastChild[node] = UNDEFINED;
        nextSibling[node] = UNDEFINED;
        depth[node] = 0;
    }
    // clear sparse matrix
    for (int node = 1; node <= treeSize; node++) {
        for (int i = 0; i <= MAX_NODES_LOG2; i++) {
            parent[node][i] = UNDEFINED;
        }
    }
    // clear queue
    queueHead = UNDEFINED;
    queueTail = UNDEFINED;
    for (int node = 1; node <= treeSize; node++) {
        nextQueueItem[node] = UNDEFINED;
    }

    // make tree
    for (int node = 2; node <= treeSize; node++) {
        cin >> parent[node][0];
        depth[node] = depth[parent[node][0]] + 1;
        // add child node to parent node
        if (firstChild[parent[node][0]] == UNDEFINED) {
            firstChild[parent[node][0]] = node;
        }
        if (lastChild[parent[node][0]] != UNDEFINED) {
            nextSibling[lastChild[parent[node][0]]] = node;
        }
        lastChild[parent[node][0]] = node;
    }

    // pre-compute sparse matrix of parents
    for (int i = 1; i <= MAX_NODES_LOG2; i++) {
        for (int node = 2; node <= treeSize; node++) {
            if (parent[node][i-1] != UNDEFINED) {
                parent[node][i] = parent[parent[node][i-1]][i-1];
            }
        }
    }
}

int lowestCommonAncestor(int u, int v) {
    // ensure v is deeper than u.
    if (depth[v] < depth[u]) {
        swap(v, u);
    }
    // bring both nodes to the same level.
    int diff = depth[v] - depth[u];
    for (int i = 0; i <= MAX_NODES_LOG2; i++) {
        if ((diff >> i) & 1) {
            v = parent[v][i];
        }
    }
    // now depth of u is same as depth of v.
    if (u == v) {
        return u;
    }
    for (int i = MAX_NODES_LOG2; i >= 0; i--) {
        if (parent[u][i] != parent[v][i]) {
            u = parent[u][i];
            v = parent[v][i];
        }
    }
    return parent[u][0];
}

void enqueue(int i) {
    if (queueHead == UNDEFINED) {
        queueHead = i;
    }
    if (queueTail != UNDEFINED) {
        nextQueueItem[queueTail] = i;
    }
    queueTail = i;
}

int dequeue() {
    int queueTop = queueHead;
    queueHead = nextQueueItem[queueTop];
    nextQueueItem[queueTop] = UNDEFINED;
    return queueTop;
}

bool queueHasElement() {
    return queueHead != UNDEFINED;
}

long long testcase() {
    long long answer = 0;
    int curr;
    int prev = UNDEFINED;
    init();
    enqueue(ROOT);
    while (queueHasElement()) {
        curr = dequeue();
        if (prev != UNDEFINED) {
            answer += depth[prev] + depth[curr];
            answer -= 2 * depth[lowestCommonAncestor(prev, curr)];
        }
        for (int child = firstChild[curr]; child != UNDEFINED; child = nextSibling[child]) {
            enqueue(child);
        }
        prev = curr;
    }
    return answer;
}

int main()
{
    int T;
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    setbuf(stdout, NULL);
    cin >> T;
    for (int tc = 1; tc <= T; tc++) {
        cout << '#' << tc << ' ' << testcase() << '\n';
    }
    return 0;
}

Back to top