神经网络学习小记录78——Keras CA(Coordinate attention)注意力机制的解析与代码详解

news/2024/7/24 8:05:22 标签: 神经网络, keras, CA, 注意力机制, 深度学习

神经网络学习小记录78——Keras CA(Coordinate attention)注意力机制的解析与代码详解

学习前言

CA注意力机制是最近提出的一种注意力机制,全面关注特征层的空间信息和通道信息。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/yolov4-tiny-keras

复制该路径到地址栏跳转。

CA_10">CA注意力机制的概念与实现

请添加图片描述
该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[H, W, C],在经过宽方向的平均池化后,获得的特征层shape为[H, 1, C],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[1, W, C],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[1, H+W, C],利用卷积+标准化+激活函数获得特征。需要注意的是,这里的卷积通道数一般会小一点,做一个缩放,可以减少参数量。卷积后的特征层的shape为[1, H+W, C/r],其中r为缩放系数。

之后再次分开为两个并行阶段,再将宽高分开成为:[1, H, C/r]和[1, W, C/r],之后进行转置。获得两个特征层[H, 1, C/r]和[1, W, C/r]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况,前者在宽上拓展一下,后者在高上拓展一下,然后一起乘上原有的特征就是CA注意力机制

实现的python代码为:

def ca_block(input_feature, ratio=16, name=""):
	channel = input_feature._keras_shape[-1]
	h		= input_feature._keras_shape[1]
	w		= input_feature._keras_shape[2]
 
	x_h = Lambda(lambda x: K.mean(x, axis=2, keepdims=True))(input_feature)
	x_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_h)
	x_w = Lambda(lambda x: K.max(x, axis=1, keepdims=True))(input_feature)
	
	x_cat_conv_relu = Concatenate(axis=2)([x_w, x_h])
	x_cat_conv_relu = Conv2D(channel // ratio, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv1_"+str(name))(x_cat_conv_relu)
	x_cat_conv_relu = BatchNormalization(name = "ca_block_bn_"+str(name))(x_cat_conv_relu)
	x_cat_conv_relu = Activation('relu')(x_cat_conv_relu)
 
	x_cat_conv_split_h, x_cat_conv_split_w = Lambda(lambda x: tf.split(x, num_or_size_splits=[h, w], axis=2))(x_cat_conv_relu)
	x_cat_conv_split_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_cat_conv_split_h)
	x_cat_conv_split_h = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv2_"+str(name))(x_cat_conv_split_h)
	x_cat_conv_split_h = Activation('sigmoid')(x_cat_conv_split_h)
 
	x_cat_conv_split_w = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv3_"+str(name))(x_cat_conv_split_w)
	x_cat_conv_split_w = Activation('sigmoid')(x_cat_conv_split_w)
 
	output = multiply([input_feature, x_cat_conv_split_h])
	output = multiply([output, x_cat_conv_split_w])
	return output

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制
在这里插入图片描述
实现代码如下:

attention = [se_block, cbam_block, eca_block, ca_block]

#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
def yolo_body(input_shape, anchors_mask, num_classes, phi = 0):
    inputs = Input(input_shape)
    #---------------------------------------------------#
    #   生成CSPdarknet53_tiny的主干模型
    #   feat1的shape为26,26,256
    #   feat2的shape为13,13,512
    #---------------------------------------------------#
    feat1, feat2 = darknet_body(inputs)
    if phi >= 1 and phi <= 4:
        feat1 = attention[phi - 1](feat1, name='feat1')
        feat2 = attention[phi - 1](feat2, name='feat2')

    # 13,13,512 -> 13,13,256
    P5 = DarknetConv2D_BN_Leaky(256, (1,1))(feat2)
    # 13,13,256 -> 13,13,512 -> 13,13,255
    P5_output = DarknetConv2D_BN_Leaky(512, (3,3))(P5)
    P5_output = DarknetConv2D(len(anchors_mask[0]) * (num_classes+5), (1,1))(P5_output)
    
    # 13,13,256 -> 13,13,128 -> 26,26,128
    P5_upsample = compose(DarknetConv2D_BN_Leaky(128, (1,1)), UpSampling2D(2))(P5)
    if phi >= 1 and phi <= 4:
        P5_upsample = attention[phi - 1](P5_upsample, name='P5_upsample')

    # 26,26,256 + 26,26,128 -> 26,26,384
    P4 = Concatenate()([P5_upsample, feat1])
    
    # 26,26,384 -> 26,26,256 -> 26,26,255
    P4_output = DarknetConv2D_BN_Leaky(256, (3,3))(P4)
    P4_output = DarknetConv2D(len(anchors_mask[1]) * (num_classes+5), (1,1))(P4_output)
    
    return Model(inputs, [P5_output, P4_output])

http://www.niftyadmin.cn/n/5379320.html

相关文章

鸿蒙开发系列教程(二十三)--List 列表操作(2)

列表样式 1、设置内容间距 在列表项之间添加间距&#xff0c;可以使用space参数&#xff0c;主轴方向 List({ space: 10 }) { … } 2、添加分隔线 分隔线用来将界面元素隔开&#xff0c;使单个元素更加容易识别。 startMargin和endMargin属性分别用于设置分隔线距离列表侧…

【编程】Rust语言入门第4篇 字符串

Rust 中的字符是 Unicode 类型&#xff0c;因此每个字符占据 4 个字节内存空间&#xff0c;但字符串不一样&#xff0c;字符串是 UTF-8 编码&#xff0c;也就是字符串中的字符所占的字节数是变化的(1 - 4)。 常见的字符串有两种: str&#xff0c;通常是引用类型&#xff0c;&a…

vivim复习

vi/vim常用命令 vi&vim常用命令 set nu 显示行号 gg 跳转到文件开头 / 向后搜索 ? 向前搜索 n 查找下一处N 查找上一处 | 光标所在行行首L 屏幕所显示的底行{ 段首} 段尾- 前一行行首 后一行行首 ( 句首 ) 下一句首 $ 行末 M 屏…

nlp中如何数据增强

在自然语言处理&#xff08;NLP&#xff09;中&#xff0c;数据增强是一种常用的技术&#xff0c;旨在通过对原始文本进行一系列变换和扩充&#xff0c;生成更多多样化的训练数据。这有助于提高模型的泛化能力和鲁棒性。下面是一些常见的数据增强方法在NLP中的应用&#xff1a;…

详解结构体内存对齐及结构体如何实现位段~

目录 ​编辑 一&#xff1a;结构体内存对齐 1.1对齐规则 1.2.为什么存在内存对齐 1.3修改默认对齐数 二.结构体实现位段 2.1什么是位段 2.2位段的内存分配 2.3位段的跨平台问题 2.4位段的应用 2.5位段使用的注意事项 三.完结散花 悟已往之不谏&#xff0c;知来者犹可…

【JavaEE】_CSS选择器

目录 1. 基本语法格式 2. 引入方式 2.1 内部样式 2.2 内联样式 2.3 外部样式 3. 基础选择器 3.1 标签选择器 3.2 类选择器 3.3 ID选择器 4. 复合选择器 4.1 后代选择器 4.2 子选择器 4.3 并集选择器 4.4 伪类选择器 1. 基本语法格式 选择器若干属性声明 2. 引入…

程序员与电脑不关机现象的探讨

在当今信息化社会&#xff0c;程序员作为数字世界的建设者和创新者&#xff0c;其工作性质和习惯引发了关于他们为何不喜欢关电脑这一现象的讨论。本文将从个人习惯、工作效率以及技术特性等方面对此进行深入剖析。 首先&#xff0c;从个人习惯角度看&#xff0c;程序员的工作具…

力扣 123. 买卖股票的最佳时机 III

题目来源&#xff1a;https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-iii/description/ C题解&#xff1a;动态规划。至多买卖两次&#xff0c;这意味着可以买卖一次&#xff0c;可以买卖两次&#xff0c;也可以不买卖。 一天一共就有四个状态&#xff1a; 第…