0%

斯坦纳树

定义

有一张无向图$(V,E)$,指定了一个$V$的子集$S$,需要求出一个边集$E’$,使得:

  • 对于任意两个点$x,y\in S$,存在一条从$x$到$y$的、由$E’$中的边组成的路径(路径可以经过不在$S$中的点)。
  • 在此基础上,最小化$E’$中所有的边的边权和。

显然$E’$会组成一棵树。称这棵树为斯坦纳树。

求法

考虑状态压缩$dp$:$f[u][s]$表示$s$集合中的点和$u$点已经连通,所需要的最小的花费。其中$u$是$V$中的任意一个点,而$s$是$S$的子集。

更新有两种:

  • 将和$u$连通的两个连通块合并起来:$f[u][s] = \min \{ f[u][t] + f[u][s-t]\}$
  • 扩展到另一个不在$S$中的点:$f[v][s] = \min \{ f[u][s] + e[u][v] \}$

首先从小到大枚举$s$,然后对于每个$s$计算所有的$f[u][s]$。先枚举第一种转移(用更小的$s$来更新当前的$s$);然后考虑第二种转移:实际上是最短路的形式,可以用spfa或者dijkstra解决。

时间复杂度$O( n 3^{|S|} + nm2^{|S|} )$。

模板

hdu4085

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <queue>
#define ll long long
using namespace std;
template <class T>
inline void rd(T &x) {
x=0; char c=getchar(); int f=1;
while(!isdigit(c)) { if(c=='-') f=-1; c=getchar(); }
while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f;
}
const int N=55;
int d[N][N],n,m,K;
int f[N][1<<10],g[1<<10],vid[N];
queue<int> que;
int vis[N];
void spfa(int c) {
while(!que.empty()) {
int u=que.front(); que.pop(),vis[u]=0;
for(int v=1;v<=n;++v)
if(f[v][c]>f[u][c]+d[u][v]) {
f[v][c]=f[u][c]+d[u][v];
if(!vis[v]) que.push(v),vis[v]=1;
}
}
}
int main() {
int T; rd(T);
while(T--) {
rd(n),rd(m),rd(K);
memset(d,0x3f,sizeof(d));
for(int x,y,w,i=1;i<=m;++i) {
rd(x),rd(y),rd(w);
d[x][y]=min(d[x][y],w);
d[y][x]=min(d[y][x],w);
}
memset(vid,0,sizeof(vid));
for(int i=1;i<=K;++i) vid[i]=i;
for(int i=n-K+1;i<=n;++i) vid[i]=K+(n-i+1);

memset(f,0x3f,sizeof(f));
for(int i=1;i<=n;++i) {
if(vid[i]) f[i][1<<vid[i]-1]=0;
f[i][0]=0;
}
for(int s=1;s<(1<<(K*2));++s) {
for(int i=1;i<=n;++i)
for(int t=s&(s-1);t;t=(t-1)&s)
f[i][s]=min(f[i][s],f[i][t]+f[i][s^t]);
for(int i=1;i<=n;++i)
if(f[i][s]<1e7) que.push(i),vis[i]=1;
spfa(s);
}
memset(g,0x3f,sizeof(g));
for(int s=1;s<(1<<(K*2));++s) {
int cnt1=0,cnt2=0;
for(int i=1;i<=K;++i) if(s&(1<<i-1)) cnt1++;
for(int i=K+1;i<=2*K;++i) if(s&(1<<i-1)) cnt2++;
if(cnt1!=cnt2) continue;
for(int i=1;i<=n;++i) g[s]=min(g[s],f[i][s]);
}
for(int s=0;s<(1<<(2*K));++s)
for(int t=s&(s-1);t;t=(t-1)&s)
g[s]=min(g[s],g[t]+g[s^t]);
int ans=g[(1<<(2*K))-1];
if(ans>1e7) printf("No solution\n");
else printf("%d\n",ans);
}
return 0;
}