1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
| #include <cstdio> #include <iostream> #include <algorithm> #include <cmath> #include <cstring> #define ll long long using namespace std; template <class T> inline void read(T &x) { x=0; char c=getchar(); int f=1; while(!isdigit(c)){if(c=='-')f=-1; c=getchar();} while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f; }
const int N=40010,M=100010; int pos[N<<1],vis[N],p[N][17],head[N],dep[N],n,ecnt,c[N],b[N],id; int In[N],Out[N]; struct ed{int to,next;}e[N<<1]; void ad(int x,int y) { e[++ecnt]=(ed){y,head[x]}; head[x]=ecnt; e[++ecnt]=(ed){x,head[y]}; head[y]=ecnt; } void dfs(int u,int last) { dep[u]=dep[last]+1,p[u][0]=last,In[u]=++id,pos[id]=u; for(int j=1;j<17;++j) p[u][j]=p[p[u][j-1]][j-1]; for(int k=head[u];k;k=e[k].next) { int v=e[k].to; if(v==last) continue; dfs(v,u); } Out[u]=++id,pos[id]=u; } int buc[N],ans,Ans[M]; void Rev(int u) { if(vis[u]) ans-=buc[c[u]]==1,buc[c[u]]--; else ans+=buc[c[u]]==0,buc[c[u]]++; vis[u]^=1; } struct Que{int l,r,id,L;}Q[M]; int LCA(int x,int y) { if(dep[x]<dep[y]) swap(x,y); for(int j=16;j>=0;--j) if(dep[x]-(1<<j)>=dep[y]) x=p[x][j]; if(x==y) return x; for(int j=16;j>=0;--j) if(p[x][j]!=p[y][j]) x=p[x][j],y=p[y][j]; return p[x][0]; } int T; bool cmp(Que A,Que B){if((A.l+1)/T==(B.l+1)/T) return A.r<B.r;return A.l<B.l;} void Update(int l,int r) { for(int i=l;i<=r;++i) Rev(pos[i]); } int main() { int m,x,y,tmp; read(n),read(m); for(int i=1;i<=n;++i) read(c[i]),b[i]=c[i]; sort(b+1,b+n+1); int sz=unique(b+1,b+n+1)-b-1; for(int i=1;i<=n;++i) c[i]=lower_bound(b+1,b+sz+1,c[i])-b; for(int i=1;i<n;++i) read(x),read(y),ad(x,y); dfs(1,0);
for(int i=1;i<=m;++i) { read(x),read(y),Q[i].id=i; int L=LCA(x,y); if(In[x]>In[y]) swap(x,y); if(L==x||L==y) Q[i].l=In[x],Q[i].r=In[y]; else Q[i].l=Out[x],Q[i].r=In[y],Q[i].L=L;
} T=ceil(sqrt(n*2)); sort(Q+1,Q+m+1,cmp); int L=Q[1].l,R=Q[1].r; for(int i=L;i<=R;++i) Rev(pos[i]); if(Q[1].L) Rev(Q[1].L); Ans[Q[1].id]=ans; if(Q[1].L) Rev(Q[1].L); for(int i=2;i<=m;++i) { Update(L,Q[i].l-1),Update(Q[i].l,L-1); Update(Q[i].r+1,R),Update(R+1,Q[i].r); if(Q[i].L) Rev(Q[i].L); Ans[Q[i].id]=ans; if(Q[i].L) Rev(Q[i].L); L=Q[i].l,R=Q[i].r; } for(int i=1;i<=m;++i) printf("%d\n",Ans[i]); return 0; }
|