0%

树上路径数颜色(树上莫队)

参考:一篇cf上的很好的blog

食用方法:树上路径数颜色。

我们构造一棵树的括号序:即,dfs一棵树的时候,在发现一个节点时,记此时的$idx$为这个节点的$In$,然后让$idx++$;结束对一个节点的访问时,记此时的$idx$为这个节点的$Out$,然后让$idx++$。这样,$u$和$v$之间路径,就转化括号序中的区间$[Out_u,In_v]$,其中我们要求区间中出现了$0$次或者$2$次的节点都不计贡献(说白了就是挪指针的时候进行类似异或的操作),即:

1
2
3
4
5
6
void Rev(int u)
{
if(vis[u]) ans-=buc[c[u]]==1,buc[c[u]]--;
else ans+=buc[c[u]]==0,buc[c[u]]++;
vis[u]^=1;
}

然后我们还需要特判LCA,因为LCA的贡献可能没有被计算到。

模板:SPOJ Count on a Tree II 树上路径数颜色

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <cstring>
#define ll long long
using namespace std;
template <class T>
inline void read(T &x)
{
x=0; char c=getchar(); int f=1;
while(!isdigit(c)){if(c=='-')f=-1; c=getchar();}
while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f;
}
// 2^16 = 65536
const int N=40010,M=100010;
int pos[N<<1],vis[N],p[N][17],head[N],dep[N],n,ecnt,c[N],b[N],id;
int In[N],Out[N];
struct ed{int to,next;}e[N<<1];
void ad(int x,int y)
{
e[++ecnt]=(ed){y,head[x]}; head[x]=ecnt;
e[++ecnt]=(ed){x,head[y]}; head[y]=ecnt;
}
void dfs(int u,int last)
{
dep[u]=dep[last]+1,p[u][0]=last,In[u]=++id,pos[id]=u;
for(int j=1;j<17;++j) p[u][j]=p[p[u][j-1]][j-1];
for(int k=head[u];k;k=e[k].next)
{
int v=e[k].to; if(v==last) continue;
dfs(v,u);
}
Out[u]=++id,pos[id]=u;
}
int buc[N],ans,Ans[M];
void Rev(int u)
{
if(vis[u]) ans-=buc[c[u]]==1,buc[c[u]]--;
else ans+=buc[c[u]]==0,buc[c[u]]++;
vis[u]^=1;
}
struct Que{int l,r,id,L;}Q[M];
int LCA(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int j=16;j>=0;--j) if(dep[x]-(1<<j)>=dep[y]) x=p[x][j]; if(x==y) return x;
for(int j=16;j>=0;--j) if(p[x][j]!=p[y][j]) x=p[x][j],y=p[y][j]; return p[x][0];
}
int T;
bool cmp(Que A,Que B){if((A.l+1)/T==(B.l+1)/T) return A.r<B.r;return A.l<B.l;}
void Update(int l,int r)
{
for(int i=l;i<=r;++i) Rev(pos[i]);
}

int main()
{
int m,x,y,tmp; read(n),read(m);
for(int i=1;i<=n;++i) read(c[i]),b[i]=c[i];
sort(b+1,b+n+1); int sz=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;++i) c[i]=lower_bound(b+1,b+sz+1,c[i])-b;
for(int i=1;i<n;++i) read(x),read(y),ad(x,y);
dfs(1,0);

// for(int i=1;i<=2*n;++i) cout<<pos[i]<<' '; cout<<endl;

for(int i=1;i<=m;++i)
{
read(x),read(y),Q[i].id=i;
int L=LCA(x,y); if(In[x]>In[y]) swap(x,y);
if(L==x||L==y) Q[i].l=In[x],Q[i].r=In[y];
else Q[i].l=Out[x],Q[i].r=In[y],Q[i].L=L;
// cout<<Q[i].l<<' '<<Q[i].r<<' '<<Q[i].L<<endl;
}
T=ceil(sqrt(n*2)); sort(Q+1,Q+m+1,cmp);
int L=Q[1].l,R=Q[1].r; for(int i=L;i<=R;++i) Rev(pos[i]);
if(Q[1].L) Rev(Q[1].L); Ans[Q[1].id]=ans; if(Q[1].L) Rev(Q[1].L);

for(int i=2;i<=m;++i)
{
Update(L,Q[i].l-1),Update(Q[i].l,L-1);
Update(Q[i].r+1,R),Update(R+1,Q[i].r);
if(Q[i].L) Rev(Q[i].L);
Ans[Q[i].id]=ans;
if(Q[i].L) Rev(Q[i].L);
L=Q[i].l,R=Q[i].r;
}
for(int i=1;i<=m;++i) printf("%d\n",Ans[i]);
return 0;
}