混合注意力机制 -- Convolutional Block Attention Module(CBAM)

发布于:2024-07-02 ⋅ 阅读:(15) ⋅ 点赞:(0)

CBAM

CBAM 模块概述 通道注意力模块(Channel Attention Mechanism)和空间注意力模块(Spatial Attention Mechanism)是注意力机制的两种主要形式,它们分别通过对通道维度和空间维度的特征图进行加权,从而使网络更加关注重要的特征。CBAM模块结合了这两种注意力机制,可以在保留空间信息的同时,有效地提取关键通道特征,提高了网络在处理复杂图像任务上的性能表现。
在这里插入图片描述
在这里插入图片描述

代码实现

import torch
from torch import nn
from torchsummary import summary
 
class ChannelModule(nn.Module):
    def __init__(self, inputs, ratio=16):
        super(ChannelModule, self).__init__()
        _, c, _, _ = inputs.size()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.share_liner = nn.Sequential(
            nn.Linear(c, c // ratio),
            nn.ReLU(),
            nn.Linear(c // ratio, c)
        )
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        x = self.maxpool(inputs).view(inputs.size(0), -1)#nc
        maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#nchw
        y = self.avgpool(inputs).view(inputs.size(0), -1)
        avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)
        return self.sigmoid(maxout + avgout)
 
 
class SpatialModule(nn.Module):
    def __init__(self):
        super(SpatialModule, self).__init__()
        self.maxpool = torch.max
        self.avgpool = torch.mean
        self.concat = torch.cat
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self, inputs):
        maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#n1hw
        avgout = self.avgpool(inputs, dim=1, keepdim=True)#n1hw
        outs = self.concat([maxout, avgout], dim=1)#n2hw
        outs = self.conv(outs)#n1hw
        return self.sigmoid(outs)
 
 
class CBAM(nn.Module):
    def __init__(self, inputs):
        super(CBAM, self).__init__()
        self.channel_out = ChannelModule(inputs)
        self.spatial_out = SpatialModule()
 
    def forward(self, inputs):
        outs = self.channel_out(inputs) * inputs
        return self.spatial_out(outs) * outs

网站公告

今日签到

点亮在社区的每一天
去签到