列表

详情


NC19786. Palindrome

描述

修修在蒜头送给他的奖杯上看到了一个长度为n的字符串s。
他希望从s中选择两个非空子串a,b(可以有重叠的部分),使得它们拼起来是一个回文串。
修修很快就算出了方案数,他听说你也会数数,就让你也来解决一下这个问题。两个方案不同当且仅当a,b中至少一个的长度或位置不同。

输入描述

第一行一个整数n (1 ≤ n ≤ 2*105),第二行一个字符串s。保证s只包含小写字母。

输出描述

输出一行一个整数表示方案数。

示例1

输入:

3
aba

输出:

16

示例2

输入:

10
abbaaababb

输出:

360

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(clang++ 11.0.1) 解法, 执行用时: 144ms, 内存消耗: 78872K, 提交时间: 2022-12-11 14:50:52

#include <iostream>
#include <cstring>
using namespace std;

const int N=400005;
int pam[N][26],fail[N],num[N],len_1[N],pos[N],last_1,p_1,n;
int sam[N][26],link[N],len[N],sz[N],c[N],a[N],cnt,last_2,lenlen;
__int128 fsize[N],ans;
char s[N],rs[N];

int newnode(int l)
{
	memset(pam[p_1], 0, sizeof(pam[p_1]));
	num[p_1]=0;
	len_1[p_1]=l;
	return p_1++;
}

void init()
{
	last_1=p_1=0;
	newnode(0),newnode(-1);
	s[0]='?',rs[0]='?';
	fail[0]=1;
	
	last_2=cnt=1;
	memset(sam, 0, sizeof(sam));
	memset(sz, 0, sizeof(sz));
}

int get_fail(int x, int flag)
{
	if(flag)
	{
		while(s[n-len_1[x]-1]!=s[n])
			x=fail[x];
	}
	else
	{
		while(rs[n-len_1[x]-1]!=rs[n])
			x=fail[x];
	}
	return x;
}

void add_1(int c, int flag)
{
	int old=get_fail(last_1, flag);
	if(!pam[old][c])
	{
		int now=newnode(len_1[old]+2);
		fail[now]=pam[get_fail(fail[old], flag)][c];
		pam[old][c]=now;
		num[now]=num[fail[now]]+1;
	}
	last_1=pam[old][c];
	pos[n]=num[last_1];
}

void add_2(int c)
{
	int p,cur=++cnt;
	len[cur]=len[last_2]+1;
	for(p=last_2; p && !sam[p][c]; p=link[p])
		sam[p][c]=cur;
	if(!p)
		p=1,link[cur]=1;
	else
	{
		int q=sam[p][c];
		if(len[q]==len[p]+1)
			link[cur]=q;
		else
		{
			int cl=++cnt;
			len[cl]=len[p]+1;
			memcpy(sam[cl], sam[q], sizeof(sam[q]));
			link[cl]=link[q];
			while(p && sam[p][c]==q)
			{
				sam[p][c]=cl;
				p=link[p];
			}
			link[q]=link[cur]=cl;
		}
	}
	sz[cur]=1;
	last_2=cur;
}

void calc()
{
	memset(c, 0, sizeof(c));
	for(int i=1; i<=cnt; i++)
		c[len[i]]++;
	for(int i=1; i<=cnt; i++)
		c[i]+=c[i-1];
	for(int i=1; i<=cnt; i++)
		a[c[len[i]]--]=i;
	for(int i=cnt; i>=1; i--)
		sz[link[a[i]]]+=sz[a[i]];
	for(int i=1; i<=cnt; i++)
		fsize[a[i]]=fsize[link[a[i]]]+1ll*(len[a[i]]-len[link[a[i]]])*sz[a[i]];
}

void fun(char * ss, int addv)
{
	int p=1,now_len=0;
	for(int i=1; i<=lenlen; i++)
	{
		int c=ss[i]-'a';
		while(p && !sam[p][c])
			p=link[p];
		if(!p)
			now_len=0,p=1;
		else
		{
			now_len=min(now_len, len[p])+1;
			p=sam[p][c];
			ans+=(fsize[link[p]]+1ll*(now_len-len[link[p]])*sz[p])*(pos[i+1]+addv);
		}
	}
}

void print(__int128 x)
{
	if(x)
	{
		print(x/10);
		cout << (int)(x%10);
	}
}

int main()
{
	scanf("%d", &lenlen);
	scanf("%s", s+1);
	for(int i=1; i<=lenlen; i++)
		rs[i]=s[lenlen-i+1];
	init();
	for(n=1; n<=lenlen; n++)
		add_1(rs[n]-'a', 0);
	for(int i=1; i<=lenlen/2; i++)
		swap(pos[i], pos[lenlen-i+1]);
	for(int i=1; i<=lenlen; i++)
		add_2(rs[i]-'a');
	calc();
	fun(s, 1);
	
	init();
	for(n=1; n<=lenlen; n++)
		add_1(s[n]-'a', 1);
	for(int i=1; i<=lenlen/2; i++)
		swap(pos[i], pos[lenlen-i+1]);
	for(int i=1; i<=lenlen; i++)
		add_2(s[i]-'a');
	calc();
	fun(rs, 0);
	print(ans);
	puts("");
	return 0;
}

上一题