0%

常系数齐次线性递推 - O(k log k log n) 求第n项

定义

对于一个数列$h_0,h_1\cdots h_n,\cdots $,称这个数列满足k阶线性递推关系是指存在量$a_1,a_2,\cdots a_k (a_k \not= 0 )$和量$b_n$($a_1,a_2,\cdots a_k,b_n$都可能以来于$n$),使得$h_n = a_1h_{n-1} + a_2h_{n-2} \cdots a_k h_{n-k} + b_n(n\ge k)$成立。

如果$b_n$是常数$0$,我们称这个线性递推关系式齐次的。

如果$a_1,a_2\cdots a_k$都是常量,我们称这个线性递推关系式常系数的。


1)

可以用矩阵表示是常系数线性递推关系(以$k=4$为例):

设上面的那个$k\times k$的矩阵为$A$。

求数列的第$n$项等价于求$A^n \cdot
\begin {bmatrix}
f_3 \\
f_2 \\
f_1 \\
f_0
\end {bmatrix}$。

利用矩阵乘法的结合律,可以用矩阵快速幂在$O(k^3 \log n)$的时间内解决。

2)

对于一个矩阵$A$,我们定义它的特征多项式为:

可以证明$P(A)=0$(也就是Cayley-Hamilton定理),其中$0$表示$0$矩阵。

对于一个$k$阶常系数线性递推关系,它的矩阵的特征多项式是:

可以直接展开行列式就得到。

考虑矩阵$\lambda E - A$的行列式:

如果确定了哪一列选择第一行,那么容易发现其他的列选择的行都是唯一确定的。如果是第$i$列选了第一行,那么$i$之前的列$j$一定会选择第$j+1$行,$i$之后的列$j$一定会选择第$j$行。故而,这个矩阵的行列式就是:

其中,第一个$-1$是这一列之前的所有的列的选择,然后$\lambda$是这一列之后的列的选择,最后还要乘$-1$是因为考虑逆序对的数量。

所以$A$的特征多项式是$P(\lambda) = \det ( \lambda E - A)$。

3)

也就是说,我们得到了

这意味着,对于任意一个$n\ge k$,$A^n$都可以用$A^0,A^1,\cdots A^{k-1}$的线性组合表示出来。

考虑计算两个矩阵的乘积:设$A^x = \sum_{i=0}^{k-1} a_i A^i ,A^y = \sum_{i=0}^{k-1}b_iA^i $,那么:

这是一个卷积的形式。卷积完之后,为了保证项数小于等于$k$,还要做一次多项式取模。

也就是说我们现在有了一个在$O(k^2)$或者$O(k\log k)$计算两个矩阵的乘积的方法。矩阵快速幂的复杂度被优化到$O(k^2\log n)$或者$O(k\log k \log n)$。

4)

但是,直接用这样的结果来计算答案复杂度大约是$k^4$的。

设$B_i$是$k\times 1$的向量$ \begin {bmatrix} f_{i+k-1} \ \vdots \ f_{i+2} \ f_{i+1} \ f_{i} \end {bmatrix} $。

我们要求的就是$A^n \cdot B_0 = (\sum_{i=0}^{k-1} a_i A^i) B_0 $,它等于:

也就是说$f_n = \sum_{i=0}^{k-1} a_if_i$。这一步的计算在$O(k)$的时间内就可以完成。

综上,我们可以在$O(k\log k\log n)$的时间内算出一个满足线性齐次递推关系的数列的第$n$项。


模板 & 例题

luogu P4723 模板 线性递推

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
// luogu-judger-enable-o2
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define PII pair<int,int>
#define MP make_pair
#define fir first
#define sec second
#define PB push_back
#define db long double
#define ll long long
using namespace std;
template <class T>
inline void rd(T &x) {
x=0; char c=getchar(); int f=1;
while(!isdigit(c)) { if(c=='-') f=-1; c=getchar(); }
while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f;
}
const int N=(1<<16)+10,mod=998244353;
inline void Add(int &x,int y) { x+=y; if(x>=mod) x-=mod; }
inline void Dec(int &x,int y) { x-=y; if(x<0) x+=mod; }
inline int Add(int x) { return x>=mod?x-mod:x; }
inline int Dec(int x) { return x<0?x+mod:x; }
int Pow(int x,int y) {
int res=1;
while(y) {
if(y&1) res=res*(ll)x%mod;
x=x*(ll)x%mod,y>>=1;
}
return res;
}
int G[N],Q[N],InvQ[N],m;
namespace Poly {
int wn[2][N];
void getwn(int l) {
for(int i=1;i<(1<<l);i<<=1) {
int w0=Pow(3,(mod-1)/(i<<1));
int w1=Pow(3,mod-1-(mod-1)/(i<<1));
wn[0][i]=wn[1][i]=1;
for(int j=1;j<i;++j) {
wn[0][i+j]=wn[0][i+j-1]*(ll)w0%mod;
wn[1][i+j]=wn[1][i+j-1]*(ll)w1%mod;
}
}
}
int rev[N];
void getr(int l) { for(int i=0;i<(1<<l);++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<l-1); }
void FFT(int *A,int len,int f) {
for(int i=0;i<len;++i) if(rev[i]<i) swap(A[i],A[rev[i]]);
for(int l=1;l<len;l<<=1)
for(int p=l<<1,i=0;i<len;i+=p)
for(int j=0;j<l;++j) {
int t1=A[i+j],t2=A[i+l+j]*(ll)wn[f][l+j]%mod;
A[i+j]=Add(t1+t2);
A[i+l+j]=Dec(t1-t2);
}
if(f==1) {
int ilen=Pow(len,mod-2);
for(int i=0;i<len;++i) A[i]=A[i]*(ll)ilen%mod;
}
}
void Mul(int *A,int *B,int *C,int l1,int l2,int l3) {
static int a[N],b[N];
int len=1,cnt=0; while(len<=max(l1-1+l2-1,l3-1)) len<<=1,cnt++; getr(cnt);
for(int i=0;i<len;++i) a[i]=b[i]=0;
for(int i=0;i<l1;++i) a[i]=A[i];
for(int i=0;i<l2;++i) b[i]=B[i];
FFT(a,len,0),FFT(b,len,0);
for(int i=0;i<len;++i) a[i]=a[i]*(ll)b[i]%mod;
FFT(a,len,1);
for(int i=0;i<l3;++i) C[i]=a[i];
}
int C[N],P[N];
void Inv(int *A,int *B,int n) {
if(n==1) return (void)(B[0]=Pow(A[0],mod-2));
int l=(n+1)>>1; Inv(A,B,l);
Mul(B,B,C,l,l,n);
Mul(C,A,C,n,n,n);
for(int i=l;i<n;++i) B[i]=0;
for(int i=0;i<n;++i)
B[i]=(2ll*B[i]-C[i]+mod)%mod;
}
void Rev(int *A,int *B,int n) {
for(int i=0;i<=n;++i) B[i]=A[i];
for(int i=0;i<=n;++i) {
if(n-i<=i) break;
swap(B[i],B[n-i]);
}
}
void Div(int *F,int *B,int n) {
static int A[N],D[N];
Rev(F,P,n);
// for(int i=0;i<=n-m;++i) printf("%d ",Q[i]); puts("");
// Inv(Q,D,n-m+1);
Mul(InvQ,P,A,n-m+1,n-m+1,n-m+1);
Rev(A,A,n-m);
Mul(A,G,D,n-m+1,m+1,n+1);
for(int i=0;i<=n;++i) B[i]=(F[i]-D[i]+mod)%mod;
}
void predoG(int n) {
Rev(G,Q,m);
Inv(Q,InvQ,n-m+1);
}
}
using Poly::Mul;
using Poly::Div;
struct Mat {
int a[N];
Mat () { memset(a,0,sizeof(a)); }
int& operator [] (int i) { return a[i]; }
friend Mat operator *(Mat A,Mat B) {
Mat C;
Mul(A.a,B.a,C.a,m,m,2*m-1);
Div(C.a,C.a,2*m-2);
return C;
}
};
Mat Pow(Mat x,int y) {
Mat res; res[0]=1;
while(y) {
if(y&1) res=res*x;
x=x*x,y>>=1;
}
return res;
}
int h[N];
int main() {
Poly::getwn(16);
int n; rd(n),rd(m);
for(int i=1;i<=m;++i) {
int x; rd(x); x=(x%mod+mod)%mod;
G[m-i]=(mod-x)%mod;
}
G[m]=1;
Poly::predoG(2*m-2);
for(int i=0;i<m;++i) rd(h[i]),h[i]=(h[i]%mod+mod)%mod;
if(n<m) { printf("%d\n",h[n]); return 0; }
if(m==1) { printf("%d\n",h[0]*(ll)Pow((mod-G[0])%mod,n)%mod); return 0; }
Mat A; A[1]=1;
A=Pow(A,n);
int ans=0;
for(int i=0;i<m;++i) Add(ans,A[i]*(ll)h[i]%mod);
printf("%d",ans);
return 0;
}

loj

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#define PII pair<int,int>
#define MP make_pair
#define fir first
#define sec second
#define PB push_back
#define db long double
#define ll long long
using namespace std;
template <class T>
inline void rd(T &x) {
x=0; char c=getchar(); int f=1;
while(!isdigit(c)) { if(c=='-') f=-1; c=getchar(); }
while(isdigit(c)) x=x*10-'0'+c,c=getchar(); x*=f;
}
const int N=2010,mod=998244353;
inline void Add(int &x,int y) { x+=y; if(x>=mod) x-=mod; }
inline void Dec(int &x,int y) { x-=y; if(x<0) x+=mod; }
int Pow(int x,int y) {
int res=1;
while(y) {
if(y&1) res=res*(ll)x%mod;
x=x*(ll)x%mod,y>>=1;
}
return res;
}
int g[N][N],f[N][N],h[N],n,m,q,p,invq;
int pw[N];
int G[N];
struct Mat {
int a[N];
Mat() { memset(a,0,sizeof(a)); }
int& operator [](int i) { return a[i]; }
friend Mat operator *(Mat A,Mat B) {
Mat C;
for(int i=0;i<m;++i)
for(int j=0;j<m;++j)
Add(C[i+j],A[i]*(ll)B[j]%mod);
for(int i=m*2-2;i>=m;--i) if(C[i]) {
for(int j=0;j<m;++j) Dec(C[i-m+j],C[i]*(ll)G[j]%mod);
C[i]=0;
}
return C;
}
};
int sol() {
memset(f,0,sizeof(f));
memset(g,0,sizeof(g));
memset(h,0,sizeof(h));
memset(G,0,sizeof(G));
// for(int i=0;i<=m;++i) printf("%d ",pw[i]); printf(" - %d\n",q); puts("");
f[m+1][0]=g[m+1][0]=1;
for(int j=m;j>0;--j) {
g[j][0]=f[j][0]=1;
for(int i=1;i*j<=m;++i) {
for(int k=1;k<=i;++k)
Add(f[j][i],g[j+1][k-1]*(ll)g[j][i-k]%mod*pw[j]%mod*q%mod);
g[j][i]=(g[j+1][i]+f[j][i])%mod;
}
}
h[0]=1;
for(int i=1;i<=m+1;++i)
for(int j=0;j+1<=i&&j<=m;++j)
h[i]=(h[i]+h[i-j-1]*(ll)g[1][j]%mod*q)%mod;
if(n+1<=m+1) return h[n+1]*(ll)invq%mod;
G[m+1]=1;
for(int i=0;i<=m;++i) G[m-i]=(mod-g[1][i]*(ll)q%mod)%mod;
m++;
Mat x,res; x[1]=1,res[0]=1;
int y=n+1;
while(y) {
if(y&1) res=res*x;
x=x*x,y>>=1;
}
int ans=0;
for(int i=0;i<m;++i) Add(ans,res[i]*(ll)h[i]%mod);
return ans*(ll)invq%mod;
}
int main() {
int _m,x,y;
rd(n),rd(_m),rd(x),rd(y);
p=x*(ll)Pow(y,mod-2)%mod;
q=(1ll-p+mod)%mod; invq=Pow(q,mod-2);
pw[0]=1; for(int i=1;i<=_m;++i) pw[i]=pw[i-1]*(ll)p%mod;
int ans=0;
m=_m; ans=sol();
m=_m-1; Dec(ans,sol());
printf("%d",ans);
return 0;
}

参考:

https://www.cnblogs.com/Troywar/p/9078013.html