在线难例挖掘:Online Hard Example Mining (OHEM)

详细链接:https://erogol.com/online-hard-example-mining-pytorch/

OHEM通过减少计算成本来选择难例,提高网络性能。它主要用于目标检测。假设你想训练一个汽车检测器,并且你有正样本图像(图像中有汽车)和负样本图像(图像中没有汽车)。现在你想训练你的网络。实际上,你会发现负样本数量会远远大于正样本数量,而且大量的负样本是比较简单的。因此,比较明智的做法是选择一部分对网络最有帮助的负样本(比较有难度的,容易被识别为正样本的)参与训练。难例挖掘就是用于选择对网络最有帮助的负样本的。

通常来说,通过对网络训练进行一定的迭代后得到临时模型,使用临时模型对所有的负样本进行测试,便可以发现那些loss很大的负样本实例,这些实例就是所谓的难例。但是这种查找难例的方法,需要很大的计算量,因为负样本图像可能会很多;另外这一方法可能是次优的,当你进行难例挖掘的时候,模型的权重是固定的,当前权重下的难例未必适用于接下了的迭代(这句话不太理解)。也就是说,这里假设你选择的所有难例负样本对下一次迭代都是有用的,直到下一次难例选择。这是一个不完美的假设,尤其是对于大型数据集而言。

OHEM通过批量难例选择选择来解决上述两个问题。给定batch-size K,前向传播保持不变,计算损失。然后,选择M(M<K)个高损失值的实例,仅使用这M个实例的损失进行反向传播。

OHEM的具体pytorch实现代码如下:

python">import torch as th                                                                 
                                                                                   
                                                                                   
class NLL_OHEM(th.nn.NLLLoss):                                                     
    """ Online hard example mining. 
    Needs input from nn.LogSotmax() """                                             
                                                                                   
    def __init__(self, ratio):      
        super(NLL_OHEM, self).__init__(None, True)                                 
        self.ratio = ratio                                                         
                                                                                   
    def forward(self, x, y, ratio=None):                                           
        if ratio is not None:                                                      
            self.ratio = ratio                                                     
        num_inst = x.size(0)                                                       
        num_hns = int(self.ratio * num_inst)                                       
        x_ = x.clone()                                                             
        inst_losses = th.autograd.Variable(th.zeros(num_inst)).cuda()              
        for idx, label in enumerate(y.data):                                       
            inst_losses[idx] = -x_.data[idx, label]                                 
        #loss_incs = -x_.sum(1)                                                    
        _, idxs = inst_losses.topk(num_hns)                                        
        x_hn = x.index_select(0, idxs)                                             
        y_hn = y.index_select(0, idxs)                                             
        return th.nn.functional.nll_loss(x_hn, y_hn)     


http://www.niftyadmin.cn/n/1212629.html

相关文章

pdf缩略图上传组件

之前仿造uploadify写了一个HTML5版的文件上传插件&#xff0c;没看过的朋友可以点此先看一下~得到了不少朋友的好评&#xff0c;我自己也用在了项目中&#xff0c;不论是用户头像上传&#xff0c;还是各种媒体文件的上传&#xff0c;以及各种个性的业务需求&#xff0c;都能得到…

龙芯怎么办?给龙芯的建议-goofegg

2019独角兽企业重金招聘Python工程师标准>>> 老杳师兄以这个标题给龙芯的发展提了些建议&#xff0c;咱也胡整一把&#xff0c;也给个建议。 龙芯发展到这一步&#xff0c;整机、笔记本等已经说明我们自己造出来的cpu能用&#xff0c;现在的关键就是打破这个联盟的时…

使用matplotlib在同一图像内绘制两个直方图

为了更为直观地看到两幅灰度图直方图分布的不同&#xff0c;想到把两幅图像的直方图放在同一图像中显示。 解决方案主要参考至&#xff1a;使用matplotlib同时绘制两个直方图 这里直接贴代码&#xff1a; bins np.linspace(0, 255, 256)plt.hist(img1.flatten(), bins, densi…

psd缩略图生成上传解决方案

第一点&#xff1a;Java代码实现文件上传 FormFile file manform.getFile(); String newfileName null; String newpathname null; String fileAddre "/numUp"; try { InputStream stream file.getInputStream();// 把文件读入 String filePath request.…

使用simulink做图像裁剪时由于矩阵大小可变导致出现的错误及其解决方法

报错及问题描述 报错部分设置及描述 在本例中&#xff0c;find bounding box 模块会得到四个变量r1, r2, c1, c2&#xff0c;img_crop 则根据这四个变量对图像进行裁剪。img_crop 中裁剪代码如下&#xff1a; function img_cropped img_crop(img, r1, r2, c1, c2)img_croppe…

windows dos 命令行 常用命令

gpedit.msc&#xff0d;&#xff0d;&#xff0d;&#xff0d;&#xff0d;--------------------组策略 sndrec32&#xff0d;&#xff0d;&#xff0d;&#xff0d;&#xff0d;&#xff0d;&#xff0d;------------------录音机 nslookup&#xff0d;&#xff0d;&#xf…

matlab中interp2双线性插值算法的实现原理及使用python简单实现双线性插值interp2算法

双线性插值算法基本原理 双线性插值算法的基本原理&#xff1a; 图1 双线性插值示意图图中绿色的点P为待插值得到的点&#xff0c;对点P进行插值需要用到Q11(x1,y1), Q12(x1,y2), Q21(x2, y1), Q22(x2, y2)的值&#xff0c;需要先在x方向线性插值得到R1(x,y1)与R2(x,y2)&…

unable to contact IP driver的错误解决办法

次方法适用xp 2003 server 当ping 127.0.0.1 出现 unable to contact IP driver ,error code2 当ipconfig /all 出现无法显示IP 地址 an internal error occureed: no supported 当在事件查看器中看到&#xff1a; 与 Network Location Awareness (NLA) 服务相依的 TCP/…