机器学习算法实现解析——libFM之libFM的训练过程之SGD的方法
本节主要介绍的是libFM源码分析的第五部分之一——libFM的训练过程之SGD的方法。
5.1、基于梯度的模型训练方法
在libFM中,提供了两大类的模型训练方法,一类是基于梯度的训练方法,另一类是基于MCMC的模型训练方法。对于基于梯度的训练方法,其类为fm_learn_sgd
类,其父类为fm_learn
类,主要关系为:
fm_learn_sgd
类是所有基于梯度的训练方法的父类,其具体的代码如下所示:
#include "fm_learn.h"
#include "../../fm_core/fm_sgd.h"
// 继承自fm_learn
class fm_learn_sgd: public fm_learn {
protected:
//DVector<double> sum, sum_sqr;
public:
int num_iter;// 迭代次数
double learn_rate;// 学习率
DVector<double> learn_rates;// 多个学习率
// 初始化
virtual void init() {
fm_learn::init();
learn_rates.setSize(3);// 设置学习率
// sum.setSize(fm->num_factor);
// sum_sqr.setSize(fm->num_factor);
}
// 利用梯度下降法进行更新,具体的训练的过程在其子类中
virtual void learn(Data& train, Data& test) {
fm_learn::learn(train, test);// 该函数并没有具体实现
// 输出运行时的参数,包括:学习率,迭代次数
std::cout << "learnrate=" << learn_rate << std::endl;
std::cout << "learnrates=" << learn_rates(0) << "," << learn_rates(1) << "," << learn_rates(2) << std::endl;
std::cout << "#iterations=" << num_iter << std::endl;
if (train.relation.dim > 0) {// 判断relation
throw "relations are not supported with SGD";
}
std::cout.flush();// 刷新
}
// SGD重新修正fm模型的权重
void SGD(sparse_row<DATA_FLOAT> &x, const double multiplier, DVector<double> &sum) {
fm_SGD(fm, learn_rate, x, multiplier, sum);// 调用fm_sgd中的fm_SGD函数
}
// debug函数,主要用于打印中间结果
void debug() {
std::cout << "num_iter=" << num_iter << std::endl;
fm_learn::debug();
}
// 对数据进行预测
virtual void predict(Data& data, DVector<double>& out) {
assert(data.data->getNumRows() == out.dim);// 判断样本个数是否相等
for (data.data->begin(); !data.data->end(); data.data->next()) {
double p = predict_case(data);// 得到线性项和交叉项的和,调用的是fm_learn中的方法
if (task == TASK_REGRESSION ) {// 回归任务
p = std::min(max_target, p);
p = std::max(min_target, p);
} else if (task == TASK_CLASSIFICATION) {// 分类任务
p = 1.0/(1.0 + exp(-p));// Sigmoid函数处理
} else {// 异常处理
throw "task not supported";
}
out(data.data->getRowIndex()) = p;
}
}
};
在fm_learn_sgd
类中,主要包括五个函数,分别为:初始化init
函数,训练learn
函数,SGD训练SGD
函数,debug的debug
函数和预测predict
函数。
5.1.1、初始化init
函数
在初始化中,对学习率的大小进行了初始化,同时继承了父类中的初始化方法。
5.1.2、训练learn
函数
在learn
函数中,没有具体的训练的过程,只是对训练中需要用到的参数进行输出,具体的训练的过程在其对应的子类中定义,如fm_learn_sgd_element
类和fm_learn_sgd_element_adapt_reg
类。
5.1.3、SGD训练SGD
函数
SGD
函数使用的是fm_sgd.h
文件中的fm_SGD
函数。fm_SGD
函数是利用梯度下降法对模型中的参数进行调整,以得到最终的模型中的参数。在利用梯度下降法对模型中的参数进行调整的过程中,假设损失函数为l,那么,对于回归问题来说,其损失函数为:
对于二分类问题,其损失函数为:
其中,σ为Sigmoid函数:
对于σ(x),其导函数为:
在可用SGD更新的过程中,首先需要计算损失函数的梯度,因此,对应于上述的回归问题和二分类问题,其中回归问题的损失函数的梯度为:
分类问题的损失函数的梯度为:
其中,λ称为正则化参数,在具体的应用中,通常加上L2正则,即:
在定义好上述的计算方法后,其核心的问题是如何计算∂y^(i)∂θ,在“机器学习算法实现解析——libFM之libFM的模型处理部分”中已知:
因此,当y^分别对w0,wi以及vi,f求偏导时,其结果分别为:
在利用梯度的方法中,其参数θ的更新方法为:
其中,η为学习率,在libFM中,其具体的代码如下所示:
// 利用SGD更新模型的参数
void fm_SGD(fm_model* fm, const double& learn_rate, sparse_row<DATA_FLOAT> &x, const double multiplier, DVector<double> &sum) {
// 1、常数项的修正
if (fm->k0) {
double& w0 = fm->w0;
w0 -= learn_rate * (multiplier + fm->reg0 * w0);
}
// 2、一次项的修正
if (fm->k1) {
for (uint i = 0; i < x.size; i++) {
double& w = fm->w(x.data[i].id);
w -= learn_rate * (multiplier * x.data[i].value + fm->regw * w);
}
}
// 3、交叉项的修正
for (int f = 0; f < fm->num_factor; f++) {
for (uint i = 0; i < x.size; i++) {
double& v = fm->v(f,x.data[i].id);
double grad = sum(f) * x.data[i].value - v * x.data[i].value * x.data[i].value;
v -= learn_rate * (multiplier * grad + fm->regv * v);
}
}
}
以上的更新的过程分别对应着上面的更新公式,其中multiplier变量分别对应着回归中的(y^(i)−y(i))和分类中的(σ(y^(i)y(i))−1)⋅y(i)。
5.1.4、预测predict
函数
predict
函数用于对样本进行预测,这里使用到了predict_case
函数,该函数在“机器学习算法实现解析——libFM之libFM的训练过程概述”中有详细的说明,得到值后,分别对回归问题和分类问题做处理,在回归问题中,主要是防止超出最大值和最小值,在分类问题中,将其值放入Sigmoid函数,得到最终的结果。
5.2、SGD的训练方法
随机梯度下降法(Stochastic Gradient Descent ,SGD)是一种简单有效的优化方法。对于梯度下降法的更多内容,可以参见“梯度下降优化算法综述”。在利用SGD对FM模型训练的过程如下图所示:
在libFM中,SGD的实现在fm_learn_sgd_element.h
文件中。在该文件中,定义了fm_learn_sgd_element
类,fm_learn_sgd_element
类继承自fm_learn_sgd
类,主要实现了fm_learn_sgd
类中的learn
方法,具体的程序代码如下所示:
#include "fm_learn_sgd.h"
// 继承了fm_learn_sgd
class fm_learn_sgd_element: public fm_learn_sgd {
public:
// 初始化
virtual void init() {
fm_learn_sgd::init();
// 日志输出
if (log != NULL) {
log->addField("rmse_train", std::numeric_limits<double>::quiet_NaN());
}
}
// 利用SGD训练FM模型
virtual void learn(Data& train, Data& test) {
fm_learn_sgd::learn(train, test);// 输出参数信息
std::cout << "SGD: DON'T FORGET TO SHUFFLE THE ROWS IN TRAINING DATA TO GET THE BEST RESULTS." << std::endl;
// SGD
for (int i = 0; i < num_iter; i++) {// 开始迭代,每一轮的迭代过程
double iteration_time = getusertime();// 记录开始的时间
for (train.data->begin(); !train.data->end(); train.data->next()) {// 对于每一个样本
double p = fm->predict(train.data->getRow(), sum, sum_sqr);// 得到样本的预测值
double mult = 0;// 损失函数的导数
if (task == 0) {// 回归
p = std::min(max_target, p);
p = std::max(min_target, p);
// loss=(y_ori-y_pre)^2
mult = -(train.target(train.data->getRowIndex())-p);// 对损失函数求导
} else if (task == 1) {// 分类
// loss
mult = -train.target(train.data->getRowIndex())*(1.0-1.0/(1.0+exp(-train.target(train.data->getRowIndex())*p)));
}
// 利用梯度下降法对参数进行学习
SGD(train.data->getRow(), mult, sum);
}
iteration_time = (getusertime() - iteration_time);// 记录时间差
// evaluate函数是调用的fm_learn类中的方法
double rmse_train = evaluate(train);// 对训练结果评估
double rmse_test = evaluate(test);// 将模型应用在测试数据上
std::cout << "#Iter=" << std::setw(3) << i << "\tTrain=" << rmse_train << "\tTest=" << rmse_test << std::endl;
// 日志输出
if (log != NULL) {
log->log("rmse_train", rmse_train);
log->log("time_learn", iteration_time);
log->newLine();
}
}
}
};
在learn
函数中,实现了SGD训练FM模型的主要过程,在实现的过程中,分别调用了SGD
函数和evaluate
函数,其中SGD
函数如上面的5.1.3、SGD训练SGD函数
小节所示,利用SGD
函数对FM模型中的参数进行更新,evaluate
函数如“机器学习算法实现解析——libFM之libFM的训练过程概述”中所示,evaluate
函数用于评估学习出的模型的效果。其中mult变量分别对应着回归中的(y^(i)−y(i))和分类中的(σ(y^(i)y(i))−1)⋅y(i)。
参考文献
- Rendle S. Factorization Machines[C]// IEEE International Conference on Data Mining. IEEE Computer Society, 2010:995-1000.
- Rendle S. Factorization Machines with libFM[M]. ACM, 2012.
机器学习算法实现解析——libFM之libFM的训练过程之SGD的方法的更多相关文章
- 机器学习算法实现解析——libFM之libFM的训练过程之Adaptive Regularization
本节主要介绍的是libFM源码分析的第五部分之二--libFM的训练过程之Adaptive Regularization的方法. 5.3.Adaptive Regularization的训练方法 5. ...
- 机器学习算法实现解析——libFM之libFM的训练过程概述
本节主要介绍的是libFM源码分析的第四部分--libFM的训练. FM模型的训练是FM模型的核心的部分. 4.1.libFM中训练过程的实现 在FM模型的训练过程中,libFM源码中共提供了四种训练 ...
- 机器学习算法实现解析——libFM之libFM的模型处理部分
本节主要介绍的是libFM源码分析的第三部分--libFM的模型处理. 3.1.libFM中FM模型的定义 libFM模型的定义过程中主要包括模型中参数的设置及其初始化,利用模型对样本进行预测.在li ...
- 机器学习算法实现解析——word2vec源代码解析
在阅读本文之前,建议首先阅读"简单易学的机器学习算法--word2vec的算法原理"(眼下还没公布).掌握例如以下的几个概念: 什么是统计语言模型 神经概率语言模型的网络结构 CB ...
- 【机器学习算法-python实现】协同过滤(cf)的三种方法实现
(转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景 协同过滤(collaborative filtering)是推荐系统经常使用的一种方法.c ...
- 机器学习算法与Python实践之(四)支持向量机(SVM)实现
机器学习算法与Python实践之(四)支持向量机(SVM)实现 机器学习算法与Python实践之(四)支持向量机(SVM)实现 zouxy09@qq.com http://blog.csdn.net/ ...
- 机器学习算法与Python实践之(五)k均值聚类(k-means)
机器学习算法与Python实践这个系列主要是参考<机器学习实战>这本书.因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学 ...
- 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression)
http://blog.csdn.net/zouxy09/article/details/20319673 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression) z ...
- 机器学习算法( 五、Logistic回归算法)
一.概述 这会是激动人心的一章,因为我们将首次接触到最优化算法.仔细想想就会发现,其实我们日常生活中遇到过很多最优化问题,比如如何在最短时间内从A点到达B点?如何投入最少工作量却获得最大的效益?如何设 ...
随机推荐
- vmvare11克隆centos虚拟机
一.现在的虚拟机软件已经很强大了,基本上能省的操作配置,都能给用户考虑到 用vmvare安装虚拟机很简单,安装完成之后,对于不了解情况的人可能会发现虚拟机无法上网(共享主机ip的方式) 为了能够上网, ...
- Java集合(5):HashSet
存入Set的每个元素必须是惟一的,因为Set不保存重复元素.加入Set的元素必须定义equals()方法以确保对象的唯一性.Set不保证维护元素的次序.Set与Collection有完全一样的接口. ...
- iOS 和服务端交互 数据加密策略
总体逻辑: 客户端:对称加密数据,上传...回执对称解密 同理服务端:获取上传数据 对称解密 ...下发:对称加密 当且仅当登录接口和 拉新(更新nonce 和 key的接口)是对称加密上传 非对称解 ...
- gitlab + jenkins + docker + k8s
总体流程: 在开发机开发代码后提交到gitlab 之后通过webhook插件触发jenkins进行构建,jenkins将代码打成docker镜像,push到docker-registry 之后将在k8 ...
- C#基础--应用程序域(Appdomain)
AppDomain理解 为了保证代码的键壮性CLR希望不同服务功能的代码之间相互隔离,这种隔离可以通过创建多个进程来实现,但操作系统中创建进程是即耗时又耗费资源的一件事,所以在CLR中引入了AppDo ...
- 【Head First Servlets and JSP】笔记
1.谈到服务器的时候,可能是指物理主机(硬件),也可能是指Web服务应用(软件). 2.谈到客户的时候,通常指人类用户,或者是浏览器应用,或者两者都包括,浏览器应用做些什么?发送请求.解释HTML和呈 ...
- 课堂测试Mysort
课上没有做出来的原因 因为自己平时很少动手敲代码,所以在自己写代码的时候往往会比较慢,而且容易出现一些低级错误,再加上基础没有打牢,对于老师课上所讲的知识不能及时的理解消化,所以可能以后的课上测试都要 ...
- Cisco、HUAWEI、H3c、Firewall等设备配置snmp
配置HUAWEI交换机S1720.S2700.S5700.S6720等型号设备的snmp v3配置 注:此配置来源自官方配置文档 操作步骤 配置交换机的接口IP地址,使其和网管站之间路由可达 (图1) ...
- 因为swap分区无法启动
用户启动时停在如下截图
- Python基础笔记系列一:基本工具与表达式
本系列教程供个人学习笔记使用,如果您要浏览可能需要其它编程语言基础(如C语言),why?因为我写得烂啊,只有我自己看得懂!! 工具基础(Windows系统下)传送门:Python基础笔记系列四:工具的 ...