列表

详情


NC17344. Double Palindrome

描述

Chiaki has a string s. She would like to know the number of (i,j) (1 ≤ i < j ≤ |s|) such that: after swapping si and sj, s can be represented as concatenation of two palindromes.
Note that a string is called palindrome if and only if its characters read the same backward as forward.

输入描述

There are multiple test cases. The first line of input is an integer T indicates the number of test cases. For each test case:
The first line contains a string s (1 ≤ |s| ≤ 105)consisting of lowercase English letters.
It is guaranteed that the sum of all |s| does not exceed 106.

输出描述

For each test case, output an integer denoting the answer.

示例1

输入:

3
abba
aa
baa

输出:

4
1
2

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 346ms, 内存消耗: 7660K, 提交时间: 2018-07-29 23:52:13

#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <set>
using namespace std;
#define MAXNUM 111111
typedef unsigned long long ull;
ull xp[MAXNUM], xhash[MAXNUM], xhash2[MAXNUM];
char s2[MAXNUM];
const ull seed = 1e9 + 7;
int s2size;
void init()
{
	xp[0] = 1;
	for (int i = 1; i <= 100005; i++)
		xp[i] = xp[i - 1] * seed;
}
ull inline gethash(ull xhash[], ull xp[], int i, int L)
{
	return xhash[i] - xhash[i + L] * xp[L];
}
ull inline gethash2(ull xhash2[], ull xp[], int i, int L)
{
	return xhash2[i] - xhash2[i - L] * xp[L];
}
bool check(int start, int start2, int len)
{
	ull k1 = gethash2(xhash2, xp, start, len), x1 = gethash(xhash, xp, start2, len);
	return k1 == x1;
}
void sethash()
{
	xhash[s2size + 1] = 0;
	for (int i = s2size; i >= 1; i--)
		xhash[i] = xhash[i + 1] * seed + s2[i],
		xhash2[0] = 0;
	for (int i = 1; i <= s2size; i++)
		xhash2[i] = xhash2[i - 1] * seed + s2[i];
}
ull pos[33];
#define pii pair<int,int>
vector<pii> v;
void getlen(int start, int start2, int len)
{
	if (len == 0)
		return;
	int left = 0, right = len;
	int res = 0, k = 0;
	while (v.size() < 3)
	{
		res = 0;
		left = 1, right = len;
		while (left <= right)
		{
			int middle = (left + right) / 2;
			if (check(start, start2, middle))
			{
				res = middle;
				left = middle + 1;
			}
			else right = middle - 1;
		}
		if (res == len)
			return;
		v.push_back(pii(start - res, start2 + res));
		start -= res + 1, start2 += res + 1;
		len -= res + 1;
		if (len <= 0)
			return;
	}
}
typedef long long ll;
int flag;
ll res;
set<pii> s;
void adds(int k1, int k2)
{
	if (k1 > k2)
		swap(k1, k2);
	if (!s.count(pii(k1, k2)))
	{
		res++;
		s.insert(pii(k1, k2));
	}
}
void sw(int x2, pii &k)
{
	if (s2[k.first] == s2[x2])
		adds(x2, k.second);
	if (s2[k.second] == s2[x2])
		adds(x2, k.first);
}
void solve()
{
	flag = 0, s.clear();
	res = 0;
	int start[2], start2[2], len[2];
	for (int i = 1; i < s2size; i++)
	{
		if (i & 1)
			start[0] = i / 2, start2[0] = i / 2 + 2;
		else start[0] = i / 2, start2[0] = i / 2 + 1;
		len[0] = i / 2;
		int i2 = s2size - i;
		if (i2 & 1)
			start[1] = i2 / 2, start2[1] = i2 / 2 + 2;
		else start[1] = i2 / 2, start2[1] = i2 / 2 + 1;
		len[1] = i2 / 2;
		start[1] += i, start2[1] += i;
		v.clear();
		getlen(start[0], start2[0], len[0]),getlen(start[1], start2[1], len[1]);
		if (v.size()==0)
		{
			if ((i & 1) && (i2 & 1))
			{
				int x1 = i / 2 + 1, x2 = i2 / 2 + 1 + i;
				if (s2[x1] != s2[x2])
					adds(i / 2 + 1, i2 / 2 + 1 + i);
			}
			flag = 1;
			continue;
		}
		else {
			if (v.size() > 2)
				continue;
			if (v.size()==1)
			{
				if (i & 1)
					sw(i / 2 + 1, v[0]);
				if (i2 & 1)
					sw(i2 / 2 + 1 + i, v[0]);
			}
			else {
				if (s2[v[0].first] == s2[v[1].first] && s2[v[0].second] == s2[v[1].second])
				{
					adds(v[0].first, v[1].second);
					adds(v[0].second, v[1].first);
				}
				if (s2[v[0].first] == s2[v[1].second] && s2[v[0].second] == s2[v[1].first])
				{
					adds(v[0].first, v[1].first);
					adds(v[0].second, v[1].second);
				}
			}
		}
	}
}
int main(void)
{
	int t;
	init();
	scanf("%d", &t);
	while (t--)
	{
		scanf("%s", s2 + 1);
		s2size = strlen(s2 + 1);
		memset(pos, 0, sizeof(pos));
		sethash();
		solve();
		if (flag)
		{
			for (int i = 1; i <= s2size; i++)
				pos[s2[i] - 'a']++;
			for (int i = 0; i < 26; i++)
				res += (pos[i] * (pos[i] - 1)) / 2;
		}
		printf("%lld\n", res);
	}
}

C++ 解法, 执行用时: 296ms, 内存消耗: 7584K, 提交时间: 2021-10-04 16:49:05

#include<bits/stdc++.h>
#define rep(i,x,y) for(int i=x; i<=y; ++i)
#define repd(i,x,y) for(int i=x; i>=y; --i)
#define mid (l+r>>1)
#define lch (rt<<1)
#define rch (rt<<1|1)
#define pb push_back

using namespace std;
const int N=100005;
char s[N];
int n,tot,cnt[N];
typedef long long LL;
const int base=19260817,mod=1000000007;
set <pair <int,int>> st;
LL b[N],h1[N],h2[N],ans;

struct D
{
	int x,y;
} dat[10];

bool operator == (D a,D b)
{
	return s[a.x]==s[b.x] && s[a.y]==s[b.y];
}

void init()
{
	b[0]=1;
	rep(i,1,n) b[i]=b[i-1]*base%mod;
	rep(i,1,n) h1[i]=(h1[i-1]*base+s[i]-'a')%mod;
	repd(i,n,1) h2[i]=(h2[i+1]*base+s[i]-'a')%mod;
}

LL find1(int x,int y)
{
	return ((h1[y]-h1[x-1]*b[y-x+1])%mod+mod)%mod;
}

LL find2(int x,int y)
{
	return ((h2[x]-h2[y+1]*b[y-x+1])%mod+mod)%mod;
}

bool check(int l1,int r1,int l2,int r2)
{
	return find1(l1,r1)==find2(l2,r2);
}

bool checkd()
{
	rep(i,1,n-1) if(check(1,i,1,i) && check(i+1,n,i+1,n)) return 1;
	return 0;
}

void add(int x,int y)
{
	if(x>y) swap(x,y);
	st.insert(make_pair(x,y));
}

void ins(int x,int y)
{
	int a=s[x]-'a',b=s[y]-'a';	
	if(a>b) swap(x,y);
	dat[++tot]=(D){x,y};
}

bool find(int L,int R)
{
	int x;
	if(check(L,R,L,R)) return 1;
	rep(i,1,2)
	{
		x=(R-L+1)>>1;
		int l=1,r=x;
		while(l<=r) check(L,L+mid-1,R-mid+1,R)?l=mid+1:r=mid-1;
		--l,L+=l,R-=l;
		ins(L,R),++L,--R;
		if(check(L,R,L,R)) return 1;
	}
	return 0;
}

void solve()
{
	scanf("%s",s+1),n=strlen(s+1);
	rep(i,0,25) cnt[i]=0;
	ans=0,init();
	st.clear();
	rep(i,1,n) ++cnt[s[i]-'a'];
	if(checkd()) rep(i,0,25) ans+=(LL)cnt[i]*(cnt[i]-1)/2;
	rep(i,1,n-1)
	{
		bool jdg=1;
		tot=0,jdg&=find(1,i),jdg&=find(i+1,n);
		if(!jdg) continue;
		if(tot==2)
		{
			if(dat[1]==dat[2]) add(dat[1].x,dat[2].y),add(dat[1].y,dat[2].x);
		}
		if(tot==1)
		{
			if(i&1)
			{
				int x=(i+1)>>1;
				if(s[x]==s[dat[1].x]) add(x,dat[1].y);
				if(s[x]==s[dat[1].y]) add(x,dat[1].x);
			}
			if((n-i)&1)
			{
				int y=(n+i+1)>>1;
				if(s[y]==s[dat[1].x]) add(y,dat[1].y);
				if(s[y]==s[dat[1].y]) add(y,dat[1].x);
			}
		}
		if(tot==0)
		{
			if(i&1 && !(n&1))
			{
				int x=(i+1)>>1,y=(n+i+1)>>1;
				if(s[x]!=s[y]) add(x,y);
			}	
		}
	}
	ans+=st.size();
	printf("%lld\n",ans);
}		

int main()
{
	int T;
	scanf("%d",&T);
	while(T--) solve();
	return 0;
}

上一题