pytorch代码实现之分布偏移卷积DSConv

DSConv卷积模块

DSConv(分布偏移卷积)可以容易地替换进标准神经网络体系结构并且实现较低的存储器使用和较高的计算速度。 DSConv将传统的卷积内核分解为两个组件:可变量化内核(VQK)和分布偏移。 通过在VQK中仅存储整数值来实现较低的存储器使用和较高的速度,同时通过应用基于内核和基于通道的分布偏移来保持与原始卷积相同的输出。 我们在ResNet50和34以及AlexNet和MobileNet上对ImageNet数据集测试了DSConv。 我们通过将浮点运算替换为整数运算,在卷积内核中实现了高达14x的内存使用量减少,并将运算速度提高了10倍。 此外,与其他量化方法不同,我们的工作允许对新任务和数据集进行一定程度的再训练。

原文地址:DSConv: Efficient Convolution Operator

DSConv结构图

代码实现:

import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair

class DSConv(_ConvNd):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=None, dilation=1, groups=1, padding_mode='zeros', bias=False, block_size=32, KDSBias=False, CDS=False):
        padding = _pair(autopad(kernel_size, padding))
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        dilation = _pair(dilation)

        blck_numb = math.ceil(((in_channels)/(block_size*groups)))
        super(DSConv, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        # KDS weight From Paper
        self.intweight = torch.Tensor(out_channels, in_channels, *kernel_size)
        self.alpha = torch.Tensor(out_channels, blck_numb, *kernel_size)

        # KDS bias From Paper
        self.KDSBias = KDSBias
        self.CDS = CDS

        if KDSBias:
            self.KDSb = torch.Tensor(out_channels, blck_numb, *kernel_size)
        if CDS:
            self.CDSw = torch.Tensor(out_channels)
            self.CDSb = torch.Tensor(out_channels)

        self.reset_parameters()

    def get_weight_res(self):
        # Include expansion of alpha and multiplication with weights to include in the convolution layer here
        alpha_res = torch.zeros(self.weight.shape).to(self.alpha.device)

        # Include KDSBias
        if self.KDSBias:
            KDSBias_res = torch.zeros(self.weight.shape).to(self.alpha.device)

        # Handy definitions:
        nmb_blocks = self.alpha.shape[1]
        total_depth = self.weight.shape[1]
        bs = total_depth//nmb_blocks

        llb = total_depth-(nmb_blocks-1)*bs

        # Casting the Alpha values as same tensor shape as weight
        for i in range(nmb_blocks):
            length_blk = llb if i==nmb_blocks-1 else bs

            shp = self.alpha.shape # Notice this is the same shape for the bias as well
            to_repeat=self.alpha[:, i, ...].view(shp[0],1,shp[2],shp[3]).clone()
            repeated = to_repeat.expand(shp[0], length_blk, shp[2], shp[3]).clone()
            alpha_res[:, i*bs:(i*bs+length_blk), ...] = repeated.clone()

            if self.KDSBias:
                to_repeat = self.KDSb[:, i, ...].view(shp[0], 1, shp[2], shp[3]).clone()
                repeated = to_repeat.expand(shp[0], length_blk, shp[2], shp[3]).clone()
                KDSBias_res[:, i*bs:(i*bs+length_blk), ...] = repeated.clone()

        if self.CDS:
            to_repeat = self.CDSw.view(-1, 1, 1, 1)
            repeated = to_repeat.expand_as(self.weight)
            print(repeated.shape)

        # Element-wise multiplication of alpha and weight
        weight_res = torch.mul(alpha_res, self.weight)
        if self.KDSBias:
            weight_res = torch.add(weight_res, KDSBias_res)
        return weight_res

    def forward(self, input):
        # Get resulting weight
        #weight_res = self.get_weight_res()

        # Returning convolution
        return F.conv2d(input, self.weight, self.bias,
                            self.stride, self.padding, self.dilation,
                            self.groups)

class DSConv2D(Conv):
    def __init__(self, inc, ouc, k=1, s=1, p=None, g=1, act=True):
        super().__init__(inc, ouc, k, s, p, g, act)
        self.conv = DSConv(inc, ouc, k, s, p, g)

http://www.niftyadmin.cn/n/5023626.html

相关文章

用户权限数据转换为用户组列表(3/3) - Excel PY公式

最近Excel圈里的大事情就是微软把PY塞进了Excel单元格,可以作为公式使用,轻松用PY做数据分析。系好安全带,老司机带你玩一把。 实例需求:如下是AD用户的列表,每个用户拥有该应用程序的只读或读写权限,现在需要创建新的…

使用Python编写高效程序

在当今竞争激烈的互联网时代,搜索引擎优化(SEO)成为了各类网站提升曝光度和流量的关键策略。而要在SEO领域中脱颖而出,掌握高效的网络抓取程序编写技巧是至关重要的。本文将分享一些宝贵的知识和技巧,帮助你使用Python…

群狼调研(长沙银行神秘顾客公司)|专业的银行网点神秘顾客公司

本文由群狼调研(长沙神秘顾客调查)出品,欢迎转载,请注明出处。选择专业的银行网点神秘顾客公司需要考虑多个因素,以确保您合作的公司能够为您提供高质量的服务和有价值的反馈。以下是选择银行网点神秘顾客公司时应考虑的关键因素:…

软件评测师之逻辑运算

目录 一、逻辑运算二、考法逻辑运算的应用 一、逻辑运算 或:||、、∪、V、OR 与:&&、*、、∩、Λ、AND 异或:⊕、XOR 同或:⊙ 非:!、﹃、~、NOT、— 下表中的0表示假,1表示真。 !A与A相…

软件自动化测试对软件产品起到什么作用?有什么注意事项?

自动化测试是把以人为驱动的测试行为转化为机器执行的一种过程。通常,在设计了测试用例并通过评审之后,由测试人员根据测试用例中描述的规程一步步执行测试,得到实际结果与期望结果的比较。 一、软件自动化测试的作用:   1、提…

论文解读 | MVSNet:非结构化多视图立体的深度推理

原创 | 文 BFT机器人 这篇论文的题目是《MVSNet: Depth Inference for Unstructured Multi-view Stereo》。这是一篇关于深度学习在多视角立体视觉(MVS)中的应用的研究论文。MVS任务的目标是从多个视角的图像中还原出三维场景的深度信息,从而…

力扣第37天----第322题、第279题

力扣第37天----第322题、第279题 文章目录 力扣第37天----第322题、第279题一、第322题--零钱兑换二、第279题--组合总和 Ⅳ 一、第322题–零钱兑换 ​ 整体思路,跟前面的几道完全背包差不多,就不具体解释了。有一些细节要注意,见代码注释。…

thymeleaf中 和作用域 相关的 内置对象

域内置对象 thymeleaf中内置了以下几个和作用域相关的内置对象: #ctx:模板引擎的全局上下文对象;#locale:在全局上下文中维护的java.util.Locale对象;#request:表示HttpServletRequest对象,只…