[HNOI2017]礼物 解题报告(数学常识+FFT)
[HNOI2017]礼物
https://ac.nowcoder.com/acm/problem/20122
估计多项式会吓到不少人,写个题解装个b。
我们数学题嘛,我们把原式写出来:
题目要求的就是:
设x序列都加u,y序列都加v,则
我们可以令u-v=t,则:
运用初中数学知识,我们我们可以把上式子看作一个二次函数,变量为t。
那么,我们就可以知道当u-v取多少的时候,相同x,y下,res能够最小。因此,亮度调整并不是本题的难点所在。
但是,题目还允许我们使得x对应的y不一样。观察上式子,我们发现:只有是同时关联x和y的,我们需要求出的是这个式子的最小值。
我们可以把这个式子变一下:,这里的x是原来的x反转之后的。我们发现,这是个一个卷积的形式。那么,对于卷积,我们自然想到多项式算法(FFT,NTT),因为没有取模,所以考虑FFT。
如果对多项式熟悉的小伙伴看到这里就可以码起来了,但像我这样的菜鸡,还不熟悉多项式和卷积,还需要看下具体过程。
我们假设两个多项式,为了方便说明,我们取具体数值:
则:
所以,我们可以得到结论:n次多项式a和m次多项式b(n>=m)相乘之后为:
也可以看作:
这里如果不明确,建议手算下。
因此,我们可以用多项式乘法解决问题。
其实,这里其实就是卷积的形式。这也是卷积和多项式乘法之间的关系桥梁。
那么,得到这个结论之后要怎么用呢?
我们令b[i+n]=b[i],也就是b的长度变为2n
此时,我们再那a,b作为多项式进行相乘:(这里举个n=2的例子,可能不够味)
结果为
其中,
我们可以非常清楚的看到,我们所需要的a[i]b[n-i]这个东西其实就是多项式相乘之后的第n到2*n-1项的系数,我们直接遍历求max就行了。
总结下:重点是要理解好卷积和多项式乘积的关系,有了这个,加上数学变形,其实就能写出这道题了。
复杂度应该是(nlogn)?但是常数挺大的。
#include <bits/stdc++.h> using namespace std; typedef long long ll; template<typename T> void read(T&x){ ll f=1; x=0; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-')f*=-1; ch=getchar(); } while(isdigit(ch)){ x=x*10+(ch-'0'); ch=getchar(); } x*=f; } //================================================ #define int ll const int maxn=3e5+100; #define complex Complex struct Complex{ double x,y; complex(double xx=0,double yy=0){ x=xx; y=yy; } friend complex operator+(const complex&a,const complex&b){return complex(a.x+b.x,a.y+b.y);} friend complex operator-(const complex&a,const complex&b){return complex(a.x-b.x,a.y-b.y);} friend complex operator*(const complex&a,const complex&b){return complex(a.x*b.x-a.y*b.y,a.x*b.y+b.x*a.y);} }numa[maxn],numb[maxn]; int a[maxn],b[maxn]; const double PI=acos(-1); int lim=1,r[maxn]; int n,m; int l; void fft(Complex*a,int lim,int type){ for(int i=0;i<=lim;i++)if(i<r[i])swap(a[i],a[r[i]]); for(ll mid=1;mid<lim;mid<<=1){ complex Wn(cos(PI/mid),type*sin(PI/mid)); int R=(mid<<1); for(ll j=0;j<lim;j+=R){ complex w(1,0); for(ll k=0;k<mid;k++,w=w*Wn){ complex x=a[j+k],y=w*a[j+mid+k]; a[j+k]=x+y; a[j+mid+k]=x-y; } } } } int cal(int a,int b,int x){ return a*x*x+b*x; } #define inf 0x3f3f3f3f int get_min(){ int res=0; int ab=0; for(int i=0;i<n;i++){ res+=a[i]*a[i]; res+=b[i]*b[i]; ab+=a[i]-b[i]; } int px=-ab/n; int P2=inf; int L=floor(-1.0*ab/n),R=ceil(-1.0*ab/n); for(int i=L;i<=R;i++){ P2=min(P2,n*i*i+2*i*ab); } return res+P2; } ll res[maxn]; int solve(){ int p1=get_min(); while(lim<=(n*3))lim<<=1,++l; for(int i=0;i<lim;++i)r[i]=((r[i>>1]>>1)|((i&1)<<(l-1))); for(int i=0;i<n;++i)numb[i+n].x=numb[i].x; reverse(numa,numa+n); fft(numa,lim,1);fft(numb,lim,1); for(int i=0;i<lim;++i)numa[i]=numa[i]*numb[i]; fft(numa,lim,-1); for(int i=0;i<lim;++i){ res[i]=(numa[i].x/lim+0.5); } ll tmp=0; for(int i=0;i<n;i++){ tmp=max(tmp,res[i+n]); } return p1-(tmp<<1); } signed main(){ //freopen("in.txt","r",stdin); read(n),read(m); for(int i=0;i<n;i++){ int x;read(x); a[i]=x; numa[i].x=x; } for(int i=0;i<n;++i){ int x;read(x); b[i]=x; numb[i].x=x; } printf("%lld\n",solve()); return 0; }