一.

  LTR(learning to rank)经常用于搜索排序中,开源工具中比较有名的是微软的ranklib,但是这个好像是单机版的,也有好长时间没有更新了。所以打算想利用lightgbm进行排序,但网上关于lightgbm用于排序的代码很少,关于回归和分类的倒是一堆。这里我将贴上python版的lightgbm用于排序的代码,里面将包括训练、获取叶结点、ndcg评估、预测以及特征重要度等处理代码,有需要的朋友可以参考一下或进行修改。

  其实在使用时,本人也对比了ranklib中的lambdamart和lightgbm,令人映像最深刻的是lightgbm的训练速度非常快,快的起飞。可能lambdamart训练需要几个小时,而lightgbm只需要几分钟,但是后面的ndcg测试都差不多,不像论文中所说的lightgbm精度高一点。lightgbm的训练速度快,我想可能最大的原因要可能是:a.节点分裂用到了直方图,而不是预排序方法;b.基于梯度的单边采样,即行采样;c.互斥特征绑定,即列采样;d.其于leaf-wise决策树生长策略;e.类别特征的支持等

二.代码

第一部分代码块是主代码,后面三个代码块是用到的加载数据和ndcg。运行主代码使用命令如训练模型使用:python lgb.py -train等

完成代码和数据格式放在https://github.com/jiangnanboy/learning_to_rank上面,大家可以参考一下!!!!!

 import os
import lightgbm as lgb
from sklearn import datasets as ds
import pandas as pd import numpy as np
from datetime import datetime
import sys
from sklearn.preprocessing import OneHotEncoder def split_data_from_keyword(data_read, data_group, data_feats):
'''
利用pandas
转为lightgbm需要的格式进行保存
:param data_read:
:param data_save:
:return:
'''
with open(data_group, 'w', encoding='utf-8') as group_path:
with open(data_feats, 'w', encoding='utf-8') as feats_path:
dataframe = pd.read_csv(data_read,
sep=' ',
header=None,
encoding="utf-8",
engine='python')
current_keyword = ''
current_data = []
group_size = 0
for _, row in dataframe.iterrows():
feats_line = [str(row[0])]
for i in range(2, len(dataframe.columns) - 1):
feats_line.append(str(row[i]))
if current_keyword == '':
current_keyword = row[1]
if row[1] == current_keyword:
current_data.append(feats_line)
group_size += 1
else:
for line in current_data:
feats_path.write(' '.join(line))
feats_path.write('\n')
group_path.write(str(group_size) + '\n') group_size = 1
current_data = []
current_keyword = row[1]
current_data.append(feats_line) for line in current_data:
feats_path.write(' '.join(line))
feats_path.write('\n')
group_path.write(str(group_size) + '\n') def save_data(group_data, output_feature, output_group):
'''
group与features分别进行保存
:param group_data:
:param output_feature:
:param output_group:
:return:
'''
if len(group_data) == 0:
return
output_group.write(str(len(group_data)) + '\n')
for data in group_data:
# 只包含非零特征
# feats = [p for p in data[2:] if float(p.split(":")[1]) != 0.0]
feats = [p for p in data[2:]]
output_feature.write(data[0] + ' ' + ' '.join(feats) + '\n') # data[0] => level ; data[2:] => feats def process_data_format(test_path, test_feats, test_group):
'''
转为lightgbm需要的格式进行保存
'''
with open(test_path, 'r', encoding='utf-8') as fi:
with open(test_feats, 'w', encoding='utf-8') as output_feature:
with open(test_group, 'w', encoding='utf-8') as output_group:
group_data = []
group = ''
for line in fi:
if not line:
break
if '#' in line:
line = line[:line.index('#')]
splits = line.strip().split()
if splits[1] != group: # qid => splits[1]
save_data(group_data, output_feature, output_group)
group_data = []
group = splits[1]
group_data.append(splits)
save_data(group_data, output_feature, output_group) def load_data(feats, group):
'''
加载数据
分别加载feature,label,query
'''
x_train, y_train = ds.load_svmlight_file(feats)
q_train = np.loadtxt(group)
return x_train, y_train, q_train def load_data_from_raw(raw_data):
with open(raw_data, 'r', encoding='utf-8') as testfile:
test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
return test_X, test_y, test_qids, comments def train(x_train, y_train, q_train, model_save_path):
'''
模型的训练和保存
'''
train_data = lgb.Dataset(x_train, label=y_train, group=q_train)
params = {
'task': 'train', # 执行的任务类型
'boosting_type': 'gbrt', # 基学习器
'objective': 'lambdarank', # 排序任务(目标函数)
'metric': 'ndcg', # 度量的指标(评估函数)
'max_position': 10, # @NDCG 位置优化
'metric_freq': 1, # 每隔多少次输出一次度量结果
'train_metric': True, # 训练时就输出度量结果
'ndcg_at': [10],
'max_bin': 255, # 一个整数,表示最大的桶的数量。默认值为 255。lightgbm 会根据它来自动压缩内存。如max_bin=255 时,则lightgbm 将使用uint8 来表示特征的每一个值。
'num_iterations': 500, # 迭代次数
'learning_rate': 0.01, # 学习率
'num_leaves': 31, # 叶子数
# 'max_depth':6,
'tree_learner': 'serial', # 用于并行学习,‘serial’: 单台机器的tree learner
'min_data_in_leaf': 30, # 一个叶子节点上包含的最少样本数量
'verbose': 2 # 显示训练时的信息
}
gbm = lgb.train(params, train_data, valid_sets=[train_data])
gbm.save_model(model_save_path) def predict(x_test, comments, model_input_path):
'''
预测得分并排序
'''
gbm = lgb.Booster(model_file=model_input_path) # 加载model ypred = gbm.predict(x_test) predicted_sorted_indexes = np.argsort(ypred)[::-1] # 返回从大到小的索引 t_results = comments[predicted_sorted_indexes] # 返回对应的comments,从大到小的排序 return t_results def test_data_ndcg(model_path, test_path):
'''
评估测试数据的ndcg
'''
with open(test_path, 'r', encoding='utf-8') as testfile:
test_X, test_y, test_qids, comments = letor.read_dataset(testfile) gbm = lgb.Booster(model_file=model_path)
test_predict = gbm.predict(test_X) average_ndcg, _ = ndcg.validate(test_qids, test_y, test_predict, 60)
# 所有qid的平均ndcg
print("all qid average ndcg: ", average_ndcg)
print("job done!") def plot_print_feature_importance(model_path):
'''
打印特征的重要度
'''
#模型中的特征是Column_数字,这里打印重要度时可以映射到真实的特征名
feats_dict = {
'Column_0': '特征0名称',
'Column_1': '特征1名称',
'Column_2': '特征2名称',
'Column_3': '特征3名称',
'Column_4': '特征4名称',
'Column_5': '特征5名称',
'Column_6': '特征6名称',
'Column_7': '特征7名称',
'Column_8': '特征8名称',
'Column_9': '特征9名称',
'Column_10': '特征10名称',
}
if not os.path.exists(model_path):
print("file no exists! {}".format(model_path))
sys.exit(0) gbm = lgb.Booster(model_file=model_path) # 打印和保存特征重要度
importances = gbm.feature_importance(importance_type='split')
feature_names = gbm.feature_name() sum = 0.
for value in importances:
sum += value for feature_name, importance in zip(feature_names, importances):
if importance != 0:
feat_id = int(feature_name.split('_')[1]) + 1
print('{} : {} : {} : {}'.format(feat_id, feats_dict[feature_name], importance, importance / sum)) def get_leaf_index(data, model_path):
'''
得到叶结点并进行one-hot编码
'''
gbm = lgb.Booster(model_file=model_path)
ypred = gbm.predict(data, pred_leaf=True) one_hot_encoder = OneHotEncoder()
x_one_hot = one_hot_encoder.fit_transform(ypred)
print(x_one_hot.toarray()[0]) if __name__ == '__main__':
model_path = "保存模型的路径" if len(sys.argv) != 2:
print("Usage: python main.py [-process | -train | -predict | -ndcg | -feature | -leaf]")
sys.exit(0) if sys.argv[1] == '-process':
# 训练样本的格式与ranklib中的训练样本是一样的,但是这里需要处理成lightgbm中排序所需的格式
# lightgbm中是将样本特征和group分开保存为txt的,什么意思呢,看下面解释
'''
feats:
1 1:0.2 2:0.4 ...
2 1:0.2 2:0.4 ...
1 1:0.2 2:0.4 ...
3 1:0.2 2:0.4 ...
group:
2
4
这里group中2表示前2个是一个qid,4表示后两个是一个qid
'''
raw_data_path = '训练样本集路径'
data_feats = '特征保存路径'
data_group = 'group保存路径'
process_data_format(raw_data_path, data_feats, data_group) elif sys.argv[1] == '-train':
# train
train_start = datetime.now()
data_feats = '特征保存路径'
data_group = 'group保存路径'
x_train, y_train, q_train = load_data(data_feats, data_group)
train(x_train, y_train, q_train, model_path)
train_end = datetime.now()
consume_time = (train_end - train_start).seconds
print("consume time : {}".format(consume_time)) elif sys.argv[1] == '-predict':
train_start = datetime.now()
raw_data_path = '需要预测的数据路径'#格式如ranklib中的数据格式
test_X, test_y, test_qids, comments = load_data_from_raw(raw_data_path)
t_results = predict(test_X, comments, model_path)
train_end = datetime.now()
consume_time = (train_end - train_start).seconds
print("consume time : {}".format(consume_time)) elif sys.argv[1] == '-ndcg':
# ndcg
test_path = '测试的数据路径'#评估测试数据的平均ndcg
test_data_ndcg(model_path, test_path) elif sys.argv[1] == '-feature':
plot_print_feature_importance(model_path) elif sys.argv[1] == '-leaf':
#利用模型得到样本叶结点的one-hot表示
raw_data = '测试数据路径'#
with open(raw_data, 'r', encoding='utf-8') as testfile:
test_X, test_y, test_qids, comments = letor.read_dataset(testfile)
get_leaf_index(test_X, model_path)

lightgbm用于排序的更多相关文章

  1. java中的类实现comparable接口 用于排序

    import java.util.Arrays; public class SortApp { public static void main(String[] args) { Student[] s ...

  2. Treemap 有序的hashmap。用于排序

    TreeMap:有固定顺序的hashmap.在需要排序的Map时候才用TreeMap. Map.在数组中我们是通过数组下标来对其内容索引的,键值对. HashMap HashMap 用哈希码快速定位一 ...

  3. C++11新特性应用--介绍几个新增的便利算法(用于排序的几个算法)

    继续C++11在头文件algorithm中添加的算法. 至少我认为,在stl的算法中,用到最多的就是sort了,我们不去探索sort的源代码.就是介绍C++11新增的几个关于排序的函数. 对于一个序列 ...

  4. XGBoost、LightGBM的详细对比介绍

    sklearn集成方法 集成方法的目的是结合一些基于某些算法训练得到的基学习器来改进其泛化能力和鲁棒性(相对单个的基学习器而言)主流的两种做法分别是: bagging 基本思想 独立的训练一些基学习器 ...

  5. LightGBM大战XGBoost,谁将夺得桂冠?

    引 言 如果你是一个机器学习社区的活跃成员,你一定知道 提升机器(Boosting Machine)以及它们的能力.提升机器从AdaBoost发展到目前最流行的XGBoost.XGBoost实际上已经 ...

  6. LightGBM调参笔记

    本文链接:https://blog.csdn.net/u012735708/article/details/837497031. 概述在竞赛题中,我们知道XGBoost算法非常热门,是很多的比赛的大杀 ...

  7. XGBoost、LightGBM、Catboost总结

    sklearn集成方法 bagging 常见变体(按照样本采样方式的不同划分) Pasting:直接从样本集里随机抽取的到训练样本子集 Bagging:自助采样(有放回的抽样)得到训练子集 Rando ...

  8. 【小程序分享篇 一 】开发了个JAVA小程序, 用于清除内存卡或者U盘里的垃圾文件非常有用

    有一种场景, 手机内存卡空间被用光了,但又不知道哪个文件占用了太大,一个个文件夹去找又太麻烦,所以我开发了个小程序把手机所有文件(包括路径下所有层次子文件夹下的文件)进行一个排序,这样你就可以找出哪个 ...

  9. MS SQL 排序规则总结

    排序规则术语        什么是排序规则呢? 排序规则是根据特定语言和区域设置标准指定对字符串数据进行排序和比较的规则.SQL Server 支持在单个数据库中存储具有不同排序规则的对象.MSDN解 ...

随机推荐

  1. Tomcat一闪而过的调试方法

    很少用tomcat来部署,都是用springboot微服务.只是以前学的时候搞demo试过而已. 软件测试的期末作业要求要测一个Javaweb的项目,给了一个包然后要求部署在tomcat中并启动. 然 ...

  2. 火狐浏览器 访问所有HTTPS网站显示连接不安全解决办法

    当 Firefox 连接到一个安全的网站时(网址最开始为“https://”),它必须确认该网站出具的证书有效且使用足够高的加密强度.如果证书无法通过验证,或加密强度过低,Firefox 会中止连接到 ...

  3. C#正则表达式根据分组命名取值

    string[] regexList = new string[] { @"^(?<TickerPart1>[0-9A-Z])[ 0_]?(?<TickerPart2> ...

  4. Hadoop Local(本地)模式搭建

    1. 下载压缩包 2. 配置环境变量 3. 配置Hadoop的JAVA_HOME路径 4. WordCount 1. 下载压缩包 下载Hadoop binary二进制压缩包 https://hadoo ...

  5. 【转载】Sqlserver使用Group By进行分组并计算每个组的数量

    在SQL语句查询中,Group By语句时常用来进行分组操作,有时候在分组的同时还需要计算出每个组的数量多少.在Sqlserver数据库中可以使用Group By加Count聚合函数来实现此功能,即通 ...

  6. Redis 知识 整理

    简介 安装 启动 注意事项 使用命令 通用命令 数据结构 字符串(string) 哈希(hash) 队列(list) 集合(set) 有序集合(zset) 位图(bitcount) 事务 订阅与发布 ...

  7. Flutter中的按钮组件介绍

    Flutter 里有很多的 Button 组件很多,常见的按钮组件有:RaisedButton.FlatButton.IconButton.OutlineButton.ButtonBar.Floati ...

  8. 某阅读多word整理自动化脚本

    版权声明:本文为博主原创文章,转载 请注明出处:https://blog.csdn.net/sc2079/article/details/101055192 - 写在前面 最近想练习英语,发现电脑磁盘 ...

  9. leetcode-2-重复的DNA序列

    所有 DNA 都由一系列缩写为 A,C,G 和 T 的核苷酸组成,例如:"ACGAATTCCG".在研究 DNA 时,识别 DNA 中的重复序列有时会对研究非常有帮助. 编写一个函 ...

  10. 牛客练习赛48 C 小w的糖果 (数学,多项式,差分)

    牛客练习赛48 C 小w的糖果 (数学,多项式) 链接:https://ac.nowcoder.com/acm/contest/923/C来源:牛客网 题目描述 小w和他的两位队友teito.toki ...