In-batch negatives Embedding模型介绍与实践
语义索引(可通俗理解为向量索引)技术是搜索引擎、推荐系统、广告系统在召回阶段的核心技术之一。语义索引模型的目标是:给定输入文本,模型可以从海量候选召回库中快速、准确地召回一批语义相关文本。语义索引模型的效果直接决定了语义相关的物料能否被成功召回进入系统参与上层排序,从基础层面影响整个系统的效果。
In-batch negatives
我们采用百度paddleNLP里提到的In-batch Negatives方案。
In-batch Negatives 策略的训练数据为语义相似的 Pair 对,策略核心是在 1 个 Batch 内同时基于 N 个负例进行梯度更新,将Batch 内除自身之外其它所有 Source Text 的相似文本 Target Text 作为负例,例如: 上例中“我手机丢了,我想换个手机” 有 1 个正例(”我想买个新手机,求推荐“),3 个负例(1.求秋色之空全集漫画,2.手机学日语的软件,3.侠盗飞车罪恶都市怎么改车)。
具体来说,In-batch negatives策略的实施步骤如下:
- 选择正样本:首先从当前批次中选择出一个正样本,这个样本是模型需要正确识别的目标样本。
- 选择负样本:然后从同一批次中随机选择或根据特定规则选择一些负样本。这些负样本可以是与正样本相似但被错误标记的样本,也可以是完全不相关的样本。
- 模型训练:将正样本和负样本一起输入模型进行训练。模型需要学会区分正样本和负样本,从而提高推荐或检索的准确性。
In-batch negatives策略的优势在于:
- 提高模型的区分能力:通过在每个批次中引入负样本,模型被迫学习如何区分正样本和负样本,这有助于提高模型的泛化能力和区分度。
- 利用现有数据:不需要额外的负样本库,可以直接利用当前批次中的数据作为负样本,这在数据有限的情况下尤其有用。
- 减少计算资源消耗:与从全局样本集中采样负样本相比,In-batch negatives可以减少计算资源的消耗,因为它避免了在整个数据集上进行负采样的需要。
然而,In-batch negatives策略也存在一些潜在的问题,例如:
- 批次大小的限制:如果批次大小较小,可能无法提供足够多样化的负样本,这可能影响模型的学习效果。
- 偏差问题:由于负样本是在同一个批次中选择的,可能会出现某些样本被频繁选为负样本的情况,这可能导致模型学习到的表示存在偏差。
一般通过 Recall@1,Recall@5 ,Recall@10 ,Recall@20 和 Recall@50 指标来评估语义索引模型的召回效果。按照paddleNLP给出的基线:
策略 | 模型 | Recall@1 | Recall@5 | Recall@10 | Recall@20 | Recall@50 |
---|---|---|---|---|---|---|
In-batch Negatives | ernie 1.0 | 51.301 | 65.309 | 69.878 | 73.996 | 78.881 |
In-batch Negatives | rocketqa-zh-base-query-encoder | 59.622 | 75.089 | 79.668 | 83.404 | 87.773 |
rocketqa作为打底transformer模型效果更好。
总结,为什么为采用In-batch negatives,一方面能充分利用现有数据,不用单独准备负样例,减少投入,另外一方面模型的区分能力也比较好。
模型数据方案
流传一句话,用1亿条数据,训练10个epoch,不如用10亿数据训练一个epoch,也就是见多识广,大力出奇迹。
我们要训练一个给搜索用的向量召回模型,核心就是让准备足够多的正样例数据。正样例数据,一方面网上有较多的开源数据,可以直接利用。另外一方面,之间了解SimBERT 时,他们的数据很多也源自于搜索数据,所以可以通过搜索引擎将query和召回结果的doc作为相似句对。
作为试验,我们构造了8000万的一个小训练集,用rocketqa-zh-mini-query-encoder
作为打底模型,训练256维的embedding模型。
root_path=inbatch
python -u -m paddle.distributed.launch --gpus "0" \
train_batch_neg.py \
--device gpu \
--save_dir ./checkpoints/${root_path} \
--batch_size 64 \
--learning_rate 5E-5 \
--epochs 3 \
--output_emb_size 256 \
--model_name_or_path rocketqa-zh-mini-query-encoder \
--save_steps 5000 \
--max_seq_length 128 \
--margin 0.2 \
--train_set_file recall/train.csv \
--recall_result_dir "recall_result_dir" \
--recall_result_file "recall_result.txt" \
--hnsw_m 100 \
--hnsw_ef 100 \
--recall_num 50 \
--similar_text_pair_file "recall/dev.csv" \
--corpus_file "recall/corpus.csv"
训练完成导出onnx模型:
def convert_model(model_path):
try:
import onnx
import onnxruntime as ort
import paddle2onnx
from onnxconverter_common import float16
except ImportError:
print(
"The inference precision is change to 'fp32', please install the dependencies that required for 'fp16' inference, pip install onnxruntime-gpu onnx onnxconverter-common"
)
onnx_dir = os.path.join(model_path, "onnx")
if not os.path.exists(onnx_dir):
os.mkdir(onnx_dir)
float_onnx_file = os.path.join(onnx_dir, "model.onnx")
if not os.path.exists(float_onnx_file):
onnx_model = paddle2onnx.command.c_paddle_to_onnx(
model_file=os.path.join(model_path, "inference.pdmodel"),
params_file=os.path.join(model_path, "inference.pdiparams"),
opset_version=13,
enable_onnx_checker=True,
)
with open(float_onnx_file, "wb") as f:
f.write(onnx_model)
fp16_model_file = os.path.join(onnx_dir, "fp16_model.onnx")
if not os.path.exists(fp16_model_file):
onnx_model = onnx.load_model(float_onnx_file)
trans_model = float16.convert_float_to_float16(onnx_model, keep_io_types=True)
onnx.save_model(trans_model, fp16_model_file)
加载测试:
class MiniRocketQAEmbedding():
def __init__(self, model_file: str = model_file, use_gpu: bool = True):
providers = ['CUDAExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
sess_options = ort.SessionOptions()
self.predictor = ort.InferenceSession(
model_file, sess_options=sess_options, providers=providers)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
def embeding(self, embeding_text):
features = self.tokenizer(embeding_text, max_seq_len=128,
pad_to_max_seq_len=True, truncation_strategy="longest_first")
vecs = self.predictor.run(None, features.data)
return vecs[0]
def similarity(self, pairs):
query = pairs[0][0]
texts = [item[1] for item in pairs]
emdbeding_text = [query]
emdbeding_text.extend(texts)
features = self.tokenizer(emdbeding_text, max_seq_len=128,
pad_to_max_seq_len=True, truncation_strategy="longest_first")
vecs = self.predictor.run(None, features.data)
# print(vecs)
query_embeding = vecs[0][0]
vecs_text1 = query_embeding / (query_embeding**2).sum() ** 0.5
result = []
for i in range(1, len(vecs[0])):
vecs_text2 = vecs[0][i]
vecs_text2 = vecs_text2 / (vecs_text2**2).sum() ** 0.5
similarity = (vecs_text1 * vecs_text2).sum()
result.append({"similarity": float(similarity)})
return result
if __name__ == "__main__":
bert = MiniRocketQAEmbedding(use_gpu=False)
import time
start = time.time()
bert.embeding(["双鱼座性格特点","双鱼座性格特点"])
print((time.time() - start) * 1000)
通过MTEB框架来测试自建搜索测试集效果:
if __name__ == '__main__':
model = MyModel()
task_names = ["SSRetrieval"]
for task in task_names:
model.query_instruction_for_retrieval = None
evaluation = MTEB(tasks=[task], task_langs=['zh', 'zh-CN'])
evaluation.run(model, output_folder=f"zh_results/256_model", batch_size=64)
测试结果:
{
"dataset_revision": null,
"dev": {
"evaluation_time": 251.86,
"map_at_1": 0.13427,
"map_at_10": 0.62859,
"map_at_100": 0.72526,
"map_at_1000": 0.72564,
"map_at_3": 0.31398,
"map_at_5": 0.45025,
"mrr_at_1": 0.71863,
"mrr_at_10": 0.81982,
"mrr_at_100": 0.82077,
"mrr_at_1000": 0.82078,
"mrr_at_3": 0.80707,
"mrr_at_5": 0.81587,
"ndcg_at_1": 0.71803,
"ndcg_at_10": 0.77357,
"ndcg_at_100": 0.83634,
"ndcg_at_1000": 0.83907,
"ndcg_at_3": 0.72048,
"ndcg_at_5": 0.73003,
"precision_at_1": 0.71803,
"precision_at_10": 0.53373,
"precision_at_100": 0.07386,
"precision_at_1000": 0.00747,
"precision_at_3": 0.68889,
"precision_at_5": 0.65699,
"recall_at_1": 0.13427,
"recall_at_10": 0.78675,
"recall_at_100": 0.98082,
"recall_at_1000": 0.99181,
"recall_at_3": 0.35371,
"recall_at_5": 0.53211
},
"mteb_dataset_name": "SSRetrieval",
"mteb_version": "1.1.1"
}
同样的数据集,用peg模型测试:
{
"dataset_revision": null,
"dev": {
"evaluation_time": 1036.11,
"map_at_1": 0.09911,
"map_at_10": 0.42835,
"map_at_100": 0.49497,
"map_at_1000": 0.49681,
"map_at_3": 0.2277,
"map_at_5": 0.31901,
"mrr_at_1": 0.56794,
"mrr_at_10": 0.67111,
"mrr_at_100": 0.6737,
"mrr_at_1000": 0.67386,
"mrr_at_3": 0.65495,
"mrr_at_5": 0.66559,
"ndcg_at_1": 0.56794,
"ndcg_at_10": 0.56275,
"ndcg_at_100": 0.62991,
"ndcg_at_1000": 0.64939,
"ndcg_at_3": 0.55564,
"ndcg_at_5": 0.54815,
"precision_at_1": 0.56794,
"precision_at_10": 0.38468,
"precision_at_100": 0.05755,
"precision_at_1000": 0.00641,
"precision_at_3": 0.53329,
"precision_at_5": 0.49464,
"recall_at_1": 0.09911,
"recall_at_10": 0.55328,
"recall_at_100": 0.7634,
"recall_at_1000": 0.84758,
"recall_at_3": 0.25931,
"recall_at_5": 0.38263
},
"mteb_dataset_name": "SSRetrieval",
"mteb_version": "1.1.1"
}
模型 | Recall@1 | Recall@10 | Recall@100 | Recall@1000 |
---|---|---|---|---|
peg模型 | 9.911 | 55.328 | 76.34 | 84.758 |
微调256模型 | 13.427 | 78.675 | 98.082 | 99.181 |
可以看到,微调的模型,用更小的参数,见多识广后,整体效果明显优于未经历大规模数据训练的更大尺寸的模型。
参考
In-batch negatives Embedding模型介绍与实践的更多相关文章
- 模型介绍之FastText
模型介绍一: 1. FastText原理及实践 前言----来源&特点 fastText是Facebook于2016年开源的一个词向量计算和文本分类工具,在学术上并没有太大创新.但是它的优点也 ...
- (zhuan) 深度学习全网最全学习资料汇总之模型介绍篇
This blog from : http://weibo.com/ttarticle/p/show?id=2309351000224077630868614681&u=5070353058& ...
- OSI七层网络模型与TCP/IP四层模型介绍
目录 OSI七层网络模型与TCP/IP四层模型介绍 1.OSI七层网络模型介绍 2.TCP/IP四层网络模型介绍 3.各层对应的协议 4.OSI七层和TCP/IP四层的区别 5.交换机工作在OSI的哪 ...
- IO模型介绍
先理解几个问题: (1)为什么读取文件的时候,需要用户进程通过系统调用内核完成(系统不能自己调用内核)什么是用户态和内核态?为什么要区分内核态和用户态呢? 在 CPU 的所有指令中,有些指令是非常危险 ...
- 关于Axure RP软件的介绍——软件工程实践第二次个人作业
关于Axure RP软件的介绍——软件工程实践第二次个人作业 Axure RP是一个非常专业的快速原型设计的一个工具,客户提出需求,然后根据需求定义和规格.设计功能和界面的专家能够快速创建应用软件或W ...
- RabbitMQ系列(三)RabbitMQ交换器Exchange介绍与实践
RabbitMQ交换器Exchange介绍与实践 RabbitMQ系列文章 RabbitMQ在Ubuntu上的环境搭建 深入了解RabbitMQ工作原理及简单使用 RabbitMQ交换器Exchang ...
- python 全栈开发,Day44(IO模型介绍,阻塞IO,非阻塞IO,多路复用IO,异步IO,IO模型比较分析,selectors模块,垃圾回收机制)
昨日内容回顾 协程实际上是一个线程,执行了多个任务,遇到IO就切换 切换,可以使用yield,greenlet 遇到IO gevent: 检测到IO,能够使用greenlet实现自动切换,规避了IO阻 ...
- {python之IO多路复用} IO模型介绍 阻塞IO(blocking IO) 非阻塞IO(non-blocking IO) 多路复用IO(IO multiplexing) 异步IO(Asynchronous I/O) IO模型比较分析 selectors模块
python之IO多路复用 阅读目录 一 IO模型介绍 二 阻塞IO(blocking IO) 三 非阻塞IO(non-blocking IO) 四 多路复用IO(IO multiplexing) 五 ...
- 深入理解 Java 内存模型(一)- 内存模型介绍
深入理解 Java 内存模型(一)- 内存模型介绍 深入理解 Java 内存模型(二)- happens-before 规则 深入理解 Java 内存模型(三)- volatile 语义 深入理解 J ...
- IO模型《一》IO模型介绍
IO模型介绍 为了更好地了解IO模型,我们需要事先回顾下:同步.异步.阻塞.非阻塞 同步(synchronous) IO和异步(asynchronous) IO,阻塞(blocking) IO和非阻塞 ...
随机推荐
- 关闭表单验证/关闭控制台async-validator警告
找到util.js node_modules -> async-validator -> es -> util.js 将console.warn(type, errors)注释 如果 ...
- NEMU PA 1 实验报告
课程地址: PA1-1 https://www.bilibili.com/video/BV1JE411J7AK PA1-2 https://www.bilibili.com/video/BV1EE41 ...
- hdparm 常用命令介绍
hdparm命令介绍 通常情况下可以使用fdisk.df等命令查看硬盘的分区情况以及当前已使用空间大小.剩余空间大小等信息.但是如果要查看硬盘的硬件信息如 硬盘型号.序列号.已运行时间等信息该用什么工 ...
- 延时队列 DelayQueue
当用户超时未支付时,给用户发提醒消息.另一种场景是,超时未付款,订单自动取消.通常,订单创建的时候可以向延迟队列种插入一条消息,到时间自动执行.其实,也可以用临时表,把这些未支付的订单放到一个临时表中 ...
- Vue+SpringBoot+ElementUI实战学生管理系统-7.专业管理模块
1.章节介绍 前一篇介绍了院系管理模块,这一篇编写专业管理模块,需要的朋友可以拿去自己定制.:) 2.获取源码 源码是捐赠方式获取,详细请QQ联系我 :)! 3.实现效果 专业列表 修改专业 4.模块 ...
- Spring Boot图书管理系统项目实战-3.用户登录
导航: pre: 2.项目搭建 next:4.基础信息管理 只挑重点的讲,具体的请看项目源码. 1.项目源码 需要源码的朋友,请捐赠任意金额后留下邮箱发送:) 2.登录页设计 <!DOCTYP ...
- Qt+QtWebApp开发笔记(二):http服务器日志系统介绍、添加日志系统至Demo测试
前言 上一篇使用QtWebApp的基于Qt的轻量级http服务器实现了一个静态网页返回的Demo,网页服务器很重要的就是日志,因为在服务器类上并没有直接返回,所以,本篇先把日志加上. Demo ...
- java基础之StringBuilder---03
StringBuilder概述 StringBuilder是一个可变的字符串类,我们可以把它看成是一个容器,这里的可变指的是StringBuilder对象中的内容是可变的. 如果对字符串进行拼接操作, ...
- golang中关于map的value类型定义为函数类型时(方法值)的一点点思考
文章的内容仅仅是自己关于map的value类型定义为函数类型时的一点点思考,如有不对的地方,请不吝赐教. 学习过后才知道叫做 方法值. 1.起因 最近在看老项目代码时,看到了一段类似于下面的定义,最开 ...
- 【Azure 应用服务】App Service for Linux环境中,如何解决字体文件缺失的情况
问题描述 部署在App Service for Linux环境中的Web App.出现了字体文件缺失的问题,页面显示本来时中文的地方,区别变为方框占位. 问题分析 在应用中,通常涉及到显示问题的有两个 ...