ST表

原理

image.png

首先声明几点:

  1. st[i][j]表示这段区间上从第i个位置自己开始起,往后数2^i个数之间的最大值
  2. lg[i]表示不超过i的最大2次幂指数(就是2的i次方不超过i且i尽量大)
  3. 输出RMQ最值的时候,最好用函数来写

第一部分:输入并预处理lg数组

由预备区1可以很容易推出,对于我们输入的每个数,其实就是sti,~(因为2^0=1 往后数1还是自己 这很显然吧)~

对于lg数组呢,因为它只有在2的整数次幂的位置才会更新出值,这个不太好讲,其实可以从log[0]=-1开始往后推一下(注:i>>1,i右移1,即i/2。在这里用位运算会快很多,亲测),推到16,甚至8就能看出来了。这个lg常数便是倍增优化了。

具体代码实现:

cin >> n >> m;
lg[0] = -1;
for (int i = 1; i <= n; i++) {
    cin >> f[i][0];
    lg[i] = lg[i >> 1] + 1;
}

第二部分:更新最值

大家可以手画个图,没错就画一条线段。用二分的方法把它覆盖住。观察每条线段的左右端点。

st[i][j]=max(st[i][j-1],st[i+(1<<(j-1))][j-1]);

请回想预备区1,则这个转移式所表达的意义就是:不妨假设从第i个位置开始,往后数2^(j-1)的区间内的最值为m1,则sti表示的就是m1;那第二段区间的起点即为2^(j-1)+1(注意!要+1!)依旧往后数2^(j-1)个数,该区间内最值为m2,则st[i][j]的值为max(m1,m2)

具体代码实现:

for (int j = 1; j <= lg[n]; j++) {
    for (int i = 1; i + (1 << j) - 1 <= n; i++) {
        f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
    }
}

又有人会问了,为什么j要在外层?

观察st[i][j]是怎么得来的,联想一下它的具体含义。

这是因为我们在for循环生成st表时,是根据区间长度由小到大逐个生成的,对于每个起始位置都要求出能覆盖整条线段的二分方法中的最小值。只有前一个位置的值全部生成完后,才能去处理下一个点。

现在再回想一下刚才的过程,我们确实能实现log级别的预处理了。

第三部分:查询(query)

查询时又会出现很多问题:若是从i^(j-1)+1的位置往后取区间最小值,很容易取不全或者是多取。比如说查询(1,7),返回值应该是RMQ(1,4)。问题来了,第二个区间的长度是3,st数组的j究竟该取多少?1?2^1=2,取不全;2?2^2=4,会往后多取了一个数,这时答案就不对了。怎么解决呢?

我们想一下刚刚的问题根本所在,是无法确定后半区间如何覆盖。那么我们~简单粗暴一点~干脆从后往前取,设右端点为r,那st数组的第一维可以写成r-(2^z)+1,依旧要+1,z具体代表什么后面揭晓。(其实很多人也能看出来了吧)

从后往前取重复了怎么办?这没问题,取的是最值,重复区间并不影响最值。

具体代码实现:

int query(int l, int r)
{
    int t = lg[r - l + 1];
    return max(f[l][t], f[r - (1 << t) + 1][t]);
}

t代表的是小于等于i + (1 << (j - 1))最右边的数也就是最短和前面一半的区间相交的地方。

例题

image(1).png

题目链接

链接

题解

#include <algorithm>
#include <bitset>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <fstream>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>

using namespace std;
#define x first
#define y second
#define endl '\n'
#define IOS                       \
    ios_base::sync_with_stdio(0); \
    cin.tie(0);                   \
    cout.tie(0);
typedef long long ll;
typedef pair<ll, ll> pll;
typedef pair<int, int> pii;
typedef unsigned long long ull;

const int INF = 0x3f3f3f3f, mod = 1000000007;
const int N = 2e6 + 10;
int n, m;
int f[N][22];
int lg[N];

int query(int l, int r)
{
    int t = lg[r - l + 1];
    return max(f[l][t], f[r - (1 << t) + 1][t]);//注意在这是1<<t,不是t,减少的是长度
}

void solve()
{
    cin >> n >> m;
    lg[0] = -1;
    for (int i = 1; i <= n; i++) {
        cin >> f[i][0];
        lg[i] = lg[i >> 1] + 1;
    }
    for (int j = 1; j <= lg[n]; j++) {
        for (int i = 1; i + (1 << j) - 1 <= n; i++) {
            f[i][j] = max(f[i][j - 1], f[i + (1 << (j - 1))][j - 1]);
        }
    }
    while (m--) {
        int x, y;
        cin >> x >> y;
        cout << query(x, y) << endl;
    }
}
int main()
{
    IOS;
    int t = 1;
    // cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
最后修改:2023 年 04 月 14 日
如果觉得我的文章对你有用,请随意赞赏