NC17344. Double Palindrome
描述
输入描述
There are multiple test cases. The first line of input is an integer T indicates the number of test cases. For each test case:
The first line contains a string s (1 ≤ |s| ≤ 105)consisting of lowercase English letters.
It is guaranteed that the sum of all |s| does not exceed 106.
输出描述
For each test case, output an integer denoting the answer.
示例1
输入:
3 abba aa baa
输出:
4 1 2
C++14(g++5.4) 解法, 执行用时: 346ms, 内存消耗: 7660K, 提交时间: 2018-07-29 23:52:13
#include <iostream> #include <algorithm> #include <cstring> #include <vector> #include <set> using namespace std; #define MAXNUM 111111 typedef unsigned long long ull; ull xp[MAXNUM], xhash[MAXNUM], xhash2[MAXNUM]; char s2[MAXNUM]; const ull seed = 1e9 + 7; int s2size; void init() { xp[0] = 1; for (int i = 1; i <= 100005; i++) xp[i] = xp[i - 1] * seed; } ull inline gethash(ull xhash[], ull xp[], int i, int L) { return xhash[i] - xhash[i + L] * xp[L]; } ull inline gethash2(ull xhash2[], ull xp[], int i, int L) { return xhash2[i] - xhash2[i - L] * xp[L]; } bool check(int start, int start2, int len) { ull k1 = gethash2(xhash2, xp, start, len), x1 = gethash(xhash, xp, start2, len); return k1 == x1; } void sethash() { xhash[s2size + 1] = 0; for (int i = s2size; i >= 1; i--) xhash[i] = xhash[i + 1] * seed + s2[i], xhash2[0] = 0; for (int i = 1; i <= s2size; i++) xhash2[i] = xhash2[i - 1] * seed + s2[i]; } ull pos[33]; #define pii pair<int,int> vector<pii> v; void getlen(int start, int start2, int len) { if (len == 0) return; int left = 0, right = len; int res = 0, k = 0; while (v.size() < 3) { res = 0; left = 1, right = len; while (left <= right) { int middle = (left + right) / 2; if (check(start, start2, middle)) { res = middle; left = middle + 1; } else right = middle - 1; } if (res == len) return; v.push_back(pii(start - res, start2 + res)); start -= res + 1, start2 += res + 1; len -= res + 1; if (len <= 0) return; } } typedef long long ll; int flag; ll res; set<pii> s; void adds(int k1, int k2) { if (k1 > k2) swap(k1, k2); if (!s.count(pii(k1, k2))) { res++; s.insert(pii(k1, k2)); } } void sw(int x2, pii &k) { if (s2[k.first] == s2[x2]) adds(x2, k.second); if (s2[k.second] == s2[x2]) adds(x2, k.first); } void solve() { flag = 0, s.clear(); res = 0; int start[2], start2[2], len[2]; for (int i = 1; i < s2size; i++) { if (i & 1) start[0] = i / 2, start2[0] = i / 2 + 2; else start[0] = i / 2, start2[0] = i / 2 + 1; len[0] = i / 2; int i2 = s2size - i; if (i2 & 1) start[1] = i2 / 2, start2[1] = i2 / 2 + 2; else start[1] = i2 / 2, start2[1] = i2 / 2 + 1; len[1] = i2 / 2; start[1] += i, start2[1] += i; v.clear(); getlen(start[0], start2[0], len[0]),getlen(start[1], start2[1], len[1]); if (v.size()==0) { if ((i & 1) && (i2 & 1)) { int x1 = i / 2 + 1, x2 = i2 / 2 + 1 + i; if (s2[x1] != s2[x2]) adds(i / 2 + 1, i2 / 2 + 1 + i); } flag = 1; continue; } else { if (v.size() > 2) continue; if (v.size()==1) { if (i & 1) sw(i / 2 + 1, v[0]); if (i2 & 1) sw(i2 / 2 + 1 + i, v[0]); } else { if (s2[v[0].first] == s2[v[1].first] && s2[v[0].second] == s2[v[1].second]) { adds(v[0].first, v[1].second); adds(v[0].second, v[1].first); } if (s2[v[0].first] == s2[v[1].second] && s2[v[0].second] == s2[v[1].first]) { adds(v[0].first, v[1].first); adds(v[0].second, v[1].second); } } } } } int main(void) { int t; init(); scanf("%d", &t); while (t--) { scanf("%s", s2 + 1); s2size = strlen(s2 + 1); memset(pos, 0, sizeof(pos)); sethash(); solve(); if (flag) { for (int i = 1; i <= s2size; i++) pos[s2[i] - 'a']++; for (int i = 0; i < 26; i++) res += (pos[i] * (pos[i] - 1)) / 2; } printf("%lld\n", res); } }
C++ 解法, 执行用时: 296ms, 内存消耗: 7584K, 提交时间: 2021-10-04 16:49:05
#include<bits/stdc++.h> #define rep(i,x,y) for(int i=x; i<=y; ++i) #define repd(i,x,y) for(int i=x; i>=y; --i) #define mid (l+r>>1) #define lch (rt<<1) #define rch (rt<<1|1) #define pb push_back using namespace std; const int N=100005; char s[N]; int n,tot,cnt[N]; typedef long long LL; const int base=19260817,mod=1000000007; set <pair <int,int>> st; LL b[N],h1[N],h2[N],ans; struct D { int x,y; } dat[10]; bool operator == (D a,D b) { return s[a.x]==s[b.x] && s[a.y]==s[b.y]; } void init() { b[0]=1; rep(i,1,n) b[i]=b[i-1]*base%mod; rep(i,1,n) h1[i]=(h1[i-1]*base+s[i]-'a')%mod; repd(i,n,1) h2[i]=(h2[i+1]*base+s[i]-'a')%mod; } LL find1(int x,int y) { return ((h1[y]-h1[x-1]*b[y-x+1])%mod+mod)%mod; } LL find2(int x,int y) { return ((h2[x]-h2[y+1]*b[y-x+1])%mod+mod)%mod; } bool check(int l1,int r1,int l2,int r2) { return find1(l1,r1)==find2(l2,r2); } bool checkd() { rep(i,1,n-1) if(check(1,i,1,i) && check(i+1,n,i+1,n)) return 1; return 0; } void add(int x,int y) { if(x>y) swap(x,y); st.insert(make_pair(x,y)); } void ins(int x,int y) { int a=s[x]-'a',b=s[y]-'a'; if(a>b) swap(x,y); dat[++tot]=(D){x,y}; } bool find(int L,int R) { int x; if(check(L,R,L,R)) return 1; rep(i,1,2) { x=(R-L+1)>>1; int l=1,r=x; while(l<=r) check(L,L+mid-1,R-mid+1,R)?l=mid+1:r=mid-1; --l,L+=l,R-=l; ins(L,R),++L,--R; if(check(L,R,L,R)) return 1; } return 0; } void solve() { scanf("%s",s+1),n=strlen(s+1); rep(i,0,25) cnt[i]=0; ans=0,init(); st.clear(); rep(i,1,n) ++cnt[s[i]-'a']; if(checkd()) rep(i,0,25) ans+=(LL)cnt[i]*(cnt[i]-1)/2; rep(i,1,n-1) { bool jdg=1; tot=0,jdg&=find(1,i),jdg&=find(i+1,n); if(!jdg) continue; if(tot==2) { if(dat[1]==dat[2]) add(dat[1].x,dat[2].y),add(dat[1].y,dat[2].x); } if(tot==1) { if(i&1) { int x=(i+1)>>1; if(s[x]==s[dat[1].x]) add(x,dat[1].y); if(s[x]==s[dat[1].y]) add(x,dat[1].x); } if((n-i)&1) { int y=(n+i+1)>>1; if(s[y]==s[dat[1].x]) add(y,dat[1].y); if(s[y]==s[dat[1].y]) add(y,dat[1].x); } } if(tot==0) { if(i&1 && !(n&1)) { int x=(i+1)>>1,y=(n+i+1)>>1; if(s[x]!=s[y]) add(x,y); } } } ans+=st.size(); printf("%lld\n",ans); } int main() { int T; scanf("%d",&T); while(T--) solve(); return 0; }