前面在mnist中使用了三个非线性层来增加模型复杂度,并通过最小化损失函数来更新参数,下面实用最底层的方式即张量进行前向传播(暂不采用层的概念)。

主要注意点如下:

  · 进行梯度运算时,tensorflow只对tf.Variable类型的变量进行记录,而不对tf.Tensor或者其他类型的变量记录

  · 进行梯度更新时,如果采用赋值方法更新即w1=w1+x的形式,那么所得的w1是tf.Tensor类型的变量,所以要采用原地更新的方式即assign_sub函数,或者再次使用tf.Variable包起来(不推荐)

代码如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os os.environ['TF_CPP_MIN_LOG_LEVEL']='' # x:[60k,28,28]
# y:[60k]
(x,y),_=datasets.mnist.load_data() x = tf.convert_to_tensor(x,dtype=tf.float32)/255.0
y = tf.convert_to_tensor(y,dtype=tf.int32) print(x.shape,y.shape,x.dtype,y.dtype)
print(tf.reduce_min(x),tf.reduce_max(x))
print(tf.reduce_min(y),tf.reduce_max(y)) train_db=tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter=iter(train_db)
sample=next(train_iter)
print('batch:',sample[0].shape,sample[1].shape) # [b,784]=>[b,256]=>[b,128]=>[b,10]
# w shape[dim_in,dim_out] b shape[dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1 = tf.Variable(tf.zeros([256])) w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128])) w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) # 设置学习率
lr = 0.001
for epoch in range(10): # 对数据集迭代
for step,(x,y) in enumerate(train_db):
# x:[128,28,28] y:[128]
x = tf.reshape(x,[-1,28*28]) with tf.GradientTape() as tape: # tape只会跟踪tf.Variable
# x:[b,28*28]
# [b,784]@[784,256]+[256]=>[b,256]+[256]
h1 = x@w1 + b1
h1 = tf.nn.relu(h1) # 去线性化
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2) # 去线性化
out = h2@w3 + b3 # 计算损失
y_onehot = tf.one_hot(y,depth=10)
# mse = mean(sum(y-out)^2)
loss = tf.square(y_onehot - out)
# mean:scalar
loss = tf.reduce_mean(loss) # 计算梯度
grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
# w1 = w1 -lr * w1_grad
w1.assign_sub(lr * grads[0]) # 原地更新
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5]) if step % 100 == 0:
print('epoch = ',epoch,'step =',step,',loss =',float(loss))

效果如下:

前向传播和反向传播实战(Tensor)的更多相关文章

  1. 机器学习(ML)八之正向传播、反向传播和计算图,及数值稳定性和模型初始化

    正向传播 正向传播的计算图 通常绘制计算图来可视化运算符和变量在计算中的依赖关系.下图绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出.可以看到,图中箭头方向大多是向右和向上,其中 ...

  2. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  3. caffe中 softmax 函数的前向传播和反向传播

    1.前向传播: template <typename Dtype> void SoftmaxLayer<Dtype>::Forward_cpu(const vector< ...

  4. caffe中的前向传播和反向传播

    caffe中的网络结构是一层连着一层的,在相邻的两层中,可以认为前一层的输出就是后一层的输入,可以等效成如下的模型 可以认为输出top中的每个元素都是输出bottom中所有元素的函数.如果两个神经元之 ...

  5. BP原理 - 前向计算与反向传播实例

    Outline 前向计算 反向传播 很多事情不是需要聪明一点,而是需要耐心一点,踏下心来认真看真的很简单的. 假设有这样一个网络层: 第一层是输入层,包含两个神经元i1 i2和截距b1: 第二层是隐含 ...

  6. 反向传播算法(前向传播、反向传播、链式求导、引入delta)

    参考链接: 一文搞懂反向传播算法

  7. Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

      第一讲:人工智能概述       第三讲:Tensorflow框架         前向传播: 反向传播: 总的代码: #coding:utf-8 #1.导入模块,生成模拟数据集 import t ...

  8. BP神经网络反向传播之计算过程分解(详细版)

    摘要:本文先从梯度下降法的理论推导开始,说明梯度下降法为什么能够求得函数的局部极小值.通过两个小例子,说明梯度下降法求解极限值实现过程.在通过分解BP神经网络,详细说明梯度下降法在神经网络的运算过程, ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. shell命令之一天一见:awk

    AWK是一种优良的文本处理工具,Linux及Unix环境中现有的功能最强大的数据处理引擎之一. 这种编程及数据操作语言(其名称得自于它的创始人阿尔佛雷德·艾侯.彼得·溫伯格和布萊恩·柯林漢姓氏的首个字 ...

  2. redis命令总结与持久化

    上篇redis文章为大家介绍了redis与它的部署工作.这次我们来说一下redis的操作命令与持久化 一.命令总结 1)String操作 6379> set k1 v1 #设定值 6379> ...

  3. 编译生成protobuf的jar包

    编译生成protobuf的jar包 配置maven 安装maven,并修改maven源为阿里云 下载maven wget http://mirror.bit.edu.cn/apache/maven/m ...

  4. 管理 使用 FastDFS

    启动管理tracker: 1. 启动文件+配置文件+命令 /usr/bin/fdfs_trackerd <config_file> [start | stop | restart] 举例: ...

  5. 一台机器上同时运行两个tomcat

    修改conf/server.xml文件,修改地方有三处 如图

  6. JMeter接口测试-循环读取库的用户信息

    前言 如何实现循环读取数据库的用户信息,并传递到下一个登录请求呢,下面我们一起来学习吧!在之前我们已经学会了利用JMeter连接数据库了,具体操作可以看我之前的随笔JMeter接口测试-JDBC测试 ...

  7. JDK SPI 机制

    一.概述 最早看到 SPI 这个机制是在 dubbo 实现 中,最近发现原来也不是什么新东西,竟然就是 JDK 中内置的玩意,今天就来一探究竟,看看它到底是什么玩意! SPI的全称是 Service ...

  8. PHP将图片base64编码传输

    PHP函数源码 function imgToBase64($img_file) { $img_base64 = ''; if (file_exists($img_file)) { $app_img_f ...

  9. 记一个开发是遇到的坑之Oralce 字符串排序

    简单描述一下情况,就是存储过程中用一个字符串类型的字段作为患者就诊的排序号,结果莫名发现叫完1号后叫了11.12等患者.用户的反馈不一定准确,自己加了日志的,赶紧拷贝日志来观察一下.结果发现实际情况就 ...

  10. 【python数据挖掘】爬取豆瓣影评数据

    概述: 爬取豆瓣影评数据步骤: 1.获取网页请求 2.解析获取的网页 3.提速数据 4.保存文件 源代码: # 1.导入需要的库 import urllib.request from bs4 impo ...