题目分析:

(我竟然可以独立做出来正赛的题,表示震惊)
这个题面显然就很神仙,不好分析,我们进行转化一下题意:
给定一个 \(n \times m\) 的矩阵,对于每一行我们可以选择一个数也可以不选择,要求至少选择一个数,而且对于任意一列其被选择的次数都必须不超过总选择的点数的一半,一次选择的方案数就是所有选择的数的乘积,求所有的选择方案的方案数。

对于 \(m \le 3\),直接记录每一列选择了多少个然后去转移就好了,复杂度 \(O(n^{m+1})\)\(40\) 分到手

看到这个小于等于的限制,应该会象征性地向容斥想想,发现容斥出来也不好做的样子(?)
我们发现容斥的话其实就是无限制的方案数减去大于 \(\frac{k}{2}\) 的方案数,而我们大于 \(\frac{k}{2}\) 最多只有一个,也就是说我们容斥之后只有 \(O(n)\) 个状态是有用的,也就是可以去直接枚举哪一列大于 \(\frac{k}{2}\)
最后用所有列都没有限制的方案,减去这些所有的方案数的和就好了。

如果能想到这个结论应该就可以发现这肯定是一个很重要的结论,因为看上去就非常有用,这样的话大概就可以想到下面这种思路:
枚举列 \(m\) 大于 \(\frac{k}{2}\),设 \(dp[i][j][k]\) 表示前 \(i\) 行,选择了 \(j\) 个,第 \(m\) 列选择了 \(k\) 个的方案数,转移即:

\[dp[i][j][k] = dp[i-1][j][k] + s[i][m] \times dp[i-1][j-1][k] + a[i][m] \times dp[i-1][j-1][k-1]
\]

我们上面设 \(s[i][j]\) 表示第 \(i\) 行除去第 \(j\) 个数的和。

这样的话就可以做到 \(O(n^3m)\),就可以得到 \(84\)可以跑路了

发现这个方程其实推到这里就没什么前途了,必须另外想想别的办法,因为其实这种状态最后统计答案也很麻烦。稍微想想就能发现,其实所谓的超过一半,就意味着比其他的列选择的多啊,类似绝对众数那样,所以只需要维护第 \(m\) 列被选择的次数与其他列被选择的次数的差就好了。
也就是设 \(f[i][j]\) 表示前 \(i\) 行,第 \(m\) 列被选择的次数减去其他列被选择的次数为 \(j\) 的方案数,转移:

\[f[i][j] = f[i-1][j] + s[i][m] \times f[i-1][j+1] + a[i][m] \times f[i-1][j-1]
\]

最后答案的统计就是所有 \(\sum_{j = 1}^{n} dp[n][j]\)
这样的复杂度就是 \(O(n^2m)\)

还剩一个问题就是所有列都无限制的方案数,这个方案数显然就是:\(-1 + \prod_{i=1}^n (s[i] + 1)\),因为每行都可以不选,但是不能都不选。

代码:

点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int MOD = 998244353,N = 105,M = 2005;
int sum[N][M],a[N][M];
int f[N][2*N];
int mod(int x){
	return (x%MOD + MOD)%MOD; 
}
signed main(){
	int n,m;scanf("%lld%lld",&n,&m);
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			scanf("%lld",&a[i][j]);
			sum[i][0] = mod(sum[i][0] + a[i][j]);
		}
	}
	for(int i=1; i<=n; i++){
		for(int j=1; j<=m; j++){
			sum[i][j] = mod(sum[i][0] - a[i][j]);
		}
	}
	int ans = 1;
	for(int i=1; i<=n; i++)	ans = mod(ans * (sum[i][0] + 1));
	ans = mod(ans - 1);
	for(int tmp = 1;tmp <= m; tmp++){
		memset(f,0,sizeof(f));
		f[0][n] = 1;
		for(int i=1; i<=n; i++){
			for(int j=n-i; j<=n+i; j++){
				f[i][j] = mod(f[i-1][j] + f[i-1][j+1] * sum[i][tmp] + f[i-1][j-1] * a[i][tmp]);
			}
		}
		for(int i=1; i<=n; i++)	ans = mod(ans - f[n][n+i]);
	}
	printf("%lld\n",ans); 
	return 0;
}