本文细述上文引出的RAECost和SoftmaxCost两个类。

SoftmaxCost

我们已经知道。SoftmaxCost类在给定features和label的情况下(超參数给定),衡量给定权重(hidden×catSize)的误差值cost,并指出当前的权重梯度。看代码。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
@Override
    public double valueAt(double[]
x)
    {
        if(
!requiresEvaluation(x) )
            return value;
        int numDataItems
= Features.columns;
         
        int[]
requiredRows = ArraysHelper.makeArray(
0,
CatSize-
2);
        ClassifierTheta
Theta =
new ClassifierTheta(x,FeatureLength,CatSize);
        DoubleMatrix
Prediction = getPredictions (Theta, Features);
         
        double MeanTerm
=
1.0 /
(
double)
numDataItems;
        double Cost
= getLoss (Prediction, Labels).sum() * MeanTerm;
        double RegularisationTerm
=
0.5 *
Lambda * DoubleMatrixFunctions.SquaredNorm(Theta.W);
         
        DoubleMatrix
Diff = Prediction.sub(Labels).muli(MeanTerm);
        DoubleMatrix
Delta = Features.mmul(Diff.transpose());
     
        DoubleMatrix
gradW = Delta.getColumns(requiredRows);
        DoubleMatrix
gradb = ((Diff.rowSums()).getRows(requiredRows));
         
        //Regularizing.
Bias does not have one.
        gradW
= gradW.addi(Theta.W.mul(Lambda));
         
        Gradient
=
new ClassifierTheta(gradW,gradb);
        value
= Cost + RegularisationTerm;
        gradient
= Gradient.Theta;
        return value;
    }<br><br>public DoubleMatrix
getPredictions (ClassifierTheta Theta, DoubleMatrix Features)<br>    {<br>        
int numDataItems
= Features.columns;<br>        DoubleMatrix Input = ((Theta.W.transpose()).mmul(Features)).addColumnVector(Theta.b);<br>        Input = DoubleMatrix.concatVertically(Input, DoubleMatrix.zeros(
1,numDataItems));<br>  
     
return Activation.valueAt(Input);
<br>    }

是个典型的2层神经网络,没有隐层,首先依据features预測labels,预測结果用softmax归一化,然后依据误差反向传播算出权重梯度。

此处添加200字。

这个典型的2层神经网络,label为一列向量,目标label置1,其余为0;转换函数为softmax函数,输出为每一个label的概率。

计算cost的函数为getLoss。如果目标label的预測输出为p∗,则每一个样本的cost也即误差函数为:

cost=E(p∗)=−log(p∗)

依据前述的神经网络后向传播算法,我们得到(j为目标label时,否则为0):

∂E∂wij=∂E∂pj∂hj∂netjxi=−1pjpj(1−pj)xi=−(1−pj)xi=−(labelj−pj)featurei

因此我们便理解了以下代码的含义:

1
DoubleMatrix
Delta = Features.mmul(Diff.transpose());

RAECost

先看实现代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@Override
    public double valueAt(double[]
x)
    {
        if(!requiresEvaluation(x))
            return value;
         
        Theta
Theta1 =
new Theta(x,hiddenSize,visibleSize,dictionaryLength);
        FineTunableTheta
Theta2 =
new FineTunableTheta(x,hiddenSize,visibleSize,catSize,dictionaryLength);
        Theta2.setWe(
Theta2.We.add(WeOrig) );
         
        final RAEClassificationCost
classificationCost =
new RAEClassificationCost(
                catSize,
AlphaCat, Beta, dictionaryLength, hiddenSize, Lambda, f, Theta2);
        final RAEFeatureCost
featureCost =
new RAEFeatureCost(
                AlphaCat,
Beta, dictionaryLength, hiddenSize, Lambda, f, WeOrig, Theta1);
     
        Parallel.For(DataCell,
            new Parallel.Operation<LabeledDatum<Integer,Integer>>()
{
                public void perform(int index,
LabeledDatum<Integer,Integer> Data)
                {
                    try {
                        LabeledRAETree
Tree = featureCost.Compute(Data);
                        classificationCost.Compute(Data,
Tree);                
                    }
catch (Exception
e) {
                        System.err.println(e.getMessage());
                    }
                }
        });
         
        double costRAE
= featureCost.getCost();
        double[]
gradRAE = featureCost.getGradient().clone();
             
        double costSUP
= classificationCost.getCost();
        gradient
= classificationCost.getGradient();
             
        value
= costRAE + costSUP;
        for(int i=0;
i<gradRAE.length; i++)
            gradient[i]
+= gradRAE[i];
         
        System.gc();   
System.gc();
        System.gc();   
System.gc();
        System.gc();   
System.gc();
        System.gc();   
System.gc();
         
        return value;
    }

cost由两部分组成,featureCost和classificationCost。程序遍历每一个样本,用featureCost.Compute(Data)生成一个递归树,同一时候累加cost和gradient。然后用classificationCost.Compute(Data, Tree)依据生成的树计算并累加cost和gradient。因此关键类为RAEFeatureCost和RAEClassificationCost。

RAEFeatureCost类在Compute函数中调用RAEPropagation的ForwardPropagate函数生成一棵树。然后调用BackPropagate计算梯度并累加。详细的算法过程。下一章分解。

jrae源代码解析(二)的更多相关文章

  1. Spring源代码解析

    Spring源代码解析(一):IOC容器:http://www.iteye.com/topic/86339 Spring源代码解析(二):IoC容器在Web容器中的启动:http://www.itey ...

  2. Spring源代码解析(收藏)

    Spring源代码解析(收藏)   Spring源代码解析(一):IOC容器:http://www.iteye.com/topic/86339 Spring源代码解析(二):IoC容器在Web容器中的 ...

  3. C#使用zxing,zbar,thoughtworkQRcode解析二维码,附源代码

    最近做项目需要解析二维码图片,找了一大圈,发现没有人去整理下开源的几个库案例,花了点时间 做了zxing,zbar和thoughtworkqrcode解析二维码案例,希望大家有帮助. zxing是谷歌 ...

  4. NIO框架之MINA源代码解析(二):mina核心引擎

    NIO框架之MINA源代码解析(一):背景 MINA的底层还是利用了jdk提供了nio功能,mina仅仅是对nio进行封装.包含MINA用的线程池都是jdk直接提供的. MINA的server端主要有 ...

  5. SDWebImage源代码解析(二)

    上一篇:SDWebImage源代码解析(一) 2.缓存 为了降低网络流量的消耗.我们都希望下载下来的图片缓存到本地.下次再去获取同一张图片时.能够直接从本地获取,而不再从远程server获取.这样做的 ...

  6. redis之字符串命令源代码解析(二)

    形象化设计模式实战             HELLO!架构                     redis命令源代码解析 在redis之字符串命令源代码解析(一)中讲了get的简单实现,并没有对 ...

  7. asp.net C#生成和解析二维码代码

    类库文件我们在文件最后面下载 [ThoughtWorks.QRCode.dll 就是类库] 使用时需要增加: using ThoughtWorks.QRCode.Codec;using Thought ...

  8. Fixflow引擎解析(二)(模型) - BPMN2.0读写

    Fixflow引擎解析(四)(模型) - 通过EMF扩展BPMN2.0元素 Fixflow引擎解析(三)(模型) - 创建EMF模型来读写XML文件 Fixflow引擎解析(二)(模型) - BPMN ...

  9. Arrays.sort源代码解析

    Java Arrays.sort源代码解析 Java Arrays中提供了对所有类型的排序.其中主要分为Primitive(8种基本类型)和Object两大类. 基本类型:采用调优的快速排序: 对象类 ...

随机推荐

  1. Hex、bin、axf、elf格式文件小结

    转自Hex.bin.axf.elf格式文件小结 一.HEX Hex文件,一般是指Intel标准的十六进制文件.Intelhex 文件常用来保存单片机或其他处理器的目标程序代码.它保存物理程序存储区中的 ...

  2. tomcat 端口被占用

    找到占用8080端口的是PID为 2392的进程,于是 ctrl +shift+esc ,然后将这个进程结束掉.

  3. win7安装IIS及将网站发布到IIS上

    1. WIN7安装IIS:  控制面板----程序和功能-----打开或关闭windows功能,如图 展开Internet信息服务,按照下图方式进行选择,然后单击"确定",等待几分 ...

  4. 交叉编译 小米路由器mini 的 python(MIPS)

    看了很多文章,要么说的是用opkg安装python,要么说的是小米路由器的交叉编译,就是没有mini的.学习了这篇文章(http://me.deepgully.com/post/56389167868 ...

  5. Delphi实现多个图像相互覆盖时无内容处点击穿透

    http://www.52delphi.com/list.asp?ID=1405 CM_HITTEST是Delphi自定义的控件消息

  6. 使用Sqoop从mysql向hdfs或者hive导入数据时出现的一些错误

    1.原表没有设置主键,出现错误提示: ERROR tool.ImportTool: Error during import: No primary key could be found for tab ...

  7. Good vs Evil

    Good vs Evil Description Middle Earth is about to go to war. The forces of good will have many battl ...

  8. [hadoop源代码解读] 【SequenceFile】

    SequeceFile是Hadoop API提供的一种二进制文件支持.这种二进制文件直接将<key, value>对序列化到文件中.一般对小文件可以使用这种文件合并,即将文件名作为key, ...

  9. cat ,more, Less区别

    使用cat more less都可以查看文本内容,但是它们三者有什么区别呢?more和less的功能完全重复吗?以下是我个人的总结,欢迎大家一起来分享 cat        连续显示.查看文件内容mo ...

  10. 物联网操作系统HelloX V1.77(beta)版本发布

    物联网操作系统HelloX V1.77发布 经过近半年的努力,物联网操作系统HelloX V1.77版本正式完成,源代码已上载到github(github.com/hellox-project/Hel ...