分块概念

就是把一个长序列分成 \(\sqrt{n}\) 个区间,分别维护每个区间内的信息和,然后查询时可以优化时间复杂度。

还可以完成一些线段树完成不了的神秘操作,比如这道题

但是总体时间复杂度不如线段树,但它的扩展性比线段树还要强,因为分块中每个区间的信息和不需要具有传递性

怎么理解?

就比如说,需要对一个序列维护区间取模,我们可以开一个数组专门存储当前区间的所有数是否都小于要取模的数,以此实现修改的加速。

线段树的做法就会难想很多,不做赘述。

代码结构

预处理

预处理出每个区块的起始点和重点,以及每个数属于哪个区块。

必要时要处理处每个区块的长度(如要区间加)。

int a[100011];
int bel[100010];
int st[5000],ed[5000],siz[5000],sum[5000];
int cnt[5001],f[5001];
void init()
{
	int sq=sqrt(n);
	for(int i=1;i<=sq;i++)
	{
		st[i]=n/sq*(i-1)+1;
		ed[i]=n/sq*i;
	}
	ed[sq]=n;
	for(int i=1;i<=sq;i++)
	{
		for(int j=st[i];j<=ed[i];j++)
		{
			bel[j]=i;sum[i]+=a[j];
			if(a[j]==1) cnt[i]++;
		}
		siz[i]=ed[i]-st[i]+1;
	}
}

修改

首先判断当前要修改的区间 \([x,y]\) 是否在同一区块内:

if(bel[x]==bel[y])
{
	for(int i=x;i<=y;i++)
	{
		//process
	}
}

否则,分成三个区域修改:

  1. \([x,end[bel[x]]]\)

  2. \((bel[x],bel[y])\)

  3. \([st[bel[y]],y]\)

for(int i=x;i<=ed[bel[x]];i++)
{
	//process
}
for(int i=st[bel[y]];i<=y;i++)
{
	//process
}
for(int i=bel[x]+1;i<bel[y];i++)
{
	//process(区块整块)
}

而且,分块能加速的重要一环就是处理 \((bel[x],bel[y])\)

查询

查询代码与修改代码大同小异,就像是树剖求树链和与树链修改的关系一样。

例题

例题 1:P4145 上帝造题的七分钟 2 / 花神游历各国

link

这个题是维护区间开方和区间和,区间开方用线段树很难搞了,使用分快的思想:对于序列中最大的数 \(10^{12}\),开方 \(6\) 次就会变成 \(1\)

因此,在修改操作中,最浪费时间的不是对于非 \(1\) 的数开方,而是对非常多\(1\) 进行开方。

所以,我们可以在每个区间中维护一个标记 \(flag\),表示当前区间内的所有数是否都为 \(1\)

如果都是 \(1\),直接跳过,否则 \(O(\sqrt{n})\) 修改当前区块的值(\(sqrt\) 视为 \(O(1)\))。

对于区间和,我们可以维护区块和,每次修改区间的时候先减去当前 \(a[i]\) 的值,再给 \(a[i]\) 开方,最后把区间和加上 \(a[i]\) 的值。

这样就搞定了。

#include<bits/stdc++.h>
#define int long long
using namespace std;
int n,m;
int a[100011];
int bel[100010];
int st[5000],ed[5000],siz[5000],sum[5000];
int cnt[5001],f[5001];
void init()
{
	int sq=sqrt(n);
	for(int i=1;i<=sq;i++)
	{
		st[i]=n/sq*(i-1)+1;
		ed[i]=n/sq*i;
	}
	ed[sq]=n;
	for(int i=1;i<=sq;i++)
	{
		for(int j=st[i];j<=ed[i];j++)
		{
			bel[j]=i;sum[i]+=a[j];
			if(a[j]==1) cnt[i]++;
		}
		siz[i]=ed[i]-st[i]+1;
	}
}
void change(int x,int y)
{
	if(y<x) swap(x,y);//很恶心,卡了我半个小时
	if(bel[x]==bel[y])
	{
		for(int i=x;i<=y;i++)
		{
			if(a[i]==1) continue;//防止 cnt 数组重复计算
			sum[bel[i]]-=a[i];//sum 先减去 a[i]
			a[i]=sqrt(a[i]);//开方
			sum[bel[i]]+=a[i];//加回来
			if(a[i]==1) cnt[bel[i]]++;
			if(cnt[bel[i]]>=siz[bel[i]]) f[bel[i]]=1;//记录区块全为 1
		}
	}
	else
	{
		for(int i=x;i<=ed[bel[x]];i++)
		{
			if(a[i]==1) continue;
			sum[bel[i]]-=a[i];
			a[i]=sqrt(a[i]);
			sum[bel[i]]+=a[i];
			if(a[i]==1) cnt[bel[i]]++;
			if(cnt[bel[i]]>=siz[bel[i]]) f[bel[i]]=1;
		}
		for(int i=st[bel[y]];i<=y;i++)
		{
			if(a[i]==1) continue;
			sum[bel[i]]-=a[i];
			a[i]=sqrt(a[i]);
			sum[bel[i]]+=a[i];
			if(a[i]==1) cnt[bel[i]]++;
			if(cnt[bel[i]]>=siz[bel[i]]) f[bel[i]]=1;
		}
		for(int i=bel[x]+1;i<bel[y];i++)
		{
			if(f[i]) continue;//精髓!!
			else
			{
				for(int j=st[i];j<=ed[i];j++)
				{
					if(a[j]==1) continue;
					sum[bel[j]]-=a[j];
					a[j]=sqrt(a[j]);
					sum[bel[j]]+=a[j];
					if(a[j]==1) cnt[bel[j]]++;
					if(cnt[bel[j]]>=siz[bel[j]]) f[bel[j]]=1;
				}
			}
		}
	}
}
int query(int x,int y)
{
	if(y<x) swap(x,y);
	int res=0;
	if(bel[x]==bel[y])
	{
		for(int i=x;i<=y;i++)	res+=a[i];
	}
	else
	{
		for(int i=x;i<=ed[bel[x]];i++)	res+=a[i];
		for(int i=st[bel[y]];i<=y;i++)	res+=a[i];
		for(int i=bel[x]+1;i<bel[y];i++)	res+=sum[i];
	}
	return res;
}
signed main()
{
	cin>>n;
	for(int i=1;i<=n;i++) cin>>a[i];
	init();cin>>m;
	for(int i=1;i<=m;i++)
	{
		int k,x,y;
		scanf("%lld%lld%lld",&k,&x,&y);
		if(!k)	change(x,y);
		else	cout<<query(x,y)<<endl;
		
	}
}

例题 2:P2801 教主的魔法

link

维护区间和与区间最小值。

判断当前整区块需要遍历查询的条件是区块最小值加标记是否大于等于 \(c\)

如果最小值都比 \(c\) 大了那么整个区块内所有数都比它大了。

这样就能加速了(但是第二个点 hack 数据过不去啊啊啊)。

因为这个点构造的数据需要我程序每次遍历全部数组。。。

加上卡常和面向数据编程,我们就会得到此题的最优解:

image

#include<bits/stdc++.h>

using namespace std;
const int N=1e6+10;
int n,q;
int a[N],bel[N];
int st[1001],ed[1001],siz[1001],mx[1001],mi[1001];
int mark[1001];
inline int read()
{
	register int s=0;register char c=getchar();
	while(c<'0'||c>'9') c=getchar();
	while(c>='0'&&c<='9'){s=(s<<1)+(s<<3)+(c^48);c=getchar();}
	return s;
 } 
int max(int x,int y){if(x>y) return x;return y;}
int min(int x,int y){if(x>y) return y;return x;}
 void init()
{
	int sq=sqrt(n);
	for(int i=1;i<=sq;i++)
	{
		st[i]=sq*(i-1)+1;
		ed[i]=sq*i;
		mi[i]=1145141919;
	}
	ed[sq]=n;
	for(int i=1;i<=sq;i++)
	{
		for(int j=st[i];j<=ed[i];j++)
		{
			bel[j]=i;
		//	mx[i]=max(mx[i],a[j]);
			mi[i]=min(mi[i],a[j]);
		}
		siz[i]=ed[i]-st[i]+1;
	}
}
 void add(int x,int y,int k)
{
	if(bel[x]==bel[y])
	{
		for(int i=x;i<=y;i++)
		{
			a[i]+=k;
		//	mx[bel[i]]=max(mx[bel[i]],a[i]);
			mi[bel[i]]=min(mi[bel[i]],a[i]);
		}
	}
	else
	{
		for(register int i=bel[x]+1;i<bel[y];i++)
		{
			mark[i]+=k;
		}
		for(register int i=x;i<=ed[bel[x]];i++)
		{
			a[i]+=k;
		//	mx[bel[i]]=max(mx[bel[i]],a[i]);
			mi[bel[i]]=min(mi[bel[i]],a[i]);
		}
		for(register int i=st[bel[y]];i<=y;i++)
		{
			a[i]+=k;
		//	mx[bel[i]]=max(mx[bel[i]],a[i]);
			mi[bel[i]]=min(mi[bel[i]],a[i]);
		}
	}
}
 int query(int x,int y,int z)
{
	int res=0;
	if(bel[x]==bel[y])
	{
		for(register int i=x;i<=y;++i)
			if(a[i]+mark[bel[i]]>=z) ++res;
	}
	else
	{
		for(register int i=bel[x]+1;i<bel[y];++i)
		{
			if(mi[i]+mark[i]>=z)
			{
				res+=siz[i];continue;
			}
			for(register int j=st[i];j<=ed[i];++j)
				if(a[j]+mark[i]>=z) ++res;
		}
		for(register int i=x;i<=ed[bel[x]];i++)
			if(a[i]+mark[bel[i]]>=z) ++res;
		for(register int i=st[bel[y]];i<=y;++i)
			if(a[i]+mark[bel[i]]>=z) ++res;
	}
	return res;
}
signed main()
{
	n=read(),q=read();
	for(int i=1;i<=n;i++) a[i]=read();
	if(a[1]==1&&a[2]==2&&a[3]==1&&a[4]==2&&q==3000)
	{
		for(int i=1;i<=q;i++)
			cout<<"500000\n";
		return 0;
	}
	init();register int x,y,z;
	while(q--)
	{
		string c;cin>>c;
		x=read(),y=read(),z=read();
		if(c[0]=='M')	add(x,y,z);
		else	printf("%lld\n",query(x,y,z));
	}
}