NC20088. [HNOI2011]勾股定理
描述
沫沫最近在研究勾股定理。对于两个正整数 A 与 B,若存在正整数 C 使得 A2+B2=C2,且 A 与 B 互质,则称(A,B)为一个互质勾股数对。
有一天,沫沫得到了 N 根木棍,其长度都是正整数,她准备从中挑选出若干根木棍来玩拼图游戏,为了使拼出的图案有凌乱美,她希望挑选出的木棍中任意两根的长度均不是互质勾股数对。现在,沫沫想知道有多少种满足要求的挑选木棍的方案。由于答案可能很大,你只要输出答案对 10^9+7 取模的结果。
输入描述
从文件input.txt中读入数据,输入文件第一行是一个正整数N,表示共有多少根木棍。
输入文件第二行是用空格隔开的N个正整数h1, h2, …, hN,其中对1≤i≤N,hi表示第i根木棍的长度。
输入的数据保证30%的数据满足对1≤i≤N有1≤hi≤3000,
另外30%的数据满足对1≤i≤N有1≤hi≤200000,
剩下的40%的数据满足对1≤i≤N有20000≤hi≤1000000,
100%的数据满足N≤1000000。
输出描述
输出文件 output.txt 仅包含一个非负整数,表示满足要求的挑选木棍的方案数对 109+7 取模的结果。
示例1
输入:
4 5 12 35 5
输出:
8
说明:
样例解释:(5,12)与(12,35)是互质勾股数对,故满足要求的挑选木棍的方案有8种,即:C++(clang++ 11.0.1) 解法, 执行用时: 536ms, 内存消耗: 51136K, 提交时间: 2022-12-05 16:05:23
#include<bits/stdc++.h> using namespace std; const int mod = 1e9 + 7; int PW2[1000005]; const int maxN = 1e6; int n; int num[1000005]; vector<int> to[1000005]; bool vis[1000005]; bool ins[1000005]; int sat[1000005]; vector<int> QE; void dfs_init( int x, int f ) { vis[x] = true; for( auto N : to[x] ) if( N ^ f ) { if( !vis[N] ) dfs_init( N, x ); else { if( !ins[x] ) QE.push_back( x ); if( !ins[N] ) QE.push_back( N ); ins[x] = ins[N] = true; } } } int dp[1000005][2]; int des[1000005]; int pnt = 0; int dfs_dp( int x ) { dp[x][0] = 1; dp[x][1] = PW2[ num[x] ] - 1; des[x] = pnt; for( auto N : to[x] ) if( des[N] ^ pnt ) { dp[x][0] = 1ll * dp[x][0] * dfs_dp( N ) % mod; dp[x][1] = 1ll * dp[x][1] * dp[N][0] % mod; } if( sat[x] == 1 ) dp[x][0] = 0; if( sat[x] == -1 ) dp[x][1] = 0; return ( dp[x][0] + dp[x][1] ) % mod; } bool check() { for( auto P : QE ) for( auto N : to[P] ) { if( sat[P] == 1 and sat[N] == 1 ) return false; } return true; } int query(int x) { QE.clear(); dfs_init( x, x ); int ans = 0; int len = 1 << QE.size(); for( int i = 0; i < len; i ++ ) { for( int j = 0; j < QE.size(); j ++ ) { sat[ QE[j] ] = (i & (1 << j)) ? 1 : -1; } if( check() ) pnt ++, ( ans += dfs_dp( x ) ) %= mod; } for( int i = 0; i < QE.size(); i ++ ) sat[ QE[i] ] = 0; return ans; } int main(){ int n; cin >> n; PW2[0] = 1; for(int i = 1;i <= n;i ++) PW2[i] = PW2[i - 1] * 2 % mod; for(int i = 1;i <= n;i ++) { int x; cin >> x; num[x] ++; } for(int i = 1;i * i <= maxN;i ++) for(int j = i + 1;2 * i * j <= maxN;j ++) { if( j * j > 2 * maxN ) break; int x = j * j - i * i, y = 2 * i * j; if( x > maxN or y > maxN ) continue; if( !num[x] or !num[y] or __gcd( x, y ) != 1 ) continue; to[x].push_back(y); to[y].push_back(x); } int ans = 1; for(int i = 1;i <= maxN;i ++) if( num[i] and !vis[i] ) { ans = 1ll * ans * query(i) % mod; } cout << ( ans - 1 + mod ) % mod; return 0; }