Transformer Masked loss原理精讲及其PyTorch逐行实现

发布于:2025-07-25 ⋅ 阅读:(22) ⋅ 点赞:(0)

Masked Loss 的核心原理是:在计算损失函数时,只考虑真实有意义的词元(token),而忽略掉为了数据对齐而填充的无意义的填充词元(padding token)。

这是重要的技术,可以确保模型专注于学习有意义的任务,并得到一个正确的性能评估。

1.原理精讲

为什么需要 Masked Loss?

在训练神经网络时,我们通常会用一个批次(batch)的数据进行训练,而不是一次只用一个样本。对于自然语言处理任务,我们会一次性处理多句话。但这些句子的长度都几乎不一样。

例如,我们有一个包含两个句子的批次:

["我", "是", "学生"] (长度为 3)

["今天", "天气", "真", "好"] (长度为 4)

为了将它们放入一个统一的张量(tensor)中进行高效的并行计算,我们必须将较短的句子“填充”到一个统一的长度(通常是这个批次中最长句子的长度)。我们会使用一个特殊的 <pad> 词元来完成这个任务。

填充后的数据就变成了:

["我", "是", "学生", "<pad>"]

["今天", "天气", "真", "好"]

现在,问题来了。当模型在训练时,它会为每个位置都生成一个预测。对于第一句话,它也会尝试在第4个位置预测 <pad>。如果我们不加处理,损失函数就会计算模型预测 <pad> 的准确度,并把这个“误差”也算进总的损失里。

这样做有两个坏处:

  1. 浪费计算资源:强迫模型去学习一个无意义的任务——“在句子末尾预测填充符”。

  2. 评估指标失真:这个无意义任务的损失会“稀释”我们真正关心的、对真实词元的预测损失,导致我们无法准确评估模型的真实性能。

Masked Loss 就是为了解决这个问题而生的。它的目标就是创建一个“掩码(mask)”,告诉损失函数不计算PAD。


PyTorch 逐行实现

在 PyTorch 中,实现 Masked Loss 非常简单,因为 nn.CrossEntropyLoss 已经内置了处理它的高效方法。

我们将一步步模拟这个过程。

第零步:准备工作

我们先导入库,并设定一些基本参数。

import torch
import torch.nn as nn

#设定参数


BATCH_SIZE = 2      # 一个批次里有2句话

SEQ_LEN = 5         # 统一填充后的句子长度是5

VOCAB_SIZE = 10     # 假设我们的词汇表很小,只有10个词

PADDING_IDX = 0     # 我们约定,ID为0的词元就是 <pad> 填充符

代码解释: 我们设定了一个场景:一个批次包含2个句子,每个句子被填充到长度5,词汇表共10个词,并且我们用 0 来代表 <pad>

第一步:模拟模型输出和真实标签

我们创建两个张量:一个是模型预测的 logits,另一个是带填充的真实标签 target

# 模拟模型的原始输出 (logits)
# 形状: (批量大小, 序列长度, 词汇表大小)
logits = torch.randn(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)

# 模拟真实的标签 (ground truth)
# 注意其中包含了 PADDING_IDX (0)
target = torch.tensor([
    [1, 5, 4, 2, PADDING_IDX],  # 第1句话,最后一个是padding
    [3, 8, 7, PADDING_IDX, PADDING_IDX]   # 第2句话,最后两个是padding
])

print("模型预测 Logits 的形状:", logits.shape)
print("真实标签 Target 的形状:", target.shape)
print("真实标签内容:\n", target)

代码解释logits 是模型对每个位置、每个词的预测得分。target 是我们的“标准答案”,可以看到,为了对齐,较短的句子末尾被填充了 0

第二步:定义损失函数 

# 定义交叉熵损失函数
# 关键:告诉损失函数,所有标签值为 PADDING_IDX 的位置都被忽略

criterion = nn.CrossEntropyLoss(ignore_index=PADDING_IDX)

ignore_index=PADDING_IDX 这个参数就是实现 Masked Loss 的方法。当我们把 padding_idx (这里是0) 传给它,CrossEntropyLoss 在内部计算时,会自动跳过所有目标标签是 0 的位置。

第三步:调整张量形状

CrossEntropyLoss 期望的输入形状是:Input: (N, C)Target: (N),其中 N 是样本总数,C 是类别数。而我们现在的 logitstarget 都是二维的批次数据,需要调整一下。

# CrossEntropyLoss 需要的输入形状是 (N, C)
# N 是总的需要计算的元素数量, C是类别数 (即词汇表大小)
# 我们用 .view() 来重塑张量

# 将 logits 从 (2, 5, 10) 变为 (10, 10)
reshaped_logits = logits.view(-1, VOCAB_SIZE)

# 将 target 从 (2, 5) 变为 (10)
reshaped_target = target.view(-1)

print("\n重塑后的 Logits 形状:", reshaped_logits.shape)
print("重塑后的 Target 形状:", reshaped_target.shape)

代码解释: 我们把 (BATCH_SIZE, SEQ_LEN) 这两个维度“压平”成一个维度。-1 是一个占位符,告诉 PyTorch 自动计算这个维度的大小(在这里就是 2 * 5 = 10)。

第四步:计算损失

现在,所有准备工作都已就绪,我们可以直接计算损失。

# 计算损失
# criterion 会自动使用我们设置的 ignore_index=0 来忽略填充位置
loss = criterion(reshaped_logits, reshaped_target)

print(f"\n计算出的 Masked Loss 是: {loss.item()}")

代码解释: 尽管 reshaped_target 中仍然包含 0,但由于我们在第二步中设置了 ignore_index=0,这些位置的损失不会被计算和累加。最终得到的 loss 值,是只基于那 7 个真实词元([1, 5, 4, 2][3, 8, 7])计算出来的平均损失。

这样,我们就用非常简洁的方式实现了 Masked Loss。


网站公告

今日签到

点亮在社区的每一天
去签到