2017年,Vaswani 等人在论文《Attention Is All You Need》中提出了 Transformer 架构,这可以说是自然语言处理领域的一次“范式转移”。具体可详见之前写的一篇文章:深入解析Transformer架构。
从那以后,无论是BERT、GPT,还是后来的大模型,几乎都建立在它的基础之上。说实话,刚接触这个结构的时候我也觉得有点抽象,尤其是多头注意力和位置编码这些设计,乍一看不太直观。但当你亲手实现一遍,很多疑惑就会慢慢解开。
今天这篇文章,我们就从最基础的模块开始,一步步用 PyTorch 实现一个完整的 Transformer 模型。我们不调用现成的 nn.Transformer
,而是自己动手写每一个组件:位置编码、注意力机制、编码器、解码器……最后还会跑一个简单的训练流程。目的不是为了替代现有的高效实现,而是帮你真正“看见”模型内部是怎么运作的。
整个过程我会尽量讲清楚每一步的设计思路,代码也会配上详细的注释。如果你正在学习深度学习或者准备面试,相信这套“手搓”流程会对理解有很大帮助。
完整代码实现
我们可以参考如下Transformer架构图。
下面是我们将要实现的完整代码。我会在关键部分插入解释,帮助你理解每个模块的作用和背后的逻辑。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
from torch.utils.data import Dataset, DataLoader
# 固定随机种子,确保每次运行结果一致
torch.manual_seed(42)
小贴士:做实验时固定随机种子是个好习惯,不然你改了个小地方,结果波动很大,容易怀疑人生。
1. 位置编码(Positional Encoding)
Transformer 没有像RNN那样的时序结构,所以它不知道词的位置顺序。为了解决这个问题,作者引入了位置编码,把位置信息加到词向量里。
这里用的是正弦和余弦函数交替的形式,好处是能表达相对位置关系,而且即使遇到比训练更长的序列,也能外推。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维用sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维用cos
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe) # 不参与梯度更新
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return x
个人体会:刚开始看公式时总觉得复杂,但其实本质就是构造一个固定的“位置模板”,然后加到每个样本上。这种设计既简单又有效,是Transformer里让我印象很深的一个巧思。
2. 缩放点积注意力(Scaled Dot-Product Attention)
这是整个模型的核心。注意力机制的本质是:给定一组键值对,通过查询来决定应该关注哪些值。
公式是这样的:
其中除以 是为了避免点积过大导致 softmax 梯度太小。
class ScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # 把无效位置设为负无穷
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
output = torch.matmul(attn_weights, value)
return output, attn_weights
注意点:这里的
mask
很关键,后面我们会用它来屏蔽填充符(padding)和防止解码器偷看未来信息。
3. 多头注意力(Multi-Head Attention)
单头注意力只能关注一种模式,而多头允许模型在不同子空间里并行地学习多种表示。你可以把它理解为“多个专家投票”。
实现上,就是把输入投影到多个头,分别做注意力,最后再拼起来。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(dropout)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
Q = self.w_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
attn_output, attn_weights = self.attention(Q, K, V, mask)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.w_o(attn_output)
return output
经验分享:
view
和transpose
这些操作初学容易出错,建议打印中间张量的 shape 来调试。比如(batch, seq_len, d_model)
变成(batch, n_heads, seq_len, d_k)
的过程要理清楚。
4. 前馈网络(Position-wise Feed-Forward)
这个模块比较简单,对序列中每一个位置独立地应用相同的两层全连接网络。
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
def forward(self, x):
return self.linear2(self.dropout(self.relu(self.linear1(x))))
虽然叫“前馈”,但它在整个结构中起到了非线性变换和特征增强的作用,不可或缺。
5. 编码器层(Encoder Layer)
每个编码器层包含两个子层:多头自注意力 + 前馈网络。每个子层都有残差连接和层归一化(LayerNorm),这是稳定训练的关键。
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差 + 归一化
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差 + 归一化
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
小提醒:残差连接放在归一化之前还是之后?原始论文是“post-norm”,但后来很多工作发现“pre-norm”更稳定。这里我们按原始结构实现。
6. 解码器层(Decoder Layer)
解码器比编码器多了一个“编码器-解码器注意力”模块,用来融合源端的信息。
另外,自注意力部分要加上序列掩码,防止当前位置看到后面的词。
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
# 掩码自注意力
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# 编码器-解码器注意力
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
# 前馈网络
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
7. 编码器 & 解码器
把多个层堆起来,再加上词嵌入和位置编码,就构成了完整的编码器和解码器。
class Encoder(nn.Module):
def __init__(self, vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1):
super(Encoder, self).__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, mask)
return x
注意这里有个小技巧:词嵌入乘以 ,是为了让后续的位置编码不会“淹没”原始信号。
解码器结构类似,就不重复贴了。
8. 完整的 Transformer 模型
把编码器、解码器和输出层组合起来:
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout=0.1):
super(Transformer, self).__init__()
self.encoder = Encoder(src_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
self.decoder = Decoder(tgt_vocab_size, d_model, n_layers, n_heads, d_ff, max_len, dropout)
self.linear = nn.Linear(d_model, tgt_vocab_size)
def forward(self, src, tgt, src_mask, tgt_mask):
enc_output = self.encoder(src, src_mask)
dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
output = self.linear(dec_output)
return output
9. 掩码的生成
掩码是训练过程中的一个重要细节:
• 填充掩码(Padding Mask):忽略
<pad>
标记• 序列掩码(Look-ahead Mask):防止解码器看到未来信息
def create_mask(src, tgt, pad_idx):
src_mask = (src != pad_idx).unsqueeze(1).unsqueeze(2)
tgt_pad_mask = (tgt != pad_idx).unsqueeze(1).unsqueeze(2)
tgt_len = tgt.size(1)
tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len))).bool()
tgt_mask = tgt_pad_mask & tgt_sub_mask
return src_mask, tgt_mask
这个
torch.tril
生成下三角矩阵的操作,是实现“只能看到前面”的关键。
10. 数据处理与训练流程
我们定义一个简单的数据集类,并用 DataLoader 批量加载数据。
训练时,目标序列要错开一位:输入是 <s> 你 好
,输出是 你 好 </s>
。
损失函数使用交叉熵,并忽略填充位置。
def train_transformer(model, dataloader, optimizer, criterion, pad_idx, device, n_epochs):
model.train()
for epoch in range(n_epochs):
total_loss = 0
for batch_idx, (src, tgt) in enumerate(dataloader):
src, tgt = src.to(device), tgt.to(device)
src_mask, tgt_mask = create_mask(src, tgt, pad_idx)
src_mask, tgt_mask = src_mask.to(device), tgt_mask.to(device)
optimizer.zero_grad()
output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])
loss = criterion(
output.contiguous().view(-1, output.size(-1)),
tgt[:, 1:].contiguous().view(-1)
)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
print(f'Epoch {epoch+1}/{n_epochs}, Average Loss: {avg_loss:.4f}')
最后在 main()
函数中配置超参数、初始化模型并开始训练。
关键模块再梳理
模块 |
作用 |
位置编码 | 给词向量注入位置信息,弥补无时序结构的缺陷 |
缩放点积注意力 | 实现“查询-键-值”机制,动态分配关注权重 |
多头注意力 | 多角度捕捉不同语义模式,提升表达能力 |
残差连接 + LayerNorm | 缓解深层网络梯度问题,加速收敛 |
掩码机制 | 控制信息流动,保证训练合理性 |
这些设计看似独立,实则环环相扣。比如没有残差连接,6层以上的Transformer几乎训不动;没有位置编码,模型就失去了顺序感知能力。
一点总结
这篇文章我们从零实现了一个标准的 Transformer 模型。虽然用的是人工构造的小数据集,无法真正完成翻译任务,但整个流程涵盖了:
• 模型结构搭建
• 数据预处理
• 掩码机制
• 训练逻辑
我已经尽可能让代码简洁明了,方便你理解和修改。如果你打算进一步扩展,可以考虑加入:
• 学习率调度器(如
NoamOpt
)• 梯度裁剪
• Beam Search 解码
• 更真实的双语数据集(如 WMT)
说实话,当我第一次跑通这个模型时,心里还挺激动的。不是因为结果多好,而是终于把那些公式和结构图变成了实实在在能运行的代码。这种“亲手造出来”的感觉,是读论文很难替代的。
附注:完整代码已测试通过,可在 CPU/GPU 上运行。如需进一步优化或扩展,请根据实际任务调整超参数和数据流程。