Pytorch之Spatial-Shift-Operation的5种实现策略

本文已授权极市平台, 并首发于极市平台公众号. 未经允许不得二次转载.

原始文档(可能会进一步更新): https://www.yuque.com/lart/ugkv9f/nnor5p

前言

之前看了一些使用空间偏移操作来替代区域卷积运算的论文:

看完这些论文后, 通过参考他们提供的核心代码(主要是后面那些MLP方法), 让我对于实现空间偏移有了一些想法.

通过整合现有的知识, 我归纳总结了五种实现策略.

由于我个人使用pytorch, 所以这里的展示也可能会用到pytorch自身提供的一些有用的函数.

问题描述

在提供实现之前, 我们应该先明确目的以便于后续的实现.

这些现有的工作都可以简化为:

给定tensor \(X \in \mathbb{R}^{1 \times 8 \times 5 \times 5}\), 这里遵循pytorch默认的数据格式, 即 B, C, H, W .

通过变换操作\(\mathcal{T}: x \rightarrow \tilde{x}\), 将\(X\)转换为\(\tilde{X}\).

这里tensor \(\tilde{X} \in \mathbb{R}^{1 \times 8 \times 5 \times 5}\), 为了提供合理的对比, 这里统一使用后面章节中基于"切片索引"策略的结果作为\(\tilde{X}\)的值.

import torch

xs = torch.meshgrid(torch.arange(5), torch.arange(5))
x = torch.stack(xs, dim=0)
x = x.unsqueeze(0).repeat(1, 4, 1, 1).float()
print(x) '''
tensor([[[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
'''

方法1: 切片索引

这是最直接和简单的策略了. 这也是S2-MLP系列中使用的策略.

我们将其作为其他所有策略的参考对象. 后续的实现中同样会得到这个结果.

direct_shift = torch.clone(x)
direct_shift[:, 0:2, :, 1:] = torch.clone(direct_shift[:, 0:2, :, :4])
direct_shift[:, 2:4, :, :4] = torch.clone(direct_shift[:, 2:4, :, 1:])
direct_shift[:, 4:6, 1:, :] = torch.clone(direct_shift[:, 4:6, :4, :])
direct_shift[:, 6:8, :4, :] = torch.clone(direct_shift[:, 6:8, 1:, :])
print(direct_shift) '''
tensor([[[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.]], [[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
'''

方法2: 特征图偏移—— torch.roll

pytorch提供了一个直接对特征图进行偏移的函数, 即 torch.roll . 这一操作在最近的transformer论文和mlp中有一些工作已经开始使用, 例如SwinTransformer和AS-MLP.

这里展示下AS-MLP论文中提供的伪代码:

其主要作用就是将特征图沿着某个轴向进行偏移, 并支持同时沿着多个轴向偏移, 从而构造更多样的偏移方向.

为了实现与前面相同的结果, 我们需要首先对输入进行padding.

因为直接切片索引有个特点就是边界值是会重复出现的, 而若是直接roll操作, 会导致所有的值整体移动.

所以为了实现类似的效果, 先对四周各padding一个网格的数据.

注意这里选择使用重复模式(replicate)以实现最终的边界重复值的效果.

import torch.nn.functional as F

pad_x = F.pad(x, pad=[1, 1, 1, 1], mode="replicate")  # 这里需要借助padding来保留边界的数据

接下来开始处理, 沿着四个方向各偏移一个单位的长度:

roll_shift = torch.cat(
[
torch.roll(pad_x[:, c * 2 : (c + 1) * 2, ...], shifts=(shift_h, shift_w), dims=(2, 3))
for c, (shift_h, shift_w) in enumerate([(0, 1), (0, -1), (1, 0), (-1, 0)])
],
dim=1,
) '''
tensor([[[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4., 4., 4.]], [[4., 0., 0., 1., 2., 3., 4.],
[4., 0., 0., 1., 2., 3., 4.],
[4., 0., 0., 1., 2., 3., 4.],
[4., 0., 0., 1., 2., 3., 4.],
[4., 0., 0., 1., 2., 3., 4.],
[4., 0., 0., 1., 2., 3., 4.],
[4., 0., 0., 1., 2., 3., 4.]], [[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4., 4., 0.],
[0., 1., 2., 3., 4., 4., 0.],
[0., 1., 2., 3., 4., 4., 0.],
[0., 1., 2., 3., 4., 4., 0.],
[0., 1., 2., 3., 4., 4., 0.],
[0., 1., 2., 3., 4., 4., 0.],
[0., 1., 2., 3., 4., 4., 0.]], [[4., 4., 4., 4., 4., 4., 4.],
[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.]], [[0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4., 4., 4.],
[0., 0., 0., 0., 0., 0., 0.]], [[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.]]]])
'''

接下来只需要剪裁一下即可:

roll_shift = roll_shift[..., 1:6, 1:6]
print(roll_shift) '''
tensor([[[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.]], [[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
'''

方法3: 1x1 Deformable Convolution—— ops.deform_conv2d

在阅读Cycle FC的过程中, 了解到了Deformable Convolution在实现空间偏移操作上的妙用.

由于torchvision最新版已经集成了这一操作, 所以我们只需要导入函数即可:

from torchvision.ops import deform_conv2d

为了使用它实现空间偏移, 我在对Cycle FC的解读中, 对相关代码添加了一些注释信息:

要想理解这一函数的操作, 需要首先理解后面使用的deform_conv2d_tv的具体用法.

具体可见:https://pytorch.org/vision/0.10/ops.html#torchvision.ops.deform_conv2d

这里对于offset参数的要求是:

offset (Tensor[batch_size, 2 _ offset_groups _ kernel_height * kernel_width, out_height, out_width])

offsets to be applied for each position in the convolution kernel.

也就是说, 对于样本 s 的输出特征图的通道 c 中的位置 (x, y) , 这个函数会从 offset 中取出, 形状为 kernel_height*kernel_width 的卷积核所对应的偏移参数, 其为 offset[s, 0:2*offset_groups*kernel_height*kernel_width, x, y] . 也就是这一系列参数都是对应样本 s 的单个位置 (x, y) 的.

针对不同的位置可以有不同的 offset , 也可以有相同的 (下面的实现就是后者).

对于这 2*offset_groups*kernel_height*kernel_width 个数, 涉及到对于输入特征通道的分组.

将其分成 offset_groups 组, 每份单独拥有一组对应于卷积核中心位置的相对偏移量, 共 2*kernel_height*kernel_width 个数.

对于每个核参数, 使用两个量来描述偏移, 即h方向和w方向相对中心位置的偏移, 即对应于后面代码中的减去 kernel_height//2 或者 kernel_width//2 .

需要注意的是, 当偏移位置位于 padding 后的 tensor 的边界之外, 则是将网格使用0补齐. 如果网格上有边界值, 则使用边界值和用0补齐的网格顶点来计算双线性插值的结果.

该策略需要我们去构造特定的相对偏移值offset来对1x1卷积核在不同通道的采样位置进行调整.

我们先构造我们需要的offset \(\Delta \in \mathbb{R}^{1 \times 2C_iK_hK_w \times 1 \times 1}\). 这里之所以将 out_height & out_width 两个维度设置为1, 是因为我们对整个空间的偏移是一致的, 所以只需要简单的重复数值即可.

offset = torch.empty(1, 2 * 8 * 1 * 1, 1, 1)
for c, (rel_offset_h, rel_offset_w) in enumerate([(0, -1), (0, -1), (0, 1), (0, 1), (-1, 0), (-1, 0), (1, 0), (1, 0)]):
offset[0, c * 2 + 0, 0, 0] = rel_offset_h
offset[0, c * 2 + 1, 0, 0] = rel_offset_w
offset = offset.repeat(1, 1, 7, 7).float() # 针对空间偏移重复偏移量

在构造offset的时候, 我们要明确, 其通道中的数据都是两两一组的, 每一组包含着沿着H轴和W轴的相对偏移量 (这一相对偏移量应该是以其作用的卷积权重位置为中心 —— 这一结论我并没有验证, 只是个人的推理, 因为这样可能在源码中实现起来更加方便, 可以直接作用权重对应位置的坐标. 在不读源码的前提下理解函数的功能, 那就需要自行构造数据来验证性的理解了).

为了更好的理解offset的作用的原理, 我们可以想象对于采样位置\((h, w)\), 使用相对偏移量\((\delta_h, \delta_w)\)作用后, 采样位置变成了\((h+\delta_h, w+\delta_w)\). 即原来作用于\((h, w)\)的权重, 偏移后直接作用到了位置\((h+\delta_h, w+\delta_w)\)上.

对于我们的前面描述的沿着四个轴向各自一个单位偏移, 可以通过对\(\delta_h\)和\(\delta_w\)分别赋予\(\{-1, 0, 1\}\)中的值即可实现.

由于这里仅需要体现通道特定的空间偏移作用, 而并不需要Deformable Convolution的卷积功能, 我们需要将卷积核设置为单位矩阵, 并转换为分组卷积对应的卷积核的形式:

weight = torch.eye(8).reshape(8, 8, 1, 1).float()
# 输入8通道,输出8通道,每个输入通道只和一个对应的输出通道有映射权值1

接下来将权重和偏移送入导入的函数中.

由于该函数对于偏移超出边界的位置是使用0补齐的网格计算的, 所以为了实现前面边界上的重复值的效果, 这里同样需要使用重复模式下的padding后的输入.

并对结果进行一下修剪:

deconv_shift = deform_conv2d(pad_x, offset=offset, weight=weight)
deconv_shift = deconv_shift[..., 1:6, 1:6]
print(deconv_shift) '''
tensor([[[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.]], [[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
'''

方法4: 3x3 Depthwise Convolution—— F.conv2d

在S2MLP中提到了空间偏移操作可以通过使用特殊构造的3x3 Depthwise Convolution来实现.

由于基于3x3卷积操作, 所以为了实现边界值的重复效果仍然需要对输入进行重复padding.

首先构造对应四个方向的卷积核:

k1 = torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k2 = torch.FloatTensor([[0, 0, 0], [0, 0, 1], [0, 0, 0]]).reshape(1, 1, 3, 3)
k3 = torch.FloatTensor([[0, 1, 0], [0, 0, 0], [0, 0, 0]]).reshape(1, 1, 3, 3)
k4 = torch.FloatTensor([[0, 0, 0], [0, 0, 0], [0, 1, 0]]).reshape(1, 1, 3, 3)
weight = torch.cat([k1, k1, k2, k2, k3, k3, k4, k4], dim=0) # 每个输出通道对应一个输入通道

接下来将卷积核和数据送入 F.conv2d 中计算即可, 输入在四边各padding了1个单位, 所以输出形状不变:

conv_shift = F.conv2d(pad_x, weight=weight, groups=8)
print(conv_shift) '''
tensor([[[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.]], [[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
'''

方法5: 网格采样—— F.grid_sample

最后这里提到的基于 F.grid_sample , 该操作是pytorch提供的用于构建STN的一个函数, 但是其在光流预测任务以及最近的一些分割任务中开始出现:

  • AlignSeg: Feature-Aligned Segmentation Networks
  • Semantic Flow for Fast and Accurate Scene Parsing

针对4Dtensor, 其主要作用就是根据给定的网格采样图grid\(\Gamma = \mathbb{R}^{B \times H_o \times W_o \times 2}\)来对数据点\((\gamma_h, \gamma_w)\)进行采样以放置到输出的位置\((h, w)\)中.

要注意的是, 该函数对限制了采样图grid的取值范围是对输入的尺寸归一化后的结果, 并且\(\Gamma\)的最后一维度分别是在索引W轴、H轴. 即对于输入tensor的布局 B, C, H, W 的四个维度从后往前索引. 实际上, 这一规则在pytorch的其他函数的设计中广泛遵循. 例如pytorch中的pad函数的规则也是一样的.

首先根据需求构造基于输入数据的原始坐标数组 (左上角为\((h_{coord}[0, 0], w_{coord}[0, 0])\), 右上角为\((h_{coord}[0, 5], w_{coord}[0, 5])\)):

h_coord, w_coord = torch.meshgrid(torch.arange(5), torch.arange(5))
print(h_coord)
print(w_coord)
h_coord = h_coord.reshape(1, 5, 5, 1)
w_coord = w_coord.reshape(1, 5, 5, 1) '''
tensor([[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
[2, 2, 2, 2, 2],
[3, 3, 3, 3, 3],
[4, 4, 4, 4, 4]])
tensor([[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4],
[0, 1, 2, 3, 4]])
'''

针对每一个输出\(\tilde{x}\), 计算对应的输入\(x\)的的坐标 (即采样位置):

            torch.cat(
[ # 请注意这里的堆叠顺序,先放靠后的轴的坐标
2 * torch.clamp(w_coord + w, 0, 4) / (5 - 1) - 1,
2 * torch.clamp(h_coord + h, 0, 4) / (5 - 1) - 1,
],
dim=-1,
)

这里的参数\(w\&h\)表示基于原始坐标系的偏移量.

由于这里直接使用clamp限制了采样区间, 靠近边界的部分会重复使用, 所以后续直接使用原始的输入即可.

将新坐标送入函数的时候, 需要将其转换为\([-1, 1]\)范围内的值, 即针对输入的形状W和H进行归一化计算.

        F.grid_sample(
x,
torch.cat(
[
2 * torch.clamp(w_coord + w, 0, 4) / (5 - 1) - 1,
2 * torch.clamp(h_coord + h, 0, 4) / (5 - 1) - 1,
],
dim=-1,
),
mode="bilinear",
align_corners=True,
)

要注意, 这里使用的是 align_corners=True , 关于pytorch中该参数的介绍可以查看https://www.yuque.com/lart/idh721/ugwn46.

True :

False :

所以可以看到, 这里前者更符合我们的需求, 因为这里提到的涉及双线性插值的算法(例如前面的Deformable Convolution)的实现都是将像素放到网格顶点上的 (按照这一思路理解比较符合实验现象, 我就姑且这样描述).

grid_sampled_shift = torch.cat(
[
F.grid_sample(
x,
torch.cat(
[
2 * torch.clamp(w_coord + w, 0, 4) / (5 - 1) - 1,
2 * torch.clamp(h_coord + h, 0, 4) / (5 - 1) - 1,
],
dim=-1,
),
mode="bilinear",
align_corners=True,
)
for x, (h, w) in zip(x.chunk(4, dim=1), [(0, -1), (0, 1), (-1, 0), (1, 0)])
],
dim=1,
)
print(grid_sampled_shift) '''
tensor([[[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.],
[0., 0., 1., 2., 3.]], [[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]], [[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.],
[1., 2., 3., 4., 4.]], [[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]], [[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4.]], [[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.],
[0., 1., 2., 3., 4.]]]])
'''

另外的一些思考

关于 F.grid_sample 的误差问题

由于 F.grid_sample 涉及到归一化操作, 自然而然存在精度损失.

所以实际上如果想要实现精确控制的话, 不太建议使用这个方法.

如果位置恰好在但单元格角点上, 倒是可以使用最近邻插值的模式来获得一个更加整齐的结果.

下面是一个例子:

h_coord, w_coord = torch.meshgrid(torch.arange(7), torch.arange(7))
h_coord = h_coord.reshape(1, 7, 7, 1)
w_coord = w_coord.reshape(1, 7, 7, 1)
grid = torch.cat(
[
2 * torch.clamp(w_coord, 0, 6) / (7 - 1) - 1,
2 * torch.clamp(h_coord, 0, 6) / (7 - 1) - 1,
],
dim=-1,
)
print(grid)
print(pad_x[:, :2]) print("mode=bilinear\n", F.grid_sample(pad_x[:, :2], grid, mode="bilinear", align_corners=True))
print("mode=nearest\n", F.grid_sample(pad_x[:, :2], grid, mode="nearest", align_corners=True)) '''
tensor([[[[-1.0000, -1.0000],
[-0.6667, -1.0000],
[-0.3333, -1.0000],
[ 0.0000, -1.0000],
[ 0.3333, -1.0000],
[ 0.6667, -1.0000],
[ 1.0000, -1.0000]], [[-1.0000, -0.6667],
[-0.6667, -0.6667],
[-0.3333, -0.6667],
[ 0.0000, -0.6667],
[ 0.3333, -0.6667],
[ 0.6667, -0.6667],
[ 1.0000, -0.6667]], [[-1.0000, -0.3333],
[-0.6667, -0.3333],
[-0.3333, -0.3333],
[ 0.0000, -0.3333],
[ 0.3333, -0.3333],
[ 0.6667, -0.3333],
[ 1.0000, -0.3333]], [[-1.0000, 0.0000],
[-0.6667, 0.0000],
[-0.3333, 0.0000],
[ 0.0000, 0.0000],
[ 0.3333, 0.0000],
[ 0.6667, 0.0000],
[ 1.0000, 0.0000]], [[-1.0000, 0.3333],
[-0.6667, 0.3333],
[-0.3333, 0.3333],
[ 0.0000, 0.3333],
[ 0.3333, 0.3333],
[ 0.6667, 0.3333],
[ 1.0000, 0.3333]], [[-1.0000, 0.6667],
[-0.6667, 0.6667],
[-0.3333, 0.6667],
[ 0.0000, 0.6667],
[ 0.3333, 0.6667],
[ 0.6667, 0.6667],
[ 1.0000, 0.6667]], [[-1.0000, 1.0000],
[-0.6667, 1.0000],
[-0.3333, 1.0000],
[ 0.0000, 1.0000],
[ 0.3333, 1.0000],
[ 0.6667, 1.0000],
[ 1.0000, 1.0000]]]])
tensor([[[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.]]]])
mode=bilinear
tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00],
[1.1921e-07, 1.1921e-07, 1.1921e-07, 1.1921e-07, 1.1921e-07,
1.1921e-07, 1.1921e-07],
[1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
1.0000e+00, 1.0000e+00],
[2.0000e+00, 2.0000e+00, 2.0000e+00, 2.0000e+00, 2.0000e+00,
2.0000e+00, 2.0000e+00],
[3.0000e+00, 3.0000e+00, 3.0000e+00, 3.0000e+00, 3.0000e+00,
3.0000e+00, 3.0000e+00],
[4.0000e+00, 4.0000e+00, 4.0000e+00, 4.0000e+00, 4.0000e+00,
4.0000e+00, 4.0000e+00],
[4.0000e+00, 4.0000e+00, 4.0000e+00, 4.0000e+00, 4.0000e+00,
4.0000e+00, 4.0000e+00]], [[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00],
[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00],
[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00],
[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00],
[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00],
[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00],
[0.0000e+00, 1.1921e-07, 1.0000e+00, 2.0000e+00, 3.0000e+00,
4.0000e+00, 4.0000e+00]]]])
mode=nearest
tensor([[[[0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4., 4., 4.],
[4., 4., 4., 4., 4., 4., 4.]], [[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.],
[0., 0., 1., 2., 3., 4., 4.]]]])
'''

F.grid_sample 与Deformable Convolution的关系

虽然二者都实现了对于输入与输出位置映射关系的调整, 但是二者调整的方式有着明显的差别.

  • 参考坐标系不同

    • 前者的坐标系是基于整体输入的一个归一化坐标系, 原点为输入的HW平面的中心位置, H轴和W轴分别以向下和向右为正向. 而在坐标系WOH中, 输入数据的左上角为\((-1, -1)\), 右上角为\((1, -1)\).
    • 后者的坐标系是相对于权重初始作用位置的相对坐标系. 但是实际上, 这里其实理解为沿着H轴和W轴的_相对偏移量_更为合适. 例如, 将权重作用位置向左偏移一个单位, 实际上让其对应的偏移参数组\((\delta_h, \delta_w)\)取值为\((0, -1)\)即可, 即将作用位置相对于原始作用位置的\(w\)坐标加上个\(-1\).
  • 作用效果不同
    • 前者直接对整体输入进行坐标调整, 对于输入的所有通道具有相同的调整效果.
    • 后者由于构建于卷积操作之上, 所以可以更加方便的处理不同通道( offset_groups )、不同的实际上可能有重叠的局部区域( kernel_height * kernel_width ). 所以实际功能更加灵活和可调整.

Shift操作的第二春

虽然在之前的工作中已经探索了多种空间shift操作的形式, 但是却并没有引起太多的关注.

  • (CVPR 2018) [Grouped Shift] Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions:
  • (ICCV 2019) 4-Connected Shift Residual Networks
  • (NIPS 2018) [Active Shift] Constructing Fast Network through Deconstruction of Convolution
  • (CVPR 2019) [Sparse Shift] All You Need Is a Few Shifts: Designing Efficient Convolutional Neural Networks for Image Classification

这些工作大多专注于轻量化网络的设计, 而现在的这些基于shift的方法, 则结合了MLP这一快船, 好像又激起了一些新的水花.

当前的这些方法, 往往会采用更有效的训练设定, 这些模型之外的策略在一定程度上也极大的提升了模型的表现. 这其实也会让人疑惑, 如果直接迁移之前的那些shift操作到这里的MLP框架中, 或许性能也不会差吧?

这一想法其实也适用于传统的CNN方法, 之前的那些结构如果使用相同的训练策略, 相比现在, 到底能差多少? 这估计只能那些有卡有时间有耐心的大佬们能够一探究竟了.

实际上综合来看, 现有的这些基于空间偏移的MLP的方法, 更可以看作是 (NIPS 2018) [Active Shift] Constructing Fast Network through Deconstruction of Convolution 这篇工作的特化版本.

也就是将原本这篇工作中的自适应学习的偏移参数改成了固定的偏移参数.

Pytorch之Spatial-Shift-Operation的5种实现策略的更多相关文章

  1. The performance between the 'normal' operation and the 'shift' operation.

    First, I gonna post my test result with some code: //test the peformance of the <normal operation ...

  2. 第十七节: EF的CodeFirst模式的四种初始化策略和通过Migration进行数据的迁移

    一. 四种初始化策略 EF的CodeFirst模式下数据库的初始化有四种策略: 1. CreateDatabaseIfNotExists:EF的默认策略,数据库不存在,生成数据库:一旦model发生变 ...

  3. SqlServer 更改复制代理配置文件参数及两种冲突策略设置

    原文:SqlServer 更改复制代理配置文件参数及两种冲突策略设置 由于经常需要同步测试并更改代理配置文件属性,所以总结成脚本,方便测试. 可更新订阅的冲突策略有两种情况:一是在发布中冲突,即订阅数 ...

  4. Redis两种持久化策略分析

    Redis专题地址:https://www.cnblogs.com/hello-shf/category/1615909.html SpringBoot读源码系列:https://www.cnblog ...

  5. Java-五种线程池,四种拒绝策略,三种阻塞队列(转)

    Java-五种线程池,四种拒绝策略,三种阻塞队列 三种阻塞队列:    BlockingQueue<Runnable> workQueue = null;    workQueue = n ...

  6. 树的三种DFS策略(前序、中序、后序)遍历

    之前刷leetcode的时候,知道求排列组合都需要深度优先搜索(DFS), 那么前序.中序.后序遍历是什么鬼,一直傻傻的分不清楚.直到后来才知道,原来它们只是DFS的三种不同策略. N = Node( ...

  7. Java中的策略模式,完成一个简单地购物车,两种付款策略实例教程

    策略模式是一种行为模式.用于某一个具体的项目有多个可供选择的算法策略,客户端在其运行时根据不同需求决定使用某一具体算法策略. 策略模式也被称作政策模式.实现过程为,首先定义不同的算法策略,然后客户端把 ...

  8. JUC之线程池-三大方法-七大参数-四种拒绝策略

    线程池:重点 三大方法 七大参数 四种拒绝策略 使用池化技术的理由: 我们的程序伴随着创建销毁线程十分浪费资源, 所以使用线程池,先创建线程,随用随取,用完归还 简单来说就是节约了资源. 使用线程池的 ...

  9. sharding-jdbc 分库分表的 4种分片策略,还蛮简单的

    上文<快速入门分库分表中间件 Sharding-JDBC (必修课)>中介绍了 sharding-jdbc 的基础概念,还搭建了一个简单的数据分片案例,但实际开发场景中要远比这复杂的多,我 ...

随机推荐

  1. java 数据类型:枚举类enum、对比方法compreTo()、获取名字.name()、获取对应值的枚举类Enum.valueOf()、包含构造方法和抽象方法的enum;实现接口;

    问题引入 为了将某一数据类型的值限定在可选的合理范围内,比如季节只有四个:春夏秋冬. 什么是枚举类 Java5之后新增了enum关键字(他与class,interface关键字地位相同)用来定义枚举类 ...

  2. 解决Tomcat9打印台乱码问题

    问题描述: Tomcat打印台.打印出来的字体全是乱码后的显示.影响视觉体验,不利于bug查找和错误排查.故寻找方法去修改. 解决方法: 1.找到目录 2.对日志参数进行修改 3.改动编码 4.修改成 ...

  3. 【LeetCode】723. Candy Crush 解题报告 (C++)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客:http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 暴力 日期 题目地址:https://leetcode ...

  4. 【LeetCode】677. Map Sum Pairs 解题报告(Python & C++)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 字典 前缀树 日期 题目地址:https://lee ...

  5. 【LeetCode】883. Projection Area of 3D Shapes 解题报告(Python)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 题目大意 解题方法 数学计算 日期 题目地址:https://leetc ...

  6. Soldier and Traveling

    B. Soldier and Traveling Time Limit: 1000ms Memory Limit: 262144KB 64-bit integer IO format: %I64d   ...

  7. 1119 - Pimp My Ride

    1119 - Pimp My Ride    PDF (English) Statistics Forum Time Limit: 2 second(s) Memory Limit: 32 MB To ...

  8. 为什么我的 WordPress 网站被封了?

    今年以来,一系列 "清朗" "护苗" "净网" 专项整治行动重拳出击,"清朗·春节网络环境"取消备案网站平台2300余家 ...

  9. Java的generator工具类,数据库生成实体类和映射文件

    首先需要几个jar包: freemarker-2.3.23.jar log4j-1.2.16.jar mybatis-3.2.3.jar mybatis-generator-core-1.3.2.ja ...

  10. Java初学者作业——定义管理员类(Admin),管理员类中的属性包括:姓名、账号、密码、电话;方法包括:登录、显示自己的信息。

    返回本章节 返回作业目录 需求说明: 定义管理员类(Admin),管理员类中的属性包括:姓名.账号.密码.电话:方法包括:登录.显示自己的信息. 实现思路: 分析类的属性及其变量类型. 分析类的方法及 ...