mnist手写数字问题初体验
上一篇我们提到了回归问题中的梯度下降算法,而且我们知道线性模型只能解决简单的线性回归问题,对于高维图片,线性模型不能完成这样复杂的分类任务。那么是不是线性模型在离散值预测或图像分类问题中就没有用武之地了呢?
本篇我们就套用regression中的部分机制来处理classification中的问题。
在这里首先介绍一下激活函数。
所谓激活函数,实际上就是引入非线性因子,将线性模型去线性化,增强模型的表达能力。ReLU激活函数是我要介绍的第一个激活函数,其定义式为φ(z)=max{0,z},图像表示如下:
简单的说relu就是一个取最大值的函数,在负区间取值为0,正区间取值不变,这种操作被称为单侧抑制(输出为0时代表神经元不会被激活)。单侧抑制的特点就是同一时间只会有一部分神经元被激活(结合函数图像可以看出),也就使得神经元具有了稀疏激活性。加入relu激活函数的神经元被称作整流线性单元,它与线性单元非常相似,唯一的区别就是在一半定义域上输出为0。整流线性单元易于优化,当其处于激活状态时(输出不为0),它的一阶导数能够保持一个较大值(等于1),并且处处一致,它的二阶导数几乎处处为0,这样的好处就是避免了梯度下降时的梯度消失问题(可参考前一篇回归问题的随笔)。
简单介绍了激活函数,那么是不是将激活函数引入我们的线性模型out=X@w+b就能使其解决复杂的图像分类问题了呢?
很显然不是的,虽然加了激活函数,但是我们可以看到模型变为out=relu(X@w+b) 依然还是太简单。那么怎么办呢?
我们可以联系一下零件加工的流程,从原料到成品,零件的加工经历了多个工序,期间每一道工序都是由前一道工序为基础,这时候,原料就相当于神经网络的输入,成品零件就相当于神经网络的输出,他们中间并不是也不能一步到位,而是经过若干“隐藏”的工序一步一步的生成产品。我们的模型同样可以借助于这种思想。即给数据处理多添加几道所谓的“工序”,我们称之为“隐藏层”,因为我们关心的只有模型的输入和输出,隐藏层的数据是我们不可见的(当然也可以在运行过程中打印出来方便调试),下面我们就利用这样的思想来解决mnist手写数字分类问题。
我们使用的是mnist数据集,也是深度学习的基础入门数据集。它一共有70k张不同的手写数字图片,其中60k用来训练模型,10k用来评估模型,且所有图片均为28*28的灰度图。我们首先设计一个稍微复杂的模型
h1=relu(X@w1+b1)
h2=relu(h1@w2+b2)
out=relu(h2@w3+b3)
其中X为输入,out为输出,h1、h2均为隐藏层,且除输入层外每一层的输入均是前一层的输出。
首先我们将输入的28*28*1的图片扁平化,即将每张图片转化成784维的向量(28*28=784),这样的好处是可以以矩阵的形式同时喂入多张图片(每一行向量为一张灰度图的信息),提高效率。对于输出out,我们令其输出一个10维的向量,代表10个数字的概率。模型可以用以下公式概括:
out=relu { relu { relu[ X@w1+b1 ] @w2+b2 }@w3+b3 }
pred=argmax(out)
loss=MSE(out,label) (均方误差损失函数即loss=∑(label-out)2)
minimize loss→[w1',b1',w2',b2',w3',b3']
参数调整完成后,可以对新的输入x进行运算从而得到对应的输出
代码如下:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets # 屏蔽通知和警告信息,减少用处不大的问题输出
os.environ['TF_CPP_MIN_LOG_LEVEL']='' (x, y), (x_val, y_val) = datasets.mnist.load_data()
x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.
y = tf.convert_to_tensor(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
print(x.shape, y.shape)
train_dataset = tf.data.Dataset.from_tensor_slices((x, y))
train_dataset = train_dataset.batch(200) # 搭建网络结构
model = keras.Sequential([
layers.Dense(512, activation='relu'),
layers.Dense(256, activation='relu'),
layers.Dense(10)]) # 初始化优化器为梯度下降优化器
optimizer = optimizers.SGD(learning_rate=0.001) def train_epoch(epoch): # Step4.循环迭代
for step, (x, y) in enumerate(train_dataset): with tf.GradientTape() as tape:
# 将输入数据压平 [b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28*28))
# Step1. 计算输出
# 输入域数据经过神经网络降维 [b, 784] => [b, 10]
out = model(x)
# Step2. 计算损失
loss = tf.reduce_sum(tf.square(out - y)) / x.shape[0] # Step3. 优化更新参数 w1, w2, w3, b1, b2, b3
grads = tape.gradient(loss, model.trainable_variables)
# w' = w - lr * grad
optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0:
print(epoch, step, 'loss:', loss.numpy()) def train(): for epoch in range(30): train_epoch(epoch) if __name__ == '__main__':
train()
运行结果如下:
可以看到损失从初始的1.65降到0.25,在这里我们先只对mnist进行一个初步探索,测试一下模型的表现,后续会通过一些更好的优化方法来不断改良我们的模型。
mnist手写数字问题初体验的更多相关文章
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 第三节,CNN案例-mnist手写数字识别
卷积:神经网络不再是对每个像素做处理,而是对一小块区域的处理,这种做法加强了图像信息的连续性,使得神经网络看到的是一个图像,而非一个点,同时也加深了神经网络对图像的理解,卷积神经网络有一个批量过滤器, ...
- mnist 手写数字识别
mnist 手写数字识别三大步骤 1.定义分类模型2.训练模型3.评价模型 import tensorflow as tfimport input_datamnist = input_data.rea ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- Tensorflow可视化MNIST手写数字训练
简述] 我们在学习编程语言时,往往第一个程序就是打印“Hello World”,那么对于人工智能学习系统平台来说,他的“Hello World”小程序就是MNIST手写数字训练了.MNIST是一个手写 ...
随机推荐
- HDU_1506_单调栈
http://acm.hdu.edu.cn/showproblem.php?pid=1506 从栈底到栈顶从小到大排序,碰到比栈定小的元素,出栈处理,直到栈顶比元素小或者栈为空. 数组最后多加了个-1 ...
- POJ_1185_状态压缩dp
http://poj.org/problem?id=1185 一次考虑两行,比一行略为复杂.sta保存每种状态炮兵位置,sum保存每种状态当行炮兵总数,a保存地形,dp[i][j][k]表示到第i行当 ...
- Linux密码策略--设置随机密码
#!/bin/bash # @Author: HanWei # @Date: -- :: # @Last Modified by: HanWei # @Last Modified -- :: # @E ...
- Mysql:自动化备份
简介 在这个数据为王的时代,数据的备份十分重要,这里就分享一篇mysql数据库自动备份的脚本(是从网上搜到的),其将配置文件和备份脚本分离,提高了安全性,脚本风格规范严谨,分享给大家希望对需要的小伙伴 ...
- 线索二叉树C++实现
#include<iostream> #include<stdlib.h> #define maxsize 100 using namespace std; typedef s ...
- FFmpeg命令读取RTMP流如何设置超时时间
子标题:FFmpeg命令录制RTMP流为FLV文件时如何设置超时时间 | FFmpeg命令如何解决录制产生阻塞的问题0x001: 前言 今天在测试程序时遇到两个问题.Q1:ffmpeg录制RTMP流并 ...
- rc.local 启动内容不生效
系统版本 CentOS Linux release 7.2.1511 问题 :/etc/rc.local 中的内容 启动机器后不生效 经过检查 /etc/rc.local 是 /etc/rc.d/ ...
- MacBook通过SSH远程访问Parallel中的Ubuntu简明教程
作为一个前端,后端也需要了解,最终选择PHP入手学习,本来想选择Python,思前想后还是PHP作为Web开发比较合适,环境最终选择Ubuntu开发,由于是第一次,遇到不少坑,经过不懈的努力不断Goo ...
- String实例 (练习)
练习题1:用户输入一段字符串,要求统计出在该段字符串中,数字,字母以及其他字符各出现过几次??? 代码实现: 运行结果: 补充:1. 连接符的使用: +用作连接符时,只能连接字符串,即“ ”双 ...
- v-charts x轴字体斜显示
如下图,因为X轴内容太多,放不下,插件默认间隔显示需求:X轴内容要全部显示出来(只有斜显示或固定宽多余的用省略代替,本来需要就是想显示全部内容,所以只能取斜显示的方案) 先看看v-charts的文档: ...