ST表
原理
首先声明几点:
st[i][j]
表示这段区间上从第i个位置自己开始起,往后数2^i
个数之间的最大值lg[i]
表示不超过i的最大2次幂指数(就是2的i次方不超过i且i尽量大)- 输出RMQ最值的时候,最好用函数来写
第一部分:输入并预处理lg数组
由预备区1可以很容易推出,对于我们输入的每个数,其实就是sti,~(因为2^0=1 往后数1还是自己 这很显然吧)~
对于lg数组呢,因为它只有在2的整数次幂的位置才会更新出值,这个不太好讲,其实可以从log[0]=-1开始往后推一下(注:i>>1,i右移1,即i/2。在这里用位运算会快很多,亲测),推到16,甚至8就能看出来了。这个lg常数便是倍增优化了。
具体代码实现:
cin >> n >> m;
lg[0] = -1;
for (int i = 1; i <= n; i++) {
cin >> f[i][0];
lg[i] = lg[i >> 1] + 1;
}
第二部分:更新最值
大家可以手画个图,没错就画一条线段。用二分的方法把它覆盖住。观察每条线段的左右端点。
st[i][j]=max(st[i][j-1],st[i+(1<<(j-1))][j-1]);
请回想预备区1,则这个转移式所表达的意义就是:不妨假设从第i个位置开始,往后数2^(j-1)
的区间内的最值为m1,则sti表示的就是m1;那第二段区间的起点即为2^(j-1)+1
(注意!要+1!)依旧往后数2^(j-1)
个数,该区间内最值为m2,则st[i][j]
的值为max(m1,m2)
。
具体代码实现:
for (int j = 1; j <= lg[n]; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
又有人会问了,为什么j要在外层?
观察st[i][j]
是怎么得来的,联想一下它的具体含义。
这是因为我们在for循环生成st表时,是根据区间长度由小到大逐个生成的,对于每个起始位置都要求出能覆盖整条线段的二分方法中的最小值。只有前一个位置的值全部生成完后,才能去处理下一个点。
现在再回想一下刚才的过程,我们确实能实现log级别的预处理了。
第三部分:查询(query)
查询时又会出现很多问题:若是从i^(j-1)+1的位置往后取区间最小值,很容易取不全或者是多取。比如说查询(1,7),返回值应该是RMQ(1,4)。问题来了,第二个区间的长度是3,st数组的j究竟该取多少?1?2^1=2,取不全;2?2^2=4,会往后多取了一个数,这时答案就不对了。怎么解决呢?
我们想一下刚刚的问题根本所在,是无法确定后半区间如何覆盖。那么我们~简单粗暴一点~干脆从后往前取,设右端点为r,那st数组的第一维可以写成r-(2^z)+1,依旧要+1,z具体代表什么后面揭晓。(其实很多人也能看出来了吧)
从后往前取重复了怎么办?这没问题,取的是最值,重复区间并不影响最值。
具体代码实现:
int query(int l, int r)
{
int t = lg[r - l + 1];
return max(f[l][t], f[r - (1 << t) + 1][t]);
}
t代表的是小于等于i + (1 << (j - 1))最右边的数
也就是最短和前面一半的区间相交的地方。
例题
题目链接
题解
#include <algorithm>
#include <bitset>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
using namespace std;
#define x first
#define y second
#define endl '\n'
#define IOS \
ios_base::sync_with_stdio(0); \
cin.tie(0); \
cout.tie(0);
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef unsigned long long ull;
const int INF = 0x3f3f3f3f, mod = 1000000007;
const int N = 2e6 + 10;
int n, m;
int f[N][22];
int lg[N];
int query(int l, int r)
{
int t = lg[r - l + 1];
return max(f[l][t], f[r - (1 << t) + 1][t]);//注意在这是1<<t,不是t,减少的是长度
}
void solve()
{
cin >> n >> m;
lg[0] = -1;
for (int i = 1; i <= n; i++) {
cin >> f[i][0];
lg[i] = lg[i >> 1] + 1;
}
for (int j = 1; j <= lg[n]; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
}
}
while (m--) {
int x, y;
cin >> x >> y;
cout << query(x, y) << endl;
}
}
int main()
{
IOS;
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}