表达式模板是Eigen、GSL和boost.uBLAS等高性能C++矩阵库的核心技术。本文基于MXNet给出的教程文档来阐述MXNet所依赖的高性能矩阵库MShadow背后的原理。

编写高效的机器学习代码

我们先来思考一个问题:如何才能编写出高效的机器学习代码?假设DNN模型按照下面的代码进行权重更新,其中weightgrad都是长度为n的vector:

  1. weight = -eta * (grad + lambda * weight)

既然我们选择C++来实现矩阵计算,那么性能肯定是最先要考虑的因素。在C++编程中,一个重要的原则就是——预先分配好所需的内存,不要在运行时申请分配临时内存。因此,我们可能会实现一个名为UpdateWeight的函数,其形参gradweight均为指针:

  1. void UpdateWeight(const float *grad, float eta, float lambda,
  2. int n, float *weight) {
  3. for (int i = 0; i < n; ++i) {
  4. weight[i] = -eta * (grad[i] + lambda * weight[i]);
  5. }
  6. }

指针gradweight所指向的内存空间都是预先分配好的,函数在运行期只需要执行计算。显然,上面的代码非常简单直观,但是,当我们反复编写它们时可能会很烦(指编写循环处理每个元素)。因此问题是,我们能否像下面这样编写代码,同时又能获得上述代码的性能呢?答案是肯定的,但是方法可能并没有那么直观。

  1. void UpdateWeight(const Vec& grad, float eta, float lambda, Vec& weight) {
  2. weight = -eta * (grad + lambda * weight);
  3. }

运算符重载

运算符重载是一种非常容易想到的解决方案。通过重载相应的运算符,我们可以将元素处理的细节隐藏在运算符中,简单地调用运算符就可以实现相应的操作。

  1. // Naive solution for vector operation overloading
  2. struct Vec {
  3. int len;
  4. float* dptr;
  5. Vec(int len) : len(len) {
  6. dptr = new float[len];
  7. }
  8. Vec(const Vec& src) : len(src.len) {
  9. dptr = new float[len];
  10. memcpy(dptr, src.dptr, sizeof(float)*len );
  11. }
  12. ~Vec(void) {
  13. delete [] dptr;
  14. }
  15. };
  16. inline Vec operator+(const Vec &lhs, const Vec &rhs) {
  17. Vec res(lhs.len);
  18. for (int i = 0; i < lhs.len; ++i) {
  19. res.dptr[i] = lhs.dptr[i] + rhs.dptr[i];
  20. }
  21. return res;
  22. }

然而,这种方法并不高效,原因是每次调用运算符时都会有内存空间的申请和释放。另一种更高效的方法是仅重载运算符+=和-=,他们无需临时内存分配即可实现, 但这又限制了我们可以调用的运算符的数量,得不偿失。下一小节,我们将介绍如何利用表达式模板实现延迟计算。

延迟计算

在调用operator+时,因为我们不知道运算符的结果要赋值给哪个变量,所以需要申请一块临时内存空间把结果保存下来。否则,如果我们能提前直到运算结果要存放在哪个变量中,那么就可以直接将结果存储到相应的内存空间。下面的代码说明了这一情况:

  1. // Example Lazy evaluation code
  2. // for simplicity, we use struct and make all members public
  3. #include <cstdio>
  4. struct Vec;
  5. // expression structure holds the expression
  6. struct BinaryAddExp {
  7. const Vec &lhs;
  8. const Vec &rhs;
  9. BinaryAddExp(const Vec &lhs, const Vec &rhs)
  10. : lhs(lhs), rhs(rhs) {}
  11. };
  12. // no constructor and destructor to allocate and de-allocate memory,
  13. // allocation done by user
  14. struct Vec {
  15. int len;
  16. float* dptr;
  17. Vec(void) {}
  18. Vec(float *dptr, int len)
  19. : len(len), dptr(dptr) {}
  20. // here is where evaluation happens
  21. inline Vec &operator=(const BinaryAddExp &src) {
  22. for (int i = 0; i < len; ++i) {
  23. dptr[i] = src.lhs.dptr[i] + src.rhs.dptr[i];
  24. }
  25. return *this;
  26. }
  27. };
  28. // no evaluation happens here
  29. inline BinaryAddExp operator+(const Vec &lhs, const Vec &rhs) {
  30. return BinaryAddExp(lhs, rhs);
  31. }
  32. const int n = 3;
  33. int main(void) {
  34. float sa[n] = {1, 2, 3};
  35. float sb[n] = {2, 3, 4};
  36. float sc[n] = {3, 4, 5};
  37. Vec A(sa, n), B(sb, n), C(sc, n);
  38. // run expression
  39. A = B + C;
  40. for (int i = 0; i < n; ++i) {
  41. printf("%d:%f==%f+%f\\n", i, A.dptr[i], B.dptr[i], C.dptr[i]);
  42. }
  43. return 0;
  44. }

在这段代码中,运算符operator+不进行实际的计算,只返回一个表达式结构BinaryAddExp,它里面保存了进行向量加法的两个操作数。当重载operator=时,我们就可以知道向量加法的目标变量以及对应的两个操作数,因此,在这种情况下,不需要任何(运行时)内存分配就可以执行计算操作!类似的,我们可以定义一个DotExp并在operator=处进行惰性求值,然后在内部调用BLAS实现矩阵乘法。

更复杂的表达式以及表达式模板

使用延迟计算能够避免运行期的临时内存分配。但是,上一小节的代码仍然面临以下两个问题:

  • 只能编写形如A=B+C的表达式,不能编写类似A=B+C+D等更加复杂的表达式
  • 如果想添加更多的表达式,就需要编写更多的operator=来执行相应的计算

上述问题的解决方法就是使用模板编程。我们将BinaryAddExp实现成一个模板类,它保存的两个操作数都是模板,这样就能够实现任意长度的加法表达式,具体代码如下。

  1. // Example code, expression template, and more length equations
  2. // for simplicity, we use struct and make all members public
  3. #include <cstdio>
  4. // this is expression, all expressions must inheritate it,
  5. // and put their type in subtype
  6. template<typename SubType>
  7. struct Exp {
  8. // returns const reference of the actual type of this expression
  9. inline const SubType& self(void) const {
  10. return *static_cast<const SubType*>(this);
  11. }
  12. };
  13. // binary add expression
  14. // note how it is inheritates from Exp
  15. // and put its own type into the template argument
  16. template<typename TLhs, typename TRhs>
  17. struct BinaryAddExp: public Exp<BinaryAddExp<TLhs, TRhs> > {
  18. const TLhs &lhs;
  19. const TRhs &rhs;
  20. BinaryAddExp(const TLhs& lhs, const TRhs& rhs)
  21. : lhs(lhs), rhs(rhs) {}
  22. // evaluation function, evaluate this expression at position i
  23. inline float Eval(int i) const {
  24. return lhs.Eval(i) + rhs.Eval(i);
  25. }
  26. };
  27. // no constructor and destructor to allocate
  28. // and de-allocate memory, allocation done by user
  29. struct Vec: public Exp<Vec> {
  30. int len;
  31. float* dptr;
  32. Vec(void) {}
  33. Vec(float *dptr, int len)
  34. :len(len), dptr(dptr) {}
  35. // here is where evaluation happens
  36. template<typename EType>
  37. inline Vec& operator= (const Exp<EType>& src_) {
  38. const EType &src = src_.self();
  39. for (int i = 0; i < len; ++i) {
  40. dptr[i] = src.Eval(i);
  41. }
  42. return *this;
  43. }
  44. // evaluation function, evaluate this expression at position i
  45. inline float Eval(int i) const {
  46. return dptr[i];
  47. }
  48. };
  49. // template add, works for any expressions
  50. template<typename TLhs, typename TRhs>
  51. inline BinaryAddExp<TLhs, TRhs>
  52. operator+(const Exp<TLhs> &lhs, const Exp<TRhs> &rhs) {
  53. return BinaryAddExp<TLhs, TRhs>(lhs.self(), rhs.self());
  54. }
  55. const int n = 3;
  56. int main(void) {
  57. float sa[n] = {1, 2, 3};
  58. float sb[n] = {2, 3, 4};
  59. float sc[n] = {3, 4, 5};
  60. Vec A(sa, n), B(sb, n), C(sc, n);
  61. // run expression, this expression is longer:)
  62. A = B + C + C;
  63. for (int i = 0; i < n; ++i) {
  64. printf("%d:%f == %f + %f + %f\\n", i,
  65. A.dptr[i], B.dptr[i],
  66. C.dptr[i], C.dptr[i]);
  67. }
  68. return 0;
  69. }

代码的关键思想是模板Exp<SubType>将其派生类SubType的类型作为模板参数,因此它可以通过self()将其自身转换为SubType,这种模式被称为奇异递归模板模式,简称CRTPBinaryAddExp现在是一个模板类,可以将多个表达式复合在一起。真正的计算操作是通过函数Eval完成的,该函数在BinaryAddExp中以递归方式实现,operator=中的函数调用src.Eval(i)会被编译成B.dptr[i] + C.dptr[i] + C.dptr[i]

灵活性

前面的示例让我们领略到模板编程的强大功能,而这最后一个示例则与MShadow的实现更为接近,它允许用户自定义二元操作符。

  1. // Example code, expression template
  2. // with binary operator definition and extension
  3. // for simplicity, we use struct and make all members public
  4. #include <cstdio>
  5. // this is expression, all expressions must inheritate it,
  6. // and put their type in subtype
  7. template<typename SubType>
  8. struct Exp{
  9. // returns const reference of the actual type of this expression
  10. inline const SubType& self(void) const {
  11. return *static_cast<const SubType*>(this);
  12. }
  13. };
  14. // binary operators
  15. struct mul{
  16. inline static float Map(float a, float b) {
  17. return a * b;
  18. }
  19. };
  20. // binary add expression
  21. // note how it is inheritates from Exp
  22. // and put its own type into the template argument
  23. template<typename OP, typename TLhs, typename TRhs>
  24. struct BinaryMapExp: public Exp<BinaryMapExp<OP, TLhs, TRhs> >{
  25. const TLhs& lhs;
  26. const TRhs& rhs;
  27. BinaryMapExp(const TLhs& lhs, const TRhs& rhs)
  28. :lhs(lhs), rhs(rhs) {}
  29. // evaluation function, evaluate this expression at position i
  30. inline float Eval(int i) const {
  31. return OP::Map(lhs.Eval(i), rhs.Eval(i));
  32. }
  33. };
  34. // no constructor and destructor to allocate and de-allocate memory
  35. // allocation done by user
  36. struct Vec: public Exp<Vec>{
  37. int len;
  38. float* dptr;
  39. Vec(void) {}
  40. Vec(float *dptr, int len)
  41. : len(len), dptr(dptr) {}
  42. // here is where evaluation happens
  43. template<typename EType>
  44. inline Vec& operator=(const Exp<EType>& src_) {
  45. const EType &src = src_.self();
  46. for (int i = 0; i < len; ++i) {
  47. dptr[i] = src.Eval(i);
  48. }
  49. return *this;
  50. }
  51. // evaluation function, evaluate this expression at position i
  52. inline float Eval(int i) const {
  53. return dptr[i];
  54. }
  55. };
  56. // template binary operation, works for any expressions
  57. template<typename OP, typename TLhs, typename TRhs>
  58. inline BinaryMapExp<OP, TLhs, TRhs>
  59. F(const Exp<TLhs>& lhs, const Exp<TRhs>& rhs) {
  60. return BinaryMapExp<OP, TLhs, TRhs>(lhs.self(), rhs.self());
  61. }
  62. template<typename TLhs, typename TRhs>
  63. inline BinaryMapExp<mul, TLhs, TRhs>
  64. operator*(const Exp<TLhs>& lhs, const Exp<TRhs>& rhs) {
  65. return F<mul>(lhs, rhs);
  66. }
  67. // user defined operation
  68. struct maximum{
  69. inline static float Map(float a, float b) {
  70. return a > b ? a : b;
  71. }
  72. };
  73. const int n = 3;
  74. int main(void) {
  75. float sa[n] = {1, 2, 3};
  76. float sb[n] = {2, 3, 4};
  77. float sc[n] = {3, 4, 5};
  78. Vec A(sa, n), B(sb, n), C(sc, n);
  79. // run expression, this expression is longer:)
  80. A = B * F<maximum>(C, B);
  81. for (int i = 0; i < n; ++i) {
  82. printf("%d:%f == %f * max(%f, %f)\\n",
  83. i, A.dptr[i], B.dptr[i], C.dptr[i], B.dptr[i]);
  84. }
  85. return 0;
  86. }

这段代码与上一小节代码的主要区别是模板类BinaryMapExp可以接受任意类型的二元操作符,要求是该操作符必须要实现一个Map函数。个人理解,第62行实现的那个F函数主要是为了编写代码方便,如果没有它,那么第69行就要写成BinaryMapExp<mul, TLhs, TRhs>(lhs.self(), rhs.self());,写起来就比较麻烦。其他的地方基本上与前一小节的代码差不多,稍微一看就能明白。

小结

综上所述,表达式模板基本工作原理包括以下几点:

  • 延迟计算,允许我们提前知道运算符和目标变量
  • 组合模板与递归计算,允许我们执行任意element-wise操作的复合表达式
  • 由于模板和内联的存在,表达式模板能够像编写for循环那样高效的实现element-wise的计算

MShadow中的表达式模板

MShadow中的表达式模板的原理与文中介绍的基本一致,但实现上还是有一些微小的差别:

  • 将计算代码和表达式构造相分离

    • 没有把Eval函数实现在Exp类中,而是根据表达式创建一个Plan类,并用它计算结果
    • 这样做的一个目的是减少Plan类中的私有变量数量,比如不需要知道数组的长度就可以计算结果
    • 另一个原因是CUDA kernel不能处理包含const reference的类
    • 这种设计值得商榷,但是目前很有用
  • 延迟计算支持复杂的表达式,例如矩阵点乘
    • 除了element-wise的表达式,MShadow还计算实现形如A = dot(B.T(), C)的语法糖。
  • 支持类型检查和数组长度检查

后记

C++11中引入了移动构造函数,可用于保存重复分配的内存,从而消除了一些需要用到表达式模板的情况。然而,内存空间仍然至少需要被分配一次。

  • This only removes the need of expression template then expression generate space, say dst = A + B + C, dst does not contain space allocated before assignment. (这句话没有理解它的意思,先把原文放这里吧)
  • 如果想要保留一切都是预先分配的这种syntax,并且表达式无需内存分配即可执行(这就是MShadow所做的事情),那么仍然需要表达式模板。

MShadow中的表达式模板的更多相关文章

  1. Expression Template(表达式模板,ET)

    1.前言 在前一篇文章自己实现简单的string类中提到在实现+操作符重载函数时,为了防止返回时生成的临时对象调用拷贝构造函数动态申请内存空间,使用了一个叫move的函数,它是C++0x新增的特性.既 ...

  2. Win 10 开发中Adaptive磁贴模板的XML文档结构,Win10 应用开发中自适应Toast通知的XML文档结构

    分享两篇Win 10应用开发的XML文档结构:Win 10 开发中Adaptive磁贴模板的XML文档结构,Win10 应用开发中自适应Toast通知的XML文档结构. Win 10 开发中Adapt ...

  3. C++ template —— 表达式模板(十)

    表达式模板解决的问题是:对于一个数值数组类,它需要为基于整个数组对象的数值操作提供支持,如对数组求和或放大: Array<), y(); ... x = 1.2 * x + x * y; 对效率 ...

  4. Spring中使用Velocity模板

    使用Velocity模板 Velocity是一种针对Java应用的易用的模板语言.Velocity模板中没有任何 Java代码,这使得它能够同时被非开发人员和开发人员轻松地理解.Velocity的用户 ...

  5. VS编译环境中TBB配置和C++中lambda表达式

    TBB(Thread Building Blocks),线程构建模块,是由Intel公司开发的并行编程开发工具,提供了对Windows,Linux和OSX平台的支持. TBB for Windows ...

  6. 解决Affter Effect汉化版(cc2015之后的版本)中出现表达式错误的一种常用方法

    解决Affter Effect出现表达式错误的一种常用方法 问题:汉化版的AE中,使用模板会出现表达式错误之类的提示,可能会导致某些设置或者效果失效 解决办法: 方法一.将配置文件中的zh_CN 改为 ...

  7. MVC中使用T4模板

    参考博文 http://www.cnblogs.com/heyuquan/archive/2012/07/26/2610959.html 图片释义 1.简单示例,对基本的模块标记 2.根据上图生成的类 ...

  8. VS2013中的MVC5模板部署到mono上的艰辛历程

    部署环境:CentOS7 + Mono 3.10 + Jexus 5.6 在Xamarin.Studio创建的asp.net项目,部署过程非常顺利,没有遇到什么问题:但在VS2013中创建的asp.n ...

  9. ThinkPHP Where 条件中使用表达式

    本文转自:这里 Where 条件表达式格式为: $map['字段名'] = array('表达式', '操作条件'); 其中 $map 是一个普通的数组变量,可以根据自己需求而命名.上述格式中的表达式 ...

随机推荐

  1. 【Java】静态与非静态

    文章目录 静态与非静态 static关键字 使用static修饰属性:静态变量(或类变量) 类变量与实例变量的内存解析 使用static修饰方法:静态方法 使用static的注意点 开发中,如何确定一 ...

  2. idea环境下SpringBoot Web应用引入JSP

    1. 环境 开发环境:idea2019.3 jkd版本:1.8 springboot版本:2.6.2 2. 引入JSP的步骤 2.1 新建工程,引入依赖 这里只是解析jsp,因此只需要引入spring ...

  3. JavaScript DOM 基础操作

    JavaScript DOM 基础操作 一.获取元素的六方式 document.getElementById('id名称') //根据id名称获取 document.getElementsByclas ...

  4. [Altium Designer 学习]怎样输出Gerber文件和钻孔文件

    为了资料保密和传输方便,交给PCB厂商打样的资料一般以Gerber和钻孔文件为主,换句话说,只要有前面说的两种文件,就能制作出你想要的PCB了. 一般来说,交给PCB厂商的Gerber有以下几层: G ...

  5. cesium加载gltf模型点击以及列表点击定位弹窗

    前言 cesium 官网的api文档介绍地址cesium官网api,里面详细的介绍 cesium 各个类的介绍,还有就是在线例子:cesium 官网在线例子,这个也是学习 cesium 的好素材. 之 ...

  6. Cesium入门3 - Cesium目录框架结构

    Cesium入门3 - Cesium目录框架结构 Cesium中文网:http://cesiumcn.org/ | 国内快速访问:http://cesium.coinidea.com/ app目录 下 ...

  7. java关键字final

    //继承弊端:打破了封装性 /* * final关键字: * 1,final是一个修饰符,可以修饰类,方法,变量. * 2,final修饰的类不可以被继承. * 3,final修饰的方法不可以被覆盖. ...

  8. 集合框架-工具类-JDK5.0特性-函数可变参数

    1 package cn.itcast.p4.news.demo; 2 3 public class ParamterDemo { 4 5 public static void main(String ...

  9. PyTorch 介绍 | DATSETS & DATALOADERS

    用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...

  10. 学习Java第6天

    今天所做的工作: 1.完成学生信息管理系统样卷 2.核心技术接口继承,多态 明天工作安排: 1.类的高级特性(Java类包) 2.异常处理 今天做一个小小的总结,Java程序是完全面向对象的,它的所有 ...