Softmax

一.Softmax回归简介

  案例:MNIST手写数字识别

  1.为了得到一张给定图片属于某个特定数字类的证据【evidence】,对图片像素进行加权求和。如果这个像素具有很强的证据说明这张图片不属于该类,那么相应的权值为负值相反如果这个像素拥有有利的证据支持这张图片属于这个类,那么权值即为正数。

    

  如下图,红色代表负数值,蓝色代表正数值:

    

  2.这里的softmax可以看做一个激励【activation】函数或者链接【link】函数,把我们定义的线性函数的输出转化成我们想要的格式,也就是关于10个数字类别的概率分布。因此,给定一张图片,它对于每一个数字的吻合度可以被softmax函数转化成一个概率值。

    

    

  展开等式右边的子式:

    

  3.softmax把输入值当成幂指数求值,再正则化这些结果值。这个幂运算表示,更大的证据对应更大的假设模型【hypothesis】里面的乘数权重值。反之拥有更少的证据意味着在假设模型里面拥有更小的乘数系数。假设模型里面的权值不可以是小于0的数值。Softmax会正则化这些权重值,使它们的总和等于1,以此构造一个有效的概率分布。

    

  如果把它写成一个等式:

    

  转化为矩阵乘和向量加:

    

  转化为公式:

    

二.代码实现

 1 # -*- coding: utf-8 -*-
2 """
3 Created on Thu Oct 18 18:02:26 2018
4
5 @author: zhen
6 """
7
8 from tensorflow.examples.tutorials.mnist import input_data
9 import tensorflow as tf
10
11 # mn.SOURCE_URL = "http://yann.lecun.com/exdb/mnist/"
12 my_mnist = input_data.read_data_sets("C:/Users/zhen/MNIST_data_bak/", one_hot=True)
13
14 # The MNIST data is split into three parts:
15 # 55,000 data points of training data (mnist.train)
16 # 10,000 points of test data (mnist.test), and
17 # 5,000 points of validation data (mnist.validation).
18
19 # Each image is 28 pixels by 28 pixels
20
21 # 输入的是一堆图片,None表示不限输入条数,784表示每张图片都是一个784个像素值的一维向量
22 # 所以输入的矩阵是None乘以784二维矩阵
23 x = tf.placeholder(dtype=tf.float32, shape=(None, 784))
24 # 初始化都是0,二维矩阵784乘以10个W值
25 W = tf.Variable(tf.zeros([784, 10]))
26 b = tf.Variable(tf.zeros([10]))
27
28 y = tf.nn.softmax(tf.matmul(x, W) + b)
29
30 # 训练
31 # labels是每张图片都对应一个one-hot的10个值的向量
32 y_ = tf.placeholder(dtype=tf.float32, shape=(None, 10))
33 # 定义损失函数,交叉熵损失函数
34 # 对于多分类问题,通常使用交叉熵损失函数
35 # reduction_indices等价于axis,指明按照每行加,还是按照每列加
36 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),
37 reduction_indices=[1]))
38 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
39
40 # 评估
41
42 # tf.argmax()是一个从tensor中寻找最大值的序号,tf.argmax就是求各个预测的数字中概率最大的那一个
43
44 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
45
46 # 用tf.cast将之前correct_prediction输出的bool值转换为float32,再求平均
47 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
48
49 # 初始化变量
50 sess = tf.InteractiveSession()
51 tf.global_variables_initializer().run()
52 # 创建Saver节点,用于保存训练的模型
53 saver = tf.train.Saver()
54 for i in range(100):
55 batch_xs, batch_ys = my_mnist.train.next_batch(100)
56 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
57 # 每隔一段时间保存一次中间结果
58 if i % 10 == 0:
59 save_path = saver.save(sess, "C:/Users/zhen/MNIST_data_bak/saver/softmax_middle_model.ckpt")
60
61 # print("TrainSet batch acc : %s " % accuracy.eval({x: batch_xs, y_: batch_ys}))
62 # print("ValidSet acc : %s" % accuracy.eval({x: my_mnist.validation.images, y_: my_mnist.validation.labels}))
63
64 # 测试
65 print("TestSet acc : %s" % accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
66 # 保存最终的模型
67 save_path = saver.save(sess, "C:/Users/zhen/MNIST_data_bak/saver/softmax_final_model.ckpt")
68
69 # 使用训练好的模型直接进行预测
70 with tf.Session() as sess_back:
71 saver.restore(sess_back, "C:/Users/zhen/MNIST_data_bak/saver/softmax_final_model.ckpt")
72 # 评估
73 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
74 accruary = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
75 # 测试
76 print(accuracy.eval({x : my_mnist.test.images, y_ : my_mnist.test.labels}))
77 # 总结
78 # 1,定义算法公式,也就是神经网络forward时的计算
79 # 2,定义loss,选定优化器,并指定优化器优化loss
80 # 3,迭代地对数据进行训练
81 # 4,在测试集或验证集上对准确率进行评测

三.结果

    

  

四.解析

  把训练好的模型存储落地磁盘,有利于多次使用和共享,也便于当训练出现异常时能恢复模型而不是重新训练!

TensorFlow实现Softmax回归(模型存储与加载)的更多相关文章

  1. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  2. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  3. 转 tensorflow模型保存 与 加载

    使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...

  4. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  5. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  6. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  7. sklearn模型保存与加载

    sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...

  8. tensorflow之逻辑回归模型实现

    前面一篇介绍了用tensorflow实现线性回归模型预测sklearn内置的波士顿房价,现在这一篇就记一下用逻辑回归分类sklearn提供的乳腺癌数据集,该数据集有569个样本,每个样本有30维,为二 ...

  9. TensorFlow构建卷积神经网络/模型保存与加载/正则化

    TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...

随机推荐

  1. tomcat-四种运行模式和三种部署模式(优化)

    四中运行模式如下: 1-bio: 传统的Java I/O操作,同步且阻塞IO. 2-nio: JDK1.4开始支持,同步阻塞或同步非阻塞IO 3-aio(nio.2): JDK7开始支持,异步非阻塞I ...

  2. Django--分页器(paginator)

    1 Django的分页器(paginator)简介 在页面显示分页数据,需要用到Django分页器组件 from django.core.paginator import Paginator Pagi ...

  3. spring security 实践 + 源码分析

    前言 本文将从示例.原理.应用3个方面介绍 spring data jpa. 以下分析基于spring boot 2.0 + spring 5.0.4版本源码 概述 Spring Security 是 ...

  4. SQL 必知必会·笔记<19>使用游标

    游标(cursor)是一个存储在DBMS服务器上的数据库查询,它不是一条SELECT语句,而是被该语句检索出来的结果集.在存储了游标之后,应用程序可以根据需要滚动或浏览其中的数据. 使用游标 使用游标 ...

  5. 音频标签化1:audioset与训练模型 | 音频特征样本

    随着机器学习的发展,很多"历史遗留"问题有了新的解决方案.这些遗留问题中,有一个是音频标签化,即如何智能地给一段音频打上标签的问题,标签包括"吉他"." ...

  6. redis-scala链接redis集群

    代码: package com.wenbronk.sparkstreaming.scala.commons import java.time.Duration import io.lettuce.co ...

  7. spring-boot-2.0.3应用篇 - shiro集成

    前言 上一篇:spring-boot-2.0.3源码篇 - 国际化,讲了如何实现国际化,实际上我工作用的模版引擎是freemaker,而不是thymeleaf,不过原理都是相通的. 接着上一篇,这一篇 ...

  8. JavaWeb学习(二十九)———— 事务

    一.事务的概念 事务指逻辑上的一组操作,组成这组操作的各个单元,要不全部成功,要不全部不成功. 例如:A——B转帐,对应于如下两条sql语句  update from account set mone ...

  9. [TJOI 2018]智力竞赛

    Description 题库链接 给出一张 \(m\) 个点的有向图.问可重最小路径覆盖是否 \(\leq n+1\) .若不,求最多用 \(n+1\) 条路径去覆盖,最大化未覆盖点点权最小值. \( ...

  10. 四层和七层负载均衡的特点及常用负载均衡Nginx、Haproxy、LVS对比

    一.四层与七层负载均衡在原理上的区别 图示: 四层负载均衡与七层负载均衡在工作原理上的简单区别如下图: 概述: 1.四层负载均衡工作在OSI模型中的四层,即传输层.四层负载均衡只能根据报文中目标地址和 ...