池化层的作用如下-引用《TensorFlow实践》:

池化层的作用是减少过拟合,并通过减小输入的尺寸来提高性能。他们可以用来对输入进行降采样,但会为后续层保留重要的信息。只使用tf.nn.conv2d来减小输入的尺寸也是可以的,但是池化层的效率更高。

常见的TensorFlow提供的激活函数如下:(详细请参考http://www.tensorfly.cn/tfdoc/api_docs/python/nn.html)

1.tf.nn.max_pool(value, ksize, strides, padding, name=None)

Performs the max pooling on the input.

  • value: A 4-D Tensor with shape [batch, height, width, channels] and type float32,float64qint8quint8qint32.
  • ksize: A list of ints that has length >= 4. The size of the window for each dimension of the input tensor.
  • strides: A list of ints that has length >= 4. The stride of the sliding window for each dimension of the input tensor.
  • padding: A string, either 'VALID' or 'SAME'. The padding algorithm.
  • name: Optional name for the operation.

当我们的ksize=[1 3 3 1]式,我们取3X3模板池pool当中最大的数据当做中心点的值,strides作为滑动跳跃的间隔,代码如下所示:

import tensorflow as tf

batch_size = 1
input_height = 3
input_width = 3
input_channels = 1
layer_input = tf.constant([
[
[[1.0],[0.2],[1.5]],
[[0.1],[1.2],[1.4]],
[[1.1],[0.4],[0.4]]
]
])
kernel = [batch_size, input_height, input_width,input_channels]
max_pool = tf.nn.max_pool(layer_input,kernel,[1,1,1,1],padding='VALID',name=None)
sess = tf.Session()
sess.run(max_pool)

输出结果如下(注意max_pool的输入的参数的维数一定要正确,否则会报错):

2.tf.nn.avg_pool(value, ksize, strides, padding, name=None)

Performs the average pooling on the input.

Each entry in output is the mean of the corresponding size ksize window in value.

  • value: A 4-D Tensor of shape [batch, height, width, channels] and type float32,float64qint8quint8, or qint32.
  • ksize: A list of ints that has length >= 4. The size of the window for each dimension of the input tensor.
  • strides: A list of ints that has length >= 4. The stride of the sliding window for each dimension of the input tensor.
  • padding: A string, either 'VALID' or 'SAME'. The padding algorithm.
  • name: Optional name for the operation.

跳跃遍历一个张量,并将被卷积核覆盖的各深度值去平均。当整个卷积核都非常重要时,若需要实现值的缩减,平均池化非常有用,例如输入张量宽度和高度很大,但深度很小的情况。

import tensorflow as tf

batch_size = 1
input_height = 3
input_width = 3
input_channels = 1
layer_input = tf.constant([
[
[[1.0],[1.0],[1.0]],
[[1.0],[0.5],[0.0]],
[[0.0],[0.0],[0.0]]
]
])
kernel = [batch_size, input_height, input_width,input_channels]
avg_pool = tf.nn.mavg_pool(layer_input,kernel,[1,1,1,1],padding='VALID',name=None)
sess = tf.Session()
sess.run(avg_pool)

输出结果如下:(1.0+1.0+1.0+0.5+0.0+0.0+0.0+0.0)/9=0.5

TensorFlow池化层-函数的更多相关文章

  1. TensorFlow 池化层

    在 TensorFlow 中使用池化层 在下面的练习中,你需要设定池化层的大小,strides,以及相应的 padding.你可以参考 tf.nn.max_pool().Padding 与卷积 pad ...

  2. tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)

    池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...

  3. tensorflow的卷积和池化层(二):记实践之cifar10

    在tensorflow中的卷积和池化层(一)和各种卷积类型Convolution这两篇博客中,主要讲解了卷积神经网络的核心层,同时也结合当下流行的Caffe和tf框架做了介绍,本篇博客将接着tenso ...

  4. tensorflow中的卷积和池化层(一)

    在官方tutorial的帮助下,我们已经使用了最简单的CNN用于Mnist的问题,而其实在这个过程中,主要的问题在于如何设置CNN网络,这和Caffe等框架的原理是一样的,但是tf的设置似乎更加简洁. ...

  5. tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图

    tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...

  6. 『TensorFlow』卷积层、池化层详解

    一.前向计算和反向传播数学过程讲解

  7. 第十三节,使用带有全局平均池化层的CNN对CIFAR10数据集分类

    这里使用的数据集仍然是CIFAR-10,由于之前写过一篇使用AlexNet对CIFAR数据集进行分类的文章,已经详细介绍了这个数据集,当时我们是直接把这些图片的数据文件下载下来,然后使用pickle进 ...

  8. 学习笔记TF014:卷积层、激活函数、池化层、归一化层、高级层

    CNN神经网络架构至少包含一个卷积层 (tf.nn.conv2d).单层CNN检测边缘.图像识别分类,使用不同层类型支持卷积层,减少过拟合,加速训练过程,降低内存占用率. TensorFlow加速所有 ...

  9. 基于深度学习和迁移学习的识花实践——利用 VGG16 的深度网络结构中的五轮卷积网络层和池化层,对每张图片得到一个 4096 维的特征向量,然后我们直接用这个特征向量替代原来的图片,再加若干层全连接的神经网络,对花朵数据集进行训练(属于模型迁移)

    基于深度学习和迁移学习的识花实践(转)   深度学习是人工智能领域近年来最火热的话题之一,但是对于个人来说,以往想要玩转深度学习除了要具备高超的编程技巧,还需要有海量的数据和强劲的硬件.不过 Tens ...

随机推荐

  1. css中input框不可点击+首行缩进

    Css 1)text-indent::首行缩进 2)disabled="true"设置input框不可以点击 3)Css:xx!important:声明提前优先级最高..!impo ...

  2. Confluence 6 从一个 XML 备份中导入一个空间

    有下面 2 中方法可以导入一个空间——通过上传一个文件,或者从你 Confluence 服务器上的一个目录中导入.上传文件仅仅是针对一个小站点的情况.为了取得最好的导入结果,我们推荐你从服务器上的目录 ...

  3. laravel 5.6

    compact() 建立一个数组,包括变量名和它们的值 打印结果: starts_with() 函数判断给定的字符串的开头是否是指定值

  4. PLC漏洞问题

    1.PLC采用大多是经过裁剪的实时操作系统,比如像linux RT.QNX.VxWorks等,这些实时操作系统广泛应用在通信.军事.航天.等工程领域,但是随之工业与网络的互连爆发出很多问题,常见的PL ...

  5. python压缩文件

    #coding=utf-8 #压缩文件 import os,os.path import zipfile #压缩:传路径,文件名 def zip_compression(dirname,zipfile ...

  6. C++ Primer 笔记——OOP

    1.基类通常都应该定义一个虚析构函数,即使该函数不执行任何实际操作也是如此. 2.任何构造函数之外的非静态函数都可以是虚函数,关键字virtual只能出现在类内部的声明语句之前而不能用于类外部的函数定 ...

  7. 对象存储服务(Object Storage Service,简称 OSS)

    阿里云对象存储服务(Object Storage Service,简称 OSS),是阿里云提供的海量.安全.低成本.高可靠的云存储服务.它具有与平台无关的RESTful API接口,能够提供99.99 ...

  8. document.getElementsByClassName() 原生方法 通过className 选择DOM节点

    <div id="box"> <div class="box">1</div> <div class="bo ...

  9. Android 第一波

    1. Devik进程,Linux进程,线程的区别 说一说对 SP 频繁操作有什么后果? SP 能存储多少数据? SP 的底层其实是由xml文件来实现的,操作 SP 的过程其实就是xml的序列化和反序列 ...

  10. Two Sum【LeetCode】

    Given an array of integers, return indices of the two numbers such that they add up to a specific ta ...