YOLOv8改进实战 | 注意力篇 | 引入基于跨空间学习的高效多尺度注意力EMA,小目标涨点明显

发布于:2024-09-05 ⋅ 阅读:(20) ⋅ 点赞:(0)

在这里插入图片描述


在这里插入图片描述
YOLOv8专栏导航点击此处跳转


前言

YOLOv8 是由 YOLOv5 的发布者 Ultralytics 发布的最新版本的 YOLO。它可用于对象检测、分割、分类任务以及大型数据集的学习,并且可以在包括 CPU 和 GPU 在内的各种硬件上执行。

YOLOv8 是一种尖端的、最先进的 (SOTA) 模型,它建立在以前成功的 YOLO 版本的基础上,并引入了新的功能和改进,以进一步提高性能和灵活性。YOLOv8 旨在快速、准确且易于使用,这也使其成为对象检测、图像分割和图像分类任务的绝佳选择。具体创新包括一个新的骨干网络、一个新的 Ancher-Free 检测头和一个新的损失函数,还支持YOLO以往版本,方便不同版本切换和性能对比。


一、EMA介绍

在这里插入图片描述

论文链接:Efficient Multi-Scale Attention Module with Cross-Spatial Learning

在这里插入图片描述

论文提出了一种新颖的高效多尺度注意力(EMA)模块。EMA模块旨在保留每个通道的信息,同时减少计算开销。它通过重塑部分通道到批次维度,并将通道雏度分组为多个子特征,使得空间语义特征在每个特征组内均匀分布。此外,EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的成对关系。

在这里插入图片描述

创新点主要包括:

  1. 高效多尺度注意力(EMA):新型的注意力机制,同时减少计算开销和保留每个通道的关键信息

  2. 通道和批次维度的重组:通过重新组织通道维度和批次维度,提高了模型处理特征的能力。

  3. 跨维度交互:模块利用跨维度的交互来捕捉像素级别的关系

  4. 全局信息编码和通道权重校准:在并行分支中编码全局信息,用于通道权重的重新校准,增强了特征表示的能力。

二、代码实现

代码目录

  • 按下面文件夹结构创建文件(相比于在原有ultralytics/nn/modules文件夹下的相关文件中直接添加便于管理
    - ultralytics
    	- nn
    		- extra_modules
    			- __init__.py
    			- attention.py
    		- modules
    

ultralytics/nn/extra_modules/__init__.py中添加:

from .attention import *

ultralytics/nn/extra_modules/attention.py中添加:

import torch
from torch import nn

__all__ = ['EMA']


class EMA(nn.Module):
    def __init__(self, channels, factor=8):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

注册模块

ultralytics/nn/tasks.py文件开头添加:

from ultralytics.nn.extra_modules import *

ultralytics/nn/tasks.py文件中parse_model函数添加:

elif m in {EMA}:
    args = [ch[f], *args]

配置yaml文件

yolov8-ema.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 1, EMA, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 1, EMA, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 1, EMA, []]

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)


三、模型测试

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

model = YOLO("yolov8n-ema.yaml")  # build a new model from scratch
                   from  n    params  module                                       arguments
  0                  -1  1       464  ultralytics.nn.modules.conv.Conv             [3, 16, 3, 2]
  1                  -1  1      4672  ultralytics.nn.modules.conv.Conv             [16, 32, 3, 2]
  2                  -1  1      7360  ultralytics.nn.modules.block.C2f             [32, 32, 1, True]
  3                  -1  1     18560  ultralytics.nn.modules.conv.Conv             [32, 64, 3, 2]
  4                  -1  2     49664  ultralytics.nn.modules.block.C2f             [64, 64, 2, True]
  5                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]
  6                  -1  2    197632  ultralytics.nn.modules.block.C2f             [128, 128, 2, True]
  7                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]
  8                  -1  1    460288  ultralytics.nn.modules.block.C2f             [256, 256, 1, True]
  9                  -1  1    164608  ultralytics.nn.modules.block.SPPF            [256, 256, 5]
 10                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 11             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 12                  -1  1    148224  ultralytics.nn.modules.block.C2f             [384, 128, 1]
 13                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']
 14             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 15                  -1  1     37248  ultralytics.nn.modules.block.C2f             [192, 64, 1]
 16                  -1  1       672  ultralytics.nn.extra_modules.attention.EMA   [64]
 17                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]
 18            [-1, 12]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 19                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]
 20                  -1  1      2624  ultralytics.nn.extra_modules.attention.EMA   [128]
 21                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]
 22             [-1, 9]  1         0  ultralytics.nn.modules.conv.Concat           [1]
 23                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]
 24                  -1  1     10368  ultralytics.nn.extra_modules.attention.EMA   [256]
 25        [16, 20, 24]  1    897664  ultralytics.nn.modules.head.Detect           [80, [64, 128, 256]]
YOLOv8n-ema summary: 249 layers, 3170864 parameters, 3170848 gradients, 9.1 GFLOPs

四、模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

# Load a model
model = YOLO("yolov8n-ema.yaml")  # build a new model from scratch

# Use the model
model.train(
    data="./mydata/data.yaml",
    epochs=300,
    batch=32,
    imgsz=640,
    workers=8,
    device=0,
    project="runs/train",
    name='exp')  # train the model

五、总结

  • 模型的训练具有很大的随机性,您可能需要点运气和更多的训练次数才能达到最高的 mAP。
    在这里插入图片描述

网站公告

今日签到

点亮在社区的每一天
去签到