0%

KD-Tree 学习笔记

参考资料

  1. OI-wiki
  2. wikipedia
  3. 一份有详细图解的课件

KD-Tree 是什么

KD-Tree(k-dimensional tree 的简写)是一种用来维护 $k$ 维的点集的数据结构。它是一棵深度为 $O(\log n)$ 的二叉搜索树,每个结点代表了点集中的一个点。

它可以支持以下的操作(设 $n$ 为点集的大小):

  1. 插入/删除一个点,时间复杂度为均摊 $O(\log n)$
  2. 查询一个边界平行于坐标轴的矩形/超矩形内的点的信息,时间复杂度上界为 $O(n^{1-\frac{1}{k}})$
  3. 对一个边界平行于坐标轴的矩形/超矩形内的点进行支持标记合并的修改(维护和线段树类似的标记),时间复杂度上界为 $O(n^{1-\frac{1}{k}})$
  4. 查询一个点的最近点/最远点,数据随机的时候时间复杂度期望为 $O(\log n)$ ,最坏复杂度为 $O(n)$ 。

建树

选择一个维度,然后在点集中选择一个这一维的坐标是中位数的点作为根,然后把其它的点按照这一维的坐标和根的大小关系分成左右两棵子树,递归到子树内进行建造。

定义一个结点的范围是它子树内的点的坐标范围,也就是一个 $k$ 维的、所有边都平行于坐标轴的超矩形。

注意到:对于某个点,显然它的左右子树的范围的交要么为空,要么只包含了这个点所在的、垂直于这一次所选的维度的坐标轴的一条线段。

以下是二维情形的建树代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/*
rec[c] 代表 c 结点的范围
tr[c] 表示 c 结点所代表的点的坐标
*/
void build(int l,int r,int &c,int d) {
if(l>r) return (void)(c=0);
int mid=l+r>>1;
nth_element(P+l,P+mid,P+r+1,(d&1?cmpx:cmpy));
tr[c=newnode()]=P[mid];
build(l,mid-1,ch[c][0],d+1),build(mid+1,r,ch[c][1],d+1);
rec[c].lx=min(tr[c].x,min(rec[ch[c][0]].lx,rec[ch[c][1]].lx));
rec[c].rx=max(tr[c].x,max(rec[ch[c][0]].rx,rec[ch[c][1]].rx));
rec[c].ly=min(tr[c].y,min(rec[ch[c][0]].ly,rec[ch[c][1]].ly));
rec[c].ry=max(tr[c].y,max(rec[ch[c][0]].ry,rec[ch[c][1]].ry));
}

插入/删除

按照 KD-Tree 的子树划分方式递归到相应的位置,然后插入/删除即可。

为了保证平衡,我们要在某个点的某个儿子的子树大小大于这个点的子树大小 $\times \alpha$ 的时候对这个子树进行重构。其中 $\alpha$ 一般取 0.75 左右。

单次插入/删除的时间复杂度是均摊 $O(\log n)$ 的。

以下是二维情形的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
bool cmpx(POINT A,POINT B) { return A.x<B.x; }
bool cmpy(POINT A,POINT B) { return A.y<B.y; }
bool cmptx(int x,int y) { return tr[x].x<tr[y].x; }
bool cmpty(int x,int y) { return tr[x].y<tr[y].y; }
void push_up(int c) {
sum[c]=sum[ch[c][0]]+sum[ch[c][1]]+val[c];
sz[c]=sz[ch[c][0]]+sz[ch[c][1]]+1;
}
vector<int> b;
void rebuild(int l,int r,int &c,int d) {
if(l>r) return (void)(c=0);
int mid=l+r>>1;
nth_element(b.begin()+l,b.begin()+mid,b.begin()+r+1,(d?cmptx:cmpty));
c=b[mid];
rebuild(l,mid-1,ch[c][0],d^1);
rebuild(mid+1,r,ch[c][1],d^1);
rec[c].lx=min(tr[c].x,min(rec[ch[c][0]].lx,rec[ch[c][1]].lx));
rec[c].rx=max(tr[c].x,max(rec[ch[c][0]].rx,rec[ch[c][1]].rx));
rec[c].ly=min(tr[c].y,min(rec[ch[c][0]].ly,rec[ch[c][1]].ly));
rec[c].ry=max(tr[c].y,max(rec[ch[c][0]].ry,rec[ch[c][1]].ry));
push_up(c);
}
void recycle(int rt) {
if(!rt) return; b.PB(rt);
recycle(ch[rt][0]),recycle(ch[rt][1]);
}
void rebuild(int &rt,int d) {
recycle(rt),rebuild(0,b.size()-1,rt,d),b.clear();
}
void ins(int &c,int d,POINT p,int v,int flg) {
if(!c) {
tr[c=++ncnt]=p,val[c]=v,push_up(c);
rec[c].lx=rec[c].rx=tr[c].x;
rec[c].ly=rec[c].ry=tr[c].y;
return;
}
int lr=(d?cmpx(tr[c],p):cmpy(tr[c],p));
int cur_flg=(sz[c]-1)*alpha<=max(sz[ch[c][lr]]+1,sz[ch[c][lr^1]]);
ins(ch[c][lr],d^1,p,v,flg|cur_flg);
rec[c].lx=min(rec[c].lx,p.x);
rec[c].rx=max(rec[c].rx,p.x);
rec[c].ly=min(rec[c].ly,p.y);
rec[c].ry=max(rec[c].ry,p.y);
push_up(c);
if(!flg&&cur_flg) rebuild(c,d);
}

矩形查询

从根开始往下递归:

  1. 如果当前结点的范围与查询的范围没有交,直接退出;
  2. 如果当前结点的范围被完全包含在查询的范围内,返回当前结点的子树信息
  3. 否则考虑当前结点所代表的点的贡献,并递归到子树内继续查询

可以证明,单次操作的时间复杂度上界为 $O(n^{1-\frac{1}{k}})$ 。

以下是二维情形的代码:

1
2
3
4
5
6
7
8
9
int qry(int c,int d,int lx,int rx,int ly,int ry) {
if(rx<rec[c].lx||lx>rec[c].rx||ry<rec[c].ly||ly>rec[c].ry) return 0;
if(lx<=rec[c].lx&&rx>=rec[c].rx&&ly<=rec[c].ly&&ry>=rec[c].ry) return sum[c];
int tot=0;
if(lx<=tr[c].x&&rx>=tr[c].x&&ly<=tr[c].y&&ry>=tr[c].y) tot+=val[c];
tot+=qry(ch[c][0],d^1,lx,rx,ly,ry);
tot+=qry(ch[c][1],d^1,lx,rx,ly,ry);
return tot;
}

矩形修改

与矩形查询是类似的。

从根开始往下递归:

  1. 如果当前结点的范围与查询的范围没有交,直接退出;
  2. 如果当前结点的范围被完全包含在查询的范围内,对当前结点打子树修改标记
  3. 否则考虑修改对当前结点代表的点的贡献,并递归到当前结点的子树内继续以上过程

单次操作的时间复杂度上界是 $O(n^{1-\frac{1}{k}})$ 。

以下是二维情形的代码:

1
2
3
4
5
6
7
8
void upd(int c,int lx,int rx,int ly,int ry,int op,int v) {
if(rec[c].lx>rx||rec[c].rx<lx||rec[c].ly>ry||rec[c].ry<ly) return;
if(rec[c].lx>=lx&&rec[c].rx<=rx&&rec[c].ly>=ly&&rec[c].ry<=ry) return (void)(op==1?add(c,v):mul(c,v));
push_down(c);
if(lx<=tr[c].x&&rx>=tr[c].x&&ly<=tr[c].y&&ry>=tr[c].y) val[c]=(op==1?(val[c]+v)%mod:1ll*val[c]*v%mod);
upd(ch[c][0],lx,rx,ly,ry,op,v);
upd(ch[c][1],lx,rx,ly,ry,op,v);
}

查询一个点的最近/最远点

本质上是对暴力搜索的剪枝。

剪枝 1 :如果一个结点的范围内的所有点到查询点的距离都不如现在的最优答案优秀,那么直接退出。

剪枝 2 :在决定先往左子树走还是往右子树走的时候,走范围边界上最优的那个点更优的那个子树。

可以证明,随机数据的时候单次查询的期望复杂度为 $O(\log n)$ ,最坏情况下的时间复杂度为 $O(n)$ 。

拓展:查询一个点的第 $k$ 近/远的点

用优先队列维护现在已经搜到的答案最优的 $k$ 个,搜索的过程中进行和前面同理的剪枝;复杂度和上面最近/最远点的复杂度是相同的。

二维 $k$ 远点:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
db Sqr(db x) { return x*x; }
db dis_to_rec(int c) {
if(!c) return -1;
return max(Sqr(rec[c].lx-qx),Sqr(rec[c].rx-qx))+max(Sqr(rec[c].ly-qy),Sqr(rec[c].ry-qy));
}
db dis(int c) { return Sqr(tr[c].x-qx)+Sqr(tr[c].y-qy); }
void qry(int c) {
if(!c) return;
item tmp=(item){tr[c].id,dis(c)};
if(tmp<Q.top()) Q.pop(),Q.push(tmp);
db d[2]={dis_to_rec(ch[c][0]),dis_to_rec(ch[c][1])};
int lr=(d[1]>d[0]);
if(sgn(d[lr]-Q.top().d)>=0) qry(ch[c][lr]);
if(sgn(d[lr^1]-Q.top().d)>=0) qry(ch[c][lr^1]);
}

题表

  • luogu P2093 【国家集训队】JZPFAR
  • luogu P3710 方方方的数据结构
  • luogu P3710 方方方的数据结构
  • luogu P4148 简单题
  • luogu P4357 【CQOI2016】K远点对