1. #coding:utf-8
  2. import tensorflow as tf
  3. import os
  4. def read_and_decode(filename):
  5. #根据文件名生成一个队列
  6. filename_queue = tf.train.string_input_producer([filename])
  7. reader = tf.TFRecordReader()
  8. _, serialized_example = reader.read(filename_queue) #返回文件名和文件
  9. features = tf.parse_single_example(serialized_example,
  10. features={
  11. 'label': tf.FixedLenFeature([], tf.int64),
  12. 'img_raw' : tf.FixedLenFeature([], tf.string),
  13. })
  14.  
  15. img = tf.decode_raw(features['img_raw'], tf.uint8)
  16. img = tf.reshape(img, [227, 227, 3])
  17. img = (tf.cast(img, tf.float32) * (1. / 255) - 0.5)*2
  18. label = tf.cast(features['label'], tf.int32)
  19. print img,label
  20. return img, label
  21.  
  22. def get_batch(image, label, batch_size,crop_size):
  23. #数据扩充变换
  24. distorted_image = tf.random_crop(image, [crop_size, crop_size, 3])#随机裁剪
  25. distorted_image = tf.image.random_flip_up_down(distorted_image)#上下随机翻转
  26. distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)#亮度变化
  27. distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)#对比度变化
  28.  
  29. #生成batch
  30. #shuffle_batch的参数:capacity用于定义shuttle的范围,如果是对整个训练数据集,获取batch,那么capacity就应该够大
  31. #保证数据打的足够乱
  32. images, label_batch = tf.train.shuffle_batch([distorted_image, label],batch_size=batch_size,
  33. num_threads=1,capacity=2000,min_after_dequeue=1000)
  34.  
  35. return images, label_batch
  36.  
  37. class network(object):
  38.  
  39. def lenet(self,images,keep_prob):
  40.  
  41. '''
  42. 根据tensorflow中的conv2d函数,我们先定义几个基本符号
  43. 输入矩阵 W×W,这里只考虑输入宽高相等的情况,如果不相等,推导方法一样,不多解释。
  44. filter矩阵 F×F,卷积核
  45. stride值 S,步长
  46. 输出宽高为 new_height、new_width
  47. 在Tensorflow中对padding定义了两种取值:VALID、SAME。下面分别就这两种定义进行解释说明。
  48. VALID
  49. new_height = new_width = (W – F + 1) / S #结果向上取整
  50. SAME
  51. new_height = new_width = W / S #结果向上取整
  52. '''
  53.  
  54. images = tf.reshape(images,shape=[-1,32,32,3])
  55. #images = (tf.cast(images,tf.float32)/255.0-0.5)*2
  56. #第一层,卷积层 32,32,3--->5,5,3,6--->28,28,6
  57. #卷积核大小为5*5 输入层深度为3即三通道图像 卷积核深度为6即卷积核的个数
  58. conv1_weights = tf.get_variable("conv1_weights",[5,5,3,6],initializer = tf.truncated_normal_initializer(stddev=0.1))
  59. conv1_biases = tf.get_variable("conv1_biases",[6],initializer = tf.constant_initializer(0.0))
  60. #移动步长为1 不使用全0填充
  61. conv1 = tf.nn.conv2d(images,conv1_weights,strides=[1,1,1,1],padding='VALID')
  62. #激活函数Relu去线性化
  63. relu1 = tf.nn.relu(tf.nn.bias_add(conv1,conv1_biases))
  64.  
  65. #第二层 最大池化层 28,28,6--->1,2,2,1--->14,14,6
  66. #池化层过滤器大小为2*2 移动步长为2 使用全0填充
  67. pool1 = tf.nn.max_pool(relu1, ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
  68.  
  69. #第三层 卷积层 14,14,6--->5,5,6,16--->10,10,16
  70. #卷积核大小为5*5 当前层深度为6 卷积核的深度为16
  71. conv2_weights = tf.get_variable("conv_weights",[5,5,6,16],initializer = tf.truncated_normal_initializer(stddev=0.1))
  72. conv2_biases = tf.get_variable("conv2_biases",[16],initializer = tf.constant_initializer(0.0))
  73.  
  74. conv2 = tf.nn.conv2d(pool1,conv2_weights,strides=[1,1,1,1],padding='VALID') #移动步长为1 不使用全0填充
  75. relu2 = tf.nn.relu(tf.nn.bias_add(conv2,conv2_biases))
  76.  
  77. #第四层 最大池化层 10,10,16--->1,2,2,1--->5,5,16
  78. #池化层过滤器大小为2*2 移动步长为2 使用全0填充
  79. pool2 = tf.nn.max_pool(relu2,ksize = [1,2,2,1],strides=[1,2,2,1],padding='SAME')
  80.  
  81. #第五层 全连接层
  82. fc1_weights = tf.get_variable("fc1_weights",[5*5*16,1024],initializer = tf.truncated_normal_initializer(stddev=0.1))
  83. fc1_biases = tf.get_variable("fc1_biases",[1024],initializer = tf.constant_initializer(0.1)) #[1,1024]
  84. pool2_vector = tf.reshape(pool2,[-1,5*5*16]) #特征向量扁平化 原始的每一张图变成了一行9×9*64列的向量
  85. fc1 = tf.nn.relu(tf.matmul(pool2_vector,fc1_weights)+fc1_biases)
  86.  
  87. #为了减少过拟合 加入dropout层
  88.  
  89. fc1_dropout = tf.nn.dropout(fc1,keep_prob)
  90.  
  91. #第六层 全连接层
  92. #神经元节点数为1024 分类节点2
  93. fc2_weights = tf.get_variable("fc2_weights",[1024,2],initializer=tf.truncated_normal_initializer(stddev=0.1))
  94. fc2_biases = tf.get_variable("fc2_biases",[2],initializer = tf.constant_initializer(0.1))
  95. fc2 = tf.matmul(fc1_dropout,fc2_weights) + fc2_biases
  96.  
  97. return fc2
  98. def lenet_loss(self,fc2,y_):
  99.  
  100. #第七层 输出层
  101. #softmax
  102. y_conv = tf.nn.softmax(fc2)
  103. labels=tf.one_hot(y_,2)
  104. #定义交叉熵损失函数
  105. #cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv),reduction_indices=[1]))
  106. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = y_conv, labels =labels))
  107. self.cost = loss
  108. return self.cost
  109.  
  110. def lenet_optimer(self,loss):
  111. train_optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)
  112. return train_optimizer
  113.  
  114. def train():
  115. image,label=read_and_decode("./train.tfrecords")
  116. batch_image,batch_label=get_batch(image,label,batch_size=30,crop_size=32)
  117. #建立网络,训练所用
  118. x = tf.placeholder("float",shape=[None,32,32,3],name='x-input')
  119. y_ = tf.placeholder("int32",shape=[None])
  120. keep_prob = tf.placeholder(tf.float32)
  121.  
  122. net=network()
  123. #inf=net.buildnet(batch_image)
  124. inf = net.lenet(x,keep_prob)
  125. loss=net.lenet_loss(inf,y_) #计算loss
  126. opti=net.optimer(loss) #梯度下降
  127.  
  128. correct_prediction = tf.equal(tf.cast(tf.argmax(inf,1),tf.int32),batch_label)
  129. accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  130.  
  131. init=tf.global_variables_initializer()
  132. with tf.Session() as session:
  133. with tf.device("/gpu:0"):
  134. session.run(init)
  135. coord = tf.train.Coordinator()
  136. threads = tf.train.start_queue_runners(coord=coord)
  137. max_iter=10000
  138. iter=0
  139. if os.path.exists(os.path.join("model",'model.ckpt')) is True:
  140. tf.train.Saver(max_to_keep=None).restore(session, os.path.join("model",'model.ckpt'))
  141. while iter<max_iter:
  142. #loss_np,_,label_np,image_np,inf_np=session.run([loss,opti,batch_image,batch_label,inf])
  143. b_batch_image,b_batch_label = session.run([batch_image,batch_label])
  144. loss_np,_=session.run([loss,opti],feed_dict={x:b_batch_image,y_:b_batch_label,keep_prob:0.6})
  145. if iter%50==0:
  146. print 'trainloss:',loss_np
  147. if iter%500==0:
  148. #accuracy_np = session.run([accuracy])
  149. accuracy_np = session.run([accuracy],feed_dict={x:b_batch_image,y_:b_batch_label,keep_prob:1.0})
  150. print 'xxxxxxxxxxxxxxxxxxxxxx',accuracy_np
  151. iter+=1
  152. coord.request_stop()#queue需要关闭,否则报错
  153. coord.join(threads)
  154. if __name__ == '__main__':
  155. train()

Tensorflow学习教程------实现lenet并且进行二分类的更多相关文章

  1. Tensorflow学习教程------普通神经网络对mnist数据集分类

    首先是不含隐层的神经网络, 输入层是784个神经元 输出层是10个神经元 代码如下 #coding:utf-8 import tensorflow as tf from tensorflow.exam ...

  2. Tensorflow学习教程------过拟合

    Tensorflow学习教程------过拟合   回归:过拟合情况 / 分类过拟合 防止过拟合的方法有三种: 1 增加数据集 2 添加正则项 3 Dropout,意思就是训练的时候隐层神经元每次随机 ...

  3. Tensorflow学习教程------代价函数

    Tensorflow学习教程------代价函数   二次代价函数(quadratic cost): 其中,C表示代价函数,x表示样本,y表示实际值,a表示输出值,n表示样本的总数.为简单起见,使用一 ...

  4. Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例

    紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取.建立网络以及模型训练整理成一个小样例,完整代码如下. #coding:utf-8 import t ...

  5. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  6. tensorflow 学习教程

    tensorflow 学习手册 tensorflow 学习手册1:https://cloud.tencent.com/developer/section/1475687 tensorflow 学习手册 ...

  7. Tensorflow学习教程------创建图启动图

    Tensorflow作为目前最热门的机器学习框架之一,受到了工业界和学界的热门追捧.以下几章教程将记录本人学习tensorflow的一些过程. 在tensorflow这个框架里,可以讲是若数据类型,也 ...

  8. Tensorflow学习教程------非线性回归

    自己搭建神经网络求解非线性回归系数 代码 #coding:utf-8 import tensorflow as tf import numpy as np import matplotlib.pypl ...

  9. Tensorflow学习教程------tensorboard网络运行和可视化

    tensorboard可以将训练过程中的一些参数可视化,比如我们最关注的loss值和accuracy值,简单来说就是把这些值的变化记录在日志里,然后将日志里的这些数据可视化. 首先运行训练代码 #co ...

随机推荐

  1. 十三、react-router 4.x的基本配置

    路由的定义及作用 根组件根据客户端不同的请求网址显示时,要卸载上一个组件,再挂载下一个组件,如果手动操作话将是一个巨大麻烦.具体过程如下图: [根组件] ↑ [首页组件] [新闻组件] [商品组件] ...

  2. Linux命令笔记一

    #查看文件大小[root@elegant-codes-3 py]# ls -lh total 1.1M -rw-r--r-- 1 root root 5.0K Feb 21 08:18 Crawl_W ...

  3. 支持 UTF-8 中文的串口调试工具

    最近使用 mdk526,编辑设置使用 utf-8,编辑窗口中文正常,但是编译的时候提示 warning: #870-D: invalid multibyte character sequence,解决 ...

  4. div自动适应高度

    div高度100%<!DOCTYPE html> <html> <head></head> <body> <div id=" ...

  5. springboot-jar-web

    预览 与springboot-jar的区别是: 1.pom.xml 将 <dependency> <groupId>org.springframework.boot</g ...

  6. 7.CSRF攻击和文件上传漏洞攻击

    一.CSRF攻击及防范措施 1.概念 请求来源于其他网站,请求并不是用户的意愿,而是伪造的请求,诱导用户发起的请求 2.场景 攻击者盗用了你的身份,以你的名义发送恶意请求.CSRF能够做的事情包括:以 ...

  7. [XNUCA2019Qualifier]EasyPHP

    0x00 知识点 预期解中知识点: htaccess生效 如果尝试上传htaccess文件会发现出现响应500的问题,因为文件尾有Just one chance 这里采用# \的方式将换行符转义成普通 ...

  8. Q4:Median of Two Sorted Arrays

    4. Median of Two Sorted Arrays 官方的链接:4. Median of Two Sorted Arrays Description : There are two sort ...

  9. INNER JOIN & OUTER JOIN

    INNER JOIN & OUTER JOIN 参考:sql

  10. rabbit-mq cluster安装

    Centos6.5 安装 RabbitMQ3.6.5 一.安装编译工具 yum -y install make gcc gcc-c++ kernel-devel m4 ncurses-devel op ...