Caffe源代码中Solver文件分析
Caffe源代码(caffe version commit: 09868ac , date: 2015.08.15)中有一些重要的头文件,这里介绍下include/caffe/solver.hpp文件的内容:
1. include文件:
<caffe/solver.hpp>:此文件的介绍能够參考: http://blog.csdn.net/fengbingchun/article/details/62423060
2. 模板类Solver:虚基类
3. 模板类WorkerSolver:继承父类Solver,用于多GPU训练时仅计算梯度
4. 模板类SGDSolver:继承父类Solver
5. 模板类NesterovSolver:继承SGDSolver
6. 模板类AdaGradSolver:继承SGDSolver
7. 模板类RMSPropSolver:继承SGDSolver
8. 模板类AdaDeltaSolver:继承SGDSolver
9. 模板类AdamSolver:继承SGDSolver
10. 函数GetSolver:new solver对象
Solver通过协调Net的前向判断计算和反向梯度计算(forward inference and backward gradients),来对參数进行更新。从而达到降低loss的目的。Caffe模型的学习被分为两个部分:由Solver进行优化、更新參数。由Net计算出loss和gradient。
solver.prototxt是一个配置文件用来告知Caffe如何对网络进行训练。
有了Net就能够进行神经网络的前后向传播计算了。可是还缺少神经网络的训练和预測功能,Solver类进一步封装了训练和预測相关的一些功能。Solver定义了针对Net网络模型的求解方法,记录神经网络的训练过程,保存神经网络模型參数,中断并恢复网络的训练过程。自己定义Solver能够实现不同的神经网络求解方式。
Caffe支持的solvers包含:
(1)、Stochastic Gradient Descent(type: “SGD”)即随机梯度下降:利用负梯度和上一次权重的更新值的线性组合来更新权重。学习率(learning rate)是负梯度的权重。
动量是上一次更新值的权重。一般将学习速率初始化为0.01。然后在训练(training)中当loss达到稳定时,将学习速率除以一个常数(比如10),将这个过程重复多次。
对于动量一般设置为0.9,动量使weight得更新更为平缓,使学习过程更为稳定、高速。
(2)、AdaDelta(type:“AdaDelta”):是一种”鲁棒的学习率方法”,同SGD一样是一种基于梯度的优化方法。
(3)、Adaptive Gradient(type: “AdaGrad”)即自适应梯度下降,与随机梯度下降一样是基于梯度的优化方法。
(4)、Adam(type:“Adam”):也是一种基于梯度的优化方法。
它包含一对自适应时刻预计变量,能够看做是AdaGrad的一种泛化形式。
(5)、Nesterov’s Accelerated Gradient(type: “Nesterov”):Nesterov提出的加速梯度下降(Nesterov’s accelerated gradient)是凸优化的一种最优算法,其收敛速度能够达到O(1/t^2),而不是O(1/t)。
虽然在使用Caffe训练深度神经网络时非常难满足O(1/t^2)收敛条件。但实际中NAG对于某些特定结构的深度学习模型仍是一个非常有效的方法。
(6)、RMSprop(type:“RMSProp”):是一种基于梯度的优化方法(同SGD相似)。
Solver:
(1)、用于优化过程的记录、创建训练网络(用于学习)和測试网络(用于评估);
(2)、通过forward和backward过程来迭代地优化和更新參数;
(3)、周期性地用測试网络评估模型性能;
(4)、在优化过程中记录模型和solver状态的快照(snapshot)。
每一次迭代过程中:
(1)、调用Net的前向过程计算出输出和loss。
(2)、调用Net的反向过程计算出梯度(loss对每层的权重w和偏置b求导)。
(3)、依据以下所讲的Solver方法。利用梯度更新參数;
(4)、依据学习率(learning rate)。历史数据和求解方法更新solver的状态。使权重从初始化状态逐步更新到终于的学习到的状态。
Solvers的运行模式有CPU/GPU两种模式。
Solver方法:用于最小化损失(loss)值。
给定一个数据集D,优化的目标是D中全部数据损失的均值,即平均损失。取得最小值。
注:以上关于Solver内容的介绍主要摘自由CaffeCN社区翻译的《Caffe官方教程中译本》。
<caffe/solver.hpp>文件的具体介绍例如以下:
#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_
#define CAFFE_OPTIMIZATION_SOLVER_HPP_ #include <string>
#include <vector> #include "caffe/net.hpp" namespace caffe { /**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver { // Solver模板类,虚基类
public:
// 显示构造函数, 内部会调用Init函数
explicit Solver(const SolverParameter& param, const Solver* root_solver = NULL);
explicit Solver(const string& param_file, const Solver* root_solver = NULL);
// 成员变量赋值,包含param_、iter_、current_step_,并调用InitTrainNet和InitTestNets函数
void Init(const SolverParameter& param);
// 为成员变量net_赋值
void InitTrainNet();
// 为成员变量test_nets_赋值
void InitTestNets();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
// 依次调用函数Restore、Step、Snapshot,然后运行net_的前向传播函数ForwardPrefilled,最后调用TestAll函数
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
// 重复运行net前向传播反向传播计算,期间会调用函数TestAll、ApplyUpdate、Snapshot及类Callback两个成员函数
void Step(int iters);
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
// 载入已有的模型
void Restore(const char* resume_file);
// 虚析构函数
virtual ~Solver() {}
// 获得slover parameter
inline const SolverParameter& param() const { return param_; }
// 获得train Net
inline shared_ptr<Net<Dtype> > net() { return net_; }
// 获得test Net
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
// 获得当前的迭代数
int iter() { return iter_; }
// Invoked at specific points during an iteration
// 内部Callback类,仅在多卡GPU模式下使用
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0; template <typename T>
friend class Solver;
};
// 获得Callback
const vector<Callback*>& callbacks() const { return callbacks_; }
// 加入一个Callback
void add_callback(Callback* value) { callbacks_.push_back(value); } protected:
// Make and apply the update value for the current iteration.
// 更新net的权值和偏置
virtual void ApplyUpdate() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
// 快照,内部会调用SnapshotToBinaryProto或SnapshotToHDF5、SnapshotSolverState函数
void Snapshot();
// 获取快照文件名称
string SnapshotFilename(const string extension);
// 写proto到.caffemodel
string SnapshotToBinaryProto();
// 写proto到HDF5文件
string SnapshotToHDF5();
// The test routine
// 内部会循环调用Test函数
void TestAll();
// 运行測试网络。net前向传播
void Test(const int test_net_id = 0);
// 存储snapshot solver state
virtual void SnapshotSolverState(const string& model_filename) = 0;
// 读HDF5文件到solver state
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
// 读二进制文件.solverstate到solver state
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
// dummy function,仅仅有声明没有实现
void DisplayOutputBlobs(const int net_id); // Caffe中类的成员变量名都带有后缀"_"。这样就easy区分暂时变量和类成员变量
SolverParameter param_; // solver parameter
int iter_; // 当前的迭代数
int current_step_; //
shared_ptr<Net<Dtype> > net_; // train net
vector<shared_ptr<Net<Dtype> > > test_nets_; // test net
vector<Callback*> callbacks_; // Callback // The root solver that holds root nets (actually containing shared layers)
// in data parallelism
const Solver* const root_solver_; // 禁止使用Solver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(Solver);
}; /**
* @brief Solver that only computes gradients, used as worker
* for multi-GPU training.
*/
template <typename Dtype>
class WorkerSolver : public Solver<Dtype> { // 模板类WorkerSolver。继承父类Solver
public:
// 显示构造函数
explicit WorkerSolver(const SolverParameter& param, const Solver<Dtype>* root_solver = NULL)
: Solver<Dtype>(param, root_solver) {} protected:
void ApplyUpdate() {}
void SnapshotSolverState(const string& model_filename) {
LOG(FATAL) << "Should not be called on worker solver.";
}
void RestoreSolverStateFromBinaryProto(const string& state_file) {
LOG(FATAL) << "Should not be called on worker solver.";
}
void RestoreSolverStateFromHDF5(const string& state_file) {
LOG(FATAL) << "Should not be called on worker solver.";
}
}; /**
* @brief Optimizes the parameters of a Net using
* stochastic gradient descent (SGD) with momentum.
*/
template <typename Dtype>
class SGDSolver : public Solver<Dtype> { // 模板类SGDSolver,继承父类Solver
public:
// 显示构造函数,调用PreSolve函数
explicit SGDSolver(const SolverParameter& param) : Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file) : Solver<Dtype>(param_file) { PreSolve(); }
// 获取history数据
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; } protected:
// 成员变量history_, update_, temp_初始化
void PreSolve();
// 获取学习率
Dtype GetLearningRate();
// 内部会调用ClipGradients、Normalize、Regularize、ComputeUpdateValue,更新net权值和偏置
virtual void ApplyUpdate();
// 调用caffe_scal函数
virtual void Normalize(int param_id);
// 调用caffe_axpy函数
virtual void Regularize(int param_id);
// 计算并更新对应Blob值,调用caffe_cpu_axpby和caffe_copy函数
virtual void ComputeUpdateValue(int param_id, Dtype rate);
// clip parameter gradients to that L2 norm,假设梯度值过大,就会对梯度做一个修剪。
// 对全部的參数乘以一个缩放因子,使得全部參数的平方和不超过參数中设定的梯度总值
virtual void ClipGradients();
// 存储snapshot solver state,内部会掉用SnapshotSolverStateToBinaryProto或SnapshotSolverStateToHDF5函数
virtual void SnapshotSolverState(const string& model_filename);
// 写solver state到二进制文件.solverstate
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
// 写solver state到HDF5
virtual void SnapshotSolverStateToHDF5(const string& model_filename);
// 读HDF5文件到solver state
virtual void RestoreSolverStateFromHDF5(const string& state_file);
// 读二进制文件.solverstate到solver state
virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
// history maintains the historical momentum data.
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
// of gradients/updates and is not needed in snapshots
// Caffe中类的成员变量名都带有后缀"_",这样就easy区分暂时变量和类成员变量
vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_; // 禁止使用SGDSolver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(SGDSolver);
}; template <typename Dtype>
class NesterovSolver : public SGDSolver<Dtype> { // 模板类NesterovSolver,继承SGDSolver
public:
// 显示构造函数
explicit NesterovSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) {}
explicit NesterovSolver(const string& param_file) : SGDSolver<Dtype>(param_file) {} protected:
// 计算并更新对应Blob值,调用caffe_cpu_axpby和caffe_copy函数
virtual void ComputeUpdateValue(int param_id, Dtype rate); // 禁止使用NesterovSolver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
}; template <typename Dtype>
class AdaGradSolver : public SGDSolver<Dtype> { // 模板类AdaGradSolver,继承SGDSolver
public:
// 显示构造函数,调用constuctor_sanity_check函数
explicit AdaGradSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } protected:
// 计算并更新对应Blob值
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
} // 禁止使用AdaGradSolver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
}; template <typename Dtype>
class RMSPropSolver : public SGDSolver<Dtype> { // 模板类RMSPropSolver,继承SGDSolver
public:
// 显示构造函数。调用constructor_sanity_check函数
explicit RMSPropSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } protected:
// 计算并更新对应Blob值
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with RMSProp.";
CHECK_GE(this->param_.rms_decay(), 0)
<< "rms_decay should lie between 0 and 1.";
CHECK_LT(this->param_.rms_decay(), 1)
<< "rms_decay should lie between 0 and 1.";
} // 禁止使用RMSPropSolver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
}; template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> { // 模板类AdaDeltaSolver。继承SGDSolver
public:
// 显示构造函数,调用AdaDeltaPreSolve函数
explicit AdaDeltaSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); } protected:
void AdaDeltaPreSolve();
// 计算并更新对应Blob值
virtual void ComputeUpdateValue(int param_id, Dtype rate); // 禁止使用AdaDeltaSolver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
}; /**
* @brief AdamSolver, an algorithm for first-order gradient-based optimization
* of stochastic objective functions, based on adaptive estimates of
* lower-order moments. Described in [1].
*
* [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
* arXiv preprint arXiv:1412.6980v8 (2014).
*/
template <typename Dtype>
class AdamSolver : public SGDSolver<Dtype> { // 模板类AdamSolver。继承SGDSolver
public:
// 显示构造函数,调用AdamPreSolve函数
explicit AdamSolver(const SolverParameter& param) : SGDSolver<Dtype>(param) { AdamPreSolve();}
explicit AdamSolver(const string& param_file) : SGDSolver<Dtype>(param_file) { AdamPreSolve(); } protected:
void AdamPreSolve();
// 计算并更新对应Blob值
virtual void ComputeUpdateValue(int param_id, Dtype rate); // 禁止使用AdamSolver类的拷贝和赋值操作
DISABLE_COPY_AND_ASSIGN(AdamSolver);
}; // new一个指定的solver方法对象
template <typename Dtype>
Solver<Dtype>* GetSolver(const SolverParameter& param) {
SolverParameter_SolverType type = param.solver_type(); switch (type) {
case SolverParameter_SolverType_SGD:
return new SGDSolver<Dtype>(param);
case SolverParameter_SolverType_NESTEROV:
return new NesterovSolver<Dtype>(param);
case SolverParameter_SolverType_ADAGRAD:
return new AdaGradSolver<Dtype>(param);
case SolverParameter_SolverType_RMSPROP:
return new RMSPropSolver<Dtype>(param);
case SolverParameter_SolverType_ADADELTA:
return new AdaDeltaSolver<Dtype>(param);
case SolverParameter_SolverType_ADAM:
return new AdamSolver<Dtype>(param);
default:
LOG(FATAL) << "Unknown SolverType: " << type;
}
return (Solver<Dtype>*) NULL;
} } // namespace caffe #endif // CAFFE_OPTIMIZATION_SOLVER_HPP_
在caffe.proto文件里。主要有一个message是与solver相关的,例如以下:
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 40 (last added: momentum2)
message SolverParameter { // Solver參数
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
//
// Exactly one train net must be specified using one of the following fields:
// train_net_param, train_net, net_param, net
// One or more test nets may be specified using any of the following fields:
// test_net_param, test_net, net_param, net
// If more than one test net field is specified (e.g., both net and
// test_net are specified), they will be evaluated in the field order given
// above: (1) test_net_param, (2) test_net, (3) net_param/net.
// A test_iter must be specified for each test_net.
// A test_level and/or a test_stage may also be specified for each test_net.
////////////////////////////////////////////////////////////////////////////// // Proto filename for the train net, possibly combined with one or more test nets.
optional string net = 24; // .prototxt文件名称, train or test net
// Inline train net param, possibly combined with one or more test nets.
optional NetParameter net_param = 25; // net parameter类 optional string train_net = 1; // Proto filename for the train net, .prototxt文件名称,train net
repeated string test_net = 2; // Proto filenames for the test nets, .prototxt文件名称,test net
optional NetParameter train_net_param = 21; // Inline train net params, train net parameter类
repeated NetParameter test_net_param = 22; // Inline test net params, test net parameter类 // The states for the train/test nets. Must be unspecified or
// specified once per net.
//
// By default, all states will have solver = true;
// train_state will have phase = TRAIN,
// and all test_state's will have phase = TEST.
// Other defaults are set according to the NetState defaults.
optional NetState train_state = 26; // train net state
repeated NetState test_state = 27; // test net state // The number of iterations for each test net.
repeated int32 test_iter = 3; // 对于測试网络(用于评估)运行一次须要迭代的次数, test_iter * batch_size = 測试图像总数量 // The number of iterations between two testing phases.
optional int32 test_interval = 4 [default = 0]; // 指定运行多少次训练网络运行一次測试网络
optional bool test_compute_loss = 19 [default = false]; // 运行測试网络时是否计算loss
// If true, run an initial test pass before the first iteration,
// ensuring memory availability and printing the starting value of the loss.
optional bool test_initialization = 32 [default = true]; // 在总的開始前,是否先运行一次測试网络
optional float base_lr = 5; // The base learning rate,基础学习率
// the number of iterations between displaying info. If display = 0, no info
// will be displayed.
optional int32 display = 6; // 指定迭代多少次显示一次结果信息
// Display the loss averaged over the last average_loss iterations
optional int32 average_loss = 33 [default = 1]; //
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1]; // // The learning rate decay policy. The currently implemented learning rate
// policies are as follows: // 学习率衰减策略
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
optional string lr_policy = 8; // 学习策略,可取的值包含:fixed、step、exp、inv、multistep、poly、sigmoid
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value, 动量
optional float weight_decay = 12; // The weight decay. //
// regularization types supported: L1 and L2
// controlled by weight_decay
optional string regularization_type = 29 [default = "L2"]; // L1 or L2
// the stepsize for learning rate policy "step"
optional int32 stepsize = 13; //
// the stepsize for learning rate policy "multistep"
repeated int32 stepvalue = 34; // // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
// whenever their actual L2 norm is larger.
optional float clip_gradients = 35 [default = -1]; // optional int32 snapshot = 14 [default = 0]; // The snapshot interval, 迭代多少次保存下结果(如权值、偏置)
optional string snapshot_prefix = 15; // The prefix for the snapshot,指定保存文件名称的前缀
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false]; //
enum SnapshotFormat {
HDF5 = 0;
BINARYPROTO = 1;
}
optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; // HDF5 or BINARYPROTO
// the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default.
enum SolverMode {
CPU = 0;
GPU = 1;
}
optional SolverMode solver_mode = 17 [default = GPU]; // 指定solve mode是CPU还是GPU
// the device_id will that be used in GPU mode. Use device_id = 0 in default.
optional int32 device_id = 18 [default = 0]; // GPU mode下使用
// If non-negative, the seed with which the Solver will initialize the Caffe
// random number generator -- useful for reproducible results. Otherwise,
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [default = -1]; // // Solver type
enum SolverType { // solver优化方法
SGD = 0;
NESTEROV = 1;
ADAGRAD = 2;
RMSPROP = 3;
ADADELTA = 4;
ADAM = 5;
}
optional SolverType solver_type = 30 [default = SGD]; // 指定solver优化方法
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
optional float delta = 31 [default = 1e-8]; //
// parameters for the Adam solver
optional float momentum2 = 39 [default = 0.999]; // // RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38; // // If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false]; // // If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [default = true]; //
}
solver的測试代码例如以下:
#include "funset.hpp"
#include <string>
#include <vector>
#include <map>
#include "common.hpp" int test_caffe_solver()
{
caffe::Caffe::set_mode(caffe::Caffe::CPU); // set run caffe mode const std::string solver_prototxt{ "E:/GitCode/Caffe_Test/test_data/model/mnist/lenet_solver.prototxt" }; caffe::SolverParameter solver_param;
if (!caffe::ReadProtoFromTextFile(solver_prototxt.c_str(), &solver_param)) {
fprintf(stderr, "parse solver.prototxt fail\n");
return -1;
} boost::shared_ptr<caffe::Solver<float> > solver(caffe::GetSolver<float>(solver_param)); caffe::SolverParameter param = solver->param(); if (param.has_net())
fprintf(stderr, "net: %s\n", param.net().c_str());
if (param.has_net_param()) {
fprintf(stderr, "has net param\n");
caffe::NetParameter net_param = param.net_param();
if (net_param.has_name())
fprintf(stderr, "net param name: %s\n", net_param.name().c_str());
}
if (param.has_train_net())
fprintf(stderr, "train_net: %s\n", param.train_net());
if (param.test_net().size() > 0) {
for (auto test_net : param.test_net())
fprintf(stderr, "test_net: %s\n", test_net);
}
if (param.has_train_net_param()) {
fprintf(stderr, "has train net param\n");
caffe::NetParameter train_net_param = param.train_net_param();
}
if (param.test_net_param().size() > 0) {
fprintf(stderr, "has test net param\n");
std::vector<caffe::NetParameter> test_net_param;
for (auto net_param : param.test_net_param())
test_net_param.push_back(net_param);
} if (param.has_train_state()) {
fprintf(stderr, "has train state\n");
caffe::NetState state = param.train_state();
}
if (param.test_state().size()) {
fprintf(stderr, "has test state\n");
} if (param.test_iter_size() > 0) {
fprintf(stderr, "has test iter\n");
for (auto iter : param.test_iter())
fprintf(stderr, " %d ", iter);
fprintf(stderr, "\n");
} if (param.has_test_interval())
fprintf(stderr, "test interval: %d\n", param.test_interval());
bool test_compute_loss = param.test_compute_loss();
fprintf(stderr, "test compute loss: %d\n", test_compute_loss);
bool test_initialization = param.test_initialization();
fprintf(stderr, "test initializtion: %d\n", test_initialization);
if (param.has_base_lr()) {
float base_lr = param.base_lr();
fprintf(stderr, "base lr: %f\n", base_lr);
}
if (param.has_display()) {
int display = param.display();
fprintf(stderr, "display: %d\n", display);
}
int average_loss = param.average_loss();
fprintf(stderr, "average loss: %d\n", average_loss);
if (param.has_max_iter()) {
int max_iter = param.max_iter();
fprintf(stderr, "max iter: %d\n", max_iter);
}
int iter_size = param.iter_size();
fprintf(stderr, "iter size: %d\n", iter_size); if (param.has_lr_policy())
fprintf(stderr, "lr policy: %s\n", param.lr_policy().c_str());
if (param.has_gamma())
fprintf(stderr, "gamma: %f\n", param.gamma());
if (param.has_power())
fprintf(stderr, "power: %f\n", param.power());
if (param.has_momentum())
fprintf(stderr, "momentum: %f\n", param.momentum());
if (param.has_weight_decay())
fprintf(stderr, "weight decay: %f\n", param.weight_decay());
std::string regularization_type = param.regularization_type();
fprintf(stderr, "regularization type: %s\n", param.regularization_type().c_str());
if (param.has_stepsize())
fprintf(stderr, "stepsize: %d\n", param.stepsize());
if (param.stepvalue_size() > 0) {
fprintf(stderr, "has stepvalue\n");
for (auto value : param.stepvalue())
fprintf(stderr, " %d ", value);
fprintf(stderr, "\n");
} fprintf(stderr, "clip gradients: %f\n", param.clip_gradients()); fprintf(stderr, "snapshot: %d\n", param.snapshot());
if (param.has_snapshot_prefix())
fprintf(stderr, "snapshot prefix: %s\n", param.snapshot_prefix().c_str());
fprintf(stderr, "snapshot diff: %d\n", param.snapshot_diff());
caffe::SolverParameter_SnapshotFormat snapshot_format = param.snapshot_format();
fprintf(stderr, "snapshot format: %s\n", snapshot_format == 0 ? "HDF5" : "BINARYPROTO");
caffe::SolverParameter_SolverMode solver_mode = param.solver_mode();
fprintf(stderr, "solver mode: %s\n", solver_mode == 0 ? "CPU" : "GPU");
if (param.has_device_id())
fprintf(stderr, "device id: %d\n", param.device_id());
fprintf(stderr, "random seed: %d\n", param.random_seed()); caffe::SolverParameter_SolverType solver_type = param.solver_type();
std::string solver_method[] {"SGD", "NESTEROV", "ADAGRAD", "RMSPROP", "ADADELTA", "ADAM"};
fprintf(stderr, "solver type: %s\n", solver_method[solver_type].c_str());
fprintf(stderr, "delta: %f\n", param.delta());
fprintf(stderr, "momentum2: %f\n", param.momentum2()); if (param.has_rms_decay())
fprintf(stderr, "rms decy: %f\n", param.rms_decay()); fprintf(stderr, "debug info: %d\n", param.debug_info());
fprintf(stderr, "snapshot after train: %d\n", param.snapshot_after_train()); boost::shared_ptr<caffe::Net<float>> net = solver->net();
std::vector<boost::shared_ptr<caffe::Net<float>>> test_nets = solver->test_nets();
fprintf(stderr, "test nets size: %d\n", test_nets.size());
fprintf(stderr, "iter: %d\n", solver->iter()); return 0;
}
部分输出结果例如以下:
Caffe源代码中Solver文件分析的更多相关文章
- Omapl138中AIS文件分析(参照Using the OMAP-L138 Bootloader)(转)
Omapl138中AIS文件分析(参照Using the OMAP-L138 Bootloader) 转载链接:https://blog.csdn.net/qq_40788950/article/de ...
- Caffe源码中common文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中的一些重要头文件如caffe.hpp.blob.hpp等或者外部调用Caffe库使用时,一般都会in ...
- Caffe源码中math_functions文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下math_functions文件. 1. include文件: ...
- JVM中 Class 文件分析
Java 虚拟机中定义的 Class 文件格式.每一个 Class 文件都对应着唯一一个类 或接口的定义信息,但是相对地,类或接口并不一定都得定义在文件里(譬如类或接口也可以通过 类加载器直接生成). ...
- NS2中trace文件分析
ns中模拟出来的时间最终会以trace文件的形式告诉我们,虽然说一般都是用awk等工具分析trace文件,但是了解trace文件的格式也是必不可少的.下面就介绍一下无线网络模拟中trace文件的格式. ...
- Caffe源码中syncedmem文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下syncedmem文件. 1. include文件: (1).& ...
- Maven项目中pom文件分析
pom英文全称: project object model 1.概述 pom.xml文件描述了maven项目的基本信息,比如groupId,artifactId,version等.也可以对maven项 ...
- Caffe源码中caffe.proto文件分析
Caffe源码(caffe version:09868ac , date: 2015.08.15)中有一些重要文件,这里介绍下caffe.proto文件. 在src/caffe/proto目录下有一个 ...
- 【caffe】三种文件类别:solver,model和weights
@tags: caffe 文件类别 solver文件 是一堆超参数,比如迭代次数,是否用GPU,多少次迭代暂存一次训练所得参数,动量项,权重衰减(即正则化参数),基本的learning rate,多少 ...
随机推荐
- intellij idea 部署项目的时候 图中application context 写不写有什么关系?有什么作用?
这个就是你部署之后访问的路径,比如你写一个/test,那反问就是127.0.0.1:8080/test,没有写的话就是127.0.0.1:8080
- 平时常用的Visual Studio操作技巧,持续更新中……
移除未使用的命名空间--方法1:右键--"组织using"--"移除未使用的using"--方法2:Shift+F10--"O"-" ...
- C# 中的动态创建技术
[转载]原文出处 http://blog.csdn.net/baiyun789/article/details/6156694 第一部分 WinForm控件在窗体中动态居中创建.删除控件及对其赋值 ...
- unity 脚本执行顺序设置 Script Execution Order Settings
通过Edit->Project Settings->Script Execution Order打开MonoManager面板 或者选择任意脚本在Inspector视图中点击Execu ...
- 使用 SQLiteManager 操作 sqlite3 数据库
SQLiteManager https://github.com/misato/SQLiteManager4iOS 本人以前从事过嵌入式开发,后来转职为iOS开发,即使如此,也绝不想去碰C语言级别的面 ...
- struts2 select 默认选中
jsp: <s:select list="#{'1':'男','2':'女'}" name="sex"/> action: private Stri ...
- JTable常见用法细则
JTable是Swing编程中很常用的控件,这里总结了一些常用方法以备查阅.欢迎补充,转载请注明作者与出处. 一.创建表格控件的各种方式:1) 调用无参构造函数. JTable table = ne ...
- 阿里云96页报告详解《云上转型》(10个案例、10大趋势/完整版PPT)
阿里云96页报告详解<云上转型>(10个案例.10大趋势/完整版PPT) 2017-12-29 14:20阿里云/云计算/技术 ﹃产业前沿超级干货﹄ ﹃数据观○重磅速递﹄ 阿里云研究中心云 ...
- 【POJ】【3525】Most Distant Point from the Sea
二分+计算几何/半平面交 半平面交的学习戳这里:http://blog.csdn.net/accry/article/details/6070621 然而这题是要二分长度r……用每条直线的距离为r的平 ...
- Eclipse 中java跨工程调用类
在Eclipse中,有时候需要跨工程调用其他工程中的方法.如下面有两个Java Project : 如果要在A工程中调用B工程中的类,可以将B工程添加到A工程中: A---- >Build Pa ...