转自:

博客

http://blog.csdn.net/google19890102/article/details/45532745/

github

https://github.com/zhaozhiyong19890102/Python-Machine-Learning-Algorithm/tree/master/Chapter_3%20Factorization%20Machine

一、因子分解机FM的模型

   因子分解机(Factorization Machine, FM)是由Steffen Rendle提出的一种基于矩阵分解的机器学习算法。

1、因子分解机FM的优势

    对于因子分解机FM来说,最大的特点是对于稀疏的数据具有很好的学习能力。现实中稀疏的数据很多,例如作者所举的推荐系统的例子便是一个很直观的具有稀疏特点的例子。

2、因子分解机FM的模型

    对于度为2的因子分解机FM的模型为:
其中,参数表示的是两个大小为的向量和向量的点积:
其中,表示的是系数矩阵的第维向量,且称为超参数。在因子分解机FM模型中,前面两部分是传统的线性模型,最后一部分将两个互异特征分量之间的相互关系考虑进来。
    因子分解机FM也可以推广到高阶的形式,即将更多互异特征分量之间的相互关系考虑进来。

二、因子分解机FM算法

    因子分解机FM算法可以处理如下三类问题:
  1. 回归问题(Regression)
  2. 二分类问题(Binary Classification)
  3. 排序(Ranking)

在这里主要介绍回归问题和二分类问题。

1、回归问题(Regression)

    在回归问题中,直接使用作为最终的预测结果。在回归问题中使用最小均方误差(the least square error)作为优化的标准,即
其中,表示样本的个数。

2、二分类问题(Binary Classification)

    与Logistic回归类似,通过阶跃函数,如Sigmoid函数,将映射成不同的类别。在二分类问题中使用logit loss作为优化的标准,即
其中,表示的是阶跃函数Sigmoid。具体形式为:

三、因子分解机FM算法的求解过程

1、交叉项系数

    在基本线性回归模型的基础上引入交叉项,如下:
  表示共有n个特征:
 
若是这种直接在交叉项的前面加上交叉项系数的方式在稀疏数据的情况下存在一个很大的缺陷,即在对于观察样本中未出现交互的特征分量,不能对相应的参数进行估计。
    对每一个特征分量引入辅助向量,利用对交叉项的系数进行估计,即
这就对应了一种矩阵的分解。对值的限定,对FM的表达能力有一定的影响。

2、模型的求解

这里要求出,主要采用了如公式求出交叉项。具体过程如下:

注:上式中: 

,且,倒数第二行中,将 j 换成 i,原式不变,所以能得到倒数第一行的形式。

3、基于随机梯度的方式求解

对于回归问题:
对于二分类问题:
 
最终交叉项要估计的参数每一个是:Vi,f
有n个特征, 每个特征有k个分量,那交叉项的参数个数就是:n*k。

四、实验(求解二分类问题)

1、实验的代码:

  1. #coding:UTF-8
  2. from __future__ import division
  3. from math import exp
  4. from numpy import *
  5. from random import normalvariate#正态分布
  6. from datetime import datetime
  7. trainData = 'E://data//diabetes_train.txt'
  8. testData = 'E://data//diabetes_test.txt'
  9. featureNum = 8
  10. def loadDataSet(data):
  11. dataMat = []
  12. labelMat = []
  13. fr = open(data)#打开文件
  14. for line in fr.readlines():
  15. currLine = line.strip().split()
  16. #lineArr = [1.0]
  17. lineArr = []
  18. for i in xrange(featureNum):
  19. lineArr.append(float(currLine[i + 1]))
  20. dataMat.append(lineArr)
  21. labelMat.append(float(currLine[0]) * 2 - 1)
  22. return dataMat, labelMat
  23. def sigmoid(inx):
  24. return 1.0 / (1 + exp(-inx))
  25. def stocGradAscent(dataMatrix, classLabels, k, iter):
  26. #dataMatrix用的是mat, classLabels是列表
  27. m, n = shape(dataMatrix)
  28. alpha = 0.01
  29. #初始化参数
  30. w = zeros((n, 1))#其中n是特征的个数
  31. w_0 = 0.    #截距项
  32. v = normalvariate(0, 0.2) * ones((n, k))   #交叉项
  33. for it in xrange(iter):
  34. print it
  35. for x in xrange(m):#随机优化,对每一个样本而言的
  36. inter_1 = dataMatrix[x] * v
  37. inter_2 = multiply(dataMatrix[x], dataMatrix[x]) * multiply(v, v)#multiply对应元素相乘
  38. #完成交叉项
  39. interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
  40. p = w_0 + dataMatrix[x] * w + interaction#计算预测的输出
  41. loss = sigmoid(classLabels[x] * p[0, 0]) - 1
  42. print loss
  43. w_0 = w_0 - alpha * loss * classLabels[x]
  44. for i in xrange(n):
  45. if dataMatrix[x, i] != 0:
  46. w[i, 0] = w[i, 0] - alpha * loss * classLabels[x] * dataMatrix[x, i]
  47. for j in xrange(k):
  48. v[i, j] = v[i, j] - alpha * loss * classLabels[x] * (dataMatrix[x, i] * inter_1[0, j] - v[i, j] * dataMatrix[x, i] * dataMatrix[x, i])
  49. return w_0, w, v
  50. def getAccuracy(dataMatrix, classLabels, w_0, w, v):
  51. m, n = shape(dataMatrix)
  52. allItem = 0
  53. error = 0
  54. result = []
  55. for x in xrange(m):
  56. allItem += 1
  57. inter_1 = dataMatrix[x] * v
  58. inter_2 = multiply(dataMatrix[x], dataMatrix[x]) * multiply(v, v)#multiply对应元素相乘
  59. #完成交叉项
  60. interaction = sum(multiply(inter_1, inter_1) - inter_2) / 2.
  61. p = w_0 + dataMatrix[x] * w + interaction#计算预测的输出
  62. pre = sigmoid(p[0, 0])
  63. result.append(pre)
  64. if pre < 0.5 and classLabels[x] == 1.0:
  65. error += 1
  66. elif pre >= 0.5 and classLabels[x] == -1.0:
  67. error += 1
  68. else:
  69. continue
  70. print result
  71. return float(error) / allItem
  72. if __name__ == '__main__':
  73. dataTrain, labelTrain = loadDataSet(trainData)
  74. dataTest, labelTest = loadDataSet(testData)
  75. date_startTrain = datetime.now()
  76. print "开始训练"
  77. w_0, w, v = stocGradAscent(mat(dataTrain), labelTrain, 20, 200)
  78. print "训练准确性为:%f" % (1 - getAccuracy(mat(dataTrain), labelTrain, w_0, w, v))
  79. date_endTrain = datetime.now()
  80. print "训练时间为:%s" % (date_endTrain - date_startTrain)
  81. print "开始测试"
  82. print "测试准确性为:%f" % (1 - getAccuracy(mat(dataTest), labelTest, w_0, w, v))

2、实验结果:

五、几点疑问

    在传统的非稀疏数据集上,有时效果并不是很好。在实验中,我有一点处理,即在求解Sigmoid函数的过程中,在有的数据集上使用了带阈值的求法:
  1. def sigmoid(inx):
  2. #return 1.0 / (1 + exp(-inx))
  3. return 1. / (1. + exp(-max(min(inx, 15.), -15.)))

六 图片

fm 讲解加代码的更多相关文章

  1. 简单的自动化使用--使用selenium实现学习通网站的刷慕课程序。注释空格加代码大概200行不到

    简单的自动化使用--使用selenium实现学习通网站的刷慕课程序.注释空格加代码大概200行不到 相见恨晚啊 github地址 环境Python3.6 + pycharm + chrom浏览器 + ...

  2. [洛谷P3376题解]网络流(最大流)的实现算法讲解与代码

    [洛谷P3376题解]网络流(最大流)的实现算法讲解与代码 更坏的阅读体验 定义 对于给定的一个网络,有向图中每个的边权表示可以通过的最大流量.假设出发点S水流无限大,求水流到终点T后的最大流量. 起 ...

  3. [CodeIgniter4]讲解-加载静态页

    讲解 本教程旨在向您介绍CodeIgniter框架和MVC体系结构的基本原理.它将向您展示如何以逐步的方式构造基本的CodeIgniter应用程序. 在本教程中,您将创建一个基本的新闻应用程序.您将从 ...

  4. Java核心技术及面试指南的视频讲解和代码下载位置

    都是百度云盘,均无密码 代码下载位置: https://pan.baidu.com/s/1I44ob0vygMxvmj2BoNioAQ 视频讲解位置: https://pan.baidu.com/s/ ...

  5. 扩展欧几里得(ex_gcd),中国剩余定理(CRT)讲解 有代码

    扩展欧几里得算法 求逆元就不说了. ax+by=c 这个怎么求,很好推. 设d=gcd(a,b) 满足d|c方程有解,否则无解. 扩展欧几里得求出来的解是 x是 ax+by=gcd(a,b)的解. 对 ...

  6. 傻瓜式的go modules的讲解和代码,及gomod能不能引入另一个gomod和gomod的use of internal package xxxx not allowed

    一 国内关于gomod的文章,哪怕是使用了百度 -csdn,依然全是理论,虽然golang的使用者大多是大神但是也有像我这样的的弱鸡是不是? 所以,我就写个傻瓜式教程了. github地址:https ...

  7. Rainbond 对接 Istio 原理讲解和代码实现分析

    一.背景 现有的 ServiceMesh 框架有很多,如 Istio.linkerd等.对于用户而言,在测试环境下,需要达到的效果是快.开箱即用.但在生产环境下,可能又有熔断.延时注入等需求.那么单一 ...

  8. C++工厂方法模式讲解和代码示例

    在C++中使用模式 使用示例: 工厂方法模式在 C++ 代码中得到了广泛使用. 当你需要在代码中提供高层次的灵活性时, 该模式会非常实用. 识别方法: 工厂方法可通过构建方法来识别, 它会创建具体类的 ...

  9. Vue学习之--------组件嵌套以及VueComponent的讲解(代码实现)(2022/7/23)

    欢迎加入刚建立的社区:http://t.csdn.cn/Q52km 加入社区的好处: 1.专栏更加明确.便于学习 2.覆盖的知识点更多.便于发散学习 3.大家共同学习进步 3.不定时的发现金红包(不多 ...

随机推荐

  1. Openwrt 3g模块

    支持Huawei E367 一.编译选项的选择 都选上 都选上 Network目录下 Utiles Luci 二.USB连接3G模块时,显示如下,表示成功 三.没找到:

  2. yii framework config 可以被配置的项目

    http://hi.baidu.com/lossless1009/item/990fdb33a52ffcf1e7bb7a4c <?php002 003 // 取消下行的注释,来定义一个路径别名0 ...

  3. bzoj 2516: 电梯

    Description Input Output 状压dp,状态表示为表示当前在第x层,电梯内有哪些人,哪些人还没到终点 #include<cstdio> #include<cstr ...

  4. java web 程序---javabean代码,出现错误。奇怪,无法解释的运行问题

    深夜吧.这个点11点半了 写了一个简单的javabean实例,发现没有任何代码书写的错误,但是问题就是程序运行会有问题,然后换一个包,重写一个,问题没了? 请问问题出现在哪里了?巧合?还是操作有误?这 ...

  5. Bootstrap:教程、简介、环境安装

    ylbtech-Bootstrap:教程.简介.环境安装 1. Bootstrap 教程返回顶部 1. Bootstrap 教程 Bootstrap,来自 Twitter,是目前最受欢迎的前端框架.B ...

  6. DevExpress 组件

    最近看到 伍华聪 的博客里, DevExpress 组件那个效果很好看,特别是 LookUpEdit GridLookUpEdit 这两个控件,完美改善了 WinForm 里的 ComboBox 今天 ...

  7. RabbitMQ 主题

    RabbitMQ (三) 发布/订阅 RabbitMQ主题 RabbitMQ Tutorials

  8. 当vcenter是linux版本的时候Sysprep存放路径

    为 VMware vCenter Server Appliance 安装 Microsoft Sysprep 工具在从 Microsoft 网站下载并安装 Microsoft Sysprep 工具之后 ...

  9. 【洛谷】P1341 无序字母对(欧拉回路)

    题目 传送门:QWQ 分析 快把欧拉回路忘光了. 欧拉回路大概就是一笔画的问题,可不可以一笔画完全图. 全图有欧拉回路当且仅当全图的奇数度数的点有0或2个. 2个时是一个点是起点,另一个是终点. 本题 ...

  10. Oracle 统计量NO_INVALIDATE参数配置(上)

    转载:http://blog.itpub.net/17203031/viewspace-1067312/ Oracle统计量对于CBO执行是至关重要的.RBO是建立在数据结构的基础上的,DDL结构.约 ...