每日Attention学习26——Dynamic Weighted Feature Fusion

发布于:2025-03-18 ⋅ 阅读:(27) ⋅ 点赞:(0)
模块出处

[ACM MM 23] [link] [code] Efficient Parallel Multi-Scale Detail and Semantic Encoding
Network for Lightweight Semantic Segmentation


模块名称

Dynamic Weighted Feature Fusion (DWFF)


模块作用

双级特征融合


模块结构

在这里插入图片描述


模块思想

我们提出了 DWFF 策略,选择性地关注特征图中信息量最大的部分,以有效地结合浅层和深层特征,提高分割精度。DWFF 可用于在具有细粒度细节的区域中更重地加权浅层特征,在具有较高语义信息的区域中更重地加权深层特征,从而实现更好的特征组合和准确的分割。


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class DWFF(nn.Module):
    def __init__(self,
                 in_channels: int,
                 height: int = 2,
                 reduction: int = 8,
                 bias: bool = False) -> None:
        super(DWFF, self).__init__()

        self.height = height
        d = max(int(in_channels / reduction), 4)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(
            nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),
            nn.BatchNorm2d(d),
            nn.LeakyReLU(0.2)
        )
        self.fcs = nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats = inp_feats[0].shape[1]
        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
        feats_U = torch.sum(inp_feats, dim=1)
        feats_S = self.avg_pool(feats_U)
        feats_Z = self.conv_du(feats_S)
        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
        return feats_V
    

if __name__ == '__main__':
    dwff = DWFF(in_channels=64)
    x1 = torch.randn([2, 64, 16, 16])
    x2 = torch.randn([2, 64, 16, 16])
    out = dwff([x1, x2])
    print(out.shape)  # 2, 64, 16, 16