NC20200. [JSOI2013]快乐的JYY
描述
输入描述
数据包行两行由大写字母组成的字符串A和B
1 ≤ |A|,|B| ≤ 50000。
输出描述
包含一行一个整数,表示紧密程度,也就是满足要求的4元组个数
示例1
输入:
PUPPY PUPPUP
输出:
17
C++(g++ 7.5.0) 解法, 执行用时: 10ms, 内存消耗: 6452K, 提交时间: 2022-08-15 12:38:05
#include<bits/stdc++.h> #define ll long long using namespace std; const int N = 5e4 + 10; struct pam { int sz,tot,last; int cnt[N], nxt[N][26], len[N], r[N], fail[N]; char s[N]; int node(int l) { ++sz; memset(nxt[sz],0,sizeof nxt[sz]); len[sz] = l, fail[sz] = cnt[sz] = 0; return sz; } void clear() { sz = -1; last = 0; s[tot=0] = '@'; node(0); node(-1); fail[0] = 1; } int getfail(int x) { while(s[tot-len[x]-1]!=s[tot]) x = fail[x]; return x; } void add(char c) { s[++tot] = c; int now = getfail(last); if(!nxt[now][c-'A']){ int x = node(len[now]+2); fail[x] = nxt[getfail(fail[now])][c-'A']; nxt[now][c-'A'] = x; } last = nxt[now][c-'A']; cnt[last]++; } void init(char *str) { clear(); int len = strlen(str+1); for(int i = 1 ; i <= len ; i ++) { add(str[i]); r[last] = i; } } ll solve(char *str) { for(int i = sz ; ~i ; i --) { cnt[fail[i]] += cnt[i]; } cnt[0] = cnt[1] = 0; for(int i = 2 ; i <= sz ; i ++){ cnt[i] += cnt[fail[i]]; } int tmp = 1; ll ans = 0; for(int i = 1 ; str[i] ; i ++){ while(tmp!=1 && (!nxt[tmp][str[i]-'A']) || (str[i-len[tmp]-1]!=str[i])) tmp = fail[tmp]; tmp = nxt[tmp][str[i]-'A']; ans += cnt[tmp]; } return ans; } }; char s1[N],s2[N]; int main() { ios::sync_with_stdio(false); cin.tie(0); cin >> s1+1 >> s2+1; pam sol; sol.init(s1); cout << sol.solve(s2) << '\n'; return 0; }
C++(clang++ 11.0.1) 解法, 执行用时: 10ms, 内存消耗: 6296K, 提交时间: 2022-08-18 10:44:31
#include <bits/stdc++.h> using namespace std; const int N = 5e5 + 10; char s[N]; int len[N], cnt[N][2], fail[N]; int t[N][26], last = 0, tot = 1; int getfail(int u, int p) { while (s[p - len[u] - 1] != s[p]) u = fail[u]; return u; } void insert(int c, int p, int id) { int u = getfail(last, p); if (!t[u][c]) { fail[++tot] = t[getfail(fail[u], p)][c]; len[t[u][c] = tot] = len[u] + 2; } cnt[last = t[u][c]][id]++; } int main() { fail[0] = 1, s[0] = '!', len[1] = -1; scanf("%s", s + 1); for (int i = 1; s[i]; i++) insert(s[i] - 'A', i, 0); scanf("%s", s + 1), last = 0; for (int i = 1; s[i]; i++) insert(s[i] - 'A', i, 1); for (int i = tot; i > 1; i--) cnt[fail[i]][0] += cnt[i][0], cnt[fail[i]][1] += cnt[i][1]; long long ans = 0; for (int i = 2; i <= tot; i++) ans += 1ll * cnt[i][0] * cnt[i][1]; cout << ans; return 0; }