很多正在入门或刚入门TensorFlow机器学习的同学希望能够通过自己指定图片源对模型进行训练,然后识别和分类自己指定的图片。但是,在TensorFlow官方入门教程中,并无明确给出如何把自定义数据输入训练模型的方法。现在,我们就参考官方入门课程《Deep MNIST for Experts》一节的内容(传送门:https://www.tensorflow.org/get_started/mnist/pros),介绍如何将自定义图片输入到TensorFlow的训练模型。

在《Deep MNISTfor Experts》一节的代码中,程序将TensorFlow自带的mnist图片数据集mnist.train.images作为训练输入,将mnist.test.images作为验证输入。当学习了该节内容后,我们会惊叹卷积神经网络的超高识别率,但对于刚开始学习TensorFlow的同学,内心可能会产生一个问号:如何将mnist数据集替换为自己指定的图片源?譬如,我要将图片源改为自己C盘里面的图片,应该怎么调整代码?

我们先看下该节课程中涉及到mnist图片调用的代码:

  1. from tensorflow.examples.tutorials.mnist import input_data
  2. mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
  3. batch = mnist.train.next_batch(50)
  4. train_accuracy = accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
  5. train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
  6. print('test accuracy %g' % accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

对于刚接触TensorFlow的同学,要修改上述代码,可能会较为吃力。我也是经过一番摸索,才成功调用自己的图片集。

要实现输入自定义图片,需要自己先准备好一套图片集。为节省时间,我们把mnist的手写体数字集一张一张地解析出来,存放到自己的本地硬盘,保存为bmp格式,然后再把本地硬盘的手写体图片一张一张地读取出来,组成集合,再输入神经网络。mnist手写体数字集的提取方式详见《如何从TensorFlow的mnist数据集导出手写体数字图片》。

将mnist手写体数字集导出图片到本地后,就可以仿照以下python代码,实现自定义图片的训练:

  1. #!/usr/bin/python3.5
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import numpy as np
  5. import tensorflow as tf
  6. from PIL import Image
  7. # 第一次遍历图片目录是为了获取图片总数
  8. input_count = 0
  9. for i in range(0,10):
  10. dir = './custom_images/%s/' % i                 # 这里可以改成你自己的图片目录,i为分类标签
  11. for rt, dirs, files in os.walk(dir):
  12. for filename in files:
  13. input_count += 1
  14. # 定义对应维数和各维长度的数组
  15. input_images = np.array([[0]*784 for i in range(input_count)])
  16. input_labels = np.array([[0]*10 for i in range(input_count)])
  17. # 第二次遍历图片目录是为了生成图片数据和标签
  18. index = 0
  19. for i in range(0,10):
  20. dir = './custom_images/%s/' % i                 # 这里可以改成你自己的图片目录,i为分类标签
  21. for rt, dirs, files in os.walk(dir):
  22. for filename in files:
  23. filename = dir + filename
  24. img = Image.open(filename)
  25. width = img.size[0]
  26. height = img.size[1]
  27. for h in range(0, height):
  28. for w in range(0, width):
  29. # 通过这样的处理,使数字的线条变细,有利于提高识别准确率
  30. if img.getpixel((w, h)) > 230:
  31. input_images[index][w+h*width] = 0
  32. else:
  33. input_images[index][w+h*width] = 1
  34. input_labels[index][i] = 1
  35. index += 1
  36. # 定义输入节点,对应于图片像素值矩阵集合和图片标签(即所代表的数字)
  37. x = tf.placeholder(tf.float32, shape=[None, 784])
  38. y_ = tf.placeholder(tf.float32, shape=[None, 10])
  39. x_image = tf.reshape(x, [-1, 28, 28, 1])
  40. # 定义第一个卷积层的variables和ops
  41. W_conv1 = tf.Variable(tf.truncated_normal([7, 7, 1, 32], stddev=0.1))
  42. b_conv1 = tf.Variable(tf.constant(0.1, shape=[32]))
  43. L1_conv = tf.nn.conv2d(x_image, W_conv1, strides=[1, 1, 1, 1], padding='SAME')
  44. L1_relu = tf.nn.relu(L1_conv + b_conv1)
  45. L1_pool = tf.nn.max_pool(L1_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  46. # 定义第二个卷积层的variables和ops
  47. W_conv2 = tf.Variable(tf.truncated_normal([3, 3, 32, 64], stddev=0.1))
  48. b_conv2 = tf.Variable(tf.constant(0.1, shape=[64]))
  49. L2_conv = tf.nn.conv2d(L1_pool, W_conv2, strides=[1, 1, 1, 1], padding='SAME')
  50. L2_relu = tf.nn.relu(L2_conv + b_conv2)
  51. L2_pool = tf.nn.max_pool(L2_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  52. # 全连接层
  53. W_fc1 = tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1))
  54. b_fc1 = tf.Variable(tf.constant(0.1, shape=[1024]))
  55. h_pool2_flat = tf.reshape(L2_pool, [-1, 7*7*64])
  56. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
  57. # dropout
  58. keep_prob = tf.placeholder(tf.float32)
  59. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  60. # readout层
  61. W_fc2 = tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1))
  62. b_fc2 = tf.Variable(tf.constant(0.1, shape=[10]))
  63. y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
  64. # 定义优化器和训练op
  65. cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
  66. train_step = tf.train.AdamOptimizer((1e-4)).minimize(cross_entropy)
  67. correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  68. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  69. with tf.Session() as sess:
  70. sess.run(tf.global_variables_initializer())
  71. print ("一共读取了 %s 个输入图像, %s 个标签" % (input_count, input_count))
  72. # 设置每次训练op的输入个数和迭代次数,这里为了支持任意图片总数,定义了一个余数remainder,譬如,如果每次训练op的输入个数为60,图片总数为150张,则前面两次各输入60张,最后一次输入30张(余数30)
  73. batch_size = 60
  74. iterations = 100
  75. batches_count = int(input_count / batch_size)
  76. remainder = input_count % batch_size
  77. print ("数据集分成 %s 批, 前面每批 %s 个数据,最后一批 %s 个数据" % (batches_count+1, batch_size, remainder))
  78. # 执行训练迭代
  79. for it in range(iterations):
  80. # 这里的关键是要把输入数组转为np.array
  81. for n in range(batches_count):
  82. train_step.run(feed_dict={x: input_images[n*batch_size:(n+1)*batch_size], y_: input_labels[n*batch_size:(n+1)*batch_size], keep_prob: 0.5})
  83. if remainder > 0:
  84. start_index = batches_count * batch_size;
  85. train_step.run(feed_dict={x: input_images[start_index:input_count-1], y_: input_labels[start_index:input_count-1], keep_prob: 0.5})
  86. # 每完成五次迭代,判断准确度是否已达到100%,达到则退出迭代循环
  87. iterate_accuracy = 0
  88. if it%5 == 0:
  89. iterate_accuracy = accuracy.eval(feed_dict={x: input_images, y_: input_labels, keep_prob: 1.0})
  90. print ('iteration %d: accuracy %s' % (it, iterate_accuracy))
  91. if iterate_accuracy >= 1:
  92. break;
  93. print ('完成训练!')

上述python代码的执行结果截图如下:

对于上述代码中与模型构建相关的代码,请查阅官方《Deep MNIST for Experts》一节的内容进行理解。在本文中,需要重点掌握的是如何将本地图片源整合成为feed_dict可接受的格式。其中最关键的是这两行:

  1. # 定义对应维数和各维长度的数组
  2. input_images = np.array([[0]*784 for i in range(input_count)])
  3. input_labels = np.array([[0]*10 for i in range(input_count)])

它们对应于feed_dict的两个placeholder:

  1. x = tf.placeholder(tf.float32, shape=[None, 784])
  2. y_ = tf.placeholder(tf.float32, shape=[None, 10])

利用Tensorflow训练自定义数据的更多相关文章

  1. 利用tensorflow训练简单的生成对抗网络GAN

    对抗网络是14年Goodfellow Ian在论文Generative Adversarial Nets中提出来的. 原理方面,对抗网络可以简单归纳为一个生成器(generator)和一个判断器(di ...

  2. TF:Tensorflow结构简单应用,随机生成100个数,利用Tensorflow训练使其逼近已知线性直线的效率和截距—Jason niu

    import os os.environ[' import tensorflow as tf import numpy as np x_data = np.random.rand(100).astyp ...

  3. TensorFlow.训练_资料(有视频)

    ZC:自己训练 的文章 貌似 能度娘出来很多,得 自己弄过才知道哪些个是坑 哪些个好用...(在CSDN文章的右侧 也有列出很多相关的文章链接)(貌似 度娘的关键字是"TensorFlow ...

  4. yolov5训练自定义数据集

    yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...

  5. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  6. TensorFlow下利用MNIST训练模型识别手写数字

    本文将参考TensorFlow中文社区官方文档使用mnist数据集训练一个多层卷积神经网络(LeNet5网络),并利用所训练的模型识别自己手写数字. 训练MNIST数据集,并保存训练模型 # Pyth ...

  7. 【实践】如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统)

    如何利用tensorflow的object_detection api开源框架训练基于自己数据集的模型(Windows10系统) 一.环境配置 1. Python3.7.x(注:我用的是3.7.3.安 ...

  8. 目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练

    将目标检测 的标注数据 .xml 转为 tfrecord 的格式用于 TensorFlow 训练. import xml.etree.ElementTree as ET import numpy as ...

  9. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

随机推荐

  1. 在windows下怎样更新vundle?

    本文出自Svitter的blog 更新Vundle的时候.不管是输出BundleInstall.还是PluginInstall! 都会调用系统的git,所以必须安装git才干达到目的更新插件. git ...

  2. iOS项目开发实战——使用Xcode6设计自己定义控件与图形

    在iOS开发中,有很多控件都是Xcode默认提供的.使用这些控件是很方便的.可是因为某些须要.须要自己设计控件,那么应该怎么做呢?在Xcode6中提供了这种接口,同意开发人员高速开发自己定义控件,而且 ...

  3. Android开发策略:缓存

    1.使用缓存策略时,优先考虑使用sdcard(需先推断有无sd卡及其剩余空间是否足够,够的话就开辟一定空间如10M): 2.获取图片时.先从sdcard上找,有的话使用该图片并更新图片最后被使用的时间 ...

  4. bzoj4240: 有趣的家庭菜园(树状数组+贪心思想)

    4240: 有趣的家庭菜园 题目:传送门 题解: 好题!%%% 一开始不知道在想什么鬼,感觉满足二分性?感觉可以维护一个先单调增再单调减的序列? 然后开始一顿瞎搞...一WA 看一波路牌...树状数组 ...

  5. Spark新愿景:让深度学习变得更加易于使用——见https://github.com/yahoo/TensorFlowOnSpark

    Spark新愿景:让深度学习变得更加易于使用   转自:https://www.jianshu.com/p/07e8200b7cea 前言 Spark成功的实现了当年的承诺,让数据处理变得更容易,现在 ...

  6. Oracle 10G 中的"回收站"

    在Oracle 10g数据库中,引入了一个回收站(Recycle Bin)的数据库对象. 回收站,从原理上来说就是一个数据字典表,放置用户Drop掉的数据库对象信息.用户进行Drop操作的对象并没有被 ...

  7. 如何正确产看API

    看API时,先看的它的父接口自接口,及其相关的抽象类和子类 看完后,看概述的第一段话就行,后面的不用看. 再看构造方法,并到底层去看构造方法里参数的具体含义. 最后,再将包含的方法一个个进行测试. 解 ...

  8. WinForm进程 线程

    进程主要调用另一程序,线程 分配工作. 一.进程: 进程是一个具有独立功能的程序关于某个数据集合的一次运行活动.它可以申请和拥有系统资源,是一个动态的概念,是一个活动的实体.Process 类,用来操 ...

  9. LUA 创建文件和文件夹

    创建文件: os.execute('mkdir e:\\aa') 创建文件夹: os.execute("cd.>e:\\wang.ini")

  10. 第5章分布式系统模式 使用客户端激活对象通过 .NET Remoting 实现 Broker

    正在 .NET 中构建一个需要使用分布式对象的应用程序,并且分布式对象的生存期由客户端控制.您的要求包括能够按值或按引用来传递对象,无论这些对象驻留在同一台计算 机上,还是驻留在同一个局域网 (LAN ...