AVL树的简洁写法

发布于:2025-06-30 ⋅ 阅读:(13) ⋅ 点赞:(0)


零、写在前面

大二学数据结构的时候写的AVL代码稀烂,回过头来重制一下,在不使用父指针的情况下以较为简洁的代码实现AVL。

只是为了方便自己快速复习下AVL树原理,不作过多原理说明,详见:AVL树详解[C++]


一、AVL 树定义

AVL 树 是一种平衡二叉搜索树,由两位俄罗斯的数学家 G.M.Adelson-Velskii 和 E.M.Landis 在1962年发明,并以他们名字的首字母命名。

1.1 性质

  1. 空二叉树是一个 AVL 树
  2. 如果 T 是一棵 AVL 树,那么其左右子树也是 AVL 树,并且 |height(lc) - height(rc) <= 1|,h 是其左右子树的高度
  3. 树高为 O(logn)

平衡因子:左子树高度 - 右子树高度

1.2 树高的证明

设 f n 为高度为 n 的 A V L 树所包含的最少节点数,则有 f n = { 1 ( n = 1 ) 2 ( n = 2 ) f n − 1 + f n − 2 + 1 ( n > 2 ) 根据常系数非齐次线性差分方程的解法 ( 或者对转移矩阵求特征向量并相似对角化 ) , { f n + 1 } 是一个斐波那契数列。这里 f n 的通项为: f n = 5 + 2 5 5 ( 1 + 5 2 ) n + 5 − 2 5 5 ( 1 − 5 2 ) n − 1 斐波那契数列以指数的速度增长,对于树高 n 有: n < log ⁡ 1 + 5 2 ( f n + 1 ) < 3 2 log ⁡ 2 ( f n + 1 ) 因此 A V L 树的高度为 O ( log ⁡ f n ) ,这里的 f n 为结点数。 设 f_{n} 为高度为 n 的 AVL 树所包含的最少节点数,则有 \\ f_{n}=\left\{\begin{array}{ll} 1 & (n=1) \\ 2 & (n=2) \\ f_{n-1}+f_{n-2}+1 & (n>2) \end{array}\right. \\ 根据常系数非齐次线性差分方程的解法(或者对转移矩阵求特征向量并相似对角化), \left\{f_{n}+1\right\} 是一个斐波那契数列。这里 f_{n} 的通项为: \\ f_{n}=\frac{5+2 \sqrt{5}}{5}\left(\frac{1+\sqrt{5}}{2}\right)^{n}+\frac{5-2 \sqrt{5}}{5}\left(\frac{1-\sqrt{5}}{2}\right)^{n}-1 \\ 斐波那契数列以指数的速度增长,对于树高 n 有: \\ n<\log _{\frac{1+\sqrt{5}}{2}}\left(f_{n}+1\right)<\frac{3}{2} \log _{2}\left(f_{n}+1\right) \\ 因此 AVL 树的高度为 O\left(\log f_{n}\right) ,这里的 f_{n} 为结点数。 fn为高度为nAVL树所包含的最少节点数,则有fn= 12fn1+fn2+1(n=1)(n=2)(n>2)根据常系数非齐次线性差分方程的解法(或者对转移矩阵求特征向量并相似对角化){fn+1}是一个斐波那契数列。这里fn的通项为:fn=55+25 (21+5 )n+5525 (215 )n1斐波那契数列以指数的速度增长,对于树高n有:n<log21+5 (fn+1)<23log2(fn+1)因此AVL树的高度为O(logfn),这里的fn为结点数。

二、AVL树实现(AVL树实现名次树)

2.1 节点定义

struct Node {
    Node *l = nullptr, *r = nullptr;	// 左右儿子
    Key key;						// 关键字
    u32 h = 1;						// 高度
    int siz = 1;					// 子树大小
    Node(Key k): key(k) {}			// 构造函数
} *root = nullptr;
// 树高
int height(Node *t) {
    return t ? t->h : 0;
}
// 树大小
int size(Node *t) {
    return t ? t->siz : 0;
}
// 修正
void pull(Node *t) {
    t->h = std::max(height(t->l), height(t->r)) + 1;
    t->siz = size(t->l) + 1 + size(t->r);
}
// 平衡因子
// 根据考研408 给出的标准 height(t->l) - height(t->r)
// 其他的教材,如邓公的代码则与本文相反
int factor(Node *t) {
    return height(t->l) - height(t->r);
}

2.2 左/右旋转

  • 旋转不改变中序
  • 可用于单倾斜型重平衡调整
// 左旋
void rotateL(Node *&t) {
    Node *r = t->r;
    t->r = r->l;
    r->l = t;
    pull(t);
    pull(r);
    t = r;
}
// 右旋
void rotateR(Node *&t) {
    Node *l = t->l;
    t->l = l->r;
    l->r = t;
    pull(t);
    pull(l);
    t = l;
}

2.3 zig-zag / zag-zig 双旋

  • 双旋用于重平衡
// 右左双旋
void rotateRL(Node *&t) {
    rotateR(t->r);
    rotateL(t);
}
// 左右双旋
void rotateLR(Node *&t) {
    rotateL(t->l);
    rotateR(t);
}

2.4 重平衡函数

  • 在插入或者删除过程中,自顶向上回溯调整
  • 如果是 单倾斜型则单旋
  • 否则双旋
  • 两种情况,分为四种子情况,代码对偶
void reBalance(Node *&t) {
    int diff = factor(t);
    if (diff == 2) {
        int diff = height(t->l->l) - height(t->l->r);
        if (diff >= 0) {
            rotateR(t);
        } else {
            rotateLR(t);
        }
    } else if (diff == -2) {
        int diff = height(t->r->r) - height(t->r->l);
        if (diff >= 0) {
            rotateL(t);
        } else {
            rotateRL(t);
        }
    }
    pull(t);
}

2.5 插入

bool insert(Node *&t, Key key) {
    if (!t) {
        t = new Node(key);
        return true;
    }
    // 是否多重集
    // if (t->key == key) return false;
    bool res = insert(key < t->key ? t->l : t->r, key);
    reBalance(t);
    return res;
}

2.6 删除

  • 对于删除节点,找到中序后继节点替换删除即可
bool erase(Node *&t, Key key) {
    if (!t) return false;
    bool res;
    if (t->key == key) {
        if (t->l && t->r) {
            Node *del = t->r;
            while (del->l) {
                del = del->l;
            }
            std::swap(t->key, del->key);
            erase(t->r, key);
        }
        else {
            t = t->l ? t->l : t->r;
        }
        res = true;
    }
    else {
        res = erase(key < t->key ? t->l : t->r, key);
    }
    if (t) {
        reBalance(t);
    }
    return res;
}

2.7 排名查询

  • 定义一个关键字x的排名为 树中比 x 小的节点数 + 1
int rank(Key key) {
    Node *t = root;
    int res = 0;
    while (t) {
        if (t->key < key) {
            res += size(t->l) + 1;
            t = t->r;
        } else {
            t = t->l;
        }
    }
    return res + 1;
}

2.8 查前驱/后继

Node *pre(Key key) {
    Node *res = nullptr;
    Node *t = root;
    while (t) {
        if (t->key < key) {
            res = t;
            t = t->r;
        } else {
            t = t->l;
        }
    }
    return res;
}

Node *suf(Key key) {
    Node *res = nullptr;
    Node *t = root;
    while (t) {
        if (t->key > key) {
            res = t;
            t = t->l;
        } else {
            t = t->r;
        }
    }
    return res;
}

2.9 查第 k 小

Node *kth(int k) {
    Node *res = root;
    while(res) {
        if (size(res->l) >= k) {
            res = res->l;
        } else if(size(res->l) + 1 == k) {
            break;
        } else {
            k -= size(res->l) + 1;
            res = res->r;
        }
    }
    return res;
}

2.10 完整代码

template<typename Key>
class AVLTree {
private:
    struct Node {
        Node *l = nullptr, *r = nullptr;
        Key key;
        u32 h = 1;
        int siz = 1;
        Node(Key k): key(k) {}
    } *root = nullptr;

    int height(Node *t) {
        return t ? t->h : 0;
    }

    int size(Node *t) {
        return t ? t->siz : 0;
    }

    void pull(Node *t) {
        t->h = std::max(height(t->l), height(t->r)) + 1;
        t->siz = size(t->l) + 1 + size(t->r);
    }

    int factor(Node *t) {
        return height(t->l) - height(t->r);
    }

    void rotateL(Node *&t) {
        Node *r = t->r;
        t->r = r->l;
        r->l = t;
        pull(t);
        pull(r);
        t = r;
    }

    void rotateR(Node *&t) {
        Node *l = t->l;
        t->l = l->r;
        l->r = t;
        pull(t);
        pull(l);
        t = l;
    }

    void rotateRL(Node *&t) {
        rotateR(t->r);
        rotateL(t);
    }

    void rotateLR(Node *&t) {
        rotateL(t->l);
        rotateR(t);
    }

    void reBalance(Node *&t) {
        int diff = factor(t);
        if (diff == 2) {
            int diff = height(t->l->l) - height(t->l->r);
            if (diff >= 0) {
                rotateR(t);
            } else {
                rotateLR(t);
            }
        } else if (diff == -2) {
            int diff = height(t->r->r) - height(t->r->l);
            if (diff >= 0) {
                rotateL(t);
            } else {
                rotateRL(t);
            }
        }
        pull(t);
    }

    bool insert(Node *&t, Key key) {
        if (!t) {
            t = new Node(key);
            return true;
        }
        // 是否多重集
        // if (t->key == key) return false;
        bool res = insert(key < t->key ? t->l : t->r, key);
        reBalance(t);
        return res;
    }

    bool erase(Node *&t, Key key) {
        if (!t) return false;
        bool res;
        if (t->key == key) {
            if (t->l && t->r) {
                Node *del = t->r;
                while (del->l) {
                    del = del->l;
                }
                std::swap(t->key, del->key);
                erase(t->r, key);
            }
            else {
                t = t->l ? t->l : t->r;
            }
            res = true;
        }
        else {
            res = erase(key < t->key ? t->l : t->r, key);
        }
        if (t) {
            reBalance(t);
        }
        return res;
    }
public:
    bool insert(Key key) {
        return insert(root, key);
    }

    bool erase(Key key) {
        return erase(root, key);
    }

    int rank(Key key) {
        Node *t = root;
        int res = 0;
        while (t) {
            if (t->key < key) {
                res += size(t->l) + 1;
                t = t->r;
            } else {
                t = t->l;
            }
        }
        return res + 1;
    }

    Node *pre(Key key) {
        Node *res = nullptr;
        Node *t = root;
        while (t) {
            if (t->key < key) {
                res = t;
                t = t->r;
            } else {
                t = t->l;
            }
        }
        return res;
    }

    Node *suf(Key key) {
        Node *res = nullptr;
        Node *t = root;
        while (t) {
            if (t->key > key) {
                res = t;
                t = t->l;
            } else {
                t = t->r;
            }
        }
        return res;
    }

    Node *kth(int k) {
        Node *res = root;
        while(res) {
            if (size(res->l) >= k) {
                res = res->l;
            } else if(size(res->l) + 1 == k) {
                break;
            } else {
                k -= size(res->l) + 1;
                res = res->r;
            }
        }
        return res;
    }

    std::vector<Key> getAll() {
        std::vector<Key> res;
        auto dfs = [&](auto &&self, Node *t) -> void{
            if (!t) return;
            self(self, t->l);
            res.push_back(t->key);
            self(self, t->r);
        };
        dfs(dfs, root);
        return res;
    }
};

三、online judge 验证

3.1 P6136 【模板】普通平衡树(数据加强版)

题目链接

P6136 【模板】普通平衡树(数据加强版)

AC代码

#include <bits/stdc++.h>

using i64 = long long;
using u32 = unsigned int;

template<typename Key>
class AVLTree {
private:
    struct Node {
        Node *l = nullptr, *r = nullptr;
        Key key;
        u32 h = 1;
        int siz = 1;
        Node(Key k): key(k) {}
    } *root = nullptr;

    int height(Node *t) {
        return t ? t->h : 0;
    }

    int size(Node *t) {
        return t ? t->siz : 0;
    }

    void pull(Node *t) {
        t->h = std::max(height(t->l), height(t->r)) + 1;
        t->siz = size(t->l) + 1 + size(t->r);
    }

    int factor(Node *t) {
        return height(t->l) - height(t->r);
    }

    void rotateL(Node *&t) {
        Node *r = t->r;
        t->r = r->l;
        r->l = t;
        pull(t);
        pull(r);
        t = r;
    }

    void rotateR(Node *&t) {
        Node *l = t->l;
        t->l = l->r;
        l->r = t;
        pull(t);
        pull(l);
        t = l;
    }

    void rotateRL(Node *&t) {
        rotateR(t->r);
        rotateL(t);
    }

    void rotateLR(Node *&t) {
        rotateL(t->l);
        rotateR(t);
    }

    void reBalance(Node *&t) {
        int diff = factor(t);
        if (diff == 2) {
            int diff = height(t->l->l) - height(t->l->r);
            if (diff >= 0) {
                rotateR(t);
            } else {
                rotateLR(t);
            }
        } else if (diff == -2) {
            int diff = height(t->r->r) - height(t->r->l);
            if (diff >= 0) {
                rotateL(t);
            } else {
                rotateRL(t);
            }
        }
        pull(t);
    }

    bool insert(Node *&t, Key key) {
        if (!t) {
            t = new Node(key);
            return true;
        }
        // 是否多重集
        // if (t->key == key) return false;
        bool res = insert(key < t->key ? t->l : t->r, key);
        reBalance(t);
        return res;
    }

    bool erase(Node *&t, Key key) {
        if (!t) return false;
        bool res;
        if (t->key == key) {
            if (t->l && t->r) {
                Node *del = t->r;
                while (del->l) {
                    del = del->l;
                }
                std::swap(t->key, del->key);
                erase(t->r, key);
            }
            else {
                t = t->l ? t->l : t->r;
            }
            res = true;
        }
        else {
            res = erase(key < t->key ? t->l : t->r, key);
        }
        if (t) {
            reBalance(t);
        }
        return res;
    }
public:
    bool insert(Key key) {
        return insert(root, key);
    }

    bool erase(Key key) {
        return erase(root, key);
    }

    int rank(Key key) {
        Node *t = root;
        int res = 0;
        while (t) {
            if (t->key < key) {
                res += size(t->l) + 1;
                t = t->r;
            } else {
                t = t->l;
            }
        }
        return res + 1;
    }

    Node *pre(Key key) {
        Node *res = nullptr;
        Node *t = root;
        while (t) {
            if (t->key < key) {
                res = t;
                t = t->r;
            } else {
                t = t->l;
            }
        }
        return res;
    }

    Node *suf(Key key) {
        Node *res = nullptr;
        Node *t = root;
        while (t) {
            if (t->key > key) {
                res = t;
                t = t->l;
            } else {
                t = t->r;
            }
        }
        return res;
    }

    Node *kth(int k) {
        Node *res = root;
        while(res) {
            if (size(res->l) >= k) {
                res = res->l;
            } else if(size(res->l) + 1 == k) {
                break;
            } else {
                k -= size(res->l) + 1;
                res = res->r;
            }
        }
        return res;
    }

    std::vector<Key> getAll() {
        std::vector<Key> res;
        auto dfs = [&](auto &&self, Node *t) -> void{
            if (!t) return;
            self(self, t->l);
            res.push_back(t->key);
            self(self, t->r);
        };
        dfs(dfs, root);
        return res;
    }
};

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    int n, m;
    std::cin >> n >> m;
    
    AVLTree<int> set;

    for (int i = 0; i < n; ++ i) {
        int x;
        std::cin >> x;
        set.insert(x);
    }

    int last = 0;
    int ans = 0;

    while(m --) {
        int t, x;
        std::cin >> t >> x;
        x ^= last;
        if (t == 1) {
            set.insert(x);
        } else if(t == 2) {
            set.erase(x);
        } else if(t == 3) {
            last = set.rank(x);
            ans ^= last;
        } else if(t == 4) {
            last = set.kth(x)->key;
            ans ^= last;
        } else if(t == 5) {
            last = set.pre(x)->key;
            ans ^= last;
        } else {
            last = set.suf(x)->key;
            ans ^= last;
        }
    }

    std::cout << ans << '\n';

    return 0;
}