基于yolov5模型的目标检测蒸馏(LD+KD)

news/2024/7/10 1:05:57 标签: YOLO, 目标检测, 驱动开发, 人工智能

文章目录

  • 前言
  • 一、Distillation理解
    • 1、Knowlege distillation
    • 2、Feature distillation
    • 3、Location distillation
    • 4、其它蒸馏
  • 二、yolov5蒸馏模型构建
    • 1、构建teacher预测模型
    • 2、构建蒸馏loss
    • 3、蒸馏模型代码图示
      • 模型初始化
      • 模型蒸馏
  • 三、蒸馏模型实验
    • 1、工程数据测试
    • 2、voc2012开源数据测试


前言

最近在看有关蒸馏(Distillation)相关的内容,也就是需要大量的计算资源及庞大的数据集去支撑大模型,以蒸馏方式转为小模型,加速推理时间与降低模型内存,有利于部署。为此,我基于yolov5模型框架,修改代码,构建一个LD+KD的蒸馏模型,并公开源码于github,供读者学习。同时,我也正在构建多头蒸馏,后期将公开源码与文章解读。

源码链接:点击这里


一、Distillation理解

蒸馏是模型压缩方法,是通过教师模型知识传授学生模型的方法。一般教师模型是较大模型,效果较好,学生模型是较小模型,直接训练效果较差,
使用蒸馏模型传授教师知识,帮助提高学生模型性能。

1、Knowlege distillation

知识蒸馏(Knowledge Distillation,简记为 KD)是一种经典的模型压缩方法,分类模型论文较多,实际是蒸馏类型信息,通过teacher模型给出软标签给学生更多信息。如下图示意:
在这里插入图片描述

2、Feature distillation

特征蒸馏也是一种经典的模型压缩方法,实际是特征图的知识传递,通过teacher模型给出特征图给学生更多特征提取约束或信息量。如下图示意:

在这里插入图片描述

3、Location distillation

位置蒸馏也是一种经典的模型压缩方法,实际是位置点(如box)的知识传递,通过teacher模型给出预测位置给学生位置信息。该方法学术不多,比较新,如下图示意:
在这里插入图片描述

4、其它蒸馏

也有很多其它蒸馏方式,如通道蒸馏、无监督、对比等蒸馏方式,或最近bert蒸馏等。当然,介于我后期会出多头蒸馏文章,我引入论文图,如下:

在这里插入图片描述

二、yolov5蒸馏模型构建

我是基于yolov5模型蒸馏的,教师模型使用大尺寸模型m,学生模型使用小尺寸模型s。同时,我修改源码构建蒸馏模型结构,接下来我介绍如何构建基于yolov5模型构建蒸馏模型。其结构如下:
在这里插入图片描述

1、构建teacher预测模型

yolov5只需使用训练后的best.pt文件,通过attempt_load即可加载完预测模型初始化,至于attempt_load函数解析,相信很多博客已有说明,我不在解释,其teacher模型构建如下:

def create_teacher_model(weights,device):
    # device = torch.device('cuda:0')
    model=attempt_load(weights, map_location=device).eval()
    stride = int(model.stride.max())  # model stride
    names = model.module.names if hasattr(model, 'module') else model.names  # get class names
    teacher_model={'model':model,
                   'stride':stride,
                   'names':names
                   }
    return teacher_model

2、构建蒸馏loss

我基于yolov5模型构建LD+KD的蒸馏方式,借用yolov5原始模型loss计算方法,teacher模型输出为类的一个序列作为target类别,而原始yolov5模型
gt的target为类别数字非序列。为此,我们修改类别表示方式,使用序列替换数字,该位置在build_targets函数中,我做了大量修改,也将对应解释写在对应代码
中,其详情如下代码:

    def build_targets(self, p, targets):
        # Build targets for compute_loss(), input targets(image_id,class,x,y,w,h)
        na, nt = self.na, targets.shape[0]  #  每个点anchor数量(3), targets(每个batch中的标签个数)
        tcls, tbox, indices, anch ,tconf = [], [], [], [], [] # tcls表示类别,tbox表示box的坐标(x,y,w,h),indices表示图像索引,anch表示选取的anchor的索引
        gain = torch.ones(targets.shape[-1]+1, device=targets.device)  # normalized to gridspace gain
        ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)  # [na,nt] same as .repeat_interleave(nt)
        targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)  # append anchor indices
        # targets[image_id,x,y,w,h,conf,...cls,anchor_id]
        g = 0.5  # bias
        off = torch.tensor([[0, 0],
                            [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                            # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                            ], device=targets.device).float() * g  # offsets

        for i in range(self.nl):  # 循环3个特征层
            anchors, shape = self.anchors[i], p[i].shape
            gain[1:5] = torch.tensor(shape)[[3, 2, 3, 2]]  # xyxy gain

            # Match targets to anchors
            t = targets * gain  # shape(3,n,7),在特征图中恢复gt尺寸,[img_id,x,y,w,h,conf,...cls,anchor_id]
            if nt:
                # Matches,选择正负样本方法,通过gt与anchor的wh比列筛选
                r = t[..., 3:5] / anchors[:, None]  # wh ratio
                j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']  # compare
                # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t']  # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
                t = t[j]  # filter,通过筛除后获得正样本

                # Offsets 获取选择完成的box的*中心点*坐标-gxy(以图像左上角为坐标原点),并转换为以特征图右下角为坐标原点的坐标-gxi
                gxy = t[:, 1:3]  # grid xy
                gxi = gain[[1, 2]] - gxy  # inverse 特征图右下角为坐标原点
                # 分别判断box的(x,y)坐标是否大于1,并距离网格左上角的距离(准确的说是y距离网格上边或x距离网格左边的距离)小于0.5,
                # 如果(x,y)中满足上述两个条件,则选中.gxy.shape=[182,2],包含x,y,所以判别后转置得到j,k,2个结果
                # 对转换之后的box的(x,y)坐标分别进行判断是否大于1,并距离网格右下角的距离(准确的说是y距离网格下边或x距离网格右边的距离)距离小于0.5,
                # 如果(x,y)中满足上述两个条件,为Ture,
                j, k = ((gxy % 1 < g) & (gxy > 1)).T    # gxy>1,以左上角为坐标原点,表示排除上边与左边边缘格子
                l, m = ((gxi % 1 < g) & (gxi > 1)).T    # gxi>1同理,以右下角为坐标原点,排除右边与下边边缘格子
                j = torch.stack((torch.ones_like(j), j, k, l, m))  # 第一行为自己本身正样本值
                t = t.repeat((5, 1, 1))[j]  # 根据j挑选正样本,但未移动相邻网格
                offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]  # 根据j处理对应正样本偏置(确定移动相邻网格)
            else:
                t = targets[0]
                offsets = 0

            # Define  b=img_id,c=[...cls],conf=conf-->预测置信度 gxy=grid xy, gwh=grid wh, a=anchors_id
            b=t[:,0].long()
            c=t[:,6:-1]
            conf=t[:,5]
            gxy= t[:,1:3]
            gwh=t[:,3:5]
            a=t[:,-1].long()

            gij = (gxy - offsets).long()  # xy与offsets对应
            gi, gj = gij.T  # grid indices

            # Append
            indices.append((b, a, gj.clamp_(0, shape[2] - 1), gi.clamp_(0, shape[3] - 1)))  # image_id, anchor_id,与网格坐标grid_x,grid_y
            tbox.append(torch.cat((gxy - gij, gwh), 1))  # box 获取(x,y)相对于网格点的偏置,以及box的宽高
            anch.append(anchors[a])  # anchors  获得对应的anchor
            tcls.append(c)  # class 获得对应类别
            tconf.append(conf)

        return tcls, tbox, indices, anch,tconf

同时,我们也修改计算类别loss位置的one shot方式,yolov5原模型的target为数字需要转换one shot编码,而teacher模型给的target本身为序列标签,无需转换,因此修改内容如下:

原代码:

# Classification
if self.nc > 1:  # cls loss (only if multiple classes)
    t = torch.full_like(ps[:, 5:], self.cn, device=device)  # targets
    t[range(n), tcls[i]] = self.cp   # 这里将其one-short编码-->也说明类从0开始
    lcls += self.BCEcls(ps[:, 5:], t)  # BCE

修改代码:

lcls += self.BCEcls(ps[:, 5:], tcls[i])  # BCE

3、蒸馏模型代码图示

模型初始化

模型与loss的初始化,如下图示:
在这里插入图片描述

模型蒸馏

学生模型硬标签loss计算、teacher-student的软标签loss计算,如此实现yolov5的KD+LD蒸馏方式,如下图示列:
在这里插入图片描述

三、蒸馏模型实验

1、工程数据测试

教师模型使用yolov5m模型-学生与蒸馏模型使用yolov5s模型,测试结果如下:
在这里插入图片描述
PR曲线图:

在这里插入图片描述
map0.5与map0.5:0.95均比学生模型高一点点。

2、voc2012开源数据测试

进一步实验测试,采用开源数据测试。
教师模型使用yolov5m模型-学生与蒸馏模型使用yolov5s模型,测试结果如下:
在这里插入图片描述
PR曲线图:

在这里插入图片描述
蒸馏模型在map0.5表现较差0.007个点,但map0.5:0.95却高了0.004个点。


http://www.niftyadmin.cn/n/5005835.html

相关文章

自动驾驶中间件

自动驾驶中间件 1. 什么是中间件2. 中间件的分类3. 自动驾驶为什么需要中间件4. 通信中间件 Reference&#xff1a; 自动驾驶中间件&#xff1a;量产落地的关键技术通俗易懂的告诉你什么是中间件 对于初入自动驾驶行业的人来说&#xff0c;各色各样的新型传感器、线控系统、芯…

centos离线安装包

一、配置阿里云yum镜像源 # 删除原yum源 rm -rf /etc/yum.repos.d/*# wget命令下载: wget [options] [url] wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repo# 清除yum缓存 yum clean all# 缓存阿里云镜像 yum makecache 二、yum离线…

[HDCTF 2023]YamiYami

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言涉及知识点解题详细过程session伪造反弹shell 前言 从暑假末尾一直搁置&#xff0c;当时卡在反弹shell搞得离flag就差一步。不过最近一两天学习完反弹shell的知…

为安全带来光明:光耦继电器的 10 种救生应用

在安全性和可靠性至关重要的世界中&#xff0c;光耦继电器已成为推动各行业进步的关键技术。这些卓越的设备经常在主流新闻中被忽视&#xff0c;但它们一直在默默地为保障生命和提高整体运营效率的关键系统提供动力。今天&#xff0c;我们重点介绍光耦继电器的十种救生应用&…

sentinel加密狗使用及规则配置

Sentinel加密狗是一种硬件加密设备&#xff0c;用于保护软件应用程序免受未经授权的访问和复制。它可以提供软件许可管理、访问控制和数据保护等功能。下面是Sentinel加密狗的使用及规则配置的相关介绍。 Sentinel加密狗的使用 插入加密狗&#xff1a;将Sentinel加密狗插入计算…

【Ptyhon】关于自定义对象的Json序列化和反序列化

背景 最近使用Ptyon爬虫数据时&#xff0c;遇到对象无法转换为JSON序列化对象问题TypeError: Object of type Main is not JSON serializable 意思&#xff1a; 就是对象不能转换为JSON序列化对象 原因&#xff1a; 是对象没有自定义实现转换为JSON序列化对象的方法没有。 实…

SV中的随机化约束

1.多个变量相互依赖&#xff0c;如何进行随机化&#xff1f; 采用solve before &#xff08;1&#xff09;constraint c_valid_write_data {solve access_type before data;solve burst_length before data;(access_type TVIP_AXI_WRITE_ACCESS) -> data.size() burst…

npm修改为国内镜像

npm config get registry 查看下载地址&#xff0c;默认是https://registry.npmjs.org/ 这是国外的地址&#xff0c;npm下载很慢 npm config set registry https://registry.npmmirror.com/ 使用此指令&#xff0c;修改为淘宝镜像&#xff0c;输入指令后检查是否修改成功 …