NC17895. [NOI2016]优秀的拆分
描述
如果一个字符串可以被拆分为 AABB的形式,其中 A 和 B 是任意非空字符串,则我们称该字符串的这种拆分是优秀的。
例如,对于字符串 aabaabaa,如果令 A=aab,B=a,我们就找到了这个字符串拆分成 AABB 的一种方式。
一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。比如我们令 A=a,B=baa,也可以用 AABB表示出上述字符串;但是,字符串 abaabaa 就没有优秀的拆分。
现在给出一个长度为 nn 的字符串 SS,我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。这里的子串是指字符串中连续的一段。
以下事项需要注意:
输入描述
包含多组数据。第一行只有一个整数 T,表示数据的组数。保证 1≤T≤10
接下来 T 行,每行包含一个仅由英文小写字母构成的字符串 S,意义如题所述。
输出描述
输出 T 行,每行包含一个整数,表示字符串 S 所有子串的所有拆分中,总共有多少个是优秀的拆分。
示例1
输入:
4 aabbbb cccccc aabaabaabaa bbaabaababaaba
输出:
3 5 4 7
C++14(g++5.4) 解法, 执行用时: 352ms, 内存消耗: 20184K, 提交时间: 2019-10-17 19:42:12
#include<cstdio> #include<cstring> #include<algorithm> #include<iostream> using namespace std; #define maxn 100010 typedef long long ll; char s[maxn]; int n, c[maxn], x[maxn], y[maxn], f[maxn], g[maxn]; struct SuffixArray { int sa[maxn], rnk[maxn], lcp[maxn]; void build_sa() { int m = 122; memset(x, 0, sizeof(x)); memset(y, 0, sizeof(y)); memset(sa, 0, sizeof(sa)); memset(rnk, 0, sizeof(rnk)); memset(lcp, 0, sizeof(lcp)); for (int i = 1; i <= m; i++) c[i] = 0; for (int i = 1; i <= n; i++) ++c[x[i] = s[i]]; for (int i = 2; i <= m; i++) c[i] += c[i - 1]; for (int i = n; i >= 1; i--) sa[c[x[i]]--] = i; for (int k = 1; k <= n; k <<= 1) { int p = 0; for (int i = n - k + 1; i <= n; i++) y[++p] = i; for (int i = 1; i <= n; i++) if (sa[i]>k) y[++p] = sa[i] - k; for (int i = 1; i <= m; i++) c[i] = 0; for (int i = 1; i <= n; i++) ++c[x[y[i]]]; for (int i = 1; i <= m; i++) c[i] += c[i - 1]; for (int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i]; swap(x, y); p = 1; x[sa[1]] = 1; for (int i = 2; i <= n; i++) { x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p : ++p; } if (p >= n) break; m = p; } for (int i = 1; i <= n; i++) rnk[sa[i]] = i; for (int i = 1, k = 0; i <= n; i++) { if (k) k--; int j = sa[rnk[i] - 1]; while (s[i + k] == s[j + k]) k++; lcp[rnk[i]] = k; } } int st[maxn][20]; void build_st() { memset(st, 0, sizeof(st)); for (int i = 1; i <= n; i++) st[i][0] = lcp[i]; for (int j = 1; (1 << j) <= n; j++) { for (int i = 1; (i + (1 << j)) - 1 <= n; i++) { st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]); } } } int query(int l, int r) { int k = 0; l = rnk[l]; r = rnk[r]; if (l>r) swap(l, r); l++; //一定要l++ while ((1 << (k + 1)) <= r - l + 1) k++; return min(st[l][k], st[r - (1 << k) + 1][k]); } }A, B; int main() { int t; scanf("%d", &t); while (t--) { scanf("%s", s + 1); n = strlen(s + 1); A.build_sa(); A.build_st(); reverse(&s[1], &s[n + 1]); B.build_sa(); B.build_st(); memset(f, 0, sizeof(f)); memset(g, 0, sizeof(g)); for (int len = 1; len <= n / 2; len++) { for (int i = len; i <= n; i += len) { int j = i + len; int Lcp = min(A.query(i, j), len), Lcs = min(B.query(n - i + 2, n - j + 2), len - 1); int t = Lcp + Lcs - len + 1; if (Lcp + Lcs >= len) { g[i - Lcs]++, g[i - Lcs + t]--; f[j + Lcp - t]++, f[j + Lcp]--; } } } for (int i = 1; i <= n; i++) f[i] += f[i - 1], g[i] += g[i - 1]; ll ans = 0; for (int i = 1; i<n; i++) ans += 1ll*f[i] * g[i + 1]; printf("%lld\n", ans); } return 0; }
C++11(clang++ 3.9) 解法, 执行用时: 498ms, 内存消耗: 1304K, 提交时间: 2020-09-30 18:02:52
// luogu-judger-enable-o2 #include <bits/stdc++.h> using namespace std; typedef long long ll; const ll base = 223, mod = 1000000097; const int maxn = 30005; int maxp[maxn], n; ll ta[maxn], tb[maxn], ans; char s[maxn]; ll f[maxn], pw[maxn]; ll geth(int l, int r) { return ((f[r] - f[l - 1] * pw[r - l + 1] % mod) % mod + mod) % mod; } int main() { pw[0] = 1; for (int i = 1; i < maxn; i++) pw[i] = (pw[i - 1] * base) % mod; int T; scanf("%d", &T); while (T--) { memset(ta, 0, sizeof(ta)), memset(tb, 0, sizeof(tb)), ans = 0; scanf("%s", s + 1); n = strlen(s + 1); for (int i = 1; i <= n; i++) f[i] = (f[i - 1] * base + s[i]) % mod; for (int i = 1; (i << 1) <= n; i++) for (int j = 1; j + i <= n; j += i) { int l = 0, r = min(n - i - j + 1, i) + 1, p, q; while (l + 1 < r) { int mid = l + r >> 1; if (geth(j, j + mid - 1) == geth(j + i, j + i + mid - 1)) l = mid; else r = mid; } q = j + l - 1, l = 0, r = min(i, j) + 1; while (l + 1 < r) { int mid = l + r >> 1; if (geth(j - mid + 1, j) == geth(j + i - mid + 1, j + i)) l = mid; else r = mid; } p = max(j - l + 1, 1); if (q - p + 1 >= i) { ++tb[p], --tb[q - i + 2]; ++ta[p + 2 * i - 1], --ta[q + i + 1]; } } for (int i = 1; i <= n; i++) ta[i] += ta[i - 1], tb[i] += tb[i - 1]; for (int i = 1; i < n; i++) ans += ta[i] * tb[i + 1]; printf("%lld\n", ans); } return 0; }