模型的整体结构

相关代码

  1. import torch
  2. import torch.nn as nn
  3. from transformers import BertModel, BertPreTrainedModel
  4. class FCLayer(nn.Module):
  5. def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
  6. super(FCLayer, self).__init__()
  7. self.use_activation = use_activation
  8. self.dropout = nn.Dropout(dropout_rate)
  9. self.linear = nn.Linear(input_dim, output_dim)
  10. self.tanh = nn.Tanh()
  11. def forward(self, x):
  12. x = self.dropout(x)
  13. if self.use_activation:
  14. x = self.tanh(x)
  15. return self.linear(x)
  16. class RBERT(BertPreTrainedModel):
  17. def __init__(self, config, args):
  18. super(RBERT, self).__init__(config)
  19. self.bert = BertModel(config=config) # Load pretrained bert
  20. self.num_labels = config.num_labels
  21. self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
  22. self.entity_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
  23. self.label_classifier = FCLayer(
  24. config.hidden_size * 3,
  25. config.num_labels,
  26. args.dropout_rate,
  27. use_activation=False,
  28. )
  29. @staticmethod
  30. def entity_average(hidden_output, e_mask):
  31. """
  32. Average the entity hidden state vectors (H_i ~ H_j)
  33. :param hidden_output: [batch_size, j-i+1, dim]
  34. :param e_mask: [batch_size, max_seq_len]
  35. e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
  36. :return: [batch_size, dim]
  37. """
  38. e_mask_unsqueeze = e_mask.unsqueeze(1) # [b, 1, j-i+1]
  39. length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1]
  40. # [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
  41. sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
  42. avg_vector = sum_vector.float() / length_tensor.float() # broadcasting
  43. return avg_vector
  44. def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask):
  45. outputs = self.bert(
  46. input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
  47. ) # sequence_output, pooled_output, (hidden_states), (attentions)
  48. sequence_output = outputs[0]
  49. pooled_output = outputs[1] # [CLS]
  50. # Average
  51. e1_h = self.entity_average(sequence_output, e1_mask)
  52. e2_h = self.entity_average(sequence_output, e2_mask)
  53. # Dropout -> tanh -> fc_layer (Share FC layer for e1 and e2)
  54. pooled_output = self.cls_fc_layer(pooled_output)
  55. e1_h = self.entity_fc_layer(e1_h)
  56. e2_h = self.entity_fc_layer(e2_h)
  57. # Concat -> fc_layer
  58. concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
  59. logits = self.label_classifier(concat_h)
  60. outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
  61. # Softmax
  62. if labels is not None:
  63. if self.num_labels == 1:
  64. loss_fct = nn.MSELoss()
  65. loss = loss_fct(logits.view(-1), labels.view(-1))
  66. else:
  67. loss_fct = nn.CrossEntropyLoss()
  68. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  69. outputs = (loss,) + outputs
  70. return outputs # (loss), logits, (hidden_states), (attentions)

代码解析

  • 首先我们来看RBERT类,它继承了BertPreTrainedModel类,在类初始化的时候要传入两个参数:config和args,config是模型相关的,args是其它的一些配置。
  • 假设输入的input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask的维度分别是:(16表示的是batchsize的大小,384表示的是设置的句子的最大长度)

    input_ids.shape= torch.Size([16, 384])

    attention_mask.shape= torch.Size([16, 384])

    token_type_ids.shape= torch.Size([16, 384])

    labels.shape= torch.Size([16])

    e1_mask.shape= torch.Size([16, 384])

    e2_mask.shape= torch.Size([16, 384])

    经过原始的bert之后得到output,其中outputs[0]的维度是[16,384,768],也就是每一个句子的表示,outputs[1]表示的是经过池化之后的句子表示,维度是[16,768],意思是将384个字的每个维度的特征通过池化将信息聚合在一起。
  • 对于sequence_output, e1_mask或者sequence_output, e2_mask,我们将他们分别传入到entity_averag函数中,针对于e1_mask或者e2_mask,他们的维度都是[16,384],然后进行变换为[16,1,384],通过将[16,1,384]和[16,384,768]进行矩阵相乘,就得到了实体的特征表示,维度是[16,1,768],去除掉第1维再除以实体的长度进行归一化,最终得到一个[16,768]的表示。
  • 我们将cls,也就是outputs[1],和实体1以及实体2的特征表示进行拼接,得到一个维度为[16,2304]的张量,再经过一个全连接层映射成[16,19],这里的19是类别的数目,最后使用相关的损失函数计算损失即可。

使用

最后是这么使用的:

定义相关参数以及设置

  1. self.args = args
  2. self.train_dataset = train_dataset
  3. self.dev_dataset = dev_dataset
  4. self.test_dataset = test_dataset
  5. self.label_lst = get_label(args)
  6. self.num_labels = len(self.label_lst)
  7. self.config = BertConfig.from_pretrained(
  8. args.model_name_or_path,
  9. num_labels=self.num_labels,
  10. finetuning_task=args.task,
  11. id2label={str(i): label for i, label in enumerate(self.label_lst)},
  12. label2id={label: i for i, label in enumerate(self.label_lst)},
  13. )
  14. self.model = RBERT.from_pretrained(args.model_name_or_path, config=self.config, args=args)
  15. # GPU or CPU
  16. self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
  17. self.model.to(self.device)

代码来源:https://github.com/monologg/R-BERT/

【关系抽取-R-BERT】模型结构的更多相关文章

  1. 学习笔记CB003:分块、标记、关系抽取、文法特征结构

    分块,根据句子的词和词性,按照规则组织合分块,分块代表实体.常见实体,组织.人员.地点.日期.时间.名词短语分块(NP-chunking),通过词性标记.规则识别,通过机器学习方法识别.介词短语(PP ...

  2. 【关系抽取-R-BERT】定义训练和验证循环

    [关系抽取-R-BERT]加载数据集 [关系抽取-R-BERT]模型结构 [关系抽取-R-BERT]定义训练和验证循环 相关代码 import logging import os import num ...

  3. Bert模型实现垃圾邮件分类

    近日,对近些年在NLP领域很火的BERT模型进行了学习,并进行实践.今天在这里做一下笔记. 本篇博客包含下列内容: BERT模型简介 概览 BERT模型结构 BERT项目学习及代码走读 项目基本特性介 ...

  4. 人工智能论文解读精选 | PRGC:一种新的联合关系抽取模型

    NLP论文解读 原创•作者 | 小欣   论文标题:PRGC: Potential Relation and Global Correspondence Based Joint Relational ...

  5. NLP(二十一)人物关系抽取的一次实战

      去年,笔者写过一篇文章利用关系抽取构建知识图谱的一次尝试,试图用现在的深度学习办法去做开放领域的关系抽取,但是遗憾的是,目前在开放领域的关系抽取,还没有成熟的解决方案和模型.当时的文章仅作为笔者的 ...

  6. 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)

    转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章   从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...

  7. 想研究BERT模型?先看看这篇文章吧!

    最近,笔者想研究BERT模型,然而发现想弄懂BERT模型,还得先了解Transformer. 本文尽量贴合Transformer的原论文,但考虑到要易于理解,所以并非逐句翻译,而是根据笔者的个人理解进 ...

  8. zz从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史

    从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 Bert最近很火,应该是最近最火爆的AI进展,网上的评价很高,那么Bert值得这么高的评价吗?我个人判断是值得.那为什么 ...

  9. 图示详解BERT模型的输入与输出

    一.BERT整体结构 BERT主要用了Transformer的Encoder,而没有用其Decoder,我想是因为BERT是一个预训练模型,只要学到其中语义关系即可,不需要去解码完成具体的任务.整体架 ...

随机推荐

  1. Java之一个整数的二进制中1的个数

    这是今年某公司的面试题: 一般思路是:把整数n转换成二进制字符数组,然后一个一个数: private static int helper1(int i) { char[] chs = Integer. ...

  2. Linux 驱动框架---dm9000分析

    前面学习了下Linux下的网络设备驱动程序的框架inux 驱动框架---net驱动框架,感觉知道了一个机器的大致结构还是不太清楚具体的细节处是怎么处理的,所以今天就来以dm9000这个网上教程最多的驱 ...

  3. adjust All In One

    adjust All In One 调整 https://www.adjust.com/ Maximize the impact of your mobile marketing Adjust is ...

  4. CSP & CORS

    CSP & CORS 内容安全策略 跨域资源共享 CSP https://developers.google.com/web/fundamentals/security/csp google ...

  5. 从长度为 M 的无序数组中,找出N个最小的数

    从长度为 M 的无序数组中,找出 N个最小的数 在一组长度为 n 的无序的数组中,取最小的 m个数(m < n), 要求时间复杂度 O(m * n) 网易有道面试题 const minTopK ...

  6. ES6 & import * & import default & import JSON

    ES6 & import * & import default & import JSON import json & default value bug api.js ...

  7. GitHub & GitHub Package Registry

    GitHub & GitHub Package Registry npm https://github.blog/2019-05-10-introducing-github-package-r ...

  8. 使用docker mediawiki,搭建网页wiki

    我只是想做一个大家都能访问的wiki,用于成员间共享和维护一些文档.找到了docker的mediawiki,这里记录一下我怎么搭的吧. 首先,如果你在一个局域网里,有公用的可以访问的服务器,那可以直接 ...

  9. Debain 系统U盘安装完全图解

    习惯了使用图形界面的操作,总有一股想要切换到文字界面的Linux的冲动,刚好趁家里的老台式机,没什么用了,就打算用来玩下Linux,在一路安装与使用的过程中,碰到了许多的问题.顺便记录下来,以希望可以 ...

  10. 后端程序员之路 12、K最近邻(k-Nearest Neighbour,KNN)分类算法

    K最近邻(k-Nearest Neighbour,KNN)分类算法,是最简单的机器学习算法之一.由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重 ...