随着预训练模型越来越成熟,预训练模型也会更多的在业务中使用,本文提供了bert和albert的快速训练和部署,实际上目前的预训练模型在用起来时都大致相同。

  基于不久前发布的中文数据集chineseGLUE,将所有任务分成四大类:文本分类,句子对判断,实体识别,阅读理解。同类可以共享代码,除上面四个任务之外,还加了一个learning to rank ,基于pair wise的方式的任务,代码见:https://github.com/jiangxinyang227/bert-for-task

  具体使用见readme

  模型定义在每个项目下的model.py文件中,直接调用bert和albert的源码modeling.py将预训练模型引入,将预训练模型作为encoder部分,也可以只作为embedding层,再自己定义encoder部分,总之可以非常方便的接入下游任务网络层,尤其是当你只想使用预训练模型作为embedding层时,我们需要自己些encoder部分。

  1.      bert_config = modeling.BertConfig.from_json_file(self.__bert_config_path)
  2.  
  3. model = modeling.BertModel(config=bert_config,
  4. is_training=self.__is_training,
  5. input_ids=self.input_ids,
  6. input_mask=self.input_masks,
  7. token_type_ids=self.segment_ids,
  8. use_one_hot_embeddings=False)
  9. output_layer = model.get_pooled_output()
  10.  
  11. hidden_size = output_layer.shape[-1].value
  12. if self.__is_training:
  13. # I.e., 0.1 dropout
  14. output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)
  15.  
  16. with tf.name_scope("output"):
  17. output_weights = tf.get_variable(
  18. "output_weights", [self.__num_classes, hidden_size],
  19. initializer=tf.truncated_normal_initializer(stddev=0.02))
  20.  
  21. output_bias = tf.get_variable(
  22. "output_bias", [self.__num_classes], initializer=tf.zeros_initializer())
  23.  
  24. logits = tf.matmul(output_layer, output_weights, transpose_b=True)
  25. logits = tf.nn.bias_add(logits, output_bias)
  26. self.predictions = tf.argmax(logits, axis=-1, name="predictions")

  在训练时加载预训练的参数值来初始化预训练模型的变量,具体在trainer.py文件中

  1. tvars = tf.trainable_variables()
  2. (assignment_map, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(
  3. tvars, self.__bert_checkpoint_path)
  4. print("init bert model params")
    tf.train.init_from_checkpoint(self.__bert_checkpoint_path, assignment_map)
  5. print("init bert model params done")
  6. sess.run(tf.variables_initializer(tf.global_variables()))

  在预测时可以直接实例化predict.py文件中的Predictor类就会加载checkpoint模型文件,调用类中的predict方法就可以进行预测,在不需要考虑模型代码加密,模型优化等情况下,可以直接线上部署。

  1. import json
  2.  
  3. from predict import Predictor
  4.  
  5. with open("config/tnews_config.json", "r") as fr:
  6. config = json.load(fr)
  7.  
  8. predictor = Predictor(config)
  9. text = "歼20座舱盖上的两条“花纹”是什么?"
  10. res = predictor.predict(text)
  11. print(res)

bert,albert的快速训练和预测的更多相关文章

  1. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  2. YOLO2 (3) 快速训练自己的目标

    1快速训练自己的目标 在 YOLO2 (2) 测试自己的数据 中记录了完整的训练自己数据的过程. 训练时目标只有一类 car. 如果已经执行过第一次训练,改过一次配置文件,之后仍然训练同样的目标还是只 ...

  3. 机器学习使用sklearn进行模型训练、预测和评价

    cross_val_score(model_name, x_samples, y_labels, cv=k) 作用:验证某个模型在某个训练集上的稳定性,输出k个预测精度. K折交叉验证(k-fold) ...

  4. 初识Sklearn-IrisData训练与预测

    笔记:机器学习入门---鸢尾花分类 Sklearn 本身就有很多数据库,可以用来练习. 以 Iris 的数据为例,这种花有四个属性,花瓣的长宽,茎的长宽,根据这些属性把花分为三类:山鸢尾花Setosa ...

  5. 【HEVC帧间预测论文】P1.1 基于运动特征的HEVC快速帧间预测算法

    基于运动特征的 HEVC 快速帧间预测算法/Fast Inter-Frame Prediction Algorithm for HEVC Based on Motion Features <HE ...

  6. Spark技术在京东智能供应链预测的应用——按照业务进行划分,然后利用scikit learn进行单机训练并预测

    3.3 Spark在预测核心层的应用 我们使用Spark SQL和Spark RDD相结合的方式来编写程序,对于一般的数据处理,我们使用Spark的方式与其他无异,但是对于模型训练.预测这些需要调用算 ...

  7. Tensorflow训练和预测中的BN层的坑

    以前使用Caffe的时候没注意这个,现在使用预训练模型来动手做时遇到了.在slim中的自带模型中inception, resnet, mobilenet等都自带BN层,这个坑在<实战Google ...

  8. fcn训练及预测tgs数据集

    一.背景 kaggle上有这样一个题目,关于盐份预测的语义分割题目.TGS Salt Identification Challenge | Kaggle  https://www.kaggle.com ...

  9. siftflow-fcn32s训练及预测

    一.说明 SIFT Flow 是一个标注的语义分割的数据集,有两个label,一个是语义分类(33类),另一个是场景标签(3类). Semantic and geometric segmentatio ...

随机推荐

  1. 初学JavaScript正则表达式(十三)

    字符串方法 search(reg) search()用于检索字符串中指定的子字符串,或检索与正则表达式相匹配的子字符串 方法返回第一个匹配结果index,查找不到返回-1 search()不执行全局匹 ...

  2. luoguP4151 [WC2011]最大XOR和路径

    题意 这题有点神啊. 首先考虑注意这句话: 路径可以重复经过某些点或边,当一条边在路径中出现了多次时,其权值在计算 XOR 和时也要被计算相应多的次数 也就是说如果出现下面的情况: 我们可以通过异或上 ...

  3. luoguP4069 [SDOI2016]游戏

    题意 显然书剖套李超树. 考虑怎么算函数值: 设\((x,y)\)的\(lca\)为\(z\),我们插一条斜率为\(k\),截距为\(b\)的线段. \((x,z)\)上的点\(u\): \(f(u) ...

  4. ionic4 组件调用的坑

    我们再开发过程中很多模块做成组件,那么调用的时候则需把module.ts中的引入去掉,如下红色框框:

  5. P2186 小Z的函数栈

    有点恶心的模拟(代码写整齐一点不就好了) 以下情况算错: 1.运行中有数的绝对值大于1000000000 2.除以和取模的时候第一个数为0 3.取栈顶元素时栈内元素不够 上代码 #include< ...

  6. MySQL实战45讲学习笔记:第三十五讲

    一.本节概述 在上一篇文章中,我和你介绍了 join 语句的两种算法,分别是 Index Nested-LoopJoin(NLJ) 和 Block Nested-Loop Join(BNL). 我们发 ...

  7. python 机器学习基础教程——第一章,引言

    https://www.cnblogs.com/HolyShine/p/10819831.html # from sklearn.datasets import load_iris import nu ...

  8. 面试被问怎么排查平时遇到的系统CPU飙高和频繁GC,该怎么回答?

    处理过线上问题的同学基本上都会遇到系统突然运行缓慢,CPU 100%,以及Full GC次数过多的问题.当然,这些问题的最终导致的直观现象就是系统运行缓慢,并且有大量的报警.本文主要针对系统运行缓慢这 ...

  9. 物联网架构成长之路(45)-容器管理平台Rancher

    0.前言 按照上一篇博客,我已经把需要下载的rancher docker 依赖镜像下载上传到Harbor了. 1. 安装 执行如下,实现一键安装 docker run -d --restart=unl ...

  10. pymysql 读取大数据内存卡死的解决方案

    背景:目前表中只有5G(后期持续增长),但是其中一个字段(以下称为detail字段)存了2M(不一定2M,部分为0,平均下来就是2M),字段中存的是一个数组,数组中存N个json数据.这个字段如下: ...