- 导入必要的库
python
import math
import torch
import torch.nn as nn
from LabmL_helpers.module import Module
from labml_n.utils import clone_module_List
from typing import Optional, List
from torch.utils.data import DataLoader, TensorDataset
from torch import optim
import torch.nn.functional as F
- Transformer 模型概述
Transformer 是一种序列到序列的模型,通过自注意力机制并行处理整个序列,能同时考虑序列中的所有元素,并学习上下文之间的关系。其架构包括编码器和解码器部分,每部分都由多个相同的层组成,这些层包含自注意力机制、前馈神经网络,以及归一化和 Dropout 步骤。 - 核心公式
- 自注意力计算:Attention(Q,K,V)=softmax(dkQKT)V,其中,Q、K、V分别是查询(Query)、键(Key)和值(Value)矩阵,dk是键的维度。
- 多头注意力:将输入分割为多个头,分别计算注意力,然后将结果拼接起来。
- 位置编码:由于 Transformer 不使用循环结构,因此引入位置编码来保留序列中的位置信息。
- 自注意力机制
- 核心原理:计算句子在编码过程中每个位置上的注意力权重,然后以权重和的方式来计算整个句子的隐含向量表示。公式中,首先将 query 与 key 的转置做点积,然后将结果除以dk ,再进行 softmax 计算,最后将结果与 value 做矩阵乘法得到 output。除以dk是为了防止QKT过大导致 softmax 计算溢出,且可使QKT结果满足均值为 0,方差 1 的分布。QKT计算本质上是余弦相似度,可表示两个向量在方向上的相似度。
- 实现
python
import numpy as np
from math import sqrt
import torch
from torch import nn
class Self_Attention(nn.Module):
# input : batch_size * seq_len * input_dim
# q : batch_size * input_dim * dim_k
# k : batch_size * input_dim * dim_k
# v : batch_size * input_dim * dim_v
def __init__(self, input_dim, dim_k, dim_v):
super(Self_Attention, self).__init__()
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_v)
self._norm_fact = 1 / sqrt(dim_k)
def forward(self, x):
Q = self.q(x) # Q: batch_size * seq_len * dim_k
K = self.k(x) # K: batch_size * seq_len * dim_k
V = self.v(x) # V: batch_size * seq_len * dim_v
# Q * K.T() # batch_size * seq_len * seq_len
atten = nn.Softmax(
dim=-1)(torch.bmm(Q, K.permute(0, 2, 1))) * self._norm_fact
# Q * K.T() * V # batch_size * seq_len * dim_v
output = torch.bmm(atten, V)
return output
X = torch.randn(4, 3, 2)
print(X)
self_atten = Self_Attention(2, 4, 5) # input_dim:2, k_dim:4, v_dim:5
res = self_atten(X)
print(res.shape) # [4,3,5]
- 多头注意力机制
不同于只使用一个注意力池化,将输入x拆分为h份,独立计算h组不同的线性投影来得到各自的 QKV,然后并行计算注意力,最后将h个注意力池化拼接起来并通过另一个可学习的线性投影进行变换以产生输出。每个头可能关注输入的不同部分,可表示更复杂的函数。
python
from math import sqrt
import torch
import torch.nn as nn
class Self_Attention_Muti_Head(nn.Module):
# input : batch_size * seq_len * input_dim
# q : batch_size * input_dim * dim_k
# k : batch_size * input_dim * dim_k
# v : batch_size * input_dim * dim_v
def __init__(self, input_dim, dim_k, dim_v, nums_head):
super(Self_Attention_Muti_Head, self).__init__()
assert dim_k % nums_head == 0
assert dim_v % nums_head == 0
self.q = nn.Linear(input_dim, dim_k)
self.k = nn.Linear(input_dim, dim_k)
self.v = nn.Linear(input_dim, dim_v)
self.nums_head = nums_head
self.dim_k = dim_k
self.dim_v = dim_v
self._norm_fact = 1 / sqrt(dim_k)
def forward(self, x):
Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //
self.nums_head)
K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k //
self.nums_head)
V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v //
self.nums_head)
print(x.shape)
print(Q.size())
atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0, 1, 3, 2))) # Q * K.T() # batch_size * seq_len * seq_len
output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1) # Q * K.T() * V # batch_size * seq_len * dim_v
return output
x = torch.rand(1, 3, 4)
print(x)
atten = Self_Attention_Muti_Head(4, 4, 4, 2)
y = atten(x)
print(y.shape)
- 视觉注意力机制
attention 机制本质是利用相关特征图学习权重分布,再用学出来的权重施加在原特征图上最后进行加权求和。计算机视觉上的注意力机制主要分为三种:空间域、通道域、混合域。- 空间域:将图片中的空间域信息做对应的空间变换,提取关键信息,对空间进行掩码的生成并打分,代表是 Spatial attention module。
- 通道域:给每个通道上的信号增加一个权重,代表该通道与关键信息的相关度,权重越大相关度越高。对通道生成掩码 mask 进行打分,代表是 senet、channel attention module。
- 混合域:空间域的注意力忽略了通道域中的信息,将每个通道的图片特征同等处理,这种做法会将空间域变换方法局限在原始特征提取阶段。
- 通道域注意力(SENet)
通过全局池化提取通道权重,然后对特征图进行改变,得到加强后的特征图。
python
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c) # 对应Squeeze操作
y = self.fc(y).view(b, c, 1, 1) # 对应Excitation操作
return x * y.expand_as(x)
- 门控注意力机制(GCT,Gated Channel Transformation)
GCT 是一种简单有效的通道间建模关系体系结构,能显著提高卷积网络在视觉任务的泛化能力。论文发现将门控机制放在 Conv 层前面训练效果最好。GCT 包含三个部分:- Global Context Embedding:设计了一种全局上下文嵌入模块,用于每个通道的全局上下文信息汇聚,公式为sc=αc∥xc∥2=αc{[∑i=1H∑j=1W(xci,j)2]+ϵ}21。
- Channel Normalization:对第一步计算的 L2 进行规范化来构建神经元竞争关系,使用跨通道的特征规范化,公式为s^c=∥s∥2Csc=[(∑c=1Csc2)+ϵ]21Csc。
- Gating Adaptation:加入门限机制,公式为x^c=xc[1+tanh(γcs^c+βc)] 。
python
class GCT(nn.Module):
def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False):
super(GCT, self).__init__()
self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1))
self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1))
self.epsilon = epsilon
self.mode = mode
self.after_relu = after_relu
def forward(self, x):
if self.mode == 'l2':
embedding = (x.pow(2).sum((2, 3), keepdim=True) +
self.epsilon).pow(0.5) * self.alpha
norm = self.gamma / \
(embedding.pow(2).mean(dim=1, keepdim=True) +
self.epsilon).pow(0.5)
elif self.mode == 'l1':
if not self.after_relu:
_x = torch.abs(x)
else:
_x = x
embedding = _x.sum((2, 3), keepdim=True) * self.alpha
norm = self.gamma / \
(torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon)
gate = 1. + torch.tanh(embedding * norm + self.beta)
return x * gate
GCT 建议添加在 Conv 层前,一般可以先冻结原来的模型,来训练 GCT,然后解冻再进行微调。