YOLOv8改进 | 图像去雾 | 利用图像去雾网络AOD-PONO-Net网络增改进图像物体检测

一、本文介绍 

本文给大家带来的改进机制是利用AODNet图像去雾网络结合PONO机制实现二次增强,我将该网络结合YOLOv8针对图像进行去雾检测(也适用于一些模糊场景,图片不清晰的检测)同时本文的内容不影响其它的模块改进可以作为工作量凑近大家的论文里,非常的适用,图像去雾检测为群友最近提出的需要的改进,在开始之前给大家推荐一下我的专栏,本专栏每周更新3-10篇最新前沿机制 | 包括二次创新全网无重复,以及融合改进(大家拿到之后添加另外一个改进机制在你的数据集上实现涨点即可撰写论文),还有各种前沿顶会改进机制 |,更有包含我所有附赠的文件(文件内集成我所有的改进机制全部注册完毕可以直接运行)和交流群和视频讲解提供给大家。  

👑欢迎大家订阅我的专栏一起学习YOLO👑 

专栏目录:YOLOv8改进有效系列目录 | 包含卷积、主干、检测头、注意力机制、Neck上百种创新机制

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备    

目录

一、本文介绍 

二、原理介绍 

三、核心代码

四、添加教程

4.1 修改一

4.2 修改二 

4.3 修改三 

五、AODNet-PONO-Net的yaml文件和运行记录

5.1 AODNet-PONO-Net的yaml文件

5.2 训练代码 

5.3 训练过程截图 

五、本文总结


二、原理介绍 

官方论文地址: 官方论文地址点击即可跳转

官方代码地址: 官方代码地址点击即可跳转

摘要:这篇论文提出了一种名为全能去雾网络(AOD-Net)的图像去雾模型,该模型是基于重新制定的大气散射模型并利用卷积神经网络(CNN)构建的。与大多数先前的模型不同,AOD-Net不是分别估计传输矩阵和大气光,而是直接通过一个轻量级的CNN生成清晰图像。这种新颖的端到端设计使得将AOD-Net嵌入到其他深度模型中变得简单,例如,用于提升雾霾图像上高级任务性能的Faster R-CNN。在合成和自然雾霾图像数据集上的实验结果证明了我们在峰值信噪比(PSNR)、结构相似性指数(SSIM)和主观视觉质量方面超越了最先进技术的性能。此外,当将AOD-Net与Faster R-CNN结合并从头到尾进行联合训练时,我们见证了雾霾图像上对象检测性能的显著提升。


AOD-Net是一个端到端的可训练去雾模型,直接从有雾图像产生清晰图像,而不是依赖于任何单独和中间参数估计步骤。基于重新公式化的大气散射模型设计,与现有工作共享相同的物理基础,但以一种“更端到端”的方式将其所有参数估计在一个统一模型中完成

主要创新点

  1. 端到端去雾模型:首次提出一个端到端训练的去雾模型,直接从雾图像生成清晰图像,避免了传统方法中估计传输矩阵和大气光的独立步骤​​。
  2. 与高级视觉任务的结合:首次量化研究去雾质量如何影响后续高级视觉任务的性能,为比较去雾结果提供了一种新的客观标准。此外,AOD-Net可以无缝地与其他深度模型嵌入,形成一个在有雾图像上执行高级任务的流水线,通过端到端的联合调优进一步提升性能​​。

个人总结:AOD-Net能够一步到位地把雾气重的照片变清晰,而不像以前的方法那样需要分好几步小心翼翼地处理。简单来说,AOD-Net就是通过学习雾中的图片和清晰图片之间的差别,找到一种直接去除雾气的捷径,使得图片恢复清晰,同时也帮助计算机更好地理解图片内容。 


三、核心代码

核心代码的使用方式看章节四!

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
torch.autograd.set_detect_anomaly(True)
__all__ = ['AOD_pono_net']
class AODnet(nn.Module):
    def __init__(self):
        super(AODnet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x3 = F.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x4 = F.relu(self.conv4(cat2))
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = F.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        return F.relu(output)

class AOD_pono_net(nn.Module):
    def __init__(self):
        super(AOD_pono_net, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, stride=1, padding=3)
        self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, stride=1, padding=1)
        self.b = 1

        self.pono = PONO(affine=False)
        self.ms = MS()

    def forward(self, x):
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x1))
        cat1 = torch.cat((x1, x2), 1)
        x1, mean1, std1 = self.pono(x1)
        x2, mean2, std2 = self.pono(x2)
        x3 = F.relu(self.conv3(cat1))
        cat2 = torch.cat((x2, x3), 1)
        x3 = self.ms(x3, mean1, std1)
        x4 = F.relu(self.conv4(cat2))
        x4 = self.ms(x4, mean2, std2)
        cat3 = torch.cat((x1, x2, x3, x4), 1)
        k = F.relu(self.conv5(cat3))

        if k.size() != x.size():
            raise Exception("k, haze image are different size!")

        output = k * x - k + self.b
        output = F.relu(output)
        return output

class PONO(nn.Module):
    def __init__(self, input_size=None, return_stats=False, affine=True, eps=1e-5):
        super(PONO, self).__init__()
        self.return_stats = return_stats
        self.input_size = input_size
        self.eps = eps
        self.affine = affine

        if affine:
            self.beta = nn.Parameter(torch.zeros(1, 1, *input_size))
            self.gamma = nn.Parameter(torch.ones(1, 1, *input_size))
        else:
            self.beta, self.gamma = None, None

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = (x.var(dim=1, keepdim=True) + self.eps).sqrt()
        x = (x - mean) / std
        if self.affine:
            x = x * self.gamma + self.beta
        return x, mean, std

class MS(nn.Module):
    def __init__(self, beta=None, gamma=None):
        super(MS, self).__init__()
        self.gamma, self.beta = gamma, beta

    def forward(self, x, beta=None, gamma=None):
        beta = self.beta if beta is None else beta
        gamma = self.gamma if gamma is None else gamma
        if gamma is not None:
            y = x.mul(gamma)  # 使用非原地操作mul
        else:
            y = x  # 如果不乘gamma,保持y不变
        if beta is not None:
            y = y.add(beta)  # 使用非原地操作add

        return y



if __name__ == "__main__":
    # Generating Sample image
    image_size = (1, 3, 640, 640)
    image = torch.rand(*image_size)
    out = AOD_pono_net()
    out = out(image)
    print(out.size())


四、添加教程

本文的网络结构为无参数网络结构所以我们的注册方式很简单。

4.1 修改一

第一还是建立文件,我们找到如下ultralytics/nn/modules文件夹下建立一个目录名字呢就是'Addmodules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二 

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。


4.3 修改三 

第三步我门中到如下文件'ultralytics/nn/tasks.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!

到此就修改完成了,大家可以复制下面的yaml文件运行,无需修改parse_model方法。。


五、AODNet-PONO-Net的yaml文件和运行记录

5.1 AODNet-PONO-Net的yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, AOD_pono_net, []]  # 0-P1/2
  - [-1, 1, Conv, [64, 3, 2]]  # 1-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 2-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 4-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 6-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 8-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 7], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)


5.2 训练代码 

大家可以创建一个py文件将我给的代码复制粘贴进去,配置好自己的文件路径即可运行。

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('ultralytics/cfg/models/v8/yolov8-C2f-FasterBlock.yaml')
    # model.load('yolov8n.pt') # loading pretrain weights
    model.train(data=r'替换数据集yaml文件地址',
                # 如果大家任务是其它的'ultralytics/cfg/default.yaml'找到这里修改task可以改成detect, segment, classify, pose
                cache=False,
                imgsz=640,
                epochs=150,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD', # using SGD
                # resume='', # 如过想续训就设置last.pt的地址
                amp=False,  # 如果出现训练损失为Nan可以关闭amp
                project='runs/train',
                name='exp',
                )


5.3 训练过程截图 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

 专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备 


http://www.niftyadmin.cn/n/5424549.html

相关文章

TortoiseSVN 报错:The server unexpectedly closed the connetion

前言 CentOS7Linux 安装subversionmod_dav_svn,搭建subversion(svn)服务器 The server unexpectedly closed the connetion 解决办法 重启Apache服务 shell> systemctl restart httpd

漏洞复现-万户ezOFFICE系列

万户 安全情报,万户ezOFFICE协同管理平台SendFileCheckTemplateEdit-SQL注入漏洞万户OA DocumentEdit_unite.jsp 存在sql注入万户协同办公平台 ezoffice 未授权访问RCExml代码注入 XXE🔪freemarkerService XXE🔪GeneralWeb-xxeofficeserverservlet + attachmentserver RCE…

C#,文字排版的折行问题(Word-wrap problem)的算法与源代码

1、英文的折行问题 给定一个单词序列,以及一行中可以输入的字符数限制(线宽)。 在给定的顺序中放置换行符,以便打印整齐。 假设每个单词的长度小于线宽。 像MS word这样的文字处理程序负责放置换行符。 这个想法是要有平衡的线条。…

网络保护第七次作业

要求: 在FW5和FW3之间建立一条IPSEC通道,保证10.0.2.0/24网段可以正常访问到192.168.1.0/24网段。 实验步骤: 接口配置: FW1: FW4: 建立IPSEC VPN: FW4: IKE参数和IPSEC参数: F…

【C++算法模板】图的存储-邻接矩阵

文章目录 邻接矩阵洛谷3643 图的存储 邻接矩阵 邻接矩阵相比于上一篇博客邻接表的讲解要简单得多 数据结构,如果将二维数组 g g g 定义为全局变量,那默认初始化应该为 0 0 0 ,如果题目中存在自环,可以做特判, m e …

【PHP+代码审计】PHP基础——流程控制

🍬 博主介绍👨‍🎓 博主介绍:大家好,我是 hacker-routing ,很高兴认识大家~ ✨主攻领域:【渗透领域】【应急响应】 【Java、PHP】 【VulnHub靶场复现】【面试分析】 🎉点赞➕评论➕收…

前端基础篇-深入了解用 HTML 与 CSS 实现正文排版、正文布局

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 HTML 与 CSS 概述 2.0 HTML - 正文排版 2.1 视频标签 2.2 音频标签 2.3 段落标签 2.4 文本加粗标签 2.5 换行标签 2.6 CSS 样式 2.7 实现正文排版 3.0 HTML - …

QT进阶---------pro项目文件中的常用命令 (第三天)

1、命令一 决定exe可执行程序的生成路径CONFIG 作用:不使用默认路径,方便移植 CONFIG(debug, debug|release) {DESTDIR $$_PRO_FILE_PWD_/../../../debugXXXsystem } else {DESTDIR $$_PRO_FILE_PWD_/../../../realeaseXXXsystem } 是用于 Qt 项目…