NC249078. (A+B)^N%P Problem
描述
输入描述
一行一个字符串,格式为 "(ax+by)^n%p"。具体见样例。
输出描述
一行一个字符串,格式为 "(ax+by)^n%p = 答案"。
具体见样例。
示例1
输入:
(2x+3y)^2%5
输出:
(2x+3y)^2%5 = (4*x^2+2*x*y+4*y^2)%5
说明:
示例2
输入:
(2x+3y)^2%6
输出:
(2x+3y)^2%6 = (4*x^2+3*y^2)%6
说明:
示例3
输入:
(c+2c)^3%4
输出:
(c+2c)^3%4 = 3*c^3%4
说明:
示例4
输入:
(2a+2b)^2%2
输出:
(2a+2b)^2%2 = 0
说明:
取模后所有项系数均为 ,故输出 。示例5
输入:
(3c+6d)%4
输出:
(3c+6d)%4 = (3*c+2*d)%4
说明:
省略输入乘方即代表一次方,。示例6
输入:
(x+2y)^3%7
输出:
(x+2y)^3%7 = (x^3+6*x^2*y+5*x*y^2+y^3)%7
说明:
C++(g++ 7.5.0) 解法, 执行用时: 90ms, 内存消耗: 46788K, 提交时间: 2023-03-24 20:32:43
#include <bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int> PII; const int N = 1e6 + 10; ll p, a, b, k; int n; char s[N], x, y; ll C[3010][3010]; ll ksm(ll a, ll b) { ll res = 1; while (b) { if (b & 1) res = res * a % p; b >>= 1; a = a * a % p; } return res; } void solve() { scanf("%s", s + 1); n = strlen(s + 1); int r, bf; for (r = 1; r <= n; r++) if (s[r] == ')')break; for (bf = 1; bf <= n; bf++) if (s[bf] == '%')break; p = k = 0; for (int i = bf + 1; i <= n; i++) { p = p * 10 + s[i] - '0'; } if (s[r + 1] != '^') k = 1; else { for (int i = r + 2; i < bf; i++) { k = k * 10 + s[i] - '0'; } } int pos; for (pos = 2; ; pos++) { if (s[pos] < '0' || s[pos] > '9') { x = s[pos]; break; } a = a * 10 + s[pos] - '0'; } a = max(a, 1ll); for (pos = pos + 2; ; pos++) { if (s[pos] < '0' || s[pos] > '9') { y = s[pos]; break; } b = b * 10 + s[pos] - '0'; } b = max(b, 1ll); a %= p, b %= p; if (x == y) { a = a + b; a = ksm(a, k); if (a == 0) { printf("%s = 0\n", s + 1); return; } printf("%s = ", s + 1); if (a != 1) printf("%lld*", a); printf("%c", x); if (k != 1) printf("^%lld", k); printf("%%%lld\n", p); return; } for (int i = 0; i <= 3000; i++) C[i][0] = 1; for (int i = 1; i <= 3000; i++) { for (int j = 1; j <= i; j++) { C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % p; } } vector<ll> v; int cnt = 0; for (int i = k; i >= 0; i--) { ll res = C[k][i] * ksm(a, i) % p * ksm(b, k - i) % p; v.push_back(res); if (res != 0) cnt++; } if (cnt == 0) { printf("%s = 0\n", s + 1); return; } printf("%s = ", s + 1); if (cnt != 1) printf("("); bool flag = false; for (int i = 0; i < v.size(); i++) { if (v[i] == 0) continue; if (flag) printf("+"); flag = true; bool tmp = false; if (v[i] != 1) printf("%lld", v[i]), tmp = true; if (i != k) { if (tmp == true) printf("*"); printf("%c", x); if (i != k - 1) printf("^%d", k - i); tmp = true; } if (i != 0) { if (tmp == true) printf("*"); printf("%c", y); if (i != 1) printf("^%d", i); } } if(cnt != 1) printf(")"); printf("%%%lld\n", p); } int main() { int _ = 1; //cin >> _; while (_--) solve(); return 0; }
Python3 解法, 执行用时: 331ms, 内存消耗: 4988K, 提交时间: 2023-03-24 21:07:02
# (ax+by)^n%p def C(n : int, m : int) -> int: if m > n: return 0 res = 1 for i in range(m): res = res * (n - i) // (i + 1) return res def pw(x : str, y : int) -> str: if y == 0: return "" if y == 1: return x return x + '^' + str(y) def poly(c : int, ch1 : str, ch2 : str, x : int, y : int) -> str: res = "" if c != 1: res = str(c) X = pw(ch1, x) if len(X) > 0: if len(res) == 0: res = X else: res += "*" + X Y = pw(ch2, y) if len(Y) > 0: if len(res) == 0: res = Y else: res += "*" + Y return res if __name__ == '__main__': inp = input() ch1 = inp[inp.find('+') - 1] ch2 = inp[inp.find(')') - 1] # print(ch1, ch2) p1 = inp.find('(') p2 = inp.find(ch1) a = 1 if p1 + 1 != p2: a = int(inp[p1 + 1:p2]) p3 = inp.find('+') p4 = inp.find(ch2, p3 + 1) b = 1 if p3 + 1 != p4: b = int(inp[p3 + 1:p4]) p5 = inp.find(')') p6 = inp.find('%') n = 1 if p5 + 1 != p6: n = int(inp[p5 + 2:p6]) p = int(inp[p6 + 1:]) # print(a, b, n, p) ans = "0" if ch1 == ch2: a += b # print(a, n, p) con = pow(a, n, p) if con > 0: ans = "" if con != 1: ans = str(con) + "*" if n == 1: ans += ch1 + '%' + str(p) else: ans += ch1 + '^' + str(n) + '%' + str(p) else: cons = [] cnt = 0 for i in range(n + 1): cons.append(C(n, i) * pow(a, n - i, p) * pow(b, i, p) % p) if cons[-1] != 0: if cnt != 0: ans += "+" + poly(cons[-1], ch1, ch2, n - i, i) else: ans = poly(cons[-1], ch1, ch2, n - i, i) cnt += 1 if cnt > 1: ans = "(" + ans + ")" if cnt > 0: ans += "%" + str(p) print(inp + " = " + ans)
C++(clang++ 11.0.1) 解法, 执行用时: 12ms, 内存消耗: 7708K, 提交时间: 2023-03-30 18:14:27
#include <bits/stdc++.h> #define ios ios::sync_with_stdio(0),cin.tie(0),cout.tie(0) #define int long long #define debug(a) cout<<#a<<": "<<a<<"\n" using namespace std; const int N=1e3+10; string a,res; int _2=0,_5=0,n=0,p=0; int len; int fpow(int a,int b) { int res=1,tmp=a%p; while(b) { if(b&1)res=res*tmp%p; tmp=tmp*tmp%p; b>>=1; } return res; } int c[N][N]; void init() { c[0][0]=1; for(int i=1;i<N;i++){ for(int j=0;j<=i;j++){ if(j==0)c[i][j]=1; else c[i][j]=(c[i-1][j-1]+c[i-1][j])%p; } } } signed main() { ios; cin>>a; len=a.size(); char x,y; int i=0,j; for(;i<len;i++)if(islower(a[i]))break; if(i==1)_2=1; else _2=stoi(a.substr(1,i-1)); x=a[i]; j=i+1; i++; for(;i<len;i++)if(islower(a[i]))break; if(i==j+1)_5=1; else _5=stoi(a.substr(j+1,i-j)); y=a[i]; j=i+1; i++; for(;i<len;i++)if(a[i]=='%')break; if(i==j+1)n=1; else n=stoi(a.substr(j+2,i-j-2)); j=i+1; // cout<<i<<"\n"; p=stoi(a.substr(i+1)); if(x==y) { int f=fpow(_2+_5,n); if(f==0)return cout<<a+" = 0\n",0; else if(f==1)res+=x; else res+=to_string(f)+'*'+x; if(n>1)res+='^'+to_string(n); cout<<a+" = "+res+'%'<<p<<"\n"; } else { //一般情况 init(); vector<string>tmp; for(int i=n; i>=0; i--) { //x的幂 int f=0; f=fpow(_2,i)*fpow(_5,n-i)%p*c[n][i]%p; string now; if(f==0)continue; if(f!=1)now+=to_string(f); if(i>1)now=now+'*'+x+'^'+to_string(i); else if(i==1)now=now+'*'+x; if(n-i>1)now=now+'*'+y+'^'+to_string(n-i); else if(n-i==1)now=now+'*'+y; reverse(now.begin(),now.end()); while(now.back()=='*')now.pop_back(); reverse(now.begin(),now.end()); tmp.emplace_back(now); } int sz=tmp.size(); if(sz==0)return cout<<a+" = 0\n",0; for(int i=0; i<sz; i++) { if(i)res=res+'+'; res+=tmp[i]; } if(sz==1)cout<<a+" = "+res+'%'<<p<<"\n"; else cout<<a+" = "+'('+res+')'+'%'<<p<<"\n"; } }
pypy3 解法, 执行用时: 651ms, 内存消耗: 30772K, 提交时间: 2023-03-24 20:04:05
c = lambda n, m: 1 if m == 0 else c(n-1, m-1) * n // m import re pattern = re.compile(r'\((\d*)([a-z])\+(\d*)([a-z])\)\^*(\d*)\%(\d*)') s = input() l = [] for i in pattern.match(s).groups(): try: l.append(int(i)) except: if i: l.append(i) else: l.append(1) a, x, b, y, n, p = int(l[0]), l[1], int(l[2]), l[3], int(l[4]), int(l[5]) ans = '' if x != y: cnt = 0 for i in range(n+1): temp = '' flag = 0 k = a**(n-i)*b**i*c(n, i) % p if k == 0: continue elif k > 1: temp += str(k) flag = 1 if (n-i) == 1: if flag: temp += '*' temp += x flag = 1 elif (n-i) > 1: if flag: temp += '*' temp += x + '^' + str(n-i) flag = 1 if i == 1: if flag: temp += '*' temp += y elif i > 1: if flag: temp += '*' temp += y + '^' + str(i) if ans: ans += '+' if temp: cnt += 1 ans += temp if ans and cnt > 1: ans = '(' + ans + ')' + '%' + str(p) elif ans and cnt == 1: ans += '%' + str(p) else: ans = '0' else: k = (a+b)**n % p flag = 0 if k == 0: print(s + ' = 0') exit() elif k > 1: ans += str(k) flag = 1 if n == 1: if flag: ans += '*' ans += x elif n > 1: if flag: ans += '*' ans += x + '^' + str(n) ans += '%' + str(p) print(s + ' = ' + ans)