前面在mnist中使用了三个非线性层来增加模型复杂度,并通过最小化损失函数来更新参数,下面实用最底层的方式即张量进行前向传播(暂不采用层的概念)。

主要注意点如下:

  · 进行梯度运算时,tensorflow只对tf.Variable类型的变量进行记录,而不对tf.Tensor或者其他类型的变量记录

  · 进行梯度更新时,如果采用赋值方法更新即w1=w1+x的形式,那么所得的w1是tf.Tensor类型的变量,所以要采用原地更新的方式即assign_sub函数,或者再次使用tf.Variable包起来(不推荐)

代码如下:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets
import os os.environ['TF_CPP_MIN_LOG_LEVEL']='' # x:[60k,28,28]
# y:[60k]
(x,y),_=datasets.mnist.load_data() x = tf.convert_to_tensor(x,dtype=tf.float32)/255.0
y = tf.convert_to_tensor(y,dtype=tf.int32) print(x.shape,y.shape,x.dtype,y.dtype)
print(tf.reduce_min(x),tf.reduce_max(x))
print(tf.reduce_min(y),tf.reduce_max(y)) train_db=tf.data.Dataset.from_tensor_slices((x,y)).batch(128)
train_iter=iter(train_db)
sample=next(train_iter)
print('batch:',sample[0].shape,sample[1].shape) # [b,784]=>[b,256]=>[b,128]=>[b,10]
# w shape[dim_in,dim_out] b shape[dim_out]
w1 = tf.Variable(tf.random.truncated_normal([784,256],stddev=0.1))
b1 = tf.Variable(tf.zeros([256])) w2 = tf.Variable(tf.random.truncated_normal([256,128],stddev=0.1))
b2 = tf.Variable(tf.zeros([128])) w3 = tf.Variable(tf.random.truncated_normal([128,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])) # 设置学习率
lr = 0.001
for epoch in range(10): # 对数据集迭代
for step,(x,y) in enumerate(train_db):
# x:[128,28,28] y:[128]
x = tf.reshape(x,[-1,28*28]) with tf.GradientTape() as tape: # tape只会跟踪tf.Variable
# x:[b,28*28]
# [b,784]@[784,256]+[256]=>[b,256]+[256]
h1 = x@w1 + b1
h1 = tf.nn.relu(h1) # 去线性化
h2 = h1@w2 + b2
h2 = tf.nn.relu(h2) # 去线性化
out = h2@w3 + b3 # 计算损失
y_onehot = tf.one_hot(y,depth=10)
# mse = mean(sum(y-out)^2)
loss = tf.square(y_onehot - out)
# mean:scalar
loss = tf.reduce_mean(loss) # 计算梯度
grads = tape.gradient(loss,[w1,b1,w2,b2,w3,b3])
# w1 = w1 -lr * w1_grad
w1.assign_sub(lr * grads[0]) # 原地更新
b1.assign_sub(lr * grads[1])
w2.assign_sub(lr * grads[2])
b2.assign_sub(lr * grads[3])
w3.assign_sub(lr * grads[4])
b3.assign_sub(lr * grads[5]) if step % 100 == 0:
print('epoch = ',epoch,'step =',step,',loss =',float(loss))

效果如下:

前向传播和反向传播实战(Tensor)的更多相关文章

  1. 机器学习(ML)八之正向传播、反向传播和计算图,及数值稳定性和模型初始化

    正向传播 正向传播的计算图 通常绘制计算图来可视化运算符和变量在计算中的依赖关系.下图绘制了本节中样例模型正向传播的计算图,其中左下角是输入,右上角是输出.可以看到,图中箭头方向大多是向右和向上,其中 ...

  2. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  3. caffe中 softmax 函数的前向传播和反向传播

    1.前向传播: template <typename Dtype> void SoftmaxLayer<Dtype>::Forward_cpu(const vector< ...

  4. caffe中的前向传播和反向传播

    caffe中的网络结构是一层连着一层的,在相邻的两层中,可以认为前一层的输出就是后一层的输入,可以等效成如下的模型 可以认为输出top中的每个元素都是输出bottom中所有元素的函数.如果两个神经元之 ...

  5. BP原理 - 前向计算与反向传播实例

    Outline 前向计算 反向传播 很多事情不是需要聪明一点,而是需要耐心一点,踏下心来认真看真的很简单的. 假设有这样一个网络层: 第一层是输入层,包含两个神经元i1 i2和截距b1: 第二层是隐含 ...

  6. 反向传播算法(前向传播、反向传播、链式求导、引入delta)

    参考链接: 一文搞懂反向传播算法

  7. Tensorflow笔记——神经网络图像识别(一)前反向传播,神经网络八股

      第一讲:人工智能概述       第三讲:Tensorflow框架         前向传播: 反向传播: 总的代码: #coding:utf-8 #1.导入模块,生成模拟数据集 import t ...

  8. BP神经网络反向传播之计算过程分解(详细版)

    摘要:本文先从梯度下降法的理论推导开始,说明梯度下降法为什么能够求得函数的局部极小值.通过两个小例子,说明梯度下降法求解极限值实现过程.在通过分解BP神经网络,详细说明梯度下降法在神经网络的运算过程, ...

  9. 深度学习与CV教程(4) | 神经网络与反向传播

    作者:韩信子@ShowMeAI 教程地址:http://www.showmeai.tech/tutorials/37 本文地址:http://www.showmeai.tech/article-det ...

随机推荐

  1. 2020年,手把手教你如何在CentOS7上一步一步搭建LDAP服务器的最新教程

    同步滚动:关 什么是LDAP 什么是LDAP? 要想知道一个概念,最简单的办法就是wikipedia,当然也可以百科. LDAP全称是轻型目录访问协议(Lightweight Directory Ac ...

  2. HBASE手动触发major_compact

    1.定时执行脚本#!/bin/bash source /etc/profile sh ./hbase shell <<EOF major_compact 'table_name' EOF ...

  3. Fastdfs php扩展访问

    一.安装FastDFS client php extension compiled under PHP 5.4 and PHP 7.0   1.安装php扩展,进入fastdfs源码文件夹中的  ph ...

  4. [RHEL8]关闭SELinux(同CentOS7)

    修改配置文件(永久修改) # vi /etc/selinux/config SELINUX=disabled # 关闭 SELINUX=enforcing # 开启 命令方式(临时修改重启失效) # ...

  5. flyway使用简介

    官网 https://flywaydb.org/ 背景 Flyway是独立于数据库的应用.管理并跟踪数据库变更的数据库版本管理工具.用通俗的话讲,Flyway可以像Git管理不同人的代码那样,管理不同 ...

  6. 基于 React 实现一个 Transition 过渡动画组件

    过渡动画使 UI 更富有表现力并且易于使用.如何使用 React 快速的实现一个 Transition 过渡动画组件? 基本实现 实现一个基础的 CSS 过渡动画组件,通过切换 CSS 样式实现简单的 ...

  7. StarUML之九、starUML的一些特殊属性的说明

    UML的扩充性机制允许你在控制的方式下扩充UML语言. 这一类的机制包括:stereotype,标记值.约束. Stereotype扩充了UML的词汇表,允许你创建新的建筑块,这些建筑块从已有的继承而 ...

  8. java工作流系统表单自动 获取数据

    关键词:工作流快速开发平台  工作流流设计  业务流程管理   asp.net 开源工作流  bpm工作流系统  java工作流主流框架  自定义工作流引擎 表单设计器  流程设计器 什么是数据自动获 ...

  9. Android中TimePicker时间选择器的使用和获取选择的时和分

    场景 实现效果如下 注: 博客: https://blog.csdn.net/badao_liumang_qizhi 关注公众号 霸道的程序猿 获取编程相关电子书.教程推送与免费下载. 实现 将布局改 ...

  10. VUE中使用XLSX实现导出excel表格

    简介 项目中经常会用导出数据的场景,这里介绍 VUE 中如何使用插件 xlsx 导出数据 安装 ## 1.使用 npm 或 yarn 安装依赖(三个依赖) npm install -S file-sa ...