(1) softmax loss

<1> softmax loss的函数形式为:

    (1)

zi为softmax的输入,f(zi)为softmax的输出。

<2> softmax loss对其输入zj求导:

     (2)

如果j==k,则zk是变量,否则zj是变量。

和的导数等于导数的和,对和中某个元素求导的话有:

(2) softmax_loss_layer.cpp中的Forward_cpu()函数:

 template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Forward_cpu(
const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
// The forward pass computes the softmax prob values.
//调用softmax层的forward函数,得到对应的输出,存到prob_中
softmax_layer_->Forward(softmax_bottom_vec_, softmax_top_vec_);
const Dtype* prob_data = prob_.cpu_data();
//一般loss层有两个输入blob,网络的predict blob(bottom[0])和label blob(bottom[1])
const Dtype* label = bottom[]->cpu_data();
//dim = N*C*H*W / N = C*H*W
int dim = prob_.count() / outer_num_;
//count变量是计算loss时的有效样本数
int count = ;
Dtype loss = ;
for (int i = ; i < outer_num_; ++i) {
for (int j = ; j < inner_num_; j++) {
//读取label
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
//如果该样本的label等于deploy中softmaxWithLoss中设定的参数ignore_label_,则该样本不参与前向和后向计算
if (has_ignore_label_ && label_value == ignore_label_) {
continue;
}
//判断label_value是否大于等于0
DCHECK_GE(label_value, );
//判断label_value是否小于prob_.shape(softmax_axis_)=C
DCHECK_LT(label_value, prob_.shape(softmax_axis_));
//对于softmax的输出channel,计算label_value索引对应的channel中prob的log.对应公式(1)
loss -= log(std::max(prob_data[i * dim + label_value * inner_num_ + j],
Dtype(FLT_MIN)));
//有效样本数加一
++count;
}
}
//最终在训练日志中显示的loss为计算的总loss除以有效样本数
top[]->mutable_cpu_data()[] = loss / get_normalizer(normalization_, count);
if (top.size() == ) {
top[]->ShareData(prob_);
}
}

(3) softmax_loss_layer.cpp中的Backward_cpu函数:

 template <typename Dtype>
void SoftmaxWithLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
if (propagate_down[]) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to label inputs.";
}
if (propagate_down[]) {
Dtype* bottom_diff = bottom[]->mutable_cpu_diff();
const Dtype* prob_data = prob_.cpu_data();
//将softmax的输出prob_复制给bottom[0]的diff(梯度) blob
caffe_copy(prob_.count(), prob_data, bottom_diff);
const Dtype* label = bottom[]->cpu_data();
int dim = prob_.count() / outer_num_;
int count = ;
for (int i = ; i < outer_num_; ++i) {
for (int j = ; j < inner_num_; ++j) {
const int label_value = static_cast<int>(label[i * inner_num_ + j]);
if (has_ignore_label_ && label_value == ignore_label_) {
for (int c = ; c < bottom[]->shape(softmax_axis_); ++c) {
bottom_diff[i * dim + c * inner_num_ + j] = ;
}
} else {
//对应公式(2),在反传梯度时,label索引对应的diff减1,其他不变。
bottom_diff[i * dim + label_value * inner_num_ + j] -= ;
++count;
}
}
}
// Scale gradient
//top[0]->cpu_diff()[0] = N
//N / count
Dtype loss_weight = top[]->cpu_diff()[] /
get_normalizer(normalization_, count);
caffe_scal(prob_.count(), loss_weight, bottom_diff);
}
}

caffe中softmax loss源码阅读的更多相关文章

  1. caffe中batch norm源码阅读

    1. batch norm 输入batch norm层的数据为[N, C, H, W], 该层计算得到均值为C个,方差为C个,输出数据为[N, C, H, W]. <1> 形象点说,均值的 ...

  2. 【源码阅读】Java集合之三 - ArrayDeque源码深度解读

    Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章,本文是第三篇ArrayDeque. ---@pdai JDK版本 ...

  3. 【源码阅读】Java集合之二 - LinkedList源码深度解读

    Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章; 本文是第二篇LinkedList. ---@pdai JDK版 ...

  4. 【源码阅读】Java集合之一 - ArrayList源码深度解读

    Java 源码阅读的第一步是Collection框架源码,这也是面试基础中的基础: 针对Collection的源码阅读写一个系列的文章,从ArrayList开始第一篇. ---@pdai JDK版本 ...

  5. Caffe源码阅读(1) 全连接层

    Caffe源码阅读(1) 全连接层 发表于 2014-09-15   |   今天看全连接层的实现.主要看的是https://github.com/BVLC/caffe/blob/master/src ...

  6. caffe-windows中classification.cpp的源码阅读

    caffe-windows中classification.cpp的源码阅读 命令格式: usage: classification string(模型描述文件net.prototxt) string( ...

  7. 源码阅读笔记 - 1 MSVC2015中的std::sort

    大约寒假开始的时候我就已经把std::sort的源码阅读完毕并理解其中的做法了,到了寒假结尾,姑且把它写出来 这是我的第一篇源码阅读笔记,以后会发更多的,包括算法和库实现,源码会按照我自己的代码风格格 ...

  8. 源码阅读经验谈-slim,darknet,labelimg,caffe(1)

    本文首先谈自己的源码阅读体验,然后给几个案例解读,选的例子都是比较简单.重在说明我琢磨的点线面源码阅读方法.我不是专业架构师,是从一个深度学习算法工程师的角度来谈的,不专业的地方请大家轻拍. 经常看别 ...

  9. SpringMVC源码阅读:Controller中参数解析

    1.前言 SpringMVC是目前J2EE平台的主流Web框架,不熟悉的园友可以看SpringMVC源码阅读入门,它交代了SpringMVC的基础知识和源码阅读的技巧 本文将通过源码(基于Spring ...

随机推荐

  1. 004:CSS三大重点之二:浮动(拖标、不占位置、转行内块)

    目录 1:拖标.不占位.转行内块 2:首先浮动的盒子需要和标准流的父级搭配使用,其次 特别的注意浮动可以使元素显示模式体现为行内块特性. 3:清除浮动 前言 CSS的定位机制有3种:普通流(标准流). ...

  2. 浅谈 Vector

    目录 浅谈Vector 1.容器基本操作 2.vector 初始化 3.vector的赋值与swap 4.vector的增删改除 1.增加元素 2.访问元素 3.删除元素 4.元素的大小 浅谈Vect ...

  3. 在wxml中直接写js代码(wxs)

    我们在h5开发中,很多时候要在html中写到js代码,这个很容易实现.但是在微信小程序开发中,是不能直接在wxml中写js代码的,因此就有了wxs.在wxml中用wxs代码,有以下几种方式(在小程序文 ...

  4. PyCharm设置自己的默认模板

    1.File-Settings 2.Editor- Code Style - File and Code Templates - Python Script 需要设置什么内容,现在就可以写入了,相关变 ...

  5. 【面试题】Java集合部分面试题

    集合与数组? 数组:(可以存储基本数据类型)是用来存储对象的一种容器,但是数组的长度固定,不适合在对象数量未知的情况下使用 集合:(只能存储对象,对象类型可以不一样)集合的长度可变,可在多数情况下使用 ...

  6. redis 主从复制和哨兵模式(二)

    Redis 主从复制 为了分担单机 redis 的数据服务压力,需要进行读写分离,所以搭建 redis 的主从结构,主节点负责写,从节点负责读,主节点定期把数据同步到从节点. 配置主从 # 配置文件中 ...

  7. 一个简单的MyBatis项目

    1.log4j.properties,我们把它设为debug级别,以便于调试.生产环境可以设为INFO,本项目放在src下面: # Global logging configuration log4j ...

  8. 从壹开始 [Admin] 之五 ║ 实现『按钮』级别权限配置

    一.前情回顾 哈喽大家好,在这个欢庆的日子里,老张祝大家工作都能蒸蒸日上!今天正好也是社团成立的第一天,我也是希望今天能是个纪念日,沾沾这个大喜庆! 放假这两天,倒是学到了很多东西,我这个也是承认的, ...

  9. layui-table 对表格数据进行处理之后的排序问题

    使用layui table过程中,将某一列的数据格式进行转换,或者将0/1状态改为是/否,或者将数字改为星星评分显示的时候都会遇到一个问题,我的表格数据转换成其他形式,同时设置了sort:true,此 ...

  10. profile文件的错误加载与基本命令间的映射

    一.绪论 [因为这篇心得是原创的,所以如果有哪处总结或者意见不足的地方,欢迎各位大神的批评和意见,共同学习,谢谢了!] 早些时候,需要在centos6.4系统中配置单机版和集群版单节点的hadoop ...