Yolov8改进CoTAttention注意力机制,效果秒杀CBAM、SE

news/2024/7/9 23:38:55 标签: YOLO, 人工智能, 深度学习, 目标检测

1.CoTAttention

论文地址:2107.12292.pdf (arxiv.org)

CoTAttention网络是一种用于多模态场景下的视觉问答(Visual Question Answering,VQA)任务的神经网络模型。它是在经典的注意力机制(Attention Mechanism)上进行了改进,能够自适应地对不同的视觉和语言输入进行注意力分配,从而更好地完成VQA任务。

CoTAttention网络中的“CoT”代表“Cross-modal Transformer”,即跨模态Transformer。在该网络中,视觉和语言输入分别被编码为一组特征向量,然后通过一个跨模态的Transformer模块进行交互和整合。在这个跨模态的Transformer模块中,Co-Attention机制被用来计算视觉和语言特征之间的交互注意力,从而实现更好的信息交换和整合。在计算机视觉和自然语言处理紧密结合的VQA任务中,CoTAttention网络取得了很好的效果。

新加坡国立大学的Qibin Hou等人提出了一种为轻量级网络设计的新的注意力机制,该机制将位置信息嵌入到了通道注意力中,称为coordinate attention

京东AI Research提出新的主干网络CoTNet,在CVPR上2021获得开放域图像识别竞赛冠军

京东AI Research提出一种Contextual Transformer Networks结构,就是将Transformer捕捉全局信息的能力与CNN捕捉临近局部信息能力相结合,从而提高网络模型的特征表达能力。值得注意的是,该方法可以实现模块的“即插即用”,将ResNet网络中的3x3模块替换成CoTNet的核心模块即可使用,Res2Net网络也是基于这种思想实现的。

在相同深度(50层或101层)下,top-1和top-5结果都表明本文的方法比卷积网络和Attention-based网络性能更好。

2. 基于Yolov8的CoordAttention实现

2.1 CoTAttention 加入 modules.py

核心代码:

######################  CoTAttention   ####     start   by  AI&CV  ###############################
import torch
from torch import flatten, nn
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

######################  CoTAttention   ####     end   by  AI&CV  ###############################

​2.2 yolov8_CoTAttention.yaml

# Ultralytics YOLO  , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 4  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 3, CoTAttention, [256]]   # 16

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 3, CoTAttention, [512]]  

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 3, CoTAttention, [1024]]  

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)


http://www.niftyadmin.cn/n/5150339.html

相关文章

WebSocket: 实时通信的理解与认识

介绍: 在现代Web应用程序中,实时通信对于提供即时更新和交互性至关重要。传统的HTTP协议虽然适合请求-响应模式,但对于需要频繁数据交换的场景并不理想。而WebSocket技术的出现填补了这个空白,为Web开发者们带来了一种高效、实时的…

【JMeter】逻辑控制器分类以及功能介绍

常用逻辑控制器的分类以及介绍 If Controller 满足if条件才会执行取样器 Loop Controller 对取样器循环多次 ForEach Controller

CAM模型可视化(可解释)

模型的可解释性问题一直是个关注的热点。注意,本文所说的“解释”,与我们日常说的“解释”内涵不一样:例如我们给孩子一张猫的图片,让他解释为什么这是一只猫,孩子会说因为它有尖耳朵、胡须等。而我们让CNN模型解释为什…

故障诊断模型 | Maltab实现BP神经网络的故障诊断

文章目录 效果一览文章概述模型描述源码设计参考资料效果一览 文章概述 故障诊断模型 | Maltab实现BP神经网络的故障诊断 模型描述 BP(Back Propagation) 算法是神经网络深度学习中最重要的算法之一,了解BP算法可以让我们更理解神经网络深度学习模型训练的本质,属于内功修行的…

SSM 整合 JWT

一、前言 SSM 是一个流行的 Java Web 框架,它可以简化开发过程并提高生产率,同时提供了高性能和可扩展性。在 SSM 中整合 JWT 身份验证可以确保用户在访问系统时能够安全地身份验证。 JWT 是一种身份验证和授权协议,它允许用户在访问系统时…

【I.mx6ull】之-----代码的编译过程

本博文记录【I.mx6ull】之-----代码的编译过程 文章目录 I.mx6ull 启动分析汇编语言驱动开发板代码编译过程将汇编语言依次编译为 .bin 文件的过程Makefile 文件的必要性 C语言驱动开发板底层过程 I.mx6ull 启动分析 比如,裸机的例程是在SD卡中,板子上电…

JMeter如何开展性能测试

文章目录 性能测试指标理解透彻以及测算微聊性能测试性能测试流程准备流程 ​👑作者主页:Java冰激凌 性能测试指标理解透彻以及测算 虚拟用户数: 线程 用户并发数:指在某一时间,一定数量的虚拟用户同时对系统的某个功…

uniapp中地图定位功能实现的几种方案

1.uniapp自带uni.getLocation uni.getLocation(options) getlocation | uni-app官网 实现思路:uni.getLocation获取经纬度后调用接口获取城市名 优点:方便快捷,直接调用 缺点:关闭定位后延时很久,无法控制定位延迟…