(1) softmax loss

<1> softmax loss的函数形式为:

    (1)

zi为softmax的输入,f(zi)为softmax的输出。

<2> softmax loss对其输入zj求导:

     (2)

如果j==k,则zk是变量,否则zj是变量。

和的导数等于导数的和,对和中某个元素求导的话有:

(2) softmax_loss_layer.cpp中的Forward_cpu()函数:

  1. template <typename Dtype>
  2. void SoftmaxWithLossLayer<Dtype>::Forward_cpu(
  3. const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
  4. // The forward pass computes the softmax prob values.
    //调用softmax层的forward函数,得到对应的输出,存到prob_中
  5. softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
  6. const Dtype* prob_data = prob_.cpu_data();
    //一般loss层有两个输入blob,网络的predict blob(bottom[0])和label blob(bottom[1])
  7. const Dtype* label = bottom[]->cpu_data();
    //dim = N*C*H*W / N = C*H*W
  8. int dim = prob_.count() / outer_num_;
    //count变量是计算loss时的有效样本数
  9. int count = ;
  10. Dtype loss = ;
  11. for (int i = ; i < outer_num_; ++i) {
  12. for (int j = ; j < inner_num_; j++) {
    //读取label
  13. const int label_value = static_cast<int>(label[i * inner_num_ + j]);
    //如果该样本的label等于deploy中softmaxWithLoss中设定的参数ignore_label_,则该样本不参与前向和后向计算
  14. if (has_ignore_label_ && label_value == ignore_label_) {
  15. continue;
  16. }
    //判断label_value是否大于等于0
  17. DCHECK_GE(label_value, );
    //判断label_value是否小于prob_.shape(softmax_axis_)=C
  18. DCHECK_LT(label_value, prob_.shape(softmax_axis_));
    //对于softmax的输出channel,计算label_value索引对应的channel中prob的log.对应公式(1)
  19. loss -= log(std::max(prob_data[i * dim + label_value * inner_num_ + j],
  20. Dtype(FLT_MIN)));
    //有效样本数加一
  21. ++count;
  22. }
  23. }
    //最终在训练日志中显示的loss为计算的总loss除以有效样本数
  24. top[]->mutable_cpu_data()[] = loss / get_normalizer(normalization_, count);
  25. if (top.size() == ) {
  26. top[]->ShareData(prob_);
  27. }
  28. }

(3) softmax_loss_layer.cpp中的Backward_cpu函数:

  1. template <typename Dtype>
  2. void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
  3. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
  4. if (propagate_down[]) {
  5. LOG(FATAL) << this->type()
  6. << " Layer cannot backpropagate to label inputs.";
  7. }
  8. if (propagate_down[]) {
  9. Dtype* bottom_diff = bottom[]->mutable_cpu_diff();
  10. const Dtype* prob_data = prob_.cpu_data();
    //将softmax的输出prob_复制给bottom[0]的diff(梯度) blob
  11. caffe_copy(prob_.count(), prob_data, bottom_diff);
  12. const Dtype* label = bottom[]->cpu_data();
  13. int dim = prob_.count() / outer_num_;
  14. int count = ;
  15. for (int i = ; i < outer_num_; ++i) {
  16. for (int j = ; j < inner_num_; ++j) {
  17. const int label_value = static_cast<int>(label[i * inner_num_ + j]);
  18. if (has_ignore_label_ && label_value == ignore_label_) {
  19. for (int c = ; c < bottom[]->shape(softmax_axis_); ++c) {
  20. bottom_diff[i * dim + c * inner_num_ + j] = ;
  21. }
  22. } else {
    //对应公式(2),在反传梯度时,label索引对应的diff减1,其他不变。
  23. bottom_diff[i * dim + label_value * inner_num_ + j] -= ;
  24. ++count;
  25. }
  26. }
  27. }
  28. // Scale gradient
    //top[0]->cpu_diff()[0] = N
    //N / count
  29. Dtype loss_weight = top[]->cpu_diff()[] /
  30. get_normalizer(normalization_, count);
  31. caffe_scal(prob_.count(), loss_weight, bottom_diff);
  32. }
  33. }

caffe中softmax loss源码阅读的更多相关文章

  1. caffe中batch norm源码阅读

    1. batch norm 输入batch norm层的数据为[N, C, H, W], 该层计算得到均值为C个,方差为C个,输出数据为[N, C, H, W]. <1> 形象点说,均值的 ...

  2. 【源码阅读】Java集合之三 - ArrayDeque源码深度解读

    Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章,本文是第三篇ArrayDeque. ---@pdai JDK版本 ...

  3. 【源码阅读】Java集合之二 - LinkedList源码深度解读

    Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章; 本文是第二篇LinkedList. ---@pdai JDK版 ...

  4. 【源码阅读】Java集合之一 - ArrayList源码深度解读

    Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章,从ArrayList开始第一篇. ---@pdai JDK版本 ...

  5. Caffe源码阅读(1) 全连接层

    Caffe源码阅读(1) 全连接层 发表于 2014-09-15   |   今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...

  6. caffe-windows中classification.cpp的源码阅读

    caffe-windows中classification.cpp的源码阅读 命令格式: usage: classification string(模型描述文件net.prototxt) string( ...

  7. 源码阅读笔记 - 1 MSVC2015中的std::sort

    大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...

  8. 源码阅读经验谈-slim,darknet,labelimg,caffe(1)

    本文首先谈自己的源码阅读体验,然后给几个案例解读,选的例子都是比较简单.重在说明我琢磨的点线面源码阅读方法.我不是专业架构师,是从一个深度学习算法工程师的角度来谈的,不专业的地方请大家轻拍. 经常看别 ...

  9. SpringMVC源码阅读:Controller中参数解析

    1.前言 SpringMVC是目前J2EE平台的主流Web框架,不熟悉的园友可以看SpringMVC源码阅读入门,它交代了SpringMVC的基础知识和源码阅读的技巧 本文将通过源码(基于Spring ...

随机推荐

  1. 任务分线程实现(java)

    1.创建一个类,用户存储信息 public class Users { private String userid; private String username; public Users() { ...

  2. PLC与上位机的socket通讯——上位机C#程序(二)

    C#的网口通信 一.命令行 客户端程序:using System;using System.Collections.Generic;using System.Linq;using System.Tex ...

  3. 微信小程序商城构建全栈应用 Thinkphp5

    课程——微信小程序商城构建全栈应用[目录]第1章 前言:不同的时代,不同的Web第2章 环境,工具与准备工作第3章 模块,路由与获取请求参数第4章 构建验证层第5章 REST与RESTFul第6章 A ...

  4. 006:CSS高级技巧

    目录 前言 理论 CSS高级技巧 一:元素的显示与隐藏 在CSS中有三个显示和隐藏的单词比较常见,我们要区分开,他们分别是 display visibility 和 overflow. 他们的主要目的 ...

  5. Containers vs Serverless:你选择谁,何时选择?

    两者都是当今技术时代的热门话题,也都被视为是开发技术的竞争对手. 首先,还有相当多的好奇和担心.此外,两者都是可供工程师使用的.高效的.机器无关的抽象. 但是,在冠军之间,有一个不可逾越的鸿沟.你要么 ...

  6. node连接数据库

    一.在package.json依赖模块添加: "mysql" : "latest",执行npm install: 二.module目录下新建mysql.js: ...

  7. ETL-Kettle学习笔记(入门,简介,简单操作)

    KETTLE Kettle:简介 ETL:简介 ETL(Extract-Transform-Load的缩写,即数据抽取.转换.装载的过程),对于企业或行业应用来说,我们经常会遇到各种数据的处理,转换, ...

  8. hihttps教你在Wireshark中提取旁路https解密源码

    大家好,我是hihttps,专注SSL web安全研究,今天本文就是教大家怎样从wireshark源码中,提取旁路https解密的源码,非常值得学习和商业应用. 一.旁路https解密条件 众所周知, ...

  9. 23种设计模式之原型模式(Prototype Pattern)

    原型模式 使用原型实例指定待创建对象的类型,并且通过复制这个原型来创建新的对象 分析: 孙悟空:根据自己的形状复制(克隆)出多个身外身 软件开发:通过复制一个原型对象得到多个与原型对象一模一样的新对象 ...

  10. <q> 与 <blockquote> 的区别

    <q> 标签在本质上与 <blockquote> 是一样的.不同之处在于它们的显示和应用.<q> 标签用于简短的行内引用.如果需要从周围内容分离出来比较长的部分(通 ...