更新、更全的《机器学习》的更新网站,更有python、go、数据结构与算法、爬虫、人工智能教学等着你:https://www.cnblogs.com/nickchen121/p/11686958.html

Tensorflow基本使用

一、确认安装Tensorflow

  1. import tensorflow as tf
  2. a = tf.constant(10)
  3. b = tf.constant(32)
  4. sess = tf.Session()
  5. print(sess.run(a+b))
  1. 42

二、获取MNIST数据集

  1. # 获取MNIST数据集
  2. # 获取地址:https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/examples/tutorials/mnist/input_data.py
  3. # Copyright 2015 Google Inc. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ==============================================================================
  17. """Functions for downloading and reading MNIST data."""
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import gzip
  22. import os
  23. import tensorflow.python.platform
  24. import numpy
  25. from six.moves import urllib
  26. from six.moves import xrange # pylint: disable=redefined-builtin
  27. import tensorflow as tf
  28. SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
  29. def maybe_download(filename, work_directory):
  30. """Download the data from Yann's website, unless it's already here."""
  31. if not os.path.exists(work_directory):
  32. os.mkdir(work_directory)
  33. filepath = os.path.join(work_directory, filename)
  34. if not os.path.exists(filepath):
  35. filepath, _ = urllib.request.urlretrieve(
  36. SOURCE_URL + filename, filepath)
  37. statinfo = os.stat(filepath)
  38. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  39. return filepath
  40. def _read32(bytestream):
  41. dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  42. return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
  43. def extract_images(filename):
  44. """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  45. print('Extracting', filename)
  46. with gzip.open(filename) as bytestream:
  47. magic = _read32(bytestream)
  48. if magic != 2051:
  49. raise ValueError(
  50. 'Invalid magic number %d in MNIST image file: %s' %
  51. (magic, filename))
  52. num_images = _read32(bytestream)
  53. rows = _read32(bytestream)
  54. cols = _read32(bytestream)
  55. buf = bytestream.read(rows * cols * num_images)
  56. data = numpy.frombuffer(buf, dtype=numpy.uint8)
  57. data = data.reshape(num_images, rows, cols, 1)
  58. return data
  59. def dense_to_one_hot(labels_dense, num_classes=10):
  60. """Convert class labels from scalars to one-hot vectors."""
  61. num_labels = labels_dense.shape[0]
  62. index_offset = numpy.arange(num_labels) * num_classes
  63. labels_one_hot = numpy.zeros((num_labels, num_classes))
  64. labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  65. return labels_one_hot
  66. def extract_labels(filename, one_hot=False):
  67. """Extract the labels into a 1D uint8 numpy array [index]."""
  68. print('Extracting', filename)
  69. with gzip.open(filename) as bytestream:
  70. magic = _read32(bytestream)
  71. if magic != 2049:
  72. raise ValueError(
  73. 'Invalid magic number %d in MNIST label file: %s' %
  74. (magic, filename))
  75. num_items = _read32(bytestream)
  76. buf = bytestream.read(num_items)
  77. labels = numpy.frombuffer(buf, dtype=numpy.uint8)
  78. if one_hot:
  79. return dense_to_one_hot(labels)
  80. return labels
  81. class DataSet(object):
  82. def __init__(self, images, labels, fake_data=False, one_hot=False,
  83. dtype=tf.float32):
  84. """Construct a DataSet.
  85. one_hot arg is used only if fake_data is true. `dtype` can be either
  86. `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
  87. `[0, 1]`.
  88. """
  89. dtype = tf.as_dtype(dtype).base_dtype
  90. if dtype not in (tf.uint8, tf.float32):
  91. raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
  92. dtype)
  93. if fake_data:
  94. self._num_examples = 10000
  95. self.one_hot = one_hot
  96. else:
  97. assert images.shape[0] == labels.shape[0], (
  98. 'images.shape: %s labels.shape: %s' % (images.shape,
  99. labels.shape))
  100. self._num_examples = images.shape[0]
  101. # Convert shape from [num examples, rows, columns, depth]
  102. # to [num examples, rows*columns] (assuming depth == 1)
  103. assert images.shape[3] == 1
  104. images = images.reshape(images.shape[0],
  105. images.shape[1] * images.shape[2])
  106. if dtype == tf.float32:
  107. # Convert from [0, 255] -> [0.0, 1.0].
  108. images = images.astype(numpy.float32)
  109. images = numpy.multiply(images, 1.0 / 255.0)
  110. self._images = images
  111. self._labels = labels
  112. self._epochs_completed = 0
  113. self._index_in_epoch = 0
  114. @property
  115. def images(self):
  116. return self._images
  117. @property
  118. def labels(self):
  119. return self._labels
  120. @property
  121. def num_examples(self):
  122. return self._num_examples
  123. @property
  124. def epochs_completed(self):
  125. return self._epochs_completed
  126. def next_batch(self, batch_size, fake_data=False):
  127. """Return the next `batch_size` examples from this data set."""
  128. if fake_data:
  129. fake_image = [1] * 784
  130. if self.one_hot:
  131. fake_label = [1] + [0] * 9
  132. else:
  133. fake_label = 0
  134. return [fake_image for _ in xrange(batch_size)], [
  135. fake_label for _ in xrange(batch_size)]
  136. start = self._index_in_epoch
  137. self._index_in_epoch += batch_size
  138. if self._index_in_epoch > self._num_examples:
  139. # Finished epoch
  140. self._epochs_completed += 1
  141. # Shuffle the data
  142. perm = numpy.arange(self._num_examples)
  143. numpy.random.shuffle(perm)
  144. self._images = self._images[perm]
  145. self._labels = self._labels[perm]
  146. # Start next epoch
  147. start = 0
  148. self._index_in_epoch = batch_size
  149. assert batch_size <= self._num_examples
  150. end = self._index_in_epoch
  151. return self._images[start:end], self._labels[start:end]
  152. def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
  153. class DataSets(object):
  154. pass
  155. data_sets = DataSets()
  156. if fake_data:
  157. def fake():
  158. return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
  159. data_sets.train = fake()
  160. data_sets.validation = fake()
  161. data_sets.test = fake()
  162. return data_sets
  163. TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  164. TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  165. TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  166. TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
  167. VALIDATION_SIZE = 5000
  168. local_file = maybe_download(TRAIN_IMAGES, train_dir)
  169. train_images = extract_images(local_file)
  170. local_file = maybe_download(TRAIN_LABELS, train_dir)
  171. train_labels = extract_labels(local_file, one_hot=one_hot)
  172. local_file = maybe_download(TEST_IMAGES, train_dir)
  173. test_images = extract_images(local_file)
  174. local_file = maybe_download(TEST_LABELS, train_dir)
  175. test_labels = extract_labels(local_file, one_hot=one_hot)
  176. validation_images = train_images[:VALIDATION_SIZE]
  177. validation_labels = train_labels[:VALIDATION_SIZE]
  178. train_images = train_images[VALIDATION_SIZE:]
  179. train_labels = train_labels[VALIDATION_SIZE:]
  180. data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
  181. data_sets.validation = DataSet(validation_images, validation_labels,
  182. dtype=dtype)
  183. data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
  184. return data_sets

三、使用Tensorflow训练——Softmax回归

  1. # 使用Tensorflow 训练——Softmax回归
  2. import time
  3. import tensorflow as tf
  4. # 读取 MNIST 数据集,分成训练数据和测试数据
  5. mnist = read_data_sets('MNIST_data/', one_hot=True)
  6. # 设置训练数据 x,连接权重 W 和偏置 b
  7. x = tf.placeholder('float', [None, 784])
  8. W = tf.Variable(tf.zeros([784, 10]))
  9. b = tf.Variable(tf.zeros([10]))
  10. # 对 x 和 W 进行内积运算后把结果传递给 softmax 函数,计算输出 y
  11. y = tf.nn.softmax(tf.matmul(x, W)+b)
  12. # 设置期望输出 y_
  13. y_ = tf.placeholder('float', [None, 10])
  14. # 计算交叉熵代价函数
  15. cross_entropy = -tf.reduce_sum(y_*tf.log(y))
  16. # 使用梯度下降法最小化交叉熵代价函数
  17. train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
  18. # 初始化所有参数
  19. init = tf.global_variables_initializer()
  20. sess = tf.Session()
  21. sess.run(init)
  22. st = time.time()
  23. # 迭代训练
  24. for i in range(1000):
  25. # 选择训练数据(mini-batch)
  26. batch_xs, batch_ys = mnist.train.next_batch(100)
  27. # 训练处理
  28. sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  29. # 进行测试,确认实际输出和期望输出是否一致
  30. correct_prediction = tf.equal(tf.argmax(y, -1), tf.argmax(y_, 1))
  31. softmax_time = time.time()-st
  32. # 计算准确率
  33. accuary = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
  34. print('准确率:%s' % sess.run(accuary, feed_dict={
  35. x: mnist.test.images, y_: mnist.test.labels}))
  36. softmax_acc = sess.run(accuary, feed_dict={
  37. x: mnist.test.images, y_: mnist.test.labels})
  1. Extracting MINIST_data/train-images-idx3-ubyte.gz
  2. Extracting MINIST_data/train-labels-idx1-ubyte.gz
  3. Extracting MINIST_data/t10k-images-idx3-ubyte.gz
  4. Extracting MINIST_data/t10k-labels-idx1-ubyte.gz
  5. 准确率:0.9191

四、使用Tensorflow训练——卷积神经网络

4.1 构建网络组件

  1. # 构建网络组件
  2. import time
  3. import tensorflow as tf
  4. def weight_variable(shape):
  5. """
  6. 初始化连接权重
  7. """
  8. # truncated_normal()根据指定的标准差创建随机数
  9. initial = tf.truncated_normal(shape, stddev=0.1)
  10. return tf.Variable(initial)
  11. def bias_variable(shape):
  12. """
  13. 初始化偏置
  14. """
  15. initial = tf.constant(0.1, shape=shape)
  16. return tf.Variable(initial)
  17. def conv2d(x, W):
  18. """
  19. 构建卷积层
  20. x: 输入数据,四维参数——批大小、高度、宽度和通道数
  21. W: 卷积核参数,四维参数——卷积核高度、卷积核宽度、输入通道数和输出通道数
  22. """
  23. # strides设置卷积核移动的步长,strides=[1,2,2,1]步长为2
  24. # padding设置是否补零填充,padding='SAME'为填充;padding='VALID'为不填充
  25. return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  26. def max_pool_2x2(x):
  27. """
  28. 构建池化层
  29. x: 输入数据,四维参数——批大小、高度、宽度和通道数
  30. """
  31. # ksize设置池化窗口的大小,四维参数——批大小、高度、宽度和通道数
  32. return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  33. # 读取MNIST数据集
  34. mnist = read_data_sets('MNIST_data', one_hot=True)
  35. # 输入数据,二维数据shape=[批大小, 数据维度]
  36. x = tf.placeholder('float', shape=[None, 784])
  37. # 期望输出
  38. y_ = tf.placeholder('float', shape=[None, 10])
  39. # 修改数据集格式(批大小*28*28*通道数),即把二维数据修改成四维张量[-1,28,28,1]
  40. x_image = tf.reshape(x, [-1, 28, 28, 1])

4.2 定义网络结构

  1. # 定义网络结构
  2. # 第1个卷积层,weight_variable([卷积核高度,卷积核宽度,通道数,卷积核个数])
  3. W_conv1 = weight_variable([5, 5, 1, 32])
  4. b_conv1 = bias_variable([32])
  5. # 激活函数及池化
  6. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1)+b_conv1)
  7. h_pool = max_pool_2x2(h_conv1)
  8. # 第2个卷积层
  9. W_conv2 = weight_variable([5, 5, 32, 64])
  10. b_conv2 = bias_variable([64])
  11. # 激活函数及池化
  12. h_conv2 = tf.nn.relu(conv2d(h_pool, W_conv2)+b_conv2)
  13. h_pool2 = max_pool_2x2(h_conv2)
  14. # 设置全连接层的参数
  15. W_fc1 = weight_variable([7*7*64, 1024])
  16. b_fc1 = bias_variable([1024])
  17. # 全连接层
  18. h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
  19. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1)+b_fc1)
  20. # Dropout
  21. keep_prob = tf.placeholder('float')
  22. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  23. # 设置全连接层的参数
  24. W_fc2 = weight_variable([1024, 10])
  25. b_fc2 = bias_variable([10])
  26. # softmax 函数
  27. y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2)+b_fc2)
  28. # 误差函数,交叉熵代价函数
  29. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))

4.3 训练模型

  1. # 训练模型
  2. # 训练方法
  3. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  4. # 测试方法
  5. correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  6. accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
  7. # 创建训练用的会话
  8. sess = tf.Session()
  9. # 初始化参数
  10. sess.run(tf.global_variables_initializer())
  11. st = time.time()
  12. # 迭代处理
  13. for i in range(1000):
  14. # 选择训练数据(mini-batch)
  15. batch = mnist.train.next_batch(50)
  16. # 训练处理
  17. _, loss_value = sess.run([train_step, cross_entropy], feed_dict={
  18. x: batch[0], y_: batch[1], keep_prob: 0.5})
  19. # 测试
  20. if i % 100 == 0:
  21. acc = sess.run(accuracy, feed_dict={
  22. x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.})
  23. print(f'卷积神经网络迭代 {i} 次的准确率:{acc}')
  24. print(f'Softmax回归训练时间:{softmax_time}')
  25. print(f'卷积神经网络训练时间:{time.time()-st}')
  26. # 测试
  27. acc = sess.run(accuracy, feed_dict={
  28. x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.})
  29. print(f'Softmax回归准确率:{softmax_acc}')
  30. print(f'卷积神经网络准确率:{acc}')
  1. 卷积神经网络迭代 0 次的准确率:0.08910000324249268
  2. 卷积神经网络迭代 100 次的准确率:0.8474000096321106
  3. 卷积神经网络迭代 200 次的准确率:0.9085000157356262
  4. 卷积神经网络迭代 300 次的准确率:0.9266999959945679
  5. 卷积神经网络迭代 400 次的准确率:0.9399999976158142
  6. 卷积神经网络迭代 500 次的准确率:0.9430999755859375
  7. 卷积神经网络迭代 600 次的准确率:0.953499972820282
  8. 卷积神经网络迭代 700 次的准确率:0.9571999907493591
  9. 卷积神经网络迭代 800 次的准确率:0.9599999785423279
  10. 卷积神经网络迭代 900 次的准确率:0.9613000154495239
  11. Softmax回归训练时间:2.030284881591797
  12. 卷积神经网络训练时间:394.48987913131714
  13. Softmax回归准确率:0.9190999865531921
  14. 卷积神经网络准确率:0.9670000076293945

五、使用Tensorflow进行可视化

  1. # 使用Tensorflow进行可视化
  2. # Copyright 2015 Google Inc. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. """Functions for downloading and reading MNIST data."""
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import gzip
  21. import os
  22. import time
  23. import tensorflow.python.platform
  24. import numpy
  25. from six.moves import urllib
  26. from six.moves import xrange # pylint: disable=redefined-builtin
  27. import tensorflow as tf
  28. SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
  29. def maybe_download(filename, work_directory):
  30. """Download the data from Yann's website, unless it's already here."""
  31. if not os.path.exists(work_directory):
  32. os.mkdir(work_directory)
  33. filepath = os.path.join(work_directory, filename)
  34. if not os.path.exists(filepath):
  35. filepath, _ = urllib.request.urlretrieve(
  36. SOURCE_URL + filename, filepath)
  37. statinfo = os.stat(filepath)
  38. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  39. return filepath
  40. def _read32(bytestream):
  41. dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  42. return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
  43. def extract_images(filename):
  44. """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
  45. print('Extracting', filename)
  46. with gzip.open(filename) as bytestream:
  47. magic = _read32(bytestream)
  48. if magic != 2051:
  49. raise ValueError(
  50. 'Invalid magic number %d in MNIST image file: %s' %
  51. (magic, filename))
  52. num_images = _read32(bytestream)
  53. rows = _read32(bytestream)
  54. cols = _read32(bytestream)
  55. buf = bytestream.read(rows * cols * num_images)
  56. data = numpy.frombuffer(buf, dtype=numpy.uint8)
  57. data = data.reshape(num_images, rows, cols, 1)
  58. return data
  59. def dense_to_one_hot(labels_dense, num_classes=10):
  60. """Convert class labels from scalars to one-hot vectors."""
  61. num_labels = labels_dense.shape[0]
  62. index_offset = numpy.arange(num_labels) * num_classes
  63. labels_one_hot = numpy.zeros((num_labels, num_classes))
  64. labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  65. return labels_one_hot
  66. def extract_labels(filename, one_hot=False):
  67. """Extract the labels into a 1D uint8 numpy array [index]."""
  68. print('Extracting', filename)
  69. with gzip.open(filename) as bytestream:
  70. magic = _read32(bytestream)
  71. if magic != 2049:
  72. raise ValueError(
  73. 'Invalid magic number %d in MNIST label file: %s' %
  74. (magic, filename))
  75. num_items = _read32(bytestream)
  76. buf = bytestream.read(num_items)
  77. labels = numpy.frombuffer(buf, dtype=numpy.uint8)
  78. if one_hot:
  79. return dense_to_one_hot(labels)
  80. return labels
  81. class DataSet(object):
  82. def __init__(self, images, labels, fake_data=False, one_hot=False,
  83. dtype=tf.float32):
  84. """Construct a DataSet.
  85. one_hot arg is used only if fake_data is true. `dtype` can be either
  86. `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
  87. `[0, 1]`.
  88. """
  89. dtype = tf.as_dtype(dtype).base_dtype
  90. if dtype not in (tf.uint8, tf.float32):
  91. raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
  92. dtype)
  93. if fake_data:
  94. self._num_examples = 10000
  95. self.one_hot = one_hot
  96. else:
  97. assert images.shape[0] == labels.shape[0], (
  98. 'images.shape: %s labels.shape: %s' % (images.shape,
  99. labels.shape))
  100. self._num_examples = images.shape[0]
  101. # Convert shape from [num examples, rows, columns, depth]
  102. # to [num examples, rows*columns] (assuming depth == 1)
  103. assert images.shape[3] == 1
  104. images = images.reshape(images.shape[0],
  105. images.shape[1] * images.shape[2])
  106. if dtype == tf.float32:
  107. # Convert from [0, 255] -> [0.0, 1.0].
  108. images = images.astype(numpy.float32)
  109. images = numpy.multiply(images, 1.0 / 255.0)
  110. self._images = images
  111. self._labels = labels
  112. self._epochs_completed = 0
  113. self._index_in_epoch = 0
  114. @property
  115. def images(self):
  116. return self._images
  117. @property
  118. def labels(self):
  119. return self._labels
  120. @property
  121. def num_examples(self):
  122. return self._num_examples
  123. @property
  124. def epochs_completed(self):
  125. return self._epochs_completed
  126. def next_batch(self, batch_size, fake_data=False):
  127. """Return the next `batch_size` examples from this data set."""
  128. if fake_data:
  129. fake_image = [1] * 784
  130. if self.one_hot:
  131. fake_label = [1] + [0] * 9
  132. else:
  133. fake_label = 0
  134. return [fake_image for _ in xrange(batch_size)], [
  135. fake_label for _ in xrange(batch_size)]
  136. start = self._index_in_epoch
  137. self._index_in_epoch += batch_size
  138. if self._index_in_epoch > self._num_examples:
  139. # Finished epoch
  140. self._epochs_completed += 1
  141. # Shuffle the data
  142. perm = numpy.arange(self._num_examples)
  143. numpy.random.shuffle(perm)
  144. self._images = self._images[perm]
  145. self._labels = self._labels[perm]
  146. # Start next epoch
  147. start = 0
  148. self._index_in_epoch = batch_size
  149. assert batch_size <= self._num_examples
  150. end = self._index_in_epoch
  151. return self._images[start:end], self._labels[start:end]
  152. def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
  153. class DataSets(object):
  154. pass
  155. data_sets = DataSets()
  156. if fake_data:
  157. def fake():
  158. return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
  159. data_sets.train = fake()
  160. data_sets.validation = fake()
  161. data_sets.test = fake()
  162. return data_sets
  163. TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
  164. TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
  165. TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
  166. TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
  167. VALIDATION_SIZE = 5000
  168. local_file = maybe_download(TRAIN_IMAGES, train_dir)
  169. train_images = extract_images(local_file)
  170. local_file = maybe_download(TRAIN_LABELS, train_dir)
  171. train_labels = extract_labels(local_file, one_hot=one_hot)
  172. local_file = maybe_download(TEST_IMAGES, train_dir)
  173. test_images = extract_images(local_file)
  174. local_file = maybe_download(TEST_LABELS, train_dir)
  175. test_labels = extract_labels(local_file, one_hot=one_hot)
  176. validation_images = train_images[:VALIDATION_SIZE]
  177. validation_labels = train_labels[:VALIDATION_SIZE]
  178. train_images = train_images[VALIDATION_SIZE:]
  179. train_labels = train_labels[VALIDATION_SIZE:]
  180. data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
  181. data_sets.validation = DataSet(validation_images, validation_labels,
  182. dtype=dtype)
  183. data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
  184. return data_sets
  185. def weight_variable(shape):
  186. """
  187. 初始化连接权重
  188. """
  189. # truncated_normal()根据指定的标准差创建随机数
  190. initial = tf.truncated_normal(shape, stddev=0.1)
  191. return tf.Variable(initial)
  192. def bias_variable(shape):
  193. """
  194. 初始化偏置
  195. """
  196. initial = tf.constant(0.1, shape=shape)
  197. return tf.Variable(initial)
  198. def conv2d(x, W):
  199. """
  200. 构建卷积层
  201. x: 输入数据,四维参数——批大小、高度、宽度和通道数
  202. W: 卷积核参数,四维参数——卷积核高度、卷积核宽度、输入通道数和输出通道数
  203. """
  204. # strides设置卷积核移动的步长,strides=[1,2,2,1]步长为2
  205. # padding设置是否补零填充,padding='SAME'为填充;padding='VALID'为不填充
  206. return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
  207. def max_pool_2x2(x):
  208. """
  209. 构建池化层
  210. x: 输入数据,四维参数——批大小、高度、宽度和通道数
  211. """
  212. # ksize设置池化窗口的大小,四维参数——批大小、高度、宽度和通道数
  213. return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  214. # 读取MNIST数据集
  215. mnist = read_data_sets('MNIST_data', one_hot=True)
  216. # # 输入数据,二维数据shape=[批大小, 数据维度]
  217. # x = tf.placeholder('float', shape=[None, 784])
  218. # # 期望输出
  219. # y_ = tf.placeholder('float', shape=[None, 10])
  220. # 通过as_default()生成一个计算图
  221. with tf.Graph().as_default():
  222. # 设置数据集和期望输出
  223. x = tf.placeholder('float', shape=[None, 784], name='Input')
  224. y_ = tf.placeholder('float', shape=[None, 10], name='GroundTruth')
  225. # 修改数据集格式(批大小*28*28*通道数),即把二维数据修改成四维张量[-1,28,28,1]
  226. x_image = tf.reshape(x, [-1, 28, 28, 1])
  227. # 第1个卷积层,weight_variable([卷积核高度,卷积核宽度,通道数,卷积核个数])
  228. W_conv1 = weight_variable([5, 5, 1, 32])
  229. b_conv1 = bias_variable([32])
  230. # 激活函数及池化
  231. h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1)+b_conv1)
  232. h_pool = max_pool_2x2(h_conv1)
  233. # 第2个卷积层
  234. W_conv2 = weight_variable([5, 5, 32, 64])
  235. b_conv2 = bias_variable([64])
  236. # 激活函数及池化
  237. h_conv2 = tf.nn.relu(conv2d(h_pool, W_conv2)+b_conv2)
  238. h_pool2 = max_pool_2x2(h_conv2)
  239. # 设置全连接层的参数
  240. W_fc1 = weight_variable([7*7*64, 1024])
  241. b_fc1 = bias_variable([1024])
  242. # 全连接层
  243. h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
  244. h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1)+b_fc1)
  245. # Dropout
  246. keep_prob = tf.placeholder('float')
  247. h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
  248. # 设置全连接层的参数
  249. W_fc2 = weight_variable([1024, 10])
  250. b_fc2 = bias_variable([10])
  251. # softmax 函数
  252. # y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2)+b_fc2)
  253. with tf.name_scope('Output') as scope:
  254. y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2)+b_fc2)
  255. # 误差函数,交叉熵代价函数
  256. # cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
  257. with tf.name_scope('xentropy') as scope:
  258. cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
  259. # tf.summary.scalar()输出训练情况
  260. ce_summ = tf.summary.scalar('cross_entropy', cross_entropy)
  261. # 训练方法
  262. # train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  263. with tf.name_scope('train') as scope:
  264. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
  265. # 测试方法
  266. # correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  267. # accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
  268. with tf.name_scope('test') as scope:
  269. correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
  270. accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
  271. accuracy_summary = tf.summary.scalar('accuracy', accuracy)
  272. # 创建训练用的会话
  273. sess = tf.Session()
  274. # 初始化参数
  275. sess.run(tf.global_variables_initializer())
  276. # 训练情况的输出设置(新增)
  277. # 把设置的所有输出操作合并为一个操作
  278. summary_op = tf.summary.merge_all()
  279. # tf.summary.FileWriter()保存训练数据,graph_def为图(网络结构)
  280. summary_writer = tf.summary.FileWriter('MNIST_data', graph_def=sess.graph_def)
  281. st = time.time()
  282. # 迭代处理
  283. for i in range(1000):
  284. # 选择训练数据(mini-batch)
  285. batch = mnist.train.next_batch(50)
  286. # 训练处理
  287. _, loss_value = sess.run([train_step, cross_entropy], feed_dict={
  288. x: batch[0], y_: batch[1], keep_prob: 0.5})
  289. # 测试
  290. if i % 100 == 0:
  291. # acc = sess.run(accuracy, feed_dict={
  292. # x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.})
  293. # summary_op输出训练数据,accuracy进行测试
  294. result = sess.run([summary_op, accuracy], feed_dict={
  295. x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.})
  296. # 传递summary_op
  297. summary_str = result[0]
  298. # 传递acc
  299. acc = result[1]
  300. # add_summary()输出summary_str的内容
  301. summary_writer.add_summary(summary_str, i)
  302. print(f'卷积神经网络迭代 {i} 次的准确率:{acc}')
  303. print(f'卷积神经网络训练时间:{time.time()-st}')
  304. # 测试
  305. acc = sess.run(accuracy, feed_dict={
  306. x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.})
  307. print(f'卷积神经网络准确率:{acc}')
  1. Extracting MNIST_data/train-images-idx3-ubyte.gz
  2. Extracting MNIST_data/train-labels-idx1-ubyte.gz
  3. Extracting MNIST_data/t10k-images-idx3-ubyte.gz
  4. Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
  5. WARNING:tensorflow:Passing a `GraphDef` to the SummaryWriter is deprecated. Pass a `Graph` object instead, such as `sess.graph`.
  6. 卷积神经网络迭代 0 次的准确率:0.11810000240802765
  7. 卷积神经网络迭代 100 次的准确率:0.8456000089645386
  8. 卷积神经网络迭代 200 次的准确率:0.9088000059127808
  9. 卷积神经网络迭代 300 次的准确率:0.9273999929428101
  10. 卷积神经网络迭代 400 次的准确率:0.935699999332428
  11. 卷积神经网络迭代 500 次的准确率:0.9404000043869019
  12. 卷积神经网络迭代 600 次的准确率:0.9490000009536743
  13. 卷积神经网络迭代 700 次的准确率:0.951200008392334
  14. 卷积神经网络迭代 800 次的准确率:0.95660001039505
  15. 卷积神经网络迭代 900 次的准确率:0.9592999815940857
  16. 卷积神经网络训练时间:374.29131293296814
  17. 卷积神经网络准确率:0.963699996471405

终端运行:tensorboard --logdir ~/Desktop/jupyter/deepLearning/图解深度学习-tensorflow/MNIST_data Starting Tensor- Board on port 6006

  • 其中--logdir指定的是完整路径目录

09-01 Tensorflow1基本使用的更多相关文章

  1. 调试大叔V1.0.1(2017.09.01)|http/s接口调试、数据分析程序员辅助开发神器

    2017.09.01 - 调试大叔 V1.0.1*支持http/https协议的get/post调试与反馈:*可保存请求协议的记录:*内置一批动态参数,可应用于URL.页头.参数:*可自由管理cook ...

  2. Cheatsheet: 2016 09.01 ~ 09.30

    Web Is JavaScript Single-Threaded? Quill 1.0 – Better Rich Text Editor for Web Apps Next Generation ...

  3. Cheatsheet: 2015 09.01 ~ 09.30

    Web A Guide to Vanilla Ajax Without jQuery Gulp for Beginners A Detailed Walkthrough of ASP.net MVC ...

  4. Cheatsheet: 2014 09.01 ~ 09.30

    Mobile Testing Mobile: Emulators, Simulators And Remote Debugging iOS 8 and iPhone 6 for Web Develop ...

  5. Cheatsheet: 2013 09.01 ~ 09.09

    .NET Multi Threaded WebScraping in CSharpDotNetTech .NET Asynchronous Patterns An Overview of Projec ...

  6. NYOJ-171 聪明的kk AC 分类: NYOJ 2014-01-02 09:01 165人阅读 评论(0) 收藏

    #include<stdio.h> #define max(x,y) x>y?x:y int main(){ int num[22][22]={0}; int n,m; int x, ...

  7. 2016.09.01 html5兼容

    <!--[if lt IE 9]>  <script src="http://apps.bdimg.com/libs/html5shiv/3.7/html5shiv.min ...

  8. 2018.09.01 09:22 Exodus

    Be careful when writing in the blog garden. Sometimes you accidentally write something wrong, and yo ...

  9. 2018.09.01 09:08 Genesis

    Nothing to think about, I don't know where to start, the mastery of learning is not an easy task, yo ...

  10. 2018.09.01 poj3071Football(概率dp+二进制找规律)

    传送门 概率dp简单题. 设f[i][j]表示前i轮j获胜的概率. 如果j,k能够刚好在第i轮相遇,找规律可以发现j,k满足: (j−1)>>(i−1)" role=" ...

随机推荐

  1. Request请求的应用

    1.通过request获得请求行 获得客户端的请求方式:String getMethod() 获得请求的资源: String getRequestURI()   StringBuffer getReq ...

  2. CSS3 03. 3D变换、坐标系、透视perspective、transformZ、transform-style添加3D效果、backface-visibility元素背面可见、动画animation、@keyfarmes、多列布局

    1.左手坐标系 伸出左手,让拇指和食指成“L”形,大拇指向右,食指向上,中指指向前方.这样我们就建立了一个左手坐标系,拇指.食指和中指分别代表X.Y.Z轴的正方向.如下图 CSS中的3D坐标系 CSS ...

  3. 实现一个基于码云Storage

    实现一个简单的基于码云(Gitee) 的 Storage Intro 上次在 asp.net core 从单机到集群 一文中提到存储还不支持分布式,并立了一个 flag 基于 github 或者 开源 ...

  4. 《程序实现》从xml、txt文件里读取数据写入excel表格

    直接上码 import java.io.BufferedReader; import java.io.DataInputStream; import java.io.File; import java ...

  5. 记录一次oracle的坑

    背景:程序正常运行中,突然技术支持人员反映数据库数据好久没有增加,于是乎各种排查问题,但是一直没有找到原因,由于代码比较久,也不是本人所写,更气的是居然用的是oracle数据库,并且是通过java代码 ...

  6. 【教程】Bluestacks0.7.9.860以上版3分钟教你摇一摇

    Bluestacks 0.7.9.860 版或以上 , 打开文件夹Win 7 用户 : C:\ProgramData\Bluestacks\UserData\InputMapperWin XP 用户 ...

  7. cmd中添加目录md

    md 创建目录. MKDIR [drive:]pathMD [drive:]path 如果命令扩展被启用,MKDIR 会如下改变: 如果需要,MKDIR 会在路径中创建中级目录.例如: 假设 \a 不 ...

  8. Qt信号槽-原理分析

    目录 一.问题 二.Moc 1.变量 2.Q_OBJECT展开后的函数声明 3.自定义信号 三.connect 四.信号触发 1.直连 2.队列连接 五.总结 六.推荐阅读 一.问题 学习Qt有一段时 ...

  9. 第1次作业-Numpy练习

    1.创建一个边界值为1而内部都是0的数组,图例如下:[提示:]解此题可以先把所有值都设置为1,这是大正方形:其次,把边界除外小正方形全部设置为0.本题用到numpy的切片原理.多维数组同样遵循x[st ...

  10. [Pandas] 04 - Efficient I/O

    SQLITE3接口 调动 SQLITE3数据库 import sqlite3 as sq3 query = 'CREATE TABLE numbs (Date date, No1 real, No2 ...