题解 | #Hacker#
Hacker
https://ac.nowcoder.com/acm/contest/33188/H
后缀数组做法
后缀数组的做法一般都包括三个数组::所有后缀中字典序第 大的是从位置 开始的后缀; :位置 开始的后缀在所有后缀中字典序排第 ; :高度数组,字典序第 大得缀与字典序第 大的后缀的最长公共前缀(longest common prefix)。另外通过建立 上的 st表,可以求得任意两个后缀 和 的最长公共前缀: 。
大体思路:对于当前字符串 的每个位置 , 找到从 开始的最长的子串,且这个子串在 中出现过。也就是找到一个最长的长度 ,满足 的子串 同时也是 的一个子串;然后再就是找以位置 为左端点,长度在 之内所有的区间权值和的最大值,也就是 ,这部分可以对数组 的前缀和数组建 表,然后查询区间 的最大值。
关于如何找到 以位置 开头,在 中出现过的最长子串:
把 和所有 连在一起,中间用没出现过的字符(例如 '$')分隔,得到的字符串记为 。然后对 跑后缀数组。在拼接的过程中,维护每个 在 中的起始下标,和第 个字符对应于原来哪一个串,后面要用。
跑出来后缀数组后,按字典序遍历所有后缀。借助前面维护的信息,我们可以知道当前后缀对应哪个 或者说对应 ; 如果当前后缀是对应某个 的,就找到离它最近的,属于 串的后缀,求它们之间的 ,这个 就是我们前面要求的那个 。
比如题目样例一:
(左边的三列数字分别代表:字典序大小,sa 的值,lcp 的值)
字典序第 大和第 大的后缀都是原属于 的后缀,第 大的后缀对应于 从下标 1 开始的后缀,也就是它本身。 (为啥 \usderset这里用不了qaq),那么 从 1 开始,长度在 以内的子串都在 中出现过,求得这个最长的公共子串长度后,用 st 表求一个权重前缀和的最大值,更新 的答案即可。
有两个需要注意的点:
1、本来求 lcp 的部分我是用先从前到后和从后到前循环一遍,记录每个位置左边第一个和右边第一个属于 的后缀,然后 st 表查询区间 lcp 最小值,但是超时了。后来发现在循环的时候直接维护最小值就行了,不需要构建 st 表。
2、所有 之间可以用 '$' 分隔,但是 和 之间最好再换个,不然下面这种数据,跑出来的 lcp 可能处理起来有点麻烦
最后勉强 700+ms 跑过:
#include<iostream>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int INF = 1e9+7;
const long long INFq = 1e18+7;
const long long mode = 998244353;
const int MAX_N = 1300100;
char s[MAX_N];
///----- SA-IS template -----
int sa[MAX_N], Rank[MAX_N], lcp[MAX_N];
int str[MAX_N<<1], Type[MAX_N<<1], p[MAX_N], cnt[MAX_N], cur[MAX_N];
#define pushS(x) sa[cur[str[x]]--] = x
#define pushL(x) sa[cur[str[x]]++] = x
#define inducedSort(v) fill_n(sa, n, -1); fill_n(cnt, m, 0); \
for (int i = 0; i < n; i++) cnt[str[i]]++; \
for (int i = 1; i < m; i++) cnt[i] += cnt[i-1]; \
for (int i = 0; i < m; i++) cur[i] = cnt[i]-1; \
for (int i = n1-1; ~i; i--) pushS(v[i]); \
for (int i = 1; i < m; i++) cur[i] = cnt[i-1]; \
for (int i = 0; i < n; i++) if (sa[i] > 0 && Type[sa[i]-1]) pushL(sa[i]-1); \
for (int i = 0; i < m; i++) cur[i] = cnt[i]-1; \
for (int i = n-1; ~i; i--) if (sa[i] > 0 && !Type[sa[i]-1]) pushS(sa[i]-1)
void sais(int n, int m, int *str, int *Type, int *p) {
int n1 = Type[n-1] = 0, ch = Rank[0] = -1, *s1 = str+n;
for (int i = n-2; ~i; i--) Type[i] = str[i] == str[i+1] ? Type[i+1] : str[i] > str[i+1];
for (int i = 1; i < n; i++) Rank[i] = Type[i-1] && !Type[i] ? (p[n1] = i, n1++) : -1;
inducedSort(p);
for (int i = 0, x, y; i < n; i++) if (~(x = Rank[sa[i]])) {
if (ch < 1 || p[x+1] - p[x] != p[y+1] - p[y]) ch++;
else for (int j = p[x], k = p[y]; j <= p[x+1]; j++, k++)
if ((str[j]<<1|Type[j]) != (str[k]<<1|Type[k])) {ch++; break;}
s1[y = x] = ch;
}
if (ch+1 < n1) sais(n1, ch+1, s1, Type+n, p+n1);
else for (int i = 0; i < n1; i++) sa[s1[i]] = i;
for (int i = 0; i < n1; i++) s1[i] = p[sa[i]];
inducedSort(s1);
}
int mapCharToInt(int n) {
int m = *max_element(s, s+n);
fill_n(Rank, m+1, 0);
for (int i = 0; i < n; i++) Rank[s[i]] = 1;
for (int i = 0; i < m; i++) Rank[i+1] += Rank[i];
for (int i = 0; i < n; i++) str[i] = Rank[s[i]] - 1;
return Rank[m];
}
void SuffixArray(int n) {
// s[n] 一定要比 s 中所有字符 ascii 值小, s[n+1] 倒无所谓
s[n] = '!'; s[n+1]='\0';
int m = mapCharToInt(++n);
sais(n, m, str, Type, p);
for (int i = 0; i < n; i++) Rank[sa[i]] = i;
for (int i = 0, h = lcp[0] = 0; i < n-1; i++) {
int j = sa[Rank[i]-1];
while (i+h < n && j+h < n && s[i+h] == s[j+h]) h++;
if (lcp[Rank[i]-1] = h) h--;
}
s[n]='\0';
}
///----- End of SA-IS -----
long long st2[100010][20];
int lg[100010];
long long pref[100010];
void construct_st2(int n) {
for(int i=1;i<=n;i++)st2[i][0] = pref[i];
for(int k=1,len=2; len<=n; len*=2,k++) {
for(int i=1;i+len-1<=n;i++) {
st2[i][k] = max( st2[i][k-1], st2[i+len/2][k-1] );
}
}
}
inline long long query2(int x,int y) {
int k = lg[y-x+1];
return max( st2[x][k], st2[y-(1<<k)+1][k] );
}
int v[100010];
int start_pos[100010];
long long ans[100010];
int Map[MAX_N];
int main() {
ios::sync_with_stdio(false);
lg[1] = 0;
for(int i=2;i<=100000;i++) lg[i] = lg[i >> 1] + 1;
int n,m,k;
cin >> n >> m >> k;
cin >> s;
for(int i=1;i<=m;i++) cin >> v[i];
int tot_len = n-1;
for(int i=0;i<n;i++) Map[i] = 0; // Map 用来映射 s[i] 对应于原来那个串,0 就是 A;
for(int i=1;i<=k;i++) {
++ tot_len;
s[tot_len] = '$'; Map[tot_len] = -1; // -1 代表是分隔符;
start_pos[i] = tot_len+1; // 记录开始位置
cin >> ( s + tot_len + 1 );
for(int j=tot_len+1; j<=tot_len+m; j++) Map[j] = i; // 代表 s[j] 原属于 B_i
tot_len += m;
}
s[n] = '#'; // A 和 B_1 之间用 '#' 而非 '$'
++ tot_len;
SuffixArray(tot_len ); // 板子传入的参数是字符串的长度,下标从 0 开始, tot_len 是 '\0' 的位置
// s[tot_len] = '\0';
// cout << "s = " << s << '\n';
// for(int i=0;i<=tot_len;i++) {
// printf("%3d %3d %3d %s\n",i,sa[i],lcp[i],s+sa[i]);
// }
// cout << '\n';
// 构建前缀和
pref[0] = 0;
for(int i=1;i<=m;i++) pref[i] = pref[i-1] + v[i];
// 前缀和的区间最大值; 为啥是 st2?因为原本有个(多余的) st 用来求 lcp, 但超时了
construct_st2(m);
int Min = 0;
for(int i=1;i<=tot_len;i++) { // 从左到右遍历一边, 用每个后缀左边第一个属于 A 的后缀更新答案
int j = Map[ sa[i] ]; // sa[i] 代表字典序第 i 大的后缀在原串的起始位置,再用 Map 映射到原来对应的串
if( j == 0 ) {
Min = lcp[i]; // 是 A 的后缀,则重置 Min
}
else {
if( j > 0 && Min > 0 ) { // 对应 B_j 的某个后缀
int index = sa[i] - start_pos[j] + 1; // index 是对应的 B_j 的那个后缀的起始下标
long long Max = query2( index, index + Min - 1 ); // 查询区间 pref 最大值
ans[j] = max( ans[j] , Max - pref[index-1] ); // 更新答案
}
Min = min( Min, lcp[i] );
}
}
Min = 0;
for(int i=tot_len;i>0;i--) { // 从右到左遍历,用每个后缀右边第一个属于 A 的后缀更新答案,几乎一样的
int j = Map[ sa[i] ];
if( j == 0 ) {
Min = lcp[i-1];
}
else {
if( j > 0 && Min > 0 ) {
int index = sa[i] - start_pos[j] + 1;
long long Max = query2( index, index + Min - 1 );
ans[j] = max( ans[j] , Max - pref[index-1] );
}
Min = min( Min, lcp[i-1] );
}
}
for(int i=1;i<=k;i++) cout << ans[i] << '\n';
}
后缀自动机做法
和后缀数组的思路是一样的,不过这里对于字符串 的每个位置 ,是找以 结尾的最长的在 中出现过的子串,后面查询的也是 之间前缀和的最小值。
怎么找:
对 建立后缀自动机后,记当前的 为 (这样我能少打一个下标qwq),在 的自动机上跑匹配,假设 串第 的位置在 的自动机上匹配的最大子串长度为 ,对应自动机上的节点为 ,那么以 结尾的串肯定是某个以 结尾的串后面加上字符 ,我们就从 开始在 树中向上转移,直到遇到第一个存在字符 的出边的节点位置,这个过程中记录 ,最后 +1 就是 的答案。
嗯,自己写的自己都看不懂写的什么东西。 还是看图吧
假设 为 , 为 ,开始 设为 ,代表根节点, ,因为根节点对应的子串为空串, 的自动机长这样子:(每个节点块最后一行{}里的是该节点的 集——节点代表的子串在原串中的结束位置;黑色的边是 parent 的边,蓝色带箭头的是自动机的转移边,旁边的字母是对应的出边的类型;黑边上也有字母是因为 parent 的边和自动机的边重了;len 代表当前节点所代表的子串的最大长度)
每个节点对应的子串:
首先是 a
,正好 号节点有 a
的出边,走到 5 号节点,++,最大长度为
然后是 b
, 号节点有 b
的出边,走到 号节点,++,最大长度为
然后是 c
, 号节点有 c
的出边,走到 号节点,++,最大长度为
然后是 d
, 号节点没有 d
的出边,沿 数向上走,走到 号节点有 d
的出边,走到 号节点,
#include<iostream>
#include<cstdio>
#include<vector>
#include<cstring>
using namespace std;
const int MAX_N = 100010;
int par[MAX_N<<1], sam[MAX_N<<1][26],len[MAX_N<<1];
int last,tot;
void sam_extend(int ch) {
int p = last;
tot++;
int np = last = tot;
len[np] = len[p] + 1;
while( p>0 && sam[p][ch]==0 ){
sam[p][ch] = np;
p = par[p];
}
if( p==0 ){
par[np] = 1;
}
else{
int q = sam[p][ch];
if( len[q] == len[p]+1 )par[np] = q;
else{
tot++;
int nq = tot;
len[nq] = len[p]+1;
par[nq] = par[q];
for(int i=0;i<26;i++)sam[nq][i] = sam[q][i];
par[np] = par[q] = nq;
while( p>0 && sam[p][ch]==q ){
sam[p][ch] = nq;
p = par[p];
}
}
}
}
int last_pos, max_len;
void Go(int ch) {
int p = last_pos;
while( p > 0 && sam[p][ch] == 0 ) {
p = par[p];
max_len = len[p];
}
if( p == 0 ) {
// 如果 1 号根节点都没有 ch 的出边,说明字符 ch 在字符串中不存在
last_pos = 1;
}
else {
int q = sam[p][ch]; // 沿着出边走出去
++ max_len; // 就是当前的最大长度
last_pos = q;
}
}
long long pref[100010];
long long st[MAX_N][20];
int lg[MAX_N];
void construct_st(int n) {
for(int i=0;i<=n;i++)st[i][0] = pref[i];
for(int k=1,len=2; len<=n; len*=2,k++) {
for(int i=0;i+len-1<=n;i++) {
st[i][k] = min( st[i][k-1], st[i+len/2][k-1] );
}
}
}
long long query(int left,int right) {
int k = lg[right-left+1];
return min( st[left][k], st[right-(1<<k)+1][k] );
}
char s[100010], t[100010];
int v[100010];
long long ans[100010];
void print_sam(){
vector<int>edge[20];
for(int i=2;i<=tot;i++)edge[par[i]].push_back(i);
for(int i=1;i<=tot;i++) {
printf("child %d :",i); for(int u : edge[i])printf(" %d",u); printf("\n");
}
for(int i=1;i<=tot;i++) {
printf("sam %d :\n",i);
for(int j=0;j<26;j++) {
if( sam[i][j] > 0 ) {
printf(" %c -> %d\n",'a'+j,sam[i][j]);
}
}
}
}
int main() {
// cin.tie(nullptr) -> sync_with_stdio(false);
lg[1] = 0;
for(int i=2;i<=100000;i++) lg[i] = lg[i >> 1] + 1;
int n,m,k;
cin >> n >> m >> k;
cin >> (s+1);
for(int i=1;i<=m;i++) cin >> v[i];
last = tot = 1;
for(int i=1;i<=n;i++) sam_extend(s[i] - 'a');
pref[0] = 0;
for(int i=1;i<=m;i++) pref[i] = pref[i-1] + v[i];
construct_st(m);
for(int j=1; j<=k; j++) {
cin >> (t+1);
last_pos = 1;
max_len = 0;
for(int i=1;i<=m;i++) {
Go(t[i] - 'a'); // 在 parent 树上沿着 last_pos 向上找到第一个有出边 t[i] 的节点
if( max_len > 0 )
ans[j] = max( ans[j], pref[i] - query(i-max_len,i) );
}
}
for(int i=1;i<=k;i++) cout << ans[i] << '\n';
}
最后闲扯些给不了解后缀自动机:
沿着 树向下走,相当于在左边添加字符,而越长的子串在原串中的的出现位置相对更少。为什么说 “从 开始在 树中向上转移,直到遇到第一个存在字符 的出边的节点位置”,因为向上走,相当于不断去掉左边的字符,越短的子串在原串的出现的位置相对更多,更“可能”会遇到一个后面跟着一个字符 的位置。
而沿着 的出边转移,相当于在子串的后面添加字符。