Transformer的一点理解,附一个简单例子理解attention中的QKV

news/2024/7/24 10:28:14 标签: transformer, 深度学习, 人工智能

Transformer用于目标检测的开山之作DETR,论文作者在附录最后放了一段简单的代码便于理解DETR模型。

DETR的backbone用的是resnet-50去掉了最后的AdaptiveAvgPool2d和Linear这两层。

self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])

经过一次卷积加上position embedding,输入到transformer,position embedding是直接加和,不是像叠盘子一样的concat。

这里回顾一下transformer

transformer中最重要的attention,这篇文章Attention Is All You Need (Transformer) 论文精读 - 知乎

举了一个简单的例子,去解释attention中的QKV到底是什么含义。 这里

引用上述文章作者的例子:

如果我们有这样姓名和年龄一个数据库

张三:18
李四:22
张伟:19
张三:20

如果查询『所有叫张三的人的平均年龄』,Key==“张三”,可以得到Key对应的两个Value,算出(18+20)/2=19。我们把『所有叫张三的人的平均年龄』这句话称为一个查询(Query)

如果有另一个查询Query‘:『所有姓张的人的平均年龄』, Key[0]==“张”,得到三个Value:(18+20+19)/3=19

这样查询很低效,为了高效,将Query,Key转为向量vector。

将姓名(Key)汉字编码为向量

张三:[1, 2, 0]
李四:[0, 0, 2]
张伟:[1, 4, 0]

如果一个Quary是查询所有姓张的人的平均年龄,那么Quary可以写成向量  [1, 0, 0],将Quary向量和Key向量做点积

dot([1, 0, 0], [1, 2, 0]) = 1
dot([1, 0, 0], [1, 2, 0]) = 1
dot([1, 0, 0], [0, 0, 2]) = 0
dot([1, 0, 0], [1, 4, 0]) = 1

将结果softmax归一化

softmax([1, 1, 0, 1]) = [1/3, 1/3, 0, 1/3]

再将归一化后的结果与Value做点积

dot([1/3, 1/3, 0, 1/3], [18, 20, 22, 19]) = 19

就得到了想要的结果。(说句题外话,这样查询感觉跟布隆过滤器Bloom Filter有点相似的感觉,将文字编码成位数组)

这个计算就是Attention is all you need论文里Scaled Dot-Product Attention

 在transformer中,query key value关系如下图所示,(reference:The Illustrated Transformer – Jay Alammar – Visualizing machine learning one concept at a time.)

 将文字编码为向量x,x与矩阵W相乘,得到q,q与k做点乘,再除8(the square root of the dimension of the key vectors used in the paper – 64),再softmax,再成v,得到z

 

 

如果是多头注意力,就会得到多个注意力头的z

在RNN中,是按顺序输入,所以网络是知道每个输入的位置次序,但是transformer不是这样,因此还要加一个positional encoding,告诉网络输入的每个词在句子中的位置

 Transformer也使用了和resnet相似的残差连接。

将编码器得到的K,V矩阵输入到解码器

在解码的第一步中,输入K V,得到一个output,而在后续的解码中,将前一部的结果也一起输入解码器。比如第二步中,将第一步的结果 “I”也输入decoder,直到decoder给出 end of sentence为止。

transformer的损失函数,通过交叉熵,使两个分布相同

Output Vocabulary是提前建好的词库,网络输出的是词库中所有词 出现在这个位置的概率。

回到DETR,DETR中叠了6个transformer的encoder和decoder,将transformer输出再分别输入两个Linear,就得到了class和bbox。


http://www.niftyadmin.cn/n/5201099.html

相关文章

姓氏情侣家庭亲子谐音顽梗头像分销流量主微信抖音小程序开发

姓氏情侣家庭亲子谐音顽梗头像分销流量主微信抖音小程序开发 姓氏情侣头像:提供各种姓氏的情侣头像模板,用户可根据自己的姓氏选择合适的头像进行定制。 家庭头像:为家庭成员提供多种形式的头像模板,让用户可以选择合适的家庭头像…

Python数据结构基础教学,从零基础小白到实战大佬!

文章目录 前言 Python有那几种数据结构?1)列表(list)1.1 什么是列表?1.2列表的增删改查 2)字典(Dictionary)2.1 什么是字典?2.2 字典的增删改查 3)元组(Tuple)4)集合(Set…

Android codec2 视频框架之输出端的内存管理

文章目录 前言setSurfacestart从哪个pool中申请buffer解码后框架的处理流程renderOutbuffer 输出显示 前言 输出buffer整体的管理流程主要可以分为三个部分: MediaCodc 和 应用之间的交互 包括设置Surface、解码输出回调到MediaCodec。将输出buffer render或者rele…

C语言--判断年月日是否合理

一.题目描述 比如输入2001,2,29,输出: 不合理 。因为平年的二月只有28天 比如输入2000,6,31,输出:不合理。因为6月是小月,只有30天。 二.思路分析 本题主要注意两个问…

现货白银MACD实战分析例子

MACD这个技术指标的全称是平滑异同移动平均线,主要表示经过平滑处理后均线的差异程度,一般用来研判现货白银价格变化的方向、强度和趋势。MT4中的MACD指标,主要是由信号线、(上升/下跌)动能柱、0轴这三部分组成。 MACD…

Memcpy运行时内存增加

结论:Memcpy不会导致内存增加。 原因:所需内存过大,动态申请时系统并未分配空间,而是边使用边分配,导致出现该现象,在所有内存均使用后,内存不会增长。

力扣 622.设计循环队列

目录 1.解题思路2.代码实现 1.解题思路 首先,该题是设计循环队列,因此我们有两种实现方法,即数组和链表,但具体考虑后,发现数组实现要更容易一些,因此使用数组实现,因此我们要给出头和尾变量&a…