树上启发式合并
\(\text{By DaiRuiChen007}\)
一、算法简介
在解决树上问题时,我们经常遇到需要统计多个节点各自的子树信息的情况,对于一般暴力统计的 \(\Theta(n^2)\) 复杂度,对于解决题目往往是不够的,这时,我们可以考虑采用 DSU on Tree 即树上启发式合并(或称静态链分治)的算法将其优化到 \(\Theta(\log n)\) 的复杂度
DSU on Tree 相对于平凡的暴力的优化在于:DSU on Tree 在每次递归时,会保留其重儿子的答案并且加到当前节点的贡献中
实现 DSU on Tree 一般分为四步执行:
- 递归解决该节点轻儿子的问题,不保留当前答案
- 递归解决该节点重儿子的问题,保留当前答案
- 统计该节点所有轻儿子的子树的贡献,得出最终答案
- 如果当前答案不需要保留,清空该子树的答案
DSU on Tree 的核心在于函数:add
和 del
,他们的作用分别是统计或删除某一个节点的贡献,在这点上 DSU on Tree 和莫队较为相似
关于 DSU on Tree 的时间复杂度:由于每个儿子被考虑的次数等同于这个节点到根上的轻边数量,显然,这个复杂度是 \(\Theta(\log n)\) 的,所以其总体复杂度为 \(\Theta(n\log n)\)
关于对树上路径的统计,除了点分治算法之外,DSU on Tree 也不失为一种强力的算法,我们可以将树上路径 \(u\to v\) 拆成 \(u\to\operatorname{LCA}(u,v)\) 和 \(\operatorname{LCA}(u,v)\to v\),每次统计答案的时候将当前节点作为 \(\operatorname{LCA}(u,v)\),对于每一个子节点 \(u\),查询在 \(u\) 以外的子树中的 \(v\) 的个数,注意由于我们需要保证树上路径为简单路径,所以我们选择的路径 \((u,v)\) 不能在当前节点的同一棵子树内,要先统计这棵子树做为 \(u\) 时的答案,然后再将这棵子树里的所有节点作为可能的 \(v\) 计算贡献
二、典例分析
I. [CodeForces375D] - Tree and Queries
思路分析
用 \(cnt_i\) 维护当前颜色 \(i\) 出现的次数,用 \(sum_x\) 维护出现次数 \(\ge x\) 的颜色个数,那么我们可以在每次加减 \(cnt\) 的同时加减 \(sum\),得到如下的贡献统计函数
时间复杂度 \(\Theta(n\log n)\)
inline void add(int color) {
++cnt[color];
++sum[cnt[color]];
}
inline void del(int color) {
--sum[cnt[color]];
--cnt[color];
}
代码呈现
#include <bits/stdc++.h>
using namespace std;
const int MAXN=1e5+1;
int n,m;
int col[MAXN],cnt[MAXN],sum[MAXN],ans[MAXN];
vector <int> edge[MAXN];
struct node {
int k,id;
};
vector <node> query[MAXN];
inline void add(int color) {
++cnt[color];
++sum[cnt[color]];
}
inline void del(int color) {
--sum[cnt[color]];
--cnt[color];
}
int fa[MAXN],siz[MAXN],son[MAXN],id[MAXN],rnk[MAXN],idcnt;
inline void dfs(int p,int f) {
fa[p]=f,siz[p]=1;
id[p]=++idcnt,rnk[idcnt]=p;
for(int v:edge[p]) {
if(v==f) continue;
dfs(v,p);
siz[p]+=siz[v];
if(siz[v]>siz[son[p]]) son[p]=v;
}
}
inline void dsu(int p,bool rsv) {
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
dsu(v,false);
}
if(son[p]) dsu(son[p],true);
add(col[p]);
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
for(int x=id[v];x<id[v]+siz[v];++x) add(col[rnk[x]]);
}
for(auto t:query[p]) ans[t.id]=sum[t.k];
if(!rsv) for(int x=id[p];x<id[p]+siz[p];++x) del(col[rnk[x]]);
}
signed main() {
scanf("%d%d",&n,&m);
for(int i=1;i<=n;++i) scanf("%d",&col[i]);
for(int i=1;i<n;++i) {
int u,v;
scanf("%d%d",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
for(int i=1;i<=m;++i) {
int u,k;
scanf("%d%d",&u,&k);
query[u].push_back((node){k,i});
}
dfs(1,0);
dsu(1,true);
for(int i=1;i<=m;++i) printf("%d\n",ans[i]);
return 0;
}
拓展解法
记录每个节点的 dfn 序,查询点 \(p\) 的答案时候等价于在区间 \([dfn_p,dfn_p+siz_p-1]\) 上做查询,也可以使用莫队解决,时间复杂度 \(\Theta(n\sqrt n)\)
另解代码
#include <bits/stdc++.h>
using namespace std;
const int MAXN=1e5+1;
int n,m,block;
int col[MAXN],cnt[MAXN],sum[MAXN],ans[MAXN];
int siz[MAXN],id[MAXN],rnk[MAXN],idcnt;
vector <int> edge[MAXN];
struct node {
int l,r,k,id;
inline friend bool operator <(const node &x,const node &y) {
if(x.l/block==y.l/block) return x.r<y.r;
else return x.l<y.l;
}
} q[MAXN];
inline void add(int pos) {
int color=col[rnk[pos]];
++cnt[color];
++sum[cnt[color]];
}
inline void del(int pos) {
int color=col[rnk[pos]];
--sum[cnt[color]];
--cnt[color];
}
inline void dfs(int p,int f) {
siz[p]=1,id[p]=++idcnt,rnk[idcnt]=p;
for(int v:edge[p]) {
if(v==f) continue;
dfs(v,p);
siz[p]+=siz[v];
}
}
signed main() {
scanf("%d%d",&n,&m);
block=pow(n,0.455);
for(int i=1;i<=n;++i) scanf("%d",&col[i]);
for(int i=1;i<n;++i) {
int u,v;
scanf("%d%d",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(1,0);
for(int i=1;i<=m;++i) {
int u,k;
scanf("%d%d",&u,&k);
q[i]=(node){id[u],id[u]+siz[u]-1,k,i};
}
sort(q+1,q+m+1);
int l=1,r=0;
for(int i=1;i<=m;++i) {
while(l<q[i].l) del(l++);
while(l>q[i].l) add(--l);
while(r>q[i].r) del(r--);
while(r<q[i].r) add(++r);
l=q[i].l,r=q[i].r,ans[q[i].id]=sum[q[i].k];
}
for(int i=1;i<=m;++i) printf("%d\n",ans[i]);
return 0;
}
II. [CodeForces600E] Most gelral
思路分析
类似上一题,用 \(sum_x\) 维护其出现次数 \(\ge x\) 的颜色的编号和,用 \(lst\) 维护出现次数最多的颜色的出现次数,每次插入删除的时候顺便更新一下 \(lst\) 即可(更像莫队了。。。),得到如下贡献统计:
inline void add(int clr) {
++cnt[clr];
if(!sum[cnt[clr]]) lst=cnt[clr];
sum[cnt[clr]]+=clr;
}
inline void del(int clr) {
sum[cnt[clr]]-=clr;
if(!sum[cnt[clr]]) lst=cnt[clr]-1;
--cnt[clr];
时间复杂度 \(\Theta(n\log n)\)
代码呈现
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN=1e5+1;
int n;
int col[MAXN],cnt[MAXN],sum[MAXN],ans[MAXN],lst=0;
vector <int> edge[MAXN];
struct node {
int k,id;
};
vector <node> query[MAXN];
inline void add(int clr) {
++cnt[clr];
if(!sum[cnt[clr]]) lst=cnt[clr];
sum[cnt[clr]]+=clr;
}
inline void del(int clr) {
sum[cnt[clr]]-=clr;
if(!sum[cnt[clr]]) lst=cnt[clr]-1;
--cnt[clr];
}
int fa[MAXN],siz[MAXN],son[MAXN],id[MAXN],rnk[MAXN],idcnt;
inline void dfs(int p,int f) {
fa[p]=f,siz[p]=1;
id[p]=++idcnt,rnk[idcnt]=p;
for(int v:edge[p]) {
if(v==f) continue;
dfs(v,p);
siz[p]+=siz[v];
if(siz[v]>siz[son[p]]) son[p]=v;
}
}
inline void dsu(int p,bool rsv) {
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
dsu(v,false);
}
if(son[p]) dsu(son[p],true);
add(col[p]);
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
for(int x=id[v];x<id[v]+siz[v];++x) add(col[rnk[x]]);
}
ans[p]=sum[lst];
if(!rsv) for(int x=id[p];x<id[p]+siz[p];++x) del(col[rnk[x]]);
}
signed main() {
scanf("%lld",&n);
for(int i=1;i<=n;++i) scanf("%lld",&col[i]);
for(int i=1;i<n;++i) {
int u,v;
scanf("%lld%lld",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(1,0);
dsu(1,true);
for(int i=1;i<=n;++i) printf("%lld ",ans[i]);
puts("");
return 0;
}
III. [CodeForces246E] Blood Cousins Return
思路分析
对于到根节点的每个深度维护一个桶记录当前子树对应深度的名字,由于要满足不可重,所以这里直接采用 set
维护,插入删除等价于在 set
中执行 insert
或 delete
操作,注意查询答案是在深度为 \(dep_{u}+k\) 的桶里查询,为了防止访问的时候出现数组越界,建议把桶的下标开到两倍
注意原题是森林,不妨把 \(0\) 直接当成根节点
时间复杂度 \(\Theta(n\log^2 n)\)
代码呈现
#include <bits/stdc++.h>
using namespace std;
const int MAXN=1e5+1;
int n,m,ans[MAXN],root;
int dep[MAXN],fa[MAXN],siz[MAXN],son[MAXN],id[MAXN],rnk[MAXN],idcnt;
string str[MAXN];
set <string> ver[MAXN<<1];
vector <int> edge[MAXN];
struct node {
int k,id;
};
vector <node> query[MAXN];
inline void add(int id) {
ver[dep[id]].insert(str[id]);
}
inline void del(int id) {
ver[dep[id]].erase(str[id]);
}
inline void dfs(int p,int f) {
fa[p]=f,siz[p]=1,dep[p]=dep[f]+1;
id[p]=++idcnt,rnk[idcnt]=p;
for(int v:edge[p]) {
if(v==f) continue;
dfs(v,p);
siz[p]+=siz[v];
if(siz[v]>siz[son[p]]) son[p]=v;
}
}
inline void dsu(int p,bool rsv) {
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
dsu(v,false);
}
if(son[p]) dsu(son[p],true);
add(p);
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
for(int x=id[v];x<id[v]+siz[v];++x) add(rnk[x]);
}
for(auto t:query[p]) ans[t.id]=ver[dep[p]+t.k].size();
if(!rsv) for(int x=id[p];x<id[p]+siz[p];++x) del(rnk[x]);
}
signed main() {
cin>>n;
for(int i=1;i<=n;++i) {
int f;
cin>>str[i]>>f;
edge[f].push_back(i);
edge[i].push_back(f);
}
cin>>m;
for(int i=1;i<=m;++i) {
int u,k;
cin>>u>>k;
query[u].push_back((node){k,i});
}
dfs(0,0);
dsu(0,true);
for(int i=1;i<=m;++i) printf("%d\n",ans[i]);
return 0;
}
IV. [Codeforces570D] Tree Requests
思路分析
对于每个深度 \(d\) 统计当前子树中深度为 \(d\) 的节点里字母 a
到 z
分别出现多少次,想要让他们重排后得到一个回文串,其充分必要条件是出现次数为奇数的字符数量 \(\le 1\),和上一题类似,查答案的时候暴力即可,也可以维护对应的奇数出现次数,时间复杂度 \(\Theta(n\log n)\)
代码呈现
#include<bits/stdc++.h>
using namespace std;
const int MAXN=5e5+1;
int n,m;
int cnt[MAXN][26],siz[MAXN],son[MAXN],id[MAXN],rnk[MAXN],fa[MAXN],dep[MAXN],idcnt;
bool ans[MAXN];
char ch[MAXN];
struct node {
int k,id;
};
vector <int> edge[MAXN];
vector <node> query[MAXN];
inline void add(int pos) {
++cnt[dep[pos]][ch[pos]-'a'];
}
inline void del(int pos) {
--cnt[dep[pos]][ch[pos]-'a'];
}
inline bool check(int dpth) {
bool use=false;
for(int i=0;i<26;++i) {
if(cnt[dpth][i]&1) {
if(use) return false;
else use=true;
}
}
return true;
}
inline void dfs(int p,int f) {
siz[p]=1,fa[p]=f,dep[p]=dep[f]+1;
id[p]=++idcnt,rnk[idcnt]=p;
for(int v:edge[p]) {
if(v==f) continue;
dfs(v,p);
siz[p]+=siz[v];
if(siz[v]>siz[son[p]]) son[p]=v;
}
}
inline void dsu(int p,bool rsv) {
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
dsu(v,false);
}
if(son[p]) dsu(son[p],true);
add(p);
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
for(int x=id[v];x<id[v]+siz[v];++x) add(rnk[x]);
}
for(node t:query[p]) ans[t.id]=check(t.k);
if(!rsv) for(int x=id[p];x<id[p]+siz[p];++x) del(rnk[x]);
return ;
}
signed main() {
scanf("%d%d",&n,&m);
for(int i=2;i<=n;++i) {
int f;
scanf("%d",&f);
edge[f].push_back(i);
edge[i].push_back(f);
}
scanf("%s",ch+1);
for(int i=1;i<=m;++i) {
int u,k;
scanf("%d%d",&u,&k);
query[u].push_back((node){k,i});
}
dfs(1,0);
dsu(1,true);
for(int i=1;i<=m;++i) {
if(ans[i]) puts("Yes");
else puts("No");
}
return 0;
}
V. [计蒜客A1082] - 青云的机房组网方案
思路分析
数论+树论=?
设 \(m=\max\limits_{i=1}^{n}\{a_i\}\le10^5\),\(p_1,p_2\sim,p_k\) 为 \(1\sim m\) 中的所有质数
设 \(\mathbf P_{x}\) 表示所有满足 \(x|a_i\) 的点两两之间构成的路径集合,\(d(\mathbf P_x)\) 表示 \(\mathbf P_x\) 中所有路径的长度之和,则有:
\text{Answer}
&=d(\mathbf P_1)-d(\mathbf P_{p_1}\cup\mathbf P_{p_2}\cup\cdots\cup\mathbf P_{p_m})\\
&=d(\mathbf P_1)-\sum d(\mathbf P_{p_i})+\sum d(\mathbf P_{p_i}\cap\mathbf P_{p_j})-\sum d(\mathbf P_{p_i}\cap \mathbf P_{p_j}\cap\mathbf P_{p_j})+\cdots\\
&=d(\mathbf P_1)-\sum d(\mathbf P_{p_i})+\sum d(\mathbf P_{p_i\times p_j})-\sum d(\mathbf P_{p_i\times p_j\times p_k})+\cdots\\
\end{aligned}
\]
对于 \(1\sim m\) 中的每个数 \(x\),考虑 \(d(\mathbf P_x)\) 对答案的贡献 \(\lambda(x)\),其中 \(\lambda(x)\) 是一个特殊的数论函数
则又有:
\]
比较一下上下两式的系数可得:
- 若 \(x=1\),则 \(\lambda(x)=1\)
- 若 \(x\) 中含有两个相同的质因子,则 \(\lambda(x)=0\)
- 若 \(x\) 中仅含有 \(k\) 个互不相同的质因子,则 \(\lambda(x)=(-1)^k\)
通过比对,可以发现 \(\lambda(x)=\mu(x)\)
因此原式化为:
\]
所以对于某条路径 \(u\to v\),设 \(p=\operatorname{LCA}(u,v)\),其对答案的贡献应该为:
\]
对于类似树上路径统计的问题,我们优先考虑 DSU on Tree,枚举 \(\operatorname{LCA}(u,v)\) 和 \(u\),则对于所有 \(v\) 的贡献之和应该为:
\]
考虑把 \(dep_u-2\times dep_p\) 的贡献提取出来,我们开两个桶分别维护:
cnt_{x,0}&=\sum\limits_{v} [x|a_v]\\
cnt_{x,1}&=\sum\limits_{v} [x|a_v]\times dep_v
\end{aligned}
\]
统计所有 \(v\) 的贡献之和可以转为:
\]
所以我们只需要在每次插入/删除/统计节点的时候枚举 \(a_u\) 的每个子节点的贡献即可
设 \(d(x)\) 表示 \(x\) 的因子个数,则时间复杂度 \(\Theta(\log n\times \sum d(a_i))\)
代码呈现
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MAXN=1e5+1;
int mu[MAXN],a[MAXN],n;
int id[MAXN],rnk[MAXN],siz[MAXN],fa[MAXN],dep[MAXN],son[MAXN],idcnt;
int cnt[MAXN][2],ans;
bool isc[MAXN];
vector <int> edge[MAXN],factor[MAXN];
inline void init() {
for(int i=1;i<MAXN;++i) mu[i]=1;
for(int i=2;i<MAXN;++i) {
if(isc[i]) continue;
for(int j=i;j<MAXN;j+=i) {
isc[j]=true;
if(j%(i*i)!=0) mu[j]*=-1;
else mu[j]=0;
}
}
for(int i=1;i<MAXN;++i) {
for(int j=i;j<MAXN;j+=i) {
factor[j].push_back(i);
}
}
}
inline void dfs(int p,int f) {
id[p]=++idcnt,rnk[idcnt]=p;
siz[p]=1,fa[p]=f,dep[p]=dep[f]+1;
for(int v:edge[p]) {
if(v==f) continue;
dfs(v,p);
siz[p]+=siz[v];
if(siz[v]>siz[son[p]]) son[p]=v;
}
}
inline void modify(int u,int op) {
for(int x:factor[a[u]]) {
cnt[x][0]+=op;
cnt[x][1]+=op*dep[u];
}
}
inline void calc(int u,int lca) {
for(int x:factor[a[u]]) {
ans+=mu[x]*(cnt[x][0]*(dep[u]-dep[lca]*2)+cnt[x][1]);
}
}
inline void dsu(int p) {
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
dsu(v);
for(int x=id[v];x<id[v]+siz[v];++x) modify(rnk[x],-1);
}
if(son[p]) dsu(son[p]);
for(int v:edge[p]) {
if(v==fa[p]||v==son[p]) continue;
for(int x=id[v];x<id[v]+siz[v];++x) calc(rnk[x],p);
for(int x=id[v];x<id[v]+siz[v];++x) modify(rnk[x],1);
}
calc(p,p);
modify(p,1);
}
signed main() {
init();
scanf("%lld",&n);
for(int i=1;i<=n;++i) scanf("%lld",&a[i]);
for(int i=1;i<n;++i) {
int u,v;
scanf("%lld%lld",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(1,0);
dsu(1);
printf("%lld\n",ans);
return 0;
}