Segment Tree 线段树

线段树参考整理自:九章算法线段树教程

1 线段树是什么?

Google常考数据结构,国内也经常问这个。

  • 线段树是一种高级数据结构,也是一种树结构,准确的说是二叉树,它能够高效的处理区间修改查询等问题。
  • 线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点
  • 线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2]右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。

2 线段树的创建

因为每次将区间的长度一分为二,所有创造的节点个数,即底层有n个节点,那么倒数第二次约n/2个节点,倒数第三次约n/4个节点,依次类推:

n + 1/2 * n + 1/4 * n + 1/8 * n + ...
=   (1 + 1/2 + 1/4 + 1/8 + ...) * n
=   2n

所以构造线段树的时间复杂度和空间复杂度都为O(n)

LintCode 201. Segment Tree Build

线段树的创建,其实就是按照区间的index进行二分,然后Recursion的定义!

"""
Definition of SegmentTreeNode:
class SegmentTreeNode:
    def __init__(self, start, end):
        self.start, self.end = start, end
        self.left, self.right = None, None
"""

class Solution:
    def build(self, start, end):
        if start > end: return 

        root = SegmentTreeNode(start, end)
        if start == end: return root

        mid = (start + end ) / 2
        root.left = self.build(start, mid)
        root.right = self.build(mid + 1, end)
        return root

LintCode 439. Segment Tree Build II

每个节点node除了有区间index的信息外,还包括其他信息,比如区间内的最大值。Node(start, end, val)

"""
Definition of SegmentTreeNode:
class SegmentTreeNode:
    def __init__(self, start, end, max):
        self.start, self.end, self.max = start, end, max
        self.left, self.right = None, None
"""
class Solution:
    def build(self, A):
        if not A: return
        return self.buildTree(0, len(A)-1, A)

    def buildTree(self, start, end, A):
        if start > end: return

        root = SegmentTreeNode(start, end, A[start])
        if start == end: return root

        mid = (start + end) / 2
        root.left = self.buildTree(start, mid, A)
        root.right = self.buildTree(mid+1, end, A)

        # Post Order update Max
        if root.left and root.left.max > root.max:
            root.max = root.left.max
        if root.right and root.right.max > root.max:
            root.max = root.right.max

        return root

如果需要区间的最小值:

root.min = Math.min(root.left.min, root.right.min);

如果需要区间的和:

root.sum = root.left.sum + root.right.sum;

3 线段树的更新

更新是从叶子节点一路走到根节点, 去更新线段树上的值。因为线段树的高度为log(n),所以更新序列中一个节点的复杂度为log(n)。

LintCode 203. Segment Tree Modify

给一个Maximum Segment Tree, 更新某个index的value。

class Solution:
    def modify(self, root, index, value):
        if not root: return

        if root.start == root.end:
            root.val = value
            root.max = value
            return 

        mid = (root.start + root.end) / 2
        if index <= mid:
            self.modify(root.left, index, value)
        else:
            self.modify(root.right, index, value)

        if root.right and root.left:
            root.max = max(root.right.max, root.left.max)

4 线段树的查询

构造线段树的目的就是为了更快的查询

  • 给定一个区间,要求区间中最大的值。线段树的区间查询操作就是将当前区间分解为较小的子区间,然后由子区间的最大值就可以快速得到需要查询区间的最大值。
  • 任意长度的线段,最多被拆分成logn条线段树上存在的线段,所以查询的时间复杂度为O(log(n))

LintCode 202. Segment Tree Query

class Solution:
    def query(self, root, start, end):
        if not root: return 

        if root.start == start and root.end == end:
            return root.max

        mid = (root.start + root.end) / 2
        if start > mid:
            return self.query(root.right, start, end)
        elif end <= mid:
            return self.query(root.left, start, end)
        else:
            return max(self.query(root.left, start, mid), 
                        self.query(root.right, mid+1, end))

5 线段树的应用

线段树的基本应用

  • 支持动态更改数组一个元素的值 O(logn)
  • 区间的和、最大值、最小值 O(logn)
  • 创建,更新,求和或者求最大最小,只有这三个function!

307. Range Sum Query - Mutable

区间求和,如果原序列不变的话直接用preSum前缀和就可以了,但是如果序列可变的话update前缀和的复杂度就变成O(n)

# 完整版Segment Tree 充分利用上面的基础操作,结果TLE =。=
# 这里最好用 Binary Indexed Tree
class SegmentTreeNode(object):
    def __init__(self, start, end, val):
        self.start, self.end = start, end
        self.sum = val
        self.left, self.right = None ,None

class SegmentTree:

    def build(self, start, end, nums):
        if start > end: 
            return None
        root = SegmentTreeNode(start, end, nums[start])
        if start == end:
            return root
        mid = (start + end) / 2
        root.left = self.build(start, mid, nums)
        root.right = self.build(mid+1, end, nums)
        root.sum = (root.left.sum if root.left else 0) + (root.right.sum if root.right else 0)
        return root

    def update(self, root, index, value):
        if not root:
            return

        if root.start == root.end:
            root.sum = value
            return 

        mid = (root.start + root.end) / 2
        if index <= mid:
            self.update(root.left, index, value)
        else:
            self.update(root.right, index, value)

        root.sum = (root.left.sum if root.left else 0) + (root.right.sum if root.right else 0)

    def rangeSum(self, root, start, end):
        if not root:
            return

        if root.start == root.end :
            return root.sum

        mid = (root.start + root.end) / 2
        if start > mid:
            return self.rangeSum(root.right, start, end)
        elif end <= mid:
            return self.rangeSum(root.left, start, end)
        else:
            return self.rangeSum(root.left, start, mid) + self.rangeSum(root.right, mid+1, end)

class NumArray(object):

    def __init__(self, nums):
        self.stree = SegmentTree()
        self.root = self.stree.build(0, len(nums)-1, nums)

    def update(self, i, val):
        self.stree.update(self.root, i, val)


    def sumRange(self, i, j):
        return self.stree.rangeSum(self.root, i, j)

LintCode 249. Count of Smaller Number before itsel

统计数组中每个元素后面比自己小的数。

  • 初始化[0, max(nums)]的数组全部都是0
  • 把数组中的数值当作index,统计数量就转化成了求区间和

其他解法:Binary Indexed Tree,Binary Search Tree

class SegTree:

    def __init__(self, start, end):
        self.start, self.end = start, end
        self.cnt = 0
        self.left, self.right = None, None
        if start != end:
            mid = (start + end) / 2
            self.left = SegTree(start, mid)
            self.right = SegTree(mid+1, end)

    def inc(self, index):
        if self.start == self.end:
            self.cnt += 1
            return

        if index <= self.left.end:
            self.left.inc(index)
        else:
            self.right.inc(index)

        self.cnt = self.left.cnt + self.right.cnt

    def sum(self, start, end):
        if start <= self.start and end >= self.end:
            return self.cnt

        if self.start == self.end:
            return 0

        if start >= self.right.start:
            return self.right.sum(start, end)
        elif end <= self.left.end:
            return self.left.sum(start, end)
        else:
            return (self.left.sum(start, self.left.end)
                    + self.right.sum(self.right.start, end))

class Solution:
    def countOfSmallerNumberII(self, A):
        if not A:
            return []

        root = SegTree(0, max(A))
        res = []
        for n in A:
            res.append(root.sum(0, n-1))
            root.inc(n)
        return res

315. Count of Smaller Numbers After Self

  • 值为负数或者最大值特别大的时候用Segment Tree,val无法作为index,导致 非常不合适!
  • 所以这个题得考虑别的解决方案

results matching ""

    No results matching ""