Palindrome Mouse - 回文树 & dfs
这道题是牛客2019暑期多校第六场的C题,题目链接
大意是,给出一个字符串,对于它所有的回文子串,若其中某个是另一个的子串,则视其为一对,求总对数。
看到回文子串就想到了回文树。简单回顾一下,回文树每个节点代表以当前字符结尾的最长回文子串,next指向前后添上某个字符后更大的那个子串,fail(suffix)指向当前子串的最长回文后缀。初步的想法是,对于某个节点,它fail链上的所有节点以及反向next链上的所有节点,是它的所有回文子串,对于所有节点求出这个数量,并相加。可是实际操作中发现,fail链和反next链上的节点有重复,导致结果会重复计数某个节点。
于是考虑通过标记避免重复计数,我们从两个根分别沿着next链向上做dfs,同时向下求算fail链的深度并逐个标记vis。如果在某个节点求算fail链深度的过程中首次发现了被标记的节点,那么说明这个节点既是后缀fail节点,也是反向next节点。到这里我们可以停下,把这个节点的子串数加上当前求得的fail深度,再加1(代表next链上更高一层),就是当前节点的子串数了。遍历完当前节点的next指针后,需要把之前打的标记撤销,因为接下来算的节点中并不包含这些子串。
具体可以看代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=200010;
struct palindrome_tree{
map<char,int> trans[maxn];
int len[maxn],suf[maxn];
int lst,cnt;
char s[maxn];
bool vis[maxn];
int dp[maxn];
long long ans;
palindrome_tree(char* _s){
memcpy(s,_s,strlen(_s));
}
int new_node(int _len,int _suf){
len[cnt]=_len;
suf[cnt]=_suf;
return cnt++;
}
void init(){
ans=0;
memset(dp,0,sizeof(dp));
memset(vis,0,sizeof(vis));
cnt=0;
new_node(0,1);
int odd_root=new_node(-1,1);
lst=odd_root;
}
void extend(int cur){
char c=s[cur];
int p=lst;
while(s[cur-len[p]-1]!=s[cur])
p=suf[p];
if(!trans[p].count(c)){
int v=suf[p];
while(s[cur-len[v]-1]!=s[cur])
v=suf[v];
trans[p][c]=new_node(len[p]+2,trans[v][c]);
}
lst=trans[p][c];
}
// int dfs(int id){
// int res=0;
// if(trans[id].empty())
// return (n_pir[id]=1);
// for(auto it=trans[id].begin();it!=trans[id].end();it++){
// int add=(n_pir[it->second]==0?dfs(it->second):n_pir[it->second]);
// res+=add;
// }
// res++;
// return res;
// }
//
// int count_pir(){
// int res=0;
// for(int i=2;i<cnt;i++){
// if(n_pir[i]==0)
// n_pir[i]=dfs(i);
// res+=(f_pir[i]*n_pir[i]-1);
// }
// return res;
// }
int fail_dep(int x){
int res=0;
vis[x]=1;
int p=suf[x];
while(p>1&&!vis[p]){
res++;
vis[p]=1;
p=suf[p];
}
return res;
}
void clear_vis(int x,int depth){
vis[x]=0;
while(depth--){
x=suf[x];
vis[x]=0;
}
}
void dfs(int x,int fa){
int dep=fail_dep(x);
if(x>1&&fa>1){
dp[x]=dp[fa]+dep+1;
}else
dp[x]=dep;
ans+=dp[x];
for(auto it=trans[x].begin();it!=trans[x].end();it++){
dfs(it->second,x);
}
clear_vis(x,dep);
}
};
int main(){
int t;
cin>>t;
char s[maxn];
for(int i=1;i<=t;i++){
scanf("%s",s);
palindrome_tree *pa=new palindrome_tree(s);
pa->init();
int len=strlen(s);
for(int i=0;i<len;i++)
pa->extend(i);
pa->dfs(0,0);
pa->dfs(1,1);
printf("Case #%d: %lld\n",i,pa->ans);
}
return 0;
}