题目
算法标签: d f s dfs dfs, 树的重心, 倍增优化, 换根 d p dp dp
思路
对于一棵树来说, 树的重心只有一个或者两个, 问题就是枚举删除每一条边, 算这两个数的所有重心的编号之和
如果根节点不是重心, 那么必然存在一个子树的点数 > n 2 > \frac{n}{2} >2n, 而且只存在一个子树的节点个数 > n 2 > \frac{n}{2} >2n
如何寻找重心?
按照 > n 2 > \frac{n}{2} >2n的子树走最后一个位置, 也就是说所有儿子的点数都严格小于 ≤ n 2 \le \frac{n}{2} ≤2n, 当前点一定是重心, 也就是重心是一定存在的
如何判断树中是否存在两个重心?
假设红色节点是重心. 那么有红色方向点数 ≤ n 2 \le \frac{n}{2} ≤2n,并且蓝色区域点数是 ≥ n 2 \ge \frac{n}{2} ≥2n, 如果蓝色点也是重心, 那么是相反的情况, 最终推出每个区域点数必须严格等于 n 2 \frac{n}{2} 2n(要求总点数是偶数), 并且两个重心是相邻的
在寻找 v v v子树重心的过程中时候可以使用倍增进行优化
但是上半部分如何计算?
将 u u u部分换根换到 v v v, 将上面部分换位以 v v v为根, 换完根之后只有 u u u节点自己会变, s [ u ] = n − s [ v ] s[u] = n - s[v] s[u]=n−s[v], f [ u ] f[u] f[u]看 u u u的新的重儿子是否是 v v v, 如果是 v v v那么需要更换为除了 v v v之外的重儿子
详细注释代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 300010, M = N * 2, K = 19; // K是倍增数组的层数
int n;
int h[N], e[M], ne[M], idx; // 邻接表存图
// f[u][k]表示从u节点出发,沿着重链走2^k步到达的节点
// g是f数组的备份,用于回溯
// p记录父节点,sz记录子树大小
int p[N], f[N][K], g[N][K], sz[N];
LL ans; // 存储最终答案
// 添加边
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
// 计算倍增数组,u是当前节点,son是u的重儿子
void calc_st(int u, int son) {
f[u][0] = son; // 第一步走重儿子
for (int i = 1; i < K; i++)
f[u][i] = f[f[u][i - 1]][i - 1]; // 倍增思想
}
// 计算以u为根的子树的重心,并累加到答案中
void calc_ans(int u) {
int tot = sz[u]; // 当前子树的总大小
// 从高位到低位尝试走,找到最接近子树中间位置的节点
for (int i = K - 1; i >= 0; i--)
if (f[u][i] && tot - sz[f[u][i]] <= tot / 2)
u = f[u][i];
// 累加找到的重心
ans += u;
// 如果子树大小是偶数且当前节点恰好是中间两个节点之一,则累加另一个重心
if (tot % 2 == 0 && sz[u] == tot / 2) ans += p[u];
}
// 第一次DFS,计算子树大小、父节点和重儿子
void dfs1(int u, int fa) {
int son = 0; // 重儿子初始化为0
sz[u] = 1, p[u] = fa;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
dfs1(j, u);
sz[u] += sz[j];
// 更新重儿子
if (sz[j] > sz[son]) son = j;
}
// 计算u的倍增数组
calc_st(u, son);
}
// 第二次DFS,换根DP计算每个子树的重心
void dfs2(int u, int fa) {
int s1 = 0, s2 = 0, szu = sz[u]; // s1是重儿子,s2是次重儿子
// 找出u的重儿子和次重儿子
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (sz[j] >= sz[s1]) s2 = s1, s1 = j;
else if (sz[j] > sz[s2]) s2 = j;
}
// 遍历所有儿子
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
// 计算当前子树j的重心
calc_ans(j);
// 换根操作:将j作为新的根
p[u] = j, sz[u] = n - sz[j];
// 如果j不是重儿子,那么u的新重儿子是s1
// 否则是次重儿子s2
if (j != s1) calc_st(u, s1);
else calc_st(u, s2);
// 计算以u为根的子树的重心
calc_ans(u);
// 递归处理子树
dfs2(j, u);
}
// 恢复现场
p[u] = fa, sz[u] = szu;
memcpy(f[u], g[u], sizeof f[u]);
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
// 初始化邻接表
scanf("%d", &n);
memset(h, -1, sizeof h), idx = 0;
// 读入树结构
for (int i = 0; i < n - 1; i++) {
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
// 第一次DFS预处理
dfs1(1, -1);
ans = 0;
// 备份f数组,用于后续恢复
memcpy(g, f, sizeof f);
// 第二次DFS计算答案
dfs2(1, -1);
printf("%lld\n", ans);
}
return 0;
}
精简注释代码
#pragma GCC optimize(2)
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long LL;
const int N = 300010, K = 19, INF = 0x3f3f3f3f;
int n;
vector<int> head[N];
int p[N], f[N][K], g[N][K], sz[N];
LL ans;
void add(int u, int v) {
head[u].push_back(v);
}
// 倍增向儿子节点走
void calc_f(int u, int son) {
f[u][0] = son;
for (int i = 1; i < K; ++i) f[u][i] = f[f[u][i - 1]][i - 1];
}
// 计算以u为根节点子树的重心
void calc_ans(int u) {
int s = sz[u];
for (int i = K - 1; i >= 0; --i) {
if (f[u][i] && s - sz[f[u][i]] <= s / 2) u = f[u][i];
}
ans += u;
if (s % 2 == 0 && sz[u] == s / 2) ans += p[u];
}
void dfs1(int u, int fa) {
int son = 0;
sz[u] = 1, p[u] = fa;
for (int v : head[u]) {
if (v == fa) continue;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v] > sz[son]) son = v;
}
// 计算当前节点沿着重儿子走能到达的节点
calc_f(u, son);
}
// 换根
void dfs2(int u, int fa) {
int s1 = 0, s2 = 0, u_sz = sz[u];
for (int v : head[u]) {
if (sz[v] > sz[s1]) {
s2 = s1;
s1 = v;
}
else if (sz[v] > sz[s2]) s2 = v;
}
for (int v : head[u]) {
if (v == fa) continue;
// 计算当前儿子的重心
calc_ans(v);
// 换根
p[u] = v, sz[u] = n - sz[v];
// 如果当前v是重儿子向第二大的儿子走, 否则直接走v
v == s1 ? calc_f(u, s2) : calc_f(u, s1);
calc_ans(u);
dfs2(v, u);
}
// 恢复现场
p[u] = fa, sz[u] = u_sz;
// 因为换根后倍增更改了f[u]数组, 将f数组恢复
memcpy(f[u], g[u], sizeof g[u]);
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int T;
cin >> T;
while (T--) {
cin >> n;
for (int i = 0; i <= n; ++i) head[i].clear();
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
add(u, v), add(v, u);
}
dfs1(1, -1);
ans = 0;
memcpy(g, f, sizeof f);
dfs2(1, -1);
cout << ans << "\n";
}
return 0;
}