列表

详情


NC25255. three kingdom

描述

    天才程序员菜哭武、张老师和石头三个人已经厌倦了N皇后问题。他们在一个n*m的棋盘上玩三人象棋。经过一番激烈的拼杀,最后只剩下三个国王。
    在这个n*m的棋盘上要放置三个国王,国王之间不能互相攻击。国王之间没有区别。国王的攻击范围为切比雪夫距离1以内的格子,即假设一个国王的位置是(r,c),那么可以攻击到所有 max(|r-i|, |c-j|)=1 的位置(i,j),而且不能有多个国王放置在同一个格子当中。
菜哭武很快发现了怎么计算有多少种放置方法,但是他太懒了,想让你帮他算。你只需要输出方案数对1,000,000,007取模的结果即可。

输入描述

    第一行两个整数t,表示有t组数据(1≤t≤105)
    接下来t行,每行两个整数n,m(1<=n,m<=109),表示棋盘的长和宽

输出描述

对于每一组数据,输出一个整数,表示放置方案数对1,000,000,007取模的结果

示例1

输入:

3
2 2
3 3
4 4

输出:

0
8
140

原站题解

上次编辑到这里,代码来自缓存 点击恢复默认模板

Python3(3.5.2) 解法, 执行用时: 1930ms, 内存消耗: 4564K, 提交时间: 2019-04-20 18:52:47

t = int(input())
modn = 1000000007
for i in range(t):
    n, m = map(int, input().split())
    ans = n*m *(n*m-1)*(n*m-2) // 6 % modn
    if n>2:
        ans -= (n-2)*m
        ans %= modn

    if m>2:
        ans -= n*(m-2)
        ans %= modn

    if (n>1) and (m>1):
        ans -= 4*(n-1)*(m-1)
        ans %= modn
    
    if (n>2) and (m>1):
        ans -= 6*(n-2)*(m-1)
        ans %= modn
    
    if (n>1) and (m>2):
        ans -= 6*(n-1)*(m-2)
        ans %= modn
    

    if (n>2) and (m>2):
        ans -= 2*(n-2)*(m-2)
        ans %= modn
    
    if n>1 and m>1:
        tmp = 0
        if n>2 and m>2:
            tmp += (n*m-9)*2 
            tmp += (n*m-8)*2 
        elif (n == 2 and m > 2) or (n > 2 and m == 2):
            tmp += (n*m-6)*2 

        if n > 3:
            if m > 2:
                tmp += (n*m-11)*(n-3)*2
            else:
                tmp += (n*m-8)*(n-3)
        if m > 3:
            if n > 2:
                tmp += (n*m-11)*(m-3)*2 
            else:
                tmp += (n*m-8)*(m-3)
        if n>2 and m>2:
            tmp += (n*m-14)*(n-3)*(m-3) % modn
    
        ans -= tmp*2
        ans %= modn
    
    if m > 1:
        tmp = 0
        if n > 1 and m > 2:
            tmp += (n*m-6)*4 
        elif n > 1 and m == 2:
            tmp += (n*m-4)*2
        elif n == 1 and m > 2:
            tmp += (n*m-3)*2
        
        if n > 2:
            if m > 2:
                tmp += (n*m-9)*(n-2)*2
            else:
                tmp += (n*m-6)*(n-2)
        if m > 3:
            if n > 1:
                tmp += (n*m-8)*(m-3)*2 
            else:
                tmp += (n*m-4)*(m-3)

        if n>1 and m>2:
            tmp += (n*m-12)*(n-2)*(m-3) % modn
        
        ans -= tmp
        ans %= modn
    
    if n > 1:
        tmp = 0
        if n > 2 and m > 1:
            tmp += (n*m-6)*4 
        elif n == 2 and m > 1:
            tmp += (n*m-4)*2
        elif n > 2 and m == 1:
            tmp += (n*m-3)*2
        
        if m > 2:
            if n > 2:
                tmp += (n*m-9)*(m-2)*2
            else:
                tmp += (n*m-6)*(m-2)
        if n > 3:
            if m > 1:
                tmp += (n*m-8)*(n-3)*2
            else:
                tmp += (n*m-4)*(n-3)
        
        if n > 2 and m > 1:
            tmp += (n*m-12)*(m-2)*(n-3) % modn
        
        ans -= tmp
        ans %= modn

    print(ans)

C++11(clang++ 3.9) 解法, 执行用时: 207ms, 内存消耗: 1356K, 提交时间: 2019-04-20 15:55:09

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<vector>
using namespace std;
typedef long long LL;
const int MOD=1e9+7;
int add (int x,int y)   {x=x+y;return x>=MOD?x-MOD:x;}
int mul (int x,int y)   {return (LL)x*y%MOD;}
int dec (int x,int y)   {x=x-y;return x<0?x+MOD:x;}
int Pow (int x,int y)
{
	if (y==0) return 1;
	if (y==1) return x;
	int lalal=Pow(x,y>>1);
	lalal=mul(lalal,lalal);
	if (y&1) lalal=mul(lalal,x);
	return lalal;
}
int main()
{
	int T;
	scanf("%d",&T);
	while (T--)
	{
		int n,m;
		scanf("%d%d",&n,&m);
		int cnt=mul(n,m);
		if (cnt<=2)	{printf("0\n");continue;}
		int ans=mul(cnt-2,mul(cnt,cnt-1));
		//printf("%d\n",ans);
		int tmp=0;
		for (int a=-1;a<=1;a++)
		for (int b=-1;b<=1;b++)
		{
			if (a==0&&b==0) continue;
			int nn=max(0,a)-min(0,a);
			int mm=max(0,b)-min(0,b);
			if (n<nn) continue;
			if (m<mm) continue;
			tmp=add(tmp,mul(n-nn,m-mm));
		}
		//printf("%d\n",tmp);
		tmp=mul(tmp,dec(cnt,2));
		ans=dec(ans,mul(3,tmp));
		tmp=0;
		for (int a=-1;a<=1;a++)
		for (int b=-1;b<=1;b++)
		for (int c=-1;c<=1;c++)
		for (int d=-1;d<=1;d++)
		{
			if (a==c&&b==d) continue;
			if (a==0&&b==0) continue;
			if (c==0&&d==0) continue;
			int nn=max(max(0,a),c)-min(min(0,a),c);
			int mm=max(max(0,b),d)-min(min(0,d),b);
			if (n<nn) continue;
			if (m<mm) continue;
			tmp=add(tmp,mul(n-nn,m-mm));
		}
		//printf("%d\n",tmp);
		ans=add(ans,mul(3,tmp));
		tmp=0;
		for (int a=-1;a<=1;a++)
		for (int b=-1;b<=1;b++)
		for (int c=-1;c<=1;c++)
		for (int d=-1;d<=1;d++)
		{
			if (a==c&&b==d) continue;
			if (a==0&&b==0) continue;
			if (c==0&&d==0) continue;
			if (max(a,c)-min(a,c)>1) continue;
			if (max(b,d)-min(b,d)>1) continue;
			int nn=max(max(0,a),c)-min(min(0,a),c);
			int mm=max(max(0,b),d)-min(min(0,d),b);
			if (n<nn) continue;
			if (m<mm) continue;
			tmp=add(tmp,mul(n-nn,m-mm));
		}
		ans=dec(ans,tmp);
		ans=mul(ans,Pow(6,MOD-2));
		printf("%d\n",ans);
	}
	return 0;
}

C++14(g++5.4) 解法, 执行用时: 83ms, 内存消耗: 3320K, 提交时间: 2019-05-01 21:22:01

#include <cstdio>
using namespace std;

const long long p = 1000000007;

int main() {
    int t;
    scanf("%d", &t);
    while(t--) {
        long long n, m;
        scanf("%lld%lld", &n, &m);
        if(n*m<=2) {
            printf("0\n");
            continue;
        }
        long long tot = 1;
        int k = 6;
        for(long long i=n*m; i>=n*m-2; i--) {
            long long t = i;
            if(t%2==0 && k%2==0) {
                t /= 2;
                k /= 2;
            }
            if(t%3==0 && k%3==0) {
                t /= 3;
                k /= 3;
            }
            tot = t%p*tot%p;
        }
        long long ans = 0;
        if(m==1)
            ans = (tot-(m*n-2)%p*(n-1)%p+(n-2)+p)%p;
        else if(n==1)
            ans = (tot-(m*n-2)%p*(m-1)%p+(m-2)+p)%p;
        else {
            long long a = 8*(m-1)%p*(n-1)%p+6*(n-2)%p*(m-1)%p+6*(m-2)%p*(n-1)%p+n*(m-2)%p+m*(n-2)%p+2*(m-2)%p*(n-2)%p;
            long long b = (m*n-2)%p*((n*(m-1)%p+m*(n-1)%p+2*(n-1)%p*(m-1)%p)%p)%p;
            ans = (tot+a-b+p)%p;
        }
        printf("%lld\n", ans);
    }
    return 0;
}

上一题