pytorch代码实现之CoordConv卷积

CoordConv卷积

深度学习领域,几乎没有什么想法能像卷积那样产生如此大的影响。对于任何涉及像素或空间表示的问题,普遍的直觉认为卷积神经网络可能是合适的。在本文中,我们通过看似平凡的坐标变换问题展示了一个惊人的反例,该问题只需要学习(x, y)笛卡尔空间中的坐标与单热像素空间中的坐标之间的映射。虽然卷积网络似乎适合这项任务,但我们表明它们失败得很明显。
CoordConv的工作原理是通过使用额外的坐标通道让卷积访问自己的输入坐标。在不牺牲普通卷积的计算和参数效率的情况下,CoordConv允许网络根据最终任务的需要学习完全的平移不变性或不同程度的平移依赖性。CoordConv解决了坐标变换问题,具有很好的泛化性,比convolution的参数少10-100倍,速度快150倍。

原文地址:An intriguing failing of convolutional neural networks and the CoordConv solution

CoordConv结构原理图

代码实现:

class AddCoords(nn.Module):
    def __init__(self, with_r=False):
        super().__init__()
        self.with_r = with_r

    def forward(self, input_tensor):
        """
        Args:
            input_tensor: shape(batch, channel, x_dim, y_dim)
        """
        batch_size, _, x_dim, y_dim = input_tensor.size()

        xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1)
        yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2)

        xx_channel = xx_channel.float() / (x_dim - 1)
        yy_channel = yy_channel.float() / (y_dim - 1)

        xx_channel = xx_channel * 2 - 1
        yy_channel = yy_channel * 2 - 1

        xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
        yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

        ret = torch.cat([
            input_tensor,
            xx_channel.type_as(input_tensor),
            yy_channel.type_as(input_tensor)], dim=1)

        if self.with_r:
            rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
            ret = torch.cat([ret, rr], dim=1)

        return ret

class CoordConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, with_r=False):
        super().__init__()
        self.addcoords = AddCoords(with_r=with_r)
        in_channels += 2
        if with_r:
            in_channels += 1
        self.conv = Conv(in_channels, out_channels, k=kernel_size, s=stride)

    def forward(self, x):
        x = self.addcoords(x)
        x = self.conv(x)
        return x

http://www.niftyadmin.cn/n/5020639.html

相关文章

C++-day4

仿照string类&#xff0c;完成myString 类 #include <iostream> #include <cstring> using namespace std; class myString { private:char *str; //记录c风格的字符串int size; //记录字符串的实际长度 public://无参构造myString():size(10…

Python爬虫技巧:使用代理IP和User-Agent应对反爬虫机制

在当今的网络环境中&#xff0c;反爬虫机制广泛应用于各个网站&#xff0c;为爬虫程序增加了困难。然而&#xff0c;作为一名Python爬虫开发者&#xff0c;我们可以利用一些技巧应对这些反爬虫措施。本文将分享一个重要的爬虫技巧&#xff1a;使用代理IP和User-Agent来应对反爬…

机器学习(9)---线性回归中的公式推导(手推)、闭式解和数值解

文章目录 一、闭式解&#xff08;解析解&#xff09;二、数值解三、一元线性回归中w和b的推导四、多元线性回归中w的推导 一、闭式解&#xff08;解析解&#xff09; 1. 在机器学习中&#xff0c;闭式解也被称为解析解&#xff08;analytical solution&#xff09;&#xff0c;…

d3.js 的使用

这篇文章相当于之前 svg 的补充。 因为 svg 代码肯定不是人为去专门写的。 在这里推荐制作 svg 的第三方库 - D3.js 用于定制数据可视化的JavaScript库 - D3 官网地址&#xff1a; D3 by Observable | The JavaScript library for bespoke data visualization 简单使用 画…

文字表达细化的提示词的学习笔记

《帮助你写一段文案的ChatGPT Prompt (提示&#xff09;应该是最能写出好文章的提示了&#xff01;》 Given some text, make it clearer. 对于给定的一些文本&#xff0c;使其更清晰。 Do not rewrite it entirely. Just make it clearer and more readable. 不要完全重写…

17. 线性代数 - 矩阵的逆

文章目录 矩阵的转置矩阵的逆Hi, 您好。我是茶桁。 我们已经学习过很多关于矩阵的知识点,今天依然还是矩阵的相关知识。我们来学一个相关操作「矩阵的转置」,更重要的是我们需要认识「矩阵的逆」 矩阵的转置 关于矩阵的转置,咱们导论课里有提到过。转置实际上还是蛮简单…

GO语言篇之CGO

GO语言篇之CGO 文章目录 GO语言篇之CGO前言C代码嵌入GO代码C文件嵌入GO代码缺点 前言 Go语言可以通过内置的CGO调用C语言接口&#xff0c;从而实现C语言代码的交互&#xff0c;CGO提供了一种将Go代码嵌入到C代码中&#xff0c;或者从Go代码中调用C函数的方法 C代码嵌入GO代码…

DC/DC开关电源学习笔记(五)开关电源的主要技术指标

(五)开关电源的主要技术指标 1.输入参数2.输出参数3.效率4.电压调整率和负载调整率5.动态特性:负载突变时输出电压的变化6.电源启动时间(Set-Up Time)与保持时间(Hold-Up Time)1.输入参数 输入电压大小,交流还是直流,相数,频率等。 2.输出参数 输出功率,输出电压,输出…