紧接上篇Tensorflow学习教程------tfrecords数据格式生成与读取,本篇将数据读取、建立网络以及模型训练整理成一个小样例,完整代码如下。

  1. #coding:utf-8
  2. import tensorflow as tf
  3. import os
  4. def read_and_decode(filename):
  5. #根据文件名生成一个队列
  6. filename_queue = tf.train.string_input_producer([filename])
  7. reader = tf.TFRecordReader()
  8. _, serialized_example = reader.read(filename_queue) #返回文件名和文件
  9. features = tf.parse_single_example(serialized_example,
  10. features={
  11. 'label': tf.FixedLenFeature([], tf.int64),
  12. 'img_raw' : tf.FixedLenFeature([], tf.string),
  13. })
  14.  
  15. img = tf.decode_raw(features['img_raw'], tf.uint8)
  16. img = tf.reshape(img, [227, 227, 3])
  17. img = (tf.cast(img, tf.float32) * (1. / 255) - 0.5)*2
  18. label = tf.cast(features['label'], tf.int32)
  19. print img,label
  20. return img, label
  21.  
  22. def get_batch(image, label, batch_size,crop_size):
  23. #数据扩充变换
  24. distorted_image = tf.random_crop(image, [crop_size, crop_size, 3])#随机裁剪
  25. distorted_image = tf.image.random_flip_up_down(distorted_image)#上下随机翻转
  26. distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)#亮度变化
  27. distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)#对比度变化
  28.  
  29. #生成batch
  30. #shuffle_batch的参数:capacity用于定义shuttle的范围,如果是对整个训练数据集,获取batch,那么capacity就应该够大
  31. #保证数据打的足够乱
  32. images, label_batch = tf.train.shuffle_batch([distorted_image, label],batch_size=batch_size,
  33. num_threads=1,capacity=2000,min_after_dequeue=1000)
  34.  
  35. return images, label_batch
  36.  
  37. class network(object):
  38. #构造函数初始化 卷积层 全连接层
  39. def __init__(self):
  40. with tf.variable_scope("weights"):
  41. self.weights={
  42.  
  43. 'conv1':tf.get_variable('conv1',[4,4,3,20],initializer=tf.contrib.layers.xavier_initializer_conv2d()),
  44.  
  45. 'conv2':tf.get_variable('conv2',[3,3,20,40],initializer=tf.contrib.layers.xavier_initializer_conv2d()),
  46.  
  47. 'conv3':tf.get_variable('conv3',[3,3,40,60],initializer=tf.contrib.layers.xavier_initializer_conv2d()),
  48.  
  49. 'fc1':tf.get_variable('fc1',[3*3*60,120],initializer=tf.contrib.layers.xavier_initializer()),
  50. 'fc2':tf.get_variable('fc2',[120,2],initializer=tf.contrib.layers.xavier_initializer()),
  51.  
  52. }
  53. with tf.variable_scope("biases"):
  54. self.biases={
  55. 'conv1':tf.get_variable('conv1',[20,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
  56. 'conv2':tf.get_variable('conv2',[40,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
  57. 'conv3':tf.get_variable('conv3',[60,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
  58.  
  59. 'fc1':tf.get_variable('fc1',[120,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
  60. 'fc2':tf.get_variable('fc2',[2,],initializer=tf.constant_initializer(value=0.0, dtype=tf.float32)),
  61.  
  62. }
  63.  
  64. def buildnet(self,images):
  65. #向量转为矩阵
  66. images = tf.reshape(images, shape=[-1, 39,39, 3])# [batch, in_height, in_width, in_channels]
  67. images=(tf.cast(images,tf.float32)/255.-0.5)*2#归一化处理
  68.  
  69. #第一层
  70. conv1=tf.nn.bias_add(tf.nn.conv2d(images, self.weights['conv1'], strides=[1, 1, 1, 1], padding='SAME'),
  71. self.biases['conv1'])
  72. relu1= tf.nn.relu(conv1)
  73. pool1=tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
  74.  
  75. #第二层
  76. conv2=tf.nn.bias_add(tf.nn.conv2d(pool1, self.weights['conv2'], strides=[1, 1, 1, 1], padding='VALID'),
  77. self.biases['conv2'])
  78. relu2= tf.nn.relu(conv2)
  79. pool2=tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
  80.  
  81. # 第三层
  82. conv3=tf.nn.bias_add(tf.nn.conv2d(pool2, self.weights['conv3'], strides=[1, 1, 1, 1], padding='VALID'),
  83. self.biases['conv3'])
  84. relu3= tf.nn.relu(conv3)
  85. pool3=tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
  86.  
  87. # 全连接层1,先把特征图转为向量
  88. flatten = tf.reshape(pool3, [-1, self.weights['fc1'].get_shape().as_list()[0]])
  89. drop1=tf.nn.dropout(flatten,0.5)
  90. fc1=tf.matmul(drop1, self.weights['fc1'])+self.biases['fc1']
  91. fc_relu1=tf.nn.relu(fc1)
  92. fc2=tf.matmul(fc_relu1, self.weights['fc2'])+self.biases['fc2']
  93. return fc2
  94.  
  95. #计算softmax交叉熵损失函数
  96. def softmax_loss(self,predicts,labels):
  97. predicts=tf.nn.softmax(predicts)
  98. labels=tf.one_hot(labels,self.weights['fc2'].get_shape().as_list()[1])
  99. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = predicts, labels =labels))
  100. self.cost= loss
  101. return self.cost
  102. #梯度下降
  103. def optimer(self,loss,lr=0.01):
  104. train_optimizer = tf.train.GradientDescentOptimizer(lr).minimize(loss)
  105.  
  106. return train_optimizer
  107.  
  108. def train():
  109. image,label=read_and_decode("./train.tfrecords")
  110. batch_image,batch_label=get_batch(image,label,batch_size=30,crop_size=39)
  111. #建立网络,训练所用
  112. net=network()
  113. inf=net.buildnet(batch_image)
  114. loss=net.softmax_loss(inf,batch_label) #计算loss
  115. opti=net.optimer(loss) #梯度下降
  116.  
  117. init=tf.global_variables_initializer()
  118. with tf.Session() as session:
  119. with tf.device("/gpu:0"):
  120. session.run(init)
  121. coord = tf.train.Coordinator()
  122. threads = tf.train.start_queue_runners(coord=coord)
  123. max_iter=1000
  124. iter=0
  125. if os.path.exists(os.path.join("model",'model.ckpt')) is True:
  126. tf.train.Saver(max_to_keep=None).restore(session, os.path.join("model",'model.ckpt'))
  127. while iter<max_iter:
  128. loss_np,_,label_np,image_np,inf_np=session.run([loss,opti,batch_image,batch_label,inf])
  129. if iter%50==0:
  130. print 'trainloss:',loss_np
  131. iter+=1
  132. coord.request_stop()#queue需要关闭,否则报错
  133. coord.join(threads)
  134. if __name__ == '__main__':
  135. train()

结果如下:

  1. Total memory: 10.91GiB
  2. Free memory: 10.16GiB
  3. 2018-02-02 10:13:24.462286: I tensorflow/core/common_runtime/gpu/gpu_device.cc:961] DMA: 0
  4. 2018-02-02 10:13:24.462294: I tensorflow/core/common_runtime/gpu/gpu_device.cc:971] 0: Y
  5. 2018-02-02 10:13:24.462303: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030] Creating TensorFlow device (/gpu:0) -> (device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:03:00.0)
  6. trainloss: 0.745739
  7. trainloss: 0.330364
  8. trainloss: 0.317668
  9. trainloss: 0.314964
  10. trainloss: 0.314613
  11. trainloss: 0.314483
  12. trainloss: 0.314132
  13. trainloss: 0.313661

Tensorflow学习教程------读取数据、建立网络、训练模型,小巧而完整的代码示例的更多相关文章

  1. Tensorflow学习教程------过拟合

    Tensorflow学习教程------过拟合   回归:过拟合情况 / 分类过拟合 防止过拟合的方法有三种: 1 增加数据集 2 添加正则项 3 Dropout,意思就是训练的时候隐层神经元每次随机 ...

  2. ASP.NET MVC 5 学习教程:数据迁移之添加字段

    原文 ASP.NET MVC 5 学习教程:数据迁移之添加字段 起飞网 ASP.NET MVC 5 学习教程目录: 添加控制器 添加视图 修改视图和布局页 控制器传递数据给视图 添加模型 创建连接字符 ...

  3. Tensorflow学习教程------代价函数

    Tensorflow学习教程------代价函数   二次代价函数(quadratic cost): 其中,C表示代价函数,x表示样本,y表示实际值,a表示输出值,n表示样本的总数.为简单起见,使用一 ...

  4. TensorFlow queue多线程读取数据

    一.tensorflow读取机制图解 我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率. 解决 ...

  5. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  6. Tensorflow机器学习入门——读取数据

    TensorFlow 中可以通过三种方式读取数据: 一.通过feed_dict传递数据: input1 = tf.placeholder(tf.float32) input2 = tf.placeho ...

  7. Tensorflow学习教程------实现lenet并且进行二分类

    #coding:utf-8 import tensorflow as tf import os def read_and_decode(filename): #根据文件名生成一个队列 filename ...

  8. TensorFlow从0到1之TensorFlow csv文件读取数据(14)

    大多数人了解 Pandas 及其在处理大数据文件方面的实用性.TensorFlow 提供了读取这种文件的方法. 前面章节中,介绍了如何在 TensorFlow 中读取文件,本节将重点介绍如何从 CSV ...

  9. tensorflow 学习教程

    tensorflow 学习手册 tensorflow 学习手册1:https://cloud.tencent.com/developer/section/1475687 tensorflow 学习手册 ...

随机推荐

  1. oracle 使用触发器实现id自增

    前提:存在数据库di_test,主键为id.1.创建一个索引sequence create sequence di_test_id minvalue 1 nomaxvalue start with 1 ...

  2. linux之 文本编辑 的基础知识点

    第一步 打开终端 创建文件命令 touch 文件名.后缀名 打开文件命令 vi 文件名.后缀名 (此时进去txt文件之后为一般模式,你无法对文件进行增删改) 之后按 i    或 a    或o  都 ...

  3. SHELL学习笔记三

    SHELL学习笔记一 SHELL学习笔记二 SHELL学习笔记三 for 命令 读取列表中的复杂值 从变量读取列表 从命令读取值 更改字段分隔符 用通配符读取目录 which 使用多个测试命令 unt ...

  4. 超低功耗Sub-1GHz性价比首选方案:CMT2300

    关于超低功耗Sub-1GHz射频收发器,目前性价比方面CMT2300是一款大多客户的首选方案,不管是成本方面还是性能方面,都能大大的满足客户的需求.下面为大家讲解下CMT2300 这款Sub-1GHz ...

  5. 使用NtQueryInformationFile函数获得不到完整路径

    #include <windows.h> #include <iostream> using namespace std; typedef struct _OBJECT_NAM ...

  6. code force 1228C

    算是一题普通数论+思维题吧. 大概很多人是被题意绕晕了. 思路: 首先常规操作求出X的质因子. 然后题目要求的是,X的每个质因子p,在g(i,p)的连乘.i∈[1,n]: 我们转换下思维,不求每一个g ...

  7. Vue-router(3)之 router-link 和 router-view 使用

    router 导入 import Vue from 'vue' import Router from 'vue-router' import order from '@/view/New/order. ...

  8. 代码杂谈-python函数

    发现函数可以设置属性变量, 如下 newfunc.func , newfunc.args def partial(func, *args, **keywords): """ ...

  9. C# 创建Windows服务。服务功能:定时操作数据库

      一.创建window服务 1.新建项目-->选择Windows服务.默认生成文件包括Program.cs,Service1.cs 2.在Service1.cs添加如下代码: System.T ...

  10. bootstrap 网格

    实现原理 网格系统的实现原理非常简单,仅仅是通过定义容器大小,平分12份(也有平分成24份或32份,但12份是最常见的),再调整内外边距,最后结合媒体查询,就制作出了强大的响应式网格系统.Bootst ...