树上启发式合并(DSU on Tree)详细讲解
一、引言:什么是树上启发式合并?
树上启发式合并(DSU on Tree,即 Disjoint Set Union on Tree),也称 Sack 或 轻重链启发式合并,是一种用于高效处理离线树上子树查询的技巧。
它能够在 O(n log n)
的时间复杂度内,解决许多“对每个节点,统计其子树中满足某种条件的节点数量”的问题。
二、核心思想
2.1 朴素做法的问题
假设我们有一个问题:
对每个节点
u
,求出以u
为根的子树中,不同颜色的节点个数。
最直接的做法是:
- 对每个节点
u
,DFS 遍历其子树,用一个集合(如set
或map
)记录所有出现的颜色。 - 时间复杂度为
O(n^2 log n)
,对于n = 10^5
来说不可接受。
2.2 启发式合并的核心思想
我们观察到:子树之间是独立的,可以合并信息。
但如果我们每次都新建一个数据结构来统计子树信息,代价太高。
启发式合并的关键是:
保留重儿子的信息,只合并轻儿子的信息。
因为从任意节点到根的路径上,最多只有 O(log n)
条轻边,所以每个节点最多被“合并” O(log n)
次。
从而总复杂度为 O(n log n)
。
三、前置知识:重链剖分(轻重儿子)
3.1 定义
- 子树大小(size):以
u
为根的子树包含的节点数。 - 重儿子(Heavy Child):子树大小最大的儿子。
- 轻儿子(Light Child):其他儿子。
- 重链:由重儿子连接形成的链。
void dfs_size(int u, int parent) {
sz[u] = 1;
heavy[u] = -1; // 没有重儿子
for (int v : adj[u]) {
if (v == parent) continue;
dfs_size(v, u);
sz[u] += sz[v];
if (heavy[u] == -1 || sz[v] > sz[heavy[u]])
heavy[u] = v;
}
}
四、算法流程(DSU on Tree)
4.1 步骤概览
对每个节点 u
执行以下操作:
- 递归处理所有轻儿子,并清除它们留下的信息。
- 递归处理重儿子,并保留它的信息。
- 将轻儿子的信息合并到当前节点。
- 加入当前节点本身的信息。
- 回答关于
u
子树的查询。
4.2 关键点
- 只有重儿子的信息被保留,其余都临时计算后合并。
- 使用一个全局数据结构(如数组
cnt[]
)来维护当前正在统计的信息。 - 在处理轻儿子时,要记得在 DFS 返回后清空其贡献。
五、C++ 实现模板
我们以经典问题为例:
对每个节点
u
,求其子树中出现次数最多的颜色的出现次数(或所有颜色的频次统计)。
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
typedef long long ll;
const int MAXN = 100010;
vector<int> adj[MAXN];
int color[MAXN];
int sz[MAXN]; // 子树大小
int heavy[MAXN]; // 重儿子
ll cnt[MAXN]; // cnt[c] = 颜色c的出现次数(全局统计)
ll max_cnt = 0; // 当前最大频次
ll sum = 0; // 当前最大频次的颜色的权值和(可选)
int n;
// 第一步:计算子树大小和重儿子
void dfs_size(int u, int par) {
sz[u] = 1;
heavy[u] = -1;
for (int v : adj[u]) {
if (v == par) continue;
dfs_size(v, u);
sz[u] += sz[v];
if (heavy[u] == -1 || sz[v] > sz[heavy[u]])
heavy[u] = v;
}
}
// 清除子树贡献
void add(int u, int par, int delta) {
// delta = 1 表示添加,-1 表示删除
cnt[color[u]] += delta;
// 维护最大频次(可选)
if (delta == 1) {
if (cnt[color[u]] > max_cnt) {
max_cnt = cnt[color[u]];
sum = color[u];
} else if (cnt[color[u]] == max_cnt) {
sum += color[u];
}
} else if (delta == -1) {
if (cnt[color[u]] + 1 == max_cnt) {
if (color[u] == sum) {
// 需要重新计算 max_cnt 和 sum
// 但这里我们只在主函数中统一处理
}
}
}
for (int v : adj[u]) {
if (v == par || v == heavy[u]) continue; // 跳过父节点和重儿子(在主流程中处理)
add(v, u, delta);
}
}
// 主函数:DSU on Tree
void dfs(int u, int par, bool keep) {
// 1. 先处理所有轻儿子,并清除它们的贡献
for (int v : adj[u]) {
if (v == par || v == heavy[u]) continue;
dfs(v, u, false); // 不保留信息
}
// 2. 处理重儿子,保留其贡献
if (heavy[u] != -1) {
dfs(heavy[u], u, true);
}
// 3. 将轻儿子的信息加进来
for (int v : adj[u]) {
if (v == par || v == heavy[u]) continue;
add(v, u, 1);
}
// 4. 加入当前节点
cnt[color[u]]++;
if (cnt[color[u]] > max_cnt) {
max_cnt = cnt[color[u]];
sum = color[u];
} else if (cnt[color[u]] == max_cnt) {
sum += color[u];
}
// 5. 此时 cnt[] 中保存的是 u 子树的完整信息
// 可以回答查询:比如 ans[u] = max_cnt 或 sum
cout << "Node " << u << ": max frequency = " << max_cnt
<< ", sum of max-colors = " << sum << endl;
// 6. 如果不需要保留信息(即当前是轻儿子),则清空整个子树贡献
if (!keep) {
add(u, par, -1); // 删除整个子树的贡献
max_cnt = 0;
sum = 0;
}
}
六、完整测试代码
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> color[i];
}
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs_size(1, -1);
dfs(1, -1, false); // 最终根节点也不保留(可设true,但无影响)
return 0;
}
✅ 这是均摊分析的结果。
七、支持的操作类型
DSU on Tree 适用于以下类型的子树查询:
问题类型 | 是否支持 |
---|---|
子树中不同颜色数量 | ✅ |
某颜色出现次数 | ✅ |
出现频率最高的颜色 | ✅ |
子树中众数(mode) | ✅ |
某值是否在子树中出现 | ✅ |
子树中满足某条件的节点数 | ✅ |
❌ 不适用于:
- 需要支持修改的在线查询(可用树链剖分 + 线段树)
- 路径查询(非子树)
- 动态树
八、优化技巧
8.1 使用 vector 替代 map
如果颜色值域大但实际使用少,可以用 map
,但通常颜色编号可离散化。
// 离散化颜色
vector<int> colors;
// ... 收集所有 color[i]
sort(colors.begin(), colors.end());
colors.erase(unique(colors.begin(), colors.end()), colors.end());
for (int i = 1; i <= n; i++) {
color[i] = lower_bound(colors.begin(), colors.end(), color[i]) - colors.begin();
}
然后 cnt
数组大小为 colors.size()
。
8.2 避免递归清除时栈溢出
对于深树,add()
函数递归可能爆栈。可改用栈或 BFS。
九、扩展应用
9.1 查询子树中出现至少 k 次的颜色数量
只需维护 freq[cnt]
数组,表示频次为 cnt
的颜色有多少种。
9.2 子树中众数的最小编号
在 add
函数中额外记录。
9.3 多种属性联合查询
如:颜色 + 深度,可用 map<pair<int,int>, int>
,但注意复杂度。
十、常见错误与调试建议
错误 | 说明 |
---|---|
忘记跳过父节点 | 导致死循环 |
忘记跳过重儿子 | 重复添加 |
keep=false 时未清除 |
内存污染 |
cnt[] 数组未清零 |
多组数据出错 |
颜色未离散化 | 数组越界 |