A Loss Function for Learning Region Proposals

训练RPN时,只对两种anchor给予正标签:和gt_box有着最高的IoU && IoU超过0.7。如果对于

所有的gt_box,其IoU都小于0.3,则标记为负。损失函数定义如下:

其中i为一个mini-batch中某anchor的索引,pi表示其为目标的预测概率,pi*表示gt_box(正为1,否则为0)。

ti和ti*分别表示预测框的位置和gt_box框的位置。Lreg如下:

bound-box regression中各参数的计算方式为:

   (4)

其对应的SmoothL1LossLayer代码如下,整个过程分为两部分:前向计算以及后向计算(1)式的后半部分:

// ------------------------------------------------------------------
// Fast R-CNN
// Copyright (c) 2015 Microsoft
// Licensed under The MIT License [see fast-rcnn/LICENSE for details]
// Written by Ross Girshick
// ------------------------------------------------------------------ #include "caffe/fast_rcnn_layers.hpp" namespace caffe {
//SmoothL1前向计算(3)式
template <typename Dtype>
__global__ void SmoothL1Forward(const int n, const Dtype* in, Dtype* out,
Dtype sigma2) {
// f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma
// |x| - 0.5 / sigma / sigma otherwise
CUDA_KERNEL_LOOP(index, n) {
Dtype val = in[index];
Dtype abs_val = abs(val);
if (abs_val < 1.0 / sigma2) {
out[index] = 0.5 * val * val * sigma2;
} else {
out[index] = abs_val - 0.5 / sigma2;
}
}
}
//
template <typename Dtype>
void SmoothL1LossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
int count = bottom[0]->count();
caffe_gpu_sub(
count,
bottom[0]->gpu_data(), //ti
bottom[1]->gpu_data(), //ti*
diff_.mutable_gpu_data()); // d := ti-ti*
if (has_weights_) { //乘上相关的权重,对应于(1)式中的pi*,有目标时为1
// apply "inside" weights
caffe_gpu_mul(
count,
bottom[2]->gpu_data(), //pi*
diff_.gpu_data(),
diff_.mutable_gpu_data()); // d := w_in * (b0 - b1)
}
//代入计算SmoothL1
SmoothL1Forward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, diff_.gpu_data(), errors_.mutable_gpu_data(), sigma2_);
CUDA_POST_KERNEL_CHECK; if (has_weights_) { //乘上相关的权重
// apply "outside" weights
caffe_gpu_mul(
count,
bottom[3]->gpu_data(), // 1/Nreg
errors_.gpu_data(),
errors_.mutable_gpu_data()); // d := w_out * SmoothL1(w_in * (b0 - b1))
} Dtype loss;
caffe_gpu_dot(count, ones_.gpu_data(), errors_.gpu_data(), &loss);
top[0]->mutable_cpu_data()[0] = loss / bottom[0]->num();
}
//反向计算,对smoothLoss求导
template <typename Dtype>
__global__ void SmoothL1Backward(const int n, const Dtype* in, Dtype* out,
Dtype sigma2) {
// f'(x) = sigma * sigma * x if |x| < 1 / sigma / sigma
// = sign(x) otherwise
CUDA_KERNEL_LOOP(index, n) {
Dtype val = in[index];
Dtype abs_val = abs(val);
if (abs_val < 1.0 / sigma2) {
out[index] = sigma2 * val;
} else {
out[index] = (Dtype(0) < val) - (val < Dtype(0));
}
}
}
//
template <typename Dtype>
void SmoothL1LossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
// after forwards, diff_ holds w_in * (b0 - b1)
int count = diff_.count();
//调用反向smoothloss,diff_.gpu_data()表示x,diff_.mutable_gpu_data()表示smoothloss的导数
SmoothL1Backward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
count, diff_.gpu_data(), diff_.mutable_gpu_data(), sigma2_); //类似于前向
CUDA_POST_KERNEL_CHECK;
for (int i = 0; i < 2; ++i) {
if (propagate_down[i]) {
const Dtype sign = (i == 0) ? 1 : -1;
const Dtype alpha = sign * top[0]->cpu_diff()[0] / bottom[i]->num();
caffe_gpu_axpby(
count, // count
alpha, // alpha
diff_.gpu_data(), // x
Dtype(0), // beta
bottom[i]->mutable_gpu_diff()); // y
if (has_weights_) {
// Scale by "inside" weight
caffe_gpu_mul(
count,
bottom[2]->gpu_data(),
bottom[i]->gpu_diff(),
bottom[i]->mutable_gpu_diff());
// Scale by "outside" weight
caffe_gpu_mul(
count,
bottom[3]->gpu_data(),
bottom[i]->gpu_diff(),
bottom[i]->mutable_gpu_diff());
}
}
}
} INSTANTIATE_LAYER_GPU_FUNCS(SmoothL1LossLayer); } // namespace caffe
 

r-cnn学习(五):SmoothL1LossLayer论文与代码的结合理解的更多相关文章

  1. 10K+,深度学习论文、代码最全汇总!

    我们大部分人是如何查询和搜集深度学习相关论文的?绝大多数情况是根据关键字在谷歌.百度搜索.想寻找相关论文的复现代码又会去 GitHub 上搜索关键词.浪费了很多时间不说,论文.代码通常也不够完整.怎么 ...

  2. R语言学习笔记(五)绘图(1)

      R是一个惊艳的图形构建平台,这也是R语言的强大之处.本文将分享R语言简单的绘图命令.   本文所使用的数据或者来自R语言自带的数据(mtcars)或者自行创建.   首先,让我们来看一个简单例子: ...

  3. [ZZ]计算机视觉、机器学习相关领域论文和源代码大集合

    原文地址:[ZZ]计算机视觉.机器学习相关领域论文和源代码大集合作者:计算机视觉与模式 注:下面有project网站的大部分都有paper和相应的code.Code一般是C/C++或者Matlab代码 ...

  4. Context Encoder论文及代码解读

    经过秋招和毕业论文的折磨,提交完论文終稿的那一刻总算觉得有多余的时间来搞自己的事情. 研究论文做的是图像修复相关,这里对基于深度学习的图像修复方面的论文和代码进行整理,也算是研究生方向有一个比较好的结 ...

  5. Android JNI学习(五)——Demo演示

    本系列文章如下: Android JNI(一)——NDK与JNI基础 Android JNI学习(二)——实战JNI之“hello world” Android JNI学习(三)——Java与Nati ...

  6. TweenMax动画库学习(五)

    目录            TweenMax动画库学习(一)            TweenMax动画库学习(二)            TweenMax动画库学习(三)            Tw ...

  7. R基础学习

    R基础学习 The Art of R Programming 1.seq 产生等差数列:seq(from,to,by) seq(from,to,length) for(i in 1:length(x) ...

  8. SVG 学习<五> SVG动画

    目录 SVG 学习<一>基础图形及线段 SVG 学习<二>进阶 SVG世界,视野,视窗 stroke属性 svg分组 SVG 学习<三>渐变 SVG 学习<四 ...

  9. R语言学习 第四篇:函数和流程控制

    变量用于临时存储数据,而函数用于操作数据,实现代码的重复使用.在R中,函数只是另一种数据类型的变量,可以被分配,操作,甚至把函数作为参数传递给其他函数.分支控制和循环控制,和通用编程语言的风格很相似, ...

随机推荐

  1. 【C#】【Thread】SpinWait

    System.Threading.SpinWait 是一个轻量同步类型,可以在低级别方案中使用它来避免内核事件所需的高开销的上下文切换和内核转换. 在多核计算机上,当预计资源不会保留很长一段时间时,如 ...

  2. [转]用Whois获得电信运营商的IP地址是如何分配的?

    [转]用Whois获得电信运营商的IP地址是如何分配的? Linux下获得一些中国电信运营商的IP地址分配情况: APNIC是管理亚太地区IP地址分配的机构,它有着丰富准确的IP地址分配库,同时这些信 ...

  3. 2016网络大事记 mark

    记录2016年每天的大事件. 2016年01月07日     快播庭审.辩护人各种出彩. 2016年01月09日     乐视多个贴吧被爆.百度出面平息. 2016年01月10日     斗鱼TV造人 ...

  4. CSS强制性换行

    一般情况下,元素拥有默认的white-space:normal(自动换行,PS:不 换行是white-space:nowrap),当录入的文字超过定义的宽度后会自动换行,但当录入的数据是一堆没有空格的 ...

  5. Centos7下安装python,查看python版本

    安装Centos的时候,一般会自带默认安装python2.x 一般用python -V可以查看python版本. 我当时安装的时候,运行了那个语句,但是却显示了一大堆出来,虽然里面也带有版本信息,但是 ...

  6. UI: 窗口全屏, 窗口尺寸

    窗口全屏 窗口尺寸 示例1.窗口全屏UI/FullScreen.xaml <Page x:Class="Windows10.UI.FullScreen" xmlns=&quo ...

  7. 点击某个按钮在tableView某个位置动态插入一行cell

    实现步骤: 1.修改数据模型数组 给模型数组的某个位置增加一个模型 2.执行以下代码 NSIndexPath *indexPath = [NSIndexPath indexPathForRow: in ...

  8. ASP.NET MVC中viewData、viewBag和templateData的使用与区别

    一:类型比较 1.1)ViewBag是动态类型(dynamic). 1.2)ViewData是一个字典型的(Dictionary)-->ViewDataDictionary. 1.3)TempD ...

  9. POJ2763 Housewife Wind

      Time Limit: 4000MS   Memory Limit: 65536K Total Submissions: 9701   Accepted: 2661 Description Aft ...

  10. 移动端设置-----rem

    对于现在不同尺寸的移动端屏幕,如果设置px来说实在有点影响用户体验,在小屏幕上太大,大屏幕上太小,不能实现响应式,所以就引进了rem的概念. rem是相对于根元素<html> 在我的项目中 ...