Leetcode 3569. Maximize Count of Distinct Primes After Split

发布于:2025-06-02 ⋅ 阅读:(28) ⋅ 点赞:(0)

1. 解题思路

这一题的话思路倒是还好,显然,要找出所有distinct的质数的切分,我们首先就是先将 10 5 10^5 105以内的所有质数找出来,这样,我们就能够快速的找出原始的数组当中的所有的质数的位置了。

然后,我们考察最优的切分方式,注意到:

  • 所有仅出现一次的质数,无论怎么切分,其贡献值都为一;
  • 对于出现过多次的质数,考察其第一次出现的位置与最后一次出现的位置,我们只要在其中间任意位置切一刀,则其贡献值就会由1变成2。

因此,这道题事实上也就变成了,给定若干个区域 [ l i , r i ] [l_i, r_i] [li,ri],找出一个位置 p p p,使之同时存在的区域最多。这个的话我们只需要依次将所有的 [ l i , r i ] [l_i, r_i] [li,ri]上的元素加一,然后考察线段当中的最大值即可。

但是,我们需要频繁地对对元素进行修改以及query,因此我们需要不断地修改整段区间 [ l k , r k ] [l_k, r_k] [lk,rk]上的值,然后再去进行query,这个就是一个标准的Lazy Segment Tree的题目了,虽然我还是不能熟练地写出对应的代码,不过相关的内容网上多的是,deepseek也能很快速地给出对应的代码实现,所以这里就不过多赘述了。

事实上,对应的Lazy Segment Tree的算法实现,我也是直接让deepseek帮我写作完成的……

2. 代码实现

给出python代码实现如下:

def get_primes(n):
    status = [0 for _ in range(n+1)]
    primes = set()
    for i in range(2, n+1):
        if status[i] != 0:
            continue
        primes.add(i)
        for j in range(i, n+1, i):
            status[j] = 1
    return primes

PRIMES = get_primes(10**5)

class LazySegmentTree:
    def __init__(self, arr):
        """
        根据数组 arr 构建惰性线段树
        :param arr: 输入数组
        """
        self.n = len(arr)
        self.arr = arr
        # 初始化线段树和惰性标记数组(4倍原始数组大小)
        self.size = 4 * self.n
        self.tree = [-float('inf')] * self.size  # 存储区间最大值
        self.lazy = [0] * self.size             # 存储惰性标记
        self._build(0, 0, self.n - 1)           # 从根节点开始建树

    def _build(self, node, start, end):
        """
        递归构建线段树
        :param node: 当前节点索引
        :param start: 当前节点表示的区间起始索引
        :param end: 当前节点表示的区间结束索引
        """
        if start == end:
            # 叶子节点,直接存储数组值
            self.tree[node] = self.arr[start]
            return
        
        mid = (start + end) // 2
        left_node = 2 * node + 1  # 左子节点索引
        right_node = 2 * node + 2  # 右子节点索引
        
        # 递归构建左右子树
        self._build(left_node, start, mid)
        self._build(right_node, mid + 1, end)
        
        # 当前节点值为左右子树的最大值
        self.tree[node] = max(self.tree[left_node], self.tree[right_node])

    def _push_down(self, node, start, end):
        """
        下推惰性标记到子节点
        :param node: 当前节点索引
        :param start: 当前节点表示的区间起始索引
        :param end: 当前节点表示的区间结束索引
        """
        if self.lazy[node] != 0:
            # 更新当前节点的值
            self.tree[node] += self.lazy[node]
            
            if start != end:  # 非叶子节点,标记下推
                left_node = 2 * node + 1
                right_node = 2 * node + 2
                # 将惰性标记添加到子节点
                self.lazy[left_node] += self.lazy[node]
                self.lazy[right_node] += self.lazy[node]
            
            # 清除当前节点的惰性标记
            self.lazy[node] = 0

    def update(self, l, r, val):
        """
        将闭区间 [l, r] 内的所有元素增加 val
        :param l: 区间左边界
        :param r: 区间右边界
        :param val: 要增加的值
        """
        self._update(0, 0, self.n - 1, l, r, val)

    def _update(self, node, start, end, l, r, val):
        """
        递归执行区间更新
        :param node: 当前节点索引
        :param start: 当前节点表示的区间起始索引
        :param end: 当前节点表示的区间结束索引
        :param l: 更新区间左边界
        :param r: 更新区间右边界
        :param val: 要增加的值
        """
        # 先下推当前节点的惰性标记
        self._push_down(node, start, end)
        
        # 当前节点区间与更新区间无交集
        if start > r or end < l:
            return
        
        # 当前节点区间完全包含在更新区间内
        if l <= start and end <= r:
            # 更新当前节点值
            self.tree[node] += val
            if start != end:  # 非叶子节点,更新子节点惰性标记
                left_node = 2 * node + 1
                right_node = 2 * node + 2
                self.lazy[left_node] += val
                self.lazy[right_node] += val
            return
        
        # 部分重叠,递归更新子区间
        mid = (start + end) // 2
        left_node = 2 * node + 1
        right_node = 2 * node + 2
        
        self._update(left_node, start, mid, l, r, val)
        self._update(right_node, mid + 1, end, l, r, val)
        
        # 更新当前节点值为左右子树的最大值
        self.tree[node] = max(self.tree[left_node], self.tree[right_node])

    def query(self, l, r):
        """
        查询闭区间 [l, r] 内的最大元素值
        :param l: 查询区间左边界
        :param r: 查询区间右边界
        :return: 区间内的最大值
        """
        return self._query(0, 0, self.n - 1, l, r)

    def _query(self, node, start, end, l, r):
        """
        递归执行区间查询
        :param node: 当前节点索引
        :param start: 当前节点表示的区间起始索引
        :param end: 当前节点表示的区间结束索引
        :param l: 查询区间左边界
        :param r: 查询区间右边界
        :return: 区间内的最大值
        """
        # 先下推当前节点的惰性标记
        self._push_down(node, start, end)
        
        # 当前节点区间与查询区间无交集
        if start > r or end < l:
            return -float('inf')
        
        # 当前节点区间完全包含在查询区间内
        if l <= start and end <= r:
            return self.tree[node]
        
        # 部分重叠,递归查询子区间
        mid = (start + end) // 2
        left_node = 2 * node + 1
        right_node = 2 * node + 2
        
        left_max = self._query(left_node, start, mid, l, r)
        right_max = self._query(right_node, mid + 1, end, l, r)
        
        return max(left_max, right_max)


class Solution:
    def maximumCount(self, nums: List[int], queries: List[List[int]]) -> List[int]:
        primes_locs = defaultdict(list)
        for i, x in enumerate(nums):
            if x in PRIMES:
                primes_locs[x].append(i)

        n = len(nums)
        segment_tree = LazySegmentTree([0 for _ in range(n)])
        for p, locs in primes_locs.items():
            if len(locs) > 1:
                segment_tree.update(locs[0], locs[-1], 1)

        def query(idx, val):
            nonlocal primes_locs, segment_tree, nums
            old_val = nums[idx]
            nums[idx] = val
            if old_val == val:
                return len(primes_locs) + segment_tree.query(0, n-1)
            if old_val in PRIMES:
                l, r = primes_locs[old_val][0], primes_locs[old_val][-1]
                primes_locs[old_val].pop(bisect.bisect_left(primes_locs[old_val], idx))
                if l == r:
                    primes_locs.pop(old_val)
                    pass
                elif len(primes_locs[old_val]) == 1:
                    segment_tree.update(l, r, -1)
                elif l == idx:
                    new_l = primes_locs[old_val][0]
                    segment_tree.update(idx, new_l-1, -1)
                elif r == idx:
                    new_r = primes_locs[old_val][-1]
                    segment_tree.update(new_r+1, idx, -1)
            if val in PRIMES:
                if val not in primes_locs:
                    primes_locs[val] = [idx]
                else:
                    l, r = primes_locs[val][0], primes_locs[val][-1]
                    bisect.insort(primes_locs[val], idx)
                    if len(primes_locs[val]) == 2:
                        l, r = primes_locs[val][0], primes_locs[val][-1]
                        segment_tree.update(l, r, 1)
                    elif idx < l:
                        segment_tree.update(idx, l-1, 1)
                    elif idx > r:
                        segment_tree.update(r+1, idx, 1)
            return len(primes_locs) + segment_tree.query(0, n-1)

        return [query(idx, val) for idx, val in queries]

提交代码评测得到:耗时7026ms,占用内存46.94MB。


网站公告

今日签到

点亮在社区的每一天
去签到