tensorflow在文本处理中的使用——CBOW词嵌入模型
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理
代码地址:https://github.com/nfmcclure/tensorflow-cookbook
数据:http://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz
CBOW概念图:

步骤如下:
- 必要包
- 声明模型参数
- 读取数据集
- 创建单词字典,转换句子列表为单词索引列表
- 生成批量数据
- 构建图
- 训练
step1:必要包
参考:tensorflow在文本处理中的使用——skip-gram模型
step2:声明模型参数
# Declare model parameters
batch_size = 500
embedding_size = 200
vocabulary_size = 2000
generations = 50000
model_learning_rate = 0.001 num_sampled = int(batch_size/2) # Number of negative examples to sample.
window_size = 3 # How many words to consider left and right. # Add checkpoints to training
save_embeddings_every = 5000
print_valid_every = 5000
print_loss_every = 100 # Declare stop words
stops = stopwords.words('english') # We pick some test words. We are expecting synonyms to appear
valid_words = ['love', 'hate', 'happy', 'sad', 'man', 'woman']
step3:读取数据集
step4:创建单词字典,转换句子列表为单词索引列表
step5:生成批量数据
看一下单步执行的中间结果,利于更好理解处理过程:
>>> rand_sentence=[, , , , , , , , , ]
>>> window_size = #类似skip-gram
>>> window_sequences = [rand_sentence[max((ix-window_size),):(ix+window_size+)] for ix, x in enumerate(rand_sentence)]
>>> label_indices = [ix if ix<window_size else window_size for ix,x in enumerate(window_sequences)]
>>> window_sequences
[[, , , ], [, , , , ], [, , , , , ], [, , , , , , ], [, , , , , , ], [, , , , , , ], [, , , , , , ], [, , , , , ], [, , , , ], [, , , ]]
>>> label_indices
[, , , , , , , , , ] #生成input和label
>>> batch_and_labels = [(x[:y] + x[(y+):], x[y]) for x,y in zip(window_sequences, label_indices)]
>>> batch_and_labels = [(x,y) for x,y in batch_and_labels if len(x)==*window_size]
>>> batch, labels = [list(x) for x in zip(*batch_and_labels)]
>>> batch_and_labels
[([, , , , , ], ), ([, , , , , ], ), ([, , , , , ], ), ([, , , , , ], )]
>>> batch
[[, , , , , ], [, , , , , ], [, , , , , ], [, , , , , ]]
>>> labels
[, , , ]
step6:构建图
# Define Embeddings:
embeddings = tf.Variable(tf.random_uniform([vocabulary_size, embedding_size], -1.0, 1.0)) # NCE loss parameters
nce_weights = tf.Variable(tf.truncated_normal([vocabulary_size, embedding_size], stddev=1.0 / np.sqrt(embedding_size)))
nce_biases = tf.Variable(tf.zeros([vocabulary_size])) # Create data/target placeholders
x_inputs = tf.placeholder(tf.int32, shape=[batch_size, 2*window_size])
y_target = tf.placeholder(tf.int32, shape=[batch_size, 1])
valid_dataset = tf.constant(valid_examples, dtype=tf.int32) # Lookup the word embedding
# Add together window embeddings:CBOW模型将上下文窗口内的单词嵌套叠加在一起
embed = tf.zeros([batch_size, embedding_size])
for element in range(2*window_size):
embed += tf.nn.embedding_lookup(embeddings, x_inputs[:, element]) # Get loss from prediction
loss = tf.reduce_mean(tf.nn.nce_loss(nce_weights, nce_biases, embed, y_target, num_sampled, vocabulary_size)) # Create optimizer
optimizer = tf.train.GradientDescentOptimizer(learning_rate=model_learning_rate).minimize(loss) # Cosine similarity between words计算验证单词集
norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True))
normalized_embeddings = embeddings / norm
valid_embeddings = tf.nn.embedding_lookup(normalized_embeddings, valid_dataset)
similarity = tf.matmul(valid_embeddings, normalized_embeddings, transpose_b=True) # Create model saving operation该方法默认会保存整个计算图会话,本例中指定参数只保存嵌套变量并设置名字
saver = tf.train.Saver({"embeddings": embeddings})
step7:训练
#Add variable initializer.
init = tf.initialize_all_variables()
sess.run(init) # Run the skip gram model.
print('Starting Training')
loss_vec = []
loss_x_vec = []
for i in range(generations):
batch_inputs, batch_labels = text_helpers.generate_batch_data(text_data, batch_size, window_size, method='cbow')
feed_dict = {x_inputs : batch_inputs, y_target : batch_labels} # Run the train step
sess.run(optimizer, feed_dict=feed_dict) # Return the loss
if (i+1) % print_loss_every == 0:
loss_val = sess.run(loss, feed_dict=feed_dict)
loss_vec.append(loss_val)
loss_x_vec.append(i+1)
print('Loss at step {} : {}'.format(i+1, loss_val)) # Validation: Print some random words and top 5 related words
if (i+1) % print_valid_every == 0:
sim = sess.run(similarity, feed_dict=feed_dict)
for j in range(len(valid_words)):
valid_word = word_dictionary_rev[valid_examples[j]]
top_k = 5 # number of nearest neighbors
nearest = (-sim[j, :]).argsort()[1:top_k+1]
log_str = "Nearest to {}:".format(valid_word)
for k in range(top_k):
close_word = word_dictionary_rev[nearest[k]]
log_str = '{} {},' .format(log_str, close_word)
print(log_str) # Save dictionary + embeddings
if (i+1) % save_embeddings_every == 0:
# Save vocabulary dictionary
with open(os.path.join(data_folder_name,'movie_vocab.pkl'), 'wb') as f:
pickle.dump(word_dictionary, f) # Save embeddings
model_checkpoint_path = os.path.join(os.getcwd(),data_folder_name,'cbow_movie_embeddings.ckpt')
save_path = saver.save(sess, model_checkpoint_path)
print('Model saved in file: {}'.format(save_path))
运行结果:

工作原理:Word2Vec嵌套的CBOW模型和skip-gram模型非常相似。主要不同点是生成数据和单词嵌套的处理。加载文本数据,归一化文本,创建词汇字典,使用词汇字典查找嵌套,组合嵌套并训练神经网络模型预测目标单词。
延伸学习:CBOW方法是在上下文窗口内单词嵌套叠加上进行训练并预测目标单词的。Word2Vec的CBOW方法更平滑,更适用于小文本数据集。
tensorflow在文本处理中的使用——CBOW词嵌入模型的更多相关文章
- tensorflow在文本处理中的使用——Doc2Vec情感分析
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——Word2Vec预测
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——skip-gram模型
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——TF-IDF算法
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——skip-gram & CBOW原理总结
摘自:http://www.cnblogs.com/pinard/p/7160330.html 先看下列三篇,再理解此篇会更容易些(个人意见) skip-gram,CBOW,Word2Vec 词向量基 ...
- tensorflow在文本处理中的使用——辅助函数
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——词袋
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- TensorFlow NMT的词嵌入(Word Embeddings)
本文转载自:http://blog.stupidme.me/2018/08/05/tensorflow-nmt-word-embeddings/,本站转载出于传递更多信息之目的,版权归原作者或者来源机 ...
- TensorFlow实现文本情感分析详解
http://c.biancheng.net/view/1938.html 前面我们介绍了如何将卷积网络应用于图像.本节将把相似的想法应用于文本. 文本和图像有什么共同之处?乍一看很少.但是,如果将句 ...
随机推荐
- 关于android SDK安装Failed to fetch URL http://dl-ssl.google.com/android/repository/addons_list-1.xml出错
最近SDK出问题了,然后在google下载了一个android-sdk-windows.rar,然后点击SDK Manager,结果一直不能刷新API Level,然后就开始在网上找了好多 ...
- Leetcode840.Magic Squares In Grid矩阵中的幻方
3 x 3 的幻方是一个填充有从 1 到 9 的不同数字的 3 x 3 矩阵,其中每行,每列以及两条对角线上的各数之和都相等. 给定一个由整数组成的 N × N 矩阵,其中有多少个 3 × 3 的 & ...
- Python发送邮件1(带附件的)
普通的发邮件(不使用类)
- AtCoder Beginner Contest 084 C - Special Trains
Special Trains Problem Statement A railroad running from west to east in Atcoder Kingdom is now comp ...
- 正则表达式问题:如何理解/href\s*=\s*(?:"(?<1>[^"]*)"|(?<1>\S+))/(转载)
ms-help://MS.VSCC/MS.MSDNVS.2052/jscript7/html/jsjsgrpregexpsyntax.htm 该文虽有解释, 但没有样例,对我这样的初学者来说很难理解 ...
- importError: DLL load failed when import matplotlib.pyplot as plt
importError: DLL load failed when import matplotlib.pyplot as plt 出现这种情况的原因, 大多是matplotlib的版本与python ...
- 基本数据类型的值传递 和引用数据类型的引用传递 Day06
ValueTest1.java package com.sxt.valuetest; /* * 基本数据类型的传递:传递的是值得副本 */ public class ValueTest1 { publ ...
- Java练习 SDUT-2746_大小写转换
大小写转换 Time Limit: 1000 ms Memory Limit: 65536 KiB Problem Description X现在要学习英文以及各种稀奇古怪的字符的了.现在他想把一串字 ...
- Laravel 的HTTP控制器
简介# 除了在路有文件中以闭包的形式定义所有的请求处理逻辑外,还可以使用控制器类来组织此类行为,控制器能够将相关 的请求处理逻辑组成的一个单独的类,控制器被存放在app/Http/Controller ...
- NoSQL之简介
简介 NoSQL(NoSQL=Not Only SQL),意即'不仅仅是"SQL".泛指非关系型的数据库.是一项全新的数据库革命性运动. 在现代的计算系统上每天网络上会产生庞大的数 ...