简要题意

给定 \(n\)\(k\) 和值域 \([1,k]\)\(n\) 个整数 \(h_i\),求有多少个长为 \(n\) 的整数序列 \(a\) 满足值域 \([1,k]\),且 \(\sum\limits_{i=1}^n[a_i=h_i]<\sum\limits_{i=1}^n[a_i=h_{(i\bmod{n})+1}]\)。答案对 \(998244353\) 取模。

\(1 \leq n \leq 2\times 10^5,1 \leq k \leq 10^9\)

思路

如果你不会这道题的 Easy Version,可以先看看 这篇博客。这道题与简单版的最大区别是这道题无法使用 \(O(n^2)\) 的 DP。

简单版的解法是:令 \(f_{i,j}\) 为截止到第 \(i\) 位,比原来的少错 \(j\) 个数的方案数。然后可以 \(O(1)\) 转移。总时间复杂度 \(O(n^2)\)

  • 结论 1: \(f_{n,i}=f_{n,-i}\)。因为对 \(j\) 道题与错 \(j\) 道题的方案本质上是对称的。
  • 结论 2:\(\sum_{i=1}^{n}f_{n,i}=\dfrac{k^n-f_0}{2}\)。因为 \(f_{n,i}=f_{n,-i}\),所以 \(k^n-f_0=\sum_{i=-n}^{-1}f_i+\sum_{i=1}^{n}f_i=2\sum_{i=1}^{n}f_{n,i}\)

所以我们其实要求的是 \(f_{n,0}\)。也就是不改变正确个数的方案总数。

不改变正确的个数的方案总数一定是变换时正确了 \(k\) 个,错了 \(k\) 个。且 \(0\leq k\leq\lfloor\dfrac{M}{2}\rfloor\),其中 \(M\)\(h\) 中变换后不同的个数。

所以:

\[f_{n,0}=\sum_{i=1}^{\frac{M}{2}}{F(i)}
\]

\(F(i)\) 是给出正确 / 错误的个数,求方案数。我们考虑如何求 \(F(i)\)

考虑 \(F(i)\) 的来自三个部分的贡献:

  1. 正确的 \(i\) 个,为 \(\binom{M}{i}\)
  2. 错误的 \(i\) 个,为 \(\binom{M-i}{i}(k-2)^{M-2i}\)。注意原来的和错误的都要错,所以是 \(M-2i\)。要排除正确的和本身的,所以是 \(k-2\)
  3. 没有贡献的,为 \(k^{M}\)

所以 \(f_{n,0}\) 就求出来了:

\[f_{n,0}=\sum_{i=1}^{\frac{M}{2}}\binom{M}{i}\binom{M-i}{i}(k-2)^{M-2i}k^{M}
\]

预处理阶乘、阶乘逆元、\(k\) 的幂和 \(k-2\) 的幂,可以做到严格线性。

时间复杂度 \(O(n)\)

代码

#include <bits/stdc++.h>
#pragma GCC optimize("Ofast", "inline", "-ffast-math")
#pragma GCC target("avx,sse2,sse3,sse4,mmx")
#define int long long
using namespace std;

const int mod = 998244353;
const int N = 2e5+5;

int n,k,m,h[N],fact[N]={1}, inv[N]={1,1}, ret, k2p[N]={1}, kp[N]={1};

int M(const int x){return (x%mod+mod)%mod;}
int c(int m, int n){assert(n-m>=0);return M(fact[n]*M(inv[n-m]*inv[m]));}

signed main(){
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
    cin>>n>>k;
    if(k==1){cout<<0;return 0;}
    for(int i=1;i<=n;i++) cin>>h[i];
    for(int i=1;i<=n;i++) m+=(h[i]!=h[i%n+1]);
    for(int i=1;i<=n;i++) fact[i]=M(fact[i-1]*i);
    for(int i=2;i<=n;i++) inv[i]=M(M(mod-mod/i)*inv[mod%i]);
    for(int i=2;i<=n;i++) inv[i]=M(inv[i-1]*inv[i]);
    for(int i=1;i<=n;i++) kp[i]=M(kp[i-1]*k);
    for(int i=1;i<=n;i++) k2p[i]=M(k2p[i-1]*(k-2));
    for(int i=0;i<=(m>>1);i++){
        ret=M(ret+M(M(c(i,m)*c(i,m-i))*M(k2p[m-(i<<1)]*kp[n-m])));
    }
    cout<<M((kp[n]-ret)*inv[2])<<'\n';;
    return 0;
}