【Algorithm Notes】LCA倍增求法

定义

一个点的祖先:从该节点出发, 一路向上走能碰到的就是其祖先了

两个点的公共祖先: 就是同一棵树上两个节点的祖先集合中的交集

最近公共祖先就是这个交集里面最靠下

注意到以上的表述中经常出现向上靠下等字眼, 说明求最近公共祖先的算法肯定与求节点的高度有关

朴素算法法求LCA

想象一下这个过程:

  • 两个节点先跳到同一个高度

  • 如果两节点相遇 (即原先两个节点存在祖孙关系), 该点即为LCA, 退出

  • 否则, 一起向上跳, 直到相遇

优化

以上这种一层一层跳的方法太慢了, 面对一棵巨大的树时会跑得巨慢, 我们要尝试优化这个过程

试想一下: 如果用倍增跳呢? 是不是速度立刻就上去了?

一些前置工作

  • 用倍增的思想求出节点 $x$ 往上跳 $2^p$ 步后可以到达的节点, 存在 $fa[x][p]$ 中
  • 用 $\texttt{BFS}$ 或 $\texttt{DFS}$ 求出节点 $x$ 的祖先, 存在 $deep[x]$ 中

这样, 每次跳的时候, 可以通过比较 $deep[a]$ 是否等于 $deep[b]$ 来判断是否在同一层, 并增加跳的长度, 把朴素算法的时间复杂度优化到 $log$ 级别

具体实现

根节点遍历整棵树, 顺便记录一下子节点的信息

Code
1
2
3
4
5
6
7
8
9
10
11
12
node nod = que.front(); que.pop();
for (auto i = edges[nod.x].begin(); i != edges[nod.x].end(); i++)
{
if ((*i) != nod.fa) // 该节点不为父节点
{
que.emplace(node(*i, nod.x));
deep[*i] = deep[nod.x] + 1; // 求deep[]数组
fa[*i][0] = nod.x;
for (int j = 1; (1 << j) <= deep[*i]; j++) // 倍增法求fa[]数组
fa[*i][j] = fa[fa[*i][j - 1]][j - 1];
}
}

开始愉快地跳跃

  • 先跳到同一高度

    Code
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    if (deep[a] < deep[b])
    swap(a, b);
    if (!deep[b])
    return b;
    for (int i = MAXP; i >= 0; i--)
    {
    if (deep[fa[a][i]] >= deep[b])
    a = fa[a][i];
    if (a == b)
    return a;
    }
  • 两个点一起往上跳

    Code
    1
    2
    3
    4
    5
    6
    7
    8
    9
    for (int i = MAXP; i >= 0; i--)
    {
    if (fa[a][i] != fa[b][i])
    {
    a = fa[a][i];
    b = fa[b][i];
    }
    }
    return fa[a][0];

上代码

Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <cstdio>
#include <vector>
#include <queue>
#include <iostream>
#include <cmath>

using namespace std;

struct node
{
int x, fa;
node(int xx = 0, int ffa = 0) : x(xx), fa(ffa) {}
~node() = default;
};

const int SIZE = 5e6 + 1, MAXP = 20;
vector<int> edges[SIZE];
queue<node> que;
int n, m, s;
int fa[SIZE][MAXP + 1], deep[SIZE];

void bfs(int st)
{
que.emplace(node(st, 0));
deep[st] = 0;
while (!que.empty())
{
node nod = que.front(); que.pop();
for (auto i = edges[nod.x].begin(); i != edges[nod.x].end(); i++)
{
if ((*i) != nod.fa)
{
que.emplace(node(*i, nod.x));
deep[*i] = deep[nod.x] + 1;
fa[*i][0] = nod.x;
for (int j = 1; (1 << j) <= deep[*i]; j++)
fa[*i][j] = fa[fa[*i][j - 1]][j - 1];
}
}
}
}

int lca(int a, int b)
{
if (deep[a] < deep[b])
swap(a, b);
if (!deep[b])
return b;
for (int i = MAXP; i >= 0; i--)
{
if (deep[fa[a][i]] >= deep[b])
a = fa[a][i];
if (a == b)
return a;
}
for (int i = MAXP; i >= 0; i--)
{
if (fa[a][i] != fa[b][i])
{
a = fa[a][i];
b = fa[b][i];
}
}
return fa[a][0];
}

int main()
{
int u, v;
scanf("%d%d%d", &n, &m, &s);
for (int i = 1; i < n; i++)
{
scanf("%d%d", &u, &v);
edges[u].emplace_back(v);
edges[v].emplace_back(u);
}
bfs(s);
int a, b;
for (int i = 1; i <= m; i++)
{
scanf("%d%d", &a, &b);
printf("%d\n", lca(a, b));
}
return 0;
}

时间复杂度

$$\Theta{(n + nlog_2n + mlog_2n)} = \Theta{((n + m)logn)}$$