[HNOI2017]礼物 解题报告(数学常识+FFT)

[HNOI2017]礼物

https://ac.nowcoder.com/acm/problem/20122

估计多项式会吓到不少人,写个题解装个b。
我们数学题嘛,我们把原式写出来:
题目要求的就是:
设x序列都加u,y序列都加v,则


我们可以令u-v=t,则:


运用初中数学知识,我们我们可以把上式子看作一个二次函数,变量为t。
那么,我们就可以知道当u-v取多少的时候,相同x,y下,res能够最小。因此,亮度调整并不是本题的难点所在。
但是,题目还允许我们使得x对应的y不一样。观察上式子,我们发现:只有是同时关联x和y的,我们需要求出的是这个式子的最小值。
我们可以把这个式子变一下:,这里的x是原来的x反转之后的。我们发现,这是个一个卷积的形式。那么,对于卷积,我们自然想到多项式算法(FFT,NTT),因为没有取模,所以考虑FFT。
如果对多项式熟悉的小伙伴看到这里就可以码起来了,但像我这样的菜鸡,还不熟悉多项式和卷积,还需要看下具体过程。
我们假设两个多项式,为了方便说明,我们取具体数值:

则:
所以,我们可以得到结论:n次多项式a和m次多项式b(n>=m)相乘之后为:


也可以看作:

这里如果不明确,建议手算下。
因此,我们可以用多项式乘法解决问题。
其实,这里其实就是卷积的形式。这也是卷积和多项式乘法之间的关系桥梁。
那么,得到这个结论之后要怎么用呢?
我们令b[i+n]=b[i],也就是b的长度变为2n
此时,我们再那a,b作为多项式进行相乘:(这里举个n=2的例子,可能不够味)
结果为


其中,
我们可以非常清楚的看到,我们所需要的a[i]b[n-i]这个东西其实就是多项式相乘之后的第n到2*n-1项的系数,我们直接遍历求max就行了。
总结下:重点是要理解好卷积和多项式乘积的关系,有了这个,加上数学变形,其实就能写出这道题了。
复杂度应该是(nlogn)?但是常数挺大的。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
template<typename T>
void read(T&x){
    ll f=1;
    x=0;
    char ch=getchar();
    while(!isdigit(ch)){
        if(ch=='-')f*=-1;
        ch=getchar();
    }
    while(isdigit(ch)){
        x=x*10+(ch-'0');
        ch=getchar();
    }
    x*=f;
}
//================================================
#define int ll
const int maxn=3e5+100;
#define complex Complex
struct Complex{
    double x,y;
    complex(double xx=0,double yy=0){
        x=xx;
        y=yy;
    }
    friend complex operator+(const complex&a,const complex&b){return complex(a.x+b.x,a.y+b.y);}
    friend complex operator-(const complex&a,const complex&b){return complex(a.x-b.x,a.y-b.y);}
    friend complex operator*(const complex&a,const complex&b){return complex(a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y);}
}numa[maxn],numb[maxn];
int a[maxn],b[maxn];
const double PI=acos(-1);
int lim=1,r[maxn];
int n,m;
int l;
void fft(Complex*a,int lim,int type){
    for(int i=0;i<=lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
    for(ll mid=1;mid<lim;mid<<=1){
        complex Wn(cos(PI/mid),type*sin(PI/mid));
        int R=(mid<<1);
        for(ll j=0;j<lim;j+=R){
            complex w(1,0);
            for(ll k=0;k<mid;k++,w=w*Wn){
                complex x=a[j+k],y=w*a[j+mid+k];
                a[j+k]=x+y;
                a[j+mid+k]=x-y;
            }
        }
    }
}

int cal(int a,int b,int x){
    return a*x*x+b*x;
}
#define inf 0x3f3f3f3f
int get_min(){
    int res=0;
    int ab=0;
    for(int i=0;i<n;i++){
        res+=a[i]*a[i];
        res+=b[i]*b[i];
        ab+=a[i]-b[i];
    }
    int px=-ab/n;
    int P2=inf;
    int L=floor(-1.0*ab/n),R=ceil(-1.0*ab/n);
    for(int i=L;i<=R;i++){
        P2=min(P2,n*i*i+2*i*ab);
    }
    return res+P2;
}
ll res[maxn];
int solve(){
    int p1=get_min();
    while(lim<=(n*3))lim<<=1,++l;
    for(int i=0;i<lim;++i)r[i]=((r[i>>1]>>1)|((i&1)<<(l-1)));
    for(int i=0;i<n;++i)numb[i+n].x=numb[i].x;
    reverse(numa,numa+n);
    fft(numa,lim,1);fft(numb,lim,1);
    for(int i=0;i<lim;++i)numa[i]=numa[i]*numb[i];
    fft(numa,lim,-1);
    for(int i=0;i<lim;++i){
        res[i]=(numa[i].x/lim+0.5);
    }
    ll tmp=0;
    for(int i=0;i<n;i++){
        tmp=max(tmp,res[i+n]);
    }
    return p1-(tmp<<1);
}

signed main(){
    //freopen("in.txt","r",stdin);
    read(n),read(m);
    for(int i=0;i<n;i++){
        int x;read(x);
        a[i]=x;
        numa[i].x=x;
    }
    for(int i=0;i<n;++i){
        int x;read(x);
        b[i]=x;
        numb[i].x=x;
    }
    printf("%lld\n",solve());
    return 0;
}
全部评论

相关推荐

09-29 11:19
门头沟学院 Java
点赞 评论 收藏
分享
点赞 收藏 评论
分享
牛客网
牛客企业服务