简要题意

给你一个 \(n\) 个元素的集合,它由前 \(n\) 个正整数构成。你需要求出它有多少个非空子集,满足若 \(x\) 在这个子集中, \(2x,3x\) 不能在子集中。由于答案可能很大,你只需要对 \(10^9+1\) 取模即可。

\(1 \leq n \leq 10^5\)

思路

这道题的构造思想非常清奇。我们可以构造一个矩阵:

\[A=\begin{bmatrix}
1 & 2 & 4 & 8 & \cdots\\
3 & 6 & 12 & 24 & \cdots\\
9 & 18 & 36 & 72 & \cdots\\
27 & 54 & 108 & 216 & \cdots\\
81 & 162 & 324 & 648 & \cdots \\
\cdots & \cdots & \cdots & \cdots & \cdots
\end{bmatrix}
\]

具体来说:

\[A_{i,j}=\begin{cases}
& 1 & i=1,j=1\\
& 3\cdot f_{i-1,j} & i\neq 1,j=1\\
& 2\cdot f_{i,j-1} & \text{otherwise}
\end{cases}
\]

这样子我们就将原问题转化成了给出一个矩阵,如果你选择 \((i,j)\),就不能选择 \((i-1,j)\)\((i,j-1)\),求方案数。这个问题可以使用状压 DP 解决。

我们设 \(f_{i,S}\) 为考虑到第 \(i\) 行,这一行选择 \(S\) 中的元素的方案数。不难发现:

\[f_{i,S}=\begin{cases}
& \operatorname{valid}(S) & i=1 \\
& \sum\limits_{T\cup S=\emptyset,\operatorname{valid}(T)}{f_{i-1,T}} & \text{otherwise}
\end{cases}
\]

其中 \(\operatorname{valid}(S)\) 是指选择该行中 \(S\) 中的元素是否合法,也就是两两是否相邻。用状态压缩的话可以简单地这样实现:

\[\operatorname{valid}(S)=S\&(S>>1)?0:1
\]

当然左移也可以。其实原理就是将原本一样的位错开,相邻的进行与运算。

最后注意这个表不是所有元素都会覆盖到(具体来说,只会覆盖到 \(\forall i,j\in \mathbb{N},2^{i}3^{j}\))。所以我们如果遇到了一个没有被之前覆盖到的元素,我们需要将它设为 \(f_{1,1}\) 重新生成矩阵 \(A\),并重新 DP,最后按照乘法原理(因为这些都可以同时选)将结果累乘。

然后这道题就做完了。最后提醒大家一句,位运算优先级比较低,建议大家勤添括号。

代码

点击查看代码
#include <bits/stdc++.h>
#define int long long
#define valid(x) (x&(x>>1)?0:1)
using namespace std;

const int mod = 1e9+1;
int M(const int x){return (x%mod+mod)%mod;}
const int N = 1e5+5;

int n,vis[N],a[25][25],col[N],f[25][1000005],final,ans=1;

inline void init(int x){
    for(int i=1;i<=11;i++){
        if(i==1) a[i][1]=x;
        else a[i][1]=a[i-1][1]*3;
        if(a[i][1]>n) break;
        vis[a[i][1]]=1;col[i]=1;final=i;
        for(int j=2;j<=18;j++){
            a[i][j]=a[i][j-1]<<1;
            if(a[i][j]>n) break;
            col[i]=j;vis[a[i][j]]=1;
        }
    }
}

inline int dp(int x){
    for(int i=0;i<(1<<col[1]);i++){
        f[1][i]=valid(i);
    }
    for(int i=2;i<=final;i++){
        for(int j=0;j<(1<<col[i]);j++){
            if(!valid(j)) continue;
            f[i][j]=0;
            for(int k=0;k<(1<<col[i-1]);k++){
                if(valid(k) && ((k&j) == 0)) f[i][j]=M(f[i][j]+f[i-1][k]);
            }
        }
    }
    int ret=0;
    for(int i=0;i<(1<<col[final]);i++) ret=M(ret+f[final][i]);
    return ret;
}

signed main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        if(vis[i]) continue;
        init(i);ans=M(ans*dp(i));
    }
    cout<<ans;
    return 0;
}