神经网络模型

每个node包含两种操作:线性变换(仿射变换)和激发函数(activation function)。

其中仿射变换是通用的,而激发函数可以很多种,如下图。

MLLib中实现ANN

使用两层(Layer)来对应模型中的一层:

  • AffineLayer 仿射变换: output = W · input + b
  • 如果是最后一层,使用SoftmaxLayerWithCrossEntropyLoss或者SigmoidLayerWithSquaredError;如果是中间层,则使用functionalLayer(new SigmoidFunction()). 目前MLlib只支持sigmoid函数,实际上ReLU激发函数更普遍

BP算法计算Gradient的四个步骤:

对照BP算法的步骤,可以发现分隔成Affine和Activation的好处。BP1和BP2中的计算,不同的activation函数有不同的计算形式,将affine变换和activation函数解耦方便组合,进而方便形成各种类型的神经网络。

MLLib FeedForward Trainer

训练器重要模块如下:

ANN模型中每层对应AffineLayer + FunctionalLayerModel OR SofrmaxLayerModelWIthCrossEntropyLoss

每个LayerModel实现三个函数:eval, computePrevDelta, grad, 作为输出层的SoftmaxLayerModel有些特殊,额外具有LossFunction特性。

可验证affine+activation LayerModel的计算组合跟BP1-4一致。

AffineLayerModel (仿射变换层)

  • eval

    \(\text{output} = W \cdot \text{input} + b\)

  • computePrevDelta

    \(prev\delta = W * \delta\)

  • grad

    $\dot{W} = input \cdot \delta^l / \text{data size} $

    input is \(a^{l-1}\),前一层的激发函数输出

    \(\dot{b} = \delta^l / \text{data size}\)

FunctionalLayerModel(activate function \(\sigma\))

作为affineModel的activation model,只影响prev\(\delta\) 的计算,grad不计算

  • eval

    \(\text{output} = \sigma (\text{input})\)

  • computePrevDelta

    \(\delta :=\delta * \sigma'(\text{input})\)

  • grad

    pass

SoftmaxLayerModelWithCrossEntropyLoss

作为最后一层激发函数,这一层很特殊。

  • eval

    计算参见手写公式。

  • computePrevDelta

    不计算

  • grad

    不计算

  • loss

    计算\(\delta^L\),公式推导参见手写公式,代码如下:

    ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)

返回loss

Softmax输出层的激发函数:

\(a^L_j = \frac{e^{z^L_j}}{\sum_k e^{z^L_k}}\)

计算BP1:\(\delta^L_j = a^L_j -y_j\)

训练mnist手写数字识别

import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} object ann extends App {
val spark = SparkSession
.builder
.appName("ANN for MNIST")
.master("local[3]")
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR") import spark.implicits._ // Load the data stored in text as a DataFrame.
val dataRdd: DataFrame= spark.sparkContext.textFile("handson-ml/data/train.csv")
.map {
line =>
val linesp = line.split(",")
val linespDouble = linesp.map(f => f.toDouble)
(linespDouble.head, Vectors.dense(linespDouble.takeRight(linespDouble.length - 1)))
}.toDF("label","features") val data = dataRdd
// Split the data into train and test
val splits: Array[DataFrame] = data.randomSplit(Array(0.6, 0.4), seed = 1234L)
val train: Dataset[Row] = splits(0)
val test: Dataset[Row] = splits(1) val layers = Array[Int](28*28, 300, 100, 10) // create the trainer and set its parameters
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(128)
.setSeed(1234L)
.setMaxIter(100)
.setLabelCol("label")
.setFeaturesCol("features") // train the model
val model = trainer.fit(train) // compute accuracy on the test set
val result = model.transform(test)
val predictionAndLabels = result.select("prediction", "label")
val evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy") println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels))
}

后记

测试集结果精度为96.68%。实际上并不高,同样的数据集使用TensorFlow训练,activation function选择ReLU,同样使用Softmax作为输出,结果可以达到98%以上。Sigmoid函数容易带来vanishing gradients问题,导致学习曲线变平。

artificial neural network in spark MLLib的更多相关文章

  1. 人工神经网络 Artificial Neural Network

    2017-12-18 23:42:33 一.什么是深度学习 深度学习(deep neural network)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高 ...

  2. 吴恩达深度学习第1课第4周-任意层人工神经网络(Artificial Neural Network,即ANN)(向量化)手写推导过程(我觉得已经很详细了)

    学习了吴恩达老师深度学习工程师第一门课,受益匪浅,尤其是吴老师所用的符号系统,准确且易区分. 遵循吴老师的符号系统,我对任意层神经网络模型进行了详细的推导,形成笔记. 有人说推导任意层MLP很容易,我 ...

  3. Neural Network and Artificial Neural Network

    神经网络的基本单元为神经元neuron,也称为process unit,可以做一些基本的运算操作.   人脑和动物大脑的发育,依赖于经验的积累和学习.神经网络就是一个用来仿照人脑进行学习的机器,其包含 ...

  4. What is “Neural Network”

    Modern neuroscientists often discuss the brain as a type of computer. Neural networks aim to do the ...

  5. 论文笔记系列-Neural Network Search :A Survey

    论文笔记系列-Neural Network Search :A Survey 论文 笔记 NAS automl survey review reinforcement learning Bayesia ...

  6. 机器学习之Artificial Neural Networks

    人类通过模仿自然界中的生物,已经发明了很多东西,比如飞机,就是模仿鸟翼,但最终,这些东西会和原来的东西有些许差异,artificial neural networks (ANNs)就是模仿动物大脑的神 ...

  7. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  8. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  9. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

随机推荐

  1. MySQL 快速复数据库的方法

    为了方便快速复制一个数据库,可以用以下命令将db1数据库的数据以及表结构复制到newdb数据库 创建新的数据库 #mysql -u root -p123456 mysql>CREATE DATA ...

  2. 20165213 java学习第一周

    20165213 -2018-2<Java程序设计>第一周学习总结 教材学习内容总结 java的四个特点:面向对象.平台无关性.动态性.简单. java编写程序步骤:再有jdk的情况下,先 ...

  3. DataTable表连接

    public static System.Data.DataTable TableJoin(System.Data.DataTable dt, System.Data.DataTable dtDeta ...

  4. Subarray Product Less Than K LT713

    Your are given an array of positive integers nums. Count and print the number of (contiguous) subarr ...

  5. Tomcat+Redis+Nginx实现session共享(Windows版)

    redis安装:xx nginx安装:xx 步骤: 1.下载tomcat-redis-session-manager相应的jar包,主要有三个: wget https://github.com/dow ...

  6. Android中关于使用空格对齐文字

    前言:今日编写新项目UI时,突然遇到文本有长有短无法对齐的问题(汗,以前竟从未遇到也从未考虑过这小小的问题),在资源文件中尝试Tab键.space空格键,发现效果都不能很好的实现,无奈只得请求度娘的协 ...

  7. web API分类

    什么是Web API? Web API是网络应用程序接口.包含了广泛的功能,网络应用通过API接口,可以实现存储服务.消息服务.计算服务等能力,利用这些能力可以进行开发出强大功能的web应用. 分类 ...

  8. setInterval与setTimeout 的区别

    setInterval在执行完一次代码之后,经过了那个固定的时间间隔,它还会自动重复执行代码,而setTimeout只执行一次那段代码     用法: setInterval("alert( ...

  9. 20155312 2006-2007-2 《Java程序设计》第二周学习总结

    20155312 2006-2007-2 <Java程序设计>第二周学习总结 课堂内容总结 git:版本控制 生活中的容灾备份 归纳思维.实验思维.计算思维 计算机:实现自动化 学会使用快 ...

  10. PHP删除空格函数

    删除空格或其他字符的相关函数 ltrim函数 描述:实现删除字符串开始位置的空格或其他字符 语法:string ltrim(string $str [,string $charlist]) 说明:ch ...