Yolov8优化: 多分支卷积模块RFB,扩大感受野提升小目标检测精度

1.RFB-Net介绍

 论文:https://arxiv.org/pdf/1711.07767.pdf

 代码:GitHub - GOATmessi7/RFBNet: Receptive Field Block Net for Accurate and Fast Object Detection, ECCV 2018

         受启发于人类视觉的Receptive Fields结构,本文提出RFB,将RFs的尺度、离心率纳入考虑范围,使用轻量级主干网也能提取到高判别性特征,使得检测器速度快、精度高;具体地,RFB基于RFs的不同尺度,使用不同的卷积核,设计了多分支的conv、pooling操作(makes use of multi-branch pooling with varying kernels),并通过虫洞卷积(dilated conv)来控制感受野的离心率,最后一步reshape操作后,形成生成的特征

 RFs也已被深入研究,如Inception、ASPP、Deformable CNN:

RFB模块是一个多分支的卷积模块,它的内部结构被划分为两部分:

1.多分支卷积层:根据RF的定义,使用多种尺寸的卷积核来实现比固定尺寸更好。具体设计:1.瓶颈结构,1x1-s2的卷积减少通道特征,然后加上一个nxn卷积。2.用5x5卷积替换为2个3x3的卷积去减少参数,这样可得到非线性结构更好的层。3.为了输出,卷积经常有stride=2或者是减少通道,所有直连层为了匹配维度用一个不带激活函数的1x1卷积层。
2.dilated 卷积层:在保持参数量可扩大感受野,用来获取更高分辨率的特征。下图展示了两种RFB结构:RFB和RFB-s。每个分支都是一个正常卷积后面加一个dilated卷积,主要尺寸和dilated因子不同。(a)RFB整体上借鉴了Inception的思想,主要不同点在于引入了3个dilated卷积层。(b)RFB-s和RFB相比主要有两个改进,一方面用3x3的卷积层代替5x5卷积层,另一方面用1x3和3x1的卷积来代替3x3卷积,主要目的是为了减少计算量,类似Inception后期版本对Inception结构的改进。

 实验结果

RFB模块:在table 2中,原始的SSD300实现了77.2%的mAP,通过简单的用RFB-max Pooling替代最后一个卷积层,我们将结果提升到了79.1%,获得了1.9%的提高,这表明了RFB模块的高效性。

2. RFB引入到yolov8

2.1修改modules.py

class BasicRFB(nn.Module):

    def __init__(self, in_planes, out_planes, stride=1, scale=0.1, map_reduce=8, vision=1, groups=1):
        super(BasicRFB, self).__init__()
        self.scale = scale
        self.out_channels = out_planes
        inter_planes = in_planes // map_reduce

        self.branch0 = nn.Sequential(
            BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
            BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
            BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 1,
                      dilation=vision + 1, relu=False, groups=groups)
        )
        self.branch1 = nn.Sequential(
            BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
            BasicConv(inter_planes, 2 * inter_planes, kernel_size=(3, 3), stride=stride, padding=(1, 1), groups=groups),
            BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 2,
                      dilation=vision + 2, relu=False, groups=groups)
        )
        self.branch2 = nn.Sequential(
            BasicConv(in_planes, inter_planes, kernel_size=1, stride=1, groups=groups, relu=False),
            BasicConv(inter_planes, (inter_planes // 2) * 3, kernel_size=3, stride=1, padding=1, groups=groups),
            BasicConv((inter_planes // 2) * 3, 2 * inter_planes, kernel_size=3, stride=stride, padding=1,
                      groups=groups),
            BasicConv(2 * inter_planes, 2 * inter_planes, kernel_size=3, stride=1, padding=vision + 4,
                      dilation=vision + 4, relu=False, groups=groups)
        )

        self.ConvLinear = BasicConv(6 * inter_planes, out_planes, kernel_size=1, stride=1, relu=False)
        self.shortcut = BasicConv(in_planes, out_planes, kernel_size=1, stride=stride, relu=False)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)

        out = torch.cat((x0, x1, x2), 1)
        out = self.ConvLinear(out)
        short = self.shortcut(x)
        out = out * self.scale + short
        out = self.relu(out)

        return out

2.2 修改tasks.py

from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
                                    Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,BasicRFB)

修改def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3) 

 if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
                 BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,BasicRFB)

2.3 yolov8_BasicRFB.yaml

# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 1, BasicRFB, [256]]  # 16 

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 1, BasicRFB, [512]]  # 20 

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 1, BasicRFB, [1024]]  # 24 

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)


http://www.niftyadmin.cn/n/1003913.html

相关文章

JavaScript中的运行时间统计方法

在开发中,我们经常需要对代码的执行时间进行统计和分析。JavaScript 提供了多种方法来测量代码的运行时间,本文将介绍使用 Date 对象、console.time() 和 performance.now() 来实现这个目标。 使用 Date 对象进行运行时间统计 Date 对象是 JavaScript …

2023上半年软考系统分析师科目一整理-23

2023上半年软考系统分析师科目一整理-23 对于如下所示的序列图所描述的场景,最适合于采用的设计模式是(30);该模式适用的场合是(31)。 A.Visitor B.Strategy C.Observe…

animate.css 动画

Animate.css | A cross-browser library of CSS animations. class"animate__bounce" 1. bounce 弹跳 2. flash 闪烁 3. pulse 放大,缩小 4. rubberBand 放大,缩小,弹…

Vue + TS封装全局搜索组件

本文介绍了如何使用Vue和TypeScript封装一个全局搜索组件。该组件可以方便地在Vue项目中使用,帮助用户快速定位所需信息。 组件功能 该全局搜索组件具有以下功能: 可以搜索指定数据源中的数据支持模糊搜索和精确搜索可以自定义搜索结果的展示方式支持…

【C++初阶】C++入门——缺省参数、函数重载

目录 一、缺省参数1.1 定义1.2 缺省参数分类1.3 缺省参数只能出现在函数声明中 二、函数重载2.1 定义2.2 构成重载的几种情况2.3 C支持函数重载的原理 一、缺省参数 1.1 定义 缺省参数是声明或定义函数时为函数的参数指定一个缺省值。在调用该函数时,如果没有指定实…

RK3288安卓7.1系统定制屏幕上面从底部往上滑显示状态栏,并且添加一个虚拟按键再次显示状态栏

实现功能:安卓系统屏幕上任意位置连续点击5次后系统自动隐藏导航栏 现场环境:导航栏状态栏隐藏,谷歌浏览器作为launcher启动并且进入 难点:任意位置点击5下这个事件如何捕捉 参考apk捕捉点击5下事件代码: public cla…

① RESTful API

1.API(Application Programming Interface) API就是一个接口;例如玩某一款游戏,你不必知道游戏具体的实现细节,只需要知道点哪个键是哪个技能就够了,而这个键之所以能实现玩家与游戏的交互,是因…

超级实用!详解Node.js中的http模块和fs模块

文章目录 1. http 模块创建 HTTP 服务器处理 HTTP 请求发送 HTTP 请求 2. fs 模块读取文件写入文件删除文件创建目录 以下是 Node.js中的http模块和fs模块 1. http 模块 用于创建和处理 HTTP 服务器和客户端,可用于构建 Web 应用程序。 const http require(http)…