多任务学习模型之ESMM介绍与实现
简介:本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。
多任务学习背景
目前工业中使用的推荐算法已不只局限在单目标(ctr)任务上,还需要关注后续的转换链路,如是否评论、收藏、加购、购买、观看时长等目标。
本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。
论文介绍
CVR预估面临两个关键问题:
1. Sample Selection Bias (SSB)
转化是在点击之后才“有可能”发生的动作,传统CVR模型通常以点击数据为训练集,其中点击未转化为负例,点击并转化为正例。但是训练好的模型实际使用时,则是对整个空间的样本进行预估,而非只对点击样本进行预估。即训练数据与实际要预测的数据来自不同分布,这个偏差对模型的泛化能力构成了很大挑战,导致模型上线后,线上业务效果往往一般。
2. Data Sparsity (DS)
CVR预估任务的使用的训练数据(即点击样本)远小于CTR预估训练使用的曝光样本。仅使用数量较小的样本进行训练,会导致深度模型拟合困难。
一些策略可以缓解这两个问题,例如从曝光集中对unclicked样本抽样做负例缓解SSB,对转化样本过采样缓解DS等。但无论哪种方法,都没有从实质上解决上面任一个问题。
由于点击=>转化,本身是两个强相关的连续行为,作者希望在模型结构中显示考虑这种“行为链关系”,从而可以在整个空间上进行训练及预测。这涉及到CTR与CVR两个任务,因此使用多任务学习(MTL)是一个自然的选择,论文的关键亮点正在于“如何搭建”这个MTL。
首先需要重点区分下,CVR预估任务与CTCVR预估任务。
- CVR = 转化数/点击数。是预测“假设item被点击,那么它被转化”的概率。CVR预估任务,与CTR没有绝对的关系。一个item的ctr高,cvr不一定同样会高,如标题党文章的浏览时长往往较低。这也是不能直接使用全部样本训练CVR模型的原因,因为无法确定那些曝光未点击的样本,假设他们被点击了,是否会被转化。如果直接使用0作为它们的label,会很大程度上误导CVR模型的学习。
- CTCVR = 转换数/曝光数。是预测“item被点击,然后被转化”的概率。
其中x,y,z分别表示曝光,点击,转换。注意到,在全部样本空间中,CTR对应的label为click,而CTCVR对应的label为click & conversion,这两个任务是可以使用全部样本的。因此,ESMM通过学习CTR,CTCVR两个任务,再根据上式隐式地学习CVR任务。具体结构如下:
网络结构上有两点值得强调:
- 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
- 隐式学习pCVR。这里pCVR 仅是网络中的一个variable,没有显示的监督信号。
具体地,反映在目标函数中:
代码实现
基于EasyRec推荐算法框架,我们实现了ESMM算法,具体实现可移步至github:EasyRec-ESMM。
EasyRec介绍:EasyRec是阿里云计算平台机器学习PAI团队开源的大规模分布式推荐算法框架,EasyRec 正如其名字一样,简单易用,集成了诸多优秀前沿的推荐系统论文思想,并且有在实际工业落地中取得优良效果的特征工程方法,集成训练、评估、部署,与阿里云产品无缝衔接,可以借助 EasyRec 在短时间内搭建起一套前沿的推荐系统。作为阿里云的拳头产品,现已稳定服务于数百个企业客户。
模型前馈网络:
def build_predict_graph(self):
"""Forward function. Returns:
self._prediction_dict: Prediction result of two tasks.
"""
# 此处从Concatenate后的tensor(all_fea)开始,省略其生成逻辑 cvr_tower_name = self._cvr_tower_cfg.tower_name
dnn_model = dnn.DNN(
self._cvr_tower_cfg.dnn,
self._l2_reg,
name=cvr_tower_name,
is_training=self._is_training)
cvr_tower_output = dnn_model(all_fea)
cvr_tower_output = tf.layers.dense(
inputs=cvr_tower_output,
units=1,
kernel_regularizer=self._l2_reg,
name='%s/dnn_output' % cvr_tower_name) ctr_tower_name = self._ctr_tower_cfg.tower_name
dnn_model = dnn.DNN(
self._ctr_tower_cfg.dnn,
self._l2_reg,
name=ctr_tower_name,
is_training=self._is_training)
ctr_tower_output = dnn_model(all_fea)
ctr_tower_output = tf.layers.dense(
inputs=ctr_tower_output,
units=1,
kernel_regularizer=self._l2_reg,
name='%s/dnn_output' % ctr_tower_name) tower_outputs = {
cvr_tower_name: cvr_tower_output,
ctr_tower_name: ctr_tower_output
}
self._add_to_prediction_dict(tower_outputs)
return self._prediction_dict
loss计算:
注意:计算CVR的指标时需要mask掉曝光数据。
def build_loss_graph(self):
"""Build loss graph. Returns:
self._loss_dict: Weighted loss of ctr and cvr.
"""
cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name] ctcvr_label = tf.cast(
self._labels[cvr_label_name] * self._labels[ctr_label_name],
tf.float32)
cvr_loss = tf.keras.backend.binary_crossentropy(
ctcvr_label, self._prediction_dict['probs_ctcvr'])
cvr_loss = tf.reduce_sum(cvr_losses, name="ctcvr_loss") # The weight defaults to 1.
self._loss_dict['weighted_cross_entropy_loss_%s' %
cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_loss ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.cast(self._labels[ctr_label_name], tf.float32),
logits=self._prediction_dict['logits_%s' % ctr_tower_name]
), name="ctr_loss") self._loss_dict['weighted_cross_entropy_loss_%s' %
ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_loss
return self._loss_dict
note: 这里loss是 weighted_cross_entropy_loss_ctr + weighted_cross_entropy_loss_cvr, EasyRec框架会自动对self._loss_dict中的内容进行加和。
metric计算:
注意:计算CVR的指标时需要mask掉曝光数据。
def build_metric_graph(self, eval_config):
"""Build metric graph. Args:
eval_config: Evaluation configuration. Returns:
metric_dict: Calculate AUC of ctr, cvr and ctrvr.
"""
metric_dict = {} cvr_tower_name = self._cvr_tower_cfg.tower_name
ctr_tower_name = self._ctr_tower_cfg.tower_name
cvr_label_name = self._label_name_dict[cvr_tower_name]
ctr_label_name = self._label_name_dict[ctr_tower_name]
for metric in self._cvr_tower_cfg.metrics_set:
# CTCVR metric
ctcvr_label_name = cvr_label_name + '_ctcvr'
cvr_dtype = self._labels[cvr_label_name].dtype
self._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(
self._labels[ctr_label_name], cvr_dtype)
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._cvr_tower_cfg.loss_type,
label_name=ctcvr_label_name,
num_class=self._cvr_tower_cfg.num_class,
suffix='_ctcvr')) # CVR metric
cvr_label_masked_name = cvr_label_name + '_masked'
ctr_mask = self._labels[ctr_label_name] > 0
self._labels[cvr_label_masked_name] = tf.boolean_mask(
self._labels[cvr_label_name], ctr_mask)
pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'
pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)
self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(
self._prediction_dict[pred_name], ctr_mask)
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._cvr_tower_cfg.loss_type,
label_name=cvr_label_masked_name,
num_class=self._cvr_tower_cfg.num_class,
suffix='_%s_masked' % cvr_tower_name)) for metric in self._ctr_tower_cfg.metrics_set:
# CTR metric
metric_dict.update(
self._build_metric_impl(
metric,
loss_type=self._ctr_tower_cfg.loss_type,
label_name=ctr_label_name,
num_class=self._ctr_tower_cfg.num_class,
suffix='_%s' % ctr_tower_name))
return metric_dict
实验及不足
我们基于开源AliCCP数据,进行了大量实验,实验部分请期待下一篇文章。实验发现,ESMM的跷跷板现象较为明显,CTR与CVR任务的效果较难同时提升。
参考文献
- Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate
- 阿里CVR预估模型之ESMM
- EasyRec-ESMM使用介绍多任务学习模型之ESMM介绍与实现
本文为阿里云原创内容,未经允许不得转载。
多任务学习模型之ESMM介绍与实现的更多相关文章
- 【论文笔记】多任务学习(Multi-Task Learning)
1. 前言 多任务学习(Multi-task learning)是和单任务学习(single-task learning)相对的一种机器学习方法.在机器学习领域,标准的算法理论是一次学习一个任务,也就 ...
- [译]深度神经网络的多任务学习概览(An Overview of Multi-task Learning in Deep Neural Networks)
译自:http://sebastianruder.com/multi-task/ 1. 前言 在机器学习中,我们通常关心优化某一特定指标,不管这个指标是一个标准值,还是企业KPI.为了达到这个目标,我 ...
- 使用深度学习的超分辨率介绍 An Introduction to Super Resolution using Deep Learning
使用深度学习的超分辨率介绍 关于使用深度学习进行超分辨率的各种组件,损失函数和度量的详细讨论. 介绍 超分辨率是从给定的低分辨率(LR)图像恢复高分辨率(HR)图像的过程.由于较小的空间分辨率(即尺寸 ...
- 推荐中的多任务学习-ESMM
本文将介绍阿里发表在 SIGIR'18 的论文ESMM<Entire Space Multi-Task Model: An Effective Approach for Estimating Po ...
- 牛亚男:基于多Domain多任务学习框架和Transformer,搭建快精排模型
导读: 本文主要介绍了快手的精排模型实践,包括快手的推荐系统,以及结合快手业务展开的各种模型实战和探索,全文围绕以下几大方面展开: 快手推荐系统 CTR模型--PPNet 多domain多任务学习框架 ...
- 多任务学习(MTL)在转化率预估上的应用
今天主要和大家聊聊多任务学习在转化率预估上的应用. 多任务学习(Multi-task learning,MTL)是机器学习中的一个重要领域,其目标是利用多个学习任务中所包含的有用信息来帮助每个任务学习 ...
- 推荐中的多任务学习-YouTube视频推荐
本文将介绍Google发表在RecSys'19 的论文<Recommending What Video to Watch Next: A Multitask Ranking System> ...
- 分布式多任务学习论文阅读(四):去偏lasso实现高效通信
1.难点-如何实现高效的通信 我们考虑下列的多任务优化问题: \[ \underset{\textbf{W}}{\min} \sum_{t=1}^{T} [\frac{1}{m_t}\sum_{i=1 ...
- 【NLP】蓦然回首:谈谈学习模型的评估系列文章(一)
统计角度窥视模型概念 作者:白宁超 2016年7月18日17:18:43 摘要:写本文的初衷源于基于HMM模型序列标注的一个实验,实验完成之后,迫切想知道采用的序列标注模型的好坏,有哪些指标可以度量. ...
- Stanford机器学习笔记-6. 学习模型的评估和选择
6. 学习模型的评估与选择 Content 6. 学习模型的评估与选择 6.1 如何调试学习算法 6.2 评估假设函数(Evaluating a hypothesis) 6.3 模型选择与训练/验证/ ...
随机推荐
- JS(数组)
一 数组的概念 问:之前学习的数据类型,只能存储一个值.如果我们想存储班级中所有学生的姓名,那么该如何存储呢?答:可以使用数组(Array).数组可以把一组相关的数据一起存放,并提供方便的访问(获取) ...
- 工作记录:8个有用的JS技巧
这里给大家分享我最近学习到的8个有用的js小技巧,废话不多说,我们上代码 1. 确保数组值 使用 grid ,需要重新创建原始数据,并且每行的列长度可能不匹配, 为了确保不匹配行之间的长度相等,可以使 ...
- 阿里二面:Java中锁的分类有哪些?你能说全吗?
引言 在多线程并发编程场景中,锁作为一种至关重要的同步工具,承担着协调多个线程对共享资源访问秩序的任务.其核心作用在于确保在特定时间段内,仅有一个线程能够对资源进行访问或修改操作,从而有效地保护数据的 ...
- 绘制三元图、颜色空间图:R语言代码
本文介绍基于R语言中的Ternary包,绘制三元图(Ternary Plot)的详细方法:其中,我们就以RGB三色分布图为例来具体介绍. 三元图可以从三个不同的角度反映数据的特征,因此在很多领 ...
- .NET Emit 入门教程:第六部分:IL 指令:1:概要介绍
前言: 在之前的文章中,我们完成了前面五个部分的内容学习,包括: 第一部分:Emit介绍 第二部分:构建动态程序集 第三部分:构建模块(Module) 第四部分:构建类型(Type) 第五部分:动态生 ...
- java基础 韩顺平老师的 面向对象(高级) 自己记的部分笔记
373,类变量引出 代码就提到了问题分析里的3点 package com.hspedu.static_; public class ChildGame { public static void mai ...
- #最大公约数#CF346A Alice and Bob
题目传送门 CF346A 分析 可以发现其所能表示的数就是能被最大公约数整除的数,且这些数不能超过最大值, 于是判断一下取数的奇偶性即可 代码 #include <cstdio> #inc ...
- #回滚莫队#AT1219 歴史の研究
洛谷题目 AT1219 分析 不满足区间减性质的运算,如最值,就不能用普通莫队求, 考虑回滚莫队,它的核心思想就是若区间在块内直接暴力, 否则将右端点从小到大排序,右端点按普通莫队求,那么左端点由于只 ...
- Node 项目通过 .npmrc 文件指定依赖安装源
背景 npm 命令运行时,往往通过命令行指定相关配置,最常用的便是使用 --registry 来指定依赖的安装源. npm install --registry=https://registry.np ...
- OpenHarmony有氧拳击之设备端开发
一.简介 在一个风和日丽,阳光明媚的下午,码农们都像往常一样正在专注地码代码.突然前面的小哥哥站起来,手握开发板,来回出拳.这是怎么回事? 原来这是一款拳击互动游戏,本文将带你一同解开其中的奥秘.开发 ...