NC21529. Periodic Palindrome
描述
输入描述
There are multiple test cases. The first line of input contains an integer T indicating the number of test cases. For each test case:
The first line contains a string s (2 ≤ |s| ≤ 106) consisting only of lowercase English letters.
It is guaranteed that the sum of |s| in all cases does not exceed 106.
输出描述
For each test case, output an integer denoting the answer.
示例1
输入:
1 aaa
输出:
5
C++11(clang++ 3.9) 解法, 执行用时: 254ms, 内存消耗: 41572K, 提交时间: 2020-03-14 14:11:38
#include<cstdio> #include<algorithm> #include<iostream> #include<cstring> using namespace std; typedef long long LL; const int N=1000005*2; int T; char ss[N]; int n; LL ans=0; int nxt[N]; bool ok[N]; void exkmp (){ nxt[1]=n;int id,mx=0; for (int u=2;u<=n;u++) { if (mx>=u) nxt[u]=min(nxt[u-id+1],mx-u+1); else nxt[u]=0; while (u+nxt[u]<=n&&ss[u+nxt[u]]==ss[nxt[u]+1]) nxt[u]++; if (u+nxt[u]-1>mx) {mx=u+nxt[u]-1;id=u;} } for (int u=1;u<=n;u++) ok[u]=(u+nxt[u]-1==n);ok[n+1]=true; } char s1[N]; int p[N]; bool ok1[N];//这个点往前匹配是否可以到1 bool ok2[N];//这个点匹配能不能匹配到n int f[N];//这个点开始往前回文的长度 int sum[N]; void manacher (){ s1[0]='$';int now=0; for (int u=1;u<=n;u++) {s1[++now]='#';s1[++now]=ss[u];} s1[++now]='#';s1[++now]='%'; p[0]=1;int mx=0,id=0; for (int i=1;i<now;i++){ if (mx>i) p[i]=min(mx-i,p[2*id-i]); else p[i]=1; while (s1[i+p[i]]==s1[i-p[i]]) p[i]++; if (i+p[i]>mx) {mx=i+p[i];id=i;} } for (int u=1;u<now;u+=2){ int x=u/2+1;//这个东西对应的是哪一位 f[x]=(p[u]/2); ok1[x]=(p[u]==u); ok2[x]=(u+p[u]==now); } ok2[n+1]=true;f[n+1]=0; } void calc1 (){ sum[n+1]=0;for (int u=n;u>=1;u--) sum[u]=sum[u+1]+ok2[u+1]; for (int u=1;u<=n;u++) { if (ok1[u]==false) continue; int t=max(u*2-1,u);t=max(t,(n+u+1)/2); ans=ans+sum[t]; } ans--;//[1,n]不可以 } int lb (int x) {return x&(-x);} void add (int x,int y) {while (x<=n) {sum[x]=sum[x]+y;x=x+lb(x);}} int get (int x) {int lalal=0;while (x>=1){lalal=lalal+sum[x];x=x-lb(x);}return lalal;} pair<int,int> g[N]; void calc2 (){ for (int u=1;u<=n;u++) sum[u]=0; int L=0; for (int u=1;u<=n;u++){add(u,1);g[u]=make_pair(f[u],u);} sort(g+1,g+1+n); for (int u=3;u<=n+1;u+=2)//首先,找一个周期 if (ok[u]){ int t=u/2;//单节的长度是多少 while (L<n&&g[L+1].first<t){L++;add(g[L].second,-1);} ans=ans+get(u-1)-get(u-t-1); int now=(n-(u-t)+1)%t,l,r,x; l=(u-t);r=min(u-1,l+now);x=(n-l+1)/t; ans=ans+(get(r)-get(l-1))*(LL)x; l=r+1;r=u-1;x=(n-l+1)/t; ans=ans+(get(r)-get(l-1))*(LL)x; } } int main(){ scanf("%d",&T); while (T--){ ans=0;scanf("%s",ss+1);n=strlen(ss+1); exkmp();manacher(); calc1(); calc2(); printf("%lld\n",ans); } return 0; }