本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

R中的softmaxreg包,发自2016-09-09,链接:https://cran.r-project.org/web/packages/softmaxreg/index.html

——————————————————————————————————————————————————————————————————

一、介绍

Softmax Regression模型本质还是一个多分类模型,对Logistic Regression 逻辑回归的拓展。如果将Softmax Regression模型和神经网络隐含层结合起来,可以进一步提升模型的性能,构成包含多个隐含层和最后一个Softmax层的多层神经网络模型。之前发现R里面没有特别适合的方法支持多层的Softmax 模型,于是就想直接用R语言写一个softmaxreg 包。可以支持大部分的多分类问题,其中的两个示例:MNIST手写体识别和多文档分类(Multi-Class DocumentClassification) 的文档如下

二、示例文档

2.1 MNIST手写体识别数据集

MNIST手写体识别的数据集是图像识别领域一个基本数据集,很多模型诸如CNN卷积神经网络等模型都经常在这个数据集上测试都能够达到97%以上的准确率。 这里想比较一下包含隐含层的softmaxreg模型,测试结果显示模型的准确率能达到93% 左右。

Part1、下载和Load数据

MNIST手写体识别的数据集可以直接从网站下载http://yann.lecun.com/exdb/mnist/,一共四个文件,分别下载下来并解压。文件格式比较特殊,可以用softmaxreg 包中的load_image_file 和load_label_file 两个函数读取。

训练集有60000幅图片,每个图片都是由16*16个像素构成,代表了0-9中的某一个数字,比如下图。

利用softmaxreg 包训练一个10分类的MNIST手写体识别的模型,用load_image_file 和load_label_file 来分别读取训练集的图像数据和标签的数据 (Reference: brendano'connor - gist.github.com/39760的读取方法)

  1. library(softmaxreg)
  2. path= "D: \\DeepLearning\\MNIST\\"
  3. #10-classclassification, Digit 0-9
  4. x= load_image_file(paste(path,'train-images-idx3-ubyte', sep=""))
  5. y= load_label_file(paste(path,'train-labels-idx1-ubyte', sep=""))
  6. xTest= load_image_file(paste(path,'t10k-images-idx3-ubyte',sep=""))
  7. yTest= load_label_file(paste(path,'t10k-labels-idx1-ubyte', sep=""))

复制代码

可以用show_digit函数来看一个数字的图像,比如查看某一个图片,比如第2副

  1. show_digit(x[2,])

复制代码


Part2、训练模型

利用softmaxReg函数,训练集输入和标签分别为为x和y,maxit 设置最多多少个Epoch, algorithm为优化的算法,rate为学习率,batch参数为SGD随机梯度下降每个Mini-Batch的样本个数。 收敛后用predict方法来看看测试集Test的准确率怎么样

  1. ## Normalize Input Data
  2. x = x/255
  3. xTest = xTest/255
  4. model1= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "sgd", rate = 0.01, batch = 1000)
  5. loss1= model1$loss
  6. #Test Accuracy
  7. yFit= predict(model1, newdata = x)
  8. table(y,yFit)

复制代码

Part3、比较不同优化算法的收敛速度

  1. model2= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "adagrad", rate = 0.01, batch =1000)
  2. loss2= model2$loss
  3. model3= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "rmsprop", rate = 0.01, batch =1000)
  4. loss3= model3$loss
  5. model4= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "momentum", rate = 0.01, batch= 1000)
  6. loss4= model4$loss
  7. model5= softmaxReg(x, y, hidden = c(), funName = 'sigmoid', maxit = 15, rang = 0.1,type = "class", algorithm = "nag", rate = 0.01, batch = 1000)
  8. loss5= model5$loss
  9. #plot the loss convergence
  10. iteration= c(1:length(loss1))
  11. myplot= plot(iteration, loss1, xlab = "iteration", ylab = "loss",ylim = c(0, max(loss1,loss2,loss3,loss4,loss5) + 0.01),
  12. type = "p", col ="black", cex = 0.7)
  13. title("ConvergenceComparision Between Learning Algorithms")
  14. points(iteration,loss2, col = "red", pch = 2, cex = 0.7)
  15. points(iteration,loss3, col = "blue", pch = 3, cex = 0.7)
  16. points(iteration,loss4, col = "green", pch = 4, cex = 0.7)
  17. points(iteration,loss5, col = "magenta", pch = 5, cex = 0.7)
  18. legend("topright",c("SGD", "Adagrad", "RMSprop","Momentum", "NAG"),
  19. col = c("black", "red","blue", "green", "magenta"),pch = c(1,2,3,4,5))
  20. save.image()

复制代码

如果maxit 迭代次数过大,模型运行时间较长,可以保存图像,最后可以看到AdaGrad, rmsprop,momentum, nag 和标准SGD这几种优化算法的收敛速度的比较效果。关于优化算法这个帖子有很好的总结:

http://cs231n.github.io/neural-networks-3/

2.2 多类别的文档分类

Softmax regression模型的每个输入为一个文档,用一个字符串表示。其中每个词word都可以用一个word2vec模型训练的word Embedding低维度的实数词向量表示。在softmaxreg包中有一个预先训练好的模型:长度为20维的英文词向量的字典,直接用data(word2vec) 调用就可以了。

假设我们需要对UCI的C50新闻数据集进行分类,数据集包含多个作者写的新闻报道,每个作者的新闻文件都在一个单独的文件夹中。 我们假设挑选5个作者的文章进行训练softmax regression 模型,然后在测试集中预测任意文档属于哪一个作者,这就构成了一个5分类的问题。

Part1, 载入预先训练好的 英文word2vec 字典表

  1. library(softmaxreg)
  2. data(word2vec) # default 20 dimension word2vec dataset
  3. #### Reuter 50 DataSet UCI Archived Dataset from

复制代码

Part2,利用loadURLData函数从网址下载数据并且解压到folder目录

  1. ## URL: "http://archive.ics.uci.edu/ml/machine-learning-databases/00217/C50.zip"
  2. URL = "http://archive.ics.uci.edu/ml/machine-learning-databases/00217/C50.zip"
  3. folder = getwd()
  4. loadURLData(URL, folder, unzip = TRUE)

复制代码

Part3,利用wordEmbed() 函数作为lookup table,从默认的word2vec数据集中查找每个单词的向量表示,默认20维度,可以自己训练自己的字典数据集来替换。

  1. ##Training Data
  2. subFoler = c('AaronPressman', 'AlanCrosby', 'AlexanderSmith', 'BenjaminKangLim', 'BernardHickey')
  3. docTrain = document(path = paste(folder, "/C50train/",subFoler, sep = ""), pattern = 'txt')
  4. xTrain = wordEmbed(docTrain, dictionary = word2vec)
  5. yTrain = c(rep(1,50), rep(2,50), rep(3,50), rep(4,50), rep(5,50))
  6. # Assign labels to 5 different authors
  7. ##Testing Data
  8. docTest = document(path = paste(folder, "/C50test/",subFoler, sep = ""), pattern = 'txt')
  9. xTest = wordEmbed(docTest, dictionary = word2vec)
  10. yTest = c(rep(1,50), rep(2,50), rep(3,50), rep(4,50), rep(5,50))
  11. samp = sample(250, 50)
  12. xTest = xTest[samp,]
  13. yTest = yTest[samp]

复制代码

Part4,训练模型,构建一个结构为20-10-5的模型,输入层为20维,即词向量的维度,隐含层的节点数为10,最后softmax层输出节点个数为5.

  1. ## Train Softmax Classification Model, 20-10-5
  2. softmax_model = softmaxReg(xTrain, yTrain, hidden = c(10), maxit = 500, type = "class",
  3. algorithm = "nag", rate = 0.05, batch = 10, L2 = TRUE)
  4. summary(softmax_model)
  5. yFit = predict(softmax_model, newdata = xTrain)
  6. table(yTrain, yFit)
  7. ## Testing
  8. yPred = predict(softmax_model, newdata = xTest)
  9. table(yTest, yPred)

复制代码

# 增加embedding的维度到50或者100可以提升模型准确度;

相关资料:

关于Stanford的中英文

http://ufldl.stanford.edu/wiki/index.php/Softmax%E5%9B%9E%E5%BD%92

softmaxregR包的下载地址:

https://cran.r-project.org/web/packages/softmaxreg/index.html

R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)的更多相关文章

  1. 深度学习-mnist手写体识别

    mnist手写体识别 Mnist数据集可以从官网下载,网址: http://yann.lecun.com/exdb/mnist/ 下载下来的数据集被分成两部分:55000行的训练数据集(mnist.t ...

  2. keras入门--Mnist手写体识别

    介绍如何使用keras搭建一个多层感知机实现手写体识别及搭建一个神经网络最小的必备知识 import keras # 导入keras dir(keras) # 查看keras常用的模块 ['Input ...

  3. Tensorflow中使用CNN实现Mnist手写体识别

    本文参考Yann LeCun的LeNet5经典架构,稍加ps得到下面适用于本手写识别的cnn结构,构造一个两层卷积神经网络,神经网络的结构如下图所示: 输入-卷积-pooling-卷积-pooling ...

  4. (六)6.10 Neurons Networks implements of softmax regression

    softmax可以看做只有输入和输出的Neurons Networks,如下图: 其参数数量为k*(n+1) ,但在本实现中没有加入截距项,所以参数为k*n的矩阵. 对损失函数J(θ)的形式有: 算法 ...

  5. CS229 6.10 Neurons Networks implements of softmax regression

    softmax可以看做只有输入和输出的Neurons Networks,如下图: 其参数数量为k*(n+1) ,但在本实现中没有加入截距项,所以参数为k*n的矩阵. 对损失函数J(θ)的形式有: 算法 ...

  6. Exercise : Softmax Regression

    Step 0: Initialize constants and parameters Step 1: Load data Step 2: Implement softmaxCost Implemen ...

  7. 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)

    博文主要内容有: 1.softmax regression的TensorFlow实现代码(教科书级的代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3 ...

  8. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  9. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

随机推荐

  1. 101490E Charles in Charge

    题目连接 http://codeforces.com/gym/101490 题目大意 你有一张图,每两点之间有一定距离,计算出比最短路大x%之内的路径中最长边的最小值 分析 先跑一遍最短路,然后二分答 ...

  2. mysql修改root用户密码

    自我总结,欢迎拍砖! 目的:若root用户密码忘记,则需要重新设置root用户的密码. 步骤: 1.找到mysql安装目录下的 my.ini 文件,找到[mysqlId]一行,在下方添加语句:skip ...

  3. C#基础(七)虚函数

    若一个实例方法声明前带有virtual关键字,那么这个方法就是虚方法.虚方法与非虚方法的最大不同是,虚方法的实现可以由派生类所取代,这种取代是通过方法的重写实现的(以后再讲)虚方法的特点:虚方法前不允 ...

  4. MySQL密码重置(root用户)

    分别在Windows下和Linux下重置了MYSQL的root的密码: 在windows下: 1:进入cmd,停止mysql服务:Net stop mysql 到mysql的安装路径启动mysql,在 ...

  5. iOS-硬件授权检测【通讯录、相机、相册、日历、麦克风、定位授权】

    总结下几个常用到的获取手机权限,从iOS8以后,获取手机某种权限需要在info.plist文件中添加权限的描述文件 <key>NSContactsUsageDescription</ ...

  6. android adb shell input各种妙用

    项目中使用一个开发版,预留两个usb接口.类似华硕TinkerBoard. 一个用户连接摄像头,一个用于adb调试.结果就没了鼠标的接口.多次切换鼠标和摄像头插头,非常不方便,带摄像头的app没法调试 ...

  7. 孤立的SQL用户

    问题 最近公司很多数据库在上云,也有一部分在下云.这期间出现了很多问题,其中一个比较恶心的问题就是"孤立用户".当数据库备份还原以后用以前的用户发现不能登录.一开始以为是登录账号没 ...

  8. 七、Selenium与phantomJS----------动态页面模拟点击、网站模拟登录

    每天一个小实例1(动态页面模拟点击,并爬取你想搜索的职位信息) from selenium import webdriver from bs4 import BeautifulSoup # 调用环境变 ...

  9. devexpress entity framework 与 asp.net mvc的坑

    最近在做一个使用ASP.NET MVC DEVEXPRESS和EF的OA模块 遇到不少问题这里记录一下: 1 如果项目中存在多个上下文类(DBContext的派生类),在做数据迁移的时候需要在不同目录 ...

  10. 使用域账号统一管理cisco网络设备

    1.思科设备和微软系统整合的背景: 公司内部有一定数量的客户端,为了实现统一化,在管理内部部署了域架构,这样可以通过组策略对客户端进行批量化管理,提高了管理的效率. 同样公司内部有一定数量的网络设备( ...