【知识蒸馏】什么是知识蒸馏、方法解读

news/2024/7/10 1:51:50 标签: 深度学习, 人工智能, 目标检测

【知识蒸馏】什么是知识蒸馏、方法解读

文章目录

  • 【知识蒸馏】什么是知识蒸馏、方法解读
    • 1. 前言
      • 1.1 由来
      • 1.2 定义
      • 1.3 可蒸馏(迁移)的知识
    • 2. 蒸馏方法介绍
      • 2.1 知识的种类、蒸馏的种类
      • 2.2 “知识”的种类
        • 2.2.1 基于响应的知识__Distilling the Knowledge in a Neural Network(提出 vanilla knowledge)
        • 2.2.2 基于feature的知识
        • 2.2.3 基于relation的知识
      • 2.3 “蒸馏”的种类
        • 2.3.1 离线蒸馏
        • 2.3.2 在线蒸馏
        • 2.3.3 自蒸馏
    • 3. 参考

1. 前言

1.1 由来

科学理论只有进行实践才能转化成产品,我们往往更看重知识转化为经济效益。轻量化网络就是将我们模型能够在现有的有限硬件条件下实现落地。

  • 假如你用深度学习模型在服务器上达到了很好的预测效果,实际上是很多网络(Resnet,Vgg等)需要的计算量和计算资源很大,这对硬件的要求很高
  • 然而,你可能只有一个1050(悲
  • 在应用服务上,我们很容易见到这些智能产品,要直接把模型算法部署到这些小的设备上是困难的
  • 于是,我们希望从一个大的模型上得到知识转移给小的模型,而小模型能达到跟大模型相当的效果,因此 知识蒸馏 就诞生了。

1.2 定义

知识蒸馏:就是把一个大模型的知识迁移到小模型上,因为大模型虽然能达到较高的精度,但它的训练往往需要大量的资源和时间,小模型的训练需要的资源少,训练速度快,但它的精度往往不如大模型。显然,不是每个人都拥有足够的资源训练大模型,为了使用更少的资源、更快的速度,并且精度不能太差,不如让小模型Student学习大模型Teacher的知识,用更少的资源就能达到不错的精度。

在这里插入图片描述

1.3 可蒸馏(迁移)的知识

需要明确的是,教师网络或给定的预训练模型中包含哪些可迁移的知识?基于常见的深度学习任务,可迁移知识可以列举为:

  • 中间层特征:浅层特征注重纹理细节,深层特征注重抽象语义;
  • 任务相关知识:如分类概率分布,目标检测涉及的实例语义、位置回归信息等;
  • 表征相关知识:强调特征表征能力的迁移,相对通用、任务无关(Task-agnostic);表征间相关性,如相似度、Relation等;

2. 蒸馏方法介绍

2.1 知识的种类、蒸馏的种类

在这里插入图片描述

2.2 “知识”的种类

2.2.1 基于响应的知识__Distilling the Knowledge in a Neural Network(提出 vanilla knowledge)

教师网络最后一层的输出,直接模仿教师最后的预测,该方法简单高效。

  • Hinton的文章 “Distilling the Knowledge in a Neural Network” 首次提出了知识蒸馏(暗知识提取)的概念,通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。
    在这里插入图片描述
  • 如上图所示,教师网络(左侧)的预测输出除以温度参数(Temperature)之后、再做Softmax计算,可以获得软化的概率分布(软目标或软标签),数值介于0~1之间,取值分布较为缓和。
    • Temperature数值越大,分布越缓和;
    • 而Temperature数值减小,容易放大错误分类的概率,引入不必要的噪声。
    • 针对较困难的分类或检测任务,Temperature通常取1,确保教师网络中正确预测的贡献。
  • 硬目标则是样本的真实标注,可以用One-hot矢量表示。
  • Total loss设计为软目标与硬目标所对应的交叉熵的加权平均(表示为KD loss与CE loss),其中软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。

联合训练:教师网络与学生网络也可以联合训练(对应论文)。
在这里插入图片描述

  • 此时教师网络的暗知识及学习方式都会影响学生网络的学习,具体如下(式中三项分别为教师网络Softmax输出的交叉熵loss、学生网络Softmax输出的交叉熵loss、以及教师网络数值输出与学生网络Softmax输出的交叉熵loss):
    在这里插入图片描述

2.2.2 基于feature的知识

损失函数:
在这里插入图片描述
Φ t Φ_t Φt 表示如果教师和学生模型的feature map的shape不一样时,把shape变成一样。
L F L_F LF:相似性函数,用于匹配教师和学生模型的feature map;

问题:

  • (1)怎么选择合适的hint层;
  • (2)由于hint层和guided层的尺寸不一样,需要研究怎么去研究匹配两者之间的特征表征。
    在这里插入图片描述

2.2.3 基于relation的知识

基于relation的知识:在不同层或者数据样本的关系。

  • FSP 矩阵(Gram 矩阵):通过两个层之间的特征图做内积,总结特征图之间的关系,使用特征图之间的联系作为知识。(2017)
  • 奇异值分解(SVD)KD用来提取键值信息
  • 多教师网络的知识用每个教师模型的logits和feature作为节点做了两个图,通过logits和表征图作为KD的知识(Multi-head graph-based KD)
    在这里插入图片描述
    在这里插入图片描述

在这里插入图片描述

2.3 “蒸馏”的种类

根据学生是不是和教师网络同时更新可以分三种:离线蒸馏、在线蒸馏、自蒸馏。
在这里插入图片描述

2.3.1 离线蒸馏

vanilla 蒸馏,两个步骤:

  • 蒸馏前,在大数据集上预先训练好教师模型
  • 蒸馏时,教师模型以logits或者中间features的形式提取知识,然后指导学生模型进行训练
  • 方法
    • 离线蒸馏方法通常采用单向知识转移和两阶段训练程序。
    • 然而,复杂的大容量教师模型训练时间很长是无法避免的,而离线蒸馏的学生模型训练通常在教师模型的指导下是高效的。而且,大老师和小学生之间的能力差距一直存在,学生往往很大程度上依赖于老师。

2.3.2 在线蒸馏

在大容量高性能的教师模型不存在的时候,使用在线蒸馏可以提高学生网络性能。

  • 方法
    • 在线蒸馏是一种具有高效并行计算的单阶段端到端训练方案。
    • 然而,现有的在线方法(例如,相互学习)通常不能解决在线设置中的高容量教师,使得进一步探索在线设置中教师和学生模型之间的关系成为有趣的话题。

2.3.3 自蒸馏

在自蒸馏中,教师和学生模型使用相同的网络。

  • 方法:
    • 从更深层蒸馏到更浅层
    • 把自己层的注意力图作为蒸馏目标蒸馏到更低层
    • 把前epoch得到的网络当作监督的训练过程转移到后面层,后面层模仿前一层
    • 标签平滑正则化

3. 参考

【1】https://blog.csdn.net/wj113149/article/details/116142902
【2】https://blog.csdn.net/nature553863/article/details/80568658


http://www.niftyadmin.cn/n/147568.html

相关文章

旋转机械设备故障诊断的轴心轨迹总结

旋转机械设备故障诊断的轴心轨迹分析 1. 转子不对中故障2. 转子不平衡2.1 文献1轨迹2.2 文献2轨迹4. 轴弯曲故障5. 转子部件松动![在这里插入图片描述](https://img-blog.csdnimg.cn/41dbae05da724251bd771c97cb4caf63.png)6. 碰摩故障3. 典型转子故障的原因1. 转子不对中故障 …

Gem5模拟器,源码调用溯源记录(十三)

救命,我这个🐋🧠,七秒钟的记忆,看完就忘了,每次又去找太麻烦了(我在visual studio中打开,没有装相应的环境,没法通过快捷键找到,只能在整个解决方案中搜索&am…

a-tree-select 基本使用,下拉框高度和宽度设置、回显时滚动条定位解决。

目录一、基本使用1. 界面效果2. 代码实现3. 问题1:下拉框占满整个屏幕4. 问题4:菜单内容过长时,下拉菜单宽度无限变宽。二、数据回显、滚动条定位1. 界面效果2. 代码实现2.1 获取默认展开节点2.1.1 代码实现2.1.2 说明2.2 设置滚动条定位2.2.…

Android Framework 面试集合——Handler篇

Handler属于非常经典的一个考题了,导致这个知识点很多时候,考官都懒得问了;这玩意很久之前就看过,但是过了一段时间,就很容易忘记,但是处理内存泄漏,IdleHandler之类的考点答案肯定很难忘。。。…

华为OD机试真题Java实现【最长连续子串】真题+解题思路+代码(20222023)

最长连续子串 题目 给定一个字符串 只包含字母和数字 按要求找出字符串中的最长连续子串的长度 字符串本身是其最长的子串 子串要求 只包含一个字母(a~z A~Z)其余必须是数字字母可以在子串中的任意位置 如果找不到满足要求的子串 比如说,全是字母或数字则返回-1🔥🔥🔥…

MySQL之case...when...then...end的详细使用

目录一、简介二、简单Case函数2.1、语法定义2.2、简单函数形式三、Case搜索函数3.1、语法定义3.2、简单用法3.3、分组3.4、分组计数3.5、分组汇总3.6、更新语句3.7、子查询结语一、简介 今天我们主要是讲讲case…when…then…end的用法,它主要分成两类: …

接口自动化神器appium进阶操作

01 变量提取和引用 变量提取和引用主要是为了解决接口之间的参数依赖问题。 使用场景:接口 A 的参数中需要使用接口 B 返回的某个数据,那么就要在请求 B 接口之后,提取数据保存,给请求 A 接口时使用。 1 变量提取 在用例集或用…

【SpringMVC】SpringMVC面试题总结

SpringMVC面试题1、SpringMVC的理解2、SpringMVC的工作流程3、SpringMVC的常用注解4、SpringMVC的九大内置组件5、Spring、SpringMVC、SpringBoot的区别1、SpringMVC的理解 2、SpringMVC的工作流程 3、SpringMVC的常用注解 RequestMapping: 用于处理请求url映射的注解&#xf…