数据结构与算法——字典(前缀)树的实现

发布于:2025-08-02 ⋅ 阅读:(11) ⋅ 点赞:(0)

参考视频:左程云--算法讲解044【必备】前缀树原理和代码详解

类实现:

class Trie {
  private:
    class TrieNode {
      public:
        int pass;
        int end;
        vector<TrieNode*> nexts;
        TrieNode(): pass(0), end(0), nexts(26, nullptr) {}
    };

    TrieNode* root;  // 根指针

  public:
    Trie() {
        root = new TrieNode();
    }

    void insert(const string& word) {
        TrieNode* node = root;
        node->pass++;
        int path;
        for(char c : word){
            path = c - 'a';
            if (node->nexts[path] == nullptr) {
                node->nexts[path] = new TrieNode();
            }
            node = node->nexts[path];
            node->pass++;
        }
        node->end++;
    }

    void mydelete(const string& word) {
        if (countWordsEqualTo(word) > 0) {
            TrieNode* node = root;
            node->pass--;
            int path;
            for (char c : word) {
                path = c - 'a';
                if (--node->nexts[path]->pass == 0) {
                    funcDelete(node, path);
                    return;
                }
                node = node->nexts[path];
            }
            node->end--;
        }
    }

    void funcDelete(TrieNode* node, int path) {
        TrieNode* target = node->nexts[path];
        node->nexts[path] = nullptr;
        delete (target);
    }

    bool search(const string& word) {
        return countWordsEqualTo(word) > 0 ? true : false;
    }

    int countWordsEqualTo(const string& word) {
        TrieNode* node = root;
        int path;
        for (char c : word) {
            path = c - 'a';
            if (node->nexts[path] == nullptr) {
                return 0;
            }
            node = node->nexts[path];
        }
        return node->end;
    }

    int prefixNumber(const string& pre) {
        TrieNode* node = root;
        int path;
        for (char c : pre) {
            path = c - 'a';
            if (node->nexts[path] == nullptr) {
                return 0;
            }
            node = node->nexts[path];
        }
        return node->pass;
    }
};

类实现基础上进行哈希表的优化:

class Trie {
private:
    class TrieNode {
    public:
        int pass;
        int end;
        unordered_map<int, TrieNode*> nexts;
        TrieNode() : pass(0), end(0) {}
    };

    TrieNode* root;  // 根指针

public:
    Trie() {
        root = new TrieNode();
    }

    void insert(const string& word) {
        TrieNode* node = root;
        node->pass++;
        int path;
        for (char c : word) {
            path = c - 'a';
            if (node->nexts.find(path) == node->nexts.end()) {
                node->nexts.insert({ path, new TrieNode() });
            }
            node = node->nexts[path];
            node->pass++;
        }
        node->end++;
    }

    void mydelete(const string& word) {
        if (countWordsEqualTo(word) > 0) {
            TrieNode* node = root;
            node->pass--;
            int path;
            for (char c : word) {
                path = c - 'a';
                if (--node->nexts[path]->pass == 0) {
                    node->nexts.erase(path);
                    return;
                }
                node = node->nexts[path];
            }
            node->end--;
        }
    }

    bool search(const string& word) {
        return countWordsEqualTo(word) > 0 ? true : false;
    }

    int countWordsEqualTo(const string& word) {
        TrieNode* node = root;
        int path;
        for (char c : word) {
            path = c - 'a';
            if (node->nexts.find(path) == node->nexts.end()) {
                return 0;
            }
            node = node->nexts[path];
        }
        return node->end;
    }

    int prefixNumber(const string& pre) {
        TrieNode* node = root;
        int path;
        for (char c : pre) {
            path = c - 'a';
            if (node->nexts.find(path) == node->nexts.end()) {
                return 0;
            }
            node = node->nexts[path];
        }
        return node->pass;
    }
};

静态数组实现 + 牛客测试

测试链接:牛客题霸--字典树的实现

#include <bits/stdc++.h>
using namespace std;
const static int N = 1e6 + 5;
int trieArr[N][26];
int passArr[N];
int endArr[N];
int cnt;
void helpFunc(int op, string& word);
class trie{
public:
    void static buildTrie(){
        cnt = 1;
    }

    void static insert(const string& word){
        int cur = 1;
        passArr[cur]++;
        int path;
        for(char c : word){
            path = c - 'a';
            if(trieArr[cur][path] == 0){
                trieArr[cur][path] = ++cnt;
            }
            cur = trieArr[cur][path];
            passArr[cur]++;
        }
        endArr[cur]++;
    }

    int static searchCount(const string& word){
        int cur = 1, path;
        for(char c : word){
            path = c - 'a';
            if(trieArr[cur][path] == 0){
                return 0;
            }
            cur = trieArr[cur][path];
        }
        return endArr[cur];
    }

    bool static search(const string& word){
        return searchCount(word) > 0 ? true : false;
    }

    int static prefixCount(const string& pre){
        int cur = 1, path;
        for(char c : pre){
            path = c - 'a';
            if(trieArr[cur][path] == 0){
                return 0;
            }
            cur = trieArr[cur][path];
        }
        return passArr[cur];
    }

    void static deleteWord(const string& word){
        if(searchCount(word) > 0){
            int cur = 1, path;
            passArr[cur]--;
            for(char c : word){
                path = c - 'a';
                if(--passArr[trieArr[cur][path]] == 0){
                    trieArr[cur][path] = 0;
                    return;
                }
                cur = trieArr[cur][path];
            }
            endArr[cur]--;
        }
    }

    void static rebuildTrie(){
        for(int i = 1; i <= cnt; i++){
            fill(&trieArr[i][0], &trieArr[i][0] + 26, 0);
            passArr[i] = 0;
            endArr[i] = 0;
        }
    }
};

int main() {
    int m;
    cin >> m;
    trie::buildTrie();
    while(m--){
        int op;
        string word;
        cin >> op >> word;
        helpFunc(op, word);
    }
    trie::rebuildTrie();
    return 0;
}

void helpFunc(int op, string& word){
    switch(op){
        case 1 :
            trie::insert(word);
            break;
        case 2 :
            trie::deleteWord(word);
            break;
        case 3 :
            if(trie::search(word)){
                cout << "YES" << endl;
            }else{
                cout << "NO" << endl;
            }
            break;
        case 4 :
            cout << trie::prefixCount(word) <<endl;
            break;
        default:
            break;
    }
}