列表

详情


NC17895. [NOI2016]优秀的拆分

描述

如果一个字符串可以被拆分为 AABB的形式,其中 A 和 B 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。

例如,对于字符串 aabaabaa,如果令 A=aab,B=a,我们就找到了这个字符串拆分成 AABB 的一种方式。

一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。比如我们令 A=a,B=baa,也可以用 AABB表示出上述字符串;但是,字符串 abaabaa 就没有优秀的拆分。

现在给出一个长度为 nn 的字符串 SS,我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。这里的子串是指字符串中连续的一段。

以下事项需要注意:

  1. 出现在不同位置的相同子串,我们认为是不同的子串,它们的优秀拆分均会被记入答案。
  2. 在一个拆分中,允许出现 A=B。例如 cccc 存在拆分 A=B=c。
  3. 字符串本身也是它的一个子串。

输入描述

包含多组数据。第一行只有一个整数 T,表示数据的组数。保证 1≤T≤10

接下来 T 行,每行包含一个仅由英文小写字母构成的字符串 S,意义如题所述。

输出描述

输出 T 行,每行包含一个整数,表示字符串 S 所有子串的所有拆分中,总共有多少个是优秀的拆分。

示例1

输入:

4
aabbbb
cccccc
aabaabaabaa
bbaabaababaaba

输出:

3
5
4
7

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++14(g++5.4) 解法, 执行用时: 352ms, 内存消耗: 20184K, 提交时间: 2019-10-17 19:42:12

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
using namespace std;
#define maxn 100010
typedef long long ll;
char s[maxn];
int n, c[maxn], x[maxn], y[maxn], f[maxn], g[maxn];


struct SuffixArray 
{
	int sa[maxn], rnk[maxn], lcp[maxn];

	void build_sa() 
	{
		int m = 122;
		memset(x, 0, sizeof(x));
		memset(y, 0, sizeof(y));
		memset(sa, 0, sizeof(sa));
		memset(rnk, 0, sizeof(rnk));
		memset(lcp, 0, sizeof(lcp));
		for (int i = 1; i <= m; i++) 
			c[i] = 0;
		for (int i = 1; i <= n; i++)
			++c[x[i] = s[i]];
		for (int i = 2; i <= m; i++)
			c[i] += c[i - 1];
		for (int i = n; i >= 1; i--) 
			sa[c[x[i]]--] = i;
		for (int k = 1; k <= n; k <<= 1)
		{
			int p = 0;
			for (int i = n - k + 1; i <= n; i++) 
				y[++p] = i;
			for (int i = 1; i <= n; i++) 
				if (sa[i]>k)
					y[++p] = sa[i] - k;
			for (int i = 1; i <= m; i++)
				c[i] = 0;
			for (int i = 1; i <= n; i++)
				++c[x[y[i]]];
			for (int i = 1; i <= m; i++)
				c[i] += c[i - 1];
			for (int i = n; i >= 1; i--)
				sa[c[x[y[i]]]--] = y[i];
			swap(x, y); 
			p = 1; 
			x[sa[1]] = 1;
			for (int i = 2; i <= n; i++)
			{
				x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p : ++p;
			}
			if (p >= n) break;
			m = p;
		}
		for (int i = 1; i <= n; i++) 
			rnk[sa[i]] = i;
		for (int i = 1, k = 0; i <= n; i++)
		{
			if (k) k--;
			int j = sa[rnk[i] - 1];
			while (s[i + k] == s[j + k])
				k++;
			lcp[rnk[i]] = k;
		}
	}
	int st[maxn][20];
	void build_st() 
	{
		memset(st, 0, sizeof(st));
		for (int i = 1; i <= n; i++) 
			st[i][0] = lcp[i];
		for (int j = 1; (1 << j) <= n; j++) 
		{
			for (int i = 1; (i + (1 << j)) - 1 <= n; i++) 
			{
				st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
			}
		}
	}
	int query(int l, int r) 
	{
		int k = 0; l = rnk[l]; 
		r = rnk[r];
		if (l>r)
			swap(l, r); 
		l++; //一定要l++ 
		while ((1 << (k + 1)) <= r - l + 1) 
			k++;
		return min(st[l][k], st[r - (1 << k) + 1][k]);
	}
}A, B;
int main()
{
	int t; 
	scanf("%d", &t);
	while (t--) 
	{
		scanf("%s", s + 1);
		n = strlen(s + 1);
		A.build_sa();
		A.build_st();
		reverse(&s[1], &s[n + 1]);
		B.build_sa();
		B.build_st();
		memset(f, 0, sizeof(f));
		memset(g, 0, sizeof(g));
		for (int len = 1; len <= n / 2; len++) 
		{
			for (int i = len; i <= n; i += len)
			{
				int j = i + len;
				int Lcp = min(A.query(i, j), len), Lcs = min(B.query(n - i + 2, n - j + 2), len - 1);
				int t = Lcp + Lcs - len + 1;
				if (Lcp + Lcs >= len) 
				{
					g[i - Lcs]++, g[i - Lcs + t]--;
					f[j + Lcp - t]++, f[j + Lcp]--;
				}
			}
		}
		for (int i = 1; i <= n; i++) 
			f[i] += f[i - 1], g[i] += g[i - 1];
		ll ans = 0;
		for (int i = 1; i<n; i++) 
			ans += 1ll*f[i] * g[i + 1];
		printf("%lld\n", ans);
	}
	return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 498ms, 内存消耗: 1304K, 提交时间: 2020-09-30 18:02:52

// luogu-judger-enable-o2
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const ll base = 223, mod = 1000000097;
const int maxn = 30005;
int maxp[maxn], n;
ll ta[maxn], tb[maxn], ans;
char s[maxn];
ll f[maxn], pw[maxn];
ll geth(int l, int r) { return ((f[r] - f[l - 1] * pw[r - l + 1] % mod) % mod + mod) % mod; }
int main() {
	pw[0] = 1;
	for (int i = 1; i < maxn; i++) pw[i] = (pw[i - 1] * base) % mod;
	int T;
	scanf("%d", &T);
	while (T--) {
		memset(ta, 0, sizeof(ta)), memset(tb, 0, sizeof(tb)), ans = 0;
		scanf("%s", s + 1);
		n = strlen(s + 1);
		for (int i = 1; i <= n; i++) f[i] = (f[i - 1] * base + s[i]) % mod;
		for (int i = 1; (i << 1) <= n; i++)
			for (int j = 1; j + i <= n; j += i) {
				int l = 0, r = min(n - i - j + 1, i) + 1, p, q;
				while (l + 1 < r) {
					int mid = l + r >> 1;
					if (geth(j, j + mid - 1) == geth(j + i, j + i + mid - 1)) l = mid;
					else r = mid;
				}
				q = j + l - 1, l = 0, r = min(i, j) + 1;
				while (l + 1 < r) {
					int mid = l + r >> 1;
					if (geth(j - mid + 1, j) == geth(j + i - mid + 1, j + i)) l = mid;
					else r = mid;
				}
				p = max(j - l + 1, 1);
				if (q - p + 1 >= i) {
					++tb[p], --tb[q - i + 2];
					++ta[p + 2 * i - 1], --ta[q + i + 1];
				}
			}
		for (int i = 1; i <= n; i++) ta[i] += ta[i - 1], tb[i] += tb[i - 1];
		for (int i = 1; i < n; i++) ans += ta[i] * tb[i + 1];
		printf("%lld\n", ans);
	}
	return 0;
}

上一题