LeetCode第 454 场周赛题解

发布于:2025-06-16 ⋅ 阅读:(17) ⋅ 点赞:(0)

题目地址

https://leetcode.cn/contest/weekly-contest-454/

锐评

参赛人数不足2k,跟巅峰时期没法比。随着AI加入,感觉会越来越多人弃坑。

题目难度适中。每一题基本上思路还算明显,就是细节上可能有些小坑,看榜上一片WA。可惜了第四题,思路是对的,赛时一直没调出来,错在了两个点,一是树节点深度跟距离用混了,二是倍增时,右侧方向枚举反了。

题解

Q1. 为视频标题生成标签

题意

给你一个字符串 caption,表示一个视频的标题。

需要按照以下步骤 按顺序 生成一个视频的 有效标签 :

  1. 所有单词 组合为单个 驼峰命名字符串 ,并在前面加上 ‘#’。驼峰命名字符串 指的是除第一个单词外,其余单词的首字母大写,且每个单词的首字母之后的字符必须是小写。

  2. 移除 所有不是英文字母的字符,但 保留 第一个字符 ‘#’。

  3. 将结果 截断 为最多 100 个字符。

对 caption 执行上述操作后,返回生成的 标签 。

示例
示例1
输入: caption = "Leetcode daily streak achieved"
输出: "#leetcodeDailyStreakAchieved"
说明

除了 “leetcode” 以外的所有单词的首字母需要大写。

示例2
输入: caption = "can I Go There"
输出: "#canIGoThere"
说明

除了 “can” 以外的所有单词的首字母需要大写。

示例3
输入: caption = "hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh"
输出: "#hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh"
说明

由于第一个单词长度为 101,因此需要从单词末尾截去最后两个字符。

提示
  • 1 < = c a p t i o n . l e n g t h < = 150 1 <= caption.length <= 150 1<=caption.length<=150
  • caption 仅由英文字母和 ’ ’ 组成。
解题思路:模拟

简单题。按照题意模拟即可,时间复杂度为 O ( n ) O(n) O(n)

PS:可能会有空字符串,需要注意下。幸好我的解法不需要考虑这种case。赛时差点切Java,好在最后没有,不然可能WA一发。逃)

参考代码(C++)
class Solution {
public:
    string generateTag(string caption) {
        string ans;
        ans.push_back('#');
        bool f = false;
        int n = caption.size();
        for (int i = 0; i < n; ++i)
            if (isalpha(caption[i])) {
                if (!f)
                    ans.push_back(tolower(caption[i]));
                else {
                    if (i - 1 >= 0 && !isalpha(caption[i - 1]))
                        ans.push_back(toupper(caption[i]));
                    else
                        ans.push_back(tolower(caption[i]));
                }
                f = true;
            }
        return ans.substr(0, min(int(ans.size()), 100));
    }
};

Q2. 统计特殊三元组

题意

给你一个整数数组 nums。

特殊三元组 定义为满足以下条件的下标三元组 (i, j, k):

  • 0 <= i < j < k < n,其中 n = nums.length
  • nums[i] == nums[j] * 2
  • nums[k] == nums[j] * 2

返回数组中 特殊三元组 的总数。

由于答案可能非常大,请返回结果对 10 9 + 7 10^9 + 7 109+7 取余数后的值。

示例
示例1
输入: nums = [6,3,6]
输出: 1
说明

唯一的特殊三元组是 (i, j, k) = (0, 1, 2),其中:

  • nums[0] = 6, nums[1] = 3, nums[2] = 6
  • nums[0] = nums[1] * 2 = 3 * 2 = 6
  • nums[2] = nums[1] * 2 = 3 * 2 = 6
示例2
输入: nums = [0,1,0,0]
输出: 1
说明

唯一的特殊三元组是 (i, j, k) = (0, 2, 3),其中:

  • nums[0] = 0, nums[2] = 0, nums[3] = 0
  • nums[0] = nums[2] * 2 = 0 * 2 = 0
  • nums[3] = nums[2] * 2 = 0 * 2 = 0
示例3
输入: nums = [8,4,2,8,4]
输出: 2
说明

共有两个特殊三元组:

  • (i, j, k) = (0, 1, 3)
  • nums[0] = 8, nums[1] = 4, nums[3] = 8
  • nums[0] = nums[1] * 2 = 4 * 2 = 8
  • nums[3] = nums[1] * 2 = 4 * 2 = 8
  • (i, j, k) = (1, 2, 4)
  • nums[1] = 4, nums[2] = 2, nums[4] = 4
  • nums[1] = nums[2] * 2 = 2 * 2 = 4
  • nums[4] = nums[2] * 2 = 2 * 2 = 4
提示
  • 3 < = n = = n u m s . l e n g t h < = 10 5 3 <= n == nums.length <= 10^5 3<=n==nums.length<=105
  • 0 < = n u m s [ i ] < = 10 5 0 <= nums[i] <= 10^5 0<=nums[i]<=105
解题思路:计数+前后缀分解

中等题。根据题目中的式子,很明显,我们可以枚举中间的位置,然后计算出它前面和后面满足条件的数的个数,根据 乘法原理 将二者乘起来计入答案即可,时间复杂度为 O ( n ) O(n) O(n)

参考代码(C++)
class Solution {
    const int mod = 1'000'000'007;
public:
    int specialTriplets(vector<int>& nums) {
        unordered_map<int, int> prec, sufc;
        for (int& x : nums)
            ++sufc[x];
        int ans = 0;
        for (int& x : nums) {
            --sufc[x];
            int y = x << 1;
            if (prec.count(y) && sufc.count(y)) {
                ans += 1LL * prec[y] * sufc[y] % mod;
                ans %= mod;
            }
            ++prec[x];
        }
        return ans;
    }
};

Q3. 子序列首尾元素的最大乘积

题意

给你一个整数数组 nums 和一个整数 m。

返回任意大小为 m 的 子序列 中首尾元素乘积的最大值

子序列 是可以通过删除原数组中的一些元素(或不删除任何元素),且不改变剩余元素顺序而得到的数组。

示例
示例1
输入: nums = [-1,-9,2,3,-2,-3,1], m = 1
输出: 81
说明

子序列 [-9] 的首尾元素乘积最大:-9 * -9 = 81。因此,答案是 81。

示例2
输入: nums = [1,3,-5,5,6,-4], m = 3
输出: 20
说明

子序列 [-5, 6, -4] 的首尾元素乘积最大。

示例3
输入: nums = [2,-1,2,-6,5,2,-5,7], m = 2
输出: 35
说明

子序列 [5, 7] 的首尾元素乘积最大。

提示
  • 1 < = n u m s . l e n g t h < = 10 5 1 <= nums.length <= 10^5 1<=nums.length<=105
  • − 10 5 < = n u m s [ i ] < = 10 5 -10^5 <= nums[i] <= 10^5 105<=nums[i]<=105
  • 1 < = m < = n u m s . l e n g t h 1 <= m <= nums.length 1<=m<=nums.length
解题思路:线段树/滑动窗口

中等题。根据题意,既然是首尾,那么我们可以枚举首/尾,然后去找尾/首。因为子序列长度要为 m,题目又没要求所有的子序列,因此只要满足条件即可。

显然,第一个满足条件的位置要基于当前位置滑动 m - 1 个位置,这时恰好有 m 个元素。进而基于当前位置滑动 m 个以上位置的地方仍然可作为首/尾,因为此时该区间元素个数已经大于 m 了,固定首尾,中间随便选择 m - 2 个位置即可。

那么答案怎么来呢?题目要求乘积最大,而根据提示数据有正有负,对于我们枚举的这个首/尾,怎么样才可能有最大值呢?显然,要么乘以一个最大值,要么乘以一个最小值。

根据上面的分析,那么问题就转化为:对于每个位置,求符合条件的某个位置区间的最大最小值。这个问题解法有很多,例如 ST表/线段树 等。赛时,为了快速A掉,直接上了线段树,时间复杂度为 O ( n l o g n ) O(nlogn) O(nlogn)

但其实你会发现这题区间是随着枚举位置的移动只会变大的,每移动一个位置,区间边界扩大一。因此,完全没必要用 线段树 。直接用一个 有序集合 就可以了,不断扩张该集合,每次取一下最大最小值更新答案,时间复杂度为 O ( n l o g n ) O(nlogn) O(nlogn)

再仔细想一下,你会发现 有序集合 也是多余的,我只需要知道当前满足条件的区间最大最小值是多少即可,因此,每扩展一个区间端点,更新一下当前区间最大最小值,时间复杂度为 O ( n ) O(n) O(n)

参考代码(C++)

线段树版本

using ll = long long;
const int maxn = 100'005;
const ll inf = 0x3f3f3f3f3f3f3f3fLL;
const ll mod = inf;
const ll maxvi = -inf;
const ll minvi = inf;
const ll setvi = -1;
const ll addvi = 0;
const ll revvi = 0;
struct seg_info {
    ll sumv, maxv, minv;

    seg_info() {
        sumv = 0;
        maxv = maxvi;
        minv = minvi;
    }

    seg_info(ll sumv, ll maxv, ll minv): sumv(sumv), maxv(maxv), minv(minv) {}

    seg_info operator + (const seg_info& p) const {
        seg_info ans;
        ans.sumv = (sumv + p.sumv) % mod;
        ans.maxv = max(maxv, p.maxv);
        ans.minv = min(minv, p.minv);
        return ans;
    }

    seg_info& operator += (const seg_info& p) {
        sumv = (sumv + p.sumv) % mod;
        maxv = max(maxv, p.maxv);
        minv = min(minv, p.minv);
        return *this;
    }
};

struct seg_node_info {
    seg_info si;
    ll setv, addv, revv;

    seg_node_info() {
        si = seg_info();
    }

    void set(int l, int r, ll val) {
        setv = val;
        si.sumv = val % mod * (r - l + 1) % mod;
        si.maxv = val;
        si.minv = val;
    }

    void add(int l, int r, ll val) {
        addv = (addv + val % mod) % mod;
        si.sumv = (si.sumv + val % mod * (r - l + 1) % mod) % mod;
        si.maxv = (si.maxv == maxvi) ? val : (si.maxv + val);
        si.minv = (si.minv == minvi) ? val : (si.minv + val);
    }

    void rev(int l, int r, ll val) {
        revv ^= val;
        si.sumv = (r - l + 1) - si.sumv;
        si.maxv = (si.maxv == maxvi) ? val : (si.maxv ^ val);
        si.minv = (si.minv == minvi) ? val : (si.minv ^ val);
    }
};

struct segment_tree {
    #ifndef SEG_SET
    #define SEG_SET
    #endif
//     #ifndef SEG_ADD
//     #define SEG_ADD
//     #endif
//     #ifndef SEG_REV
//     #define SEG_REV
//     #endif
    #define lrt rt << 1
    #define rrt rt << 1 | 1
    #define lson l, mid, lrt
    #define rson mid + 1, r, rrt
    #define op_set 0
    #define op_add 1
    #define op_rev 2
    seg_node_info sni[maxn << 2];

    void push_up(int rt) {
        sni[rt].si = sni[lrt].si + sni[rrt].si;
    }

    void build(int l, int r, int rt) {
        sni[rt] = seg_node_info();
        #ifdef SEG_SET
        sni[rt].setv = setvi;
        #endif
        #ifdef SEG_ADD
        sni[rt].addv = addvi;
        #endif
        #ifdef SEG_REV
        sni[rt].revv = revvi;
        #endif
        if (l == r)
            return;
        int mid = (l + r) >> 1;
        build(lson);
        build(rson);
        push_up(rt);
    }

    void push_down(int l, int r, int rt) {
        int mid = (l + r) >> 1;
        #ifdef SEG_SET
        if (sni[rt].setv != setvi) {
            sni[lrt].set(l, mid, sni[rt].setv);
            sni[rrt].set(mid + 1, r, sni[rt].setv);
            sni[lrt].addv = addvi;
            sni[rrt].addv = addvi;
            sni[rt].setv = setvi;
        }
        #endif
        #ifdef SEG_ADD
        if (sni[rt].addv != addvi) {
            sni[lrt].add(l, mid, sni[rt].addv);
            sni[rrt].add(mid + 1, r, sni[rt].addv);
            sni[rt].addv = addvi;
        }
        #endif
        #ifdef SEG_REV
        if (sni[rt].revv != revvi) {
            sni[lrt].rev(l, mid, sni[rt].revv);
            sni[rrt].rev(mid + 1, r, sni[rt].revv);
            sni[rt].revv = revvi;
        }
        #endif
    }

    void update(int op, int cl, int cr, ll val, int l, int r, int rt) {
        if (cl <= l && r <= cr) {
            if (op == op_set) {
                sni[rt].set(l, r, val);
                sni[rt].addv = addvi;
            } else if (op == op_add)
                sni[rt].add(l, r, val);
            else if (op == op_rev)
                sni[rt].rev(l, r, val);
            return;
        }
        push_down(l, r, rt);
        int mid = (l + r) >> 1;
        if (cl <= mid)
            update(op, cl, cr, val, lson);
        if (cr > mid)
            update(op, cl, cr, val, rson);
        push_up(rt);
    }

    void update_set(int cl, int cr, ll val, int l, int r, int rt) {
        update(op_set, cl, cr, val, l, r, rt);
    }

    void update_add(int cl, int cr, ll val, int l, int r, int rt) {
        update(op_add, cl, cr, val, l, r, rt);
    }

    void update_rev(int cl, int cr, int l, int r, int rt) {
        update(op_rev, cl, cr, 1, l, r, rt);
    }

    seg_info queries(int ql, int qr, int l, int r, int rt) {
        if (ql <= l && r <= qr)
            return sni[rt].si;
        push_down(l, r, rt);
        int mid = (l + r) >> 1;
        seg_info ans = seg_info();
        if (ql <= mid)
            ans += queries(ql, qr, lson);
        if (qr > mid)
            ans += queries(ql, qr, rson);
        return ans;
    }
} st;

class Solution {
public:
    long long maximumProduct(vector<int>& nums, int m) {
        int n = nums.size();
        st.build(0, n - 1, 1);
        for (int i = 0; i < n; ++i)
            st.update_set(i, i, nums[i], 0, n - 1, 1);
        ll ans = -inf;
        for (int i = 0; i < n; ++i) {
            int j = i + m - 1;
            if (j < n) {
                auto res = st.queries(j, n - 1, 0, n - 1, 1);
                ans = max({ans, nums[i] * res.minv, nums[i] * res.maxv});
            }
        }
        return ans;
    }
};

有序集合版本

class Solution {
    using ll = long long;
public:
    long long maximumProduct(vector<int>& nums, int m) {
        int n = nums.size();
        multiset<int> mst;
        ll ans = LLONG_MIN;
        for (int i = m - 1; i < n; ++i) {
            int j = i - m + 1;
            mst.insert(nums[j]);
            ans = max({ans, 1LL * nums[i] * (*mst.begin()), 1LL * nums[i] * (*mst.rbegin())});
        }
        return ans;
    }
};

滑动窗口版本

class Solution {
    using ll = long long;
public:
    long long maximumProduct(vector<int>& nums, int m) {
        int n = nums.size(), minv = INT_MAX, maxv = INT_MIN;
        ll ans = LLONG_MIN;
        for (int i = m - 1; i < n; ++i) {
            int j = i - m + 1;
            minv = min(minv, nums[j]);
            maxv = max(maxv, nums[j]);
            ans = max({ans, 1LL * nums[i] * minv, 1LL * nums[i] * maxv});
        }
        return ans;
    }
};

Q4. 树中找到带权中位节点

题意

给你一个整数 n,以及一棵 无向带权 树,根节点为节点 0,树中共有 n 个节点,编号从 0 到 n - 1。该树由一个长度为 n - 1 的二维数组 edges 表示,其中 e d g e s [ i ] = [ u i , v i , w i ] edges[i] = [u_i, v_i, w_i] edges[i]=[ui,vi,wi] 表示存在一条从节点 u i u_i ui v i v_i vi 的边,权重为 w i w_i wi

带权中位节点 定义为从 u i u_i ui v i v_i vi 路径上的 第一个 节点 x,使得从 u i u_i ui 到 x 的边权之和 大于等于 该路径总权值和的一半。

给你一个二维整数数组 queries。对于每个 q u e r i e s [ j ] = [ u j , v j ] queries[j] = [u_j, v_j] queries[j]=[uj,vj],求出从 u j u_j uj v j v_j vj 路径上的带权中位节点。

返回一个数组 ans,其中 ans[j] 表示查询 queries[j] 的带权中位节点编号。

示例
示例1
输入: n = 2, edges = [[0,1,7]], queries = [[1,0],[0,1]]
输出: [0,1]
说明

查询 路径 边权 总路径权值和 一半 解释 答案
[1, 0] 1 → 0 [7] 7 3.5 从 1 → 0 的权重和为 7 >= 3.5,中位节点是 0。 0
[0, 1] 0 → 1 [7] 7 3.5 从 0 → 1 的权重和为 7 >= 3.5,中位节点是 1。 1
示例2
输入: n = 3, edges = [[0,1,2],[2,0,4]], queries = [[0,1],[2,0],[1,2]]
输出: [1,0,2]
说明

查询 路径 边权 总路径权值和 一半 解释 答案
[0, 1] 0 → 1 [2] 2 1 从 0 → 1 的权值和为 2 >= 1,中位节点是 1。 1
[2, 0] 2 → 0 [4] 4 2 从 2 → 0 的权值和为 4 >= 2,中位节点是 0。 0
[1, 2] 1 → 0 → 2 [2, 4] 6 3 从 1 → 0 = 2 < 3,
从 1 → 2 = 6 >= 3,中位节点是 2。
2
示例3
输入: n = 5, edges = [[0,1,2],[0,2,5],[1,3,1],[2,4,3]], queries = [[3,4],[1,2]]
输出: [2,2]
说明

查询 路径 边权 总路径权值和 一半 解释 答案
[3, 4] 3 → 1 → 0 → 2 → 4 [1, 2, 5, 3] 11 5.5 从 3 → 1 = 1 < 5.5,
从 3 → 0 = 3 < 5.5,
从 3 → 2 = 8 >= 5.5,中位节点是 2。
2
[1, 2] 1 → 0 → 2 [2, 5] 7 3.5 从 1 → 0 = 2 < 3.5,
从 1 → 2 = 7 >= 3.5,中位节点是 2。
2
提示
  • 2 < = n < = 10 5 2 <= n <= 10^5 2<=n<=105
  • e d g e s . l e n g t h = = n − 1 edges.length == n - 1 edges.length==n1
  • e d g e s [ i ] = = [ u i , v i , w i ] edges[i] == [u_i, v_i, w_i] edges[i]==[ui,vi,wi]
  • 0 < = u i , v i < n 0 <= u_i, v_i < n 0<=ui,vi<n
  • 1 < = w i < = 10 9 1 <= w_i <= 10^9 1<=wi<=109
  • 1 < = q u e r i e s . l e n g t h < = 10 5 1 <= queries.length <= 10^5 1<=queries.length<=105
  • q u e r i e s [ j ] = = [ u j , v j ] queries[j] == [u_j, v_j] queries[j]==[uj,vj]
  • 0 < = u j , v j < n 0 <= u_j, v_j < n 0<=uj,vj<n
  • 输入保证 edges 表示一棵合法的树。
解题思路:二分/树上倍增+LCA

困难题。首先,树上任意两点之间的路径是唯一的,那么从树上一点到另一点经过的点的权值和是非递减的。又因为询问是 10 5 10^5 105 级别的,我们不能每次都去遍历,这样肯定会 TLE。那么怎么快速计算出两点之间的距离呢?显然我们可以在 LCA 计算过程中计算出基于根节点的每个点 u u u 的深度 d e p u dep_u depu 以及权值和 d i s u dis_u disu,那么对于任意两点 u , v u, v u,v,其路径上的权值和如下。
d i s u v = d i s u + d i s v − 2 ∗ d i s f a ,其中 f a = L C A ( u , v ) dis_{uv} = dis_u + dis_v - 2 * dis_{fa},其中 fa = LCA(u, v) disuv=disu+disv2disfa,其中fa=LCA(u,v)

如上所述,其实路径被分为了两条链, u → f a u \rightarrow fa ufa f a → v fa \rightarrow v fav,那么我们可以分别 二分 两条链中的节点,判断该节点是否满足条件,至于中点怎么确定呢?这个时候就需要用到深度了,取 u u u f a fa fa 的深度差,就可以表示他最多能往上走几步,时间复杂度为 O ( q l o g 2 n ) O(qlog^{2}n) O(qlog2n)(其中二分步长需要 O ( l o g n ) O(logn) O(logn),倍增算出向上走具体步数所对应的节点需要 O ( l o g n ) O(logn) O(logn),总共有 q q q 次询问)。

其实求出 f a fa fa 后就能够确认答案在哪条链上了。假如 d i s u − d i s f a > = ⌈ d i s u v 2 ⌉ dis_u - dis_{fa} >= \lceil{\frac{dis_{uv}}{2}}\rceil disudisfa>=2disuv ,答案在第一条链上,否则在第二条链上。加个分类讨论,可以简化代码写法,时间复杂度不变。

再思考一下,既然路径确定,是否可以直接用 倍增 来锁定节点?显然是可以的,跟计算 LCA 是一样的思路,从高位枚举跳跃步数即可,而且可以优化掉二分的时间,时间复杂度为 O ( q l o g n ) O(qlogn) O(qlogn)

参考代码(C++)

二分+LCA版本

using ll = long long;
using pii = pair<int, int>;
const int maxn = 100'005;
const int maxm = 20;
struct tree_ancestor {
    int n, m, id;
    int d[maxn];
    ll dis[maxn];
    int p[maxn][maxm];
    pii pv[maxn];
    vector<pii> adj[maxn];

    void init(int n) {
        this->n = n;
        id = 0;
        for (int i = 0; i < n; ++i)
            adj[i].clear(), d[i] = dis[i] = 0;
        m = 32 - __builtin_clz(n);
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                p[i][j] = -1;
    }

    void add_edge(int u, int v, int w) {
        adj[u].push_back({v, w});
        adj[v].push_back({u, w});
    }

    void dfs(int u, int fu) {
        p[u][0] = fu;
        pv[u].first = id++;
        for (auto& [v, w] : adj[u])
            if (v != fu) {
                d[v] = d[u] + 1;
                dis[v] = dis[u] + w;
                dfs(v, u);
            }
        pv[u].second = id - 1;
    }

    void calc(int root) {
        dfs(root, -1);
        for (int j = 1; j < m; ++j)
            for (int i = 0; i < n; ++i)
                if (int pa = p[i][j - 1]; pa != -1)
                    p[i][j] = p[pa][j - 1];
    }

    int get_kth_ancestor(int u, int k) {
        while (k && u != -1) {
            u = p[u][__builtin_ctz(k)];
            k &= k - 1;
        }
        return u;
    }

    int get_lca(int u, int v) {
        // 返回 u 和 v 的最近公共祖先(节点编号从 0 开始)
        if (d[u] > d[v])
            swap(u, v);
        // 使 u 和 v 在同一深度
        v = get_kth_ancestor(v, d[v] - d[u]);
        if (v == u)
            return v;
        for (int j = m - 1; j >= 0; --j) {
            int pu = p[u][j], pv = p[v][j];
            if (pu != pv)
                u = pu, v = pv;
        }
        return p[u][0];
    }

    bool is_ancestor(int u, int v) {
        return pv[u].first < pv[v].first && pv[v].first <= pv[u].second;
    }
} ta;

class Solution {
public:
    vector<int> findMedian(int n, vector<vector<int>>& edges, vector<vector<int>>& queries) {
        ta.init(n);
        for (auto& e : edges)
            ta.add_edge(e[0], e[1], e[2]);
        ta.calc(0);
        vector<int> ans;
        for (auto& q : queries) {
            int u = q[0], v = q[1];
            int fa = ta.get_lca(u, v);
            ll dt = ta.dis[u] + ta.dis[v] - (ta.dis[fa] << 1);
            int l = 0, r = ta.d[u] - ta.d[fa], res = -1;
            while (l <= r) {
                int mid = (l + r) >> 1;
                int fu = ta.get_kth_ancestor(u, mid);
                if (((ta.dis[u] - ta.dis[fu]) << 1) >= dt) {
                    res = fu;
                    r = mid - 1;
                } else
                    l = mid + 1;
            }
            if (res != -1)
                ans.push_back(res);
            else {
                l = 0, r = ta.d[v] - ta.d[fa], res = fa;
                while (l <= r) {
                    int mid = (l + r) >> 1;
                    int fv = ta.get_kth_ancestor(v, mid);
                    if (((ta.dis[fv] + ta.dis[u] - (ta.dis[fa] << 1)) << 1) >= dt) {
                        res = fv;
                        l = mid + 1;
                    } else
                        r = mid - 1;
                }
                ans.push_back(res);
            }
        }
        return ans;
    }
};

树上倍增(左右尝试)+LCA版本

using ll = long long;
using pii = pair<int, int>;
const int maxn = 100'005;
const int maxm = 20;
struct tree_ancestor {
    int n, m, id;
    int d[maxn];
    ll dis[maxn];
    int p[maxn][maxm];
    pii pv[maxn];
    vector<pii> adj[maxn];

    void init(int n) {
        this->n = n;
        id = 0;
        for (int i = 0; i < n; ++i)
            adj[i].clear(), d[i] = dis[i] = 0;
        m = 32 - __builtin_clz(n);
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                p[i][j] = -1;
    }

    void add_edge(int u, int v, int w) {
        adj[u].push_back({v, w});
        adj[v].push_back({u, w});
    }

    void dfs(int u, int fu) {
        p[u][0] = fu;
        pv[u].first = id++;
        for (auto& [v, w] : adj[u])
            if (v != fu) {
                d[v] = d[u] + 1;
                dis[v] = dis[u] + w;
                dfs(v, u);
            }
        pv[u].second = id - 1;
    }

    void calc(int root) {
        dfs(root, -1);
        for (int j = 1; j < m; ++j)
            for (int i = 0; i < n; ++i)
                if (int pa = p[i][j - 1]; pa != -1)
                    p[i][j] = p[pa][j - 1];
    }

    int get_kth_ancestor(int u, int k) {
        while (k && u != -1) {
            u = p[u][__builtin_ctz(k)];
            k &= k - 1;
        }
        return u;
    }

    int get_lca(int u, int v) {
        // 返回 u 和 v 的最近公共祖先(节点编号从 0 开始)
        if (d[u] > d[v])
            swap(u, v);
        // 使 u 和 v 在同一深度
        v = get_kth_ancestor(v, d[v] - d[u]);
        if (v == u)
            return v;
        for (int j = m - 1; j >= 0; --j) {
            int pu = p[u][j], pv = p[v][j];
            if (pu != pv)
                u = pu, v = pv;
        }
        return p[u][0];
    }

    bool is_ancestor(int u, int v) {
        return pv[u].first < pv[v].first && pv[v].first <= pv[u].second;
    }
} ta;

class Solution {
public:
    vector<int> findMedian(int n, vector<vector<int>>& edges, vector<vector<int>>& queries) {
        ta.init(n);
        for (auto& e : edges)
            ta.add_edge(e[0], e[1], e[2]);
        ta.calc(0);
        vector<int> ans;
        for (auto& q : queries) {
            int u = q[0], v = q[1];
            int fa = ta.get_lca(u, v);
            ll dt = ta.dis[u] + ta.dis[v] - (ta.dis[fa] << 1);
            int ut = u;
            for (int j = ta.m - 1; j >= 0; --j)
                if (ta.p[ut][j] != -1 && ta.d[ta.p[ut][j]] >= ta.d[fa] && ((ta.dis[u] - ta.dis[ta.p[ut][j]]) << 1) < dt)
                    ut = ta.p[ut][j];
            // cout << "ut1:" << ut << '\n';
            if (ut != fa) {
                ut = ta.p[ut][0];
                if (ut != -1 && ta.d[ut] >= ta.d[fa]) {
                    ans.push_back(ut);
                    continue;
                }
            }
            ut = v;
            for (int j = ta.m - 1; j >= 0; --j)
                if (ta.p[ut][j] != -1 && ta.d[ta.p[ut][j]] >= ta.d[fa] && ((ta.dis[u] + ta.dis[ta.p[ut][j]] - (ta.dis[fa] << 1)) << 1) >= dt)
                    ut = ta.p[ut][j];
            // cout << "ut2:" << ut << '\n';
            ans.push_back(ut);
        }
        return ans;
    }
};

树上倍增(判定左右分支)+LCA版本

using ll = long long;
using pii = pair<int, int>;
const int maxn = 100'005;
const int maxm = 20;
struct tree_ancestor {
    int n, m, id;
    int d[maxn];
    ll dis[maxn];
    int p[maxn][maxm];
    pii pv[maxn];
    vector<pii> adj[maxn];

    void init(int n) {
        this->n = n;
        id = 0;
        for (int i = 0; i < n; ++i)
            adj[i].clear(), d[i] = dis[i] = 0;
        m = 32 - __builtin_clz(n);
        for (int i = 0; i < n; ++i)
            for (int j = 0; j < m; ++j)
                p[i][j] = -1;
    }

    void add_edge(int u, int v, int w) {
        adj[u].push_back({v, w});
        adj[v].push_back({u, w});
    }

    void dfs(int u, int fu) {
        p[u][0] = fu;
        pv[u].first = id++;
        for (auto& [v, w] : adj[u])
            if (v != fu) {
                d[v] = d[u] + 1;
                dis[v] = dis[u] + w;
                dfs(v, u);
            }
        pv[u].second = id - 1;
    }

    void calc(int root) {
        dfs(root, -1);
        for (int j = 1; j < m; ++j)
            for (int i = 0; i < n; ++i)
                if (int pa = p[i][j - 1]; pa != -1)
                    p[i][j] = p[pa][j - 1];
    }

    int get_kth_ancestor(int u, int k) {
        while (k && u != -1) {
            u = p[u][__builtin_ctz(k)];
            k &= k - 1;
        }
        return u;
    }

    int get_lca(int u, int v) {
        // 返回 u 和 v 的最近公共祖先(节点编号从 0 开始)
        if (d[u] > d[v])
            swap(u, v);
        // 使 u 和 v 在同一深度
        v = get_kth_ancestor(v, d[v] - d[u]);
        if (v == u)
            return v;
        for (int j = m - 1; j >= 0; --j) {
            int pu = p[u][j], pv = p[v][j];
            if (pu != pv)
                u = pu, v = pv;
        }
        return p[u][0];
    }

    bool is_ancestor(int u, int v) {
        return pv[u].first < pv[v].first && pv[v].first <= pv[u].second;
    }
} ta;

class Solution {
public:
    vector<int> findMedian(int n, vector<vector<int>>& edges, vector<vector<int>>& queries) {
        ta.init(n);
        for (auto& e : edges)
            ta.add_edge(e[0], e[1], e[2]);
        ta.calc(0);
        vector<int> ans;
        for (auto& q : queries) {
            int u = q[0], v = q[1];
            if (u == v) {
                ans.push_back(u);
                continue;
            }
            int fa = ta.get_lca(u, v);
            ll dt = ta.dis[u] + ta.dis[v] - (ta.dis[fa] << 1);
            ll half = (dt + 1) >> 1;
            ll dr = ta.dis[u] - ta.dis[fa];
            // cout << "dt:" << dt << ";half:" << half << ";dr:" << dr << '\n';
            if (dr >= half) {
                int ut = u;
                for (int j = ta.m - 1; j >= 0; --j)
                    if (ta.p[ut][j] != -1 && ta.d[ta.p[ut][j]] >= ta.d[fa] && ta.dis[u] - ta.dis[ta.p[ut][j]] < half)
                        ut = ta.p[ut][j];
                ans.push_back(ta.p[ut][0]);
            } else {
                int ut = v;
                for (int j = ta.m - 1; j >= 0; --j)
                    if (ta.p[ut][j] != -1 && ta.d[ta.p[ut][j]] >= ta.d[fa] && ta.dis[u] + ta.dis[ta.p[ut][j]] - (ta.dis[fa] << 1) >= half)
                        ut = ta.p[ut][j];
                ans.push_back(ut);
            }
        }
        return ans;
    }
};

网站公告

今日签到

点亮在社区的每一天
去签到