【关系抽取-R-BERT】模型结构
模型的整体结构
相关代码
import torch
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
class FCLayer(nn.Module):
def __init__(self, input_dim, output_dim, dropout_rate=0.0, use_activation=True):
super(FCLayer, self).__init__()
self.use_activation = use_activation
self.dropout = nn.Dropout(dropout_rate)
self.linear = nn.Linear(input_dim, output_dim)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.dropout(x)
if self.use_activation:
x = self.tanh(x)
return self.linear(x)
class RBERT(BertPreTrainedModel):
def __init__(self, config, args):
super(RBERT, self).__init__(config)
self.bert = BertModel(config=config) # Load pretrained bert
self.num_labels = config.num_labels
self.cls_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
self.entity_fc_layer = FCLayer(config.hidden_size, config.hidden_size, args.dropout_rate)
self.label_classifier = FCLayer(
config.hidden_size * 3,
config.num_labels,
args.dropout_rate,
use_activation=False,
)
@staticmethod
def entity_average(hidden_output, e_mask):
"""
Average the entity hidden state vectors (H_i ~ H_j)
:param hidden_output: [batch_size, j-i+1, dim]
:param e_mask: [batch_size, max_seq_len]
e.g. e_mask[0] == [0, 0, 0, 1, 1, 1, 0, 0, ... 0]
:return: [batch_size, dim]
"""
e_mask_unsqueeze = e_mask.unsqueeze(1) # [b, 1, j-i+1]
length_tensor = (e_mask != 0).sum(dim=1).unsqueeze(1) # [batch_size, 1]
# [b, 1, j-i+1] * [b, j-i+1, dim] = [b, 1, dim] -> [b, dim]
sum_vector = torch.bmm(e_mask_unsqueeze.float(), hidden_output).squeeze(1)
avg_vector = sum_vector.float() / length_tensor.float() # broadcasting
return avg_vector
def forward(self, input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask):
outputs = self.bert(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = outputs[1] # [CLS]
# Average
e1_h = self.entity_average(sequence_output, e1_mask)
e2_h = self.entity_average(sequence_output, e2_mask)
# Dropout -> tanh -> fc_layer (Share FC layer for e1 and e2)
pooled_output = self.cls_fc_layer(pooled_output)
e1_h = self.entity_fc_layer(e1_h)
e2_h = self.entity_fc_layer(e2_h)
# Concat -> fc_layer
concat_h = torch.cat([pooled_output, e1_h, e2_h], dim=-1)
logits = self.label_classifier(concat_h)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
# Softmax
if labels is not None:
if self.num_labels == 1:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
代码解析
- 首先我们来看RBERT类,它继承了BertPreTrainedModel类,在类初始化的时候要传入两个参数:config和args,config是模型相关的,args是其它的一些配置。
- 假设输入的input_ids, attention_mask, token_type_ids, labels, e1_mask, e2_mask的维度分别是:(16表示的是batchsize的大小,384表示的是设置的句子的最大长度)
input_ids.shape= torch.Size([16, 384])
attention_mask.shape= torch.Size([16, 384])
token_type_ids.shape= torch.Size([16, 384])
labels.shape= torch.Size([16])
e1_mask.shape= torch.Size([16, 384])
e2_mask.shape= torch.Size([16, 384])
经过原始的bert之后得到output,其中outputs[0]的维度是[16,384,768],也就是每一个句子的表示,outputs[1]表示的是经过池化之后的句子表示,维度是[16,768],意思是将384个字的每个维度的特征通过池化将信息聚合在一起。 - 对于sequence_output, e1_mask或者sequence_output, e2_mask,我们将他们分别传入到entity_averag函数中,针对于e1_mask或者e2_mask,他们的维度都是[16,384],然后进行变换为[16,1,384],通过将[16,1,384]和[16,384,768]进行矩阵相乘,就得到了实体的特征表示,维度是[16,1,768],去除掉第1维再除以实体的长度进行归一化,最终得到一个[16,768]的表示。
- 我们将cls,也就是outputs[1],和实体1以及实体2的特征表示进行拼接,得到一个维度为[16,2304]的张量,再经过一个全连接层映射成[16,19],这里的19是类别的数目,最后使用相关的损失函数计算损失即可。
使用
最后是这么使用的:
定义相关参数以及设置
self.args = args
self.train_dataset = train_dataset
self.dev_dataset = dev_dataset
self.test_dataset = test_dataset
self.label_lst = get_label(args)
self.num_labels = len(self.label_lst)
self.config = BertConfig.from_pretrained(
args.model_name_or_path,
num_labels=self.num_labels,
finetuning_task=args.task,
id2label={str(i): label for i, label in enumerate(self.label_lst)},
label2id={label: i for i, label in enumerate(self.label_lst)},
)
self.model = RBERT.from_pretrained(args.model_name_or_path, config=self.config, args=args)
# GPU or CPU
self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
self.model.to(self.device)
【关系抽取-R-BERT】模型结构的更多相关文章
- 学习笔记CB003:分块、标记、关系抽取、文法特征结构
分块,根据句子的词和词性,按照规则组织合分块,分块代表实体.常见实体,组织.人员.地点.日期.时间.名词短语分块(NP-chunking),通过词性标记.规则识别,通过机器学习方法识别.介词短语(PP ...
- 【关系抽取-R-BERT】定义训练和验证循环
[关系抽取-R-BERT]加载数据集 [关系抽取-R-BERT]模型结构 [关系抽取-R-BERT]定义训练和验证循环 相关代码 import logging import os import num ...
- Bert模型实现垃圾邮件分类
近日,对近些年在NLP领域很火的BERT模型进行了学习,并进行实践.今天在这里做一下笔记. 本篇博客包含下列内容: BERT模型简介 概览 BERT模型结构 BERT项目学习及代码走读 项目基本特性介 ...
- 人工智能论文解读精选 | PRGC:一种新的联合关系抽取模型
NLP论文解读 原创•作者 | 小欣 论文标题:PRGC: Potential Relation and Global Correspondence Based Joint Relational ...
- NLP(二十一)人物关系抽取的一次实战
去年,笔者写过一篇文章利用关系抽取构建知识图谱的一次尝试,试图用现在的深度学习办法去做开放领域的关系抽取,但是遗憾的是,目前在开放领域的关系抽取,还没有成熟的解决方案和模型.当时的文章仅作为笔者的 ...
- 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史(转载)
转载 https://zhuanlan.zhihu.com/p/49271699 首发于深度学习前沿笔记 写文章 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 张 ...
- 想研究BERT模型?先看看这篇文章吧!
最近,笔者想研究BERT模型,然而发现想弄懂BERT模型,还得先了解Transformer. 本文尽量贴合Transformer的原论文,但考虑到要易于理解,所以并非逐句翻译,而是根据笔者的个人理解进 ...
- zz从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史
从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史 Bert最近很火,应该是最近最火爆的AI进展,网上的评价很高,那么Bert值得这么高的评价吗?我个人判断是值得.那为什么 ...
- 图示详解BERT模型的输入与输出
一.BERT整体结构 BERT主要用了Transformer的Encoder,而没有用其Decoder,我想是因为BERT是一个预训练模型,只要学到其中语义关系即可,不需要去解码完成具体的任务.整体架 ...
随机推荐
- HDU4578 Transformation(多标记线段树)题解
题意: 操作有:\(1\).区间都加\(a\):\(2\).区间都乘\(a\):\(3\).区间都重置成\(a\):\(4\).询问区间幂次和\(\sum_{i=l}^rnum[i]^p(p\in\{ ...
- 5分钟学Go 基础01:初识 Go 的第一印象是薪水可观
本文首发于公众号「5分钟学Go」,一个让你每次花 5 分钟就能掌握一个技能点的公众号.目前在博主连更 5 分钟学Go系列,大家可以关注下,第一时间掌握Go技能.如果想要加群交流,可以在公众号后台回复「 ...
- TypeScript Version 23 Design Patterns
TypeScript Version 23 Design Patterns TypeScript 设计模式 https://refactoring.guru/design-patterns/types ...
- HTML a Tag All In One
HTML a Tag All In One HTML <a> target https://developer.mozilla.org/en-US/docs/Web/HTML/Elemen ...
- GitHub & GitHub Package Registry
GitHub & GitHub Package Registry npm https://github.blog/2019-05-10-introducing-github-package-r ...
- array auto slice
array auto slice https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Ar ...
- SpringBoot进阶教程(七十一)详解Prometheus+Grafana
随着容器技术的迅速发展,Kubernetes已然成为大家追捧的容器集群管理系统.Prometheus作为生态圈Cloud Native Computing Foundation(简称:CNCF)中的重 ...
- JDK环境解析,安装和目的
目录 1. JDK环境解析 1.1 JVM 1.2 JRE 1.3 JDK 2. JDK安装 2.1 为什么使用JDK8 2.1.1 更新 2.1.2 稳定 2.1.3 需求 2.2 安装JDK 2. ...
- HDOJ-3001(TSP+三进制状态压缩)
Traving HDOJ-3001 这题考察的是状态压缩dp和tsp问题的改编 需要和传统tsp问题区分的事,这题每个点最多可以经过两次故状态有3种:0,1,2 这里可以模仿tsp问题的二进制压缩方法 ...
- js导出execl 兼容ie Chrome Firefox各种主流浏览器(js export execl)
第一种导出table布局的表格 1 <html> 2 3 <head> 4 <meta charset="utf-8"> 5 <scrip ...