答案来自专栏:机器学习算法与自然语言处理

详解softmax函数以及相关求导过程

这几天学习了一下softmax激活函数,以及它的梯度求导过程,整理一下便于分享和交流。

softmax函数

softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!

假设我们有一个数组,V,Vi表示V中的第i个元素,那么这个元素的softmax值就是

<img src="https://pic3.zhimg.com/50/v2-39eca1f41fe487983f5111f5e5073396_hd.jpg" data-rawwidth="140" data-rawheight="71" class="content_image" width="140">

更形象的如下图表示:

<img src="https://pic1.zhimg.com/50/v2-87b232ab0e292a536e94b73952caadd0_hd.jpg" data-rawwidth="670" data-rawheight="396" class="origin_image zh-lightbox-thumb" width="670" data-original="https://pic1.zhimg.com/v2-87b232ab0e292a536e94b73952caadd0_r.jpg">

softmax直白来说就是将原来输出是3,1,-3通过softmax函数一作用,就映射成为(0,1)的值,而这些值的累和为1(满足概率的性质),那么我们就可以将它理解成概率,在最后选取输出结点的时候,我们就可以选取概率最大(也就是值对应最大的)结点,作为我们的预测目标!

举一个我最近碰到利用softmax的例子:我现在要实现基于神经网络的句法分析器。用到是基于转移系统来做,那么神经网络的用途就是帮我预测我这一个状态将要进行的动作是什么?比如有10个输出神经元,那么就有10个动作,1动作,2动作,3动作...一直到10动作。(这里涉及到nlp的知识,大家不用管,只要知道我现在根据每个状态(输入),来预测动作(得到概率最大的输出),最终得到的一系列动作序列就可以完成我的任务即可

原理图如下图所示:

<img src="https://pic1.zhimg.com/50/v2-892ff35f895772d419c77bdd00984f48_hd.jpg" data-rawwidth="627" data-rawheight="188" class="origin_image zh-lightbox-thumb" width="627" data-original="https://pic1.zhimg.com/v2-892ff35f895772d419c77bdd00984f48_r.jpg">

那么比如在一次的输出过程中输出结点的值是如下:

[0.2,0.1,0.05,0.1,0.2,0.02,0.08,0.01,0.01,0.23]

那么我们就知道这次我选取的动作是动作10,因为0.23是这次概率最大的,那么怎么理解多分类呢?很容易,如果你想选取俩个动作,那么就找概率最大的俩个值即可~(这里只是简单的告诉大家softmax在实际问题中一般怎么应用)

softmax相关求导

当我们对分类的Loss进行改进的时候,我们要通过梯度下降,每次优化一个step大小的梯度,这个时候我们就要求Loss对每个权重矩阵的偏导,然后应用链式法则。那么这个过程的第一步,就是对softmax求导传回去,不用着急,我后面会举例子非常详细的说明。在这个过程中,你会发现用了softmax函数之后,梯度求导过程非常非常方便!

下面我们举出一个简单例子,原理一样,目的是为了帮助大家容易理解!

<img src="https://pic2.zhimg.com/50/v2-284da728c54de850e34573d6fe682965_hd.jpg" data-rawwidth="620" data-rawheight="391" class="origin_image zh-lightbox-thumb" width="620" data-original="https://pic2.zhimg.com/v2-284da728c54de850e34573d6fe682965_r.jpg">

我们能得到下面公式:

z4 = w41*o1+w42*o2+w43*o3

z5 = w51*o1+w52*o2+w53*o3

z6 = w61*o1+w62*o2+w63*o3

z4,z5,z6分别代表结点4,5,6的输出,01,02,03代表是结点1,2,3往后传的输入.

那么我们可以经过softmax函数得到

好了,我们的重头戏来了,怎么根据求梯度,然后利用梯度下降方法更新梯度!

要使用梯度下降,肯定需要一个损失函数,这里我们使用交叉熵作为我们的损失函数,为什么使用交叉熵损失函数,不是这篇文章重点,后面有时间会单独写一下为什么要用到交叉熵函数(这里我们默认选取它作为损失函数)

交叉熵函数形式如下:

<img src="https://pic2.zhimg.com/50/v2-cd3d19d3680993cb9ae62d2fd110e43d_hd.jpg" data-rawwidth="173" data-rawheight="59" class="content_image" width="173">

其中y代表我们的真实值,a代表我们softmax求出的值。i代表的是输出结点的标号!在上面例子,i就可以取值为4,5,6三个结点(当然我这里只是为了简单,真实应用中可能有很多结点

现在看起来是不是感觉复杂了,居然还有累和,然后还要求导,每一个a都是softmax之后的形式!

但是实际上不是这样的,我们往往在真实中,如果只预测一个结果,那么在目标中只有一个结点的值为1,比如我认为在该状态下,我想要输出的是第四个动作(第四个结点),那么训练数据的输出就是a4 = 1,a5=0,a6=0,哎呀,这太好了,除了一个为1,其它都是0,那么所谓的求和符合,就是一个幌子,我可以去掉啦!

为了形式化说明,我这里认为训练数据的真实输出为第j个为1,其它均为0!

那么Loss就变成了,累和已经去掉了,太好了。现在我们要开始求导数了!

我们在整理一下上面公式,为了更加明白的看出相关变量的关系:

其中,那么形式变为

那么形式越来越简单了,求导分析如下:

参数的形式在该例子中,总共分为w41,w42,w43,w51,w52,w53,w61,w62,w63.这些,那么比如我要求出w41,w42,w43的偏导,就需要将Loss函数求偏导传到结点4,然后再利用链式法则继续求导即可,举个例子此时求w41的偏导为:

<img src="https://pic3.zhimg.com/50/v2-b345590cc09867302a22071e69b1f81e_hd.jpg" data-rawwidth="595" data-rawheight="331" class="origin_image zh-lightbox-thumb" width="595" data-original="https://pic3.zhimg.com/v2-b345590cc09867302a22071e69b1f81e_r.jpg">

w51.....w63等参数的偏导同理可以求出,那么我们的关键就在于Loss函数对于结点4,5,6的偏导怎么求,如下:

这里分为俩种情况:

<img src="https://pic3.zhimg.com/50/v2-da652b68fb126c709e69eda674fb42be_hd.jpg" data-rawwidth="433" data-rawheight="284" class="origin_image zh-lightbox-thumb" width="433" data-original="https://pic3.zhimg.com/v2-da652b68fb126c709e69eda674fb42be_r.jpg">

 

j=i对应例子里就是如下图所示:

比如我选定了j为4,那么就是说我现在求导传到4结点这!

<img src="https://pic2.zhimg.com/50/v2-40a344cb96d81d30aec0d91a95a04409_hd.jpg" data-rawwidth="570" data-rawheight="353" class="origin_image zh-lightbox-thumb" width="570" data-original="https://pic2.zhimg.com/v2-40a344cb96d81d30aec0d91a95a04409_r.jpg">

那么由上面求导结果再乘以交叉熵损失函数求导

,它的导数为,与上面相乘为形式非常简单,这说明我只要正向求一次得出结果,然后反向传梯度的时候,只需要将它结果减1即可,后面还会举例子!)那么我们可以得到Loss对于4结点的偏导就求出了了(这里假定4是我们的预计输出

第二种情况为:

<img src="https://pic1.zhimg.com/50/v2-fa3ff8ff1c781e259a15b36b85c65fcc_hd.jpg" data-rawwidth="356" data-rawheight="292" class="content_image" width="356">

这里对应我的例子图如下,我这时对的是j不等于i,往前传:

<img src="https://pic1.zhimg.com/50/v2-a4edb7767bfed190991e86cbdaa472cc_hd.jpg" data-rawwidth="547" data-rawheight="361" class="origin_image zh-lightbox-thumb" width="547" data-original="https://pic1.zhimg.com/v2-a4edb7767bfed190991e86cbdaa472cc_r.jpg">

那么由上面求导结果再乘以交叉熵损失函数求导

,它的导数为,与上面相乘为形式非常简单,这说明我只要正向求一次得出结果,然后反向传梯度的时候,只需要将它结果保存即可,后续例子会讲到这里就求出了除4之外的其它所有结点的偏导,然后利用链式法则继续传递过去即可!我们的问题也就解决了!

下面我举个例子来说明为什么计算会比较方便,给大家一个直观的理解

举个例子,通过若干层的计算,最后得到的某个训练样本的向量的分数是[ 2, 3, 4 ],

那么经过softmax函数作用后概率分别就是=[

,,] = [0.0903,0.2447,0.665],如果这个样本正确的分类是第二个的话,那么计算出来的偏导就是[0.0903,0.2447-1,0.665]=[0.0903,-0.7553,0.665],是不是非常简单!!然后再根据这个进行back propagation就可以了

到这里,这篇文章的内容就讲完了,我希望根据自己的理解,通过列出大量例子,直白的给大家讲解softmax的相关内容,让大家少走弯路,真心希望对大家的理解有帮助!欢迎交流指错!画图整理不易,觉得有帮助的给个赞呗,哈哈!

参考:

softmax的log似然代价函数(公式求导)

交叉熵代价函数(作用及公式推导)

Softmax回归 - Ufldl

部分图片来自于网络!

softmax函数详解的更多相关文章

  1. [转]softmax函数详解

    答案来自专栏:机器学习算法与自然语言处理 详解softmax函数以及相关求导过程 这几天学习了一下softmax激活函数,以及它的梯度求导过程,整理一下便于分享和交流. softmax函数 softm ...

  2. Softmax函数详解与推导

    一.softmax函数 softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类! 假设我们有一个数组,V,Vi表示V中的第i个元素,那么这个 ...

  3. malloc 与 free函数详解<转载>

    malloc和free函数详解   本文介绍malloc和free函数的内容. 在C中,对内存的管理是相当重要.下面开始介绍这两个函数: 一.malloc()和free()的基本概念以及基本用法: 1 ...

  4. NSSearchPathForDirectoriesInDomains函数详解

    NSSearchPathForDirectoriesInDomains函数详解     #import "NSString+FilePath.h" @implementation ...

  5. JavaScript正则表达式详解(二)JavaScript中正则表达式函数详解

    二.JavaScript中正则表达式函数详解(exec, test, match, replace, search, split) 1.使用正则表达式的方法去匹配查找字符串 1.1. exec方法详解 ...

  6. Linux C popen()函数详解

    表头文件 #include<stdio.h> 定义函数 FILE * popen( const char * command,const char * type); 函数说明 popen( ...

  7. kzalloc 函数详解(转载)

    用kzalloc申请内存的时候, 效果等同于先是用 kmalloc() 申请空间 , 然后用 memset() 来初始化 ,所有申请的元素都被初始化为 0. view plain /** * kzal ...

  8. Netsuite Formula > Oracle函数列表速查(PL/SQL单行函数和组函数详解).txt

    PL/SQL单行函数和组函数详解 函数是一种有零个或多个参数并且有一个返回值的程序.在SQL中Oracle内建了一系列函数,这些函数都可被称为SQL或PL/SQL语句,函数主要分为两大类: 单行函数 ...

  9. jQuery.attr() 函数详解

    一,jQuery.attr()  函数详解: http://www.365mini.com/page/jquery-attr.htm 二,jQuery函数attr()和prop()的区别: http: ...

随机推荐

  1. 使用Python分析ELF文件优化Flash和Sram空间的案例

    1. 背景 Zephyr项目Flash和Ram空间比较紧张,有着非常强烈的优化需求. 优化的前提是量化标的,那么如何量化Flash和Ram的使用量呢? 在量化之后,首先要对量化结果进行分析,然后采取措 ...

  2. 使用wordpress搭建自己的独立博客

    最近想要搭建自己的私人博客, 各种百度,完整的搭建步骤如下! 首先得要有自己的vps或者云主机,我这里是自己的云主机,有自己的域名(我这边目前没有买域名)! 搭建步骤! 1,安装lnmp(linux+ ...

  3. grpc.go

    package,,), etcd.WithPrefix(), etcd.WithPrevKV()}     gw.wch = gw.c.Watch(gw.ctx, gw.target, opts... ...

  4. 学习ASP.NET Core Razor 编程系列十——添加新字段

    学习ASP.NET Core Razor 编程系列目录 学习ASP.NET Core Razor 编程系列一 学习ASP.NET Core Razor 编程系列二——添加一个实体 学习ASP.NET ...

  5. BZOJ_1031_[JSOI2007]字符加密Cipher_后缀数组

    BZOJ_1031_[JSOI2007]字符加密Cipher_后缀数组 Description 喜欢钻研问题的JS同学,最近又迷上了对加密方法的思考.一天,他突然想出了一种他认为是终极的加密办法 :把 ...

  6. java quartz 计算近20次执行时间

    /** * * @desc 计算表达式近20次时间 * @auth josnow * @date 2017年5月31日 下午12:16:25 * @param cron * @return */ pu ...

  7. python环境下实现OrangePi Zero寄存器访问及GPIO控制

    最近入手OrangePi Zero一块,程序上需要使用板子上自带的LED灯,在网上一查,不得不说OPi的支持跟树莓派无法相比.自己摸索了一下,实现简单的GPIO控制方法,作者的Zero安装的是Armb ...

  8. H5 新特性之 fileReader 实现本地图片视频资源的预览

    大家好 !!  又见面了, 今天我们来搞一搞   H5的新增API    FileReader     真是一个超级超级方便的API呢!!!很多场景都可以使用.......... 我们先不赘述MDN文 ...

  9. 接口测试心得--签名处理(Python)

    一.背景 最近负责的项目接口签名规则做了调整,第一次接触“2次认证“的方式,正好有时间,记录一下. 测试的服务A有一部分接口需要给第三方调用,这样需要对第三方有个认证,认证是由一个公共服务(API鉴权 ...

  10. RabbitMQ死信队列另类用法之复合死信

    前言 在业务开发过程中,我们常常需要做一些定时任务,这些任务一般用来做监控或者清理任务,比如在订单的业务场景中,用户在创建订单后一段时间内,没有完成支付,系统将自动取消该订单,并将库存返回到商品中,又 ...