一.在实体识别中,bert+lstm+crf也是近来常用的方法。这里的bert可以充当固定的embedding层,也可以用来和其它模型一起训练fine-tune。大家知道输入到bert中的数据需要一定的格式,如在单个句子的前后需要加入"[CLS]"和“[SEP]”,需要mask等。下面使用pad_sequences对句子长度进行截断以及padding填充,使每个输入句子的长度一致。构造训练集后,下载中文的预训练模型并加载相应的模型和词表vocab以参数配置,最后并利用albert抽取句子的embedding,这个embedding可以作为一个下游任务和其它模型进行组合完成特定任务的训练。

  1. import torch
  2. from configs.base import config
  3. from model.modeling_albert import BertConfig, BertModel
  4. from model.tokenization_bert import BertTokenizer
  5. from keras.preprocessing.sequence import pad_sequences
  6. from torch.utils.data import TensorDataset, DataLoader, RandomSampler
  7.  
  8. import os
  9.  
  10. device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
  11. MAX_LEN = 10
  12. if __name__ == '__main__':
  13. bert_config = BertConfig.from_pretrained(str(config['albert_config_path']), share_type='all')
  14. base_path = os.getcwd()
  15. VOCAB = base_path + '/configs/vocab.txt' # your path for model and vocab
  16. tokenizer = BertTokenizer.from_pretrained(VOCAB)
  17.  
  18. # encoder text
  19. tag2idx={'[SOS]':101, '[EOS]':102, '[PAD]':0, 'B_LOC':1, 'I_LOC':2, 'O':3}
  20. sentences = ['我是中华人民共和国国民', '我爱祖国']
  21. tags = ['O O B_LOC I_LOC I_LOC I_LOC I_LOC I_LOC O O', 'O O O O']
  22.  
  23. tokenized_text = [tokenizer.tokenize(sent) for sent in sentences]
  24. #利用pad_sequence对序列长度进行截断和padding
  25. input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_text], #没法一条一条处理,只能2-d的数据,即多于一条样本,但是如果全部加载到内存是不是会爆
  26. maxlen=MAX_LEN-2,
  27. truncating='post',
  28. padding='post',
  29. value=0)
  30.  
  31. tag_ids = pad_sequences([[tag2idx.get(tok) for tok in tag.split()] for tag in tags],
  32. maxlen=MAX_LEN-2,
  33. padding="post",
  34. truncating="post",
  35. value=0)
  36.  
  37. #bert中的句子前后需要加入[CLS]:101和[SEP]:102
  38. input_ids_cls_sep = []
  39. for input_id in input_ids:
  40. linelist = []
  41. linelist.append(101)
  42. flag = True
  43. for tag in input_id:
  44. if tag > 0:
  45. linelist.append(tag)
  46. elif tag == 0 and flag:
  47. linelist.append(102)
  48. linelist.append(tag)
  49. flag = False
  50. else:
  51. linelist.append(tag)
  52. if tag > 0:
  53. linelist.append(102)
  54. input_ids_cls_sep.append(linelist)
  55.  
  56. tag_ids_cls_sep = []
  57. for tag_id in tag_ids:
  58. linelist = []
  59. linelist.append(101)
  60. flag = True
  61. for tag in tag_id:
  62. if tag > 0:
  63. linelist.append(tag)
  64. elif tag == 0 and flag:
  65. linelist.append(102)
  66. linelist.append(tag)
  67. flag = False
  68. else:
  69. linelist.append(tag)
  70. if tag > 0:
  71. linelist.append(102)
  72. tag_ids_cls_sep.append(linelist)
  73.  
  74. attention_masks = [[int(tok > 0) for tok in line] for line in input_ids_cls_sep]
  75.  
  76. print('---------------------------')
  77. print('input_ids:{}'.format(input_ids_cls_sep))
  78. print('tag_ids:{}'.format(tag_ids_cls_sep))
  79. print('attention_masks:{}'.format(attention_masks))
  80.  
  81. # input_ids = torch.tensor([tokenizer.encode('我 是 中 华 人 民 共 和 国 国 民', add_special_tokens=True)]) #为True则句子首尾添加[CLS]和[SEP]
  82. # print('input_ids:{}, size:{}'.format(input_ids, len(input_ids)))
  83. # print('attention_masks:{}, size:{}'.format(attention_masks, len(attention_masks)))
  84.  
  85. inputs_tensor = torch.tensor(input_ids_cls_sep)
  86. tags_tensor = torch.tensor(tag_ids_cls_sep)
  87. masks_tensor = torch.tensor(attention_masks)
  88.  
  89. train_data = TensorDataset(inputs_tensor, tags_tensor, masks_tensor)
  90. train_sampler = RandomSampler(train_data)
  91. train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=2)
  92.  
  93. model = BertModel.from_pretrained(config['bert_dir'],config=bert_config)
  94. model.to(device)
  95. model.eval()
  96. with torch.no_grad():
  97. '''
  98. note:
  99. 一.
  100. 如果设置:"output_hidden_states":"True"和"output_attentions":"True"
  101. 输出的是: 所有层的 sequence_output, pooled_output, (hidden_states), (attentions)
  102. 则 all_hidden_states, all_attentions = model(input_ids)[-2:]
  103.  
  104. 二.
  105. 如果没有设置:output_hidden_states和output_attentions
  106. 输出的是:最后一层 --> (output_hidden_states, output_attentions)
  107. '''
  108. for index, batch in enumerate(train_dataloader):
  109. batch = tuple(t.to(device) for t in batch)
  110. b_input_ids, b_input_mask, b_labels = batch
  111. last_hidden_state = model(input_ids = b_input_ids,attention_mask = b_input_mask)
  112. print(len(last_hidden_state))
  113. all_hidden_states, all_attentions = last_hidden_state[-2:] #这里获取所有层的hidden_satates以及attentions
  114. print(all_hidden_states[-2].shape)#倒数第二层hidden_statesshape
             print(all_hidden_states[-2])

二.打印结果

input_ids:[[101, 2769, 3221, 704, 1290, 782, 3696, 1066, 1469, 102], [101, 2769, 4263, 4862, 1744, 102, 0, 0, 0, 0]]
tag_ids:[[101, 3, 3, 1, 2, 2, 2, 2, 2, 102], [101, 3, 3, 3, 3, 102, 0, 0, 0, 0]]
attention_masks:[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]
4
torch.Size([2, 10, 768])
tensor([[[-1.1074, -0.0047,  0.4608,  ..., -0.1816, -0.6379,  0.2295],
         [-0.1930, -0.4629,  0.4127,  ..., -0.5227, -0.2401, -0.1014],
         [ 0.2682, -0.6617,  0.2744,  ..., -0.6689, -0.4464,  0.1460],
         ...,
         [-0.1723, -0.7065,  0.4111,  ..., -0.6570, -0.3490, -0.5541],
         [-0.2028, -0.7025,  0.3954,  ..., -0.6566, -0.3653, -0.5655],
         [-0.2026, -0.6831,  0.3778,  ..., -0.6461, -0.3654, -0.5523]],

[[-1.3166, -0.0052,  0.6554,  ..., -0.2217, -0.5685,  0.4270],
         [-0.2755, -0.3229,  0.4831,  ..., -0.5839, -0.1757, -0.1054],
         [-1.4941, -0.1436,  0.8720,  ..., -0.8316, -0.5213, -0.3893],
         ...,
         [-0.7022, -0.4104,  0.5598,  ..., -0.6664, -0.1627, -0.6270],
         [-0.7389, -0.2896,  0.6083,  ..., -0.7895, -0.2251, -0.4088],
         [-0.0351, -0.9981,  0.0660,  ..., -0.4606,  0.4439, -0.6745]]])

关于bert+lstm+crf实体识别训练数据的构建的更多相关文章

  1. 基于bert的命名实体识别,pytorch实现,支持中文/英文【源学计划】

    声明:为了帮助初学者快速入门和上手,开始源学计划,即通过源代码进行学习.该计划收取少量费用,提供有质量保证的源码,以及详细的使用说明. 第一个项目是基于bert的命名实体识别(name entity ...

  2. BiLSTM+CRF 实体识别

    https://www.cnblogs.com/Determined22/p/7238342.html 这篇博客 里面这个公式表示抽象的含义,表示的是最后的分数由他们影响,不是直观意义上的相加. 为什 ...

  3. 『深度应用』NLP命名实体识别(NER)开源实战教程

    近几年来,基于神经网络的深度学习方法在计算机视觉.语音识别等领域取得了巨大成功,另外在自然语言处理领域也取得了不少进展.在NLP的关键性基础任务—命名实体识别(Named Entity Recogni ...

  4. 基于keras实现的中文实体识别

    1.简介 NER(Named Entity Recognition,命名实体识别)又称作专名识别,是自然语言处理中常见的一项任务,使用的范围非常广.命名实体通常指的是文本中具有特别意义或者指代性非常强 ...

  5. 抛弃模板,一种Prompt Learning用于命名实体识别任务的新范式

    原创作者 | 王翔 论文名称: Template-free Prompt Tuning for Few-shot NER 文献链接: https://arxiv.org/abs/2109.13532 ...

  6. 基于BERT预训练的中文命名实体识别TensorFlow实现

    BERT-BiLSMT-CRF-NERTensorflow solution of NER task Using BiLSTM-CRF model with Google BERT Fine-tuni ...

  7. 用IDCNN和CRF做端到端的中文实体识别

    实体识别和关系抽取是例如构建知识图谱等上层自然语言处理应用的基础.实体识别可以简单理解为一个序列标注问题:给定一个句子,为句子序列中的每一个字做标注.因为同是序列标注问题,除去实体识别之外,相同的技术 ...

  8. 基于条件随机场(CRF)的命名实体识别

    很久前做过一个命名实体识别的模块,现在有时间,记录一下. 一.要识别的对象 人名.地名.机构名 二.主要方法 1.使用CRF模型进行识别(识别对象都是最基础的序列,所以使用了好评率较高的序列识别算法C ...

  9. 基于双向LSTM和迁移学习的seq2seq核心实体识别

    http://spaces.ac.cn/archives/3942/ 暑假期间做了一下百度和西安交大联合举办的核心实体识别竞赛,最终的结果还不错,遂记录一下.模型的效果不是最好的,但是胜在“端到端”, ...

随机推荐

  1. Springboot使用外置tomcat的同时使用websocket通信遇到的坑

    随意门:https://blog.csdn.net/qq_43323720/article/details/99660430 另外,使用了nginx的话,需要注意开放websocket支持 serve ...

  2. Machine概念和获取帮助 【翻译】

    Machine概念和获取帮助 Docker Machine 允许您在各种环境中预配 Docker 计算机,包括驻留在本地系统.云提供商或裸机服务器(物理计算机)上的虚拟机.Docker Machine ...

  3. Web API 接口版本控制 SDammann.WebApi.Versioning

    前言 在设计对外 Web API 时,实务上可能会有新旧版本 API 并存的情况,例如开放 Web API 给厂商串接,但同一个服务更新版本时,不一定所有厂商可以在同一时间都跟着更新他们的系统,但如果 ...

  4. pm2 常用操作

    PM2全局安装 npm i pm2 -g PM2启动.net core pm2 start "dotnet xxx.dll" --name api //name后面跟你要取的名字 ...

  5. QT开发小技巧-窗口处理(一)

    this->setWindowFlags(Qt::WindowCloseButtonHint); // 仅保留关闭按钮 this->setAttribute(Qt::WA_DeleteOn ...

  6. html中正则匹配img

    1.正则匹配html中的img标签,取出img的url并进行图片文件下载: /// <summary> /// 将image标签的src属性的url替换为base64 /// </s ...

  7. arcgis js之调用wms服务

    arcgis js之调用wms服务 定义: export const tdtlayer = async () => { let WMSLayer = await arcgisPackage.WM ...

  8. Nginx如何配置基础缓存

    // /path/to/cache/:用于缓存的本地磁盘目录 // levels :在 /path/to/cache/ 设置了一个两级层次结构的目录. // 将大量的文件放置在单个目录中会导致文件访问 ...

  9. Image Processing and Computer Vision_Review:A survey of recent advances in visual feature detection—2014.08

    翻译 一项关于视觉特征检测的最新进展概述——http://tongtianta.site/paper/56761 摘要 -特征检测是计算机视觉和图像处理中的基础和重要问题.这是一个低级处理步骤,它是基 ...

  10. Django:forms局部函数、cookie、sesion、auth模块

    一.forms组件 二.cookie和session组件 三.auth组件 一.forms组件 1.校验字段功能 针对一个实例:注册用户讲解 模型:models class UserInfo(mode ...