这次的mnist学习加入了测试集,看看学习的准确率,代码如下

# encoding: utf-8

import tensorflow as tf
import matplotlib.pyplot as plt #加载下载好的mnist数据库 60000张训练 10000张测试 每一张维度(28,28)
path = r'G:\2019\python\mnist.npz'
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path) #第一层输入256, 第二次输出128, 第三层输出10
#第一,二,三层参数w,b
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1)) #正态分布的一种
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) #两种数据预处理的方法
#(一)预处理训练数据
x = tf.convert_to_tensor(x_train, dtype = tf.float32)/255. #0:1 ; -1:1(不适合训练,准确度不高)
x = tf.reshape(x, [-1, 28*28])
y = tf.convert_to_tensor(y_train, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
#将60000组训练数据切分为600组,每组100个数据
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.shuffle(60000) #尽量与样本空间一样大
train_db = train_db.batch(100) # #(二)自定义预处理测试函数
def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255. #先将类型转化为float32,再归一到0-1
x = tf.reshape(x, [-1, 28*28]) #不知道x数量,用-1代替,转化为一维784个数据
y = tf.cast(y, dtype=tf.int32) #转化为整型32
y = tf.one_hot(y, depth=10) #训练数据所需的one-hot编码
return x, y #将10000组测试数据预处理
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.shuffle(10000)
test_db = test_db.batch(100) #
test_db = test_db.map(preprocess) lr = 0.001 #学习率
losses = [] #储存每epoch的loss值,便于观察学习情况
acc = [] #准确率 for epoch in range(30): #
#一次性处理100组(x, y)数据
for step, (x, y) in enumerate(train_db): #遍历切分好的数据step:0->599
with tf.GradientTape() as tape:
#向前传播第一,二,三层
h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256]) #可以直接写成 +b1
h1 = tf.nn.relu(h1)
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
out = h2@w3 + b3 #计算mse
loss = tf.square(y - out)
loss = tf.reduce_mean(loss)
#计算参数的梯度,tape.gradient为自动求导函数,loss为目标数据,目的使它越来越接近真实值
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
#更新w,b
w1.assign_sub(lr*grads[0]) #原地减去给定的值,实现参数的自我更新
b1.assign_sub(lr*grads[1])
w2.assign_sub(lr*grads[2])
b2.assign_sub(lr*grads[3])
w3.assign_sub(lr*grads[4])
b3.assign_sub(lr*grads[5])
#观察学习情况
if step%100 == 0:
print('训练第 ',epoch,'轮',', 第',step,'步, ','loss:', float(loss))
losses.append(float(loss)) #将每100step后的loss情况储存起来,最后观察 if step%500 == 0:
total, total_correct = 0., 0.
for x, y in test_db:
h1 = x @ w1 + b1
h1 = tf.nn.relu(h1)
h2 = h1 @ w2 + b2
h2 = tf.nn.relu(h2)
out = h2 @ w3 + b3 pred = tf.argmax(out, axis=1) # 选取概率最大的类别
y = tf.argmax(y, axis=1) # 类似于one-hot逆编码
correct = tf.equal(pred, y) # 比较真实值和预测值是否相等
total += x.shape[0]
# 统计正确的个数
total_correct += tf.reduce_sum(tf.cast(correct, dtype=tf.int32)).numpy()
print('训练第 ',epoch,'轮',', 第',step,'步, ', 'Evaluate Acc:', total_correct/total)
acc.append(total_correct/total) #plt.subplot(121)
x1 = [i*100 for i in range(len(losses))]
plt.plot(x1, losses, marker='s', label='training')
plt.xlabel('Step')
plt.ylabel('MSE')
plt.legend()
#plt.savefig('exam_mnist_forward.png')
#plt.show() #plt.subplot(122)
plt.figure()
x2 = [i for i in range(len(acc))]
plt.plot(x2, acc, 'r',marker='d', label='testing')
plt.xlabel('Step')
plt.ylabel('Accuracy')
plt.legend()
#plt.savefig('test_mnist_forward.png')
plt.show()

误差何准确率如下

发现和书中类似,但要注意的如下:

(1)数据预处理时,打散值选择和数据空间一样大;

(2)数据处理选择0-1之间,而不用(-1 :1),是因为后者学习效率不理想!

(3)代码还可以进行优化处理!

总的来说,代码还是容易理解,使用也更加简洁!

下一次更新,全连接网络,关于汽车油耗的预测。

tensorflow 2.0 学习(四)的更多相关文章

  1. tensorflow 1.0 学习:用CNN进行图像分类

    tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化. 任务:花卉分类 版本:tensorflow 1 ...

  2. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  3. tensorflow 1.0 学习:参数和特征的提取

    在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如: #取出所有参与训练的参数 params=tf.trainable_variables() print(&q ...

  4. tensorflow 1.0 学习:参数初始化(initializer)

    CNN中最重要的就是参数了,包括W,b. 我们训练CNN的最终目的就是得到最好的参数,使得目标函数取得最小值.参数的初始化也同样重要,因此微调受到很多人的重视,那么tf提供了哪些初始化参数的方法呢,我 ...

  5. Tensorflow 2.0 学习资源

    我从换了新工作才开始学习使用Tensorflow,感觉实在太难用了,sess和graph对 新手很不友好,各种API混乱不堪,这些在tf2.0都有了重大改变,2.0大量使用keras的 api,初步使 ...

  6. tensorflow 1.0 学习:用别人训练好的模型来进行图像分类

    谷歌在大型图像数据库ImageNet上训练好了一个Inception-v3模型,这个模型我们可以直接用来进来图像分类. 下载地址:https://storage.googleapis.com/down ...

  7. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

  8. tensorflow 1.0 学习:池化层(pooling)和全连接层(dense)

    池化层定义在 tensorflow/python/layers/pooling.py. 有最大值池化和均值池化. 1.tf.layers.max_pooling2d max_pooling2d( in ...

  9. tensorflow 1.0 学习:卷积层

    在tf1.0中,对卷积层重新进行了封装,比原来版本的卷积层有了很大的简化. 一.旧版本(1.0以下)的卷积函数:tf.nn.conv2d conv2d( input, filter, strides, ...

随机推荐

  1. day12——生成器、推导式、简单内置函数

    day12 生成器 迭代器:python中内置的一种节省空间的工具 生成器的本质就是一个迭代器 迭代器和生成器的区别:一个是pyhton自带的,一个是程序员自己写的 写一个生成器 基于函数 在函数中将 ...

  2. centos7中mysql的rpm包安装

    解决依赖 yum remove mysql-libs 执行命令:yum -y install autoconf 安装依赖 yum -y install autoconf 安装mysql rpm -iv ...

  3. (十二)一个简单的pdf文件体

    %PDF-1.0                     % 文件头,说明符合PDF1.0规范 1 0 obj                          %对象号     产生号(修改次数)  ...

  4. git 学习笔记 --多人协作

    当你从远程仓库克隆时,实际上Git自动把本地的master分支和远程的master分支对应起来了,并且,远程仓库的默认名称是origin. 要查看远程库的信息,用git remote: $ git r ...

  5. 一个Java程序员该有的良好品质

    一.前言 多年来,在IT领域,从一个普通的程序员到一个技术主管,再到一个技术经理,再到一个技术主管,他们践踏了许多坑,劳累了许多课程,还背着许多罐子.在提高他们的技术和管理能力的同时,他们一直在考虑如 ...

  6. yii框架里DetailView视图和GridView的区别

    1,首先从语义上分析 DetailView是数据视图,用于显示一条记录的数据,相当于网页中的详情页 GridView是网格视图,用于显示数据表里的所有记录,相当于网页里的列表页 2.用法上的区别 首先 ...

  7. java之mybatis之配置文件讲解

    1.核心配置文件 <configuration> <!-- 它们都是外部化,可替代的属性.可以配置在一个典型的Java 属性文件中,或者通过 properties 元素的子元素进行配 ...

  8. Java自学-数组 初始化数组

    Java 如何初始化数组 步骤 1 : 分配空间与赋值分步进行 public class HelloWorld { public static void main(String[] args) { i ...

  9. windows nvlddmkm、DRIVER_POWER_STATE_FAILURE 蓝屏问题的解决资料

    背景与现象描述 博主在最近购买了 机械革命 Z2-R (MECHREVO Z2-R Series GK5CP02) 笔记本电脑后,几乎每天均有不下3次的蓝屏,而且机器热时,更甚,达到每天10次以上,简 ...

  10. 【转载】C#中Convert.ToInt32方法将字符串转换为Int32类型

    在C#编程过程中,可以使用Convert.ToInt32方法将字符串或者其他可转换为数字的对象变量转换为ToInt32类型,Convert.ToInt32方法有多个重载方法,最常使用的一个方法将字符串 ...