从代码学习深度学习 - 预训练word2vec PyTorch版

发布于:2025-05-20 ⋅ 阅读:(12) ⋅ 点赞:(0)


前言

词嵌入(Word Embeddings)是自然语言处理(NLP)领域中的基石技术之一。它们将词语从稀疏的、高维的独热编码(one-hot encoding)表示转换为稠密的、低维的向量表示。这些向量能够捕捉词语之间的语义和句法关系,使得相似的词在向量空间中距离更近。Word2Vec是其中一种非常流行且有效的词嵌入算法,由Google的Tomas Mikolov等人在2013年提出。它主要包含两种模型架构:CBOW(Continuous Bag-of-Words,连续词袋模型)和Skip-gram(跳字模型)。

本篇博客将聚焦于Skip-gram模型,并结合**负采样(Negative Sampling)**这一重要的优化技巧,通过PyTorch框架从零开始实现一个Word2Vec模型。我们将详细探讨数据预处理的每一个步骤,如何构建模型,如何进行训练,以及训练完成后如何应用得到的词向量来寻找相似词。通过深入代码细节,我们希望能帮助读者更好地理解Word2Vec的内部工作原理及其在PyTorch中的实现。

我们将依赖一系列辅助脚本来处理数据、可视化训练过程以及进行模型训练。让我们一步步揭开Word2Vec的神秘面纱。

完整代码:下载链接

辅助工具

在构建和训练Word2Vec模型之前,我们首先介绍一下项目中用到的一些辅助Python脚本。这些脚本提供了数据加载、预处理、可视化以及训练监控等常用功能。

1. 绘图工具 (utils_for_huitu.py)

这个脚本主要封装了使用matplotlib进行绘图的常用函数,特别是在Jupyter Notebook环境中,它包含了一个Animator类,可以动态地展示训练过程中的损失变化。

# 导入必要的包
import matplotlib.pyplot as plt  # 用于创建和操作 Matplotlib 图表
from matplotlib_inline import backend_inline  # 用于在Jupyter中设置Matplotlib输出格式
from IPython import display  # 用于后续动态显示(如 Animator)
import torch  # 导入PyTorch库,用于处理张量类型的图像
import numpy as np  # 导入NumPy,可能用于数据处理
import matplotlib as mpl  # 导入Matplotlib主模块,用于设置图像属性

def set_figsize(figsize=(3.5, 2.5)):
    """
    设置matplotlib图形的大小
    
    参数:
        figsize: tuple[float, float] - 图形大小,形状为 (宽度, 高度),单位为英寸
        
    输出:
        无返回值
    """
    plt.rcParams['figure.figsize'] = figsize  # 设置图形默认大小

def use_svg_display():
    """
    使用 SVG 格式在 Jupyter 中显示绘图
    
    输入:
        无
    输出:
        无返回值
    """
    backend_inline.set_matplotlib_formats('svg')  # 设置 Matplotlib 使用 SVG 格式

def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):
    """设置 Matplotlib 的轴  

    输入:
        axes: Matplotlib 的轴对象  # 输入参数:轴对象
        xlabel: x 轴标签  # 输入参数:x 轴标签
        ylabel: y 轴标签  # 输入参数:y 轴标签
        xlim: x 轴范围  # 输入参数:x 轴范围
        ylim: y 轴范围  # 输入参数:y 轴范围
        xscale: x 轴刻度类型  # 输入参数:x 轴刻度类型
        yscale: y 轴刻度类型  # 输入参数:y 轴刻度类型
        legend: 图例标签列表  # 输入参数:图例标签
    
    输出:
        无返回值  # 函数无显式返回值

    """
    axes.set_xlabel(xlabel)  # 设置 x 轴标签
    axes.set_ylabel(ylabel)  # 设置 y 轴标签
    axes.set_xscale(xscale)  # 设置 x 轴刻度类型
    axes.set_yscale(yscale)  # 设置 y 轴刻度类型
    axes.set_xlim(xlim)  # 设置 x 轴范围
    axes.set_ylim(ylim)  # 设置 y 轴范围
    if legend:  # 检查是否提供了图例标签
        axes.legend(legend)  # 如果有图例,则设置图例
    axes.grid()  # 为轴添加网格线
    
class Animator:
    """在动画中绘制数据,仅针对一张图的情况
    """
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        """初始化 Animator 类 
        输入:
            xlabel: x 轴标签,默认为 None  # 输入参数:x 轴标签
            ylabel: y 轴标签,默认为 None  # 输入参数:y 轴标签
            legend: 图例标签列表,默认为 None  # 输入参数:图例标签
            xlim: x 轴范围,默认为 None  # 输入参数:x 轴范围
            ylim: y 轴范围,默认为 None  # 输入参数:y 轴范围
            xscale: x 轴刻度类型,默认为 'linear'  # 输入参数:x 轴刻度类型
            yscale: y 轴刻度类型,默认为 'linear'  # 输入参数:y 轴刻度类型
            fmts: 绘图格式元组,默认为 ('-', 'm--', 'g-.', 'r:')  # 输入参数:线条格式
            nrows: 子图行数,默认为 1  # 输入参数:子图行数
            ncols: 子图列数,默认为 1  # 输入参数:子图列数
            figsize: 图像大小元组,默认为 (3.5, 2.5)  # 输入参数:图像大小
        
        输出:
            无返回值  # 方法无显式返回值

        定义位置::numref:`sec_softmax_scratch`  # 指明定义的参考位置
        """
        if legend is None:  # 检查 legend 是否为 None
            legend = []  # 如果为 None,则初始化为空列表
        use_svg_display()  # 设置绘图显示为 SVG 格式
        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)  # 创建绘图对象和子图
        if nrows * ncols == 1:  # 判断是否只有一个子图
            self.axes = [self.axes, ]  # 如果是单个子图,将 axes 转为列表
        self.config_axes = lambda: set_axes(  # 定义 lambda 函数配置坐标轴
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)  # 调用 set_axes 设置参数
        self.X, self.Y, self.fmts = None, None, fmts  # 初始化数据和格式属性

    def add(self, x, y):
        """向图表中添加多个数据点  

        输入:
            x: x 轴数据点  # 输入参数:x 轴数据
            y: y 轴数据点  # 输入参数:y 轴数据
        
        输出:
            无返回值  # 方法无显式返回值
        """
        if not hasattr(y, "__len__"):  # 检查 y 是否具有长度属性(是否可迭代)
            y = [y]  # 如果不可迭代,将 y 转为单元素列表
        n = len(y)  # 获取 y 的长度
        if not hasattr(x, "__len__"):  # 检查 x 是否具有长度属性
            x = [x] * n  # 如果不可迭代,将 x 扩展为与 y 同长度的列表
        if not self.X:  # 检查 self.X 是否已初始化
            self.X = [[] for _ in range(n)]  # 如果未初始化,为每条线创建空列表
        if not self.Y:  # 检查 self.Y 是否已初始化
            self.Y = [[] for _ in range(n)]  # 如果未初始化,为每条线创建空列表
        for i, (a, b) in enumerate(zip(x, y)):  # 遍历 x 和 y 的数据对
            if a is not None and b is not None:  # 检查数据点是否有效
                self.X[i].append(a)  # 将 x 数据点添加到对应列表
                self.Y[i].append(b)  # 将 y 数据点添加到对应列表
        self.axes[0].cla()  # 清除当前轴的内容
        for x, y, fmt in zip(self.X, self.Y, self.fmts):  # 遍历所有数据和格式
            self.axes[0].plot(x, y, fmt)  # 绘制每条线
        self.config_axes()  # 调用 lambda 函数配置坐标轴
        display.display(self.fig)  # 显示当前图形
        display.clear_output(wait=True)  # 标记当前输出为待清除,但由于 wait=True,它不会立即清除,而是等待下一次 display.display()。

def show_list_len_pair_hist(legend, xlabel, ylabel, xlist, ylist):
    """
    绘制列表长度对的直方图,用于比较两组列表中元素长度的分布
    
    参数:
        legend: list[str] - 图例标签,形状为 (2,),分别对应xlist和ylist的标签
        xlabel: str - x轴标签
        ylabel: str - y轴标签
        xlist: list[list] - 第一组列表,形状为 (样本数量, 每个样本的元素数)
        ylist: list[list] - 第二组列表,形状为 (样本数量, 每个样本的元素数)
    
    输出:
        无返回值,但会显示生成的直方图
    """
    set_figsize()  # 设置图形大小
    
    # plt.hist返回的三个值:
    # n: list[array] - 每个bin中的样本数量,形状为 (2, bin数量)
    # bins: array - bin的边界值,形状为 (bin数量+1,)
    # patches: list[list[Rectangle]] - 直方图的矩形对象,形状为 (2, bin数量)
    _, _, patches = plt.hist(
        [[len(l) for l in xlist], [len(l) for l in ylist]])  # 绘制两组数据长度的直方图
    
    plt.xlabel(xlabel)  # 设置x轴标签
    plt.ylabel(ylabel)  # 设置y轴标签
    
    # 为第二组数据(ylist)的直方图添加斜线图案,以区分两组数据
    for patch in patches[1].patches:  # patches[1]是ylist对应的矩形对象列表
        patch.set_hatch('/')  # 设置填充图案为斜线
    
    plt.legend(legend)  # 添加图例

解读

  • set_figsizeuse_svg_display 用于基础的Matplotlib绘图设置。
  • set_axes 是一个通用的函数,用于配置图表的坐标轴标签、范围、刻度类型和图例。
  • Animator 类是实现动态绘图的关键。在训练循环中,我们可以周期性地调用其add方法,传入当前的训练轮次(或迭代次数)和对应的损失值(或其他指标)。Animator会清除旧的图像并重新绘制,从而在Jupyter Notebook中形成动画效果,直观地展示训练趋势。
  • show_list_len_pair_hist 函数用于绘制两个列表集合中,各子列表长度分布的直方图,方便进行数据分析和比较。

2. 数据处理工具 (utils_for_data.py)

这个脚本是Word2Vec数据预处理的核心,包含了从读取原始文本、构建词汇表、下采样、生成中心词-上下文词对、负采样到最终打包成PyTorch DataLoader的完整流程。

from collections import Counter  # 导入 Counter 类
from collections import Counter  # 用于词频统计
import torch  # PyTorch 核心库
from torch.utils import data  # PyTorch 数据加载工具
import numpy as np  # NumPy 用于数组操作
import random  # 导入随机模块,用于下采样和负采样
import math  # 导入数学函数模块,用于概率计算
import os

def count_corpus(tokens):
    """
    统计词元的频率
    
    参数:
        tokens: 词元列表,可以是:
            - 一维列表,例如 ['a', 'b']
            - 二维列表,例如 [['a', 'b'], ['c']]
    
    返回值:
        Counter: Counter 对象,统计每个词元的出现次数
    """
    # 如果输入为空列表,直接返回空计数器
    if not tokens:  # 等价于 len(tokens) == 0
        return Counter()
    
    # 检查输入是否为二维列表
    if isinstance(tokens[0], list):
        # 将二维列表展平为一维列表
        flattened_tokens = [token for sublist in tokens for token in sublist]
    else:
        # 如果是一维列表,直接使用原列表
        flattened_tokens = tokens
    
    # 使用 Counter 统计词频并返回
    return Counter(flattened_tokens)


class Vocab:
    """文本词表类,用于管理词元及其索引的映射关系"""

    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        """初始化词表

        Args:
            tokens: 输入的词元列表,可以是1D或2D列表,默认为空列表
            min_freq: 词元最小出现频率,小于此频率的词元将被忽略,默认为0
            reserved_tokens: 预留的特殊词元列表(如'<pad>'),默认为空列表
        """
        # 处理默认参数
        self.tokens = tokens if tokens is not None else []
        self.reserved_tokens = reserved_tokens if reserved_tokens is not None else []

        # 统计词元频率并按频率降序排序
        counter = self._count_corpus(self.tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)

        # 初始化词表,'<unk>'为未知词元,索引为0
        self.idx_to_token = ['<unk>'] + self.reserved_tokens
        self.token_to_idx = {
   token: idx for idx, token in enumerate(self.idx_to_token)}

        # 添加满足最小频率要求的词元到词表
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = 

网站公告

今日签到

点亮在社区的每一天
去签到