Lowest Common Ancestor (LCA)

smsh0722·2022년 4월 15일
0

Tree

목록 보기
1/5

Tree에서 두 nodes u와 v의 LCA는, root로부터 가장 멀리(deepest) 있는 공통 조상이다.

그렇다면 어떻게 LCA를 구할 수 있을까?

Method 1

Naive 하게 root에서 각 node까지의 경로를 비교하여 풀 수 있다.

  1. root에서 u, v까지의 경로를 각각 배열에 저장한다
  2. 두 배열을 비교하여 얻은 공통 조상 중, root에서 가장 멀리 떨어진 것이 LCA이다.

시간 복잡도는 O(N)이다.

Method 2

또 다른 Naive 한 풀이는, Parent Pointer를 이용하는 것이다. method 1과 다르게, u 또는 v에서 시작해 root 방향으로 순회를 한다.

  1. 먼저, u에서 root로 순회하면서 ancestors를 저장한다.
  2. v에서 root로 순회할 때, u에서의 ancestors와 동일한 것이 있는지 확인한다. 이때, 가장 먼저 동일한 것이 곧 LCA이다.

시간 복잡도는 O(N) 이다.

Method 3

	   1
     /   \
    2     3
  /   \
 4     5

위와 같은 tree가 있을 때, DFS를 이용한 Euler tour를 한다고 해보자. sub-tree의 root인 r에 도달 했을 때, r의 자손인 u로 갔다가, 다시 r로 돌아와 다른 자손 v로 내려가게 된다. 즉, u와 v는 공통 조상인 r을 반드시 거치게 되는데, 이 특징을 이용해서 문제를 풀 수 있다.

먼저, 세 가지 배열이 필요하다.

  • dfs_node[]: Euler tour로 방문하는 nodes를 순서대로 저장
  • dfs_level[]: Euler tour로 방문하는 nodes의 level을 순서대로 저장
  • firstOccur[]: 각 node가 Euler tour에서 처음으로 나타난 index

위 상황에서는, 다음과 같이 배열을 생성한다.

dfs_node[]   = {1, 2, 4, 2, 5, 2, 1, 3, 1}
dfs_level[]  = {0, 1, 2, 1, 2, 1, 0, 1, 0}
firstOccur[] = {0, 1, 7, 2, 4 }

이를 이용해서 LCA(4, 3)을 구해보자. 각 node 가 처음 발생한 index는 7, 2다. 공통 조상은 2와 7 사이에 level이 가장 낮은 node일 것이다. 이에 따라 dfs_level[2:7]를 조사하면, index = 6 일 때, level = 0으로 최소인 이 node가 곧 lca가 되며, dfs_node[6]를 통해 node 1인 것을 확인할 수 있다.

이때, 구간의 최소를 하나하나 비교하여 찾는 것이 아니라, Segment Tree를 기반으로 한 Range Minimum Query(RMQ)로 찾으면 더 빠를 것이다.

시간복잡도는, DFS를 통해 배열을 생성하고 Segment Tree를 생성하는데 O(N)이고, 매 Query마다 O(logN)이다.


/* NOTE: RMQ-based LCA 
Level + SegTree
*/
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

const int INT_INF = ~(1<<31);

struct SegTree{
    size_t N;
    struct Node{
        int level;
        int nodeNum;
    };
    vector<Node> arr;
};

// num of nodes
int N;
// num of queries
int M;

vector<vector<int>> adjList;
vector<bool> visited;
vector<int> firstAppear;
vector<int> seqMemo;
vector<int> levelMemo;
SegTree st;

int level = 0;
int seqIdx = 0;

void DoMemo( int node )
{
    if ( firstAppear[node] == -1 )
        firstAppear[node] = seqIdx;
    seqMemo.push_back( node );
    levelMemo.push_back( level );
    seqIdx++;
}

void DFS( int node )
{
    visited[node] = true;
    level++;
    
    DoMemo(node);    

    for ( size_t i = 0; i < adjList[node].size(); i++ ){
        int dst = adjList[node][i];
        if ( visited[dst] == false ){
            DFS(dst);

            DoMemo(node);
        }
    }

    level--;
}

SegTree::Node BuildST( int node, int l, int r )
{
    if ( l == r ){
        return (st.arr[node] = {levelMemo[l], seqMemo[l]} );
    }

    int mid = (r-l)/2 + l;
    SegTree::Node lRst = BuildST( node*2+1, l, mid );
    SegTree::Node rRst = BuildST(node*2+2, mid+1,r);
    if ( lRst.level < rRst.level ){
        st.arr[node] = lRst;
    }
    else 
        st.arr[node] = rRst;
    return st.arr[node];
}

SegTree::Node MinST( int node, int l, int r, int tl, int tr )
{
    if ( tl <= l && r <= tr )
        return st.arr[node];
    if ( r < tl || tr < l )
        return {INT_INF, -1};
    
    int mid = (r-l)/2+l;
    SegTree::Node lRst = MinST(node*2+1, l, mid, tl, tr);
    SegTree::Node rRst = MinST(node*2+2, mid+1, r, tl, tr );
    if ( lRst.level < rRst.level )
        return lRst;
    else
        return rRst;
}

void PrintDebug( const vector<int>& arr )
{
    for ( size_t i = 0; i < arr.size(); i++ ){
        cout << arr[i] << " ";
    } cout << endl;
}

int main( void )
{
    ios_base::sync_with_stdio(false); cin.tie(NULL);

    cin >> N;
    adjList.resize(N+1);
    visited.resize(N+1, false);
    firstAppear.resize(N+1,-1);
    for ( int i = 0; i < N-1; i++ ){
        int a, b;
        cin >> a >> b;
        adjList[a].push_back(b);
        adjList[b].push_back(a);
    }

    // Memo DFS
    DFS(1);

    st.N = seqIdx;
    int h = ceil(log2(st.N));
    int size = (1<<(h+1))-1;
    st.arr.resize( size );
    BuildST(0, 0, st.N-1);

    // cout << "firstApear: ";
    // PrintDebug( firstAppear );
    // cout << "seqMemo: ";
    // PrintDebug( seqMemo );
    // cout << "levelMemo: ";
    // PrintDebug( levelMemo );

    cin >> M;
    for ( int m = 0; m < M; m++ ){
        int a, b;
        cin >> a >> b;
        int aFA = firstAppear[a];
        int bFA = firstAppear[b];
        if ( aFA > bFA )
            swap(aFA,bFA);
        
        // cout << "AF, BF: " << aFA<< " " << bFA << endl; // Debug

        SegTree::Node rst = MinST( 0, 0, st.N-1, aFA, bFA );
        cout << rst.nodeNum << "\n";
    }

    return 0;
}

Method 4

method 2에서는 root 방향으로 순회하며 서로의 ancestors를 비교한다. 그러나, 모든 ancestors를 하나하나 비교하기 때문에 빠르지 않다. 대신에, Binary Search에서 찾으려는 값이 mid보다 작으면 (l, mid-1)를 조사하고, mid보다 크면 (mid+1, r)을 조사하는 것처럼, ancestor의 조사 구간을 선택하면서 풀면 더 빠를 것이다.
먼저, 다음과 같이 각 node의 2^k 번째 조상을 미리 저장한다.

memo[i][j] = i-th node의 (2^j)-th ancestor
		   = memo[ memo[i][j-1] ][j-1]

( 0 <= j <= log ), ( log = ceil( log2(N) ) )

이를 이용하여 Binary 하게 LCA를 찾는다.

1. u 와 v의 level을 동일하게 맞춘다. 이때, u == v 라면, 다른 하나가 조상인 것이다.
2. loop j: log to 0
	if ( memo[u][j] != memo[v][j] ){
    	u = memo[u][j];
        v = memo[v][j];
    }
3. memo[u][0] 이 LCA이다. (memo[v][0])

이렇게 Dynamic Programming을 이용해서 Binary하게 LCA를 찾을 수 있는데, 이를 Binary Lifting 방법이라고 부른다.

시간 복잡도는, dp 생성에 O(NlogN)이고, 매 Query마다 O(logN)이다.
(이러한 형식의 rmq용 data structure를 sparse table 이라고도 부른다.)

/* NOTE: Binary Lifting
*/
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;

// num of nodes
int N;
// num of queries
int M;
int maxLevel;

vector<vector<int>> adjList;
// depthOfNode[node] = depth of Node's in tree, val == -1 ? unvisited
vector<int> depthOfNode;
// parents[node][i] = node's 2^i-th parent, val == -1 ? unvisited
vector<vector<int>> parents;

void DFS( int cur, int prev )
{
    depthOfNode[cur] = depthOfNode[prev]+1;
    parents[cur][0] = prev;

    for ( size_t i = 0; i < adjList[cur].size(); i++ ){
        int dst = adjList[cur][i];
        if ( depthOfNode[dst] == -1 ){
            DFS( dst, cur );
        }
    }
}

void BuildParents()
{

    for ( int level = 1; level <= maxLevel; level++ ){
        for ( int node = 1; node <= N; node++ ){
            if ( parents[node][level-1] != -1 )
                parents[node][level] = parents[parents[node][level-1]][level-1];
        }
    }
}

int LCA( int a, int b )
{
    // Set to same depth
    if ( depthOfNode[a] > depthOfNode[b] )
        swap(a, b);
    int depthDiff = depthOfNode[b] - depthOfNode[a];
    for ( int i = 0; i <= maxLevel; i++ ){
        if( depthDiff&(1<<i) )
            b = parents[b][i];
    }
    
    if ( a== b)
        return a;
    
    // find lca
    for ( int level = maxLevel; level >= 0; level-- ){
        if( parents[a][level] != parents[b][level] ){
            a = parents[a][level];
            b = parents[b][level];
        }
    }
    
    return parents[a][0];
}

int main ( void )
{
    ios_base::sync_with_stdio(false); cin.tie(NULL);

    cin >> N;
    maxLevel = log2(N);
    adjList.resize(N+1);
    depthOfNode.resize(N+1, -1);
    parents.resize(N+1, vector<int>(maxLevel+1, -1));
    for ( int i = 0; i < N-1; i++ ){
        int a, b;
        cin >> a >> b;
        adjList[a].push_back(b);
        adjList[b].push_back(a);
    }

    // DFS
    DFS(1, 1);

    // BuildTable
    BuildParents();

    // LCA
    cin >> M;
    for ( int m = 0; m < M; m++ ){
        int a, b;
        cin >> a >> b;
        cout << LCA(a, b) << "\n";
    }

    return 0;
}

0개의 댓글