使用句子中出现单词的Vector加权平均进行文本相似度分析虽然简单,但也有比较明显的缺点:没有考虑词序且词向量区别不明确。如下面两个句子:

  • “北京的首都是中国”与“中国的首都是北京”的相似度为1。
  • “学习容易”和“学习困难”的相似度很容易也非常高。
    为解决这类问题,需要用其他方法对句子进行表示,LSTM是常用的一种方式,本文简单使用单层LSTM对句子重新表示,并通过若干全连接层对句子相似度进行衡量。
数据准备

训练和测试数据包括两个待比较句子以及其相似度(0-1):

测试数据格式相似。

语料编码

自然语言无法直接作为神经网络输入,需进行编码该部分包括以下步骤:

  • 读人训练和测试数据,分词,并给每个词编号。
  • 根据词编号,进一步生成每个句子的编号向量,句子采用固定长度,不足的位置补零。
  • 保存词编号到文件,保存词向量矩阵方便预测使用。

中文分词使用jieba分词工具,词的编号则使用Keras的Tokenizer:

1
2
3
4
5
6
7
8
print("Fit tokenizer...")
tokenizer = Tokenizer(num_words=MAX_NB_WORDS, lower=False)
tokenizer.fit_on_texts(texts_1 + texts_2 + test_texts_1 + test_texts_2)
if save:
print("Save tokenizer...")
if not os.path.exists(save_path):
os.makedirs(save_path)
cPickle.dump(tokenizer, open(os.path.join(save_path, tokenizer_name), "wb"))

其中texts_1 、texts_2 、test_texts_1 、 test_texts_2的元素分别为训练数据和测试数据的分词后的列表,如:

1
["我", "是", "谁"]

经过上面的过程 tokenizer保存了语料中出现过的词的编号映射。

1
2
> print tokenizer.word_index
{"我": 2, "是":1, "谁":3}

利用tokenizer对语料中的句子进行编号

1
2
3
> sequences_1 = tokenizer.texts_to_sequences(texts_1)
> print sequences_1
[[2 1 3], ...]

最终生成固定长度(假设为10)的句子编号列表

1
2
3
> data_1 = pad_sequences(sequences_1, maxlen=MAX_SEQUENCE_LENGTH)
> print data_1
[[0 0 0 0 0 0 0 2 1 3], ...]

data_1即可作为神经网络的输入。

词向量映射

在对句子进行编码后,需要准备句子中词的词向量映射作为LSTM层的输入。这里使用预训练的词向量(这里)参数,生成词向量映射矩阵:

1
2
3
4
5
6
word2vec = Word2Vec.load(EMBEDDING_FILE)
embedding_matrix = np.zeros((nb_words, EMBEDDING_DIM))
for word, i in word_index.items():
if word in word2vec.wv.vocab:
embedding_matrix[i] = word2vec.wv.word_vec(word)
np.save(embedding_matrix_path, embedding_matrix)
网络结构

该神经网络采用简单的单层LSTM+全连接层对数据进行训练,网络结构图:

网络由Keras实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32 大专栏  LSTM 句子相似度分析
def ():
embedding_layer = Embedding(nb_words,
EMBEDDING_DIM,
weights=[embedding_matrix],
input_length=MAX_SEQUENCE_LENGTH,
trainable=False)
lstm_layer = LSTM(num_lstm, dropout=rate_drop_lstm, recurrent_dropout=rate_drop_lstm) sequence_1_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences_1 = embedding_layer(sequence_1_input)
y1 = lstm_layer(embedded_sequences_1) sequence_2_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences_2 = embedding_layer(sequence_2_input)
y2 = lstm_layer(embedded_sequences_2) merged = concatenate([y1, y2])
merged = Dropout(rate_drop_dense)(merged)
merged = BatchNormalization()(merged) merged = Dense(num_dense, activation=act)(merged)
merged = Dropout(rate_drop_dense)(merged)
merged = BatchNormalization()(merged)
preds = Dense(1, activation='sigmoid')(merged) model = Model(inputs=[sequence_1_input, sequence_2_input],
outputs=preds)
model.compile(loss='binary_crossentropy',
optimizer='nadam',
metrics=['acc'])
model.summary()
return model

该部分首先定义embedding_layer作为输入层和LSTM层的映射层,将输入的句子编码映射为词向量列表作为LSTM层的输入。两个LSTM的输出拼接后作为全连接层的输入,经过Dropout和BatchNormalization正则化,最终输出结果进行训练。

训练与预测

训练采用nAdam以及EarlyStopping,保存训练过程中验证集上效果最好的参数。最终对测试集进行预测。

1
2
3
4
5
6
7
8
9
10
11
model = get_model()
early_stopping = EarlyStopping(monitor='val_loss', patience=3)
bst_model_path = STAMP + '.h5'
model_checkpoint = ModelCheckpoint(bst_model_path, save_best_only=True, save_weights_only=True) hist = model.fit([data_1, data_2], labels,
validation_data=([val_1, val_2], labels),
epochs=100, batch_size=10, shuffle=True, callbacks=[early_stopping, model_checkpoint])
predicts = model.predict([data_1, data_2], batch_size=10, verbose=1)
for i in range(len(test_ids)):
print "t1: %s, t2: %s, score: %s" % (test_1[i], test_2[i], predicts[i])
小结

该网络在Kaggle Quora数据集val验证可达到80%左右的准确率,应用于中文,由于数据集有限,产生了较大的过拟合。此外在Tokenizer.fit_on_texts应用于中文时,不支持Unicode编码,可以对其源码方法进行重写,加入Ascii字符和Unicode的转换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
'''
this part is solve keras.preprocessing.text can not process unicode
'''
def text_to_word_sequence(text,
filters='!"#$%&()*+,-./:;<=>?@[\]^_`{|}~tn',
lower=True, split=" "):
if lower: text = text.lower()
if type(text) == unicode:
translate_table = {ord(c): ord(t) for c, t in zip(filters, split * len(filters))}
else:
translate_table = keras.maketrans(filters, split * len(filters))
text = text.translate(translate_table)
seq = text.split(split)
return [i for i in seq if i] keras.preprocessing.text.text_to_word_sequence = text_to_word_sequence
项目源码https://github.com/zqhZY/semanaly/

更多关注公众号:

LSTM 句子相似度分析的更多相关文章

  1. 机器学习 - LSTM应用之情感分析

    1. 概述 在情感分析的应用领域,例如判断某一句话是positive或者是negative的案例中,咱们可以通过传统的standard neuro network来作为解决方案,但是传统的神经网络在应 ...

  2. 相似度分析,循环读入文件(加入了HanLP,算法第四版的库)

    相似度分析的,其中的分词可以采用HanLP即可: http://www.open-open.com/lib/view/open1421978002609.htm /****************** ...

  3. 文本离散表示(三):TF-IDF结合n-gram进行关键词提取和文本相似度分析

    这是文本离散表示的第二篇实战文章,要做的是运用TF-IDF算法结合n-gram,求几篇文档的TF-IDF矩阵,然后提取出各篇文档的关键词,并计算各篇文档之间的余弦距离,分析其相似度. TF-IDF与n ...

  4. Java利用hanlp完成语句相似度分析的案例详解

    分享一篇hanlp分词工具使用的小案例,即利用hanlp分词工具分析两个中文语句的相似度的案例.供大家一起学习参考! 在做考试系统需求时,后台题库系统提供录入题目的功能.在录入题目的时候,由于题目来源 ...

  5. 文本相似度分析(基于jieba和gensim)

    基础概念 本文在进行文本相似度分析过程分为以下几个部分进行, 文本分词 语料库制作 算法训练 结果预测 分析过程主要用两个包来实现jieba,gensim jieba:主要实现分词过程 gensim: ...

  6. 八大排序算法详解(动图演示 思路分析 实例代码java 复杂度分析 适用场景)

    一.分类 1.内部排序和外部排序 内部排序:待排序记录存放在计算机随机存储器中(说简单点,就是内存)进行的排序过程. 外部排序:待排序记录的数量很大,以致于内存不能一次容纳全部记录,所以在排序过程中需 ...

  7. 八大排序算法——堆排序(动图演示 思路分析 实例代码java 复杂度分析)

    一.动图演示 二.思路分析 先来了解下堆的相关概念:堆是具有以下性质的完全二叉树:每个结点的值都大于或等于其左右孩子结点的值,称为大顶堆:或者每个结点的值都小于或等于其左右孩子结点的值,称为小顶堆.如 ...

  8. 八大排序算法——希尔(shell)排序(动图演示 思路分析 实例代码java 复杂度分析)

    一.动图演示 二.思路分析 希尔排序是把记录按下标的一定增量分组,对每组使用直接插入排序算法排序:随着增量逐渐减少,每组包含的关键词越来越多,当增量减至1时,整个文件恰被分成一组,算法便终止. 简单插 ...

  9. 八大排序算法——基数排序(动图演示 思路分析 实例代码java 复杂度分析)

    一.动图演 二.思路分析 基数排序第i趟将待排数组里的每个数的i位数放到tempj(j=1-10)队列中,然后再从这十个队列中取出数据,重新放到原数组里,直到i大于待排数的最大位数. 1.数组里的数最 ...

随机推荐

  1. 2019ICPC 上海网络赛 L. Digit sum(二维树状数组+区间求和)

    https://nanti.jisuanke.com/t/41422 题目大意: 给出n和b,求1到n,各数在b进制下各位数之和的总和. 直接暴力模拟,TLE.. 没想到是要打表...还是太菜了. # ...

  2. 计算KS值的标准代码

    计算KS值的标准代码 from scipy.stats import ks_2samp get_ks = lambda y_pred,y_true: ks_2samp(y_pred[y_true==1 ...

  3. 讯飞语音的中的bug用户校验失败

    用户校验失败:原因是目录没有复制粘贴正确. 下面是刚刚下载的SDK目录: 下面的是自己Android工程中的目录:注意复制粘贴的文件路径要正确

  4. Linux常见指令x-mind

  5. [单调队列]XKC's basketball team

    XKC's basketball team 题意:给定一个序列,从每一个数后面比它大至少 \(m\) 的数中求出与它之间最大的距离.如果没有则为 \(-1\). 题解:从后向前维护一个递增的队列,从后 ...

  6. redhat下libreoffice 的安装

    1.第一次安装libreoffic时是用网络yum源安装的,但是装好之后不能用,找了好久没有找出问题,后来从官网下载安装包后安装就可以了. 下载地址:https://zh-cn.libreoffice ...

  7. rsync配置文件

    vim /etc/rsyncd.conf motd file = /etc/rsyncd.motd #设置服务器信息提示文件,在该文件中编写提示信息 transfer logging = yes #开 ...

  8. nginx应用geoip模块,实现不同地区访问不同页面的需求(实践版)

    https://www.52os.net/articles/configure-nginx-using-geoip-allow-whitelist.html       搞了几天没有搞定,这篇文章一下 ...

  9. DateTimePicket jQuery 日期插件,开始时间和结束时间示例

    需要引入的js文件: <input type="text" id="startTime" placeholder="开始时间"/> ...

  10. [LC] 106. Construct Binary Tree from Inorder and Postorder Traversal

    Given inorder and postorder traversal of a tree, construct the binary tree. Note:You may assume that ...