转载自https://www.jianshu.com/p/e5b03cf22c80

Ceres solver 是谷歌开发的一款用于非线性优化的库,在谷歌的开源激光雷达slam项目cartographer中被大量使用。

Ceres简易例程

使用Ceres求解非线性优化问题,一共分为三个部分:
1、 第一部分:构建cost fuction,即代价函数,也就是寻优的目标式。这个部分需要使用仿函数(functor)这一技巧来实现,做法是定义一个cost function的结构体,在结构体内重载()运算符,具体实现方法后续介绍。
2、 第二部分:通过代价函数构建待求解的优化问题。
3、 第三部分:配置求解器参数并求解问题,这个步骤就是设置方程怎么求解、求解过程是否输出等,然后调用一下Solve方法。

好了,此时你应该对ceres的大概使用流程有了一个基本的认识。下面我就基于ceres官网上的教程中的一个例程来详细介绍一下ceres的用法。
Ceres官网教程给出的例程中,求解的问题是求x使得1/2*(10-x)^2取到最小值。(很容易心算出x的解应该是10)
好,来看代码:

#include<iostream>
#include<ceres/ceres.h>
using namespace std;
using namespace ceres; //第一部分:构建代价函数,重载()符号,仿函数的小技巧
struct CostFunctor
{
template <typename T> bool operator()(const T* const x, T* residual) const
{
residual[] = T(10.0) - x[]; return true;
}
};
//主函数
int main(int argc, char** argv)
{
google::InitGoogleLogging(argv[]);
// 寻优参数x的初始值,为5
double initial_x = 5.0;
double x = initial_x;
// 第二部分:构建寻优问题
Problem problem;
CostFunction* cost_function = new AutoDiffCostFunction<CostFunctor, , >(new CostFunctor);
//使用自动求导,将之前的代价函数结构体传入,第一个1是输出维度,即残差的维度,第二个1是输入维度,即待寻优参数x的维度。
problem.AddResidualBlock(cost_function, NULL, &x);
//向问题中添加误差项,本问题比较简单,添加一个就行。
//第三部分: 配置并运行求解器
Solver::Options options; options.linear_solver_type = ceres::DENSE_QR;
//配置增量方程的解法
options.minimizer_progress_to_stdout = true;
//输出到cout
Solver::Summary summary;
//优化信息
Solve(options, &problem, &summary);
//求解!!!
std::cout << summary.BriefReport() << "\n";
//输出优化的简要信息
//最终结果
std::cout << "x : " << initial_x << " -> " << x << "\n";
return ;
}

第一部分:构造代价函数结构体

这里的使用了仿函数的技巧,即在CostFunction结构体内,对()进行重载,这样的话,该结构体的一个实例就能具有类似一个函数的性质,在代码编写过程中就能当做一个函数一样来使用。
关于仿函数,这里再多说几句,对结构体、类的一个实例,比如Myclass类的一个实例Obj1,如果Myclass里对()进行了重载,那Obj1被创建之后,就可以将Obj1这个实例当做函数来用,比如Obj(x)这样,为了方便读者理解,下面随便编一段简单的示例代码,凑活看看吧。

//仿函数的示例代码
#include<iostream>
using namespace std;
class Myclass
{
public:
Myclass(int x):_x(x){};
int operator()(const int n)const
{ return n*_x; }
private: int _x;
};
int main()
{
Myclass Obj1();
cout<<Obj1()<<endl;
system("pause");
return ;
}

在我随便写的示教代码中,可以看到我将Myclass的()符号的功能定义成了将括号内的数n乘以隐藏参数x倍,其中x是Obj1对象的一个私有成员变量,是是在构造Obj1时候赋予的。因为重载了()符号,所以在主函数中Obj1这个对象就可以当做一个函数来使用,使用方法为Obj1(n),如果Obj1的内部成员变量_x是5,则此函数功能就是将输入参数扩大5倍,如果这个成员变量是50,Obj1()函数的功能就是将输入n扩大50倍,这也是仿函数技巧的一个优点,它能利用对象的成员变量来储存更多的函数内部参数。

了解了仿函数技巧的使用方法后,再回过头来看看ceres使用中构造CostFuction 的具体方法:
CostFunction结构体中,对括号符号重载的函数中,传入参数有两个,一个是待优化的变量x,另一个是残差residual,也就是代价函数的输出。
重载了()符号之后,CostFunction就可以传入AutoDiffCostFunction方法来构建寻优问题了。

第二部分:通过代价函数构建待求解的优化问题

Problem problem;
CostFunction* cost_function = new AutoDiffCostFunction<CostFunctor, , >(new CostFunctor);
problem.AddResidualBlock(cost_function, NULL, &x);
这一部分就是待求解的优化问题的构建过程,使用之前结构体创建一个实例,由于使用了仿函数技巧,该实例在使用上可以当做一个函数。基于该实例new了一个CostFunction结构体,这里使用的自动求导,将之前的代价函数结构体传入,第一个1是输出维度,即残差的维度,第二个1是输入维度,即待寻优参数x的维度。分别对应之前结构体中的residual和x。
向问题中添加误差项,本问题比较简单,添加一次就行(有的问题要不断多次添加ResidualBlock以构建最小二乘求解问题)。这里的参数NULL是指不使用核函数,&x表示x是待寻优参数。

第三部分:配置问题并求解问题

Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
Solver::Summary summary;
Solve(options, &problem, &summary);
std::cout << summary.BriefReport() << "\n";
std::cout << "x : " << initial_x << " -> " << x << "\n";

这一部分很好理解,创建一个Option,配置一下求解器的配置,创建一个Summary。最后调用Solve方法,求解。
最后输出结果:

iter      cost      cost_change  |gradient|   |step|    tr_ratio  tr_radius  ls_iter  iter_time  total_time
1.250000e+01 0.00e+00 5.00e+00 0.00e+00 0.00e+00 1.00e+04 3.41e-05 1.36e-04
1.249750e-07 1.25e+01 5.00e-04 5.00e+00 1.00e+00 3.00e+04 5.89e-05 2.66e-04
1.388518e-16 1.25e-07 1.67e-08 5.00e-04 1.00e+00 9.00e+04 1.91e-05 3.01e-04
x : ->

读者们看到这里相信已经对Ceres库的使用已经有了一个大概的认识,现在可以试着将代码实际运行一下来感受一下,加深一下理解。
博主的使用环境为Ubuntu 16.04,所以在此附上CMakeList.txt

附:CMakeLists.txt代码:

cmake_minimum_required(VERSION 2.8)
project(ceres) find_package(Ceres REQUIRED)
include_directories( ${CERES_INCLUDE_DIRS} ) add_executable(use_ceres l_2.cpp)
target_link_libraries(use_ceres ${CERES_LIBRARIES})

进阶-更多的求导法

在上面的例子中,使用的是自动求导法(AutoDiffCostFunction),Ceres库中其实还有更多的求导方法可供选择(虽然自动求导的确是最省心的,而且一般情况下也是最快的。。。)。这里就简要介绍一下其他的求导方法:
数值求导法(一般比自动求导法收敛更慢,且更容易出现数值错误):
数值求导法的代价函数结构体构建和自动求导中的没有区别,只是在第二部分的构建求解问题中稍有区别,下面是官网给出的数值求导法的问题构建部分代码:

CostFunction* cost_function = new NumericDiffCostFunction<NumericDiffCostFunctor, ceres::CENTRAL, , >( new NumericDiffCostFunctor); 
problem.AddResidualBlock(cost_function, NULL, &x);
乍一看和自动求导法中的代码没区别,除了代价函数结构体的名字定义得稍有不同,使用的是NumericDiffCostFunction而非AutoDiffCostFunction,改动的地方只有在模板参数设置输入输出维度前面加了一个模板参数ceres::CENTRAL,表明使用的是数值求导法。
还有其他一些更多更复杂的求导法,不详述。

再进阶-曲线拟合

趁热打铁,阅读到这里想必读者们应该已经对Ceres库的使用已经比较了解了(如果前面认真看了的话),现在就来尝试解决一个更加复杂的问题来检验一下成果,顺便进阶一下。
问题:
拟合非线性函数的曲线(和官网上的例子不一样,稍微复杂一丢丢):
y=e{3x{2}+2x+1}
依然,先上代码:
代码之前先啰嗦几句,整个代码的思路还是先构建代价函数结构体,然后在[0,1]之间均匀生成待拟合曲线的1000个数据点,加上方差为1的白噪声,数据点用两个vector储存(x_data和y_data),然后构建待求解优化问题,最后求解,拟合曲线参数。
(PS. 本段代码中使用OpenCV的随机数产生器,要跑代码的同学可能要先装一下OpenCV)

#include <iostream>
#include <opencv2/core/core.hpp>
#include <ceres/ceres.h> using namespace std;
using namespace cv; //构建代价函数结构体,abc为待优化参数,residual为残差
struct CURVE_FITTING_COST
{
CURVE_FITTING_COST(double x, double y): _x(x), _y(y){}
template<typename T>
bool operator()(const T* const abc, T* residual) const
{
residual[] = _y - ceres::exp(abc[] * _x * _x + abc[] * _x + abc[]);
return true;
}
const double _x, _y;
}; int main()
{
//参数初始化设置,abc初始化为0,白噪声方差为1(使用opencv的随机数产生器)
double a = , b = , c = ;
double w = ;
RNG rng;
double abc[] = {, , };
//生成待拟合曲线的数据散点,存储在Vector里,x_data, y_data
vector<double> x_data, y_data;
for(int i = ; i < ; i++)
{
double x = i / 1000.0;
x_data.push_back(x);
y_data.push_back(exp(a * x * x + b * x + c) + rng.gaussian(w));
} //反复使用AddResidualBlock方法(逐个散点,反复1000次)
//将每个点的残差累积求和构建最小二乘优化式
//不使用核函数,待优化参数是abc
ceres::Problem problem;
for(int i = ; i < ; i++)
{
problem.AddResidualBlock(
new ceres::AutoDiffCostFunction<CURVE_FITTING_COST, , >(
new CURVE_FITTING_COST(x_data[i], y_data[i])),
nullptr,
abc
);
}
//配置求解器并求解,输出结果
ceres::Solver::Options options;
options.linear_solver_type = ceres::DENSE_QR;
options.minimizer_progress_to_stdout = true;
ceres::Solver::Summary summary;
ceres::Solve( options, &problem, &summary); cout << "a = " << abc[] << endl;
cout << "b = " << abc[] << endl;
cout << "c = " << abc[] << endl;
return ; }

对应的CMakeLists.txt

cmake_minimum_required(VERSION 2.8)
project(ceres) find_package(Ceres REQUIRED)
include_directories( ${CERES_INCLUDE_DIRS} )
find_package(OpenCV REQUIRED)
include_directories( ${OpenCV_INCLUDE_DIRS} )
add_executable(curve curve.cpp)
target_link_libraries(curve ${CERES_LIBRARIES} ${OpenCV_LIBS})

代码解读:
代码的整体流程还是之前的流程,四个部分大致相同。比之前稍微复杂一点的地方就在于计算单个点的残差时需要输入该点的x,y坐标,而且需要反复多次累计单点的残差以构建总体的优化目标。
先看代价函数结构体的构建:

struct CURVE_FITTING_COST {
CURVE_FITTING_COST(double x,double y):_x(x),_y(y){}
template <typename T> bool operator()(const T* const abc,T* residual)const
{
residual[]=_y-ceres::exp(abc[]*_x*_x+abc[]*_x+abc[]);
return true;
}
const double _x,_y;
};

这里依然使用仿函数技巧,与之前不同的是结构体内部有_x,_y成员变量,用于储存散点的坐标。
再看优化问题的构建:

 ceres::Problem problem;
for(int i=;i<;i++) {
//自动求导法,输出维度1,输入维度3,
problem.AddResidualBlock( new ceres::AutoDiffCostFunction<CURVE_FITTING_COST,,>( new CURVE_FITTING_COST(x_data[i],y_data[i]) ),
nullptr,
abc );
}
这里由于有1000个点,所以需要对每个点计算一次残差,将所有残差累积在一起构成问题的总体优化目标,所以for循环1000次。
这里与前例不同的是需要输入散点的坐标x,y,由于_x,_y是结构体成员变量,所以可以通过构造函数直接对这两个值赋值。本代码里也是这么用的。
最终的运行结果是:
iter      cost      cost_change  |gradient|   |step|    tr_ratio  tr_radius  ls_iter  iter_time  total_time
5.277388e+06 0.00e+00 5.58e+04 0.00e+00 0.00e+00 1.00e+04 1.68e-02 1.71e-02
4.287886e+238 -4.29e+238 0.00e+00 7.39e+02 -8.79e+231 5.00e+03 4.92e-04 1.77e-02
1.094203e+238 -1.09e+238 0.00e+00 7.32e+02 -2.24e+231 1.25e+03 3.86e-04 1.82e-02
5.129910e+234 -5.13e+234 0.00e+00 6.96e+02 -1.05e+228 1.56e+02 3.69e-04 1.86e-02
1.420558e+215 -1.42e+215 0.00e+00 4.91e+02 -2.97e+208 9.77e+00 3.58e-04 1.89e-02
9.607928e+166 -9.61e+166 0.00e+00 1.85e+02 -2.23e+160 3.05e-01 3.59e-04 1.93e-02
7.192680e+60 -7.19e+60 0.00e+00 4.59e+01 -2.94e+54 4.77e-03 3.70e-04 1.97e-02
5.061060e+06 2.16e+05 2.68e+05 1.21e+00 2.52e+00 1.43e-02 1.82e-02 3.79e-02
4.342234e+06 7.19e+05 9.34e+05 8.84e-01 2.08e+00 4.29e-02 1.58e-02 5.38e-02
2.876001e+06 1.47e+06 2.06e+06 6.42e-01 1.66e+00 1.29e-01 1.12e-02 6.50e-02
1.018645e+06 1.86e+06 2.58e+06 4.76e-01 1.38e+00 3.86e-01 1.95e-02 8.46e-02
1.357731e+05 8.83e+05 1.30e+06 2.56e-01 1.13e+00 1.16e+00 9.16e-03 9.38e-02
2.142986e+04 1.14e+05 2.71e+05 8.60e-02 1.03e+00 3.48e+00 7.91e-03 1.02e-01
1.636436e+04 5.07e+03 5.94e+04 3.01e-02 1.01e+00 1.04e+01 7.07e-03 1.09e-01
1.270381e+04 3.66e+03 3.96e+04 6.21e-02 9.96e-01 3.13e+01 5.77e-03 1.15e-01
6.723500e+03 5.98e+03 2.68e+04 1.30e-01 9.89e-01 9.39e+01 5.15e-03 1.20e-01
1.900795e+03 4.82e+03 1.24e+04 1.76e-01 9.90e-01 2.82e+02 5.10e-03 1.25e-01
5.933860e+02 1.31e+03 3.45e+03 1.23e-01 9.96e-01 8.45e+02 5.34e-03 1.30e-01
5.089437e+02 8.44e+01 3.46e+02 3.77e-02 1.00e+00 2.53e+03 5.44e-03 1.36e-01
5.071157e+02 1.83e+00 4.47e+01 1.63e-02 1.00e+00 7.60e+03 5.43e-03 1.41e-01
5.056467e+02 1.47e+00 3.03e+01 3.13e-02 1.00e+00 2.28e+04 5.37e-03 1.47e-01
5.046313e+02 1.02e+00 1.23e+01 3.82e-02 1.00e+00 6.84e+04 5.48e-03 1.52e-01
5.044403e+02 1.91e-01 2.23e+00 2.11e-02 9.99e-01 2.05e+05 5.60e-03 1.58e-01
5.044338e+02 6.48e-03 1.38e-01 4.35e-03 9.98e-01 6.16e+05 6.58e-03 1.65e-01
a = 3.01325
b = 1.97599
c = 1.01113

可以看到,最终的拟合结果与真实值非常接近。

再再进阶-鲁棒曲线拟合

求解优化问题中(比如拟合曲线),数据中往往会有离群点、错误值什么的,最终得到的寻优结果很容易受到影响,此时就可以使用一些损失核函数来对离群点的影响加以消除。要使用核函数,只需要把上述代码中的NULL或nullptr换成损失核函数结构体的实例。
Ceres库中提供的核函数主要有:TrivialLoss 、HuberLoss、 SoftLOneLoss 、 CauchyLoss。
比如此时要使用CauchyLoss,只需要将nullptr换成new CauchyLoss(0.5)就行(0.5为参数)。
下面两图别为Ceres官网上的例程的结果,可以明显看出使用损失核函数之后的曲线收离群点的影响更小。

不使用鲁棒核函数的拟合
使用鲁棒核函数的拟合

参考资料:
[1]. http://www.ceres-solver.org/
[2].《视觉slam十四讲》 高翔
[3]. http://www.cnblogs.com/xiaowangba/p/6313933.html

 
 
 
 
 
 
 
 
 
 
 
 

ceres入门学习的更多相关文章

  1. vue入门学习(基础篇)

    vue入门学习总结: vue的一个组件包括三部分:template.style.script. vue的数据在data中定义使用. 数据渲染指令:v-text.v-html.{{}}. 隐藏未编译的标 ...

  2. Hadoop入门学习笔记---part4

    紧接着<Hadoop入门学习笔记---part3>中的继续了解如何用java在程序中操作HDFS. 众所周知,对文件的操作无非是创建,查看,下载,删除.下面我们就开始应用java程序进行操 ...

  3. Hadoop入门学习笔记---part3

    2015年元旦,好好学习,天天向上.良好的开端是成功的一半,任何学习都不能中断,只有坚持才会出结果.继续学习Hadoop.冰冻三尺,非一日之寒! 经过Hadoop的伪分布集群环境的搭建,基本对Hado ...

  4. PyQt4入门学习笔记(三)

    # PyQt4入门学习笔记(三) PyQt4内的布局 布局方式是我们控制我们的GUI页面内各个控件的排放位置的.我们可以通过两种基本方式来控制: 1.绝对位置 2.layout类 绝对位置 这种方式要 ...

  5. PyQt4入门学习笔记(一)

    PyQt4入门学习笔记(一) 一直没有找到什么好的pyqt4的教程,偶然在google上搜到一篇不错的入门文档,翻译过来,留以后再复习. 原始链接如下: http://zetcode.com/gui/ ...

  6. Hadoop入门学习笔记---part2

    在<Hadoop入门学习笔记---part1>中感觉自己虽然总结的比较详细,但是始终感觉有点凌乱.不够系统化,不够简洁.经过自己的推敲和总结,现在在此处概括性的总结一下,认为在准备搭建ha ...

  7. Retrofit 入门学习

    Retrofit 入门学习官方RetrofitAPI 官方的一个例子 public interface GitHubService { @GET("users/{user}/repos&qu ...

  8. MyBatis入门学习教程-使用MyBatis对表执行CRUD操作

    上一篇MyBatis学习总结(一)--MyBatis快速入门中我们讲了如何使用Mybatis查询users表中的数据,算是对MyBatis有一个初步的入门了,今天讲解一下如何使用MyBatis对use ...

  9. opengl入门学习

    OpenGL入门学习 说起编程作图,大概还有很多人想起TC的#include <graphics.h>吧? 但是各位是否想过,那些画面绚丽的PC游戏是如何编写出来的?就靠TC那可怜的640 ...

随机推荐

  1. python远程操作服务器

    python远程控制 标签(空格分隔): 远程Linux python远程控制:方案: Paramiko Pexpect(主要Linux机器) 安装Paramiko pip install param ...

  2. angular记录

    1. <h1>{{title}}</h1> 双花括号语法是 Angular 的插值绑定语法. 这个插值绑定的意思是把组件的 title 属性的值绑定到 HTML 中的 h1 标 ...

  3. BOM 对象--location、navigator、screen、history

    1.location 对象 location提供了与当前窗口中加载的文档有关的信息,还有一些导航功能.需要注意的是,window.location 和 document.location 引用的是同一 ...

  4. kali安装显卡驱动

    由于我们使用cpu一般最多也就是4到16核,而一块不错的gpu可以多大上千核,在并行复杂运算能力上GPU的运算速度远远超过CPU的运算速度,所以很多场合比如暴力穷举破解,挖矿更多地使用GPU,所以有必 ...

  5. django models返回数据根据某字段倒序排列

    例如有一个models表叫做report,report表中有一个endtime,想将结果按照endtime倒序排列   正序排列的方法:[models对象.objects.order_by(“字段名& ...

  6. STL::forward_list

    forward_list(c++11): 内部是一个单链表的实现:但是为了效率的考虑,故意没有 size 这个内置函数. Constructor 六种构造方式default; fill; range; ...

  7. Gym - 101911C Bacteria (规律题)

    传送门:点我 Time limit2000 ms Memory limit262144 kB Recently Monocarp has created his own mini-laboratory ...

  8. vcenter或workstation12导入ovf出错:硬件系列vmx 14不受支持

    原因是因为导出ovf的虚拟机版本太高. 两个方法,一个强制,一个推荐. 强制 1. 打开ovf后缀文件,把<vssd:VirtualSystemType>vmx-14</vssd:V ...

  9. c++实现中的一些注意 事项

    1,尽可能延后对象中的变量定义式的出现,这样可以增加程序的清晰度,尽量少的调用构造,如果有定义变量最好在末尾定义并给予初值,这样就避免了默认构造函数的调用. 2 尽量少做转型操作. const_cas ...

  10. sqlserver2017 +SSMS+ VS2017+SSDT 安装要点及相关组件下载地址

    1.sqlserver2017安装PolyBase需要安装jdk7 ,注意必须是7  jdk10是不行的. 下载地址:http://dl-t1.wmzhe.com/30/30117/jdk_7u_1. ...