动态开点线段树说明

作者:Grey

原文地址:

博客园:动态开点线段树说明

CSDN:动态开点线段树说明

说明

针对普通线段树,参考使用线段树解决数组任意区间元素修改问题

在普通线段树中,线段树在预处理的时候,需要申请 4 倍大小的数组空间来存放划分的区域,

而本文介绍的动态开点线段树,它和普通线段树的区别是,动态开点线段树不需要像普通线段树那样提前申请 4 倍大小的数据空间来存放划分区域,等到实际使用的时候,再来申请。

先讲一种比较简单的动态开点线段树,这种线段树只支持单点的更新和查询。

即支持如下两个方法

void add(i, v);

该方法表示在 i 上的值加上 v;

int query(int s, int e)

该方法用于获取 s 到 e 区间内的累加和信息。

该线段树只需要定义一个节点数据结构即可

  public static class Node {
    public int sum;
    public Node left;
    public Node right;
  }

其中 sum 表示 Node 所在区间的累加和,left 表示节点左孩子信息,right 表示节点右孩子信息。

线段树初始化过程也只需要

  public static class DynamicSegmentTree {
    public Node root;
    public int size;

    public DynamicSegmentTree(int max) {
      root = new Node();
      size = max;
    }
  }

size 表示线段树支持的范围,这个范围从线段树一开始初始化的时候设定好(编号1 到 编号size就是区间范围)。和普通线段树不一样的地方在于,节点只建立了 root 节点,未初始化所有区间。

接下来看add方法,

    public void add(int i, int v) {
      add(root, 1, size, i, v);
    }

这个方法调用了线段树内部的私有add方法,

    // c-> cur 当前节点!表达的范围 l~r
    // i位置的数,增加v
    // 潜台词!i一定在l~r范围上!
    private void add(Node c, int l, int r, int i, int v) {
      if (l == r) {
        c.sum += v;
      } else { // l~r 还可以划分
        int mid = (l + r) / 2;
        if (i <= mid) { // l ~ mid
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { // mid + 1 ~ r
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);
      }
    }

这个add方法的几个参数分别代表

c : 表示 add 操作的区间代表节点是多少

l...r 表示任务区间,由于初始化 size,所以在调用公开的 add 方法时候,l = 1, r = size,表示在初始化区间范围内操作。

i:表示要操作的位置

v: 表示要增加的值

整个 add 私有方法逻辑也比较简单,核心代码

        // i 在节点左边
        if (i <= mid) { 
            // 如果节点的左树为空,则建立新节点
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { 
            // i 在节点右边
            // 如果节点右树为空,则建立新节点
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        // 最后当前节点要汇聚左右树的结果,之所以要判空是因为左右树可能不需要都建立出来
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);

查询方法的逻辑也比较简单

    public int query(int s, int e) {
      return query(root, 1, size, s, e);
    }

调用了内部的一个私有 query 方法,

    private int query(Node c, int l, int r, int s, int e) {
      if (c == null) {
        return 0;
      }
      if (s <= l && r <= e) { 
        return c.sum;
      }
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }
    }
  }

这个私有方法的几个参数说明如下

c:表示要操作的线段树的代表节点是什么;

l...r 是划分的区间范围

s...e 是任务的区间范围

核心逻辑如下

// 如果任务的区间已经包含了划分的区间,直接返回结果
      if (s <= l && r <= e) { 
        return c.sum;
      }
      // 否则,去左右区间拿累加和
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        // 整合成自己的累加和返回
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }

整个支持单点更新的动态线段树的完整代码如下(含对数器代码)

// 只支持单点增加 + 范围查询的动态开点线段树(累加和)
public class Code01_DynamicSegmentTree {

  public static class Node {
    public int sum;
    public Node left;
    public Node right;
  }

  // arr[0] -> 1
  // 线段树,从1开始下标!
  public static class DynamicSegmentTree {
    public Node root;
    public int size;

    public DynamicSegmentTree(int max) {
      root = new Node();
      size = max;
    }

    // 下标i这个位置的数,增加v
    public void add(int i, int v) {
      add(root, 1, size, i, v);
    }

    // c-> cur 当前节点!表达的范围 l~r
    // i位置的数,增加v
    // 潜台词!i一定在l~r范围上!
    private void add(Node c, int l, int r, int i, int v) {
      if (l == r) {
        c.sum += v;
      } else { // l~r 还可以划分
        int mid = (l + r) / 2;
        if (i <= mid) { // l ~ mid
          if (c.left == null) {
            c.left = new Node();
          }
          add(c.left, l, mid, i, v);
        } else { // mid + 1 ~ r
          if (c.right == null) {
            c.right = new Node();
          }
          add(c.right, mid + 1, r, i, v);
        }
        c.sum = (c.left != null ? c.left.sum : 0) + (c.right != null ? c.right.sum : 0);
      }
    }

    // s~e范围的累加和
    public int query(int s, int e) {
      return query(root, 1, size, s, e);
    }

    // 当前节点c,表达的范围l~r
    // 收到了一个任务,s~e这个任务!
    // s~e这个任务,影响了多少l~r范围的数,把答案返回!
    private int query(Node c, int l, int r, int s, int e) {
      if (c == null) {
        return 0;
      }
      if (s <= l && r <= e) {
        return c.sum;
      }
      int mid = (l + r) / 2;
      if (e <= mid) {
        return query(c.left, l, mid, s, e);
      } else if (s > mid) {
        return query(c.right, mid + 1, r, s, e);
      } else {
        return query(c.left, l, mid, s, e) + query(c.right, mid + 1, r, s, e);
      }
    }
  }

  public static class Right {
    public int[] arr;

    public Right(int size) {
      arr = new int[size + 1];
    }

    public void add(int i, int v) {
      arr[i] += v;
    }

    public int query(int s, int e) {
      int sum = 0;
      for (int i = s; i <= e; i++) {
        sum += arr[i];
      }
      return sum;
    }
  }

  public static void main(String[] args) {
    int size = 10000;
    int testTime = 50000;
    int value = 500;
    DynamicSegmentTree dst = new DynamicSegmentTree(size);
    Right right = new Right(size);
    System.out.println("测试开始");
    for (int k = 0; k < testTime; k++) {
      if (Math.random() < 0.5) {
        int i = (int) (Math.random() * size) + 1;
        int v = (int) (Math.random() * value);
        dst.add(i, v);
        right.add(i, v);
      } else {
        int a = (int) (Math.random() * size) + 1;
        int b = (int) (Math.random() * size) + 1;
        int s = Math.min(a, b);
        int e = Math.max(a, b);
        int ans1 = dst.query(s, e);
        int ans2 = right.query(s, e);
        if (ans1 != ans2) {
          System.out.println("出错了!");
          System.out.println(ans1);
          System.out.println(ans2);
        }
      }
    }
    System.out.println("测试结束");
  }
}

接下来看一个使用动态开点线段树来解决的一个问题

即:LeetCode 315. Count of Smaller Numbers After Self

注:本题可以用归并排序,树状数组,有序表来解,也可以用动态开点线段树来解。

主要思路如下

以如下数组为例来说明

nums = {5,8,7,4,2,9}

首先,初始化一个 List,这个 List 用于存放每个位置的右侧比其小的数有几个,List 的大小和原始数组一样

List<Integer> ans = new ArrayList<>(nums.length);

ans 在初始化的时候,均设置为 0 ,表示,所有位置都还没计算过。

ans = [0,0,0,0,0,0]

接下来对原始数组进行排序(注意:排序的时候,不能只使用值来排序,要带上这个值所在的位置,这样排序后才不会丢失该值在原始数组中的位置信息)

    int[][] arr = new int[n][];
    for (int i = 0; i < n; i++) {
        // 要记录值,也要记录位置,防止排序后找不到值对应的位置在哪里
      arr[i] = new int[] {nums[i], i};
    }
    // 排序按值排序
    Arrays.sort(arr, Comparator.comparingInt(a -> a[0]));

排序后,arr 按如下顺序组织

{值:2,原始位置:4}
{值:4,原始位置:3}
{值:5,原始位置:0}
{值:7,原始位置:2}
{值:8,原始位置:1}
{值:9,原始位置:5}

接下来初始化开点线段树,线段树的size就是原始数组的大小,且每个位置都是0,

按顺序遍历这个 arr 数组,最小值 2 被取出,其原始位置是 4,且 4 号位置右侧没有比自己更小的数,接下来在开点线段树中把把 4 号位置的值加1,表示 4 号位置被处理过了,在线段树中查4号位置以后并没有任何标记记录,说明没有比这个数更小的数了,直接设置4号位置的ans值为0

ans = [0,0,0,0,0,0]

线段树中

seg = [0,0,0,0,1,0]

接下来是 3 号位置的4,在线段树中查到,有一个比它小的,直接设置到 ans 中,然后在线段树中把 3 号位置也标记为 1,说明处理过,

ans = [0,0,0,1,0,0]

线段树中

seg = [0,0,0,1,1,0]

接下来是0号位置的5, 在线段树中,查到右侧有两个标记过的,说明有两个比它小的数,直接在 ans 中把 0 号位置设置为 2, 然后在线段树中把 0 号位置标记为 1 ,说明处理过,此时

ans = [2,0,0,1,0,0]

线段树中

seg = [1,0,0,1,1,0]

接下来是 2 号位置的 7, 在线段树中,查到右侧有两个标记过的,说明有两个比它小的数,直接在 ans 中把 2 号位置设置为 2, 然后在线段树中把 2 号位置标记为 1 ,说明处理过,此时

ans = [2,0,2,1,0,0]

线段树中

seg = [1,0,1,1,1,0]

接下来是 1 号位置的 8, 在线段树中,查到右侧有三个标记过的,说明有三个比它小的数,直接在 ans 中把 1 号位置设置为 3, 然后在线段树中把 1 号位置标记为 1 ,说明处理过,此时

ans = [2,3,2,1,0,0]

线段树中

seg = [1,1,1,1,1,0]

接下来是 5 号位置的 9, 在线段树中,查到右侧没有标记过的,说明没有比它小的数,直接在 ans 中把 5 号位置设置为 0, 然后在线段树中把 5 号位置标记为 1 ,说明处理过,此时

ans = [2,3,2,1,0,0]

线段树中

seg = [1,1,1,1,1,1]

以上就是整个流程。

核心代码如下

  public static List<Integer> countSmaller(int[] nums) {
    if (nums == null || nums.length == 0) {
      return new ArrayList<>();
    }
    int n = nums.length;
    List<Integer> ans = new ArrayList<>(n);
    for (int i = 0; i < n; i++) {
      ans.add(0);
    }
    int[][] arr = new int[n][];
    for (int i = 0; i < n; i++) {
        // 要记录值,也要记录位置,防止排序后找不到值对应的位置在哪里
      arr[i] = new int[] {nums[i], i};
    }
    Arrays.sort(arr, Comparator.comparingInt(a -> a[0]));
    DynamicSegmentTree dst = new DynamicSegmentTree(n);
    for (int[] num : arr) {
      ans.set(num[1], dst.query(num[1] + 1, n));
      dst.add(num[1] + 1, 1);
    }
    return ans;
  }

其中 DynamicSegmentTree 结构就是前面提到的动态开点线段树的实现。

更多

算法和数据结构笔记