在深度学习的图像识别领域中,我们经常使用卷积神经网络CNN来对图像进行特征提取,当我们使用TensorFlow搭建自己的CNN时,一般会使用TensorFlow中的卷积函数和池化函数来对图像进行卷积和池化操作,而这两种函数中都存在参数padding,该参数的设置很容易引起错误,所以在此总结下。

1.为什么要使用padding

在弄懂padding规则前得先了解拥有padding参数的函数,在TensorFlow中,主要使用tf.nn.conv2d()进行(二维数据)卷积操作,tf.nn.max_pool()、tf.nn.avg_pool来分别实现最大池化和平均池化,通过查阅官方文档我们知道其需要的参数如下:

  1. tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None,name=None)
  2. tf.nn.max_pool_with_argmax(input, ksize, strides, padding, Targmax=None, name=None)
  3. tf.nn.max_pool(value, ksize, strides, padding, name=None)

这三个函数中都含有padding参数,我们在使用它们的时候需要传入所需的值,padding的值为字符串,可选值为'SAME' 和 'VALID' ;

padding参数的作用是决定在进行卷积或池化操作时,是否对输入的图像矩阵边缘补0,'SAME' 为补零,'VALID' 则不补,其原因是因为在这些操作过程中过滤器可能不能将某个方向上的数据刚好处理完,如下所示:

当步长为5,卷积核尺寸为6×6时,当padding为VALID时,则可能造成数据丢失(如左图),当padding为SAME时,则对其进行补零(如右图),

2. padding公式

首先,定义变量:

输入图片的宽和高:i_w 和 i_h

输出特征图的宽和高:o_w 和 o_h

过滤器的宽和高:f_w 和 f_h

宽和高方向的步长:s_w 和 s_h

宽和高方向总的补零个数:pad_w 和 pad_h

顶部和底部的补零个数:pad_top 和 pad_bottom

左部和右部的补零个数:pad_left 和 pad_right

1.VALID模式

输出的宽和高为

  1. o_w = i_w - f_w + 1)/ s_w #(结果向上取整)
  2. o_h = i_h - f_h + 1)/ s_h #(结果向上取整)

2. SAME模式

输出的宽和高为

  1. o_w = i_w / s_w#(结果向上取整)
  2. o_h = i_h / s_h#(结果向上取整)

各个方向的补零个数为:max()为取较大值,

  1. pad_h = max(( o_h -1 ) × s_h + f_h - i_h 0
  2. pad_top = pad_h / 2 # 注意此处向下取整
  3. pad_bottom = pad_h - pad_top
  4. pad_w = max(( o_w -1 ) × s_w + f_w - i_w 0
  5. pad_left = pad_w / 2 # 注意此处向下取整
  6. pad_right = pad_w - pad_left

3.卷积padding的实战分析

接下来我们通过在TensorFlow中使用卷积和池化函数来分析padding参数在实际中的应用,代码如下:

  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3.  
  4. # 首先,模拟输入一个图像矩阵,大小为5*5
  5. # 输入图像矩阵的shape为[批次大小,图像的高度,图像的宽度,图像的通道数]
  6. input = tf.Variable(tf.constant(1.0, shape=[1, 5, 5, 1]))
  7.  
  8. # 定义卷积核,大小为2*2,输入和输出都是单通道
  9. # 卷积核的shape为[卷积核的高度,卷积核的宽度,图像通道数,卷积核的个数]
  10. filter1 = tf.Variable(tf.constant([-1.0, 0, 0, -1], shape=[2, 2, 1, 1]))
  11.  
  12. # 卷积操作 strides为[批次大小,高度方向的移动步长,宽度方向的移动步长,通道数]
  13. # SAME
  14. op1_conv_same = tf.nn.conv2d(input, filter1, strides=[1,2,2,1],padding='SAME')
  15. # VALID
  16. op2_conv_valid = tf.nn.conv2d(input, filter1, strides=[1,2,2,1],padding='VALID')
  17.  
  18. init = tf.global_variables_initializer()
  19. with tf.Session() as sess:
  20. sess.run(init)
  21. print("op1_conv_same:\n", sess.run(op1_conv_same))
  22. print("op2_conv_valid:\n", sess.run(op2_conv_valid))

VALID模式的分析:

SAME模式分析:

  1. o_w = i_w / s_w = 5/2 = 3
  2. o_h = i_h / s_h = 5/2 = 3
  3.  
  4. pad_w = max ( (o_w - 1 ) × s_w + f_w - i_w , 0 )
  5. = max ( (3 - 1 ) × 2 + 2 - 5 , 0 ) = 1
  6. pad_left = 1 / 2 =0
  7. pad_right = 1 - 0 =0
  8. # 同理
  9. pad_top = 0
  10. pad_bottom = 1

运行代码后的结果如下:

4.池化padding的实战分析

这里主要分析最大池化和平均池化两个函数,函数中padding参数设置和矩阵形状计算都与卷积一样,但需要注意的是:

1. 当padding='SAME',计算avg_pool时,每次的计算是除以图像被filter框出的非零元素的个数,而不是filter元素的个数,如下图,第一行第三列我们计算出的结果是除以2而非4,第三行第三列计算出的结果是除以1而非4;

2. 当计算全局池化时,即与图像矩阵形状相同的过滤器进行一次池化,此情况下无padding,即在边缘没有补0,我们直接除以整个矩阵的元素个数,而不是除以非零元素个数(注意与第一点进行区分)

池化函数的代码示例如下:

  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3.  
  4. # 首先,模拟输入一个特征图,大小为5*5
  5. # 输入图像矩阵的shape为[批次大小,图像的高度,图像的宽度,图像的通道数]
  6. input = tf.Variable(tf.constant(1.0, shape=[1, 5, 5, 1]))
  7.  
  8. # 最大池化操作 strides为[批次大小,高度方向的移动步长,宽度方向的移动步长,通道数]
  9. # ksize为[1, 池化窗口的高,池化窗口的宽度,1]
  10. # SAME
  11. op1_max_pooling_same = tf.nn.max_pool(input, [1,2,2,1], strides=[1,2,2,1],padding='SAME')
  12. # VALID
  13. op2_max_pooling_valid = tf.nn.max_pool(input, [1,2,2,1], strides=[1,2,2,1],padding='VALID')
  14.  
  15. # 平均池化
  16. op3_avg_pooling_same = tf.nn.avg_pool(input, [1,2,2,1], strides=[1,2,2,1],padding='SAME')
  17. # 全局池化,filter是一个与输入矩阵一样大的过滤器
  18. op4_global_pooling_same = tf.nn.avg_pool(input, [1,5,5,1], strides=[1,5,5,1],padding='SAME')
  19.  
  20. init = tf.global_variables_initializer()
  21. with tf.Session() as sess:
  22. sess.run(init)
  23. print("op1_max_pooling_same:\n", sess.run(op1_max_pooling_same))
  24. print("op2_max_pooling_valid:\n", sess.run(op2_max_pooling_valid))
  25. print("op3_max_pooling_same:\n", sess.run(op3_avg_pooling_same))
  26. print("op4_global_pooling_same:\n", sess.run(op4_global_pooling_same))

运行结果如下:

 

5.总结

在搭建CNN时,我们输入的图像矩阵在网络中需要经过多层卷积和池化操作,在这个过程中,feature map的形状会不断变化,如果不清楚padding参数引起的这些变化,程序在运行过程中会发生错误,当然在实际写代码时,可以将每一层feature map的形状打印出来,了解每一层Tensor的变化。

转载请注明出处:https://www.cnblogs.com/White-xzx/p/9497029.html

【TensorFlow】一文弄懂CNN中的padding参数的更多相关文章

  1. 一文弄懂神经网络中的反向传播法——BackPropagation【转】

    本文转载自:https://www.cnblogs.com/charlotte77/p/5629865.html 一文弄懂神经网络中的反向传播法——BackPropagation   最近在看深度学习 ...

  2. 基于TensorFlow理解CNN中的padding参数

    1 TensorFlow中用到padding的地方 在TensorFlow中用到padding的地方主要有tf.nn.conv2d(),tf.nn.max_pool(),tf.nn.avg_pool( ...

  3. [转] 一文弄懂神经网络中的反向传播法——BackPropagation

    在看CNN和RNN的相关算法TF实现,总感觉有些细枝末节理解不到位,浮在表面.那么就一点点扣细节吧. 这个作者讲方向传播也是没谁了,666- 原文地址:https://www.cnblogs.com/ ...

  4. 一文弄懂神经网络中的反向传播法——BackPropagation

    最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...

  5. 一文弄懂神经网络中的反向传播法(Backpropagation algorithm)

    最近在看深度学习的东西,一开始看的吴恩达的UFLDL教程,有中文版就直接看了,后来发现有些地方总是不是很明确,又去看英文版,然后又找了些资料看,才发现,中文版的译者在翻译的时候会对省略的公式推导过程进 ...

  6. 彻底弄懂AngularJS中的transclusion

    点击查看AngularJS系列目录 彻底弄懂AngularJS中的transclusion AngularJS中指令的重要性是不言而喻的,指令让我们可以创建自己的HTML标记,它将自定义元素变成了一个 ...

  7. 一文弄懂-Netty核心功能及线程模型

    目录 一. Netty是什么? 二. Netty 的使用场景 三. Netty通讯示例 1. Netty的maven依赖 2. 服务端代码 3. 客户端代码 四. Netty线程模型 五. Netty ...

  8. 一文弄懂-《Scalable IO In Java》

    目录 一. <Scalable IO In Java> 是什么? 二. IO架构的演变历程 1. Classic Service Designs 经典服务模型 2. Event-drive ...

  9. 一文弄懂-BIO,NIO,AIO

    目录 一文弄懂-BIO,NIO,AIO 1. BIO: 同步阻塞IO模型 2. NIO: 同步非阻塞IO模型(多路复用) 3.Epoll函数详解 4.Redis线程模型 5. AIO: 异步非阻塞IO ...

随机推荐

  1. Dubbo、Zookeeper集群搭建及Rose使用心得(二)

    上篇讲了一下配置,这次主要写一下这个框架开发的大概流程.这里以实现 登陆 功能为例. 一.准备工作 1.访问拦截器 用户在进行网站访问的时候,有可能访问到不存在的网页,所以,我们需要把这些链接重新定向 ...

  2. JAVA 加密算法初探DES&AES

    开发项目中需要将重要数据缓存在本地以便在离线是读取,如果不对数据进行处理,很容易造成损失.所以,我们一般对此类数据进行加密处理.这里,主要介绍两种简单的加密算法:DES&AES. 先简单介绍一 ...

  3. Betsy Ross Problem

    Matlab学习中的betsy ross 问题.用matlab函数画1777年的美国国旗. 五角星绘制部分是自己想出来的方法去画上的.具体代码参考如下. 先是绘制矩形的函数 function Draw ...

  4. pymc

    sklearn实战-乳腺癌细胞数据挖掘 https://study.163.com/course/introduction.htm?courseId=1005269003&utm_campai ...

  5. Redis实战(一)CentOS 7上搭建redis-3.0.2

    1.安装redis wget http://download.redis.io/releases/redis-3.0.2.tar.gz tar zxvf redis-3.0.2.tar.gz cd   ...

  6. iframe元素的学习(笔记)

    什么是iframe:iframe元素即内联框架,iframe是内联的并且承前启后,对于外围的页面,iframe是一个普通的元素,对于iframe里面的内容,又是一个五脏俱全的页面.重下面的写法可以看出 ...

  7. Web开发中的18个关键性错误

    前几年,我有机会能参与一些有趣的项目,并且独立完成开发.升级.重构以及新功能的开发等工作. 本文总结了一些PHP程序员在Web开发中经常 忽略的关键错误,尤其是在处理中大型的项目上问题更为突出.典型的 ...

  8. cin.get()和cin.getline()之间的区别

    cin.getline()和cin.get()都是对输入的面向行的读取,即一次读取整行而不是单个数字或字符,但是二者有一定的区别. cin.get()每次读取一整行并把由Enter键生成的换行符留在输 ...

  9. 10种CSS3实现的Loading效果

    原文链接:http://www.cnblogs.com/jr1993/p/4622039.html 第一种效果: 代码如下: <div class="loading"> ...

  10. 2016.5.14——leetcode-HappyNumber,House Robber

    leetcode:HappyNumber,House Robber 1.Happy Number 这个题中收获2点: 1.拿到题以后考虑特殊情况,代码中考虑1和4,或者说<6的情况,动手算下.( ...