0%

CTSC2010 循环卷积

洛谷上的题面

有两个长度为$n$的多项式$A$和$B$,需要求出$AB^C$。其中的$$运算定义如下:

数据范围:$n\le 5\times 10^5,C\le 10^9$。保证$n$一定可以被分解成若干个不超过$10$的正整数的乘积,并且$n+1$是质数。

Solution:
我们用FFT计算卷积的过程,本质上是构造了一个$trans(p,i)$(贡献系数),令DFT后的数组为$A’[p]=\sum_{i=0}^n trans(p,i)\times A[i]$,我们必须让$trans(p,i)\times trans(p,j)=trans(p,i+j)$成立。FFT中,我们的$trans(p,i)=\omega_n^{pi}$。而$trans(p,i)\times trans(p,j)=\omega_n^{pi}\cdot \omega_n^{pj} =\omega_n^{p(i+j)}=trans(p,i+j)$。故而这个变换可以满足我们的要求。

考虑用类似的思想来解决这个问题,我们要构造一个$trans(p,i)$,使得$trans(p,i)\times trans(p,j)=trans(p,(i+j)\text{ mod }n)$。有一个令人惊喜的发现:$trans(p,i)=\omega_n^{pi}$是满足条件的,因为$\omega_n^k=\omega_n^{k-n}$。

这就意味着,我们可以FFT,然后直接对点值快速幂,再IDFT回去。

可是,整个数组的长度并不是$2^k$。而且$C$很大,直接$power$肯定会炸精度,而如果用NTT,模数也满足$t\cdot 2^k +1$的形式。怎么办呢?

题目里面有一个条件:保证$n$一定可以被分解成若干个不超过$10$的正整数的乘积。(实际上,这样的$n$被称作smooth number)。我们进行分治的时候,不一定将一个长度为$n$的区间分成两个长度为${n\over 2}$的区间;可以取出当前区间长度的一个质因子$p$,然后将当前区间分成$p$个区间,分治完了之后再合并。具体地,假设我们现在分治序列$a_0,a_2,\cdots a_{n-1}$,那么我们取$n$的一个质因子$p$,然后按照模$p$的余数分组。以$n=6,p=3$为例,我们希望计算对于每一个$k\in [0,n)$,$f(\omega_n^k)$的值,其中$f(x)=\sum_{i=0}^{n-1}a_ix^i$。稍加推导:

由于$\omega_n^k = \omega_{n\over p}^{k\over p}$,我们可以递归下去,对每一个长度为$2$的区间进行分治,算出$k\in [0,{n\over 3})$时,该区间$f(x)$的取值。我们不需要算得更多,是因为$\omega_n^k =\omega_n^{k+n}$。然后合并就可以了。由于$p$总是很小,也就是划分的区间的数量很小,合并的复杂度可以视作常数。所以总复杂度是$n\log n$的。

还有精度的问题:可以直接用模数的原根替代单位根,因为我们只需要满足,模数-1是所有分治区间的长度的倍数就可以了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define ll long long
using namespace std;
template <class T>
inline void read(T &x)
{
x=0; char c=getchar(); int f=1;
while(!isdigit(c)){if(c=='-')f=-1; c=getchar();}
while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f;
}
int mod,G;
ll Pow(ll x,ll y){ll res=1; while(y){if(y&1)res=res*x%mod; x=x*x%mod,y>>=1;} return res;}
namespace P_root{
const int N=720+22;
int flg[N],pri[N],num,d[N],tot;
void predo(int n){for(int i=2;i<=n;++i){if(!flg[i])pri[num++]=i; for(int j=0;j<num&&pri[j]*i<=n;++j){flg[i*pri[j]]=1; if(i%pri[j]==0) break;}}}
void get(int n){tot=0;for(int j=0;j<num&&pri[j]*pri[j]<=n;++j)if(n%pri[j]==0){d[++tot]=pri[j];while(n%pri[j]==0)n/=pri[j];}if(n>1)d[++tot]=n;}
bool che(int x){for(int j=1;j<=tot;++j)if(Pow(x,(mod-1)/d[j])==1)return 0; return 1;}
int cal(){predo(N-22),get(mod-1);for(int i=2;i<mod;++i) if(che(i)) return i; return -1;}
}
const int N=5e5+10;
int di[33],num,n,pos[N];
ll A[N],B[N],t[N];
int findpos(int pr,int x,int dep){return dep==num+1?pr+x:findpos(pr*di[dep]+x%di[dep],x/di[dep],dep+1);}
void cpy(ll A[],ll B[]){for(int i=0;i<n;++i) A[i]=B[i];}
void FFT(ll A[],int f)
{
for(int i=0;i<n;++i) t[pos[i]]=A[i]; cpy(A,t);
for(int lt=1,l=di[num],cur=num;cur>=1;cur--,lt=l,l*=di[cur])
{
ll wn=Pow(G,f==1?n/l:n-n/l);
for(int i=0;i<n;i+=l)
{
ll wk=1;
for(int k=0;k<l;++k,wk=wk*wn%mod) // now calculate i+k, w->wn^k
{
ll tmp=0; ll w=1;
for(int j=k%lt;j<l;j+=lt,w=w*wk%mod)
tmp=(tmp+w*A[i+j])%mod;
t[i+k]=tmp;
}
}
cpy(A,t);
}
if(f==-1){ll Inv=Pow(n,mod-2); for(int i=0;i<n;++i) A[i]=A[i]*Inv%mod;}
}

int main()
{
int t; read(n),read(t); mod=n+1,G=P_root::cal();
int tmp=n; for(int i=2;i*i<=tmp;++i){while(tmp%i==0)tmp/=i,di[++num]=i;}if(tmp>1)di[++num]=tmp;
// sort(di+1,di+num+1);
for(int i=0;i<n;++i) pos[i]=findpos(0,i,1);
for(int i=0;i<n;++i) read(A[i]); FFT(A,1);
for(int i=0;i<n;++i) read(B[i]); FFT(B,1);
for(int i=0;i<n;++i) A[i]=A[i]*Pow(B[i],t)%mod;
FFT(A,-1);
for(int i=0;i<n;++i) printf("%lld\n",A[i]);
return 0;
}