洛谷上的题面
有两个长度为$n$的多项式$A$和$B$,需要求出$AB^C$。其中的$ $运算定义如下:
数据范围:$n\le 5\times 10^5,C\le 10^9$。保证$n$一定可以被分解成若干个不超过$10$的正整数的乘积,并且$n+1$是质数。
Solution: 我们用FFT计算卷积的过程,本质上是构造了一个$trans(p,i)$(贡献系数),令DFT后的数组为$A’[p]=\sum_{i=0}^n trans(p,i)\times A[i]$,我们必须让$trans(p,i)\times trans(p,j)=trans(p,i+j)$成立。FFT中,我们的$trans(p,i)=\omega_n^{pi}$。而$trans(p,i)\times trans(p,j)=\omega_n^{pi}\cdot \omega_n^{pj} =\omega_n^{p(i+j)}=trans(p,i+j)$。故而这个变换可以满足我们的要求。
考虑用类似的思想来解决这个问题,我们要构造一个$trans(p,i)$,使得$trans(p,i)\times trans(p,j)=trans(p,(i+j)\text{ mod }n)$。有一个令人惊喜的发现:$trans(p,i)=\omega_n^{pi}$是满足条件的,因为$\omega_n^k=\omega_n^{k-n}$。
这就意味着,我们可以FFT,然后直接对点值快速幂,再IDFT回去。
可是,整个数组的长度并不是$2^k$。而且$C$很大,直接$power$肯定会炸精度,而如果用NTT,模数也满足$t\cdot 2^k +1$的形式。怎么办呢?
题目里面有一个条件:保证$n$一定可以被分解成若干个不超过$10$的正整数的乘积。(实际上,这样的$n$被称作smooth number)。我们进行分治的时候,不一定将一个长度为$n$的区间分成两个长度为${n\over 2}$的区间;可以取出当前区间长度的一个质因子$p$,然后将当前区间分成$p$个区间,分治完了之后再合并。具体地,假设我们现在分治序列$a_0,a_2,\cdots a_{n-1}$,那么我们取$n$的一个质因子$p$,然后按照模$p$的余数分组。以$n=6,p=3$为例,我们希望计算对于每一个$k\in [0,n)$,$f(\omega_n^k)$的值,其中$f(x)=\sum_{i=0}^{n-1}a_ix^i$。稍加推导:
由于$\omega_n^k = \omega_{n\over p}^{k\over p}$,我们可以递归下去,对每一个长度为$2$的区间进行分治,算出$k\in [0,{n\over 3})$时,该区间$f(x)$的取值。我们不需要算得更多,是因为$\omega_n^k =\omega_n^{k+n}$。然后合并就可以了。由于$p$总是很小,也就是划分的区间的数量很小,合并的复杂度可以视作常数。所以总复杂度是$n\log n$的。
还有精度的问题:可以直接用模数的原根替代单位根,因为我们只需要满足,模数-1是所有分治区间的长度的倍数就可以了。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 #include <cstdio> #include <iostream> #include <algorithm> #include <cstring> #define ll long long using namespace std ;template <class T >inline void read (T &x ){ x=0 ; char c=getchar(); int f=1 ; while (!isdigit (c)){if (c=='-' )f=-1 ; c=getchar();} while (isdigit (c)) x=x*10 -'0' +c,c=getchar(); x*=f; } int mod,G;ll Pow (ll x,ll y) {ll res=1 ; while (y){if (y&1 )res=res*x%mod; x=x*x%mod,y>>=1 ;} return res;}namespace P_root{ const int N=720 +22 ; int flg[N],pri[N],num,d[N],tot; void predo (int n) {for (int i=2 ;i<=n;++i){if (!flg[i])pri[num++]=i; for (int j=0 ;j<num&&pri[j]*i<=n;++j){flg[i*pri[j]]=1 ; if (i%pri[j]==0 ) break ;}}} void get (int n) {tot=0 ;for (int j=0 ;j<num&&pri[j]*pri[j]<=n;++j)if (n%pri[j]==0 ){d[++tot]=pri[j];while (n%pri[j]==0 )n/=pri[j];}if (n>1 )d[++tot]=n;} bool che (int x) {for (int j=1 ;j<=tot;++j)if (Pow(x,(mod-1 )/d[j])==1 )return 0 ; return 1 ;} int cal () {predo(N-22 ),get(mod-1 );for (int i=2 ;i<mod;++i) if (che(i)) return i; return -1 ;} } const int N=5e5 +10 ;int di[33 ],num,n,pos[N];ll A[N],B[N],t[N]; int findpos (int pr,int x,int dep) {return dep==num+1 ?pr+x:findpos(pr*di[dep]+x%di[dep],x/di[dep],dep+1 );}void cpy (ll A[],ll B[]) {for (int i=0 ;i<n;++i) A[i]=B[i];}void FFT (ll A[],int f) { for (int i=0 ;i<n;++i) t[pos[i]]=A[i]; cpy(A,t); for (int lt=1 ,l=di[num],cur=num;cur>=1 ;cur--,lt=l,l*=di[cur]) { ll wn=Pow(G,f==1 ?n/l:n-n/l); for (int i=0 ;i<n;i+=l) { ll wk=1 ; for (int k=0 ;k<l;++k,wk=wk*wn%mod) { ll tmp=0 ; ll w=1 ; for (int j=k%lt;j<l;j+=lt,w=w*wk%mod) tmp=(tmp+w*A[i+j])%mod; t[i+k]=tmp; } } cpy(A,t); } if (f==-1 ){ll Inv=Pow(n,mod-2 ); for (int i=0 ;i<n;++i) A[i]=A[i]*Inv%mod;} } int main () { int t; read(n),read(t); mod=n+1 ,G=P_root::cal(); int tmp=n; for (int i=2 ;i*i<=tmp;++i){while (tmp%i==0 )tmp/=i,di[++num]=i;}if (tmp>1 )di[++num]=tmp; for (int i=0 ;i<n;++i) pos[i]=findpos(0 ,i,1 ); for (int i=0 ;i<n;++i) read(A[i]); FFT(A,1 ); for (int i=0 ;i<n;++i) read(B[i]); FFT(B,1 ); for (int i=0 ;i<n;++i) A[i]=A[i]*Pow(B[i],t)%mod; FFT(A,-1 ); for (int i=0 ;i<n;++i) printf ("%lld\n" ,A[i]); return 0 ; }