一:入门

 1、基本用法

 (1)、自动交叉验证

  Surprise有一套内置的 算法数据集供您使用。在最简单的形式中,只需几行代码即可运行交叉验证程序:

from surprise import SVD
from surprise import Dataset
from surprise.model_selection import cross_validate # Load the movielens-100k dataset (download it if needed),
# 加载movielens-100k数据集(如果需要,下载)
data = Dataset.load_builtin('ml-100k') # #我们将使用着名的SVD算法。
# We'll use the famous SVD algorithm.
algo = SVD() #运行5倍交叉验证并打印结果
# Run 5-fold cross-validation and print results
cross_validate(algo, data, measures=['RMSE', 'MAE'], cv=5, verbose=True)

输出结果:

Evaluating RMSE, MAE of algorithm SVD on 5 split(s).

                  Fold 1  Fold 2  Fold 3  Fold 4  Fold 5  Mean    Std
RMSE (testset) 0.9398 0.9321 0.9413 0.9349 0.9329 0.9362 0.0037
MAE (testset) 0.7400 0.7351 0.7400 0.7364 0.7370 0.7377 0.0020
Fit time 5.66 5.47 5.46 5.60 5.77 5.59 0.12
Test time 0.24 0.14 0.18 0.15 0.15 0.17 0.04

load_builtin()方法将提供下载movielens-100k数据集(如果尚未下载),并将其保存.surprise_data在主目录的文件夹中(您也可以选择将其保存在其他位置)。

我们在这里使用众所周知的 SVD 算法,但是有许多其他算法可用。

cross_validate() 函数根据cv参数运行交叉验证过程,并计算一些accuracy度量。我们在这里使用经典的5倍交叉验证,但可以使用更高级的迭代器

(2)、测试集分解和fit()方法

如果您不想运行完整的交叉验证程序,可以使用对 train_test_split() 给定大小的训练集和测试集进行采样,并使用您的选择。您将需要使用将在列车集上训练算法的方法,以及将返回从testset进行的预测的方法:accuracy metricfit()test()

from surprise import SVD
from surprise import Dataset
from surprise import accuracy
from surprise.model_selection import train_test_split # Load the movielens-100k dataset (download it if needed),
data = Dataset.load_builtin('ml-100k') # sample random trainset and testset # 随机测试集和训练集
# test set is made of 25% of the ratings. # 将25%的数据设置成测试集
trainset, testset = train_test_split(data, test_size=.25) # We'll use the famous SVD algorithm.
algo = SVD() # Train the algorithm on the trainset, and predict ratings for the testset # 在训练集中训练算法,并预测数据
algo.fit(trainset)
predictions = algo.test(testset) # Then compute RMSE
accuracy.rmse(predictions)

执行结果:

RMSE: 0.9461

(3)、训练整个训练集和predict()方法

显然,我们也可以简单地将算法拟合到整个数据集,而不是运行交叉验证。这可以通过使用build_full_trainset()将构建trainset对象的方法来完成 :

from surprise import KNNBasic
from surprise import Dataset # Load the movielens-100k dataset
data = Dataset.load_builtin('ml-100k') # Retrieve the trainset.
# 检索训练集
trainset = data.build_full_trainset() # Build an algorithm, and train it.
# 构建算法并训练
algo = KNNBasic()
algo.fit(trainset) uid = str(196) # raw user id (as in the ratings file). They are **strings**!
iid = str(302) # raw item id (as in the ratings file). They are **strings**! # get a prediction for specific users and items.
# #获取特定用户和项目的预测。
pred = algo.predict(uid, iid, r_ui=4, verbose=True)

预测结果:

user: 196        item: 302        r_ui = 4.00   est = 4.06   {'actual_k': 40, 'was_impossible': False}
# est表示预测值

以上都是使用内置的数据集。

2、使用自定义数据集

Surprise有一组内置 数据集,但您当然可以使用自定义数据集。加载评级数据集可以从文件(例如csv文件)或从pandas数据帧完成。无论哪种方式,您都需要ReaderSurprise定义一个对象,以便能够解析文件或数据帧。

# 要从文件(例如csv文件)加载数据集,您将需要以下 load_from_file()方法:

from surprise import BaselineOnly
from surprise import Dataset
from surprise import Reader
from surprise.model_selection import cross_validate
import os # path to dataset file
# 数据集路径
file_path = os.path.expanduser(r'C:/Users/FELIX/.surprise_data/ml-100k/ml-100k/u.data') # As we're loading a custom dataset, we need to define a reader. In the
# movielens-100k dataset, each line has the following format:
# 'user item rating timestamp', separated by '\t' characters.
# #当我们加载自定义数据集时,我们需要定义一个reader。在
# #movielens-100k数据集中,每一行都具有以下格式:
# #'user item rating timestamp',以'\ t'字符分隔。
reader = Reader(line_format='user item rating timestamp', sep='\t') data = Dataset.load_from_file(file_path, reader=reader) # We can now use this dataset as we please, e.g. calling cross_validate
# #我们现在可以随意使用这个数据集,例如调用cross_validate
cross_validate(BaselineOnly(), data, verbose=True)
# 要从pandas数据框加载数据集,您将需要该 load_from_df()方法。您还需要一个Reader对象,但只能rating_scale指定参数。数据框必须有三列,对应于用户(原始)ID,项目(原始)ID以及此顺序中的评级。因此,每行对应于给定的评级。这不是限制性的,因为您可以轻松地重新排序数据框的列

import pandas as pd

from surprise import NormalPredictor
from surprise import Dataset
from surprise import Reader
from surprise.model_selection import cross_validate # Creation of the dataframe. Column names are irrelevant.
# #创建数据帧。列名无关紧要。
ratings_dict = {'itemID': [1, 1, 1, 2, 2],
'userID': [9, 32, 2, 45, 'user_foo'],
'rating': [3, 2, 4, 3, 1]}
df = pd.DataFrame(ratings_dict) # A reader is still needed but only the rating_scale param is requiered.
# #仍然需要一个reader,但只需要rating_scale param。
reader = Reader(rating_scale=(1, 5)) # The columns must correspond to user id, item id and ratings (in that order).
# #列必须对应于用户ID,项目ID和评级(按此顺序)。
data = Dataset.load_from_df(df[['userID', 'itemID', 'rating']], reader) # We can now use this dataset as we please, e.g. calling cross_validate
# #我们现在可以随意使用这个数据集,例如调用cross_validate
cross_validate(NormalPredictor(), data, cv=2)

3、使用交叉验证迭代器

对于交叉验证,我们可以使用cross_validate()为我们完成所有艰苦工作的功能。但是为了更好地控制,我们还可以实现交叉验证迭代器,并使用split()迭代器的test()方法和算法的 方法对每个拆分进行预测 。这是一个例子,我们使用经典的K-fold交叉验证程序和3个拆分:

from surprise import SVD
from surprise import Dataset
from surprise import accuracy
from surprise.model_selection import KFold # Load the movielens-100k dataset
data = Dataset.load_builtin('ml-100k') # define a cross-validation iterator
# define一个交叉验证迭代器
kf = KFold(n_splits=3) algo = SVD() for trainset, testset in kf.split(data): # train and test algorithm.
#训练和测试算法。
algo.fit(trainset)
predictions = algo.test(testset) # Compute and print Root Mean Squared Error
# 计算并打印输出
accuracy.rmse(predictions, verbose=True)

可以使用其他交叉验证迭代器,如LeaveOneOut或ShuffleSplit。在这里查看所有可用的迭代器。Surprise的交叉验证工具的设计源于优秀的scikit-learn API。

交叉验证的一个特例是当折叠已经被某些文件预定义时。例如,movielens-100K数据集已经提供了5个训练和测试文件(u1.base,u1.test ... u5.base,u5.test)。惊喜可以通过使用surprise.model_selection.split.PredefinedKFold 对象来处理这种情况:

from surprise import SVD
from surprise import Dataset
from surprise import Reader
from surprise import accuracy
from surprise.model_selection import PredefinedKFold # path to dataset folder
files_dir = os.path.expanduser('~/.surprise_data/ml-100k/ml-100k/') # This time, we'll use the built-in reader.
reader = Reader('ml-100k') # folds_files is a list of tuples containing file paths:
# [(u1.base, u1.test), (u2.base, u2.test), ... (u5.base, u5.test)]
train_file = files_dir + 'u%d.base'
test_file = files_dir + 'u%d.test'
folds_files = [(train_file % i, test_file % i) for i in (1, 2, 3, 4, 5)] data = Dataset.load_from_folds(folds_files, reader=reader)
pkf = PredefinedKFold() algo = SVD() for trainset, testset in pkf.split(data): # train and test algorithm.
algo.fit(trainset)
predictions = algo.test(testset) # Compute and print Root Mean Squared Error
accuracy.rmse(predictions, verbose=True)

当然,也可以对单个文件进行训练和测试。但是folds_files参数仍然要列表的形式。

4、使用GridSearchCV调整算法参数

cross_validate()函数报告针对给定参数集的交叉验证过程的准确度度量。如果你想知道哪个参数组合能产生最好的结果,那么这个 GridSearchCV就可以解决了。给定一个dict参数,该类详尽地尝试所有参数组合并报告任何精度测量的最佳参数(在不同的分裂上取平均值)。它受到scikit-learn的GridSearchCV的启发。

from surprise import SVD
from surprise import Dataset
from surprise.model_selection import GridSearchCV # Use movielens-100K
data = Dataset.load_builtin('ml-100k') param_grid = {'n_epochs': [5, 10], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6]}
gs = GridSearchCV(SVD, param_grid, measures=['rmse', 'mae'], cv=3) gs.fit(data) # best RMSE score
print(gs.best_score['rmse']) # 输出最高的准确率的值 # combination of parameters that gave the best RMSE score
print(gs.best_params['rmse']) # 输出最好的批次,学习率参数

通过上面操作得到最佳参数后就可以使用该参数的算法:

# We can now use the algorithm that yields the best rmse:
algo = gs.best_estimator['rmse']
algo.fit(data.build_full_trainset())

surprise库官方文档分析(一)的更多相关文章

  1. surprise库官方文档分析(二):使用预测算法

    1.使用预测算法 Surprise提供了一堆内置算法.所有算法都派生自AlgoBase基类,其中实现了一些关键方法(例如predict,fit和test).可以在prediction_algorith ...

  2. surprise库官方文档分析(三):搭建自己的预测算法

    1.基础 创建自己的预测算法非常简单:算法只不过是一个派生自AlgoBase具有estimate 方法的类.这是该方法调用的predict()方法.它接受内部用户ID,内部项ID,并返回估计评级r f ...

  3. webpack官方文档分析(三):Entry Points详解

    1.有很多种方法可以在webpack的配置中定义entry属性,为了解释为什么它对你有用,我们将展现有哪些方法可以配置entry属性. 2.单一条目语法 用法: entry: string|Array ...

  4. webpack官方文档分析(二):概念

    1.概念 webpack的核心是将JavaScript应用程序的静态捆绑模块.当webpack处理您的应用程序时,它会在内部构建一个依赖关系图,它映射您的项目所需的每个模块并生成一个或多个包. 从版本 ...

  5. webpack官方文档分析(一):安装

    一:安装 1.首先要安装Node.js->node.js下载 2.本地安装 要安装最新版本或特定版本,运行如下: npm install --save-dev webpack npm insta ...

  6. Akka源码分析-官方文档说明

    如果有小伙伴在看官方文档的时候,发现有些自相矛盾的地方,不要怀疑,可能是官方文档写错了或写的不清楚,毕竟它只能是把大部分情况描述清楚.开源代码一直在更新,官方文档有没有更新就不知道了,特别是那些官方不 ...

  7. hbase官方文档(转)

    FROM:http://www.just4e.com/hbase.html Apache HBase™ 参考指南  HBase 官方文档中文版 Copyright © 2012 Apache Soft ...

  8. HBase官方文档

    HBase官方文档 目录 序 1. 入门 1.1. 介绍 1.2. 快速开始 2. Apache HBase (TM)配置 2.1. 基础条件 2.2. HBase 运行模式: 独立和分布式 2.3. ...

  9. 比官方文档更易懂的Vue.js教程!包你学会!

    欢迎大家前往腾讯云+社区,获取更多腾讯海量技术实践干货哦~ 本文由蔡述雄发表于云+社区专栏 蔡述雄,现腾讯用户体验设计部QQ空间高级UI工程师.智图图片优化系统首席工程师,曾参与<众妙之门> ...

随机推荐

  1. IOS微信浏览器返回事件监听问题

    业务需求:从主页进入A订单页面,然后经过各种刷新或点标签加载后点左上角的返回直接返回到主页 采取方法:采用onpopstate事件监听url改变,从而跳转到主页 遇到的问题:安卓上测试没问题:苹果手机 ...

  2. 使用Android Studio遇到的问题

    学校这课程安排没明白...又要写安卓了. 这里把使用Android Studio3.1时遇到的问题记录下. Android Studio无法启动模拟器 解决: 控制面板->程序->关闭Hy ...

  3. 用 Scoop 管理你的 Windows 软件

    包管理系统,Homebrew 就是 macOS 上体验最佳的软件包管理,能帮助我们方便快捷.干净利落的管理软件.在Windows平台上也有一个非常棒的包管理软件--Scoop.Scoop 最适合安装那 ...

  4. 《图解HTTP》摘录

    # 图解HTTP 第 1 章 了解Web及网络基础 1.1使用http协议访问web 客户端:通过发送请求获取服务器资源的Web浏览器等. Web使用一种名为 HTTP(HyperText Trans ...

  5. 关于/var/log/maillog 时间和系统时间不对应的问题 -- 我出现的是日志时间比系统时间慢12个小时

    那么让我们来见证奇迹的时刻吧!! 首先你要看下/etc/localtime的软连接,到哪了 一般就是这块出问题了 检查这里就绝对不会错的 对比图 : 这种情况, 删除/etc/localtime : ...

  6. 二、openfeign生成并调用客户端动态代理对象

    所有文章 https://www.cnblogs.com/lay2017/p/11908715.html 正文 上一篇文章中,我们了解到了@FeignClient注解的接口被扫描到以后,会生成一个Fe ...

  7. VBA日期时间函数(十三)

    VBScript日期和时间函数帮助开发人员将日期和时间从一种格式转换为另一种格式,或以适合特定条件的格式表示日期或时间值. 日期函数 编号 函数 描述 1 Date 一个函数,它返回当前的系统日期. ...

  8. python 循环结构(for-in)

    循环结构(for-in) 说明:也是循环结构的一种,经常用于遍历字符串.列表,元组,字典等 格式: for x in y: 循环体 执行流程:x依次表示y中的一个元素,遍历完所有元素循环结束 示例1: ...

  9. 【异常】Maxwell异常 Exception in thread "main" net.sf.jsqlparser.parser.TokenMgrError: Lexical error at line 1, column 596. Encountered: <EOF> after : ""

    1 详细异常 Exception in thread "main" net.sf.jsqlparser.parser.TokenMgrError: Lexical error at ...

  10. TLS1.3 握手过程特性的整理

    1.密码协商 TLS协议中,密码协商的过程中Client在ClientHello中提供四种option 第一:client 支持的加密套件列表,密码套件里面中能出现Client支持的AEAD算法或者H ...