小样本学习--(1)概论

目录

一、概述

二、小样本学习的数据集

1、Omniglot

2、MiniimageNet

三、孪生网络

四、三元组损失函数


一、概述

        小样本学习用于处理训练数据集中样本数量少的情况,一般来说,小样本学习流程是这样的,从一个多种类少量样本的巨大数据集中训练一个Pretrained网络模型(这一步不需要做),之后可以基于预训练模型根据微调、元学习或度量方法进行fine-tune,做到对查询集的一个分类和识别。

小样本学习的LibFewShot库:https://github.com/RL-VIG/LibFewShot

        小样本学习与传统神经网络的区别:

        假设训练猫狗分类问题,传统神经网络会从大量带标签的猫狗训练集中进行充分训练,得到较好的模型,然后测试集也是猫狗数据集,只不过是训练集中没有的图片,模型将对测试集进行分类。

        小样本学习首先在一个较大的较多类别,每个类别较少数据的数据集(即辅助集,不包含猫狗类别)中进行预训练,通过迁移学习对预训练模型进行微调,微调时会利用一个Support set(支持集),支持集包含猫狗的图片和标签,根据支持集的类别共K类和每个类别的图片数量n张,又叫做K-way n-shot小样本问题,通常K取5或10,n取1或5。通过在支持集进行微调,达到少量样本完成对查询集(测试集,猫狗测试集)的分类。

        小样本学习,不需要传统神经网络的过高层数,过多的融合来寻找分类的特征从而知道如何分类,而是通过有限的支持集进行相似度匹配,来达到分类的效果。

        小样本学习例子:

        下图的Query:兔子就是测试集,而辅助集在训练时没有见过兔子类,那么他是如何分类的呢?

        通过依赖支持集Support Set对于预训练模型进行微调,来获得水獭与测试图片相似度最高的标签。

        另外 K-way n-shot的举例如下:

         K-way n-shot与测试集的Accuracy的关系:

(1)支持集类别数越多,测试集Accuracy越低,因为测试图片占测试种类的比例下降了。

(2)支持集图片越多,测试集Accuracy越高,这个很好理解,图片越多学的越好。

二、小样本学习的数据集

1、Omniglot

        Omniglot是全语言文字数据集,包含50种语言的字母表,共计1623个类,每个字母由20个不同的人书写,也就是每个字母仅有20张图片,每个图片的像素为105*105。Omniglot数据集分为训练集和测试集,训练集有30个字母表,964个字符,测试集有20个字母表,659个字符,训练集和测试集类别不同,也就是说预训练也是进行的小样本学习,Omniglot数据集一般用作小样本训练。

2、MiniimageNet

        MiniimageNet是一个从ImageNet数据集中抽取的数据集,一共100个类别,每个类别600张图片,共计6万张图片。MiniimageNet数据集的训练集64个类别,验证集16个类别,测试集20个类别。Miniimagenet用于针对各种生物、物品的小样本学习数据集。

三、孪生网络

        孪生网络,利用相同样本和不同样本之间的区别,训练出一个能够分类的神经网络。

        首先将训练集分成正负样本,且样本数量相等的三元组形式,类别相同的图片为正样本,类别不同的图片(首先选取一张图片a,再找从不属于a的图片中随机取样b图片)为负样本。

        孪生网络前向传播输入两张图片,经过映射得到两个列向量,向量作差得到z层,经过全连接网络和激活函数,与所给target计算损失函数,并进行反向传播修改权重。

        注意这个网络只是简单的一个解释,内部的网络已经更新换代,但大体依旧是输入两张图片与一个Target训练该模型。如下图这种就是图片映射的列向量进入网络层,而没有直接做差。

        测试模型时,根据测试集与支持集的不同类别计算相似度,相似度最大的记为本次测试的类别。

四、三元组损失函数

        三元组损失(Triplet Loss),是基于度量的小样本学习中的损失函数方法。首先从训练集中随机选择一张图片作为anchor,如下图中第一张老虎图片,再根据anchor的类别寻找该类的随机一张图片作为Positive,最后从trainset除去老虎类,随机抽取一张图片记为Negative。

        根据三张图片,正样本和负样本去计算与anchor的2-范数,也就是几何距离,记作d+和d-,d+越小越好,正样本越接近anchor,d-越大越好,负样本越远离anchor。

        如果d+=d-那么相当于随机模型,所以训练好的模型必须满足d^-\geqslant d^++\alpha,我们定义三元组损失为   Loss(x^a,x^+,x^-)=max{0,d^++\alpha-d^-}

        根据三元组损失计算预测图片与支持集中图片的距离dist,通过比较距离中最短的一个,就可以确定预测图片所属的类别。 

相关视频:Siamese Network (孪生网络) (2/3)_哔哩哔哩_bilibili


http://www.niftyadmin.cn/n/5098054.html

相关文章

redis 实现互相关注功能

突然想到平时的设计软件如何实现互相关注这个功能,然后查询后大致思路如下: 可以使用 Redis 数据库来存储关注关系。 在社交网络应用程序中,互相关注功能(也称为双向关注或好友关系)是一种常见的功能,允许…

小程序-uni-app:将页面(html+css)生成图片/海报/名片,进行下载 保存到手机

一、需要描述 本文实现,uniapp微信小程序,把页面内容保存为图片,并且下载到手机上。 说实话网上找了很多资料,但是效果不理想,直到看了一个开源项目,我知道可以实现了。 本文以开源项目uniapp-wxml-to-can…

Vue项目使用svg之svg-sprite-loader详细使用

项目中为了体验好、性能优、资源丰富等原因经常会用svg这种矢量图,但是svg不能直接像image标签一样直接使用,这就需要前端的同学自己处理了。 svg有以下优点: svg放大不失真,png,jpg会出现失真现象svg的体积非常小,对…

SpringBoot整合Activiti

SpringBoot集成Activiti7 SpringBoot版本使用2.7.16 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.16</version><relativePath/> <!-- lookup…

NPM 使用入门

NPM NPM是随同 NodeJS 一起安装的包管理工具&#xff0c;能解决 NodeJS 代码部署上的很多问题&#xff0c;常见的使用场景有以下几种&#xff1a; 允许用户从NPM服务器下载别人编写的第三方包到本地使用。允许用户从NPM服务器下载并安装别人编写的命令行程序到本地使用。允许…

通达OA 2016网络智能办公系统 handle.php SQL注入漏洞

一、漏洞描述 北京通达信科科技有限公司通达OA2016网络智能办公系统 handle.php 存在sql注入漏洞&#xff0c;攻击者可利用此漏洞获取数据库管理员权限&#xff0c;查询数据、获取系统信息&#xff0c;威胁企业单位数据安全。 二、网络空间搜索引擎查询 fofa查询 app"T…

FDTD Solutions笔记

FDTD Solutions笔记 目录使用流程实例 目录 使用流程 实例 材料条件 步骤 基底 2. 添加规则膜层 3. 添加仿真区 解释&#xff1a; 仿真区为&#xff08;0,0&#xff09;&#xff0c;x方向为0.4&#xff0c;y方向是1 解释&#xff1a; 一般先用低精度进行计算 解释&#xff1a…

Kotlin中的比较运算符

在Kotlin中&#xff0c;我们可以使用比较运算符进行值的比较和判断。下面对Kotlin中的等于、不等于、小于、大于、小于等于和大于等于进行详细介绍&#xff0c;并提供示例代码。 等于运算符&#xff08;&#xff09;&#xff1a; 等于运算符用于判断两个值是否相等。如果两个值…