NC19998. [HAOI2016]找相同字符
描述
输入描述
两行,两个字符串s1,s2,长度分别为n1,n2。1 ≤ n1, n2 ≤ 200000,字符串中只有小写字母
输出描述
输出一个整数表示答案
示例1
输入:
aabb bbaa
输出:
10
C++14(g++5.4) 解法, 执行用时: 184ms, 内存消耗: 105716K, 提交时间: 2019-11-03 08:54:55
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; typedef long long ll; int const xn=8e5+5; int n,m,cnt=1,fa[xn],go[xn][30],l[xn],tax[xn],q[xn],d1[xn],d2[xn]; char dc[xn]; int work(int p,int w) { int nq=++cnt,q=go[p][w]; l[nq]=l[p]+1; fa[nq]=fa[q]; fa[q]=nq; memcpy(go[nq],go[q],sizeof go[q]); for(;p&&go[p][w]==q;p=fa[p])go[p][w]=nq; return nq; } int ext(int p,int w) { if(go[p][w]) { int q=go[p][w]; if(l[q]==l[p]+1)return q; return work(p,w); } int np=++cnt; l[np]=l[p]+1; for(;p&&!go[p][w];p=fa[p])go[p][w]=np; if(!p)fa[np]=1; else { int q=go[p][w]; if(l[q]==l[p]+1)fa[np]=q; else fa[np]=work(p,w); } return np; } void rsort() { for(int i=1;i<=cnt;i++)tax[l[i]]++; for(int i=1;i<=cnt;i++)tax[i]+=tax[i-1]; for(int i=cnt;i;i--)q[tax[l[i]]--]=i; } int main() { scanf("%s",dc); n=strlen(dc); for(int lst=1,i=0;i<n;i++)lst=ext(lst,dc[i]-'a'+1),d1[lst]++; scanf("%s",dc); m=strlen(dc); for(int lst=1,i=0;i<n;i++)lst=ext(lst,dc[i]-'a'+1),d2[lst]++; rsort(); for(int i=cnt,x;i;i--)d1[fa[x=q[i]]]+=d1[x],d2[fa[x]]+=d2[x]; ll ans=0; for(int i=1;i<=cnt;i++)ans+=(ll)(l[i]-l[fa[i]])*d1[i]*d2[i]; printf("%lld\n",ans); return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 237ms, 内存消耗: 99832K, 提交时间: 2020-01-14 14:29:25
#include<bits/stdc++.h> using namespace std; #define ll long long const int N=1e6+100; char s[N]; int t[N][27],lst=1,cnt=1,fa[N],len[N]; int size[N][3]; void insert(int x,int op) { int p=lst,now=++cnt; lst=now;len[now]=len[p]+1; if(op!=2) size[now][op]++; for(;p&&t[p][x]==0;p=fa[p]) t[p][x]=now; if(p==0) fa[now]=1; else { int q=t[p][x]; if(len[q]==len[p]+1)fa[now]=q; else { int nq=++cnt; len[nq]=len[p]+1; fa[nq]=fa[q]; memcpy(t[nq],t[q],sizeof t[q]); fa[now]=fa[q]=nq; for(;p&&t[p][x]==q;p=fa[p]) t[p][x]=nq; } } } int a[N],c[N]; ll ans=0; void getans() { for(int i=1;i<=cnt;i++) c[i]=0; for(int i=1;i<=cnt;i++) c[len[i]]++; for(int i=1;i<=cnt;i++) c[i]+=c[i-1]; for(int i=cnt;i>=0;i--) a[c[len[i]]--]=i; for(int i=cnt;i>=1;i--) { int p=a[i]; size[fa[p]][0]+=size[p][0]; size[fa[p]][1]+=size[p][1]; ans+=1ll*size[p][0]*size[p][1]*(len[p]-len[fa[p]]); } } int main() { cin>>s+1; for(int i=1;s[i];i++) insert(s[i]-'a',0); insert(26,2); cin>>s+1; for(int i=1;s[i];i++) insert(s[i]-'a',1); getans(); cout<<ans<<endl; return 0; }