列表

详情


NC249078. (A+B)^N%P Problem

描述

输入一个字符串,为一个代数式,请把代数式拆分(运算)。

每一项的系数 \%p ,如系数 \%p 后等于 0 则省略输出整一项,如系数 \%p 后等于 1 则省略输出系数。

每一项按照第一个字母降幂输出,如为 1 次则省略 "^1",不可省略乘号。

如果答案中只有一项,省略括号。

字符串格式仅为 "(ax+by)^n%p"。

其中 x,y 为可以为26个小写字母中任意一个字符, a,b,n,p 为正整数,若 n=1 ,则 "^n" 省略输入。如 a,b 等于 1 ,则输入中会省略。

1\le n \le 10^31\le p \le 10^91\le a,b \le 10^6

输入描述

一行一个字符串,格式为 "(ax+by)^n%p"。
具体见样例。

输出描述

一行一个字符串,格式为 "(ax+by)^n%p = 答案"。
具体见样例。

示例1

输入:

(2x+3y)^2%5

输出:

(2x+3y)^2%5 = (4*x^2+2*x*y+4*y^2)%5

说明:

(2x+3y)^2=4x^2+12xy+9y^2
4\%5=4,12\%5=2,9\%5=4

示例2

输入:

(2x+3y)^2%6

输出:

(2x+3y)^2%6 = (4*x^2+3*y^2)%6

说明:

(2x+3y)^2=4x^2+12xy+9y^2
4\%6=4
12\%6=0,故输出中省略该项。
9\%6=3

示例3

输入:

(c+2c)^3%4

输出:

(c+2c)^3%4 = 3*c^3%4

说明:

(c+2c)^3=(3c)^3=27c^3
27\%4=3,由于答案中只有一项,省略括号。

示例4

输入:

(2a+2b)^2%2

输出:

(2a+2b)^2%2 = 0

说明:

取模后所有项系数均为 0,故输出 0

示例5

输入:

(3c+6d)%4

输出:

(3c+6d)%4 = (3*c+2*d)%4

说明:

省略输入乘方即代表一次方,n=1

示例6

输入:

(x+2y)^3%7

输出:

(x+2y)^3%7 = (x^3+6*x^2*y+5*x*y^2+y^3)%7

说明:

(x+2y)^3=x^3+6x^2y+12xy^2+8y^3

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

C++(g++ 7.5.0) 解法, 执行用时: 90ms, 内存消耗: 46788K, 提交时间: 2023-03-24 20:32:43

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
const int N = 1e6 + 10;
ll p, a, b, k;
int n;
char s[N], x, y;
ll C[3010][3010];
ll ksm(ll a, ll b) {
	ll res = 1;
	while (b) {
		if (b & 1) res = res * a % p;
		b >>= 1;
		a = a * a % p;
	}
	return res;
}
void solve() {
	scanf("%s", s + 1);
	n = strlen(s + 1);
	int r, bf;
	for (r = 1; r <= n; r++)
		if (s[r] == ')')break;
	for (bf = 1; bf <= n; bf++)
		if (s[bf] == '%')break;
	p = k = 0;
	for (int i = bf + 1; i <= n; i++) {
		p = p * 10 + s[i] - '0';
	}
	if (s[r + 1] != '^') k = 1;
	else {
		for (int i = r + 2; i < bf; i++) {
			k = k * 10 + s[i] - '0';
		}
	}
	int pos;
	for (pos = 2; ; pos++) {
		if (s[pos] < '0' || s[pos] > '9') {
			x = s[pos];
			break;
		}
		a = a * 10 + s[pos] - '0';
	}
	a = max(a, 1ll);
	for (pos = pos + 2; ; pos++) {
		if (s[pos] < '0' || s[pos] > '9') {
			y = s[pos];
			break;
		}
		b = b * 10 + s[pos] - '0';
	}
	b = max(b, 1ll);
	a %= p, b %= p;
	if (x == y) {
		a = a + b;
		a = ksm(a, k);
		if (a == 0) {
			printf("%s = 0\n", s + 1);
			return;
		}
		printf("%s = ", s + 1);
		if (a != 1) printf("%lld*", a);
		printf("%c", x);
		if (k != 1) printf("^%lld", k);
		printf("%%%lld\n", p);
        return;
	}
	for (int i = 0; i <= 3000; i++) C[i][0] = 1;
	for (int i = 1; i <= 3000; i++) {
		for (int j = 1; j <= i; j++) {
			C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % p;
		}
	}
	vector<ll> v;
	int cnt = 0;
	for (int i = k; i >= 0; i--) {
		ll res = C[k][i] * ksm(a, i) % p * ksm(b, k - i) % p;
		v.push_back(res);
		if (res != 0) cnt++;
	}
	if (cnt == 0) {
		printf("%s = 0\n", s + 1);
		return;
	}
	printf("%s = ", s + 1);
	if (cnt != 1) printf("(");
	bool flag = false;
	for (int i = 0; i < v.size(); i++) {
		if (v[i] == 0) continue;
		if (flag) printf("+");
		flag = true;
		bool tmp = false;
		if (v[i] != 1) printf("%lld", v[i]), tmp = true;
		if (i != k) {
			if (tmp == true) printf("*");
			printf("%c", x);
			if (i != k - 1) printf("^%d", k - i);
			tmp = true;
		}
		if (i != 0) {
			if (tmp == true) printf("*");
			printf("%c", y);
			if (i != 1) printf("^%d", i);
		}
	}
	if(cnt != 1) printf(")");
	printf("%%%lld\n", p);
}
int main() {
	int _ = 1;
	//cin >> _;
	while (_--) solve();
	return 0;
}

Python3 解法, 执行用时: 331ms, 内存消耗: 4988K, 提交时间: 2023-03-24 21:07:02

# (ax+by)^n%p

def C(n : int, m : int) -> int:
    if m > n:
        return 0
    res = 1
    for i in range(m):
        res = res * (n - i) // (i + 1)
    return res

def pw(x : str, y : int) -> str:
    if y == 0:
        return ""
    if y == 1:
        return x
    return x + '^' + str(y)
def poly(c : int, ch1 : str, ch2 : str, x : int, y : int) -> str:
    res = ""
    if c != 1:
        res = str(c)
    X = pw(ch1, x)
    if len(X) > 0:
        if len(res) == 0:
            res = X
        else:
            res += "*" + X
    Y = pw(ch2, y)
    if len(Y) > 0:
        if len(res) == 0:
            res = Y
        else:
            res += "*" + Y
    return res

if __name__ == '__main__':
    inp = input()

    ch1 = inp[inp.find('+') - 1]
    ch2 = inp[inp.find(')') - 1]
    # print(ch1, ch2)

    p1 = inp.find('(')
    p2 = inp.find(ch1)
    a = 1
    if p1 + 1 != p2:
        a = int(inp[p1 + 1:p2])
    p3 = inp.find('+')
    p4 = inp.find(ch2, p3 + 1)
    b = 1
    if p3 + 1 != p4:
        b = int(inp[p3 + 1:p4])
    p5 = inp.find(')')
    p6 = inp.find('%')
    n = 1
    if p5 + 1 != p6:
        n = int(inp[p5 + 2:p6])
    p = int(inp[p6 + 1:])
    # print(a, b, n, p)

    ans = "0"
    if ch1 == ch2:
        a += b
        # print(a, n, p)

        con = pow(a, n, p)
        if con > 0:
            ans = ""
            if con != 1:
                ans = str(con) + "*"

            if n == 1:
                ans += ch1 + '%' + str(p)
            else:
                ans += ch1 + '^' + str(n) + '%' + str(p)

    else:
        cons = []
        cnt = 0
        for i in range(n + 1):
            cons.append(C(n, i) * pow(a, n - i, p) * pow(b, i, p) % p)
            if cons[-1] != 0:
                if cnt != 0:
                    ans += "+" + poly(cons[-1], ch1, ch2, n - i, i)
                else:
                    ans = poly(cons[-1], ch1, ch2, n - i, i)
                cnt += 1

        if cnt > 1:
            ans = "(" + ans + ")"
        if cnt > 0:
            ans += "%" + str(p)

    print(inp + " = " + ans)

C++(clang++ 11.0.1) 解法, 执行用时: 12ms, 内存消耗: 7708K, 提交时间: 2023-03-30 18:14:27

#include <bits/stdc++.h>

#define ios ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)

#define int long long

#define debug(a) cout<<#a<<": "<<a<<"\n"

using namespace std;

const int N=1e3+10;

string a,res;

int _2=0,_5=0,n=0,p=0;

int len;

int fpow(int a,int b) {

	int res=1,tmp=a%p;	while(b) {

		if(b&1)res=res*tmp%p;

		tmp=tmp*tmp%p;

		b>>=1;

	}

	return res;

}

int c[N][N];

void init() {

	c[0][0]=1;

	for(int i=1;i<N;i++){

		for(int j=0;j<=i;j++){

			if(j==0)c[i][j]=1;

			else c[i][j]=(c[i-1][j-1]+c[i-1][j])%p;

		}

	}

	

}

signed main() {

	ios;

	cin>>a;

	len=a.size();

	char x,y;

	int i=0,j;

	for(;i<len;i++)if(islower(a[i]))break;

	if(i==1)_2=1;

	else _2=stoi(a.substr(1,i-1));

	x=a[i];

	j=i+1;

	i++; 

	

	for(;i<len;i++)if(islower(a[i]))break;

	if(i==j+1)_5=1;

	else _5=stoi(a.substr(j+1,i-j));

	y=a[i];

	j=i+1;

	i++;

	

	for(;i<len;i++)if(a[i]=='%')break;

	if(i==j+1)n=1;

	else n=stoi(a.substr(j+2,i-j-2));

	j=i+1;

	

// 	cout<<i<<"\n";

	p=stoi(a.substr(i+1));

	

	if(x==y) {

		int f=fpow(_2+_5,n);

		if(f==0)return cout<<a+" = 0\n",0;

		else if(f==1)res+=x;

		else res+=to_string(f)+'*'+x;

		

		if(n>1)res+='^'+to_string(n);

		cout<<a+" = "+res+'%'<<p<<"\n";

	}

	else { //一般情况

		init();

		

		vector<string>tmp;

		for(int i=n; i>=0; i--) { //x的幂

			int f=0;

			f=fpow(_2,i)*fpow(_5,n-i)%p*c[n][i]%p;

			string now;

			if(f==0)continue;

			if(f!=1)now+=to_string(f);

			

			if(i>1)now=now+'*'+x+'^'+to_string(i);

			else if(i==1)now=now+'*'+x;

			if(n-i>1)now=now+'*'+y+'^'+to_string(n-i);

			else if(n-i==1)now=now+'*'+y;

			reverse(now.begin(),now.end());

			while(now.back()=='*')now.pop_back();

			reverse(now.begin(),now.end());

			tmp.emplace_back(now);

			

		}

		int sz=tmp.size();

		if(sz==0)return cout<<a+" = 0\n",0;

		for(int i=0; i<sz; i++) {

			if(i)res=res+'+';

			res+=tmp[i];

		}
        if(sz==1)cout<<a+" = "+res+'%'<<p<<"\n";
		else cout<<a+" = "+'('+res+')'+'%'<<p<<"\n";

	}

}

pypy3 解法, 执行用时: 651ms, 内存消耗: 30772K, 提交时间: 2023-03-24 20:04:05

c = lambda n, m: 1 if m == 0 else c(n-1, m-1) * n // m

import re
pattern = re.compile(r'\((\d*)([a-z])\+(\d*)([a-z])\)\^*(\d*)\%(\d*)')
s = input()
l = []
for i in pattern.match(s).groups():
    try:
        l.append(int(i))
    except:
        if i:
            l.append(i)
        else:
            l.append(1)
a, x, b, y, n, p = int(l[0]), l[1], int(l[2]), l[3], int(l[4]), int(l[5])
ans = ''
if x != y:
    cnt = 0
    for i in range(n+1):
        temp = ''
        flag = 0
        k = a**(n-i)*b**i*c(n, i) % p
        if k == 0:
            continue
        elif k > 1:
            temp += str(k)
            flag = 1
        if (n-i) == 1:
            if flag:
                temp += '*'
            temp += x
            flag = 1
        elif (n-i) > 1:
            if flag:
                temp += '*'
            temp += x + '^' + str(n-i)
            flag = 1
        if i == 1:
            if flag:
                temp += '*'
            temp += y
        elif i > 1:
            if flag:
                temp += '*'
            temp += y + '^' + str(i)
        if ans:
            ans += '+'
        if temp:
            cnt += 1
        ans += temp
    if ans and cnt > 1:
        ans = '(' + ans + ')' + '%' + str(p)
    elif ans and cnt == 1:
        ans += '%' + str(p)
    else:
        ans = '0'
else:
    k = (a+b)**n % p
    flag = 0
    if k == 0:
        print(s + ' = 0')
        exit()
    elif k > 1:
        ans += str(k)
        flag = 1
    if n == 1:
        if flag:
            ans += '*'
        ans += x
    elif n > 1:
        if flag:
            ans += '*'
        ans += x + '^' + str(n)
    ans += '%' + str(p)
print(s + ' = ' + ans)

上一题