先看一下CLASS有哪些参数:

  1. torch.nn.Conv2d(
  2. in_channels,
  3. out_channels,
  4. kernel_size,
  5. stride=1,
  6. padding=0,
  7. dilation=1,
  8. groups=1,
  9. bias=True,
  10. padding_mode='zeros'
  11. )

可以对输入的张量进行 2D 卷积。

in_channels: 输入图片的 channel 数。

out_channels: 输出图片的 channel 数。

kernel_size: 卷积核的大小。

stride: 滑动的步长。

bias: 若设为 True,则对输出图像每个元素加上一个可以学习的 bias。

dilation: 核间点距。

padding: 控制补 $0$ 的数目。padding 是在卷积之前补 $0$,如果愿意的话,可以通过使用 torch.nn.Functional.pad 来补非 $0$ 的内容。padding 补 $0$ 的策略是四周都补,如果 padding 输入是一个二元组的话,则第一个参数表示高度上面的 padding,第2个参数表示宽度上面的 padding。

关于 padding 策略的例子:

  1. x = torch.tensor([[[[-1.0, 2.0], [3.5, -4.0]]]])
  2. print(x, x.shape) # N = 1, C = 1, (H,W) = (2,2)
  3. layer1 = torch.nn.Conv2d(1, 1, kernel_size=(1, 1), padding=0)
  4. layer2 = torch.nn.Conv2d(1, 1, kernel_size=(1, 1), padding=(1, 2))
  5. y = layer1(x)
  6. print(y, y.shape)
  7. z = layer2(x)
  8. print(z, z.shape)

结果:

  1. tensor([[[[-1.0000, 2.0000],
  2. [ 3.5000, -4.0000]]]]) torch.Size([1, 1, 2, 2])
  3. tensor([[[[-0.3515, 0.4479],
  4. [ 0.8476, -1.1510]]]], grad_fn=<ThnnConv2DBackward>) torch.Size([1, 1, 2, 2])
  5. tensor([[[[-0.6553, -0.6553, -0.6553, -0.6553, -0.6553, -0.6553],
  6. [-0.6553, -0.6553, 0.2367, -2.4393, -0.6553, -0.6553],
  7. [-0.6553, -0.6553, -3.7772, 2.9127, -0.6553, -0.6553],
  8. [-0.6553, -0.6553, -0.6553, -0.6553, -0.6553, -0.6553]]]],
  9. grad_fn=<ThnnConv2DBackward>) torch.Size([1, 1, 4, 6])

可以看到 padding 为 $(1,2)$ 时,在高度上两边各增加了 $1$ 行,总共增加 $2$ 行。在宽度上两边各增加 $2$ 列,总共增加 $4$ 列。至于为什么增加的行列不是 $0$,这是因为有参数 bias 存在的缘故,此时 bias 值为 $-0.6553$(这个 bias 值初始值应该是一个随机数)。

关于 dilation:

默认情况下 dilation 为 $(1,1)$,就是正常的紧密排布的卷积核。

下图是 dilation 为 $(2,2)$ 的情况(没有 padding,stride 为 $(1,1)$),蓝色的是输入图像,绿色的是输出图像。

输入图像的 shape 是 $(N, C_{in}, H_{in}, W_{in})$,$N$ 是 batch size,$C_{in}$ 表示 channel 数,$H,W$ 分别表示高和宽。

输出图像的 shape $(N, C_{out}, H_{out}, W_{out})$ 可以通过计算得到:

这个式子很好理解,由于宽高的计算类似,所以只以高为例子来讲:

$H_{in} + 2 \times \rm{padding}[0]$ 即输入图像补完 $0$ 之后的高度,一个卷积核在图像上所能覆盖的高度为 $(\rm{kernel\_size}[0] - 1) \times \rm{dilation}[0] + 1$(例如上面动图就是 $(3 - 1) \times 2 + 1 = 5$),这两个值相减即为,步长为 $1$ 时,卷积核在图像高度上能滑动的次数。而这个次数除去实际步长 $stride[0]$ 再向下取整,即卷积核在图像高度上实际能滑动的次数。这个实际滑动次数加上 $1$ 即输出图像的高度。

需要注意的是:kernel_size, stride, padding, dilation 不但可以是一个单个的 int ——表示在高度和宽度使用同一个 int 作为参数,也可以使用一个 (int1, int2) 的二元组(其实本质上单个的 int 也可以看作一个二元组 (int, int))。在元组中,第1个参数对应高度维度,第2个参数对应宽度维度。

另外,对于卷积核,它其实并不是二维的,它具有长宽深三个维度;实际上它的 channel 数等于输入图像的 channel 数 $C_{in}$,而卷积核的个数即输出图像的 channel 数 $C_{out}$。

以上图为例,输入图像的 shape 是 $(C = 3, H = 6, W = 6)$,这里略去 batch size,第一个卷积核是 $(C = 3, H = 3, W = 3)$,他在输入图像上滑动并卷积后得到一张 $(C = 1, H = 4, W = 4)$ 的特征图(feature map),第二个卷积核类似得到第二张 $(C = 1, H = 4, W = 4)$ 特征图,那么输出图像就是把这两张特征图叠在一块儿,shape 即为 $(C = 2, H = 4, W = 4)$。


这里顺带记录一下 Batch norm 2D 是怎么做的:

如果把一个 shape 为 $(N, C, H, W)$ 类比为一摞书,这摞书总共有 N 本,每本均有 C 页,每页有 H 行,每行 W 个字符。BN 求均值时,相当于把这 $N$ 本书都选同一个页码加起来(例如第1本书的第36页,第2本书的第36页......),再除以每本书的该页上的字符的总数 $N \times H \times W$,因此可以把 BN 看成求“平均书”的操作(注意这个“平均书”每页只有一个字),求标准差时也是同理。

例如下图,输入的张量 shape 为 $(4, 3, 2, 2)$,对于所有 batch 中的同一个 channel 的元素进行求均值与方差,比如对于所有的 batch,都拿出来最后一个channel,一共有 $f_1 + f_2 + f_3 + f_4 = 4 + 4 + 4 + 4 = 16$ 个元素,然后去求这 $16$ 个元素的均值与方差。

求取完了均值与方差之后,对于这 $16$ 个元素中的每个元素分别进行归一化,然后乘以 $\gamma$ 加上 $\beta$,公式如下

batch norm层能够学习到的参数,对于一个特定的 channel 而言实际上是两个参数 $\gamma, beta$,而对于所有的channel而言实际上就是 channel 数的两倍。

关于其他的 Normalization 做法的形象理解可以参考https://zhuanlan.zhihu.com/p/69659844

关于torch.nn.Conv2d的笔记的更多相关文章

  1. torch.nn.Conv2d()使用

    API 输入:[ batch_size, channels, height_1, width_1 ] Conv2d输入参数:[ channels, output, height_2, width_2 ...

  2. 关于torch.nn.Linear的笔记

    关于该类: torch.nn.Linear(in_features, out_features, bias=True) 可以对输入数据进行线性变换: $y  = x A^T + b$ in_featu ...

  3. [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList

    1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...

  4. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  5. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

  6. PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx

    PyTorch : torch.nn.xxx 和 torch.nn.functional.xxx 在写 PyTorch 代码时,我们会发现一些功能重复的操作,比如卷积.激活.池化等操作.这些操作分别可 ...

  7. nn.Conv2d 参数及输入输出详解

    Torch.nn.Conv2d(in_channels,out_channels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=Tru ...

  8. Pytorch中nn.Conv2d的用法

    Pytorch中nn.Conv2d的用法 nn.Conv2d是二维卷积方法,相对应的还有一维卷积方法nn.Conv1d,常用于文本数据的处理,而nn.Conv2d一般用于二维图像. 先看一下接口定义: ...

  9. Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别

    在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...

随机推荐

  1. Vue + Webpack 根据不同环境打包

    修改 prod.env.js // 当前正在运行的脚本名称 const TARGET = process.env.npm_lifecycle_event // 第一个参数 let argv = pro ...

  2. 获取QQ群中的所有群友QQ

    package com.jm.mail.tools; import java.io.BufferedReader; import java.io.IOException; import java.io ...

  3. Python测试进阶——(3)编写Python程序监控计算机的服务是否正常运行

    用python写了个简单的监控进程的脚本,当发现进程消失的时候,立即调用服务,开启服务. 脚本的工作原理是这样的:脚本读取配置文件,读取预先配置好的调用系统服务的路径和所要监控的服务在进程管理器中的进 ...

  4. Easy_Re

    这题比较简单,一波常规的操作之后直接上ida(小白的常规操作在以前的博客里都有所以这里不在赘述了),ida打开之后查看一下, 这里应该就是一个入口点了,接着搜索flag字符串, 上面的黄色的部分转换成 ...

  5. 在linux7(centos)中安装python3.7.2

    一般情况下linux上都默认安装了python,检查一下我的版本 没有安装python3,但是目前已经是python3了,所以为了方便,还是要在系统上安装一下比较好. 上面的命令,直接输入python ...

  6. 动态设置html根字体大小(随着设备屏幕的大小而变化,从而实现响应式)

    代码如下:如果设置了根字体大小,font-size必须是rem var html =document.querySelector('html'); html.style.fontSize = docu ...

  7. Docker 学习之部署php + nginx(一)

    博主电脑系统是window 10 专业版的,所以在此记录下docker的基本使用方法. 参考地址: https://www.runoob.com/docker/docker-install-php.h ...

  8. 查看oracle单签session

    转自 https://blog.csdn.net/alexsong123/article/details/51858092 怎样查看oracle当前的连接数呢?只需要用下面的SQL语句查询一下就可以了 ...

  9. pt-archiver 归档数据

    pt-archiver 参数说明pt-archiver是Percona-Toolkit工具集中的一个组件,是一个主要用于对MySQL表数据进行归档和清除工具.它可以将数据归档到另一张表或者是一个文件中 ...

  10. Vuex源码分析(转)

    当我们用vue在开发的过程中,经常会遇到以下问题 多个vue组件共享状态 Vue组件间的通讯 在项目不复杂的时候,我们会利用全局事件bus的方式解决,但随着复杂度的提升,用这种方式将会使得代码难以维护 ...