Segment Tree 线段树
线段树参考整理自:九章算法线段树教程
1 线段树是什么?
Google常考数据结构,国内也经常问这个。
- 线段树是一种高级数据结构,也是一种树结构,准确的说是二叉树,它能够高效的处理区间修改查询等问题。
- 线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
- 线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。
2 线段树的创建
因为每次将区间的长度一分为二,所有创造的节点个数,即底层有n个节点,那么倒数第二次约n/2个节点,倒数第三次约n/4个节点,依次类推:
n + 1/2 * n + 1/4 * n + 1/8 * n + ...
= (1 + 1/2 + 1/4 + 1/8 + ...) * n
= 2n
所以构造线段树的时间复杂度和空间复杂度都为O(n)
LintCode 201. Segment Tree Build
线段树的创建,其实就是按照区间的index进行二分,然后Recursion的定义!
"""
Definition of SegmentTreeNode:
class SegmentTreeNode:
def __init__(self, start, end):
self.start, self.end = start, end
self.left, self.right = None, None
"""
class Solution:
def build(self, start, end):
if start > end: return
root = SegmentTreeNode(start, end)
if start == end: return root
mid = (start + end ) / 2
root.left = self.build(start, mid)
root.right = self.build(mid + 1, end)
return root
LintCode 439. Segment Tree Build II
每个节点node除了有区间index的信息外,还包括其他信息,比如区间内的最大值。Node(start, end, val)
"""
Definition of SegmentTreeNode:
class SegmentTreeNode:
def __init__(self, start, end, max):
self.start, self.end, self.max = start, end, max
self.left, self.right = None, None
"""
class Solution:
def build(self, A):
if not A: return
return self.buildTree(0, len(A)-1, A)
def buildTree(self, start, end, A):
if start > end: return
root = SegmentTreeNode(start, end, A[start])
if start == end: return root
mid = (start + end) / 2
root.left = self.buildTree(start, mid, A)
root.right = self.buildTree(mid+1, end, A)
# Post Order update Max
if root.left and root.left.max > root.max:
root.max = root.left.max
if root.right and root.right.max > root.max:
root.max = root.right.max
return root
如果需要区间的最小值:
root.min = Math.min(root.left.min, root.right.min);
如果需要区间的和:
root.sum = root.left.sum + root.right.sum;
3 线段树的更新
更新是从叶子节点一路走到根节点, 去更新线段树上的值。因为线段树的高度为log(n),所以更新序列中一个节点的复杂度为log(n)。
LintCode 203. Segment Tree Modify
给一个Maximum Segment Tree, 更新某个index的value。
class Solution:
def modify(self, root, index, value):
if not root: return
if root.start == root.end:
root.val = value
root.max = value
return
mid = (root.start + root.end) / 2
if index <= mid:
self.modify(root.left, index, value)
else:
self.modify(root.right, index, value)
if root.right and root.left:
root.max = max(root.right.max, root.left.max)
4 线段树的查询
构造线段树的目的就是为了更快的查询。
- 给定一个区间,要求区间中最大的值。线段树的区间查询操作就是将当前区间分解为较小的子区间,然后由子区间的最大值就可以快速得到需要查询区间的最大值。
- 任意长度的线段,最多被拆分成logn条线段树上存在的线段,所以查询的时间复杂度为O(log(n))
LintCode 202. Segment Tree Query
class Solution:
def query(self, root, start, end):
if not root: return
if root.start == start and root.end == end:
return root.max
mid = (root.start + root.end) / 2
if start > mid:
return self.query(root.right, start, end)
elif end <= mid:
return self.query(root.left, start, end)
else:
return max(self.query(root.left, start, mid),
self.query(root.right, mid+1, end))
5 线段树的应用
线段树的基本应用
- 支持动态更改数组一个元素的值 O(logn)
- 求区间的和、最大值、最小值 O(logn)
- 创建,更新,求和或者求最大最小,只有这三个function!
307. Range Sum Query - Mutable
区间求和,如果原序列不变的话直接用preSum前缀和就可以了,但是如果序列可变的话update前缀和的复杂度就变成O(n)
# 完整版Segment Tree 充分利用上面的基础操作,结果TLE =。=
# 这里最好用 Binary Indexed Tree
class SegmentTreeNode(object):
def __init__(self, start, end, val):
self.start, self.end = start, end
self.sum = val
self.left, self.right = None ,None
class SegmentTree:
def build(self, start, end, nums):
if start > end:
return None
root = SegmentTreeNode(start, end, nums[start])
if start == end:
return root
mid = (start + end) / 2
root.left = self.build(start, mid, nums)
root.right = self.build(mid+1, end, nums)
root.sum = (root.left.sum if root.left else 0) + (root.right.sum if root.right else 0)
return root
def update(self, root, index, value):
if not root:
return
if root.start == root.end:
root.sum = value
return
mid = (root.start + root.end) / 2
if index <= mid:
self.update(root.left, index, value)
else:
self.update(root.right, index, value)
root.sum = (root.left.sum if root.left else 0) + (root.right.sum if root.right else 0)
def rangeSum(self, root, start, end):
if not root:
return
if root.start == root.end :
return root.sum
mid = (root.start + root.end) / 2
if start > mid:
return self.rangeSum(root.right, start, end)
elif end <= mid:
return self.rangeSum(root.left, start, end)
else:
return self.rangeSum(root.left, start, mid) + self.rangeSum(root.right, mid+1, end)
class NumArray(object):
def __init__(self, nums):
self.stree = SegmentTree()
self.root = self.stree.build(0, len(nums)-1, nums)
def update(self, i, val):
self.stree.update(self.root, i, val)
def sumRange(self, i, j):
return self.stree.rangeSum(self.root, i, j)
LintCode 249. Count of Smaller Number before itsel
统计数组中每个元素后面比自己小的数。
- 初始化[0, max(nums)]的数组全部都是0
- 把数组中的数值当作index,统计数量就转化成了求区间和
其他解法:Binary Indexed Tree,Binary Search Tree
class SegTree:
def __init__(self, start, end):
self.start, self.end = start, end
self.cnt = 0
self.left, self.right = None, None
if start != end:
mid = (start + end) / 2
self.left = SegTree(start, mid)
self.right = SegTree(mid+1, end)
def inc(self, index):
if self.start == self.end:
self.cnt += 1
return
if index <= self.left.end:
self.left.inc(index)
else:
self.right.inc(index)
self.cnt = self.left.cnt + self.right.cnt
def sum(self, start, end):
if start <= self.start and end >= self.end:
return self.cnt
if self.start == self.end:
return 0
if start >= self.right.start:
return self.right.sum(start, end)
elif end <= self.left.end:
return self.left.sum(start, end)
else:
return (self.left.sum(start, self.left.end)
+ self.right.sum(self.right.start, end))
class Solution:
def countOfSmallerNumberII(self, A):
if not A:
return []
root = SegTree(0, max(A))
res = []
for n in A:
res.append(root.sum(0, n-1))
root.inc(n)
return res
315. Count of Smaller Numbers After Self
- 值为负数或者最大值特别大的时候用Segment Tree,val无法作为index,导致 非常不合适!
- 所以这个题得考虑别的解决方案