NC15857. Similarity
描述
Bob has two strings a and b, he wants to calculate the similarity of them. The similarity is defined by the sum of the number of times in string a for each substring in string b.
输入描述
The first line is the string a.
The second line is the string b.
All strings consist of only lowercase English letters.
The length of string a and the length of string b will not exceed 500000.
输出描述
Output the similarity of two strings.
示例1
输入:
abb bba
输出:
6
说明:
For the substring “b”,it appears 2 times in string a.C++(g++ 7.5.0) 解法, 执行用时: 449ms, 内存消耗: 200536K, 提交时间: 2022-11-01 11:31:47
#include<bits/stdc++.h> using namespace std; char sa[2000005],sb[2000005]; long long lena,lenb; int head[4000005],to[4000005],nxt[4000005],tot; inline void addedge(int u,int v){ to[++tot]=v,nxt[tot]=head[u],head[u]=tot; } int fa[2000005],ch[2000005][27],len[2000005],lst=1,cnt=1; long long pos[2000005],tag[2000005],sza[2000005],szb[2000005],ans; inline void extend(int c){ int p=lst,np=++cnt; len[np]=len[p]+1; for(;p&&ch[p][c]==0;p=fa[p]){ ch[p][c]=np; } if(p==0){ fa[np]=1; }else{ int q=ch[p][c]; if(len[q]==len[p]+1){ fa[np]=q; }else{ int nq=++cnt; len[nq]=len[p]+1,fa[nq]=fa[q]; for(int i=0;i<=26;i++){ ch[nq][i]=ch[q][i]; } fa[q]=fa[np]=nq; for(;p&&ch[p][c]==q;p=fa[p]){ ch[p][c]=nq; } } } lst=np; } void dfs(int x){ for(int i=head[x];i;i=nxt[i]){ dfs(to[i]); sza[x]=sza[x]+sza[to[i]]; szb[x]=szb[x]+szb[to[i]]; } //if(tag[x]){ //cout<<sz[x]<<" "<<len[x]-len[fa[x]]<<endl; ans+=sza[x]*szb[x]*(len[x]-len[fa[x]]); //} } int main(){ scanf("%s %s",sa+1,sb+1); lena=strlen(sa+1),lenb=strlen(sb+1); for(int i=1;i<=lena;i++){ extend(sa[i]-'a'); pos[i]=lst; //sz[lst]++; } extend(26); for(int i=1;i<=lenb;i++){ extend(sb[i]-'a'); tag[lst]=1; } int poss=1; for(int i=1;i<=lena;i++){ poss=ch[poss][sa[i]-'a']; sza[poss]++; } poss=1; for(int i=1;i<=lenb;i++){ poss=ch[poss][sb[i]-'a']; szb[poss]++; } for(int i=2;i<=cnt;i++){ addedge(fa[i],i); } dfs(1); cout<<ans<<endl; return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 291ms, 内存消耗: 166656K, 提交时间: 2020-03-30 14:23:12
#include<bits/stdc++.h> using namespace std; #define LL long long #define N 500010 const int M=N<<1; struct sam { int t[M][26],len[M]={-1},fa[M],sz=2,last=1,f[M]; void init() { memset(t,0,(sz+10)*sizeof t[0]); sz=2; last=1; } void ins(int ch) { int p=last,np=last=sz++; f[np]=1; len[np]=len[p]+1; for(;p&&!t[p][ch];p=fa[p]) t[p][ch]=np; if(!p) { fa[np]=1; return; } int q=t[p][ch]; if(len[p]+1==len[q]) fa[np]=q; else { int nq=sz++; len[nq]=len[p]+1; memcpy(t[nq],t[q],sizeof t[0]); fa[nq]=fa[q]; fa[np]=fa[q]=nq; for(;t[p][ch]==q;p=fa[p]) t[p][ch]=nq; } } int c[M]={1},a[M]; void rsort() { for(int i=1;i<sz;i++) c[i]=0; for(int i=1;i<sz;i++) c[len[i]]++; for(int i=1;i<sz;i++) c[i]+=c[i-1]; for(int i=1;i<sz;i++) a[--c[len[i]]]=i; } }a,b; int n,m; char s[N]; LL ans; void dfs(int x,int y) { if(x!=1) ans+=((LL)(a.f[x]))*((LL)(b.f[y])); for(int i=0;i<26;i++) if(a.t[x][i]&&b.t[y][i]) dfs(a.t[x][i],b.t[y][i]); } int main() { scanf("%s",s+1); n=strlen(s+1); a.init(); for(int i=1;i<=n;i++) a.ins(s[i]-'a'); scanf("%s",s+1); b.init(); for(int i=1;i<=n;i++) b.ins(s[i]-'a'); a.rsort(),b.rsort(); for(int i=a.sz-1;i>=1;i--) a.f[a.fa[a.a[i]]]+=a.f[a.a[i]]; for(int i=b.sz-1;i>=1;i--) b.f[b.fa[b.a[i]]]+=b.f[b.a[i]]; ans=0; dfs(1,1); cout<<ans<<endl; return 0; }
C++14(g++5.4) 解法, 执行用时: 271ms, 内存消耗: 195960K, 提交时间: 2020-07-31 17:52:05
#include<bits/stdc++.h> using namespace std; const int N=2e6+100; typedef long long ll; char s[N]; int ch[N*2][27],len[N*2],fa[N*2]; int last=1,tot=1; ll endpos_size[N][2],c[N];//endpos_size也就是每个类在原串中出现的次数 void add(int c,int f) { int p=last,np=last=++tot; len[np]=len[p]+1; endpos_size[np][f]=1; for(;p&&!ch[p][c];p=fa[p]) ch[p][c]=np; if(!p) fa[np]=1; else { int q=ch[p][c]; if(len[q]==len[p]+1) fa[np]=q; else { int nq=++tot;len[nq]=len[p]+1; memcpy(ch[nq],ch[q],sizeof(ch[q])); fa[nq]=fa[q];fa[q]=fa[np]=nq; for(;p&&ch[p][c]==q;p=fa[p]) ch[p][c]=nq; } } } ll a[N],ans; int main() { cin>>s+1; for(int i=1;s[i];i++) add(s[i]-'a',0); add(26,0); cin>>s+1; for(int i=1;s[i];i++) add(s[i]-'a',1); for(int i=1;i<=tot;i++) c[len[i]]++; for(int i=1;i<=tot;i++) c[i]+=c[i-1]; for(int i=1;i<=tot;i++) a[c[len[i]]--]=i; for(int i=tot;i;i--) { int p=a[i]; endpos_size[fa[p]][0]+=endpos_size[p][0]; endpos_size[fa[p]][1]+=endpos_size[p][1]; } for(int i=2;i<=tot;i++) ans+=endpos_size[i][0]*endpos_size[i][1]*(len[i]-len[fa[i]]); cout<<ans; }