tensorflow2.0 学习(三)
用tensorflow2.0 版回顾了一下mnist的学习
代码如下,感觉这个版本下的mnist学习更简洁,更方便
关于tensorflow的基础知识,这里就不更新了,用到什么就到网上搜索相关的知识
# encoding: utf-8 import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt #加载下载好的mnist数据库 60000张训练 10000张测试 每一张维度(28,28)
path = r'G:\2019\python\mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
f.close() #预处理输入数据
x = 2*tf.convert_to_tensor(x_train, dtype = tf.float32)/255. - 1
x = tf.reshape(x, [-1, 28*28])
y = tf.convert_to_tensor(y_train, dtype=tf.int32)
y = tf.one_hot(y, depth=10) #第一层输入256, 第二次输出128, 第三层输出10
#第一,二,三层参数w,b
w1 = tf.Variable(tf.random.truncated_normal([784, 256], stddev=0.1)) #正态分布的一种
b1 = tf.Variable(tf.zeros([256]))
w2 = tf.Variable(tf.random.truncated_normal([256, 128], stddev=0.1))
b2 = tf.Variable(tf.zeros([128]))
w3 = tf.Variable(tf.random.truncated_normal([128, 10], stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) #将60000组数据切分为600组,每组100个数据
train_db = tf.data.Dataset.from_tensor_slices((x, y)).batch(100)
lr = 0.001 #学习率
losses = [] #储存每epoch的loss值,便于观察学习情况 for epoch in range(20):
#一次性处理100组(x, y)数据
for step, (x, y) in enumerate(train_db): #遍历切分好的数据step:0->599
with tf.GradientTape() as tape:
#向前传播第一,二,三层
h1 = x@w1 + tf.broadcast_to(b1, [x.shape[0], 256]) #可以直接写成 +b1
h1 = tf.nn.relu(h1)
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2)
out = h2@w3 + b3
#计算mse
loss = tf.square(y - out)
loss = tf.reduce_mean(loss)
#计算参数的梯度,tape.gradient为自动求导函数,loss为目标数据,目的使它越来越接近真实值
grads = tape.gradient(loss, [w1, b1, w2, b2, w3, b3])
#更新w,b
w1.assign_sub(lr*grads[0]) #原地减去给定的值,实现参数的自我更新
b1.assign_sub(lr*grads[1])
w2.assign_sub(lr*grads[2])
b2.assign_sub(lr*grads[3])
w3.assign_sub(lr*grads[4])
b3.assign_sub(lr*grads[5])
#观察学习情况
if step%500 == 0:
print(epoch, step, 'loss:', float(loss))
#将每epoch的loss情况储存起来,最后观察
losses.append(float(loss)) plt.plot(losses, marker='s', label='training')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.legend()
plt.savefig('exam_mnist_forward.png')
plt.show()
观察结果:

可由注释理解代码的含义!下一次更新mnist数据集训练的进阶!
tensorflow2.0 学习(三)的更多相关文章
- tensorflow2.0 学习(一)
虽说是按<TensorFlow深度学习>这本书来学习的,但是总会碰到新的问题!记录下这些问题,有利于巩固知新. 之前学过一些tensorflow1.0的知识,到RNN这章节,后面没有再继续 ...
- Tensorflow2.0学习(一)
站长资讯平台:今天学习一下Tensorflow2.0 的基础 核心库,@tf.function ,可以方便的将动态图的语言,变成静态图,在某种程度上进行计算加速 TensorFlow Lite Ten ...
- tensorflow2.0学习笔记
今天我们开始学习tensorflow2.0,用一种简单和循循渐进的方式,带领大家亲身体验深度学习.学习的目录如下图所示: 1.简单的神经网络学习过程 1.1张量生成 1.2常用函数 1.3鸢尾花数据读 ...
- TensorFlow2.0(三):排序及最大、最小、平均值
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- tensorflow2.0 学习(二)
线性回归问题 # encoding: utf-8 import numpy as np import matplotlib.pyplot as plt data = [] for i in range ...
- tensorflow2.0学习笔记第一章第四节
1.4神经网络实现鸢尾花分类 import tensorflow as tf from sklearn import datasets import pandas as pd import numpy ...
- tensorflow2.0学习笔记第一章第一节
一.简单的神经网络实现过程 1.1张量的生成 # 创建一个张量 #tf.constant(张量内容,dtpye=数据类型(可选)) import tensorflow as tf import num ...
- tensorflow2.0学习笔记第一章第二节
1.2常用函数 本节目标:掌握在建立和操作神经网络过程中常用的函数 # 常用函数 import tensorflow as tf import numpy as np # 强制Tensor的数据类型转 ...
- tensorflow2.0学习笔记第一章第三节
1.3鸢尾花数据读入 # 从sklearn包datasets读入数据 from sklearn import datasets from pandas import DataFrame import ...
随机推荐
- 通过werkzeug了解wsgi
Django有wsgi当做socket,而flask自身是没有wsgi的,他是通过werkzeug来实现的. 看源码 下面看下源码是如何实现的: #这是我们写的flask代码from flask im ...
- (转).Net Core控制台生成exe能独立运行
原文介绍了两种方式,方式一经测试可用(生成exe在开发机器上可运行),但是因为服务器是windows server2012 r2,没有安装补丁,造成了困难,尚未在服务器上运行成功. (提示 api-m ...
- C# 多线程与高并发处理并且具备暂停、继续、停止功能
--近期有一个需要运用多线程的项目,会有并发概率,所以写了一份代码,可能有写地方还不完善,后续有需求在改 1 /// <summary> /// 并发对象 /// </summary ...
- 自学Python编程的第四天----------来自苦逼的转行人
2019-09-14 21:15:24 今天是学习Python的第四天,也是写博客的第四天 今天的内容是有关'列表'.'元组'.'range'的用法 列表:增删改查.列表的嵌套 元组:元组的嵌套 ra ...
- 五 查询数据SELECT 一、单表查询
一 单表查询的语法 二 关键字的执行优先级 三 简单查询 四 WHERE约束 五 分组查询:GROUP BY 六 HAVING过滤 七 查询排序:ORDER BY 八 限制查询的记录数:LIMIT 九 ...
- Node.js到底是什么
接触前端也有一段时间了,逐渐开始接触Node.js,刚刚接触Node.js的时候一直都以为Node.js就是JavaScript,当对Node.js有一定的了解之后,其实并不然两者之间有关系,其中的关 ...
- Maven父子工程,子项目变灰,提示该项目已被移除出maven父工程
最近使用maven的父子工程结构搭建微服务架构时,不知道什么原因, 子工程总是被莫名移除出父工程,然后打包处的项目名变成了灰色, 重启该项目时会提示,“该子项目已被移除,是否删除该项目”,这个 当然不 ...
- 源码解析-url状态检测神器ping-url
前言 ping-url是我最近开源的一个小工具,这篇文章也是专门写它设计理念的科普文. 为什么会做这个ping-url开源工具呢? 起因是:本小哥在某天接到一个特殊的需求,要用前端的方式判断任意一个u ...
- MySQL MGR--MGR部署
MGR部署 场景描述: 使用三台服务器搭建一个简单MGR集群,使用MySQL 5.7.24版本,服务器列表为: 192.168.1.147 192.168.1.148 192.168.1.149 1. ...
- MySQL Index--平衡树结构
树结构 ==================================================B树,即平衡二叉树,每个非叶子节点最多拥有两个子节点.所有键值出现在叶子节点和非叶子节点. ...