因为要修改Caffe crop layer GPU部分的代码,现将自己对这部分GPU代码的理解总结一下,请大家多多指教!

crop layer完成的功能(以matlab的方式表示):A(N,C,H,W),Reference(n,c,h,w),Offsets(o1, o2, o3,o4), croped_A=A[o1:o1+n, o2:o2+c, o3:o3+h, o4:o4+w]

先代码,后解释

  1. #include <vector>
  2.  
  3. #include "caffe/layers/crop_layer.hpp"
  4.  
  5. namespace caffe {
  6.  
  7. __device__ int compute_uncropped_index(
  8. int index,
  9. const int ndims,
  10. const int* src_strides,
  11. const int* dest_strides,
  12. const int* offsets) {
  13. int dest_index = index;
  14. int src_index = ;
  15. for (int i = ; i < ndims; ++i) {
  16. int coord = dest_index / dest_strides[i];
  17. dest_index -= coord * dest_strides[i];
  18. src_index += src_strides[i] * (coord + offsets[i]);
  19. }
  20. return src_index;
  21. }
  22.  
  23. template <typename Dtype>
  24. __global__ void crop_kernel_forward(const int nthreads,
  25. const int ndims,
  26. const int* src_strides,
  27. const int* dest_strides,
  28. const int* offsets,
  29. const Dtype* src, Dtype* dest) {
  30. CUDA_KERNEL_LOOP(index, nthreads) {
  31. int src_index = compute_uncropped_index(
  32. index, ndims, src_strides, dest_strides, offsets);
  33. dest[index] = src[src_index];
  34. }
  35. }
  36.  
  37. template <typename Dtype>
  38. __global__ void crop_kernel_backward(const int nthreads,
  39. const int ndims,
  40. const int* src_strides,
  41. const int* dest_strides,
  42. const int* offsets,
  43. Dtype* src, const Dtype* dest) {
  44. CUDA_KERNEL_LOOP(index, nthreads) {
  45. int src_index = compute_uncropped_index(
  46. index, ndims, src_strides, dest_strides, offsets);
  47. src[src_index] = dest[index];
  48. }
  49. }
  50.  
  51. template <typename Dtype>
  52. void CropLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
  53. const vector<Blob<Dtype>*>& top) {
  54. const Dtype* bottom_data = bottom[]->gpu_data();
  55. Dtype* top_data = top[]->mutable_gpu_data();
  56. int n = top[]->count();
  57. // NOLINT_NEXT_LINE(whitespace/operators)
  58. crop_kernel_forward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n,
  59. bottom[]->num_axes(),
  60. src_strides_.gpu_data(),
  61. dest_strides_.gpu_data(),
  62. offsets.gpu_data(),
  63. bottom_data, top_data);
  64. }
  65.  
  66. template <typename Dtype>
  67. void CropLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
  68. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
  69. const Dtype* top_diff = top[]->gpu_diff();
  70. Dtype* bottom_diff = bottom[]->mutable_gpu_diff();
  71. int n = top[]->count();
  72.  
  73. if (propagate_down[]) {
  74. caffe_gpu_set(bottom[]->count(), static_cast<Dtype>(), bottom_diff);
  75. // NOLINT_NEXT_LINE(whitespace/operators)
  76. crop_kernel_backward<<<CAFFE_GET_BLOCKS(n), CAFFE_CUDA_NUM_THREADS>>>(n,
  77. bottom[]->num_axes(),
  78. src_strides_.gpu_data(),
  79. dest_strides_.gpu_data(),
  80. offsets.gpu_data(),
  81. bottom_diff, top_diff);
  82. }
  83. }
  84.  
  85. INSTANTIATE_LAYER_GPU_FUNCS(CropLayer);
  86.  
  87. } // namespace caffe

我将分析的重点放在Forward_gpu函数上,该函数在获取bottom、top data的指针之后,调用GPU端程序crop_kernel_forward。

其参数含义如下:

  • nthreads: nxcxhxw
  • ndims:4
  • src_strides: (CxHxW,HxW,W,1)
  • dest_strides:(cxhxw,hxw,w,1)
  • offsets:(o1, o2, o3, o4)
  • src:源指针
  • dest:目的指针

可以理解为src是A矩阵,dest就是我们需要的croped_A矩阵

crop_kernel_forward函数将每一个数据影射到一个线程,先计算通过compute_uncropped_index函数计算src_index,然后进行赋值。这里的重点是compute_uncropped_index,下面我通过函数注释的方式解析一下该函数的具体含义。

  1. __device__ int compute_uncropped_index(
  2. int index,
  3. const int ndims,
  4. const int* src_strides,
  5. const int* dest_strides,
  6. const int* offsets) {
  7. int dest_index = index; //将线程号赋给dest_index
  8. int src_index = ; //初始化src_index
  9. for (int i = ; i < ndims; ++i) { //每个维度分别处理
  10. int coord = dest_index / dest_strides[i];//coord表示dest第i个维度的坐标
  11. dest_index -= coord * dest_strides[i];//消除第i维坐标的影响
  12. src_index += src_strides[i] * (coord + offsets[i]);//coord和offsets[i]在src_index引入的偏移
  13. }
  14. return src_index;
  15. }

注释可能解释的比较含糊,可以简单理解为“给定一个index,获取dest对应的坐标(n’,c’,h’,w’),然后加上offsets偏移量,分别乘以不同坐标对应步长获取dest在src中的对应位置索引”。

Caffe代码分析--crop_layer.cu的更多相关文章

  1. caffe源代码分析--math_functions.cu代码研究

    当中用到一个宏定义CUDA_KERNEL_LOOP 在common.hpp中有. #defineCUDA_KERNEL_LOOP(i,n) \ for(inti = blockIdx.x * bloc ...

  2. angular代码分析之异常日志设计

    angular代码分析之异常日志设计 错误异常是面向对象开发中的记录提示程序执行问题的一种重要机制,在程序执行发生问题的条件下,异常会在中断程序执行,同时会沿着代码的执行路径一步一步的向上抛出异常,最 ...

  3. Caffe CommonLayer分析

    Caffe CommonLayer分析 \(Caffe\)中包含了很多通用的功能层,包含了\(concat\),\(slice\),\(split\),\(crop\),\(flip\),\(scal ...

  4. Android代码分析工具lint学习

    1 lint简介 1.1 概述 lint是随Android SDK自带的一个静态代码分析工具.它用来对Android工程的源文件进行检查,找出在正确性.安全.性能.可使用性.可访问性及国际化等方面可能 ...

  5. pmd静态代码分析

    在正式进入测试之前,进行一定的静态代码分析及code review对代码质量及系统提高是有帮助的,以上为数据证明 Pmd 它是一个基于静态规则集的Java源码分析器,它可以识别出潜在的如下问题:– 可 ...

  6. [Asp.net 5] DependencyInjection项目代码分析-目录

    微软DI文章系列如下所示: [Asp.net 5] DependencyInjection项目代码分析 [Asp.net 5] DependencyInjection项目代码分析2-Autofac [ ...

  7. [Asp.net 5] DependencyInjection项目代码分析4-微软的实现(5)(IEnumerable<>补充)

    Asp.net 5的依赖注入注入系列可以参考链接: [Asp.net 5] DependencyInjection项目代码分析-目录 我们在之前讲微软的实现时,对于OpenIEnumerableSer ...

  8. 完整全面的Java资源库(包括构建、操作、代码分析、编译器、数据库、社区等等)

    构建 这里搜集了用来构建应用程序的工具. Apache Maven:Maven使用声明进行构建并进行依赖管理,偏向于使用约定而不是配置进行构建.Maven优于Apache Ant.后者采用了一种过程化 ...

  9. STM32启动代码分析 IAR 比较好

    stm32启动代码分析 (2012-06-12 09:43:31) 转载▼     最近开始使用ST的stm32w108芯片(也是一款zigbee芯片).开始看他的启动代码看的晕晕呼呼呼的. 还好在c ...

随机推荐

  1. win8.1启用ahci后蓝屏

    先简单介绍一下,本应该win7开始,系统安装的时候默认就启用了ahci硬盘模式.但是博主犯了傻,装了win8.1后安装win XP形成双系统.xp并不支持ahci模式,所以将硬盘模式改成了IDE模式, ...

  2. JQuery插件之Animate.css和 jquery-aniview

    Animate.css 一款强大的预设css3动画库 简介 animate.css 是一个来自国外的 CSS3 动画库,它预设了抖动(shake).闪烁(flash).弹跳(bounce).翻转(fl ...

  3. JS组件系列——自己动手封装bootstrap-treegrid组件

    前言:最近产品需要设计一套相对完整的组织架构的解决方案,由于组织架构涉及到层级关系,在表格里面展示层级关系,自然就要用到所谓的treegrid.可惜的是,一些轻量级的表格组件本身并没有自带树形表格的功 ...

  4. 统计数据方面SQL与HQL

    因为HQL是面向对象的,所以对于统计数据方面使用HQL时不合适的,其实HQL最终还是会转化成SQL语句,项目里使用HQL语句应该是为了标准规范化. 统计的数据:同一个表,同一个字段,不同属性,统计不同 ...

  5. 十分钟彻底理解javascript 的 this指向,不懂请砸店

    函数的this指向谁,和函数在哪里被定义的,函数在哪里被执行的没有半毛钱关系,只遵守下面的规律: 在非严格模式中: 1.自执行函数里面,this永远指向window; <script> v ...

  6. TP5常用命令符操作

    ThinkPHP5常用命令符操作   1. 模块自动生成指令:   默认会读取应用目录application下面的build.php作为自动   生成的定义文件,如果你的定义文件位置不同,则需要使用 ...

  7. javascript 六种数据类型(一)

    js的数据类型和常见隐式转化逻辑. 一.六种数据类型 原始类型(基本类型):按值访问,可以操作保存在变量中实际的值.原始类型汇总中null和undefined比较特殊. 引用类型:引用类型的值是保存在 ...

  8. trait技术详解,这次包你学得会

    trait的使用技巧trait是php5.4以后新增加的一个功能,可以将多个类中,共用的一些属性和方法提取出来做来公共trait类,就像是装配汽车的配件,如果你的类中要用到这些配件,就直接用use导入 ...

  9. Java发布一个简单 webservice应用 并发送SOAP请求

    一.创建并发布一个简单的webservice应用 1.webservice 代码: package com.ls.demo; import javax.jws.WebMethod; import ja ...

  10. poj2481 Cows 树状数组

    题目链接:http://poj.org/problem?id=2481 解题思路: 这道题对每组数据进行查询,是树状数组的应用.对于二维的树状数组, 首先想到排序.现在对输入的数据按右值从大到小排序, ...