0%

任意模数NTT(NTT+CRT,MTT的优化)

问题:给出两个多项式,以及模数$P$,计算它们的卷积,不保证$P$是$t\cdot 2^k + 1$的形式的质数。

模数不支持NTT,那么我们肯定只能通过某种方式算出精确值,然后再取模。然而,两个这样的多项式相乘,每一项的系数最大是$n\cdot P^2$,大约$10^{23}$,FFT会爆精度,也不可能单模数NTT……


方法一:三模数NTT + CRT

我们选三个乘积大于了$n\cdot P^2$的、可以NTT的质数,算出系数模它们的余数,再将得到的余数利用中国剩余定理合并。然而粘板子是不行的,因为会爆long long。

假设现在有三个方程:

其中$m_1,m_2,m_3$都是$10^9$级别的。我们先中国剩余定理合并前面两个,得到:

其中$m=m_1\cdot m_2$。设$x=c+k_1m=c_2+k_2m_2$,则$k_1m-k_2m_2=c_2-c$,可以直接用扩展欧几里得求出一组解,因为扩展欧几里得算法解不定方程$ax+by=gcd(a,b)$,求出的解一定满足$|x|\le |b|$且$|y|\le |a|$。

整个同余方程组的最小正解,一定就是卷积结果的系数的真实值。因为任何一个比最小正解大的解,一定大于$m_1\cdot m_2\cdot m_3$,而这已经到了$10^{27}$的级别了,但这道题的答案不会超过$10^{23}$。

那么我们现在算出了最小的$k_1$,带回$x=c+k_1m$就可以得到$x$的值了。这里的乘法计算,当然要先对$P$取模再算乘法(要不然还是会爆long long QAQ)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
//这三个模数的原根都是3
int M[3]={1004535809,104857601,998244353};
int C[3][N],A[N],B[N];
void solve(int a[],int b[],int P)
{
for(int k=0;k<3;++k)
{
for(int i=0;i<len;++i) A[i]=a[i],B[i]=b[i];
FFT(A,1,M[k]),FFT(B,1,M[k]);
for(int i=0;i<len;++i) C[k][i]=A[i]*(ll)B[i]%M[k];
FFT(C[k],-1,M[k]);
}

ll P1=M[0]*(ll)M[1];
for(int i=0;i<len;++i)
{
ll k0,k1; exgcd(M[0],M[1],k0,k1); k0=(k0*(C[1][i]-C[0][i])%M[1]+M[1])%M[1];
ll x1=(C[0][i]+k0*M[0]%P1)%P1;
exgcd(P1,M[2],k0,k1); k0=(k0%M[2]*((C[2][i]-x1)%M[2])+M[2])%M[2];
a[i]=(x1+k0%P*(P1%P))%P;
}
}

方法二:MTT

用FFT计算的瓶颈在于精度爆炸,那么我们可以拆系数,把系数的值域变小,用时间换精度。

我们把原来的多项式$A(x)$拆成$A_0(x)\cdot M+A_1(x)$,其中$M$是一个常数。即,令${A_0}_i=\lfloor{A_i\over M}\rfloor,{A_1}_i=A_i\mod M$。这样计算两个多项式乘积就是算:

如果$M$取到$\sqrt P$左右的话,我们需要计算卷积的那些多项式的系数都只有$\sqrt P$级别了,精度得到有效改善。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
void Mul(int A[],int B[],int C[])
{
for(int i=0;i<len;++i) T1[i]=Comp(A[i],0),T2[i]=Comp(B[i],0);
FFT(T1,1),FFT(T2,1); for(int i=0;i<len;++i) T1[i]=T1[i]*T2[i]; FFT(T1,-1);
for(int i=0;i<len;++i) C[i]=((ll)(T1[i].a+0.5))%mod;
}
int a0[N],a1[N],b0[N],b1[N],c[N];
void MTT(int A[],int B[],int C[])
{
for(int i=0;i<len;++i)
{
a0[i]=A[i]/M,a1[i]=A[i]%M;
b0[i]=B[i]/M,b1[i]=B[i]%M;
C[i]=0;
}
Mul(a0,b0,c); for(int i=0;i<len;++i) C[i]=(C[i]+c[i]%mod*(ll)M%mod*M%mod)%mod;
Mul(a1,b0,c); for(int i=0;i<len;++i) C[i]=(C[i]+c[i]%mod*(ll)M%mod)%mod;
Mul(a0,b1,c); for(int i=0;i<len;++i) C[i]=(C[i]+c[i]%mod*(ll)M%mod)%mod;
Mul(a1,b1,c); for(int i=0;i<len;++i) C[i]=(C[i]+c[i])%mod;
}

FFT优化:DFT与IDFT合并

这是一种可以通过一次FFT,算出两个多项式的DFT/IDFT的骚操作。

假设要算$A(x),B(x)$的DFT。那么我们构造:

如果我们得到了$P,Q$的DFT,我们就可以得到$A,B$的DFT。而$P$和$Q$有一个非常奇妙的性质(下面设$W=k\cdot {2\pi \over l}$):

上下式子好像啊(废话,本来就构造得很相似),而且,$Q$最终式子的前半部分,$\sin$前面的都变了号,而$\cos$前面的没有变;后半部分$\cos$变了号,而$sin$不变。而正好$\sin$是奇函数,$\cos$是偶函数,即$\sin (-x) = - \sin x,\cos (-x) = \cos x$于是可以得到:

因为$w_n^{-k} = w_n^{n-k}$,所以$conj(P(\omega_n^{n-k})) = Q(\omega_n^k)$。

这样,如果我们算出了$P$的DFT,就可以得到$Q$的DFT。然后就可以得到:

那么DFT怎么办呢?考虑构造$M(\omega_n^k) = A(\omega_n^k)+iB(\omega_n^k) = P(\omega_n^k)$,那么我们把$P$IDFT回去以后,分别取实部和虚部,就可以得到$A,B$的IDFT。

由此,我们的DFT/IDFT的运算次数可以减少一半,只需要算4次。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
Comp T[N];
void DFT(int A[],int B[],Comp a[],Comp b[])
{
for(int i=0;i<len;++i) T[i]=Comp(A[i],B[i]); FFT(T,1);
for(int i=0;i<len;++i)
{
int j=(len-i)%len;
a[i]=(T[i]+T[j].conj())*Comp(0.5,0);
b[i]=(T[i]-T[j].conj())*Comp(0,-0.5);
}
}
void IDFT(Comp A[],Comp B[],int a[],int b[])
{
for(int i=0;i<len;++i) T[i]=A[i]+B[i]*Comp(0,1); FFT(T,-1);
for(int i=0;i<len;++i) a[i]=((ll)(T[i].a+0.5))%mod,b[i]=((ll)(T[i].b+0.5))%mod;
}
int a0[N],a1[N],b0[N],b1[N];
Comp A0[N],A1[N],B0[N],B1[N],T1[N],T2[N];
void MTT(int A[],int B[],int C[])
{
for(int i=0;i<len;++i)
{
a0[i]=A[i]/M,a1[i]=A[i]%M;
b0[i]=B[i]/M,b1[i]=B[i]%M;
C[i]=0;
}
DFT(a0,a1,A0,A1); DFT(b0,b1,B0,B1);
for(int i=0;i<len;++i) T1[i]=A0[i]*B0[i],T2[i]=A1[i]*B1[i];
IDFT(T1,T2,a0,a1);
for(int i=0;i<len;++i) C[i]=(C[i]+M*(ll)M%mod*a0[i]%mod+a1[i])%mod;
for(int i=0;i<len;++i) T1[i]=A0[i]*B1[i]+A1[i]*B0[i];
FFT(T1,-1);
for(int i=0;i<len;++i) C[i]=(C[i]+M*((ll)(T1[i].a+0.5)%mod))%mod;
}