用tensorflow2.0 版回顾了一下mnist的学习

代码如下,感觉这个版本下的mnist学习更简洁,更方便

关于tensorflow的基础知识,这里就不更新了,用到什么就到网上搜索相关的知识

# encoding: utf-8

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt #加载下载好的mnist数据库 60000张训练 10000张测试 每一张维度(28,28)
path = r'G:\2019\python\mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
f.close() #预处理输入数据
x = 2*tf.convert_to_tensor(x_train, dtype = tf.float32)/255. - 1
x = tf.reshape(x, [-1, 28*28])
y = tf.convert_to_tensor(y_train, dtype=tf.int32)
y = tf.one_hot(y, depth=10) #第一层输入256, 第二次输出128, 第三层输出10
#第一,二,三层参数w,b
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1)) #正态分布的一种
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) #将60000组数据切分为600组,每组100个数据
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(100)
lr = 0.001 #学习率
losses = [] #储存每epoch的loss值,便于观察学习情况 for epoch in range(20):
#一次性处理100组(x, y)数据
for step, (x, y) in enumerate(train_db): #遍历切分好的数据step:0->599
with tf.GradientTape() as tape:
#向前传播第一,二,三层
h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256]) #可以直接写成 +b1
h1 = tf.nn.relu(h1)
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
out = h2@w3 + b3
#计算mse
loss = tf.square(y - out)
loss = tf.reduce_mean(loss)
#计算参数的梯度,tape.gradient为自动求导函数,loss为目标数据,目的使它越来越接近真实值
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
#更新w,b
w1.assign_sub(lr*grads[0]) #原地减去给定的值,实现参数的自我更新
b1.assign_sub(lr*grads[1])
w2.assign_sub(lr*grads[2])
b2.assign_sub(lr*grads[3])
w3.assign_sub(lr*grads[4])
b3.assign_sub(lr*grads[5])
#观察学习情况
if step%500 == 0:
print(epoch, step, 'loss:', float(loss))
#将每epoch的loss情况储存起来,最后观察
losses.append(float(loss)) plt.plot(losses, marker='s', label='training')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.legend()
plt.savefig('exam_mnist_forward.png')
plt.show()

观察结果:

可由注释理解代码的含义!下一次更新mnist数据集训练的进阶!

tensorflow2.0 学习(三)的更多相关文章

  1. tensorflow2.0 学习(一)

    虽说是按<TensorFlow深度学习>这本书来学习的,但是总会碰到新的问题!记录下这些问题,有利于巩固知新. 之前学过一些tensorflow1.0的知识,到RNN这章节,后面没有再继续 ...

  2. Tensorflow2.0学习(一)

    站长资讯平台:今天学习一下Tensorflow2.0 的基础 核心库,@tf.function ,可以方便的将动态图的语言,变成静态图,在某种程度上进行计算加速 TensorFlow Lite Ten ...

  3. tensorflow2.0学习笔记

    今天我们开始学习tensorflow2.0,用一种简单和循循渐进的方式,带领大家亲身体验深度学习.学习的目录如下图所示: 1.简单的神经网络学习过程 1.1张量生成 1.2常用函数 1.3鸢尾花数据读 ...

  4. TensorFlow2.0(三):排序及最大、最小、平均值

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  5. tensorflow2.0 学习(二)

    线性回归问题 # encoding: utf-8 import numpy as np import matplotlib.pyplot as plt data = [] for i in range ...

  6. tensorflow2.0学习笔记第一章第四节

    1.4神经网络实现鸢尾花分类 import tensorflow as tf from sklearn import datasets import pandas as pd import numpy ...

  7. tensorflow2.0学习笔记第一章第一节

    一.简单的神经网络实现过程 1.1张量的生成 # 创建一个张量 #tf.constant(张量内容,dtpye=数据类型(可选)) import tensorflow as tf import num ...

  8. tensorflow2.0学习笔记第一章第二节

    1.2常用函数 本节目标:掌握在建立和操作神经网络过程中常用的函数 # 常用函数 import tensorflow as tf import numpy as np # 强制Tensor的数据类型转 ...

  9. tensorflow2.0学习笔记第一章第三节

    1.3鸢尾花数据读入 # 从sklearn包datasets读入数据 from sklearn import datasets from pandas import DataFrame import ...

随机推荐

  1. 【MySQL】各种小坑-持续更新

    中文乱码问题 在建表的时候额外执行 ALTER TABLE camera CONVERT TO CHARACTER SET utf8; 如果还是不行注意看一下precision,为2的时候容易出现?? ...

  2. python_socket (套接字)

    socket是计算机网络通信的基本的技术之一.如今大多数基于网络的软件,如浏览器,即时通讯工具甚至是P2P下载都是基于Socket实现的. 网络上两个程序通过一个双向的通信连接实现数据的交换,这个连接 ...

  3. windows桌面远程连接突然不能双向复制文件

    远程桌面连接windows 2008,突然无法在本地和服务器之间互相复制文件.根据微软的说明,由rdpclip.exe进程来控制,打开远程服务器的任务管理器,看到rdpclip.exe进程存在,即可进 ...

  4. C#——零散学习1

    C#——零散学习1 //结构体(与C语言相似) struct Position { public float x; public float y;         //不一定需要把结构体成员设置为pu ...

  5. 关于如何控制一个页面的Ajax读数据只读一次的简单解决办法!

    例如:一个页面有一个按钮,点击的时候用ajax去后台获取数据,获取成功以后返回.下次再点击的时候就不要去获取数据了. 解决办法有很多: 1.用Get方法去读数据,会缓存. 2.用jquery的data ...

  6. 必须掌握的Linux用户组

    在 Linux 系统中用户组起着重要作用.用户组提供了一种简单方法供一组用户互相共享文件.用户组也允许系统管理员更加有效地管理用户权限,因为管理员可以将权限分配给用户组而不是逐一分配给单个用户. 尽管 ...

  7. Java GC的工作原理详解

    JVM学习笔记之JVM内存管理和JVM垃圾回收的概念,JVM内存结构由堆.栈.本地方法栈.方法区等部分组成,另外JVM分别对新生代下载地址  和旧生代采用不同的垃圾回收机制. 首先来看一下JVM内存结 ...

  8. 【转载】 C#中decimal.TryParse方法和decimal.Parse方法的异同之处

    在C#编程过程中,decimal.TryParse方法和decimal.Parse方法都可以将字符串string转换为decimal类型,但两者还是有区别,最重要的区别在于decimal.TryPar ...

  9. SpringBoot上传文件报错,临时路径不存在

    异常信息 报错日志: The temporary upload location [/tmp/tomcat.7957874575370093230.8088/work/Tomcat/localhost ...

  10. 图说jdk1.8新特性(1)--- 函数式接口

    函数式接口 总结起来就以下几点: 如果一个接口要想成为函数接口(函数接口可以直接用lambda方式简化),则必须有且仅有一个抽象的方法(非default和static) 可以通过注解@Function ...