列表

详情


NC252417. 少女曾见的日本原风景

描述

东风谷早苗想起了一道以前见过的ACM题,她打算做出这道题然后跟灵梦显摆。题目是这样的:

定义 w(s,t) 为有多少个不同的回文串同时在 st 中出现。定义字符串 sval 函数为:
val(s) = \sum_{i=1}^{|s|-1} w^2(s_{1, i}, s_{i+1, |s|})
对于给定字符串 s,需要分割成 k 个子段,最大的子段的 val 值最小。

显然早苗不会,于是这道题是你的了。

输入描述

第一行两个正整数,表示字符串长度 n 和字段数 k
第二行一个字符串 s

输出描述

一个整数,表示分割后最小的 val 值。

示例1

输入:

10 2
aabbaaabbb

输出:

4

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(clang++ 11.0.1) 解法, 执行用时: 916ms, 内存消耗: 24944K, 提交时间: 2023-05-25 12:32:50

#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, k;
const int maxn = 1e6 + 10;
int f[maxn], l[maxn];
int d1[maxn];
struct PAM
{
    string s;
    std::vector<int> e[maxn];
    int fail[maxn];     // fail指针
    int len[maxn];      // 该节点表示的字符串长度
    int tree[maxn][26]; // 同Trie,指向儿子
    int trans[maxn];    // trans指针
    int tot = 1, pre;   // tot代表节点数,pre代表上次插入字符后指向的回文树位置
    int getfail(int x, int i)
    {
        while (i - len[x] - 1 < 0 || s[i - len[x] - 1] != s[i])
        {
            x = fail[x];
        }
        return x;
    }
    void init()
    {
        // 1是奇根,0是偶根
        for (int i = 0; i <= tot + 1; i++)
        {
            len[i] = 0;
            l[i] = f[i] = 0;
            e[i].clear();
            memset(tree[i], 0, sizeof(tree[i]));
            fail[i] = 0;
        }
        fail[0] = 1;
        e[1].push_back(0);
        tot = 1, pre = 0;
        len[1] = -1;
    }
    int gettrans(int x, int i)
    {
        while (((len[x] + 2) << 1) > len[tot] || s[i - len[x] - 1] != s[i])
            x = fail[x];
        return x;
    }
    void insert(int u, int i)
    {
        int Fail = getfail(pre, i); // 找到符合要求的点
        if (!tree[Fail][u])
        {
            // 没建过就新建节点
            len[++tot] = len[Fail] + 2;                  // 长度自然是父亲长度+2
            fail[tot] = tree[getfail(fail[Fail], i)][u]; // fail为满足条件的次短回文串+u
            e[fail[tot]].push_back(tot);
            tree[Fail][u] = tot; // 指儿子
            if (len[tot] <= 2)
                trans[tot] = fail[tot]; // 特殊trans
            else
            {
                int Trans = gettrans(trans[Fail], i); // 求trans
                trans[tot] = tree[Trans][u];
            }
            f[tree[Fail][u]] = i + 1;
            l[tree[Fail][u]] = i + 1;
        }
        pre = tree[Fail][u]; // 更新pre
        d1[i] = pre;
        l[pre] = max(l[pre], i + 1);
    }
    void dfs(int u, int fa)
    {
        for (auto x : e[u])
        {
            dfs(x, u);
            f[u] = min(f[u], f[x]);
            l[u] = max(l[u], l[x]);
        }
    }
} t;
int d[maxn];
int cal(string s)
{
    t.init();
    t.s = s;
    for (int i = 0; i < s.length(); i++)
    {
        d[i + 1] = 0;
        t.insert(s[i] - 'a', i);
    }
    t.dfs(1, 0);
    for (int i = 2; i <= t.tot; i++)
    {
        if (f[i] <= l[i] - t.len[i] + 1)
        {
            d[f[i]]++;
            d[l[i] - t.len[i] + 1]--;
        }
    }
    for (int i = 1; i <= s.length(); i++)
    {
        d[i] += d[i - 1];
    }
    int ans = 0;
    for (int i = 1; i < s.length(); i++)
    {
        ans += d[i] * d[i];
    }
    return ans;
}

string s;
bool check(int x)
{
    int tot = 0;
    for (int i = 0; i < n;)
    {
        int j = 0;
        int len = 1;
        while (i + len <= n && cal(s.substr(i, len)) <= x)
        {
            len *= 2;
        }
        while (len > 1) // 二进制拆分
        {
            len /= 2;
            if (i + j + len <= n && cal(s.substr(i, j + len)) <= x)
            {
                j += len;
            }
        }
        i += j;
        tot++;
        if (tot > k)
        {
            return 0;
        }
    }
    return 1;
}
signed main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);
    cin >> n >> k;
    cin >> s;
    int l = 0, r = 1e9;
    int ans = 0;
    while (l <= r)
    {
        int mid = l + r >> 1;
        if (check(mid))
        {
            r = mid - 1;
            ans = mid;
        }
        else
        {
            l = mid + 1;
        }
    }
    cout << ans;
    return 0;
}

C++(g++ 7.5.0) 解法, 执行用时: 716ms, 内存消耗: 1076K, 提交时间: 2023-05-19 21:24:53

#include <bits/stdc++.h>

using namespace std;

typedef long long i64;

const int N = 100005;
int n, k;
char a[N];

namespace pam {
int cnt, lt, fail[N], ch[N][26], len[N], mn[N], mx[N];

void build(char *a, int n) {
  cnt = 1, lt = 1;
  fail[1] = 0, fail[0] = 1, len[1] = -1;
  for (int i = 1; i <= n; i++) {
    while (i - 1 - len[lt] < 1 || a[i - 1 - len[lt]] != a[i]) lt = fail[lt];
    if (!ch[lt][a[i] - 'a']) {
      cnt++;
      mn[cnt] = i;
      for (int x = fail[lt];; x = fail[x])
        if (a[i - 1 - len[x]] == a[i]) {
          fail[cnt] = ch[x][a[i] - 'a'];
          break;
        }
      ch[lt][a[i] - 'a'] = cnt, len[cnt] = len[lt] + 2;
    }
    lt = ch[lt][a[i] - 'a'];
    mx[lt] = i;
  }
}

int d[N];

i64 calc(int n) {
  for (int i = cnt; i > 1; i--) {
    mx[fail[i]] = max(mx[fail[i]], mx[i]);
    if (mx[i] - len[i] >= mn[i]) {
      d[mn[i]]++, d[mx[i] - len[i] + 1]--;
    }
  }

  i64 sum = 0;
  for (int i = 1; i <= n; i++) {
    d[i] += d[i - 1];
    sum += 1ll * d[i] * d[i];
  }
  memset(d, 0, (n + 1) << 2);
  return sum;
}

void clear() {
  for (int i = 0; i <= cnt; i++) {
    fail[i] = len[i] = mn[i] = mx[i] = 0;
    memset(ch[i], 0, sizeof ch[i]);
  }
  cnt = 0;
}

bool check(int l, int r, i64 k) {
  if (r > n) {
    return 0;
  }

  build(a + l - 1, r - l + 1);
  i64 tp = calc(r - l + 1);
  clear();
  return tp <= k;
}
}  // namespace pam

int main() {
  ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);

  cin >> n >> k >> (a + 1);
  i64 l = 0, r = 0, ans = 0;
  int gu = (n - 1) / k + 1;
  for (int i = 1; i != gu; i++) {
    int t = min(i, gu - i);
    r += 1ll * t * t;
  }
  ans = r--;

  while (l <= r) {
    i64 mid = (l + r) >> 1;
    int s = 0;
    bool ok = 1;

    for (int i = 1, j; i <= n; i = j + 1, s++) {
      if (s == k) {
        ok = 0;
        break;
      }

      int t = 1;
      while (1) {
        if (pam::check(i, i + (1 << t) - 1, mid)) {
          t++;
          continue;
        }
        break;
      }

      j = i + (1 << (t - 1)) - 1;
      int l = j + 1, r = min(n, i + (1 << t) - 2);
      while (l <= r) {
        int m = (l + r) >> 1;
        if (pam::check(i, m, mid)) {
          j = m;
          l = m + 1;
        } else {
          r = m - 1;
        }
      }
    }

    if (ok) {
      ans = mid;
      r = mid - 1;
    } else {
      l = mid + 1;
    }
  }

  cout << ans << endl;
}

上一题