YOLO系列正负样本分配策略

news/2024/7/10 1:56:46 标签: 计算机视觉, 人工智能, 目标检测, YOLO

1、YOLOv3

使用MaxIoUAssigner策略来给gt分配样本,基本上保证每个gt都有唯一的anchor对应,匹配的原则是该anchor与gt的IOU最大且大于FG_THRESH,这种分配制度会导致正样本比较少,cls和bbox分支训练起来可能比较慢。在剩余的anchor中,如果有anchor跟所有gt的IOU都小于BG_THRESH,则将此类anchor设为负样本,如果有anchor跟所有gt的IOU大于BG_THRESH且小于FG_THRESH,则忽视掉此类anchor。

下面以Towards-Realtime-MOT/utils/utils.py中的代码为例:

def build_targets_thres(target, anchor_wh, nA, nC, nGh, nGw):
    ID_THRESH = 0.5
    FG_THRESH = 0.5
    BG_THRESH = 0.4
    nB = len(target)  # number of images in batch
    assert(len(anchor_wh)==nA)

    tbox = torch.zeros(nB, nA, nGh, nGw, 4).cuda()  # batch size, anchors, grid size
    tconf = torch.LongTensor(nB, nA, nGh, nGw).fill_(0).cuda()
    tid = torch.LongTensor(nB, nA, nGh, nGw, 1).fill_(-1).cuda() 
    for b in range(nB):
        t = target[b]
        t_id = t[:, 1].clone().long().cuda()
        t = t[:,[0,2,3,4,5]]
        nTb = len(t)  # number of targets
        if nTb == 0:
            continue

        gxy, gwh = t[: , 1:3].clone() , t[:, 3:5].clone()
        gxy[:, 0] = gxy[:, 0] * nGw
        gxy[:, 1] = gxy[:, 1] * nGh
        gwh[:, 0] = gwh[:, 0] * nGw
        gwh[:, 1] = gwh[:, 1] * nGh
        gxy[:, 0] = torch.clamp(gxy[:, 0], min=0, max=nGw -1)
        gxy[:, 1] = torch.clamp(gxy[:, 1], min=0, max=nGh -1)

        gt_boxes = torch.cat([gxy, gwh], dim=1)
        
        anchor_mesh = generate_anchor(nGh, nGw, anchor_wh)
        anchor_list = anchor_mesh.permute(0,2,3,1).contiguous().view(-1, 4)
        iou_pdist = bbox_iou(anchor_list, gt_boxes)
        iou_max, max_gt_index = torch.max(iou_pdist, dim=1)  ## 取出每个pre与gt的IOU最大值

        iou_map = iou_max.view(nA, nGh, nGw)       
        gt_index_map = max_gt_index.view(nA, nGh, nGw)
        
        id_index = iou_map > ID_THRESH
        fg_index = iou_map > FG_THRESH  ## 若IOU大于FG_THRESH,则为foreground
        bg_index = iou_map < BG_THRESH  ## 若IOU小于BG_THRESH,则为background
        ign_index = (iou_map < FG_THRESH) * (iou_map > BG_THRESH)  ## 若IOU大于BG_THRESH并小于FG_THRESH,则ignore
        tconf[b][fg_index] = 1
        tconf[b][bg_index] = 0
        tconf[b][ign_index] = -1

        gt_index = gt_index_map[fg_index]
        gt_box_list = gt_boxes[gt_index]
        gt_id_list = t_id[gt_index_map[id_index]]
        if torch.sum(fg_index) > 0:
            tid[b][id_index] =  gt_id_list.unsqueeze(1)
            fg_anchor_list = anchor_list.view(nA, nGh, nGw, 4)[fg_index] 
            delta_target = encode_delta(gt_box_list, fg_anchor_list)
            tbox[b][fg_index] = delta_target
    return tconf, tbox, tid

2、YOLOv4

yolov4为了增加正样本,采用multi anchor策略,只要大于IoU阈值的anchor,都视为正样本

3、YOLOv5

确定gt是否匹配当前特征图的anchors

因为yolov5是多尺度预测,所以首先需要确定gt应该跟哪个尺度的特征图上的anchor进行匹配。规则为:gt的宽高分别与当前尺度下的anchor的宽高进行比较,如果它们的比例在[1/4,4]之间,则当前gt可以与当前尺度下的anchor进行匹配。

下面以yolov5/utils/loss.py代码为例:

# wh ratio
r = t[..., 4:6] / anchors[:, None]

# compare
j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']

# filter
t = t[j]

将与gt中心点邻近的两个点也作为正样本点(即总共有3个正样本点) 

将gt所在的中心点视作一个ceil,并将该ceil划分成4个象限,如果gt的中心点位于该ceil中的第四象限,则将该ceil右边的单元格以及下边的单元格也视为正样本点(如下图的黄色单元格)

anchors, shape = self.anchors[i], p[i].shape
gain[2:6] = torch.tensor(shape)[[3, 2, 3, 2]]  ## 取当前尺度特征图的w、h,如果原图尺寸为[640,640],下采样8倍后,特征图的尺寸就变成[80,80],则gain[2:6]=[80,80,80,80]

t = targets * gain  ## 将[0,1]之间的坐标映射到[0,80]上
g = 0.5  ## 偏移量,用于判断gt的x、y坐标在单元格的哪个象限

if nt:
    ## 求gt的宽高与anchor的宽高的比值
    r = t[..., 4:6] / anchors[:, None]
    ## 判断比值是否在[1/4,4]这个范围内
    j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']
    ## 挑选符合条件的gt
    t = t[j]  

    gxy = t[:, 2:4]  ## gt的x、y坐标
    gxi = gain[[2, 3]] - gxy  ## gt的x、y坐标到w、h的距离

    ## 由于gxy+gxi=[80,80],因此j和l互斥,k和m互斥
    j, k = ((gxy % 1 < g) & (gxy > 1)).T
    l, m = ((gxi % 1 < g) & (gxi > 1)).T
    j = torch.stack((torch.ones_like(j), j, k, l, m))

    ## 将t复制成5份(gt中心点所在单元格加上该单元格的左、上、右、下的单元格)
    t = t.repeat((5, 1, 1))[j]
    offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]

4、YOLOX

yolox的正负样本分配策略的代码可以参考:YOLOX中的SimOTA_Cassiel_cx的博客-CSDN博客

这里简单记录下SimOTA(simple optimal transport assignment)的步骤:

  1. 确定候选正样本计算anchor_box的中心点,若anchor_box的中心点落在gt内或者在以gt的中心点为圆心,以center_radius为半径的圆内,就将该anchor视为候选正样本。如下图所示,红色框为gt,绿色框为蓝点预测的anchor_box,假设该anchor_box的中心点为绿点,由于绿点在红框内,因此该anchor_box便作为候选正样本
  2. 计算候选正样本跟gt之间的cls loss和iou loss并以一定比例加权求和,作为cost,计算公式如下(如果anchor_box的中心点不在gt内,该anchor_box与gt的cost就会很大,对应代码中的100000.0 * (~is_in_boxes_and_center)
    cost = (
               pair_wise_cls_loss
               + 3.0 * pair_wise_ious_loss
               + 100000.0 * (~is_in_boxes_and_center)
           )
  3. 对候选正样本与gt之间的iou进行大小排序,挑选前10个(可自己调整)最大iou,对这些iou求和并取整,该数值作为当前gt的dynamic_k
  4. 对于该gt,取前dynamic_k个最小的cost的候选正样本,作为正样本
  5. 如果存在一个正样本匹配多个gt的情况,则选cost较小的gt来匹配(如,某anchor_box与gt_1和gt_2的cost分别为0.8和0.2,则取消该anchor_box与gt_1的匹配,只匹配gt_2

5、YOLOv7

yolov7的正负样本分配策略为yolov5和yolox的结合体。首先使用yolov5的策略去挑选候选正样本,再使用yolox中的SimOTA策略去从候选正样本中挑选正样本。

【参考文章】

目标检测正负样本区分策略和平衡策略总结(一) - 知乎

yolov7正负样本分配详解 - 知乎 

深入浅出Yolo系列之Yolox核心基础完整讲解 - 知乎


http://www.niftyadmin.cn/n/698627.html

相关文章

pytorch 中 dim 的-1,0,1,2 的意义 详解

对于3维矩阵&#xff0c;dim为-1时 与 dim为2时 的效果是一样的。dim为0时 从0维度&#xff0c; 下图 是三维实例 图的目的是 可以由一个想象的空间。 下面代码 与上图关系不大 >>> ab torch.tensor([[[0,1,2,3],[1,2,3,4]],[[2,3,4,5],[4,5,6,7]],[[5,6,7,8],…

网工内推 | 互联网大厂,字节跳动招资深网工,最高40k*15薪

01 北京字节跳动 招聘岗位&#xff1a;资深无线网络工程师 职责描述&#xff1a; 1、负责字节跳动全球办公室-无线网络运维保障工作&#xff1b; 2、负责字节跳动所属线下门店和电商库房的无线网络运维保障工作&#xff1b; 3、主导集中性无线网络问题治理&#xff1b; 4、负责…

系统架构设计师 6:数据库设计

一、数据库系统 数据库系统&#xff08;DataBase System, DBS&#xff09;是一个采用了数据库技术&#xff0c;有组织地、动态地存储大量相关联数据&#xff0c;从而方便多用户访问的计算机系统。广义上讲&#xff0c;DBS包括了数据库管理系统&#xff08;DBMS&#xff09;。 …

车载 Android开发面试习题

随着车联网技术的不断发展和普及&#xff0c;越来越多的汽车厂商开始使用 Android 操作系统作为车载娱乐和信息娱乐系统的核心。在这个趋势下&#xff0c;车载 Android 应用开发程序员的需求也日益增加。 像一些车企大厂不惜给出 30K~60K的高资&#xff0c;去广招这方面的技术人…

2023-06-29:redis中什么是热点Key?该如何解决?

2023-06-29&#xff1a;redis中什么是热点Key&#xff1f;该如何解决&#xff1f; 答案2023-06-29&#xff1a; 在Redis中&#xff0c;经常被访问的key被称为热点key。 产生原因和危害 原因 热点key问题产生的原因可以归纳为以下两种情况&#xff1a; 用户对于某些数据的…

leecode 数据库:1070. 产品销售分析 III

导入数据&#xff1a; Create table If Not Exists Sales (sale_id int, product_id int, year int, quantity int, price int); Create table If Not Exists Product (product_id int, product_name varchar(10)); Truncate table Sales; insert into Sales (sale_id, product…

利用QtRO解决QSerialPort跨线程调用问题

目录 一、关于QT多线程串口通信问题 二、QtRo介绍 三、串口控制光源 3.1接口 3.2 服务端 3.3 客户端 一、关于QT串口跨线程使用问题 在机器视觉项目中经常使用串口通信&#xff0c;如光源控制。当串口的创建和读写不在同一个线程时&#xff0c;qt会提示不能跨线程使用的警…

记一次 .NET 某埋线管理系统 崩溃分析

一&#xff1a;背景 1. 讲故事 经常有朋友跟我反馈&#xff0c;说看你的文章就像看天书一样&#xff0c;有没有一些简单入手的dump 让我们先找找感觉&#xff0c;哈哈&#xff0c;今天就给大家带来一篇入门级的案例&#xff0c;这里的入门是从 WinDbg 的角度来阐述的&#xf…