列表

详情


NC200545. 函数求和

描述

给出一个序列  。
定义  
需要回答 q 个询问,每次询问给出一个区间  , 求 

输入描述

第一行两个整数 n,q ,表示序列长度和询问次数。
第二行 n 个整数,第 i 个数表示 a_i 。
接下来 q 行,每行两个整数 l, r ,表示询问的区间端点。
每次输入的 l,r 需要异或上 lastans ,lastans 表示上一个询问取模后的答案,最初 lastans = 0 。
保证  ,异或后  。

输出描述

对于每个询问,输出一行一个整数,表示  。

示例1

输入:

5 2
0 1 2 3 4 
2 3
2 7

输出:

3
12

说明:

第二次询问输入 2 7 ,实际询问区间为 [1,4]

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 1260ms, 内存消耗: 93584K, 提交时间: 2023-05-25 17:17:24

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ls(p) (p<<1)
#define rs(p) (p<<1|1)
// #define mid ((l+r)>>1)
// #define double long double
#define eps (1e-15)
#define lowbit(i) ((i)&(-i))
const int mod=998244353;
const int inf=1e15;
namespace NTT
{
    const int g=3,gi=332748118,N=5e6,mod=998244353;
    int limit=1,len;
    signed pos[N];
    inline int fast(int x,int k)
    {
        int ret=1;
        while(k)
        {
            if(k&1) ret=ret*x%mod;
            x=x*x%mod;
            k>>=1;
        }
        return ret;
    }
    inline vector<int> ntt(vector<int> a,int inv)
    {
        for(int i=0;i<limit;++i)
            if(i<pos[i]) swap(a[i],a[pos[i]]);
        for(int mid=1;mid<limit;mid<<=1)
        {
            int Wn=fast(inv?g:gi,(mod-1)/(mid<<1));
            for(int r=mid<<1,j=0;j<limit;j+=r)
            {
                int w=1;
                for(int k=0;k<mid;++k,w=w*Wn%mod)
                {
                    int x=a[j+k],y=w*a[j+k+mid]%mod;
                    a[j+k]=(x+y)%mod;
                    a[j+k+mid]=(x-y+mod)%mod;
                }
            }
        }
        if(inv) return a;
        inv=fast(limit,mod-2);
        for(int i=0;i<limit;++i) a[i]=a[i]*inv%mod;
        return a;
    }
    inline vector<int> deriva(vector<int> a,int n)
    {
        a.resize(n);
        for(int i=1;i<n;++i) a[i-1]=a[i]*i%mod;
        a[n-1]=0;
        return a;
    }
    inline vector<int> integral(vector<int> a,int n)
    {
        a.resize(n);
        for(int i=n-1;i;--i) a[i]=a[i-1]*fast(i,mod-2)%mod;
        a[0]=0;
        return a;
    }
    inline vector<int> add(vector<int> a,vector<int> b,int n=-1,int m=-1)
    {
        if(n==-1) n=a.size();
        if(m==-1) m=b.size();
        limit=max(n,m);
        a.resize(limit),b.resize(limit);
        for(int i=0;i<limit;++i) a[i]=(a[i]+b[i])%mod;
        return a;
    }
    inline vector<int> mul(vector<int> a,vector<int> b,int n=-1,int m=-1)
    {
        if(n==-1) n=a.size();
        if(m==-1) m=b.size();
        limit=1,len=0;
        while(limit<n+m) limit<<=1,++len;
        a.resize(limit,0),b.resize(limit,0);
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
        a=ntt(a,1),b=ntt(b,1);
        for(int i=0;i<limit;++i) a[i]=a[i]*b[i]%mod;
        vector<int> c=ntt(a,0);
        c.resize(n+m-1);
        return c;
    }
    inline vector<int> poly_inv(vector<int> a,int n)
    {
        if(n==1)
        {
            vector<int> b(1);
            b[0]=fast(a[0],mod-2);
            return b;
        }
        vector<int> b=poly_inv(a,(n+1)>>1);
        limit=1,len=0;
        while(limit<n+n) limit<<=1,++len;
        for(int i=0;i<limit;++i) pos[i]=(pos[i>>1]>>1)|((i&1)<<(len-1));
        a.resize(limit),b.resize(limit);
        vector<int> c;
        c.resize(limit);
        for(int i=0;i<n;++i) c[i]=a[i];
        for(int i=n;i<limit;++i) c[i]=b[i]=0;
        c=ntt(c,1);b=ntt(b,1);
        for(int i=0;i<limit;++i) b[i]=(2-c[i]*b[i]%mod+mod)%mod*b[i]%mod;
        b=ntt(b,0);
        b.resize(n);
        return b;
    }
    inline vector<int> ln(vector<int> a,int n=-1)
    {
        if(n==-1) n=a.size();
        return integral(mul(deriva(a,n),poly_inv(a,n),n,n),n);
    }
    inline vector<int> exp(vector<int> a,int n=-1)
    {
        if(n==-1) n=a.size();
        if(n==1)
        {
            vector<int> b(1);
            b[0]=1;
            return b;
        }
        vector<int> b=exp(a,(n+1)>>1);
        vector<int> f=ln(b,n);
        f[0]=(a[0]+1-f[0]+mod)%mod;
        for(int i=1;i<n;++i) f[i]=(a[i]-f[i]+mod)%mod;
        b=mul(b,f,n,n);
        b.resize(n);
        return b;
    }
}
void solve()
{
    int n,m;
    cin>>n>>m;
    vector<int> a(n);
    for(auto &x:a) cin>>x;
    int cnt=2500;
    vector C(cnt+1,vector<int>(cnt+1));
    for(int i=0;i<=cnt;++i)
    {
        C[i][0]=1;
        for(int j=1;j<=i;++j) C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod;
    }
    auto tmp=C[cnt];
    vector<vector<int> > p(cnt+1);
    p[0]=a;
    for(int i=1,j=n-cnt;j>=0;j-=cnt,++i)
    {
        p[i]=NTT::mul(p[i-1],tmp);
        p[i].erase(p[i].begin(),p[i].begin()+cnt);
        p[i].resize(j);
    }
    int ans=0;
    while(m--)
    {
        int l,r;
        cin>>l>>r;
        l^=ans,r^=ans;ans=0;
        int len=(r-l)%cnt,d=(r-l)/cnt;
        for(int i=0;i<=len;++i)
        {
            // cout<<p[d][l+i-1]<<' '<<"!!!!"<<endl;
            ans+=p[d][l+i-1]*C[len][i]%mod;
        }
        ans%=mod;
        cout<<ans<<'\n';
    }
}
signed main()
{
    ios::sync_with_stdio(0);
    cin.tie(0),cout.tie(0);
    int T=1; //cin>>T;
    while(T--) solve();
    // cout<<clock()-st<<'\n';
    return 0;
}
/*
h[0]=0
h[1]=


*/

C++14(g++5.4) 解法, 执行用时: 2253ms, 内存消耗: 47236K, 提交时间: 2020-01-11 10:47:40

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define all(x) (x).begin(),(x).end()
const int N = 1e6 + 7;

namespace NTT {
    const int mod = 7 * 17 << 23 | 1;
    const int G = 3;
	inline int add(int x, int y) {
		return (x += y) >= mod ? x - mod : x;
	}

	int power_mod(int a, int b) {
		int res = 1;
		for(; b; b >>= 1, a = (ll) a * a % mod)
			if(b & 1) res = (ll) res * a % mod;
		return res;
	}

	int nbase = 1;
	vector<int> roots = {0, 1};
	vector<int> rev = {0, 1};

	void prepare(int zeros) {
		if(nbase >= zeros) return;
		roots.resize(1 << zeros);
		rev.resize(1 << zeros);
		for(int i = 0; i < 1<<zeros; i++) {
			rev[i] = (rev[i>>1]>>1)|((i&1)<<zeros-1);
		}
		while(nbase < zeros) {
			int z = power_mod(G, (mod-1)>>nbase+1);
			for(int i = 1 << nbase - 1; i < 1 << nbase; i++) {
				roots[i << 1] = roots[i];
				roots[i<<1|1] = (ll) roots[i] * z % mod;
			}
			nbase++;
		}
	}

	void ntt(vector<int> &a, int n = -1) {
		if(n == -1) n = a.size();
		assert((n & (n - 1)) == 0);
		int zeros = __builtin_ctz(n);
		prepare(zeros);
		int shift = nbase - zeros;
		for(int i = 0; i < n; i++) {
			if(i < (rev[i] >> shift)) {
				swap(a[i], a[rev[i]>>shift]);
			}
		}
		for(int i = 1; i < n; i <<= 1) {
			for(int j = 0; j < n; j += i * 2) {
				for(int k = 0; k < i; k++) {
					int z = (ll) a[i + j + k] * roots[i + k] % mod;
					a[i + j + k] = add(a[j + k], mod - z);
					a[j + k] = add(a[j + k], z);
				}
			}
		}
	}
	vector<int> ta, tb;
	vector<int> multiply(const vector<int> &a, const vector<int> &b) {
		int l1 = a.size(), l2 = b.size(), need = l1 + l2 - 1;
		int sz = 1 << (32 - __builtin_clz(need - 1));
		if((int) ta.size() < sz) {
			ta.resize(sz);
			tb.resize(sz);
		}
		copy(all(a), ta.begin());
		copy(all(b), tb.begin());
		fill(ta.begin() + l1, ta.begin() + sz, 0);
		fill(tb.begin() + l2, tb.begin() + sz, 0);
		ntt(ta, sz), ntt(tb, sz);
		int radio = power_mod(sz, mod - 2);
		for(int i = 0; i <= (sz >> 1); i++) {
			int j = (sz - 1) & (sz - i);
			int z = (ll) ta[j] * tb[j] % mod * radio % mod;
			if(i != j) ta[j] = (ll) ta[i] * tb[i] % mod * radio % mod;
			ta[i] = z;
		}
		ntt(ta, sz);
		return vector<int>(ta.begin(), ta.begin() + need);
	}
}
using namespace NTT;

vector<int> a, z[600], t;

const int SZ = 2500;
int c[SZ][SZ];

int C(int n, int m) {
    if(n < m || m < 0) return 0;
    return c[n][m];
}

int main() {
#ifdef local
	freopen("in.txt", "r", stdin);
	//freopen("out2.txt", "w", stdout);
#endif
	ios::sync_with_stdio(false);
	cin.tie(0), cout.tie(0);
	int n, q; cin >> n >> q;
    c[0][0] = 1;
    for(int i = 1; i < SZ; i++) {
        c[i][0] = 1;
        for(int j = 1; j <= i; j++) {
            c[i][j] = add(c[i - 1][j - 1], c[i - 1][j]);
        }
    }

    a.resize(n);
    for(int i = 0; i < n; i++) cin >> a[i];
    int B = 2.0 * sqrt(n * log(n));
    t = vector<int>(c[B], c[B] + B + 1);
    z[0] = a;
    for(int sz = B; sz < n; sz += B) {
        int id = sz / B;
        z[id] = multiply(z[id - 1], t);
        z[id].erase(z[id].begin(), z[id].begin() + B);
        z[id].resize(n - sz);
    }
    int lastans = 0;
    while(q--) {
        int l, r; cin >> l >> r; l ^= lastans, r ^= lastans;
        int sz = (r - l) / B, re = (r - l) % B;
        l--;
//        cout << l << ' ' << r << ' '<< sz << ' '<< re <<endl;
        int res = 0;
        for(int i = 0; i <= re; i++) {
            res = (res + (ll) z[sz][i + l] * C(re, i)) % mod;
        }
        cout << (lastans = res) << '\n';
    }

	return 0;
}

C++11(clang++ 3.9) 解法, 执行用时: 1756ms, 内存消耗: 45016K, 提交时间: 2020-02-16 14:16:53

#include <bits/stdc++.h>
using namespace std;
typedef vector<int> Poly;
typedef long long ll;
const int maxn = 131073, sz = 2500, maxk = 54, P = 998244353, g = 3, N = 262144, K = 17;
int n, m, w[22], rw[22], a[maxn], c[sz + 1][sz + 3], l, r, li, o;
Poly trans, p[maxk];
int Pow(ll p, int e) {
  static ll r;
  for (r = 1; e; e >>= 1, p = p * p % P)
    if (e & 1) r = r * p % P;
  return r;
}
void NTT(Poly &a, int n, int t) {
  for (int i = 1, j = 0; i < n - 1; ++i) {
    int s = n; do j ^= s >>= 1; while (~j & s);
    if (i < j) swap(a[i], a[j]);
  }
  for (int d = 0; (1 << d) < n; d++) {
    int m = 1 << d, m2 = m << 1, _w = t > 0 ? w[d] : rw[d];
    for (int i = 0; i < n; i += m2)
      for (int w = 1, j = 0; j < m; ++j) {
        int &A = a[i + j + m], &B = a[i + j], t = (ll)w * A % P;
        A = B - t; if (A < 0) A += P;
        B += t; if (B >= P) B -= P;
        w = (ll)w * _w % P;
      }
  }
  int linv = Pow(n, P - 2);
  if (t < 0) for (int i = 0; i < n; ++i) a[i] = (ll)a[i] * linv % P;
}
Poly operator * (const Poly &x, const Poly &y) {
  static Poly a, b;
  int n = 1;
  while (n < x.size() + y.size() - 1) n <<= 1;
  a = x, b = y, a.resize(n), b.resize(n);
  NTT(a, n, 1), NTT(b, n, 1);
  for (int i = 0; i < n; ++i) a[i] = (ll)a[i] * b[i] % P;
  NTT(a, n, -1);
  return a;
}
int main() {
  w[K] = Pow(g, (P - 1) / N), rw[K] = Pow(w[K], P - 2);
  for (int i = K - 1; ~i; --i)
    w[i] = (ll)w[i + 1] * w[i + 1] % P, rw[i] = (ll)rw[i + 1] * rw[i + 1] % P;
  scanf("%d%d", &n, &m);
  for (int i = 0; i < n; ++i)
    scanf("%d", a + i);
  for (int i = 0; i <= sz; ++i)
    for (int j = c[i][0] = 1; j <= i; ++j) {
      c[i][j] = c[i - 1][j - 1] + c[i - 1][j];
      if (c[i][j] >= P) c[i][j] -= P;
    }
  trans = Poly(c[sz], c[sz] + sz + 1), p[0] = Poly(a, a + n);
  for (int i = 1, j = n - sz; j >= 0; ++i, j -= sz) {
    p[i] = p[i - 1] * trans;
    p[i].erase(p[i].begin(), p[i].begin() + sz), p[i].resize(j);
  }
  ll ans = 0;
  while (m--) {
    scanf("%d%d", &l, &r);
    l ^= ans, r ^= ans, ans = 0;
    o = (r - l) % sz, li = (r - l) / sz;
    for (int i = 0; i <= o; ++i)
      ans += (ll)p[li][l - 1 + i] * c[o][i] % P;
    printf("%lld\n", ans %= P);
  }
  return 0;
}

上一题