题解 | #平衡的选择题#

平衡的选择题

http://www.nowcoder.com/practice/8245dc44313d47338015bf57278b9a8e

题意

做n次选择,每次选择(A,B,C,D,AB,AC,AD,BC,BD,CD,ABC,ACD,ABD,BCD,ABCD)中的一种

限制,每次选择完后,从最开始到当前选择的所有内容中,A和C出现次数差的绝对值不大于1,B和D出现次数差的绝对值不大于2

给定 n1e5n \leq 1e5n1e5, 求完成n次选择的方案数

算法

直接模拟递推

把题目抽象成数学。

  1. 每次在111151515 中选择一个数,以下是简单的映射关系
A B C D AB AC AD BC BD CD ABC ACD ABD BCD ABCD
0001 0010 0100 1000 0011 0101 1001 0110 1010 1100 0111 1101 1011 1110 1111

A,B,C,D是否被选,就和对应二进制位是否为1对应

  1. 记录A和C个数的差值,记录B和D个数的差值,并按照题目限制控制合法的方案。

ans[i][j]=ans[i][j] = ans[i][j]= 表示到当前位置,AC的差值为i,BD的差值为j的方案数

每次模拟选择一个数(1~15)

ans[AC][BD]+=ans[ACAC][BDBD]ans_{当前}[AC][BD] += ans_{上一次}[AC - 选择数导致的A和C差值的变化][BD - 选择数导致的B和D差值的变化]ans[AC][BD]+=ans[ACAC][BDBD]

  1. 变成代码

注意到 C++ 下标不能使用负数,我们分别做值映射

对于AC的差值 1,0,1=>0,1,2-1,0,1 => 0,1,21,0,1=>0,1,2

对于BD的差值 2,1,0,1,2=>0,1,2,3,4-2,-1,0,1,2 => 0,1,2,3,42,1,0,1,2=>0,1,2,3,4

所以默认值 ans[i][j]=0ans[i][j] = 0ans[i][j]=0,其中 ans[1][2]=1ans[1][2] = 1ans[1][2]=1 表示,还未选择前,AC和BD的差值都为0的情况

以此循环n次,可以计算出n次选择后,AC和BD不同差值的方案数

最终的答案为i=0..2,j=0..4ans[i][j]\sum_{i=0..2,j=0..4} ans[i][j]i=0..2,j=0..4ans[i][j]

对于需要取模的部分注意取模即可

代码

class Solution {
public:
    /**
     * 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
     * @param n int整型 
     * @return int整型
     */
    int solve(int n) {
        const int mod = 1000000007;
        // arr[i][j] = 
        //  A的个数减去C的个数+1为i
        //  B的个数减去D的个数+2为j
        // 时的方案数
        vector<vector<long long> > arr = vector<vector<long long> >(3,vector<long long>(5,0));
        arr[1][2] = 1;
        // 递推n个位置
        for(int i = 0;i < n;i++){
            // 状态转移的结果为ans
            vector<vector<long long> > ans = vector<vector<long long> >(3,vector<long long>(5,0));
            for(int ac = 0;ac < 3;ac++){ // 枚举现有的ac差
                for(int bd = 0;bd < 5;bd++){  // 枚举现有的bd的差
                    for(int j = 1;j < 16;j++){ // 枚举选项
                        int diff_ac = (j>>0)%2 - (j>>1)%2; // 计算ac差的变化
                        int diff_bd = (j>>2)%2 - (j>>3)%2; // 计算bd差的变化

                        int new_ac = ac+diff_ac; // 新的ac的差
                        if( new_ac < 0 || new_ac >= 3)continue;
                        int new_bd = bd+diff_bd; // 新的bd的差
                        if( new_bd < 0 || new_bd >= 5)continue;
                        (ans[new_ac][new_bd] += arr[ac][bd])%=mod; // 状态转移
                    }
                }
            }
            arr = ans;
        }
        long long result = 0;
        for(int ac = 0;ac<3;ac++){ // 枚举ac的差
            for(int bd = 0;bd < 5;bd++){ // 枚举bd的差
                (result += arr[ac][bd])%=mod; // 统计答案
            }
        }
        return result;
    }
};

复杂度分析

时间复杂度: 我们循环了nnn次,每一次模拟选择,模拟选择的代价是常数35163\cdot 5\cdot 163516, 所以 总时间复杂度为O(n)O(n)O(n)

空间复杂度: 我们仅用了 一个常数大小的结果数组,和一个常数大小的临时数组来记录方案,所以空间复杂度是常数O(1)O(1)O(1)

矩阵乘法/快速幂

我们发现,上面的递推关系中,与nnn无关,且每次转换关系又是线性加和。

所以我们把AC和BD的值看成一个状态整体(代码中encode函数实现),有35=153\cdot 5=1535=15

不同状态整体的转义系数是常数,满足这个条件,就可以变成矩阵乘法 其中iiijjj列表示,上一个状态是i,转移为状态jjj的方案数,矩阵为151515\cdot151515的大小,太大不要手动推算,矩阵具体的值由代码算出。

而矩阵乘法可以使用快速幂来提高效率

考虑到初始 矩阵为 (010)(0\cdots 1 \cdots 0)(010), 仅有表示AC和BD差值为0的项(也就是 encode(1,2))为1

所以最终的答案为 i=0..15(basematrix)n[encode(1,2)][i]\sum_{i=0..15} (basematrix)^n[encode(1,2)][i]i=0..15(basematrix)n[encode(1,2)][i]

代码

class Solution {
public:
    typedef long long ll;
    #define rep(i,a,n) for (ll i=a;i<n;i++)
    const int mod = 1000000007;
    // 矩阵乘法
    vector<vector<ll>> mul(vector<vector<ll>>& m1,vector<vector<ll>>& m2){
        vector<vector<ll>> res = vector<vector<ll>>(m1.size(),vector<ll>(m2[0].size(),0));
        rep(i,0,m1.size()){
            rep(j,0,m2[0].size()){
                rep(k,0,m1[0].size()){
                    (res[i][j]+=m1[i][k]*m2[k][j]%mod)%=mod;
                }
            }
        }
        return res ;
    }
    // 矩阵幂次
    vector<vector<ll>> matrixp(vector<vector<ll>>& m1, ll pwr){
        // 单位矩阵
        vector<vector<ll>> res = vector<vector<ll>>(m1.size(),vector<ll>(m1.size(),0));
        rep(i,0,m1.size()){
            res[i][i] = 1;
        }
        // 快速幂
        while(pwr){ // 幂次不为0
            if(pwr%2)res = mul(res,m1); // 当前二进制位为1 则乘上翻倍后的基数
            m1 = mul(m1,m1); // 基数翻倍
            pwr/=2; // 幂次除以2
        }
        return res ;
    }
    
    int encode(int v0,int v1){ // v1 最大小于5,所以编码两个数成一个数作为状态
        return v0*5+v1;
    }
    
    /**
     * 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
     * @param n int整型 
     * @return int整型
     */
    int solve(int n) {
        const int sz = 3*5;
        // 状态转换关系计算
        vector<vector<long long> > matrix = vector<vector<long long> >(sz,vector<long long>(sz,0));
        for(int ac = 0;ac < 3;ac++){ // 枚举所有A-C+1的差
            for(int bd = 0;bd < 5;bd++){  // 枚举所有B-D+2的差
                for(int j = 1;j < 16;j++){ // 枚举所有选项
                    int diff_ac = (j>>0)%2 - (j>>1)%2; // AC差的变化
                    int diff_bd = (j>>2)%2 - (j>>3)%2; // BD差的变化
                    int new_ac = ac+diff_ac; // 新的AC的差
                    if( new_ac < 0 || new_ac >= 3)continue;
                    int new_bd = bd+diff_bd; // 新的BD的差
                    if( new_bd < 0 || new_bd >= 5)continue;
                    matrix[encode(ac,bd)][encode(new_ac,new_bd)] += 1; // 写状态转移矩阵
                }
            }
        }
        // 矩阵n次方
        vector<vector<long long> > matrixResult = matrixp(matrix, n);
        
        long long result = 0;
        for(int idx = 0;idx<sz;idx++){ // 所有合法的状态
            (result += matrixResult[encode(1,2)][idx])%=mod; // 统计合法的结果
        }
        return result;
    }
};

复杂度分析

时间复杂度: 我们通过快速幂,计算的是转换矩阵的nnn次方,矩阵大小为常数,所以总时间复杂度为O(log(n))O(log(n))O(log(n))

空间复杂度: 我们仅用了 一个常数大小的结果矩阵,和一个计算矩阵幂次的非递归函数,所以空间复杂度是常数O(1)O(1)O(1)

知识点

  1. 对问题的抽象化,虽然题目是ABCD,但是实际上因为是选择所有情况,熟悉二进制的应该能立刻想到1到15能完成一一映射
  2. int 的溢出,虽然输入输出都是int,但是 涉及到int的加法乘法,可能有溢出的情况时,记得使用long long 来完成中间过程的运算避免溢出
  3. 与项数无关的递推式,可以想到矩阵乘法
全部评论

相关推荐

10-25 00:32
香梨想要offer:感觉考研以后好好学 后面能乱杀,目前这简历有点难
点赞 评论 收藏
分享
11-01 20:03
已编辑
门头沟学院 算法工程师
Amazarashi66:这种也是幸存者偏差了,拿不到这个价的才是大多数
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务