高级数据结构-树状数组(Fenwick Tree)

发布于:2024-12-18 ⋅ 阅读:(62) ⋅ 点赞:(0)

类模版

def lowerBits(x):
    return x & (-x)


class BinaryIndexTree:
    def __init__(self, size):
        self.size = size
        self.tree = [0] * (size + 1)

    def update(self, index, delta):
        while index <= self.size:
            self.tree[index] += delta
            index += lowerBits(index)

    def query(self, index):
        res = 0
        while index >= 1:
            res += self.tree[index]
            index -= lowerBits(index)
        return res

    def rangeQuery(self, left, right):
        return self.query(right) - self.query(left - 1)

树状数组的结构与使用非常简洁有效。

动机

给出一个数组 a [ n ] a[n] a[n],问题给出的需求为查询下标为 [ left , right ] [\text{left}, \text{right}] [left,right]全闭区间内的数组和。记区间长度为 m m m,则每次查询的时间复杂度为 O ( m ) O(m) O(m),当查询次数 k k k O ( n ) O(n) O(n)级别且区间长度也为 O ( n ) O(n) O(n)级别时。总体时间复杂度为 O ( n 2 ) O(n^2) O(n2)

第一种改进思路
使用前缀和数组 b [ n ] b[n] b[n],在查询时使用差分计算。这样单次查询的时间复杂度为 O ( 1 ) O(1) O(1)
然而我们的问题此时如果提出额外的需求,不光查询,还包含若干施加到数组 a [ n ] a[n] a[n]某元素的修改操作。
记每次修改前元素 a [ i ] a[i] a[i] c i c_i ci,修改后为 c i ′ c_i^{\prime} ci。前缀和数组需要对 b [ j ] , j ≥ i b[j], j \ge i b[j],ji,进行差分更新,每个元素的delta为 c i ′ − c i c_i^{\prime}-c_i cici。即修改操作的时间复杂度为 O ( n − i ) O(n-i) O(ni),若 i i i服从均匀分布,修改操作的时间复杂度为 Θ ( n 2 ) \Theta(\frac{n}{2}) Θ(2n)

树状数组
以上分析指出,在数组既要查询又要修改时需要一种更优秀的数据结构。前缀和数组在一定程度上缓解了问题。只要能改善修改时的时间复杂度即可。修改时时间复杂度高的原因是前缀和数组 b [ i ] b[i] b[i]管辖的是 [ 0 , i ] [0, i] [0,i]的区间和,如果可以有效缩短管辖区间则可以在数组修改时做更少的数据结构存储修改。

树状数组的数据存储结构tree为1-index。每个元素 tree [ i ] \text{tree}[i] tree[i],管辖的 a [ i ] a[i] a[i]
个数为 lowBit ( i ) \text{lowBit}(i) lowBit(i),具体区间为 [ i − lowBit ( i ) + 1 , i ] [i-\text{lowBit}(i)+1, i] [ilowBit(i)+1,i]

Low-Bits
一个非负整数的二进制数制表示中从低位开始数第一个1以及所有低位的0共同构成一个数的低位比特。

此时单次查询或更新操作的时间复杂度均为 O ( log ⁡ ( n ) ) O(\log(n)) O(log(n))(仅证明单点查询)
证明:

  • 查询:
    给出一个索引 i i i,想获得 [ 0 , i ] [0, i] [0,i]区间数组 a [ j ] a[j] a[j]的元素和。 tree [ i ] \text{tree}[i] tree[i]保存了 [ i − lowBit ( i ) + 1 , i ] [i-\text{lowBit}(i)+1, i] [ilowBit(i)+1,i]的区间和,将此区间结果加到res上,迭代 i ′ = i − lowBit ( i ) i^\prime=i-\text{lowBit}(i) i=ilowBit(i)
  1. 首先证明迭代一定停止:
    i − lowBit ( i ) + 1 ≤ i i-\text{lowBit}(i)+1 \le i ilowBit(i)+1i,i为奇数时取等号。 i ′ = i − lowBit ( i ) < i − lowBit ( i ) + 1 ≤ i i^\prime=i-\text{lowBit}(i) < i-\text{lowBit}(i)+1 \le i i=ilowBit(i)<ilowBit(i)+1i,所以迭代严格单调减,又因为 i ≥ 1 i \ge 1 i1,所以有下界。证毕。
  2. 证明操作至多 ⌈ log ⁡ ( n + 1 ) ⌉ \lceil \log(n+1) \rceil log(n+1)⌉次:
    因为每次迭代 i i i减少一个二进制1,而 i i i的二进制1至多 ⌈ log ⁡ ( i + 1 ) ⌉ \lceil \log(i+1) \rceil log(i+1)⌉位, ⌈ log ⁡ ( i + 1 ) ⌉ ≤ ⌈ log ⁡ ( n + 1 ) ⌉ \lceil \log(i+1) \rceil \le \lceil \log(n+1) \rceil log(i+1)⌉log(n+1)⌉。证毕。
  • 更新:
    我们需要更新所有包含了 a [ i ] a[i] a[i] tree [ j ] \text{tree}[j] tree[j],首先 tree [ i ] \text{tree}[i] tree[i]一定包含,因为区间是全闭的。数据是类似前缀和的,因此 j ≥ i j \ge i ji,需要证明递推公式 j ′ = j + lowBit ( j ) j^\prime=j+\text{lowBit}(j) j=j+lowBit(j)的正确性与次数渐进上限为 O ( log ⁡ ( n ) ) O(\log(n)) O(log(n))
  1. 观察到 [ k − lowBit ( k ) + 1 , k ] [k-\text{lowBit}(k)+1, k] [klowBit(k)+1,k] [ k ′ − lowBit ( k ′ ) + 1 , k ′ ] [k^\prime-\text{lowBit}(k^\prime)+1, k^\prime] [klowBit(k)+1,k]的子集,其中 k ′ k^\prime k使用以上公式迭代产生。因为新区间
    [ k + lowBit ( k ) − lowBit ( k + lowBit ( k ) ) + 1 , k + lowBit ( k ) ] [k+\text{lowBit}(k)-\text{lowBit}(k+\text{lowBit}(k))+1, k+\text{lowBit}(k)] [k+lowBit(k)lowBit(k+lowBit(k))+1,k+lowBit(k)]
    k < k + lowBit ( k ) k<k+\text{lowBit}(k) k<k+lowBit(k),且
    ( k + lowBit ( k ) − lowBit ( k + lowBit ( k ) ) + 1 ) − ( k − lowBit ( k ) + 1 ) = 2 lowBit ( k ) − lowBit ( k + lowBit ( k ) ) \begin{align} & (k+\text{lowBit}(k)-\text{lowBit}(k+\text{lowBit}(k))+1)-(k-\text{lowBit}(k)+1) \\ &= 2\text{lowBit}(k)-\text{lowBit}(k+\text{lowBit}(k)) \end{align} (k+lowBit(k)lowBit(k+lowBit(k))+1)(klowBit(k)+1)=2lowBit(k)lowBit(k+lowBit(k))
    k + lowBit ( k ) k+\text{lowBit}(k) k+lowBit(k)会使 k k k lowBit ( k ) \text{lowBit}(k) lowBit(k)至少左移一位,即至少为 2 lowBit ( k ) 2\text{lowBit}(k) 2lowBit(k),遇到进位情况则会更大。因此 ( 2 ) (2) (2)式小于等于 0 0 0
    因此以上迭代公式涉及的 k k k确实都需要更新。
    接着需要证明没有任何包含 a [ i ] a[i] a[i] tree [ j ] \text{tree}[j] tree[j]被落下。上述证明表明使用递推公式可以找出一个数列 tree [ j k ] \text{tree}[j_k] tree[jk],这个数列的所有项是需要更新的项,假设存在其它包含 a [ i ] a[i] a[i] tree [ j ] \text{tree}[j] tree[j]被落下,按照递推序,我们可以定义这些被落下的项组成数列的首项,若包含多组数列,将首项中下标最小的取出,证明这项也已经被包含即可。记这项为 tree [ i + d ] , d > 0 \text{tree}[i+d], d>0 tree[i+d],d>0
    明显 tree [ i ] \text{tree}[i] tree[i]是按照下标排序第一个包含 a [ i ] a[i] a[i]的项,且已经包含在递推公式中。则有 i + d − lowBit ( i + d ) + 1 < i < i + d i+d-\text{lowBit}(i+d)+1< i < i+d i+dlowBit(i+d)+1<i<i+d,第一个不等式可以推出 d + 1 < lowBit ( i + d ) d +1< \text{lowBit}(i+d) d+1<lowBit(i+d)
    lowBit ( i ) ≠ lowBit ( d ) \text{lowBit}(i) \ne \text{lowBit}(d) lowBit(i)=lowBit(d)时, lowBit ( i + d ) = min ⁡ { lowBit ( i ) , lowBit ( d ) } \text{lowBit}(i+d) = \min \{ \text{lowBit}(i), \text{lowBit}(d) \} lowBit(i+d)=min{lowBit(i),lowBit(d)},第一种情况 d + 1 < lowBit ( d ) ≤ d d+1<\text{lowBit}(d) \le d d+1<lowBit(d)d lowBit ( d ) < lowBit ( i ) \text{lowBit}(d)<\text{lowBit}(i) lowBit(d)<lowBit(i)矛盾,第二种情况 d + 1 < lowBit ( i ) d+1<\text{lowBit}(i) d+1<lowBit(i) lowBit ( i ) < lowBit ( d ) \text{lowBit}(i)<\text{lowBit}(d) lowBit(i)<lowBit(d),即 d + 1 < lowBit ( d ) ≤ d d+1<\text{lowBit}(d) \le d d+1<lowBit(d)d矛盾。
    因此 lowBit ( i ) = lowBit ( d ) \text{lowBit}(i) = \text{lowBit}(d) lowBit(i)=lowBit(d),由于一定有 lowBit ( d ) ≤ d \text{lowBit}(d) \le d lowBit(d)d,推出 lowBit ( i ) ≤ d \text{lowBit}(i) \le d lowBit(i)d。所以满足条件的最小 d d d lowBit ( i ) \text{lowBit}(i) lowBit(i),即 tree [ i + d ] \text{tree}[i+d] tree[i+d]是递推公式中包含的项,与假设矛盾。

Tips: 数学归纳法正向推导也是可行的。

  1. j ′ = j + lowBit ( j ) , j ≤ n j^\prime=j+\text{lowBit}(j), j \le n j=j+lowBit(j),jn迭代至多 O ( log ⁡ ( n ) ) O(\log(n)) O(log(n))次。因为每次会使 j j j lowBit ( j ) \text{lowBit}(j) lowBit(j)至少左移一位,总共只有 O ( log ⁡ ( n ) ) O(\log(n)) O(log(n))位。
    审讯树也是可以用来证明与理解树状数组的方式。

离散化树状数组

在例如扫描线等问题中,数据定义在无限二维平面上,某维度的值可能取值范围很大,但是数据相对来说较为稀疏。可以将所有值按照序进行排序压缩映射到一个更小的范围来节约树状数组的空间。一般实际应用都是使用状态压缩了的离散化树状数组。

解题应用

一维离散化问题

LCR 170. 交易逆序对的总数
常规解法:

class Solution:
    def reversePairs(self, record: List[int]) -> int:        
        n=len(record)
        if n==0:
            return 0
        record_=record.copy()
        record_.sort()
        record_compression = {record_[0]:1}
        r_idx=2
        for r in record_[1:]:
            if record_compression.get(r, None):
                pass 
            else:
                record_compression[r]=r_idx 
                r_idx+=1 
        bit = BinaryIndexTree(len(record_compression))
        res=0
        # for i in range(n-1, -1, -1):
        #     r=record[i]
        #     res+=(n-1-i-bit.rangeQuery(record_compression[r], len(record_compression)))
        #     bit.update(record_compression[r])
        bit.update(record_compression[record[0]])
        for i in range(1, n):
            r=record[i]
            res+=(i-bit.query(record_compression[r]))
            bit.update(record_compression[r])
        return res

更巧妙解法:
从后向前遍历,出现了更小的就是逆序。可以节省一部分减法。

class Solution:
    def reversePairs(self, record: List[int]) -> int:        
        n=len(record)
        if n==0:
            return 0
        record_=record.copy()
        record_.sort()
        record_compression = {record_[0]:1}
        r_idx=2
        for r in record_[1:]:
            if record_compression.get(r, None):
                pass 
            else:
                record_compression[r]=r_idx 
                r_idx+=1 
        bit = BinaryIndexTree(len(record_compression))
        res=0

        for i in range(n-1, -1, -1):
            r=record[i]
            res+=bit.query(record_compression[r]-1)
            bit.update(record_compression[r])
        return res

扫描线配合树状数组解决计算几何问题(二维)

3382. 用点构造面积最大的矩形 II

class Solution:
    def maxRectangleArea(self, x_coords: List[int], y_coords: List[int]) -> int:
        def get_lowest_bit(x):
            return x & -x

        class FenwickTree:
            def __init__(self, size):
                self.tree = [0] * (size + 1)
                self.size = size

            def update(self, index, delta=1):
                while index <= self.size:
                    self.tree[index] += delta
                    index += get_lowest_bit(index)

            def query(self, index):
                result = 0
                while index >= 1:
                    result += self.tree[index]
                    index -= get_lowest_bit(index)
                return result

        # Compress y-coordinates to a unique index
        y_compression = {}
        max_area = -1
        events = [list(event) for event in zip(x_coords, y_coords)]
        events.sort(key=lambda x: (x[0], x[1]))
        y_values = [events[0][1]]
        vertical_lines = [[events[0]]]

        for event in events[1:]:
            y_values.append(event[1])
            if event[0] != vertical_lines[-1][0][0]:
                vertical_lines.append([event])
            else:
                vertical_lines[-1].append(event)
        y_values.sort()

        # Assign compressed indices to y-values
        y_index = 1
        for i, y in enumerate(y_values):
            if y not in y_compression:
                y_compression[y] = y_index
                y_index += 1

        bit_tree = FenwickTree(len(y_compression))
        y_pair_count_map = {}

        bit_tree.update(y_compression[vertical_lines[0][0][1]])
        for i, point in enumerate(vertical_lines[0][1:], start=1):
            ya = vertical_lines[0][i - 1][1]
            yb = point[1]
            bit_tree.update(y_compression[yb])
            cnt = bit_tree.query(y_compression[yb]) - bit_tree.query(y_compression[ya] - 1)
            y_pair_count_map[(ya, yb)] = [cnt, vertical_lines[0][i - 1][0]]

        for j, vertical in enumerate(vertical_lines[1:], start=1):
            bit_tree.update(y_compression[vertical[0][1]])
            for i, point in enumerate(vertical[1:], start=1):
                ya = vertical[i - 1][1]
                yb = point[1]
                bit_tree.update(y_compression[yb])
                cnt = bit_tree.query(y_compression[yb]) - bit_tree.query(y_compression[ya] - 1)
                if (ya, yb) in y_pair_count_map:
                    if cnt - y_pair_count_map[(ya, yb)][0] == 2:
                        current_area = (yb - ya) * (point[0] - y_pair_count_map[(ya, yb)][1])
                        max_area = max(max_area, current_area)
                y_pair_count_map[(ya, yb)] = [cnt, point[0]]
        return max_area


网站公告

今日签到

点亮在社区的每一天
去签到