【C++】map和set的模拟实现

发布于:2025-05-14 ⋅ 阅读:(13) ⋅ 点赞:(0)

1.底层红黑树节点的定义

enum Colur
{
	RED,
	BLACK
};
template<class T> 
struct RBTreeNode
{
	RBTreeNode<T>* _left;
	RBTreeNode<T>* _right;
	RBTreeNode<T>* _parent;
	T _data;
	Colur _col;

	RBTreeNode(const T& data)
		:_left(nullptr)
		, _right(nullptr)
		, _parent(nullptr)
		,_data(data)
		,_col(RED)
	{ }
};

只需要一个模板参数就可以实现map和set,T可以接收key作为set,也可以接收一个键值对作为map。map和set用的同一个类模板,通过不同的模板参数实现不同的类

2.底层红黑树迭代器的实现

typedef重定义

在这里插入图片描述
1.Self: 表示当前实例化的 _TreeIterator 类型,其指针类型和引用类型由模板参数 Ptr 和 Ref 决定。这使得 Self 可以适用于不同的指针和引用类型,通常用于类内部实现,例如在成员函数中返回或传递迭代器时,保持与当前迭代器类型一致。。
2.terator: 表示一个具体的 _TreeIterator 类型,其中指针类型是 T*,引用类型是 T&。这是一个特定的迭代器类型通常用于提供给用户的标准接口,例如在容器类中提供 begin 和 end 方法返回这种类型的迭代器。。

构造函数

在这里插入图片描述
1.是关键,由上图定义可知,iterator一直是一个普通迭代器。有以下两种情况A:当前类被实例化为const迭代器时,这是一个构造函数支持普通迭代器构造const迭代器,将普通迭代器 it 的节点指针 it.node 赋值给当前 const 迭代器的 node 成员变量。
B:当前类被实例化为普通迭代器时,该函数为拷贝构造
作用: 允许 _TreeIterator 类型的迭代器对象能够从另一个 Iterator 类型的迭代器对象进行初始化,确保两者指向同一个节点。
2.node: 这是一个指向红黑树节点的指针。_node(node): 通过初始化列表将传入的节点指针赋给当前迭代器的 _node 成员变量。
作用: 允许直接通过一个节点指针初始化 _TreeIterator 对象,使得迭代器指向指定的节点。

重载前置++

在这里插入图片描述

Self& operator++()
{
	//右子树不为空访问其最左节点(最小)
	if (_node->_right)
	{
		Node* subleft = _node->_right;
		while (subleft->_left)
		{
			subleft = subleft->_left;
		}
		_node = subleft;
	}
	//else//找孩子是父亲左的那个祖先节点,就是下一个要访问的节点
	//{
	//	Node* cur = _node;
	//	Node* parent = cur->_parent;
	//	while (parent)
	//	{
	//		if (parent->_left == cur)
	//		{
	//			break;
	//		}
	//		else
	//		{
	//			//继续往上更新节点
	//			cur = cur->_parent;
	//			parent = parent->_parent;
	//		}
	//	}
	//	_node = parent;
	//}
	else
	{
		Node* cur = _node;
		Node* parent = cur->_parent;
		while (parent && cur == parent->_right)
		{
			cur = cur->_parent;
			parent = parent->_parent;
		}
		_node = parent;
	}
	return *this;
}

else语句有两种实现方式,第二种更为简洁清晰。

重载前置–

与++顺序相反,按右子树 根 左子树的顺序访问节点 。实现思路反过来即可:
1.左树不为空,访问左树的最右节点(最大节点)
2.左树为空,代表该子树已访问完成,访问孩子是父亲右的那个祖先节点

Self& operator--()
{
	if (_node->_left)
	{
		Node* subright = _node->_left;
		while (subright->_right)
		{
			subright = subright->_right;
		}
		_node = subright;
	}
	else
	{
		Node* cur = _node;
		Node* parent = cur->_parent;
		while (parent && cur == parent->_left)
		{
			cur = cur->_parent;
			parent = parent->_parent;
		}
		_node = parent;
	}
	return *this;
}

看一下库里面的实现方式:
在这里插入图片描述
相比我们自己的模拟实现,增加了一个哨兵位节点,与根节点双向链接,通过leftmost函数找到左树中的最右节点,结束标志即返回哨兵位节点为空。这样在查找下一个访问的节点时通过调用函数固然很方便,但在旋转维护红黑树性质时增加了一定的复杂性。
自己模拟实现时通过while (parent && cur == parent->_left)该条件加上一个哨兵位节点也可以实现库中方式,该循环条件避免了cur和parent的无限循环错误

基本操作

在这里插入图片描述

3.底层红黑树的实现

通用模板

在这里插入图片描述
K:键类型,用于在红黑树中进行排序和查找
T:存储在红黑树中的元素类型。对于set,T就是K;对于映射map,T是一个键值对pair。
keyofT:一个函数对象,用于从元素类型T中提取键K。主要服务于map,对于映射map就是pair的第一个元素迭代器。set为了保持结构一致也用该函数,不过就直接返回key即可

获取迭代器的成员函数

//起始位置是树的最左节点,即最小节点
iterator begin()
{
	Node* leftmin = _root;
	while (leftmin&&leftmin->_left)
	{
		leftmin = leftmin->_left;
	}
	//调用迭代器的构造函数
	return iterator(leftmin);
}

iterator end()
{
//调用迭代器的构造函数
	return iterator(nullptr);
}

const_iterator begin()const
{
	Node* leftmin = _root;
	while (leftmin && leftmin->_left)
	{
		leftmin = leftmin->_left;
	}
	return const_iterator(leftmin);
}

const_iterator end()const
{
	return const_iterator(nullptr);
}

查找节点

Node* Find(const K& key)
{
	Node* cur = _root;
	keyofT kot;
	while (cur)
	{
		if (kot(cur->_data) < key)
		{
			cur = cur->_right;
		}
		else if (kot(cur->_data) > key)
		{
			cur = cur->_left;
		}
		else//相等找到了
		{
			return cur;
		}
	}
	return nullptr;
}

需注意与普通红黑树的比较方式不同,需调用kot来进行比较。
kot 的作用是从节点的数据中提取键值,使得红黑树的查找、插入和删除等操作能够基于键值进行比较和决策。这种设计使得红黑树模板具有通用性,可以适应不同的数据类型和键提取逻辑。

插入节点

与普通红黑树实现思路相同,只做出以下改变
1.通过kot提取节点数据的键值进行比较
2.返回值变成一个键值对,first为迭代器,second为bool值表示插入是否成功

pair<iterator,bool> Insert(const T& data)
{
	if (_root == nullptr)
	{
		_root = new Node(data);
		_root->_col = BLACK;
		return make_pair(iterator(_root),true);
	}
	//查找插入位置
	Node* cur = _root;
	Node* parent = nullptr;

	//map和set共用模板,需要将data转化为键值
	keyofT kot;
	while (cur)
	{
		if (kot(cur->_data)<kot(data))
		{
			parent = cur;
			cur = cur->_right;
		}
		else if (kot(cur->_data)>kot(data))
		{
			parent = cur;
			cur = cur->_left;
		}
		else
		{//找到相同键值
			return make_pair(iterator(cur), false);
		}
	}
	//插入新节点
	cur = new Node(data);
	cur->_col = RED;
	//cur要去遍历,需提前保存
	Node* newnode = cur;
	if (kot(parent->_data)>kot(data))
	{
		parent->_left = cur;
	}
	else
	{
		parent->_right = cur;
	}
	cur->_parent = parent;
	while (parent&&parent->_col==RED)
	{
		Node* grandfather = parent->_parent;
		if (parent == grandfather->_left)
		{
			Node* uncle = grandfather->_right;
			//uncle存在且为红
			if (uncle && uncle->_col == RED)
			{
				//变色
				parent->_col = uncle->_col = BLACK;
				grandfather->_col = RED;
				//更新节点.继续向上处理
				cur= grandfather;//注意赋值的顺序
				parent = grandfather->_parent;
			}
			else//uncle不存在或为黑
			{//判断是单旋还是双旋
				if (cur == parent->_left)
				{
					//	  g
					//	p
					//c
					RotateR(grandfather);
					parent->_col = BLACK;
					grandfather->_col = RED;
				 }
				else
				{
					//    g
					//  p
					//	  c
					RotateL(parent);
					RotateR(grandfather);
					cur->_col = BLACK;
					grandfather->_col = RED;
				}
				break;
			}
		}
		else//parent == grandfather->_right
		{
			Node* uncle = grandfather->_left;
			//uncle存在且为红
			if (uncle && uncle->_col == RED)
			{
				parent->_col = uncle->_col = BLACK;
				grandfather->_col = RED;
				//更新节点.继续向上处理
				cur= grandfather;
				parent = grandfather->_parent;
			}
			else//uncle不存在或为黑
			{//判断是单旋还是双旋
				if (cur == parent->_right)
				{
					//g  
					//  p
					//    c
					RotateL(grandfather);
					parent->_col = BLACK;
					grandfather->_col = RED;
				}
				else
				{
					//g
					//  p
					//c
					RotateR(parent);
					RotateL(grandfather);
					cur->_col = BLACK;
					grandfather->_col = RED;
				}
				break;
			}
		}
	}
	_root->_col = BLACK;
	return make_pair(iterator(newnode), true);
}

其余代码如单双旋,测试部分与i普通红黑树相同,详见文章末尾总体代码

4.map的模拟实现

map的底层结构就是红黑树,因此在map中直接封装一棵红黑树,然后将其接口包装下即可

#pragma once
#include"RBTree.h"
namespace ee
{
	template<class K,class V>
	class map
	{
		struct MapkeyofT
		{
			const K&operator() (const pair<K, V>&kv)
			{
				return kv.first;
			}
		};
	public:
		typedef typename RBTree<K, pair<const K, V>, MapkeyofT>::iterator iterator;
		typedef typename RBTree<K, pair<const K, V>, MapkeyofT>::const_iterator const_iterator;


		iterator begin()
		{
			return _t.begin();
		}

		iterator end()
		{
			return _t.end();
		}

		const_iterator begin()const
		{
			return _t.begin();
		}

		const_iterator end()const
		{
			return _t.end();
		}

		V& operator[](const K& key)
		{
			pair<iterator, bool> ret = insert(make_pair(key, c));
			return ret.first->second;
		}

		pair<iterator,bool> insert(const pair<K, V>& kv)
		{
			return _t.Insert(kv);
		}
	private:
		RBTree<K, pair<const K, V>, MapkeyofT> _t;
	};
}

解析:
1.map 类是一个模板类,K 表示键的类型,V 表示值的类型。
2.迭代器中pair<const K, V>保证了键值不可修改的同时能修改元素值,const迭代器就两个都不能修改
3.[]重载:提供了对指定键的值的访问。

默认构造值V():当键不存在时,operator[] 会插入一个默认构造的值。这意味着每次访问不存在的键时,都会自动创建一个新的键值对。
如果键已存在,返回已存在值的引用。

5.set的模拟实现

#pragma once
#include"RBTree.h"
namespace ee
{
	template<class K>
	class set
	{
		struct SetkeyofT
		{
			const K& operator()(const K&key)
			{
				return key;
			}
		};
	public:
		typedef typename RBTree<K, K, SetkeyofT>::const_iterator iterator;
		typedef typename RBTree<K, K, SetkeyofT>::const_iterator const_iterator;

		//普通迭代器底层也是const迭代器,直接定义成const
		const_iterator begin()const
		{
			return _t.begin();
		}

		const_iterator end()const
		{
			return _t.end();
		}

		pair<iterator,bool> insert(const K& key)
		{
			return _t.Insert(key);
		}
		//pair<iterator, bool> insert(const K& key)
		//{
		//	// pair<RBTree::iterator, bool>
		//	pair<typename RBTree<K, K, SetkeyofT>::iterator, bool> ret = _t.Insert(key);
		//	return pair<iterator, bool>(ret.first, ret.second);
		//}
	private:
		RBTree<K, K, SetkeyofT> _t;
	};
}

解析:
与map不同的是,迭代器的定义const和非const迭代器的底层都为const迭代器,保证了键值的不可修改.但这样会引起插入时的问题

在这里插入图片描述
由于set模拟实现中定义的迭代器底层也是const迭代器,所以存在返回值不匹配的问题。_t是一个普通对象调用的是普通迭代器,所以我们指定返回底层红黑树中也就是普通迭代器然后再通过构造函数用普通迭代器构造const迭代器,如下图
在这里插入图片描述

6.整体代码

  • 底层红黑树改装
#pragma once
#include<iostream>
#include<vector>
using namespace std;

enum Colur
{
	RED,
	BLACK
};
template<class T> 
struct RBTreeNode
{
	RBTreeNode<T>* _left;
	RBTreeNode<T>* _right;
	RBTreeNode<T>* _parent;
	T _data;
	Colur _col;

	RBTreeNode(const T& data)
		:_left(nullptr)
		, _right(nullptr)
		, _parent(nullptr)
		,_data(data)
		,_col(RED)
	{ }
};

template<class T,class Ptr,class Ref>
struct _TreeIterator
{
	typedef RBTreeNode<T> Node;
	typedef _TreeIterator<T,Ptr,Ref> Self;
	typedef _TreeIterator<T, T*, T&> Iterator;
	Node* _node;

	// 因为这里写了这个转换,所以可以从非 const iterator转化到const的iterator,
	// 如果注释掉了的话,代码就编译出错了
	_TreeIterator(const Iterator&it)
		:_node(it._node)
	{ }

	_TreeIterator(Node* node)
		:_node(node)
	{}

	Ref operator*()
	{
		return _node->_data;
	}

	Ptr operator->()
	{
		return &_node->_data;
	}

	bool operator!=(const Self& s)
	{
		return _node != s._node;
	}
	//这里是重载前置++
	Self& operator++()
	{
		//右子树不为空访问其最左节点(最小)
		if (_node->_right)
		{
			Node* subleft = _node->_right;
			while (subleft->_left)
			{
				subleft = subleft->_left;
			}
			_node = subleft;
		}
		//else//找孩子是父亲左的那个祖先节点,就是下一个要访问的节点
		//{
		//	Node* cur = _node;
		//	Node* parent = cur->_parent;
		//	while (parent)
		//	{
		//		if (parent->_left == cur)
		//		{
		//			break;
		//		}
		//		else
		//		{
		//			//继续往上更新节点
		//			cur = cur->_parent;
		//			parent = parent->_parent;
		//		}
		//	}
		//	_node = parent;
		//}
		else
		{
			Node* cur = _node;
			Node* parent = cur->_parent;
			while (parent && cur == parent->_right)
			{
				cur = cur->_parent;
				parent = parent->_parent;
			}
			_node = parent;
		}
		return *this;
	}
	
	Self& operator--()
	{
		if (_node->_left)
		{
			Node* subright = _node->_left;
			while (subright->_right)
			{
				subright = subright->_right;
			}
			_node = subright;
		}
		else
		{
			Node* cur = _node;
			Node* parent = cur->_parent;
			while (parent && cur == parent->_left)
			{
				cur = cur->_parent;
				parent = parent->_parent;
			}
			_node = parent;
		}
		return *this;
	}
};

// set->RBTree<K, K, SetKeyOfT> _t;
// map->RBTree<K, pair<K, V>, MapKeyOfT> _t;
template<class K,class T,class keyofT>
struct RBTree
{
	typedef RBTreeNode<T> Node;
public:
	typedef _TreeIterator<T,T*,T&> iterator;
	typedef _TreeIterator<T,const T*,const T&> const_iterator;

	iterator begin()
	{
		Node* leftmin = _root;
		while (leftmin&&leftmin->_left)
		{
			leftmin = leftmin->_left;
		}
		//调用迭代器的构造函数
		return iterator(leftmin);
	}

	iterator end()
	{
		return iterator(nullptr);
	}

	const_iterator begin()const
	{
		Node* leftmin = _root;
		while (leftmin && leftmin->_left)
		{
			leftmin = leftmin->_left;
		}
		//调用迭代器的构造函数
		return const_iterator(leftmin);
	}

	const_iterator end()const
	{
		return const_iterator(nullptr);
	}

	Node* Find(const K& key)
	{
		Node* cur = _root;
		keyofT kot;
		while (cur)
		{
			if (kot(cur->_data) < key)
			{
				cur = cur->_right;
			}
			else if (kot(cur->_data) > key)
			{
				cur = cur->_left;
			}
			else//相等找到了
			{
				return cur;
			}
		}
		return nullptr;
	}

	pair<iterator,bool> Insert(const T& data)
	{
		if (_root == nullptr)
		{
			_root = new Node(data);
			_root->_col = BLACK;
			return make_pair(iterator(_root),true);
		}
		//查找插入位置
		Node* cur = _root;
		Node* parent = nullptr;

		//map和set共用模板,需要将data转化为键值
		keyofT kot;
		while (cur)
		{
			if (kot(cur->_data)<kot(data))
			{
				parent = cur;
				cur = cur->_right;
			}
			else if (kot(cur->_data)>kot(data))
			{
				parent = cur;
				cur = cur->_left;
			}
			else
			{//找到相同键值
				return make_pair(iterator(cur), false);
			}
		}
		//插入新节点
		cur = new Node(data);
		cur->_col = RED;
		//cur要去遍历,需提前保存
		Node* newnode = cur;
		if (kot(parent->_data)>kot(data))
		{
			parent->_left = cur;
		}
		else
		{
			parent->_right = cur;
		}
		cur->_parent = parent;
		while (parent&&parent->_col==RED)
		{
			Node* grandfather = parent->_parent;
			if (parent == grandfather->_left)
			{
				Node* uncle = grandfather->_right;
				//uncle存在且为红
				if (uncle && uncle->_col == RED)
				{
					//变色
					parent->_col = uncle->_col = BLACK;
					grandfather->_col = RED;
					//更新节点.继续向上处理
					cur= grandfather;//注意赋值的顺序
					parent = grandfather->_parent;
				}
				else//uncle不存在或为黑
				{//判断是单旋还是双旋
					if (cur == parent->_left)
					{
						//	  g
						//	p
						//c
						RotateR(grandfather);
						parent->_col = BLACK;
						grandfather->_col = RED;
					 }
					else
					{
						//    g
						//  p
						//	  c
						RotateL(parent);
						RotateR(grandfather);
						cur->_col = BLACK;
						grandfather->_col = RED;
					}
					break;
				}
			}
			else//parent == grandfather->_right
			{
				Node* uncle = grandfather->_left;
				//uncle存在且为红
				if (uncle && uncle->_col == RED)
				{
					parent->_col = uncle->_col = BLACK;
					grandfather->_col = RED;
					//更新节点.继续向上处理
					cur= grandfather;
					parent = grandfather->_parent;
				}
				else//uncle不存在或为黑
				{//判断是单旋还是双旋
					if (cur == parent->_right)
					{
						//g  
						//  p
						//    c
						RotateL(grandfather);
						parent->_col = BLACK;
						grandfather->_col = RED;
					}
					else
					{
						//g
						//  p
						//c
						RotateR(parent);
						RotateL(grandfather);
						cur->_col = BLACK;
						grandfather->_col = RED;
					}
					break;
				}
			}
		}
		_root->_col = BLACK;
		return make_pair(iterator(newnode), true);
	}
	void RotateL(Node* parent)
	{
		_rotateCount++;
		Node* cur = parent->_right;
		Node* curleft = cur->_left;
		Node* ppnode = parent->_parent;
		//第一次改变链接
		parent->_right = curleft;
		if (curleft)
		{
			curleft->_parent = parent;
		}
		//第二次改变链接
		cur->_left = parent;
		parent->_parent = cur;
		//判断根节点的链接情况
		//为根节点调整平衡因子情况
		if (parent == _root)
		{
			_root = cur;
			cur->_parent = nullptr;
		}
		//树中的部分调整情况
		else
		{
			if (ppnode->_left == parent)
			{
				ppnode->_left = cur;
			}
			else
			{
				ppnode->_right = cur;
			}
			cur->_parent = ppnode;
		}
	}

	void RotateR(Node* parent)
	{
		_rotateCount++;
		Node* cur = parent->_left;
		Node* curright = cur->_right;
		Node* ppnode = parent->_parent;
		//第一次链接
		parent->_left = curright;
		if (curright)

		{
			curright->_parent = parent;
		}
		//第二次链接
		cur->_right = parent;
		parent->_parent = cur;
		//调整根节点链接关系
		if (parent == _root)
		{
			_root = cur;
			cur->_parent = nullptr;
		}
		else
		{
			if (ppnode->_left == parent)
			{
				ppnode->_left = cur;
			}
			else
			{
				ppnode->_right = cur;
			}
			cur->_parent = ppnode;
		}
	}

	bool CheckColur(Node* root, int blacknum, int benchmark)
	{
		if (root == nullptr)
		{
			if (blacknum != benchmark)
			{
				return false;
			}
			return true;
		}
		
		//判断树中黑节点数量
		if (root->_col==BLACK)
		{
			blacknum++;
		}
		//判断树中是否有连续红节点情况
		if (root->_col == RED && root->_parent && root->_parent->_col == RED)
		{
			cout << root->_kv.first << "出现连续红色节点" << endl;
			return false;
		}

		return CheckColur(root->_left, blacknum, benchmark)
			&& CheckColur(root->_right, blacknum, benchmark);
	}

	bool IsBalance()
	{
		return IsBalance(_root);
	}
	bool IsBalance(Node* root)
	{
		if (root == nullptr)
		{
			return true;
		}
		if (root->_col != BLACK)
		{
			return false;
		}

		//计算基准值
		int benchmark = 0;
		Node* cur = root;
		while (cur)
		{
			if(cur->_col==BLACK)
			benchmark++;
			//选择一条路径来计算黑节点数量
			cur = cur->_left;
		}

		return CheckColur(root, 0, benchmark);
	}

	int Height()
	{
		return Height(_root);
	}
	
	int Height(Node* root)
	{
		if (root == nullptr)
		{
			return 0;
		}
		int leftheight = Height(root->_left);
		int rightheight = Height(root->_right);
		return leftheight > rightheight ?
			leftheight + 1 : rightheight + 1;
	}

	bool IsBST() {
		vector<int> result;
		_Inorder(_root,result);
		for (int i = 1; i < result.size(); i++) {
			if (result[i] <= result[i - 1]) {
				return false;
			}
		}
		return true;
	}
	void _Inorder(Node* root, vector<K>& result) {
		if (root == nullptr) {
			return;
		}
		_Inorder(root->_left, result);
		result.push_back(root->_kv.first);
		_Inorder(root->_right, result);
	}

	public:
		int _rotateCount = 0;
	private:
		Node* _root = nullptr;
};
  • 测试代码
#include<iostream>
using namespace std;
#include"MyMap.h"
#include"MySet.h"


int main()
{
	ee::set<int> s;
	auto iterator = s.insert(10);
	s.insert(24);
	s.insert(3);
	s.insert(4);
	s.insert(34);
	s.insert(1024);
	ee::set<int>::iterator it = s.begin();
	while (it != s.end())
	{
		/*if (*it % 2 == 0)
		{
			*it += 10;
		}*/
		cout << *it <<" ";
		++it;
	}
	cout << endl;

	ee::map<int, int> m;
	m.insert(make_pair(1, 1));
	m.insert(make_pair(2, 2));
	m.insert(make_pair(3, 3));
	for (const auto& kv : m)
	{
		cout << kv.first << ":" << kv.second << endl;
	}
	cout << endl;
	//测试map的[]重载
	ee::map<string, string>dict;
	dict.insert(make_pair("学习", "nice"));
	dict["study"];//插入

	for (const auto& kv : dict)
	{
		cout << kv.first << " " << kv.second << endl;
	}
	cout << endl;

	dict["study"] = "好好学习";//修改
	dict["学习"] = "deserve";
	dict["mint"] = "薄荷";//插入+修改
	for (const auto& kv : dict)
	{
		cout << kv.first << " " << kv.second << endl;
	}
	cout << endl;
	return 0;
}