矩阵优化递推方程:构造矩阵与矩阵快速幂(2019.7.16训练)

poj 3233 Matrix Power Series

A S = A + A 2 + A 3 + + A k 题意:给你一个矩阵A,要你求矩阵 S = A + A^2 + A^3 + … + A^k AS=A+A2+A3++Ak

思路:用不化简的矩阵快速幂直接求和会超时,要推导出一个数学公式

S ( k ) = A + A 2 + A 3 + + A k S ( k ) = S ( k 1 ) + A k 设 S(k) = A + A^2 + A^3 + … + A^k, 则有 S(k) = S(k-1) + A^k S(k)=A+A2+A3++AkS(k)=S(k1)+Ak

2 × 2 B B [ <mstyle displaystyle="false" scriptlevel="0"> S ( k 1 ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> A k </mstyle> ] = [ <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> S ( k ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> <mstyle displaystyle="false" scriptlevel="0"> A k + 1 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> </mstyle> ] 设2×2的矩阵B满足 :B* \left[ \begin{matrix} S(k-1) \\ A^k \end{matrix} \right] = \\ \left[ \begin{matrix} &amp;S(k)&amp; \\ &amp;A ^{k+1}&amp; \end{matrix} \right] 2×2BB[S(k1)Ak]=[S(k)Ak+1]

B = [ <mstyle displaystyle="false" scriptlevel="0"> E </mstyle> <mstyle displaystyle="false" scriptlevel="0"> E </mstyle> <mstyle displaystyle="false" scriptlevel="0"> O </mstyle> <mstyle displaystyle="false" scriptlevel="0"> A </mstyle> ] E O 则可以得到B= \left[ \begin{matrix} E &amp; E\\ O &amp; A \end{matrix} \right] (E为单位阵,O为零矩阵) B=[EOEA]EO

B B k + 1 = [ <mstyle displaystyle="false" scriptlevel="0"> E </mstyle> <mstyle displaystyle="false" scriptlevel="0"> E + A + A 2 + A 3 + + A k </mstyle> <mstyle displaystyle="false" scriptlevel="0"> O </mstyle> <mstyle displaystyle="false" scriptlevel="0"> A k + 1 </mstyle> ] 让B与自身相乘,可得到B^{k+1}= \left[ \begin{matrix} E &amp; E + A + A^2 + A^3 + … + A^k\\ O &amp; A^{k+1} \end{matrix} \right] BBk+1=[EOE+A+A2+A3++AkAk+1]

B k + 1 E 所以只要求出矩阵B的k+1次方,它的右上角的子矩阵减去单位阵E,即为答案。 Bk+1E

#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
ll n,k,mod;
struct node
{
    ll m[62][62];
};
node s,B;
node mul(node x,node y)
{
    for(int i=0;i<2*n;i++)
        for(int j=0;j<2*n;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<2*n;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s;
    for(int i=0;i<2*n;i++)
        for(int j=0;j<2*n;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>n>>k>>mod;
    for(int i=0;i<n;i++)
        B.m[i][i]=B.m[i][i+n]=1;//B矩阵左上角和右上角的子矩阵均为E
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
        cin>>B.m[i+n][j+n];//B矩阵右下角的子矩阵为A
    s=quickpow(B,k+1);
    for(int i=0;i<n;i++)
        s.m[i][i+n]--;//右上角减去单位阵
    for(int i=0;i<n;i++)
        for(int j=n;j<2*n;j++)
        {
            s.m[i][j]=(s.m[i][j]%mod+mod)%mod;//防止出现负数
            j==2*n-1?printf("%lld\n",s.m[i][j]):printf("%lld ",s.m[i][j]);
        }
    return 0;
}

hdu 1588 Gauss Fibonacci

这题其实和上题有点联系,要求的和为 f(g(i)) for 0<=i<n,设这个和为S(n)
即S(n) = f(b) + f(k+b) + f(2*k+b) + … + f((n-1)*k+b)
由于f(n)是斐波那契数列,则有
2 × 2 A A [ <mstyle displaystyle="false" scriptlevel="0"> f ( n 1 ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> f ( n 2 ) </mstyle> ] = [ <mstyle displaystyle="false" scriptlevel="0"> f ( n ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> f ( n 1 ) </mstyle> ] A = [ <mstyle displaystyle="false" scriptlevel="0"> 1 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 1 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 1 </mstyle> <mstyle displaystyle="false" scriptlevel="0"> 0 </mstyle> ] 2×2的矩阵A满足 :A* \left[ \begin{matrix} f(n-1) \\ f(n-2) \end{matrix} \right] = \\ \left[ \begin{matrix} f(n) \\ f (n-1) \end{matrix} \right] 其中A= \left[ \begin{matrix} 1&amp;1 \\ 1&amp;0 \end{matrix} \right] 2×2AA[f(n1)f(n2)]=[f(n)f(n1)]A=[1110]

n &gt; = 2 A n [ <mstyle displaystyle="false" scriptlevel="0"> f ( 1 ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> f ( 0 ) </mstyle> ] = [ <mstyle displaystyle="false" scriptlevel="0"> f ( n + 1 ) </mstyle> <mstyle displaystyle="false" scriptlevel="0"> f ( n ) </mstyle> ] f ( 1 ) = 1 f ( 0 ) = 0 n&gt;=2时,有A^n* \left[ \begin{matrix} f(1) \\ f(0) \end{matrix} \right] = \\ \left[ \begin{matrix} f(n+1) \\ f (n) \end{matrix} \right] 其中f(1)=1,f(0)=0 n>=2An[f(1)f(0)]=[f(n+1)f(n)]f(1)=1f(0)=0

A n ( 2 × 2 ) A . m [ 1 ] [ 0 ] A . m [ 1 ] [ 1 ] 设A^n(2×2矩阵)的左下角元素为A.m[1][0],右下角元素为A.m[1][1] An(2×2)A.m[1][0]A.m[1][1]

A . m [ 1 ] [ 0 ] f ( 1 ) + A . m [ 1 ] [ 1 ] f ( 0 ) = f ( n ) f ( n ) = A . m [ 1 ] [ 0 ] 根据矩阵乘法,有A.m[1][0]*f(1)+A.m[1][1]*f(0)=f(n),即f(n)=A.m[1][0] A.m[1][0]f(1)+A.m[1][1]f(0)=f(n)f(n)=A.m[1][0]
则S(n) = Ab + Ak+b + A2*k+b + … + A(n-1)*k+b(A的幂次取[1][0]位置,也就是矩阵左下角的元素)
S(n) = Ab(E + Ak + A2*k + … + A(n-1)*k)
设B=Ak,则S(n) = Ab(E + B + B2 + … + Bn-1)
括号内的形式就和上题poj 3233差不多了,直接用上题解法求括号内的矩阵和,之后再乘以Ab,得到矩阵ans的左下角元素即为答案。

#include <bits/stdc++.h>
using namespace std;
const int N=4;
typedef long long ll;
ll n,k,b,mod;
struct node
{
    ll m[N][N];
};
node t,s1,s2,s3,ans,B,A={1,1,0,0,1};//不能写A={1,1,1,0},因为此处默认A是4*4的矩阵
node mul(node x,node y,ll n)
{
    node s;
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<n;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b,ll n)
{
    node s;
    for(int i=0;i<n;i++)
        for(int j=0;j<n;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a,n);}
        a=mul(a,a,n);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>k>>b>>n>>mod)
    {
        s1=quickpow(A,b,2);
        s2=quickpow(A,k,2);
        memset(B.m,0,sizeof(B.m));
        for(int i=0;i<2;i++)
        B.m[i][i]=B.m[i][i+2]=1;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++)
            B.m[i+2][j+2]=s2.m[i][j];
        s3=quickpow(B,n,4);
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++)
            t.m[i][j]=s3.m[i][j+2];
        ans=mul(s1,t,2);
        printf("%lld\n",ans.m[1][0]);
    }
    return 0;
}

hdu 4965 Fast Matrix Calculation

按照题目的步骤来算肯定是会超时的,因为A×B最大是1000×1000,再快速幂1e6次就超时了。
可以利用矩阵乘法的结合律,先算B×A,B×A最大只有6×6,这样快速幂能省很多时间。
原式(A×B)n*n = A×(B×A)n*n-1×B,利用这个公式计算即可。

#include <bits/stdc++.h>
using namespace std;
const int N=10,mod=6;
int n,k,ans,a[1010][10],b[10][1010],t[1010][10],s[1010][1010];
struct node
{
    int m[N][N];
};
node C,M;
node mul(node x,node y)
{
    node s;
    for(int i=0;i<k;i++)
        for(int j=0;j<k;j++)
        {
            s.m[i][j]=0;
            for(int p=0;p<k;p++)
                s.m[i][j]=(s.m[i][j]+x.m[i][p]*y.m[p][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,int b)
{
    node s;
    for(int i=0;i<k;i++)
        for(int j=0;j<k;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n>>k&&!(n==0&&k==0))
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                cin>>a[i][j];
        for(int i=0;i<k;i++)
            for(int j=0;j<n;j++)
                cin>>b[i][j];
        memset(C.m,0,sizeof(C.m));
        for(int i=0;i<k;i++)
            for(int j=0;j<k;j++)
                for(int p=0;p<n;p++)
                C.m[i][j]=(C.m[i][j]+b[i][p]*a[p][j]%mod)%mod;
        M=quickpow(C,n*n-1);//M=(B*A)^(n*n-1)
        memset(t,0,sizeof(t));
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                for(int p=0;p<k;p++)
                t[i][j]=(t[i][j]+a[i][p]*M.m[p][j]%mod)%mod;//t=A*M
        memset(s,0,sizeof(s));
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                for(int p=0;p<k;p++)
                s[i][j]=(s[i][j]+t[i][p]*b[p][j]%mod)%mod;//s=t*B
        ans=0;
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
            ans=ans+s[i][j];
        printf("%d\n",ans);
    }
    return 0;
}

hdu 4920 Matrix multiplication

普通的两矩阵相乘取模,时间优化在于取模,如果你取模的顺序写得不对,就超时了,比如以下代码:

#include <bits/stdc++.h>
using namespace std;
const int N=810,mod=3;
int n,a[N][N],b[N][N],s[N][N];
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n)
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                cin>>a[i][j];
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                cin>>b[i][j];
        memset(s,0,sizeof(s));
        for(int i=0;i<n;i++)
            for(int k=0;k<n;k++)
                for(int j=0;j<n;j++)
                    s[i][j]=(s[i][j]+a[i][k]*b[k][j]%mod)%mod;
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                j==n-1?printf("%d\n",s[i][j]):printf("%d ",s[i][j]);
    }
    return 0;
}

超时原因应该是取模都写到三重循环里了,O(n3)取模比较费时间。(取模也很耗时间啊我枯了)
其实只要先对原矩阵a、b的每个元素取模,最后对答案矩阵的每个元素取模就AC了。

#include <bits/stdc++.h>
using namespace std;
const int N=810,mod=3;
int n,a[N][N],b[N][N],s[N][N];
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n)
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
            {
                cin>>a[i][j];
                a[i][j]=a[i][j]%mod;//先对原矩阵每个元素取模
            }
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
            {
                cin>>b[i][j];
                b[i][j]=b[i][j]%mod;
            }
        memset(s,0,sizeof(s));
        for(int i=0;i<n;i++)
            for(int k=0;k<n;k++)
                for(int j=0;j<n;j++)
                    s[i][j]=s[i][j]+a[i][k]*b[k][j];
        for(int i=0;i<n;i++)
            for(int j=0;j<n;j++)
                j==n-1?printf("%d\n",s[i][j]%mod):printf("%d ",s[i][j]%mod);//输出答案时取模
    }
    return 0;
}

剩下的题目基本上都是一个套路,就是利用题目给你的递推方程构造矩阵A,然后求A的多少次幂乘前几项初始值构成的矩阵就能得到答案

HIT 2060 - Fibonacci Problem Again

#include <bits/stdc++.h>
using namespace std;
const int N=3,mod=1e9;
typedef long long ll;
ll a,b,ans1,ans2;
struct node
{
    ll m[N][N];
};
node s1,s2,A={1,1,1,0,1,1,0,1,0},E={1,0,0,0,1,0,0,0,1};
node mul(node x,node y)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s=E;
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>a>>b&&!(a==0&&b==0))
    {
        s1=quickpow(A,b-1);//b>=1恒成立
        ans1=s1.m[0][0]*2+s1.m[0][1]+s1.m[0][2];
        if(a==0) ans2=0;
        else if(a==1) ans2=1;
        else
        {
            s2=quickpow(A,a-2);
            ans2=s2.m[0][0]*2+s2.m[0][1]+s2.m[0][2];
        }
        printf("%lld\n",((ans1-ans2)%mod+mod)%mod);//相减后可能负数要加上mod再取模,否则会错
    }
    return 0;
}

HIT 2255 - Not Fibonacci

#include <bits/stdc++.h>
using namespace std;
const int N=3,mod=1e7;
typedef long long ll;
ll t,a,b,p,q,s,e,ans1,ans2;
struct node
{
    ll m[N][N];
};
node s1,s2,A,E={1,0,0,0,1,0,0,0,1};
node mul(node x,node y)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s=E;
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    cin>>t;
    while(t--)
    {
        cin>>a>>b>>p>>q>>s>>e;//[s,e]区间和
        A={1,p,q,0,p,q,0,1,0};
        if(e==0)ans1=a;
        else
        {
            s1=quickpow(A,e-1);
            ans1=s1.m[0][0]*(a+b)+s1.m[0][1]*b+s1.m[0][2]*a;
        }
        if(s==0) ans2=0;
        else if(s==1) ans2=a;
        else
        {
            s2=quickpow(A,s-2);
            ans2=s2.m[0][0]*(a+b)+s2.m[0][1]*b+s2.m[0][2]*a;
        }
        printf("%lld\n",((ans1-ans2)%mod+mod)%mod);
    }
    return 0;
}

hdu 3306 Another kind of Fibonacci

#include <bits/stdc++.h>
using namespace std;
const int N=4,mod=10007;
typedef long long ll;
ll n,p,q,ans;
struct node
{
    ll m[N][N];
};
node s,A,E={1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1};
node mul(node x,node y)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s=E;
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>n>>p>>q)
    {
        A={1,p*p%mod,2*p*q%mod,q*q%mod,0,p*p%mod,2*p*q%mod,q*q%mod,0,p%mod,q%mod,0,0,1,0,0};
        s=quickpow(A,n-1);
        ans=s.m[0][0]*2%mod+s.m[0][1]%mod+s.m[0][2]%mod+s.m[0][3]%mod;//一定要取模,否则会错
        printf("%lld\n",ans%mod);
    }
    return 0;
}

hdu 1757 A Simple Math Problem

#include <bits/stdc++.h>
using namespace std;
const int N=10;
typedef long long ll;
ll k,mod,ans,a[N];
struct node
{
    ll m[N][N];
};
node s,A;
node mul(node x,node y)
{
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            s.m[i][j]=0;
            for(int k=0;k<N;k++)
                s.m[i][j]=(s.m[i][j]+x.m[i][k]*y.m[k][j]%mod)%mod;
        }
    return s;
}
node quickpow(node a,ll b)
{
    node s;
    for(int i=0;i<N;i++)
        for(int j=0;j<N;j++)
        {
            if(i==j)s.m[i][j]=1;
            else s.m[i][j]=0;
        }
    while(b)
    {
        if(b&1){b--;s=mul(s,a);}
        a=mul(a,a);b=b/2;
    }
    return s;
}
int main()
{
    ios::sync_with_stdio(false);
    while(cin>>k>>mod)
    {
        for(int i=0;i<N;i++)
            cin>>a[i];
        if(k<10){printf("%lld\n",a[k]%mod);continue;}
        /*A={a[0],a[1],a[2],a[3],a[4],a[5],a[6],a[7],a[8],a[9], 1,0,0,0,0,0,0,0,0,0, 0,1,0,0,0,0,0,0,0,0, 0,0,1,0,0,0,0,0,0,0, 0,0,0,1,0,0,0,0,0,0, 0,0,0,0,1,0,0,0,0,0, 0,0,0,0,0,1,0,0,0,0, 0,0,0,0,0,0,1,0,0,0, 0,0,0,0,0,0,0,1,0,0, 0,0,0,0,0,0,0,0,1,0};*/
        for(int i=0;i<N;i++)
        {
            A.m[0][i]=a[i];
            if(i>=1)A.m[i][i-1]=1;
        }
        s=quickpow(A,k-9);
        ans=0;
        for(int i=0;i<N;i++)
            ans=ans+s.m[0][i]*(9-i)%mod;
        printf("%lld\n",ans%mod);
    }
    return 0;
}
全部评论

相关推荐

点赞 评论 收藏
分享
预计下个星期就能开奖吧,哪位老哥来给个准信
华孝子爱信等:对接人上周说的是这周
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务