文章目录
前言
词嵌入(Word Embeddings)是自然语言处理(NLP)领域中的基石技术之一。它们将词语从稀疏的、高维的独热编码(one-hot encoding)表示转换为稠密的、低维的向量表示。这些向量能够捕捉词语之间的语义和句法关系,使得相似的词在向量空间中距离更近。Word2Vec是其中一种非常流行且有效的词嵌入算法,由Google的Tomas Mikolov等人在2013年提出。它主要包含两种模型架构:CBOW(Continuous Bag-of-Words,连续词袋模型)和Skip-gram(跳字模型)。
本篇博客将聚焦于Skip-gram模型,并结合**负采样(Negative Sampling)**这一重要的优化技巧,通过PyTorch框架从零开始实现一个Word2Vec模型。我们将详细探讨数据预处理的每一个步骤,如何构建模型,如何进行训练,以及训练完成后如何应用得到的词向量来寻找相似词。通过深入代码细节,我们希望能帮助读者更好地理解Word2Vec的内部工作原理及其在PyTorch中的实现。
我们将依赖一系列辅助脚本来处理数据、可视化训练过程以及进行模型训练。让我们一步步揭开Word2Vec的神秘面纱。
完整代码:下载链接
辅助工具
在构建和训练Word2Vec模型之前,我们首先介绍一下项目中用到的一些辅助Python脚本。这些脚本提供了数据加载、预处理、可视化以及训练监控等常用功能。
1. 绘图工具 (utils_for_huitu.py
)
这个脚本主要封装了使用matplotlib
进行绘图的常用函数,特别是在Jupyter Notebook环境中,它包含了一个Animator
类,可以动态地展示训练过程中的损失变化。
# 导入必要的包
import matplotlib.pyplot as plt # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display # 用于后续动态显示(如 Animator)
import torch # 导入PyTorch库,用于处理张量类型的图像
import numpy as np # 导入NumPy,可能用于数据处理
import matplotlib as mpl # 导入Matplotlib主模块,用于设置图像属性
def set_figsize(figsize=(3.5, 2.5)):
"""
设置matplotlib图形的大小
参数:
figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸
输出:
无返回值
"""
plt.rcParams['figure.figsize'] = figsize # 设置图形默认大小
def use_svg_display():
"""
使用 SVG 格式在 Jupyter 中显示绘图
输入:
无
输出:
无返回值
"""
backend_inline.set_matplotlib_formats('svg') # 设置 Matplotlib 使用 SVG 格式
def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
"""设置 Matplotlib 的轴
输入:
axes: Matplotlib 的轴对象 # 输入参数:轴对象
xlabel: x 轴标签 # 输入参数:x 轴标签
ylabel: y 轴标签 # 输入参数:y 轴标签
xlim: x 轴范围 # 输入参数:x 轴范围
ylim: y 轴范围 # 输入参数:y 轴范围
xscale: x 轴刻度类型 # 输入参数:x 轴刻度类型
yscale: y 轴刻度类型 # 输入参数:y 轴刻度类型
legend: 图例标签列表 # 输入参数:图例标签
输出:
无返回值 # 函数无显式返回值
"""
axes.set_xlabel(xlabel) # 设置 x 轴标签
axes.set_ylabel(ylabel) # 设置 y 轴标签
axes.set_xscale(xscale) # 设置 x 轴刻度类型
axes.set_yscale(yscale) # 设置 y 轴刻度类型
axes.set_xlim(xlim) # 设置 x 轴范围
axes.set_ylim(ylim) # 设置 y 轴范围
if legend: # 检查是否提供了图例标签
axes.legend(legend) # 如果有图例,则设置图例
axes.grid() # 为轴添加网格线
class Animator:
"""在动画中绘制数据,仅针对一张图的情况
"""
def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
figsize=(3.5, 2.5)):
"""初始化 Animator 类
输入:
xlabel: x 轴标签,默认为 None # 输入参数:x 轴标签
ylabel: y 轴标签,默认为 None # 输入参数:y 轴标签
legend: 图例标签列表,默认为 None # 输入参数:图例标签
xlim: x 轴范围,默认为 None # 输入参数:x 轴范围
ylim: y 轴范围,默认为 None # 输入参数:y 轴范围
xscale: x 轴刻度类型,默认为 'linear' # 输入参数:x 轴刻度类型
yscale: y 轴刻度类型,默认为 'linear' # 输入参数:y 轴刻度类型
fmts: 绘图格式元组,默认为 ('-', 'm--', 'g-.', 'r:') # 输入参数:线条格式
nrows: 子图行数,默认为 1 # 输入参数:子图行数
ncols: 子图列数,默认为 1 # 输入参数:子图列数
figsize: 图像大小元组,默认为 (3.5, 2.5) # 输入参数:图像大小
输出:
无返回值 # 方法无显式返回值
定义位置::numref:`sec_softmax_scratch` # 指明定义的参考位置
"""
if legend is None: # 检查 legend 是否为 None
legend = [] # 如果为 None,则初始化为空列表
use_svg_display() # 设置绘图显示为 SVG 格式
self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize) # 创建绘图对象和子图
if nrows * ncols == 1: # 判断是否只有一个子图
self.axes = [self.axes, ] # 如果是单个子图,将 axes 转为列表
self.config_axes = lambda: set_axes( # 定义 lambda 函数配置坐标轴
self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend) # 调用 set_axes 设置参数
self.X, self.Y, self.fmts = None, None, fmts # 初始化数据和格式属性
def add(self, x, y):
"""向图表中添加多个数据点
输入:
x: x 轴数据点 # 输入参数:x 轴数据
y: y 轴数据点 # 输入参数:y 轴数据
输出:
无返回值 # 方法无显式返回值
"""
if not hasattr(y, "__len__"): # 检查 y 是否具有长度属性(是否可迭代)
y = [y] # 如果不可迭代,将 y 转为单元素列表
n = len(y) # 获取 y 的长度
if not hasattr(x, "__len__"): # 检查 x 是否具有长度属性
x = [x] * n # 如果不可迭代,将 x 扩展为与 y 同长度的列表
if not self.X: # 检查 self.X 是否已初始化
self.X = [[] for _ in range(n)] # 如果未初始化,为每条线创建空列表
if not self.Y: # 检查 self.Y 是否已初始化
self.Y = [[] for _ in range(n)] # 如果未初始化,为每条线创建空列表
for i, (a, b) in enumerate(zip(x, y)): # 遍历 x 和 y 的数据对
if a is not None and b is not None: # 检查数据点是否有效
self.X[i].append(a) # 将 x 数据点添加到对应列表
self.Y[i].append(b) # 将 y 数据点添加到对应列表
self.axes[0].cla() # 清除当前轴的内容
for x, y, fmt in zip(self.X, self.Y, self.fmts): # 遍历所有数据和格式
self.axes[0].plot(x, y, fmt) # 绘制每条线
self.config_axes() # 调用 lambda 函数配置坐标轴
display.display(self.fig) # 显示当前图形
display.clear_output(wait=True) # 标记当前输出为待清除,但由于 wait=True,它不会立即清除,而是等待下一次 display.display()。
def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
"""
绘制列表长度对的直方图,用于比较两组列表中元素长度的分布
参数:
legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签
xlabel: str - x轴标签
ylabel: str - y轴标签
xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)
ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)
输出:
无返回值,但会显示生成的直方图
"""
set_figsize() # 设置图形大小
# plt.hist返回的三个值:
# n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)
# bins: array - bin的边界值,形状为 (bin数量+1,)
# patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)
_, _, patches = plt.hist(
[[len(l) for l in xlist], [len(l) for l in ylist]]) # 绘制两组数据长度的直方图
plt.xlabel(xlabel) # 设置x轴标签
plt.ylabel(ylabel) # 设置y轴标签
# 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据
for patch in patches[1].patches: # patches[1]是ylist对应的矩形对象列表
patch.set_hatch('/') # 设置填充图案为斜线
plt.legend(legend) # 添加图例
解读:
set_figsize
和use_svg_display
用于基础的Matplotlib绘图设置。set_axes
是一个通用的函数,用于配置图表的坐标轴标签、范围、刻度类型和图例。Animator
类是实现动态绘图的关键。在训练循环中,我们可以周期性地调用其add
方法,传入当前的训练轮次(或迭代次数)和对应的损失值(或其他指标)。Animator
会清除旧的图像并重新绘制,从而在Jupyter Notebook中形成动画效果,直观地展示训练趋势。show_list_len_pair_hist
函数用于绘制两个列表集合中,各子列表长度分布的直方图,方便进行数据分析和比较。
2. 数据处理工具 (utils_for_data.py
)
这个脚本是Word2Vec数据预处理的核心,包含了从读取原始文本、构建词汇表、下采样、生成中心词-上下文词对、负采样到最终打包成PyTorch DataLoader
的完整流程。
from collections import Counter # 导入 Counter 类
from collections import Counter # 用于词频统计
import torch # PyTorch 核心库
from torch.utils import data # PyTorch 数据加载工具
import numpy as np # NumPy 用于数组操作
import random # 导入随机模块,用于下采样和负采样
import math # 导入数学函数模块,用于概率计算
import os
def count_corpus(tokens):
"""
统计词元的频率
参数:
tokens: 词元列表,可以是:
- 一维列表,例如 ['a', 'b']
- 二维列表,例如 [['a', 'b'], ['c']]
返回值:
Counter: Counter 对象,统计每个词元的出现次数
"""
# 如果输入为空列表,直接返回空计数器
if not tokens: # 等价于 len(tokens) == 0
return Counter()
# 检查输入是否为二维列表
if isinstance(tokens[0], list):
# 将二维列表展平为一维列表
flattened_tokens = [token for sublist in tokens for token in sublist]
else:
# 如果是一维列表,直接使用原列表
flattened_tokens = tokens
# 使用 Counter 统计词频并返回
return Counter(flattened_tokens)
class Vocab:
"""文本词表类,用于管理词元及其索引的映射关系"""
def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
"""初始化词表
Args:
tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表
min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0
reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表
"""
# 处理默认参数
self.tokens = tokens if tokens is not None else []
self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []
# 统计词元频率并按频率降序排序
counter = self._count_corpus(self.tokens)
self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
# 初始化词表,'<unk>'为未知词元,索引为0
self.idx_to_token = ['<unk>'] + self.reserved_tokens
self.token_to_idx = {
token: idx for idx, token in enumerate(self.idx_to_token)}
# 添加满足最小频率要求的词元到词表
for token, freq in self._token_freqs:
if freq < min_freq:
break
if token not in self.token_to_idx:
self.idx_to_token.append(token)
self.token_to_idx[token] =