论文地址为:Cognitive Graph for Multi-Hop Reading Comprehension at Scale

github地址:CogQA

背景

假设你手边有一个维基百科的搜索引擎,可以用来获取实体对应的文本段落,那么如何来回答下面这个复杂的问题呢?

“谁是某部在2003年取景于洛杉矶Quality cafe的电影的导演?”

很自然地,我们将会从例如Quality cafe这样的“相关实体”入手,通过维基百科查询相关介绍,并在其中讲到好莱坞电影的时候迅速定位到“Old School”“Gone in 60 Seconds”这两部电影,通过继续查询两部电影相关的介绍,我们找到他们的导演。最后一步是判断到底是哪位导演,这需要我们自己分析句子的语意和限定词,在了解到电影是2003年之后,我们可以做出最后判断——Todd Phillips是我们想要的答案。

事实上,“快速将注意力定位到相关实体”和“分析句子语意进行推断”是两种不同的思维过程。

在认知学里,著名的“双过程理论(dual process theory)”认为,人的认知分为两个系统,系统一(System 1)是基于直觉的、无知觉的思考系统,其运作依赖于经验和关联;而系统二(System 2)则是人类特有的逻辑推理能力,此系统利用工作记忆(working memory)中的知识进行慢速但是可靠的逻辑推理,系统二是显式的,需要意识控制的,是人类高级智能的体现。

论文详情

因此,本文提出一种新颖的迭代框架:算法使用两个系统来维护一张认知图谱(Cognitive Graph):

  • 系统一在文本中抽取与问题相关的实体名称并扩展节点和汇总语义向量,
  • 系统二利用图神经网络在认知图谱上进行推理计算。

正如之前提到的,人类的系统一是无知觉(unconscious),CogQA中的系统一也是流行的NLP黑盒模型,例如BERT。

在文章的实现中,系统一的输入分为三部分:

  1. 问题本身
  2. 从前面段落中找到的“线索(clues)”
  3. 关于某个实体x的维基百科文档

系统一的目标是抽取文档中的“下一跳实体名称(hop span)”和“答案候选(ans span)”。

这些抽取的到的实体和答案候选将作为节点添加到认知图谱中。此外,系统一还将计算当前实体 x 的语意向量,这将在系统二中用作关系推理的初始值。

模型架构图如下:

 源码解析(主要是model.py文件)分为七大模块

1. 导入相应的库代码,主要是bert模块,有些库model.py没有用到,这块就不做相应解释了。(utils是文章作者写的模块)

 from pytorch_pretrained_bert.modeling import (
BertPreTrainedModel as PreTrainedBertModel,
BertModel,
BertLayerNorm,
gelu,
BertEncoder,
BertPooler,
)
import torch
from torch import nn
import re
import pdb
from pytorch_pretrained_bert.tokenization import (
whitespace_tokenize,
BasicTokenizer,
BertTokenizer,
)
from utils import (
fuzzy_find,
find_start_end_after_tokenized,
find_start_end_before_tokenized,
bundle_part_to_batch,
)

2. MLP模块

该模块较为简单,就是简单的多层感知机,如果大于两层,会加入相应的dropout 和 LayerNorm,并采用了bert所特有的gelu激活函数。

 class MLP(nn.Module):
def __init__(self, input_sizes, dropout_prob=0.2, bias=False):
super(MLP, self).__init__()
self.layers = nn.ModuleList()
for i in range(1, len(input_sizes)):
self.layers.append(nn.Linear(input_sizes[i-1], input_sizes[i], bias=bias))
self.norm_layers = nn.ModuleList()
if len(input_sizes) > 2:
for i in range(1, len(input_sizes) - 1):
self.norm_layers.append(nn.LayerNorm(input_sizes[i]))
self.drop_out = nn.Dropout(p=dropout_prob) def forward(self, x):
for i, layer in enumerate(self.layers):
x = layer(self.drop_out(x))
if i < len(self.layers) - 1:
x = gelu(x)
if len(self.norm_layers):
x = self.norm_layers[i](x)
return x

3. GCN模块

这里采用最基础的GCN,没有使用任何GCN库,速度可能较慢,但是考虑到主要时间限制在bert模型,所以这里的时间效率下降可忽略。

 class GCN(nn.Module):
def init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.05) def __init__(self, input_size):
super(GCN, self).__init__()
self.diffusion = nn.Linear(input_size, input_size, bias=False) # diffusion线性变换
self.retained = nn.Linear(input_size, input_size, bias=False) # retaine线性变换
self.predict = MLP(input_sizes=(input_size, input_size, 1))
self.apply(self.init_weights) # 参数矩阵赋予初始化权重(正态分布) def forward(self, A, x):
layer1_diffusion = A.t().mm(gelu(self.diffusion(x))) # t() 转置
# A为邻接矩阵(n, n) * (n, input_size) ==> (n, input_size)
x = gelu(self.retained(x) + layer1_diffusion) # (n, input_size)
layer2_diffusion = A.t().mm(gelu(self.diffusion(x))) # (n, input_size)
x = gelu(self.retained(x) +layer2_diffusion) # (n, input_size)
return self.predict(x).sqeeze(-1) # (n, )

4. bert embedding模块 (具体见注释)

 class BertEmbeddingsPlus(nn.Module):
""" 构建word embeddings, position embeddings, token_type embeddings.
"""
def __init__(self, config, max_sentence_type=30):
super(BertEmbeddingsPlus, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(
config.max_position_embeddings, config.hidden_size
) # shape (位置embedding的种类,隐层大小)
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size, config.hidden_size
) # (2, hidden_size) A/B segment
self.sentence_type_embeddings = nn.Embedding(
max_sentence_type, config.hidden_size
) # 句子类型embedding (30, hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) # bert LN层
self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None):
"""
:param input_ids: (n, seq_length) n 就是 batch_size
:param token_type_ids: (n, seq_length)
:return:
"""
seq_length = input_ids.size(1) # 文本序列长度
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
# [5] => [0,1,2,3,4]
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# shape变化:(seq_length) => (1, seq_length) => (n, seq_length)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) word_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings((token_type_ids > 0).long())
# token_type_embeddings, 分为 A/B,segment bert输入模式
sentences_type_embeddings = self.sentence_type_embeddings(token_type_ids)
# 这才是对token_type进行embedding embeddings = (word_embeddings + position_embeddings
+ token_type_embeddings + sentences_type_embeddings)
# 四个embedding相加,充分考虑各种信息
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings

5 bert模型编码模块

输入input_ids,输出bert的最后的编码结果,和指定哪一层的编码结果(下层偏语法,上层偏语义)
 class BertModelPlus(BertModel):
def __init__(self, config):
super(BertModelPlus, self).__init__()
self.embeddings = BertEmbeddingsPlus(config)
self.encoder = BertEncoder(config) # bert 编码器
self.pooler = BertPooler(config) # bert 池化器
self.apply(self.init_bert_weights) # BertModel 的初始权重参数 def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_hidden=-4):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids) # (n, seq_length), n is batch_size
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# (n, seq_length) => (n, 1, 1, seq_length) extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 转换数值类型
extended_attention_mask = (1.0 - extended_attention_mask) * (-10000.0)
# 将attention_mask为0的变为-10000, 1变为0, 方便softmax求注意力后,
# mask为0(也就是padding)的值完全消除(=0)
# 例如[1, 1, 1, 0, 0] ==> (0, 0, 0, -10000, -10000) embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(
embedding_output, extended_attention_mask, output_all_encoded_layers=True)
# sequence_output = encoded_layers[-1]
# pooled_output = self.pooler(sequence_output)
encoded_layers, hidden_layers = (
encoded_layers[-1], encoded_layers[output_hidden] # -4 倒数第四层
)
return encoded_layers, hidden_layers # shape (batch_size, hidden_size), (batch_size, hidden_size)

6. 多跳阅读理解模块

也就是论文中提到的系统一,即“下一跳实体名称”和“答案候选”的抽取,是通过预测每个位置是否是span开始或者结束的概率来确定,与BERT原文中的做法相同;

其中几个值得注意的细节,比如之所以将“下一跳实体名称”和“答案候选”分开,是因为前者更多关注语意相关性而后者则需要匹配疑问词;

第0个位置的输出被用来产生一个阈值,判断段落内是否包含有意义的“下一跳实体名称”或者“答案候选”。

该模块就是常规的抽取式阅读理解,解码出span的代码部分较为复杂~

注:这里没有用到GCN,GCN是用来作为系统二进行推理的模块。

 class BertForMultiHopQuestionAnswering(PreTrainedBertModel):
def __init__(self, config):
super(BertForMultiHopQuestionAnswering, self).__init__()
self.bert = BertModelPlus(config)
self.qa_outputs = nn.Linear(config.hidden_size, 4)
self.apply(self.init_bert_weights) # PreTrainedBertModel 初始化权重 def forward(self, input_ids,
token_type_ids=None,
attention_mask=None,
sep_positions=None,
hop_start_weights=None,
hop_end_weights=None,
ans_start_weights=None,
ans_end_weights=None,
B_starts=None,
allow_limit=(0, 0),
):
"""
从系统1抽取span (分为两个系统,具体看原文)
:param input_ids: LongTensor
(batch_size, max_len)
:param token_type_ids: LongTensor
The A/B Segmentation in BERTs. (batch, maxlen)
:param attention_mask: LongTensor
指示该位置是token还是padding (batch_size, maxlen)
:param sep_positions: LongTensor
[SEP]的具体位置 主要用来发现支持段落的句子 (batch_size, max_seps)
:param hop_start_weights: Tensor(默认为FloatTensor)
hop开始位置的标注情况
:param hop_end_weights: Tensor
hop结束位置的标注情况 (ground truth)
:param ans_start_weights: Tensor
答案标注开始位置的可能性(概率)
:param ans_end_weights: Tensor
答案标注结束位置的可能性(概率)
:param B_starts:
句子B的开始位置
:param allow_limit:
An Offset for negative threshold (负阈值的偏移量)
:return:
"""
batch_size = input_ids.size()[0]
device = input_ids.get_device() if input_ids.is_cuda else torch.device('cpu')
sequence_output, hidden_output = self.bert(input_ids, token_type_ids, attention_mask)
# 上面两者的shape都为: (batch_size, max_len, hidden_size)
semantics = hidden_output[:, 0] # shape: (batch_size, hidden_size) if sep_positions is None:
return semantics # 仅仅语义信息
else:
max_sep = sep_positions.size()[-1] # max_seps
if max_sep == 0:
empty = torch.zeros(batch_size, 0, dtype=torch.long, device=device) # mistake
return (
empty,
empty,
semantics,
empty,
) # Only semantics, used in eval, the same ``empty'' variable is a mistake in general cases but simple # 预测span
logits = self.qa_outputs(sequence_output)
hop_start_logits, hop_end_logits, ans_start_logits, ans_end_logits = logits.split(
split_size=1, dim=-1 # 前面的1代表单个分块的形状大小
) # 每个的形状为 (batch_size, max_len, 1)
hop_start_logits = hop_start_logits.squeeze(-1)
hop_end_logits = hop_end_logits.squeeze(-1)
ans_start_logits = ans_start_logits.squeeze(-1)
ans_end_logits = ans_end_logits.squeeze(-1) # Shape: [batch_size, max_len] if hop_start_weights is not None: # train mode (因为提供了标签信息:hop_start_weights等)
lgsf = nn.LogSoftmax(dim=1)
# 如果句子中没有目标span,start_weights = end_weights = 0(tensor)
# 以下四个求二元交叉熵loss
hop_start_loss = -torch.sum(hop_start_weights * lgsf(hop_start_logits), dim=-1)
hop_end_loss = -torch.sum(hop_end_weights * lgsf(hop_end_logits), dim=-1)
ans_start_loss = -torch.sum(ans_start_weights * lgsf(ans_start_logits), dim=-1)
ans_end_loss = -torch.sum(ans_end_weights * lgsf(ans_end_logits), dim=-1) hop_loss = torch.mean((hop_start_loss + hop_end_loss)) / 2
ans_loss = torch.mean((ans_start_loss + ans_end_loss)) / 2 else:
K_hop, K_ans = 3, 1
hop_preds = torch.zeros(batch_size, K_hop, 3, dtype=torch.long, device=device)
# (batch_size, 3, 3)
ans_preds = torch.zeros(batch_size, K_ans, 3, dtype=torch.long, device=device)
# (batch_size, 1, 3) ans_start_gap = torch.zeros(batch_size, device=device)
for u, (start_logits, end_logits, preds, K, allow) in enumerate(
(
(
hop_start_logits, # (batch_size, max_len)
hop_end_logits,
hop_preds, # (batch_size, 3, 3)
K_hop, #
allow_limit[0],
),
(
ans_start_logits,
ans_end_logits,
ans_preds, # (batch_size, 1, 3)
K_ans, #
allow_limit[1],
),
)
):
for i in range(batch_size):
# 对于batch_size里的每个样本,即每个文本
if sep_positions[i, 0] > 0:
values, indices = start_logits[i, B_starts[i]:].topk(K)
# B是文档,QA所对应的paragraph
# 取出前K大的概率值以及对应的位置index
for k, index in enumerate(indices): # 3个 或 1个(answer)
if values[k] <= start_logits[i, 0] - allow: # not golden
# 小tip: start_logits[i, 0] 代表一个置信度或叫阈值
# 来判断段落内是否有有意义的“下一跳实体名称”或者“答案候选”。
if u == 1: # For ans spans
ans_start_gap[i] = start_logits[i, 0] - values[k]
break
start = index + B_starts[i] # 输入文本中span所在的开始位置
# find ending 找到span的结束位置
for j, ending in enumerate(sep_positions[i]):
if ending > start or ending <= 0:
break # 找到ending所对应的支撑句子sep位置
if ending <= start:
break
ending = min(ending, start + 10)
end = torch.argmax(end_logits[i, start:ending]) + start
# 得到end span在文本中的结束位置
preds[i, k, 0] = start
preds[i, k, 1] = end
preds[i, k, 2] = j
return ((hop_loss, ans_loss, semantics)
if hop_start_weights is not None
else (hop_preds, ans_preds, semantics, ans_start_gap))

7. 认知图网络模块

 class CognitiveGCN(nn.Module):
"""
在认知图谱上进行推理计算,使用GCN实现隐式推理计算——每一步迭代,前续节点将变换过的信息传递到下一跳节点,
并更新目前的隐层表示。 在认知图谱扩展过程中,如果某被访问节点出现新的父节点(环状结构或汇集状结构),
表明此点获得新的线索信息(clues),需要重新扩展计算。最终算法流程借助前沿点(frontier nodes)队列形式实现。
"""
def __init__(self, hidden_size):
super(CognitiveGCN, self).__init__()
self.gcn = GCN(hidden_size)
self.both_net = MLP((hidden_size, hidden_size, 1))
self.select_net = MLP((hidden_size, hidden_size, 1)) def forward(self, bundle, model, device):
batch = bundle_part_to_batch(bundle)
batch = tuple(t.to(device) for t in batch)
hop_loss, ans_loss, semantics = model(
*batch
) # Shape of semantics: [num_para, hidden_size]
num_additional_nodes = len(bundle.additional_nodes)
if num_additional_nodes > 0:
max_length_additional = max([len(x) for x in bundle.additional_nodes])
# 取出最大长度——max_len
ids = torch.zeros(
(num_additional_nodes, max_length_additional),
dtype=torch.long,
device=device,
)
segment_ids = torch.zeros(
(num_additional_nodes, max_length_additional),
dtype=torch.long,
device=device,
)
input_mask = torch.zeros(
(num_additional_nodes, max_length_additional),
dtype=torch.long,
device=device,
)
# 得到对应的ids, segment_ids, input_mask
for i in range(num_additional_nodes):
length = len(bundle.additional_nodes[i]) # 对于邻接结点
ids[i, :length] = torch.tensor(
bundle.additional_nodes[i], dtype=torch.long
)
input_mask[i, :length] = 1 # mask为1 padding段相应变为0
additional_semantics = model(ids, segment_ids, input_mask) semantics = torch.cat((semantics, additional_semantics), dim=0) # 二者相拼接 assert semantics.size()[0] == bundle.adj.size()[0] # 等于邻接矩阵的结点数 if bundle.question_type == 0: # Wh-
pred = self.gcn(bundle.adj.to(device), semantics)
ce = torch.nn.CrossEntropyLoss()
final_loss = ce(
pred.unsqueeze(0),
torch.tensor([bundle.answer_id], dtype=torch.long, device=device),
)
else:
x, y, ans = bundle.answer_id
ans = torch.tensor(ans, dtype=torch.float, device=device)
diff_sem = semantics[x] - semantics[y]
classifier = self.both_net if bundle.question_type == 1 else self.select_net
final_loss = 0.2 * torch.nn.functional.binary_cross_entropy_with_logits(
classifier(diff_sem).squeeze(-1), ans.to(device)
)
return hop_loss, ans_loss, final_loss

具体详情待进一步补充。

Cognitive Graph for Multi-Hop Reading Comprehension at Scale(ACL2019) 阅读笔记与源码解析的更多相关文章

  1. 机器阅读理解综述Neural Machine Reading Comprehension Methods and Trends(略读笔记)

    标题:Neural Machine Reading Comprehension: Methods and Trends 作者:Shanshan Liu, Xin Zhang, Sheng Zhang, ...

  2. Reading Face, Read Health论文阅读笔记

    摘要: 随着计算技术觉.人工智能和移动技术的发展,利用计算机读脸技术去识别每个人每天的健康是可行的.怎么去设计一个基于FRT(face reading technologies)的用于得到每天的保健实 ...

  3. HDU4990 Reading comprehension —— 递推、矩阵快速幂

    题目链接:https://vjudge.net/problem/HDU-4990 Reading comprehension Time Limit: 2000/1000 MS (Java/Others ...

  4. hdu-4990 Reading comprehension(快速幂+乘法逆元)

    题目链接: Reading comprehension Time Limit: 2000/1000 MS (Java/Others)     Memory Limit: 32768/32768 K ( ...

  5. 论文选读二:Multi-Passage Machine Reading Comprehension with Cross-Passage Answer Verification

    论文选读二:Multi-Passage Machine Reading Comprehension with Cross-Passage Answer Verification 目前,阅读理解通常会给出 ...

  6. Attention-over-Attention Neural Networks for Reading Comprehension论文总结

    Attention-over-Attention Neural Networks for Reading Comprehension 论文地址:https://arxiv.org/pdf/1607.0 ...

  7. Deep Learning of Graph Matching 阅读笔记

    Deep Learning of Graph Matching 阅读笔记 CVPR2018的一篇文章,主要提出了一种利用深度神经网络实现端到端图匹配(Graph Matching)的方法. 该篇文章理 ...

  8. [图解tensorflow源码] [原创] Tensorflow 图解分析 (Session, Graph, Kernels, Devices)

    TF Prepare [图解tensorflow源码] 入门准备工作 [图解tensorflow源码] TF系统概述篇 Session篇 [图解tensorflow源码] Session::Run() ...

  9. openfalcon源码分析之graph

    openfalcon源码分析之graph 本节内容 graph功能 graph源码分析 2.1 graph中重要的数据结构 2.2 graph的简要流程图 2.3 graph处理数据过程 2.4 gr ...

随机推荐

  1. 《如何学习基于ARM嵌入式系统》笔记整理

    author:Peong time:20190603 如何学习基于ARM嵌入式系统 一.嵌入式系统的概念 从硬件上讲,将外围器件,与CPU集成在一起. 从操作系统上讲,定制符合要求的系统内核 从应用上 ...

  2. js初学者循环经典题目

    1.根据一个数字日期,判断这个日期是这一年的第几天例如: 2016和02和11,计算后结果为42 var y = 2016;//+prompt("请输入年份") ;         ...

  3. 百万年薪python之路 -- 面向对象初始

    面向对象初始 1.1 面向过程编程vs函数式编程 函数编程较之面向过程编程最明显的两个特点: 1,减少代码的重用性. 2,增强代码的可读性. 1.2 函数式编程vs面向对象编程 面向对象编程:是一类相 ...

  4. 手写一个简单的ElasticSearch SQL转换器(一)

    一.前言 之前有个需求,是使ElasticSearch支持使用SQL进行简单查询,较新版本的ES已经支持该特性(不过貌似还是实验性质的?) ,而且git上也有elasticsearch-sql 插件, ...

  5. OptimalSolution(1)--递归和动态规划(2)矩阵的最小路径和与换钱的最少货币数问题

    一.矩阵的最小路径和 1 3 5 9 1 4 9 18 1 4 9 18 8 1 3 4 9 9 5 8 12 5 0 6 1 14 14 5 11 12 8 8 4 0 22 22 13 15 12 ...

  6. MySQL基础篇(3)常用函数和运算符

    一.字符串函数(索引位置都从1开始) CONCAT(S1,S2,...Sn): 连接S1,S2,...Sn为一个字符串,任何字符串与NULL进行连接的结果都是NULL INSERT(str,x,y,i ...

  7. Mysql数据库(三)Mysql表结构管理

    一.MySQL数据类型 1.数字类型 (1)整数数据类型包括TINYINT/BIT/BOOL/SMALLINT/MEDIUMINT/INT/BIGINT (2)浮点数据类型包括FLOAT/DOUBLE ...

  8. UART中的硬件流控RTS与CTS

    最近太忙了,没时间写对Ucos-II的移植,先将工作中容易搞错的一个知识点记录下来,关于CTS与RTS的. 在RS232中本来CTS 与RTS 有明确的意义,但自从贺氏(HAYES ) 推出了聪明猫( ...

  9. springboot(3)——配置文件和自动配置原理详细讲解

    原文地址 目录 概述 1. 配置文件作用 2.配置文件位置 3.配置文件的定义 3.1如果是定义普通变量(数字 字符串 布尔) 3.2如果是定义对象.Map 3.3如果是定义数组 4.配置文件的使用 ...

  10. 20190723_C的三个小实现

    1. 有一个字符串开头或结尾含有n个空格(“    abcdefgdddd   ”),欲去掉前后的空格,返回一个新的字符串.a) 要求1:请自己定义一个接口(函数),并实现功能:b) 要求2:编写测试 ...