SEAttention
摘要
卷积神经网络(CNNs)的核心构建模块是卷积算子,它使网络能够通过在每一层的局部感受野内融合空间和通道信息来构建有价值的特征。此前大量研究聚焦于这种关系中的空间成分,试图通过在整个特征层级中提升空间编码质量来增强 CNN 的表征能力。在这项工作中,我们将重点放在通道关系上,并提出一种新颖的架构单元,称为 “挤压与激励
”(Squeeze-and-Excitation,简称 SE)模块。该模块通过显式建模通道间的相互依赖关系,自适应地重新校准通道维度的特征响应
。我们证明,这些模块可以堆叠在一起形成 SENet 架构,在不同数据集上都具有极高的泛化能力。我们进一步展示了,SE 模块在仅增加少量计算成本的情况下,就能显著提升现有最先进 CNN 的性能。挤压与激励网络是我们参加 2017 年 ImageNet 大规模视觉识别挑战赛(ILSVRC 2017)分类任务提交成果的基础,我们凭借该成果获得了第一名,并将 top-5 错误率降至 2.251%,相比 2016 年的冠军成果有 25% 的相对提升。模型和代码可在https://github.com/hujie-frank/SENet获取。
模型结构
简单描述:
- 压缩:通过全局平均池化,压缩空间特征
- 提取:通过线性层来进行压缩以及扩张得到注意力权重
- 融合:将权重与特征图相乘进行输出
模型代码
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
def __init__(self, channel=32,reduction=8):
super().__init__()
# 初始化代码
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 自适应平均池化(H,W->1,1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # 线性层进行通道压缩
nn.ReLU(inplace=True),# 激活函数,引入非线性
nn.Linear(channel // reduction, channel, bias=False), # 线性层进行通道扩张
nn.Sigmoid()# 激活函数,将输出值压缩到0-1之间,生成权重
)
# 初始化权重,针对不同的网络层使用不同的初始化策略
def init_weights(self):
for m in self.modules(): # 遍历所有子模块
if isinstance(m, nn.Conv2d): # 如果子模块是卷积层
init.kaiming_normal_(m.weight, mode='fan_out') # 使用kaiming初始化方法
if m.bias is not None: # 如果卷积层有偏置
init.constant_(m.bias, 0) # 将偏置初始化为0
elif isinstance(m, nn.BatchNorm2d): # 如果子模块是批归一化层
init.constant_(m.weight, 1) # 将权重初始化为1
init.constant_(m.bias, 0) # 将偏置初始化为0
elif isinstance(m, nn.Linear): # 如果子模块是线性层
init.normal_(m.weight, std=0.001) # 使用正态分布初始化权重
if m.bias is not None: # 如果线性层有偏置
init.constant_(m.bias, 0) # 将偏置初始化为0
def forward(self, x):
b, c, _, _ = x.size() # [batch_size, channel, height, width]-->[1,64,32,32]
y = self.avg_pool(x).view(b, c) # 自适应平均池化[1,64,32,32]-->[1,64,1,1]
y = self.fc(y).view(b, c, 1, 1) # 线性层进行通道压缩和扩张[1,64,1,1]-->[1,64,1,1]
return x * y.expand_as(x) # 将权重与特征图相乘[1,64,1,1]-->[1,64,32,32]
if __name__ == '__main__':
input=torch.randn(1,64,32,32).cuda()
se = SEAttention(channel=64,reduction=8).cuda()
output=se(input)
print('input_size:', input.size())
print('output_size:', output.size())
print("最大内存占用:", torch.cuda.max_memory_allocated() // 1024 // 1024, "MB")