Spark2.0协同过滤与ALS算法介绍
ALS矩阵分解
一个 的打分矩阵 A 可以用两个小矩阵和的乘积来近似,描述一个人的喜好经常是在一个抽象的低维空间上进行的,并不需要把其喜欢的事物一一列出。再抽象一些,把人们的喜好和电影的特征都投到这个低维空间,一个人的喜好映射到了一个低维向量,一个电影的特征变成了纬度相同的向量,那么这个人和这个电影的相似度就可以表述成这两个向量之间的内积。
我们把打分理解成相似度,那么“打分矩阵A(m*n)”就可以由“用户喜好特征矩阵U(m*k)”和“产品特征矩阵V(n*k)”的乘积。
矩阵分解过程中所用的优化方法分为两种:交叉最小二乘法(alternative least squares)和随机梯度下降法(stochastic gradient descent)。
损失函数包括正则化项(setRegParam)。
参数选取
分块数:分块是为了并行计算,默认为10。 正则化参数:默认为1。 秩:模型中隐藏因子的个数显示偏好信息-false,隐式偏好信息-true,默认false(显示) alpha:只用于隐式的偏好数据,偏好值可信度底线。 非负限定 numBlocks is the number of blocks the users and items will be
partitioned into in order to parallelize computation (defaults to
10). rank is the number of latent factors in the model (defaults to 10). maxIter is the maximum number of iterations to run (defaults to 10). regParam specifies the regularization parameter in ALS (defaults to 1.0). implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false
which means using explicit feedback). alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference
observations (defaults to 1.0). nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false).
ALS als = new ALS()
.setMaxIter(10)//最大迭代次数,设置太大发生java.lang.StackOverflowError
.setRegParam(0.16)//正则化参数
.setAlpha(1.0)
.setImplicitPrefs(false)
.setNonnegative(false)
.setNumBlocks(10)
.setRank(10)
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
需要注意的问题:
对于用户和物品项ID ,基于DataFrame API 只支持integers,因此最大值限定在integers范围内。
The DataFrame-based API for ALS currently only supports integers for
user and item ids. Other numeric types are supported for the user and
item id columns, but the ids must be within the integer value range.
//循环正则化参数,每次由Evaluator给出RMSError
List RMSE=new ArrayList();//构建一个List保存所有的RMSE
for(int i=0;i<20;i++){//进行20次循环
double lambda=(i*5+1)*0.01;//RegParam按照0.05增加
ALS als = new ALS()
.setMaxIter(5)//最大迭代次数
.setRegParam(lambda)//正则化参数
.setUserCol("userId")
.setItemCol("movieId")
.setRatingCol("rating");
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data
Dataset predictions = model.transform(test);
//RegressionEvaluator.setMetricName可以定义四种评估器
//"rmse" (default): root mean squared error
//"mse": mean squared error
//"r2": R^2^ metric
//"mae": mean absolute error
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")//RMS Error
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse = evaluator.evaluate(predictions);
RMSE.add(rmse);
System.out.println("RegParam "+0.01*i+" RMSE " + rmse+"\n");
}
//输出所有结果
for (int j = 0; j < RMSE.size(); j++) {
Double lambda=(j*5+1)*0.01;
System.out.println("RegParam= "+lambda+" RMSE= " + RMSE.get(j)+"\n");
}
通过设计一个循环,可以研究最合适的参数,部分结果如下:
RegParam= 0.01 RMSE= 1.956
RegParam= 0.06 RMSE= 1.166
RegParam= 0.11 RMSE= 0.977
RegParam= 0.16 RMSE= 0.962//具备最小的RMSE,参数最合适
RegParam= 0.21 RMSE= 0.985
RegParam= 0.26 RMSE= 1.021
RegParam= 0.31 RMSE= 1.061
RegParam= 0.36 RMSE= 1.102
RegParam= 0.41 RMSE= 1.144
RegParam= 0.51 RMSE= 1.228
RegParam= 0.56 RMSE= 1.267
RegParam= 0.61 RMSE= 1.300
//将RegParam固定在0.16,继续研究迭代次数的影响
输出如下的结果,在单机环境中,迭代次数设置过大,会出现一个java.lang.StackOverflowError异常。是由于当前线程的栈满了引起的。
numMaxIteration= 1 RMSE= 1.7325
numMaxIteration= 4 RMSE= 1.0695
numMaxIteration= 7 RMSE= 1.0563
numMaxIteration= 10 RMSE= 1.055
numMaxIteration= 13 RMSE= 1.053
numMaxIteration= 16 RMSE= 1.053
//测试Rank隐含语义个数
Rank =1 RMSErr = 1.1584
Rank =3 RMSErr = 1.1067
Rank =5 RMSErr = 0.9366
Rank =7 RMSErr = 0.9745
Rank =9 RMSErr = 0.9440
Rank =11 RMSErr = 0.9458
Rank =13 RMSErr = 0.9466
Rank =15 RMSErr = 0.9443
Rank =17 RMSErr = 0.9543
//可以用SPARK-SQL自己定义评估算法(如下面定义了一个平均绝对值误差计算过程)
// Register the DataFrame as a SQL temporary view
predictions.createOrReplaceTempView("tmp_predictions");
Dataset absDiff=spark.sql("select abs(prediction-rating) as diff from tmp_predictions");
absDiff.createOrReplaceTempView("tmp_absDiff");
spark.sql("select mean(diff) as absMeanDiff from tmp_absDiff").show();
完整代码
public class Rating implements Serializable{...}
可以在 http://spark.apache.org/docs/latest/ml-collaborative-filtering.html找到:
package my.spark.ml.practice.classification; import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession; public class myCollabFilter2 { public static void main(String[] args) {
SparkSession spark=SparkSession
.builder()
.appName("CoFilter")
.master("local[4]")
.config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse" )
.getOrCreate(); String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"
+ "data/mllib/als/sample_movielens_ratings.txt"; //屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
//-------------------------------1.0 准备DataFrame----------------------------
//..javaRDD()函数将DataFrame转换为RDD
//然后对RDD进行Map 每一行String->Rating
JavaRDD ratingRDD=spark.read().textFile(path).javaRDD()
.map(new Function() { @Override
public Rating call(String str) throws Exception {
return Rating.parseRating(str);
}
});
//System.out.println(ratingRDD.take(10).get(0).getMovieId()); //由JavaRDD(每一行都是一个实例化的Rating对象)和Rating Class创建DataFrame
Dataset ratings=spark.createDataFrame(ratingRDD, Rating.class);
//ratings.show(30); //将数据随机分为训练集和测试集
double[] weights=new double[] {0.8,0.2};
long seed=1234;
Dataset [] split=ratings.randomSplit(weights, seed);
Dataset training=split[0];
Dataset test=split[1]; //------------------------------2.0 ALS算法和训练数据集,产生推荐模型-------------
for(int rank=1;rank<20;rank++)
{
//定义算法
ALS als=new ALS()
.setMaxIter(5)////最大迭代次数,设置太大发生java.lang.StackOverflowError
.setRegParam(0.16)
.setUserCol("userId")
.setRank(rank)
.setItemCol("movieId")
.setRatingCol("rating");
//训练模型
ALSModel model=als.fit(training);
//---------------------------3.0 模型评估:计算RMSE,均方根误差---------------------
Dataset predictions=model.transform(test);
//predictions.show();
RegressionEvaluator evaluator=new RegressionEvaluator()
.setMetricName("rmse")
.setLabelCol("rating")
.setPredictionCol("prediction");
Double rmse=evaluator.evaluate(predictions);
System.out.println("Rank =" + rank+" RMSErr = " + rmse);
}
}
}
Spark2.0协同过滤与ALS算法介绍的更多相关文章
- 机器学习(十三)——机器学习中的矩阵方法(3)病态矩阵、协同过滤的ALS算法(1)
http://antkillerfarm.github.io/ 向量的范数(续) 范数可用符号∥x∥λ表示. 经常使用的有: ∥x∥1=|x1|+⋯+|xn| ∥x∥2=x21+⋯+x2n−−−−−− ...
- Spark2.0 协同过滤推荐
ALS矩阵分解 http://blog.csdn.net/oucpowerman/article/details/49847979 http://www.open-open.com/lib/view/ ...
- 协同过滤 CF & ALS 及在Spark上的实现
使用Spark进行ALS编程的例子可以看:http://www.cnblogs.com/charlesblc/p/6165201.html ALS:alternating least squares ...
- 原创:协同过滤之ALS
推荐系统的算法,在上个世纪90年代成型,最早应用于UserCF,基于用户的协同过滤算法,标志着推荐系统的形成.首先,要明白以下几个理论:①长尾理论②评判推荐系统的指标.之所以需要推荐系统,是要挖掘冷门 ...
- [Recommendation System] 推荐系统之协同过滤(CF)算法详解和实现
1 集体智慧和协同过滤 1.1 什么是集体智慧(社会计算)? 集体智慧 (Collective Intelligence) 并不是 Web2.0 时代特有的,只是在 Web2.0 时代,大家在 Web ...
- CF(协同过滤算法)
1 集体智慧和协同过滤 1.1 什么是集体智慧(社会计算)? 集体智慧 (Collective Intelligence) 并不是 Web2.0 时代特有的,只是在 Web2.0 时代,大家在 Web ...
- 协同过滤(CF)算法
1 集体智慧和协同过滤 1.1 什么是集体智慧(社会计算)? 集体智慧 (Collective Intelligence) 并不是 Web2.0 时代特有的,只是在 Web2.0 时代,大家在 Web ...
- spark-MLlib之协同过滤ALS
协同过滤与推荐 协同过滤是一种根据用户对各种产品的交互与评分来推荐新产品的推荐系统技术. 协同过滤引入的地方就在于它只需要输入一系列用户/产品的交互记录: 无论是显式的交互(例如在购物网站 ...
- 基于Python协同过滤算法的认识
Contents 1. 协同过滤的简介 2. 协同过滤的核心 3. 协同过滤的实现 4. 协同过滤的应用 1. 协同过滤的简介 关于协同过滤的一个最经典的例子就是看电影,有时候 ...
随机推荐
- ARP、Proxy ARP、Gratuitous ARP
Proxy ARP 什么是Proxy ARP? 一个主机A(通常是路由器)有意应答另一个主机B的ARP请求(ARP requests).主机A通过伪装其身份,承担起将分组路由到真实目的地的责任.代理A ...
- C 全局变量 本地变量
- PHP 获取上传文件的实际类型
方案一: mime_content_type ( string $filename ) : string (PHP 4 >= 4.3.0, PHP 5, PHP 7) mime_content_ ...
- linux实操_权限管理
rwx权限详解 作用到文件: [r]代表可读(read):可以读取,查看 [w]代表可写(write):可以修改,但是不代表可以删除文件,删除一个文件的前提条件时对该文件所在的目录有写权限,才能删除该 ...
- 指定js文件不使用 ESLint 语法检查
整个文件范围内禁止规则出现警告 将/* eslint-disable */放置于文件最顶部 /* eslint-disable */ alert('foo'); 在文件中临时禁止规则出现警告 将需要忽 ...
- axios并行请求
有些操作需要在几个异步请求都完成之后再执行,虽然一个Ajax可以放到另一个Ajax完成的回调里面,但这样很容易导致回调地狱,且代码也极其不美观. 幸运的是axios提供了并行请求的方法, 使用方法: ...
- hash 跟B+tree的区别
1.hash只支持in跟=,不支持范围查询,时间复杂度:O(1) 2.B+tree支持范围查询,时间复杂度:O(log n) 3. B+tree 的优点:1.磁盘读取代价更低 ...
- Python 练习实例2
Python 练习实例2 题目:企业发放的奖金根据利润提成.利润(I)低https://www.xuanhe.net/于或等于10万元时,奖金可提10%:利润高于10万元,低于20万元时,低于10万元 ...
- python以下划线开头的变量和函数的作用
在python中,我们经常能看到很多变量名以_下划线开头,而且下划线的数量还不一样,那么这些变量的作用到底是什么? 变量名分类: # 以数字.字母开头: 正常的公有变量名a = 1def aa(): ...
- 【算法进阶-康托展开】-C++
目录 引入 这位老爷子就是康托 基本概念 康托展开是一个全排列到一个自然数的双射,常用于构建hash表时的空间压缩.设有n个数(1,2,3,4,-,n),可以有组成不同(n!种)的排列组合,康托展开表 ...