字符串 hash + poj 2774 (hash+二分)

所谓字符串hash 就是将一串字符串 hash对应的一个值,一般后续的查找,匹配等等

比较常见的方法是,将字符串看作是一个p进制的值,然后对这个很大的数进行取模,得到的就是该字符串的hash值,根据前人的经验,为了减少冲突,我们将p 选择为31 ,131 这些数字,并且为了方便,我们直接把数值类型定义为unsigned long long 这样,我们就不许要手动取模,当数值溢出的时候, unsigned long long 自然会对2的64次取模,然后就可以得到一串字符串的hash值了

poj 2774 就是让你求两个字符串的最长公共子串的长度
解题思路是,二分字符串的长度,然后hash以后判断是否存在即可
复杂度o(nlogn)

代码

#include <iostream>
#include <algorithm>
#include <stdlib.h>
#include <cstring>
#include <vector>
#include <map>
#include <set>
#include <stdio.h>
#include <queue>
#include <stack>
#define cl(a) memset(a,0,sizeof(a))
#define ll long long
#define pb(i) push_back(i)
//#define mp make_pair
using namespace std;
const int maxn=2e5+50;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
typedef pair<int,int> PII;
string s,s2;
int len1,len2;
typedef unsigned long long ULL;
vector<ULL>shash;
const int p=31;
bool check(int x)
{
    shash.clear();
    ULL tmp=0;
    for(int i=0;i<x;i++)
    {
        tmp = tmp*p+s[i];
    }
    shash.push_back(tmp);
    ULL base =1;
    for(int i=0;i<x;i++)
    {
        base = base *p;
    }
    for(int i=x;i<len1;i++)
    {
        tmp = tmp * p+s[i] - base*s[i-x];
        shash.push_back(tmp);
    }
    sort(shash.begin(),shash.end());
    ULL Stmp=0;
    for(int i=0;i<x;i++)
    {
        Stmp = Stmp*p+s2[i];
    }
    if(binary_search(shash.begin(),shash.end(),Stmp))
    {
        return true;
    }
    for(int i=x;i<len2;i++)
    {
        Stmp = Stmp * p +s2[i] - base *s2[i-x];
        if(binary_search(shash.begin(),shash.end(),Stmp))
        {
            return true;
        }
    }
    return false;
}

int n,d,k;
ll fpow(ll n1, ll k)
{
    ll r = 1;
    for (; k; k >>= 1)
    {

        if (k & 1) r = r * n1;
        if(r>4e5)return 1;
        n1 = n1 * n1;
    }
    return r;
}
int main()
{

    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    while(cin>>s>>s2)
    {
        len1=s.size();
        len2=s2.size();
        int l=0,r=min(len1,len2);
        int ans=0;
        while(l<=r)
        {
            int mid = l + (r-l)/2;
            if(check(mid))
            {
                ans=mid;

                l = mid+1;
            }
            else r=mid-1;
        }
        cout<<ans<<endl;
    }


    return 0;
}

看到有人说改成数组过不了,我改了试了一下,发现没有问题,如果有需要的话,参考下面的代码

#include <iostream>
#include <algorithm>
#include <stdlib.h>
#include <cstring>
#include <vector>
#include <map>
#include <set>
#include <stdio.h>
#include <string>
#include <queue>
#include <stack>
#define cl(a) memset(a,0,sizeof(a))
#define ll long long
#define pb(i) push_back(i)
//#define mp make_pair
using namespace std;
const int maxn=2e5+50;
const int inf=0x3f3f3f3f;
const int mod=1e9+7;
typedef pair<int,int> PII;
string s,s2;
int len1,len2;
typedef unsigned long long ULL;
//vector<ULL>shash;
ULL thash[maxn];
const int p=31;
bool check(int x)
{
    //shash.clear();
    ULL tmp=0;
    int idx=0;
    for(int i=0;i<x;i++)
    {
        tmp = tmp*p+s[i];
    }
    //shash.push_back(tmp);
    thash[idx++]=tmp;
    ULL base =1;
    for(int i=0;i<x;i++)
    {
        base = base *p;
    }
    for(int i=x;i<len1;i++)
    {
        tmp = tmp * p+s[i] - base*s[i-x];
        //shash.push_back(tmp);
        thash[idx++]=tmp;
    }
    //sort(shash.begin(),shash.end());
    sort(thash,thash+idx);
    ULL Stmp=0;
    for(int i=0;i<x;i++)
    {
        Stmp = Stmp*p+s2[i];
    }
    if(binary_search(thash,thash+idx,Stmp))
    {
        return true;
    }
    for(int i=x;i<len2;i++)
    {
        Stmp = Stmp * p +s2[i] - base *s2[i-x];
        if(binary_search(thash,thash+idx,Stmp))
        {
            return true;
        }
    }
    return false;
}

int n,d,k;
ll fpow(ll n1, ll k)
{
    ll r = 1;
    for (; k; k >>= 1)
    {

        if (k & 1) r = r * n1;
        if(r>4e5)return 1;
        n1 = n1 * n1;
    }
    return r;
}
int main()
{

    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    while(cin>>s>>s2)
    {
        len1=s.size();
        len2=s2.size();
        int l=0,r=min(len1,len2);
        int ans=0;
        while(l<=r)
        {
            int mid = l + (r-l)/2;
            if(check(mid))
            {
                ans=mid;

                l = mid+1;
            }
            else r=mid-1;
        }
        cout<<ans<<endl;
    }


    return 0;
}
全部评论

相关推荐

点赞 评论 收藏
分享
评论
点赞
收藏
分享
牛客网
牛客企业服务