【目标检测】SSD损失函数详解

最近看看这个古早的目标检测网络,看了好多文章,感觉对损失函数的部分讲得都是不很清楚得样子,所以自己捋一下。

首先,SSD 得损失函数由两部分得加权和,一部分是定位损失 L l o c L_{loc} Lloc,另一部分是分类置信率损失 L c o n f L_{conf} Lconf。公式一般写成: L ( x , c , l , g ) = 1 N ( L c o n f ( x , c ) + α L l o c ( x , l , g ) ) L(x,c,l,g) = \frac{1}{N}\left( L_{conf}(x, c) + \alpha L_{loc}(x, l, g) \right) L(x,c,l,g)=N1(Lconf(x,c)+αLloc(x,l,g))总结来说,定位损失 L l o c L_{loc} Lloc 是使用 smooth L1 loss 来计算的,而分类置信率损失 L c o n f L_{conf} Lconf 是通过 softmax loss 来计算的。


定位损失 L l o c L_{loc} Lloc

接下来就分开讲讲两部分,首先是定位损失 L l o c L_{loc} Lloc, 它的公式可以写成: L l o c ( x , l , g ) = ∑ i ∈ P o s N ∑ m ∈ { c x , c y , w , h } x i j p ⋅ s m o o t h L 1 ( l i m − g ^ j m ) L_{loc}(x, l, g) = \sum_{i \in Pos}^{N} \sum_{m \in \left\{ cx, cy, w, h \right\}} x_{ij}^p \cdot \mathbf{smooth_{L1}}(l_i^m - \hat{g}_j^m) Lloc(x,l,g)=iPosNm{cx,cy,w,h}xijpsmoothL1(limg^jm)下昂西介绍一下里面的各种参数的含义:

  • i ∈ P o s i \in Pos iPos:第一个求和下的 P o s Pos Pos 是一个集合,我们知道在训练的时候,会根据 IOU(SSD 里好像是大于 0.5) 对 Default box(其实和 anchor 的含义一样)与 Ground truth box(后面统称 gt box)进行匹配,如果 第 i i i 个 default box 与 第 j j j 个 gt box 匹配上了,那么这个 default box i i i 就会被放入 P o s Pos Pos 的集合中,表示 positive,也就是被标记成了正样本。
  • N N N:是正样本集合 P o s Pos Pos 的总数,表示有 N N N 个 default box 与 gt box 匹配上了。
  • m ∈ { c x , c y , w , h } m \in \left\{ cx, cy, w, h \right\} m{cx,cy,w,h}:这四个值是 anchor 的位置参数,表示中心点的坐标和 anchor 的尺寸。
  • x i j p x_{ij}^p xijp:可以理解为唯一标识 flag,如果 default box i i i 与 gt box j j j 是匹配的,gt box 的类别是 p p p,则为1,否则 0。
  • l i m l_i^m lim:是预测值,也就是 bounding box 与 default box 的偏移值,不是真实的坐标,具体的转换在下面给出。
  • g ^ j m \hat{g}_j^m g^jm:是真实值,是 gt box 与 default box 的偏移值。

偏移值的计算

回到前面提到的偏移值,如果 default box i i i 的位置参数是 { d i c x , d i c y , d i w , d i h } \{ d_i^{cx}, d_i^{cy}, d_i^{w}, d_i^{h} \} {dicx,dicy,diw,dih},gt box j j j 的位置参数是 { g j c x , g j c y , g j w , g j h } \{ g_j^{cx}, g_j^{cy}, g_j^{w}, g_j^{h} \} {gjcx,gjcy,gjw,gjh},我们就可以算出真实值对应的偏移量 { g ^ j c x , g ^ j c y , g ^ j w , g ^ j h } \{ \hat{g}_j^{cx}, \hat{g}_j^{cy}, \hat{g}_j^{w}, \hat{g}_j^{h} \} {g^jcx,g^jcy,g^jw,g^jh}
g ^ j c x = ( g j c x − d i c x ) d i w           g ^ j w = log ⁡ ( g j w d i w ) g ^ j c y = ( g j c y − d i c y ) d i h           g ^ j h = log ⁡ ( g j h d i h ) \begin{align*} \hat{g}_j^{cx} = \frac{(g_j^{cx} - d_i^{cx})}{d_i^w} & \space\space\space\space\space\space\space\space\space \hat{g}_j^{w} = \log\left( \frac{g_j^w}{d_i^w} \right)\\ \hat{g}_j^{cy} =\frac{ (g_j^{cy} - d_i^{cy})}{d_i^h} & \space\space\space\space\space\space\space\space\space \hat{g}_j^{h} = \log\left( \frac{g_j^h}{d_i^h} \right) \end{align*} g^jcx=diw(gjcxdicx)g^jcy=dih(gjcydicy)         g^jw=log(diwgjw)         g^jh=log(dihgjh)同理,我们也可以通过预测得到的 bounding box 参数 { b i c x , b i c y , b i w , b i h } \{b_i^{cx}, b_i^{cy}, b_i^{w}, b_i^{h} \} {bicx,bicy,biw,bih} ,来计算得到 bounding box 的偏移量,也就是预测值 l l l l i c x = ( b i c x − d i c x ) d i w           l i w = log ⁡ ( b i w d i w ) l i c y = ( b i c y − d i c y ) d i h           l i h = log ⁡ ( b i h d i h ) \begin{align*} l_i^{cx} = \frac{(b_i^{cx} - d_i^{cx})}{d_i^w} & \space\space\space\space\space\space\space\space\space l_i^{w} = \log\left( \frac{b_i^w}{d_i^w} \right)\\ l_i^{cy} =\frac{ (b_i^{cy} - d_i^{cy})}{d_i^h} & \space\space\space\space\space\space\space\space\space l_i^{h} = \log\left( \frac{b_i^h}{d_i^h} \right) \end{align*} licx=diw(bicxdicx)licy=dih(bicydicy)         liw=log(diwbiw)         lih=log(dihbih)

中心点的偏移量计算是很好理解的,为什么宽高的偏移量要用 log 函数来算呢?

smooth L1 loss

SSD 是用到 smooth L1 loss 来计算真实值与预测值之间的差异: s m o o t h L 1 ( x ) = { 0.5 x 2  if  ∣ x ∣ < 1 ∣ x ∣ − 0.5  otherwise  \mathbf{smooth_{L1}}(x) =\begin{cases} 0.5x^2 & \text{ if } |x|< 1 \\ |x| -0.5 & \text{ otherwise } \end{cases} smoothL1(x)={0.5x2x0.5 if x<1 otherwise 

这篇文章我觉得解释得挺清晰的,比较了 L1 loss, L2 loss 和 smooth L1 loss 三者之间的优劣。也提到了 loc loss 的演进。

bounding box 回归损失函数,也就是用于定位边界框的损失函数,其演进线路如下:

Smooth L1 loss
IOU loss
GIOU loss
DIOU loss
CIOU loss

我记得 YOLOv1 的损失函数,在计算位置损失的时候,还是使用的欧式距离,也就是所谓的 L2 loss。

L1 loss 是求两个数之间的绝对值距离,导数是常数(小于 0 则为 -1,大于等于 0 则为 1),在零点处是不平滑的;

L2 loss 是两个数之间差的平方,导数是 2 x 2x 2x(也可以看出,受到 x x x 的影响很大),但是在零点处是平滑的。多个 L2 loss 求和再平均也叫做 MSE loss (Mean Square Error)。

而我们的主角 smooth L1 loss,如名字所见,是平滑版的 L1 loss,导数为: d   s m o o t h L 1 ( x ) d x = { x  if  ∣ x ∣ < 1 ± 1  otherwise  \frac{ \mathrm{d} \space \mathbf{smooth_{L1}}(x)}{\mathrm{d}x} =\begin{cases} x & \text{ if } |x|< 1 \\ \pm1 & \text{ otherwise } \end{cases} dxd smoothL1(x)={x±1 if x<1 otherwise  在零点附近都是平滑的,而且在其它区间都是常数,也不会出现 L2 loss 随着 x x x 的增大而在损失函数中占据主导地位。


置信率损失 L c o n f L_{conf} Lconf

下面就讲一下分类的置信率损失 L c o n f L_{conf} Lconf,完整的公式如下: L c o n f ( x , c ) = − ∑ i ∈ P o s N x i j p log ⁡ c ^ i p − ∑ i ∈ N e g log ⁡ c ^ i 0     where     c ^ i p = exp ⁡ ( c i p ) ∑ p exp ⁡ ( c i p ) L_{conf}(x, c) = - \sum_{i \in Pos}^N x_{ij}^p\log{\hat{c}_i^p} - \sum_{i\in Neg} \log{\hat{c}_i^0} \space\space\space\space \text{where} \space\space\space\space \hat{c}_i^p=\frac{\exp{(c_i^p)}}{\sum_p \exp{(c_i^p)}} Lconf(x,c)=iPosNxijplogc^ipiNeglogc^i0    where    c^ip=pexp(cip)exp(cip),从公式的形态,可以看出来是二元交叉熵。没错,其实 softmax loss 就相当于交叉熵和 softmax 的组合,先看看最后的 softmax 公式: c ^ i p = exp ⁡ ( c i p ) ∑ p exp ⁡ ( c i p ) \hat{c}_i^p=\frac{\exp{(c_i^p)}}{\sum_p \exp{(c_i^p)}} c^ip=pexp(cip)exp(cip)

  • c i p c_i^p cip: 对于分类的部分,一般网络的全连接层会输出 P P P 个类别的向量,在 SSD 因为要考虑背景(背景是分类 0),这个长度为 P + 1 P+1 P+1 的向量经过 softmax 之后,所有值的和会被限制为 1。其中置信率最大的值即是 c i p c_i^p cip(目前这个值还没经过归一化),这表示 anchor i i i 是的类别是 p p p 的可能性最大。
  • c ^ i p \hat{c}_i^p c^ip:是通过对 c i p c_i^p cip 进行 softmax 而得到的,表示 anchor i i i 是分类 p p p 的概率,值是位于 0~1 之间的。

然后就是主公式的各个参数的具体含义:

  • x i j p x_{ij}^p xijp:含义同上面位置损失提到的,如果 anchor i i i 与 gt box j j j 是匹配的,gt box 的类别是 p p p,则为1,否则 0。
  • c ^ i 0 \hat{c}_i^0 c^i0:就是背景的分类概率(负样本)。

其实这个公式也没什么难理解的。


大致就是这么多了,如果大家有什么不清楚的地方或者是文章哪里写错了,欢迎评论留言。


http://www.niftyadmin.cn/n/5164070.html

相关文章

实现http请求-hutool

hutool工具HttpUtil 使用hutool就能实现http请求&#xff0c;官方案例 // 最简单的HTTP请求&#xff0c;可以自动通过header等信息判断编码&#xff0c;不区分HTTP和HTTPS String result1 HttpUtil.get("https://www.baidu.com");// 当无法识别页面编码的时候&…

2023年的低代码:数字化、人工智能、趋势及未来展望

本文由葡萄城技术团队发布。转载请注明出处&#xff1a;葡萄城官网&#xff0c;葡萄城为开发者提供专业的开发工具、解决方案和服务&#xff0c;赋能开发者。 前言 正如许多专家预测的那样&#xff0c;低代码平台在2023年将展现更加强劲的势头。越来越多的企业正在纷纷转向低代…

vue2中的mixins混入

目录 引言&#xff1a; 一、什么是混入&#xff1f; mixins 基础 选项合并 全局混入 自定义选项合并策略 二、mixins混入的优势 三、mixins混入的最佳实践 结论&#xff1a; 引言&#xff1a; 在Vue.js开发中&#xff0c;我们经常会遇到一些场景&#xff0c;多个组件…

To create the 45th Olympic logo by using CSS

You are required to create the 45th Olympic logo by using CSS. The logo is composed of five rings and three rectangles with rounded corners. The HTML code has been given. It is not allowed to add, edit, or delete any HTML elements. 私信完整源码 <!DOCT…

基于SSM的旅游管理系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

事项要素模型

环境问题 普通Linux环境没什么问题 arm&#xff0b;麒麟系统 python3.7tensorflow1.14.0(tensorflow-1.14.0-cp37-none-linux_aarch64.whl)keras2.3.0h5py2.10.0(conda install h5py2.10.0)protobuf3.20.0 docker操作&#xff1a;docker cp xxx 6a8:/root docker exec -it 6…

【ES专题】Logstash与FileBeat详解以及ELK整合详解

目录 前言阅读对象阅读导航前置知识笔记正文一、ELK架构1.1 经典的ELK1.2 整合消息队列Nginx架构 二、LogStash介绍2.1 Logstash核心概念2.1.1 Pipeline2.1.2 Event2.1.3 Codec (Code / Decode)2.1.4 Queue 2.2 Logstash数据传输原理2.3 Logstash的安装&#xff08;以windows为…

研发项目管理改进方法有哪些

研发项目管理改进方法有哪些 1.多项目协同管理 有可视化的项目进度管理环境&#xff0c;可通过表格视图或施工进度表&#xff0c;项目成员可以共同进行实时项目计划的编制。可以修改项目的任务约束、重大事件以及开始结束日期。还可进行任务分解、任务约束、不限层任务树、任务…