免费阅读篇 | 芒果YOLOv8改进110:注意力机制GAM:用于保留信息以增强渠道空间互动

news/2024/7/10 2:54:54 标签: YOLO, 人工智能, 目标检测

💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可

该专栏完整目录链接: 芒果YOLOv8深度改进教程

该篇博客为免费阅读内容,直接改进即可🚀🚀🚀

文章目录

      • 1. GAM论文
      • 2. YOLOv8 核心代码改进部分
      • 2.1 核心新增代码
        • 2.2 修改部分
      • 2.3 YOLOv8-gam 网络配置文件
      • 2.4 运行代码
      • 改进说明


1. GAM论文

在这里插入图片描述

研究了多种注意力机制来提高各种计算机视觉任务的性能。然而,现有的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息缩减和放大全局交互式表示来提高深度神经网络的性能。我们引入了带有多层感知器的 3D 排列,用于通道注意力以及卷积空间注意力子模块。对CIFAR-100和ImageNet-1K上图像分类任务的所提机制的评估表明,我们的方法在ResNet和轻量级MobileNet上都稳定地优于最近的几种注意力机制。

在这里插入图片描述

具体细节可以去看原论文:https://arxiv.org/pdf/2112.05561v1.pdf


YOLOv8__26">2. YOLOv8 核心代码改进部分

2.1 核心新增代码

首先在ultralytics/nn/modules文件夹下,创建一个 gam.py文件,新增以下代码

import numpy as np
import torch
from torch import nn
from torch.nn import init

class GAMAttention(nn.Module):
       #https://paperswithcode.com/paper/global-attention-mechanism-retain-information
    def __init__(self, c1, c2, group=True,rate=4):
        super(GAMAttention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), 
            nn.BatchNorm2d(int(c1 /rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), 
            nn.BatchNorm2d(c2)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        x = x * x_channel_att

        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle 
        out = x * x_spatial_att
        return out  

def channel_shuffle(x, groups=2):
        B, C, H, W = x.size()
        out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
        out=out.view(B, C, H, W) 
        return out   

2.2 修改部分

在ultralytics/nn/modules/init.py中导入 定义在 gam.py 里面的模块

from .gam import GAMAttention

'GAMAttention' 加到 __all__ = [...] 里面

第一步:
ultralytics/nn/tasks.py文件中,新增

from ultralytics.nn.modules import GAMAttention

然后在 在tasks.py中配置
找到

        elif m is nn.BatchNorm2d:
        args = [ch[f]]

在这句上面加一个

        elif m is GAMAttention:
            c1, c2 = ch[f], args[0]
            if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)

            args = [c1, c2, *args[1:]]

YOLOv8gam__107">2.3 YOLOv8-gam 网络配置文件

新增YOLOv8-gam.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 3, GAMAttention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 运行代码

直接替换YOLOv8-gam.yaml 进行训练即可

到这里就完成了这篇的改进。

改进说明

这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下


🥇🥇🥇
添加博主联系方式:

友好的读者可以添加博主QQ: 2434798737, 有空可以回答一些答疑和问题

🚀🚀🚀


参考

https://github.com/ultralytics/ultralytics


http://www.niftyadmin.cn/n/5437208.html

相关文章

【Flask开发实战】防火墙配置文件解析(二)之shell读取内容

一、前言 上一篇文章中,介绍了防火墙配置文件包含的基本元素和格式样式,并模拟了几组有代表性的规则内容,作为基础测试数据。在拿到基础测试数据后,关于我们最终想解析成的数据是什么样式的,其实不难看出,…

java的成员变量和局部变量

1、什么是成员变量和局部变量? 2、成员变量和局部变量区别 区别 成员变量 局部变量 类中位置不同 类中方法外 方法内或者方法声明上 内存中位置不同 堆内存 栈内存 生命周期不同 随着对象的存在而存在,随着对象的消失而消失 随着方法的调用而…

Vue3、element-plus和Vue2、elementUI的一些转换

插槽 Vue3<template #default"scope"></template> <template #footer></template>Vue2<template slot-scope"scope"></template> <template slot"footer"></template>JS定义 Vue3 <script…

深入理解mysql 从入门到精通

1. MySQL结构 由下图可得MySQL的体系构架划分为&#xff1a;1.网络接入层 2.服务层 3.存储引擎层 4.文件系统层 1.网络接入层 提供了应用程序接入MySQL服务的接口。客户端与服务端建立连接&#xff0c;客户端发送SQL到服务端&#xff0c;Java中通过JDBC来实现连接数据库。 …

qt+ffmpeg 实现音视频播放(二)之音频播放

一、音频播放流程 1、打开音频文件 通过 avformat_open_input() 打开媒体文件并分配和初始化 AVFormatContext 结构体。 函数原型如下&#xff1a; int avformat_open_input(AVFormatContext **ps, const char *url, AVInputFormat *fmt, AVDictionary **options); 参数说…

OpenAI Q-Star:AGI距离自我意识越来越近

最近硅谷曝出一份54页的内部文件&#xff0c;揭露了去年OpenAI宫斗&#xff0c;导致Altman&#xff08;奥特曼&#xff09;差点离职的神秘项目——Q-Star&#xff08;神秘代号Q*&#xff09;。 根据该文件显示&#xff0c;Q-Star多模态大模型拥有125万亿个参数&#xff0c;比现…