这一节,介绍TensorFlow中的一个封装好的高级库,里面有前面讲过的很多函数的高级封装,使用这个高级库来开发程序将会提高效率。

我们改写第十三节的程序,卷积函数我们使用tf.contrib.layers.conv2d(),池化函数使用tf.contrib.layers.max_pool2d()和tf.contrib.layers.avg_pool2d(),全连接函数使用tf.contrib.layers.fully_connected()。

一 tf.contrib.layers中的具体函数介绍

1.tf.contrib.layers.conv2d()函数的定义如下:

  1. def convolution(inputs,
  2. num_outputs,
  3. kernel_size,
  4. stride=1,
  5. padding='SAME',
  6. data_format=None,
  7. rate=1,
  8. activation_fn=nn.relu,
  9. normalizer_fn=None,
  10. normalizer_params=None,
  11. weights_initializer=initializers.xavier_initializer(),
  12. weights_regularizer=None,
  13. biases_initializer=init_ops.zeros_initializer(),
  14. biases_regularizer=None,
  15. reuse=None,
  16. variables_collections=None,
  17. outputs_collections=None,
  18. trainable=True,
  19. scope=None):

常用的参数说明如下:

  • inputs:形状为[batch_size, height, width, channels]的输入。
  • num_outputs:代表输出几个channel。这里不需要再指定输入的channel了,因为函数会自动根据inpus的shpe去判断。
  • kernel_size:卷积核大小,不需要带上batch和channel,只需要输入尺寸即可。[5,5]就代表5x5的卷积核,如果长和宽都一样,也可以只写一个数5.
  • stride:步长,默认是长宽都相等的步长。卷积时,一般都用1,所以默认值也是1.如果长和宽都不相等,也可以用一个数组[1,2]。
  • padding:填充方式,'SAME'或者'VALID'。
  • activation_fn:激活函数。默认是ReLU。也可以设置为None
  • weights_initializer:权重的初始化,默认为initializers.xavier_initializer()函数。
  • weights_regularizer:权重正则化项,可以加入正则函数。biases_initializer:偏置的初始化,默认为init_ops.zeros_initializer()函数。
  • biases_regularizer:偏置正则化项,可以加入正则函数。
  • trainable:是否可训练,如作为训练节点,必须设置为True,默认即可。如果我们是微调网络,有时候需要冻结某一层的参数,则设置为False。

2.tf.contrib.layers.max_pool2d()函数的定义如下:

  1. def max_pool2d(inputs,
  2. kernel_size,
  3. stride=2,
  4. padding='VALID',
  5. data_format=DATA_FORMAT_NHWC,
  6. outputs_collections=None,
  7. scope=None):

参数说明如下:

  • inputs: A 4-D tensor of shape `[batch_size, height, width, channels]` if`data_format` is `NHWC`, and `[batch_size, channels, height, width]` if `data_format` is `NCHW`.
  • kernel_size: A list of length 2: [kernel_height, kernel_width] of the pooling kernel over which the op is computed. Can be an int if both values are the same.
  • stride: A list of length 2: [stride_height, stride_width].Can be an int if both strides are the same. Note that presently both strides must have the same value.
  • padding: The padding method, either 'VALID' or 'SAME'.
  • data_format: A string. `NHWC` (default) and `NCHW` are supported.
  • outputs_collections: The collections to which the outputs are added.
  • scope: Optional scope for name_scope.

3.tf.contrib.layers.avg_pool2d()函数定义

  1. def avg_pool2d(inputs,
  2. kernel_size,
  3. stride=2,
  4. padding='VALID',
  5. data_format=DATA_FORMAT_NHWC,
  6. outputs_collections=None,
  7. scope=None):

参数说明如下:

  • inputs: A 4-D tensor of shape `[batch_size, height, width, channels]` if`data_format` is `NHWC`, and `[batch_size, channels, height, width]` if `data_format` is `NCHW`.
  • kernel_size: A list of length 2: [kernel_height, kernel_width] of the pooling kernel over which the op is computed. Can be an int if both values are the same.
  • stride: A list of length 2: [stride_height, stride_width].Can be an int if both strides are the same. Note that presently both strides must have the same value.
  • padding: The padding method, either 'VALID' or 'SAME'.
  • data_format: A string. `NHWC` (default) and `NCHW` are supported.
  • outputs_collections: The collections to which the outputs are added.
  • scope: Optional scope for name_scope.

4.tf.contrib.layers.fully_connected()函数的定义如下:

  1. def fully_connected(inputs,
  2. num_outputs,
  3. activation_fn=nn.relu,
  4. normalizer_fn=None,
  5. normalizer_params=None,
  6. weights_initializer=initializers.xavier_initializer(),
  7. weights_regularizer=None,
  8. biases_initializer=init_ops.zeros_initializer(),
  9. biases_regularizer=None,
  10. reuse=None,
  11. variables_collections=None,
  12. outputs_collections=None,
  13. trainable=True,
  14. scope=None):

参数说明如下:

  • inputs: A tensor of at least rank 2 and static value for the last dimension; i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
  • num_outputs: Integer or long, the number of output units in the layer.
  • activation_fn: Activation function. The default value is a ReLU function.Explicitly set it to None to skip it and maintain a linear activation.
  • normalizer_fn: Normalization function to use instead of `biases`. If `normalizer_fn` is provided then `biases_initializer` and
  • `biases_regularizer` are ignored and `biases` are not created nor added.default set to None for no normalizer function
  • normalizer_params: Normalization function parameters.
  • weights_initializer: An initializer for the weights.
  • weights_regularizer: Optional regularizer for the weights.
  • biases_initializer: An initializer for the biases. If None skip biases.
  • biases_regularizer: Optional regularizer for the biases.
  • reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given.
  • variables_collections: Optional list of collections for all the variables or a dictionary containing a different list of collections per variable.
  • outputs_collections: Collection to add the outputs.
  • trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).如果我们是微调网络,有时候需要冻结某一层的参数,则设置为False。
  • scope: Optional scope for variable_scope.

二 改写cifar10分类

代码如下:

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Thu May 3 12:29:16 2018
  4.  
  5. @author: zy
  6. """
  7.  
  8. '''
  9. 建立一个带有全连接层的卷积神经网络 并对CIFAR-10数据集进行分类
  10. 1.使用2个卷积层的同卷积操作,滤波器大小为5x5,每个卷积层后面都会跟一个步长为2x2的池化层,滤波器大小为2x2
  11. 2.对输出的64个feature map进行全局平均池化,得到64个特征
  12. 3.加入一个全连接层,使用softmax激活函数,得到分类
  13. '''
  14.  
  15. import cifar10_input
  16. import tensorflow as tf
  17. import numpy as np
  18.  
  19. def print_op_shape(t):
  20. '''
  21. 输出一个操作op节点的形状
  22. '''
  23. print(t.op.name,'',t.get_shape().as_list())
  24.  
  25. '''
  26. 一 引入数据集
  27. '''
  28. batch_size = 128
  29. learning_rate = 1e-4
  30. training_step = 15000
  31. display_step = 200
  32. #数据集目录
  33. data_dir = './cifar10_data/cifar-10-batches-bin'
  34. print('begin')
  35. #获取训练集数据
  36. images_train,labels_train = cifar10_input.inputs(eval_data=False,data_dir = data_dir,batch_size=batch_size)
  37. print('begin data')
  38.  
  39. '''
  40. 二 定义网络结构
  41. '''
  42.  
  43. #定义占位符
  44. input_x = tf.placeholder(dtype=tf.float32,shape=[None,24,24,3]) #图像大小24x24x
  45. input_y = tf.placeholder(dtype=tf.float32,shape=[None,10]) #0-9类别
  46.  
  47. x_image = tf.reshape(input_x,[batch_size,24,24,3])
  48.  
  49. #1.卷积层 ->池化层
  50.  
  51. h_conv1 = tf.contrib.layers.conv2d(inputs=x_image,num_outputs=64,kernel_size=5,stride=1,padding='SAME', activation_fn=tf.nn.relu) #输出为[-1,24,24,64]
  52. print_op_shape(h_conv1)
  53. h_pool1 = tf.contrib.layers.max_pool2d(inputs=h_conv1,kernel_size=2,stride=2,padding='SAME') #输出为[-1,12,12,64]
  54. print_op_shape(h_pool1)
  55.  
  56. #2.卷积层 ->池化层
  57.  
  58. h_conv2 =tf.contrib.layers.conv2d(inputs=h_pool1,num_outputs=64,kernel_size=[5,5],stride=[1,1],padding='SAME', activation_fn=tf.nn.relu) #输出为[-1,12,12,64]
  59. print_op_shape(h_conv2)
  60. h_pool2 = tf.contrib.layers.max_pool2d(inputs=h_conv2,kernel_size=[2,2],stride=[2,2],padding='SAME') #输出为[-1,6,6,64]
  61. print_op_shape(h_pool2)
  62.  
  63. #3全连接层
  64.  
  65. nt_hpool2 = tf.contrib.layers.avg_pool2d(inputs=h_pool2,kernel_size=6,stride=6,padding='SAME') #输出为[-1,1,1,64]
  66. print_op_shape(nt_hpool2)
  67. nt_hpool2_flat = tf.reshape(nt_hpool2,[-1,64])
  68. y_conv = tf.contrib.layers.fully_connected(inputs=nt_hpool2_flat,num_outputs=10,activation_fn=tf.nn.softmax)
  69. print_op_shape(y_conv)
  70.  
  71. '''
  72. 三 定义求解器
  73. '''
  74.  
  75. #softmax交叉熵代价函数
  76. cost = tf.reduce_mean(-tf.reduce_sum(input_y * tf.log(y_conv),axis=1))
  77.  
  78. #求解器
  79. train = tf.train.AdamOptimizer(learning_rate).minimize(cost)
  80.  
  81. #返回一个准确度的数据
  82. correct_prediction = tf.equal(tf.arg_max(y_conv,1),tf.arg_max(input_y,1))
  83. #准确率
  84. accuracy = tf.reduce_mean(tf.cast(correct_prediction,dtype=tf.float32))
  85.  
  86. '''
  87. 四 开始训练
  88. '''
  89. sess = tf.Session();
  90. sess.run(tf.global_variables_initializer())
  91. # 启动计算图中所有的队列线程 调用tf.train.start_queue_runners来将文件名填充到队列,否则read操作会被阻塞到文件名队列中有值为止。
  92. tf.train.start_queue_runners(sess=sess)
  93.  
  94. for step in range(training_step):
  95. #获取batch_size大小数据集
  96. image_batch,label_batch = sess.run([images_train,labels_train])
  97.  
  98. #one hot编码
  99. label_b = np.eye(10,dtype=np.float32)[label_batch]
  100.  
  101. #开始训练
  102. train.run(feed_dict={input_x:image_batch,input_y:label_b},session=sess)
  103.  
  104. if step % display_step == 0:
  105. train_accuracy = accuracy.eval(feed_dict={input_x:image_batch,input_y:label_b},session=sess)
  106. print('Step {0} tranining accuracy {1}'.format(step,train_accuracy))

第十六节,使用函数封装库tf.contrib.layers的更多相关文章

  1. 第三百三十六节,web爬虫讲解2—urllib库中使用xpath表达式—BeautifulSoup基础

    第三百三十六节,web爬虫讲解2—urllib库中使用xpath表达式—BeautifulSoup基础 在urllib中,我们一样可以使用xpath表达式进行信息提取,此时,你需要首先安装lxml模块 ...

  2. centos shell脚本编程2 if 判断 case判断 shell脚本中的循环 for while shell中的函数 break continue test 命令 第三十六节课

    centos  shell脚本编程2 if 判断  case判断   shell脚本中的循环  for   while   shell中的函数  break  continue  test 命令   ...

  3. ASP.NET MVC深入浅出系列(持续更新) ORM系列之Entity FrameWork详解(持续更新) 第十六节:语法总结(3)(C#6.0和C#7.0新语法) 第三节:深度剖析各类数据结构(Array、List、Queue、Stack)及线程安全问题和yeild关键字 各种通讯连接方式 设计模式篇 第十二节: 总结Quartz.Net几种部署模式(IIS、Exe、服务部署【借

    ASP.NET MVC深入浅出系列(持续更新)   一. ASP.NET体系 从事.Net开发以来,最先接触的Web开发框架是Asp.Net WebForm,该框架高度封装,为了隐藏Http的无状态模 ...

  4. 第一百二十六节,JavaScript,XPath操作xml节点

    第一百二十六节,JavaScript,XPath操作xml节点 学习要点: 1.IE中的XPath 2.W3C中的XPath 3.XPath跨浏览器兼容 XPath是一种节点查找手段,对比之前使用标准 ...

  5. 第四百一十六节,Tensorflow简介与安装

    第四百一十六节,Tensorflow简介与安装 TensorFlow是什么 Tensorflow是一个Google开发的第二代机器学习系统,克服了第一代系统DistBelief仅能开发神经网络算法.难 ...

  6. 第三百四十六节,Python分布式爬虫打造搜索引擎Scrapy精讲—Requests请求和Response响应介绍

    第三百四十六节,Python分布式爬虫打造搜索引擎Scrapy精讲—Requests请求和Response响应介绍 Requests请求 Requests请求就是我们在爬虫文件写的Requests() ...

  7. 第三百二十六节,web爬虫,scrapy模块,解决重复ur——自动递归url

    第三百二十六节,web爬虫,scrapy模块,解决重复url——自动递归url 一般抓取过的url不重复抓取,那么就需要记录url,判断当前URL如果在记录里说明已经抓取过了,如果不存在说明没抓取过 ...

  8. 大白话5分钟带你走进人工智能-第二十六节决策树系列之Cart回归树及其参数(5)

                                                    第二十六节决策树系列之Cart回归树及其参数(5) 上一节我们讲了不同的决策树对应的计算纯度的计算方法, ...

  9. m_Orchestrate learning system---二十六、动态给封装好的控件添加属性

    m_Orchestrate learning system---二十六.动态给封装好的控件添加属性 一.总结 一句话总结:比如我现在封装好了ueditor控件,我外部调用这个控件,因为要写数据到数据库 ...

随机推荐

  1. jackson使用问题:mapper.readValue()将JSON字符串转反序列化为对象失败或异常

    问题根源:转化目标实体类的属性要与被转JSON字符串总的字段 一 一对应!字符串里可以少字段,但绝对不能多字段. 先附上我这段出现了问题的源码: // 1.接收并转化相应的参数.需要在pom.xml中 ...

  2. python 三目运算符

    格式: true_res if condition else false_res Meto 1: Meto 2: >>> x = 2 >>> x+1 if x!=1 ...

  3. AMS工作原理—— App启动概要

    说明: 1. 通过Launcher或者startActivity启动最终的流程都是和上面的一致的. 2. AMP是AMS在App端(client端)的代理, ATP是ApplicationThread ...

  4. 把当前ubuntu系统做成镜像

    把当前ubuntu系统做成镜像 2018年06月19日 15:24:51 还需要再学习一个 阅读数:9720 原文地址: http://community.bwbot.org/topic/167/%E ...

  5. 数据库 -- pymysql

    pythen3连接mysql pymsql介绍 PyMySQL 是在 Python3.x 版本中用于连接 MySQL 服务器的一个库,Python2中则使用mysqldb. Django中也可以使用P ...

  6. Android路径之Javascript基础-笔记

    一.Javascript概述(知道) a.一种基于对象和事件驱动的脚本语言 b.作用: 给页面添加动态效果 c.历史: 原名叫做livescript.W3c组织开发的标准叫ECMAscipt. d.特 ...

  7. 【XSY2535】整数 NTT

    题目描述 问有多少个满足以下要求的\(k\)进制数: 1.每个数字出现的次数不超过\(n\) 2.\(0\)没有出现过 3.若\(g_{i,j}=0\),则\(i\)不能出现恰好\(j\)次. 两次询 ...

  8. Google Apps的单点登录-谷歌使用的单点登录

    简述: Customer :客户 Google:谷歌 Identity Provider:身份提供者安全断言标记语言(英语:Security Assertion Markup Language,简称S ...

  9. SPHINX 文档写作工具安装简要指南 - windows 版 - 基于python

    此教程基于本地己安装好 PYTHON 并配置过全局变量:一定具备相应的基础再操作: 上传图片以免产生误导,以下为文字描述,按下列操作即可: 下载 get-pip.py脚本; python get-pi ...

  10. Android stadio 生成项目 Plugin with id 'com.android.application' not found

    buildscript { repositories { jcenter() } dependencies { classpath 'com.android.tools.build:gradle:2. ...