从代码学习深度学习 - 来自Transformers的双向编码器表示(BERT) PyTorch版

发布于:2025-06-19 ⋅ 阅读:(16) ⋅ 点赞:(0)


前言

在自然语言处理(NLP)的世界里,词嵌入技术是基石。从早期的 Word2Vec、GloVe 等上下文无关(context-independent)模型,到后来能够根据上下文动态调整词表示的 ELMo、GPT 等上下文敏感(context-sensitive)模型,我们见证了 NLP 领域表示学习的飞速发展。

上下文无关模型,如 Word2Vec,为每个词分配一个固定的向量,无法区分多义词。例如,“bank”在“river bank”(河岸)和“investment bank”(投资银行)中会被赋予完全相同的表示。

为了解决这个问题,ELMo 和 GPT 等模型应运而生。ELMo 使用双向 LSTM 来编码上下文,但其下游任务通常需要一个特定于任务的模型架构。而 GPT 使用强大的 Transformer 解码器,是任务无关的,但其自回归的特性使其只能“从左到右”地编码上下文,无法同时利用左右两侧的信息。

2018年,BERT(Bidirectional Encoder Representations from Transformers)横空出世,它集众家之长,革命性地改变了 NLP 领域。BERT 不仅实现了真正的双向上下文编码,而且其任务无关的设计使其能够以最小的架构改动,在众多 NLP 任务中取得顶尖(SOTA)的性能。

在这里插入图片描述

本篇文章将以 PyTorch 为工具,深入剖析 BERT 的内部结构和实现细节。我们将从输入表示开始,一步步构建 BERT 编码器,并实现其两个核心的预训练任务:掩蔽语言模型(MLM)和下一句预测(NSP)。让我们通过代码,揭开 BERT 的神秘面纱。

完整代码:下载链接

输入表示

BERT 的一个精妙之处在于其能够灵活处理单个文本和文本对。为了实现这一点,BERT 对输入序列进行了特殊格式化。

  • 单个文本输入:格式为 [CLS] 文本序列A [SEP]
  • 文本对输入:格式为 [CLS] 文本序列A [SEP] 文本序列B [SEP]

其中:

  • [CLS]:一个特殊的分类标记。它不对应任何真实词元,但其在 BERT 输出中的最终表示被设计为聚合整个输入序列的信息,通常用于分类任务。
  • [SEP]:一个特殊的分隔标记,用于分隔不同的文本片段。

为了让模型能够区分文本对中的两个句子(例如,在问答任务中区分问题和上下文),BERT 引入了 片段嵌入(Segment Embeddings)。第一个句子的所有词元会加上片段嵌入 A,第二个句子的所有词元会加上片段嵌入 B。

下面的 get_tokens_and_segments 函数清晰地展示了这一过程。它接收一个或两个文本序列(已分词),并返回符合 BERT 格式的词元列表及其对应的片段索引。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BERT模型输入序列处理工具
用于获取输入序列的词元及其对应的片段索引
"""

from typing import List, Tuple, Optional


def get_tokens_and_segments(tokens_a: List[str], tokens_b: Optional[List[str]] = None) -> Tuple[List[str], List[int]]:
    """
    获取输入序列的词元及其片段索引
    
    该函数用于处理BERT模型的输入序列,将单个或两个文本序列转换为包含特殊标记的词元列表,
    并生成对应的片段索引,用于区分不同的输入片段。
    
    参数:
        tokens_a (List[str]): 第一个输入序列的词元列表
                             维度: [seq_len_a] - seq_len_a为第一个序列的长度
        tokens_b (Optional[List[str]], 可选): 第二个输入序列的词元列表,默认为None
                                           维度: [seq_len_b] - seq_len_b为第二个序列的长度
    
    返回:
        Tuple[List[str], List[int]]: 包含两个元素的元组
            - tokens (List[str]): 处理后的完整词元序列
                                维度: [total_len] - total_len为最终序列总长度
                                单序列时: total_len = seq_len_a + 2 (包含<cls>和<sep>)
                                双序列时: total_len = seq_len_a + seq_len_b + 3 (包含<cls>和两个<sep>)
            - segments (List[int]): 对应的片段索引列表
                                  维度: [total_len] - 与tokens长度相同
                                  0表示第一个片段(包括<cls>和第一个<sep>)
                                  1表示第二个片段(包括第二个<sep>)
    
    示例:
        >>> tokens_a = ['hello', 'world']
        >>> tokens, segments = get_tokens_and_segments(tokens_a)
        >>> print(tokens)    # ['<cls>', 'hello', 'world', '<sep>']
        >>> print(segments)  # [0, 0, 0, 0]
        
        >>> tokens_b = ['good', 'morning']
        >>> tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
        >>> print(tokens)    # ['<cls>', 'hello', 'world', '<sep>', 'good', 'morning', '<sep>']
        >>> print(segments)  # [0, 0, 0, 0, 1, 1, 1]
    """
    # 构建第一个片段的词元序列
    # tokens维度: [seq_len_a + 2] - 包含<cls>标记 + tokens_a + <sep>标记
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    
    # 构建第一个片段的索引序列,全部标记为0
    # segments维度: [seq_len_a + 2] - 与tokens长度对应
    segments = [0] * (len(tokens_a) + 2)
    
    # 如果存在第二个输入序列
    if tokens_b is not None:
        # 将第二个序列的词元添加到tokens中,并在末尾添加<sep>标记
        # tokens维度更新为: [seq_len_a + seq_len_b + 3]
        tokens += tokens_b + ['<sep>']
        
        # 为第二个序列生成片段索引,全部标记为1,并添加到segments中
        # segments维度更新为: [seq_len_a + seq_len_b + 3] - 与tokens长度对应
        segments += [1] * (len(tokens_b) + 1)
    
    return tokens, segments

# 调用示例
if __name__ == "__main__":
    print("=" * 60)
    print("BERT输入序列处理示例")
    print("=" * 60)
    
    # 示例1: 单个序列处理
    print("\n【示例1】单个序列处理:")
    print("-" * 30)
    tokens_a = ['我', '喜欢', '自然', '语言', '处理']
    print(f"输入序列A: {
     tokens_a}")
    print(f"序列A长度: {
     len(tokens_a)}")
    
    tokens, segments = get_tokens_and_segments(tokens_a)
    print(f"处理后词元: {
     tokens}")
    print(f"片段索引:   {
     segments}")
    print(f"最终长度: {
     len(tokens)} (原长度{
     len(tokens_a)} + 2个特殊标记)")
    
    # 示例2: 两个序列处理(句子对分类任务)
    print("\n【示例2】两个序列处理(句子对任务):")
    print("-" * 30)
    tokens_a = ['今天', '天气', '很好']
    tokens_b = ['适合', '外出', '游玩']
    print(f"输入序列A: {
     tokens_a}")
    print(f"输入序列B: {
     tokens_b}")
    print(f"序列A长度: {
     len(tokens_a)}, 序列B长度: {
     len(tokens_b)}")
    
    tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
    print(f"处理后词元: {
     tokens}")
    print(f"片段索引:   {
     segments}")
    print(f"最终长度: {
     len(tokens)} (A:{
     len(tokens_a)} + B:{
     len(tokens_b)} + 3个特殊标记)")
    
    # 示例3: 问答任务示例
    print("\n【示例3】问答任务示例:")
    print("-" * 30)
    question = ['什么', '是', 'BERT', '模型']
    context = ['BERT', '是', '一种', '预训练', '语言', '模型']
    print(f"问题: {
     question}")
    print(f"上下文: {
     context}")
    
    tokens, segments = get_tokens_and_segments(question, context)
    print(f"处理后词元: {
     tokens}")
    print(f"片段索引:   {
     segments}")
    
    # 分析片段索引的含义
    print("\n【片段索引说明】:")
    print("-" * 30)
    print("索引0: 问题部分(包括<cls>和第一个<sep>)")
    print("索引1: 上下文部分(包括第二个<sep>)")
    
    # 示例4: 英文示例
    print("\n【示例4】英文文本处理:")
    print("-" * 30)
    tokens_a = ['hello', 'world']
    tokens_b = ['good', 'morning']
    print(f"English A: {
     tokens_a}")
    print(f"English B: {
     tokens_b}")
    
    tokens, segments = get_tokens_and_segments(tokens_a, tokens_b)
    print(f"Processed tokens: {
     tokens}")
    print(f"Segment indices:  {
     segments}")
    
    print("\n" + "=" * 60)
    print("示例运行完成!")

运行结果:

============================================================
BERT输入序列处理示例
============================================================

【示例1】单个序列处理:
------------------------------
输入序列A: ['我', '喜欢', '自然', '语言', '处理']
序列A长度: 5
处理后词元: ['<cls>', '我', '喜欢', '自然', '语言', '处理', '<sep>']
片段索引:   [0, 0, 0, 0, 0, 0, 0]
最终长度: 7 (原长度5 + 2个特殊标记)

【示例2】两个序列处理(句子对任务):
------------------------------
输入序列A: ['今天', '天气', '很好']
输入序列B: ['适合', '外出', '游玩']
序列A长度: 3, 序列B长度: 3
处理后词元: ['<cls>', '今天', '天气', '很好', '<sep>', '适合', '外出', '游玩', '<sep>']
片段索引:   [0, 0, 0, 0, 0, 1, 1, 1, 1]
最终长度: 9 (A:3 + B:3 + 3个特殊标记)

【示例3】问答任务示例:
------------------------------
问题: ['什么', '是', 'BERT', '模型']
上下文: ['BERT', '是', '一种', '预训练', '语言', '模型']
处理后词元: ['<cls>', '什么', '是', 'BERT', '模型', '<sep>', 'BERT', '是', '一种', '预训练', '语言', '模型', '<sep>']
片段索引:   [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

【片段索引说明】:
------------------------------
索引0: 问题部分(包括<cls>和第一个<sep>)
索引1: 上下文部分(包括第二个<sep>)

【示例4】英文文本处理:
------------------------------
English A: ['hello', 'world']
English B: ['good', 'morning']
Processed tokens: ['<cls>', 'hello', 'world', '<sep>', 'good', 'morning', '<sep>']
Segment indices:  [0, 0, 0, 0, 1, 1, 1]

============================================================
示例运行完成!

除了片段嵌入,BERT 还使用了 位置嵌入(Position Embeddings) 来让模型感知到词元在序列中的顺序。与原始 Transformer 使用固定的正弦/余弦位置编码不同,BERT 使用的是可学习的位置嵌入。

最终,每个输入词元的表示是其 词元嵌入片段嵌入位置嵌入 三者之和。

在这里插入图片描述

BERTEncoder类实现

BERT 的核心是一个多层的双向 Transformer 编码器。接下来,我们将实现这个编码器。我们的 BERTEncoder 类将包含词嵌入层、片段嵌入层、可学习的位置嵌入参数,以及堆叠的多层 EncoderBlock

下面的代码块包含了构建 BERTEncoder 所需的所有组件,从底层的缩放点积注意力到完整的多头注意力和编码器块。代码注释非常详细,解释了每个模块的功能、参数和张量维度的变化。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BERT编码器的完整实现
包含缩放点积注意力、多头注意力、编码器块和BERT编码器的完整定义
"""

import math
import torch
import torch.nn as nn
from typing import Optional


def masked_softmax(X: torch.Tensor, valid_lens: Optional[torch.Tensor]) -> torch.Tensor:
    """
    通过在最后一个轴上遮蔽元素来执行softmax操作
    
    参数:
        X (torch.Tensor): 输入张量,维度: [batch_size, seq_len, seq_len] 或 [batch_size*num_heads, seq_len, seq_len]
        valid_lens (Optional[torch.Tensor]): 有效长度,维度: [batch_size] 或 [batch_size, seq_len] 或 None
    
    返回:
        torch.Tensor: 经过遮蔽的softmax结果,维度与输入X相同
    """
    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 在最后的轴上,被遮蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = X.reshape(-1, shape[-1])
        for batch_idx, valid_len in enumerate(valid_lens):
            X[batch_idx, valid_len:] = -1e6
        return nn.functional.softmax(X, dim=-1).reshape(shape)


def transpose_qkv(X: torch.Tensor, num_heads: int) -> torch.Tensor:
    """
    为了多头注意力的并行计算而变换形状
    
    参数:
        X (torch.Tensor): 输入张量,维度: [batch_size, seq_len, num_hiddens]
        num_heads (int): 注意力头的数量
    
    返回:
        torch.Tensor: 变换后的张量,维度: [batch_size*num_heads, seq_len, num_hiddens/num_heads]
    """
    # 输入X的形状: [batch_size, seq_len, num_hiddens]
    # 输出形状: [batch_size, seq_len, num_heads, num_hiddens/num_heads]
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    
    # 输出形状: [batch_size, num_heads, seq_len, num_hiddens/num_heads]
    X = X.permute(0, 2, 1, 3)
    
    # 最终输出形状: [batch_size*num_heads, seq_len, num_hiddens/num_heads]
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X: torch.Tensor, num_heads: int) -> torch.Tensor:
    """
    逆转transpose_qkv函数的操作
    
    参数:
        X (torch.Tensor): 输入张量,维度: [batch_size*num_heads, seq_len, num_hiddens/num_heads]
        num_heads (int): 注意力头的数量
    
    返回:
        torch.Tensor: 逆转换后的张量,维度: [batch_size, seq_len, num_hiddens]
    """
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)


class PositionWiseFFN(nn.Module):
    """
    基于位置的前馈网络
    
    参数:
        ffn_num_input (int): 输入特征维度
        ffn_num_hiddens (int): 隐藏层特征维度
        ffn_num_outputs (int): 输出特征维度
    """
    def __init__(self, ffn_num_input: int, ffn_num_hiddens: int, ffn_num_outputs: int, **kwargs):
        super

网站公告

今日签到

点亮在社区的每一天
去签到