引子

在上次的 《word2vector论文笔记》中大致介绍了两种词向量训练方法的原理及优劣,这篇咱们以skip-gram算法为例来代码实践一把。

当前教程参考:A Word2Vec Keras tutorial

导库

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Dot, Embedding
from tensorflow.keras.losses import cosine_similarity
from tensorflow.keras.preprocessing.sequence import skipgrams
from tensorflow.keras.preprocessing import sequence import urllib
import collections
import os
import zipfile import numpy as np
import tensorflow as tf
tf.__version__
'2.0.0'

数据下载与预处理

将数据下载到本地,若本地已有数据则根据文件大小判断文件是否正确

def maybe_download(filename, url, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
filename, _ = urllib.urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', filename)
else:
print(statinfo.st_size)
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename

读取本地数据,输出为一个单词列表

def read_data(filename):
"""Extract the first file enclosed in a zip file as a list of words."""
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data

构造数据集

输入:单词列表、词典大小

输出:转为int后的单词列表、词频统计表、word2index字典、index2word字典

def build_dataset(words, n_words):
"""Process raw inputs into a dataset."""
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(n_words - 1))
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
data = list()
unk_count = 0
for word in words:
if word in dictionary:
index = dictionary[word]
else:
index = 0 # dictionary['UNK']
unk_count += 1
data.append(index)
count[0][1] = unk_count
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data, count, dictionary, reversed_dictionary

借助以上几个功能函数,构造我们需要的数据集。并打印查看案例

def collect_data(vocabulary_size=10000):
url = 'http://mattmahoney.net/dc/'
filename = maybe_download('text8.zip', url, 31344016)
vocabulary = read_data(filename)
print(vocabulary[:7])
data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
vocabulary_size)
del vocabulary # Hint to reduce memory.
return data, count, dictionary, reverse_dictionary vocab_size = 10000
data, count, dictionary, reverse_dictionary = collect_data(vocabulary_size=vocab_size)
print(data[:7])
Found and verified text8.zip
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']
[5234, 3081, 12, 6, 195, 2, 3134]

构建模型

参数设定

window_size: 即每次sample样本的时候以当前词为中心向左右取词时的窗口大小。举例当前词(target)为Wi,则上下文词(context)为Wi-3,Wi-2,Wi-1,Wi+1,Wi+2,Wi+3

vector_dim:词向量的大小

epochs:模型要训练的轮数,这个有些夸张。论文中通常训练一轮或两轮下来模型收敛效果就已经很不错。跟语料规模也有很大关系。

window_size = 3
vector_dim = 300
epochs = 100000 valid_size = 16 # Random set of words to evaluate similarity on.
valid_window = 100 # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)

sequence.make_sampling_table函数用于构造一个根据词频进行抽样的抽样分布,不同词的抽样概率按照以下公式进行计算。此分布用于下一步生成skip-gram特定的样本。

P_(word) = min(1, \frac {sqrt{\frac {wordfrequency} {samplingfactor} }} {\frac {wordfrequency} {samplingfactor}})
sampling_table = sequence.make_sampling_table(vocab_size)

skipgrams 是tf提供的用于专门构造skip-gram训练样本的工具函数,只需要提供单词序列、词典大小、窗口大小、抽样表即可。如以下示例,couples对象第一个元素为target词,第二个为context词,labels表示当前词对是否处于同一条语料的指定窗口大小范围内。

couples, labels = skipgrams(data, vocab_size, window_size=window_size, sampling_table=sampling_table)
word_target, word_context = zip(*couples)
word_target = np.array(word_target, dtype="int32")
word_context = np.array(word_context, dtype="int32") print(couples[:10], labels[:10])
[[5964, 1], [99, 379], [1385, 700], [4770, 1991], [3118, 9262], [4488, 6982], [4708, 6615], [9269, 7965], [4374, 5294], [4236, 6354]] [1, 1, 0, 1, 0, 0, 0, 0, 0, 0]

到此,我们用于模型训练的数据已经准备好啦!

关于softmax和负采样

我们看softmax的公式,为了做概率归一化,每次都需要针对词典中每个词计算其e为底的指数计算,这样当词典中的次数较大时计算成本是比较高昂的。

也有其他工作采用hierarchical softmax做为输出层。

skip-gram 则为此提出了下采样的二分类训练方法。

简言之就是根据窗口大小构造许多单词对,有些单词对确实是在目标词窗口范围内有共现的,称为正样本;反之,也根据不同词出现的频率对其他

非上下文的词进行采样,与目标词构成负样本单词对。我们的模型仅仅需要二分类就好。

创建模型

模型构建

模型本身很简单

把两个词分别输入,经过embedding层lookup到各自的词向量,通过向量点乘来衡量相似度,最后接一个单神经元的全连接层,通过sigmoid对输出进行激活,输出就是当前词对是否共现的概率判断,处于0-1之间。

损失函数为二分类损失binary_crossentropy

优化器选择rmsprop

# create some input variables
input_target = Input((1,))
input_context = Input((1,)) embedding = Embedding(vocab_size, vector_dim, input_length=1, name='embedding')
target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context) # now perform the dot product operation to get a similarity measure
dot_product = Dot(axes=1)([target, context])
dot_product = Reshape((1,))(dot_product)
# add the sigmoid output layer
output = Dense(1, activation='sigmoid')(dot_product)
# create the primary training model
model = Model(inputs=[input_target, input_context], outputs=output)
optimizer = tf.keras.optimizers.Adam(0.001)
model.compile(loss='binary_crossentropy', optimizer=optimizer)

模型检查

打印模型结构并检查,与我们预期一致。

共300W可训练参数

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
embedding (Embedding) (None, 1, 300) 3000000 input_1[0][0]
input_2[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 300, 1) 0 embedding[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 300, 1) 0 embedding[1][0]
__________________________________________________________________________________________________
dot (Dot) (None, 1, 1) 0 reshape[0][0]
reshape_1[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 1) 0 dot[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 1) 2 reshape_2[0][0]
==================================================================================================
Total params: 3,000,002
Trainable params: 3,000,002
Non-trainable params: 0
__________________________________________________________________________________________________

查找相似词,用于验证词向量训练效果

这里构造一个相似词查找回调类,它会针对最常见的几个词,寻找当前词向量训练结果下与其相似度最高的top_k个词。

class SimilarityCallback:
def run_sim(self):
for i in range(valid_size):
valid_word = reverse_dictionary[valid_examples[i]]
top_k = 8 # number of nearest neighbors
sim = self._get_sim(valid_examples[i])
nearest = (-sim).argsort()[1:top_k + 1]
log_str = 'Nearest to %s:' % valid_word
for k in range(top_k):
close_word = reverse_dictionary[nearest[k]]
log_str = '%s %s,' % (log_str, close_word)
print(log_str) @staticmethod
def _get_sim(valid_word_idx):
sim = np.zeros((vocab_size,))
in_arr1 = np.zeros((1,))
in_arr2 = np.zeros((1,))
in_arr1[0,] = valid_word_idx
for i in range(vocab_size):
in_arr2[0,] = i
out = cosine_similarity(in_arr1, in_arr2)
sim[i] = out
return sim
sim_cb = SimilarityCallback()

训练并验证模型效果

每过1000个batch打印一次loss

每过10000个batch打印一次相似词查找结果

可以看到刚开始时由于模型权重随机初始化,针对高频词找出的相似词基本都没什么意义。

经过几万个batch的训练,找出的相似词已经与相应的高频词具有较强的相关性了。

arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
idx = np.random.randint(0, len(labels)-1)
arr_1[0,] = word_target[idx]
arr_2[0,] = word_context[idx]
arr_3[0,] = labels[idx]
loss = model.train_on_batch([arr_1, arr_2], arr_3)
if cnt % 1000 == 0:
print("Iteration {}, loss={}".format(cnt, loss))
if cnt % 10000 == 0:
sim_cb.run_sim()

word2vector代码实践的更多相关文章

  1. ReactiveCocoa代码实践之-更多思考

    三.ReactiveCocoa代码实践之-更多思考 1. RACObserve()宏形参写法的区别 之前写代码考虑过 RACObserve(self.timeLabel , text) 和 RACOb ...

  2. ReactiveCocoa代码实践之-RAC网络请求重构

    前言 RAC相比以往的开发模式主要有以下优点:提供了统一的消息传递机制:提供了多种奇妙且高效的信号操作方法:配合MVVM设计模式和RAC宏绑定减少多端依赖. RAC的理论知识非常深厚,包含有FRP,高 ...

  3. 深刻理解Python中的元类(metaclass)--代码实践

    根据http://blog.jobbole.com/21351/所作的代码实践. 这篇讲得不错,但以我现在的水平,用到的机会是很少的啦... #coding=utf-8 class ObjectCre ...

  4. Java的BIO和NIO很难懂?用代码实践给你看,再不懂我转行!

    本文原题“从实践角度重新理解BIO和NIO”,原文由Object分享,为了更好的内容表现力,收录时有改动. 1.引言 这段时间自己在看一些Java中BIO和NIO之类的东西,也看了很多博客,发现各种关 ...

  5. TextCNN代码实践

    在上文<TextCNN论文解读>中已经介绍了TextCNN的原理,本文通过tf2.0来做代码实践. 数据集:来自中文任务基准测评的数据集IFLYTEK 导库 import os impor ...

  6. 机器学习(四):通俗理解支持向量机SVM及代码实践

    上一篇文章我们介绍了使用逻辑回归来处理分类问题,本文我们讲一个更强大的分类模型.本文依旧侧重代码实践,你会发现我们解决问题的手段越来越丰富,问题处理起来越来越简单. 支持向量机(Support Vec ...

  7. ReactiveCocoa代码实践之-UI组件的RAC信号操作

    上一节是自己对网络层的一些重构,本节是自己一些代码小实践做出的一些demo程序,基本涵盖大多数UI控件操作. 一.用UISlider实现调色板 假设我们现在做一个demo,上面有一个View用来展示颜 ...

  8. iOS代码实践总结

    转载地址:http://mobile.51cto.com/hot-492236.htm 最近一个月除了专门抽时间和精力重构之外,还有就是遇到需要添加功能的模块的时候,由于项目中的代码历史因素比较多,第 ...

  9. 使用 DartPad 制作代码实践教程

    DartPad 是一个开源的.在浏览器中体验和运行 Dart 编程语言的线上编辑器,目标是为了帮助开发者更好地了解 Dart 编程语言以及 Flutter 应用开发. DartPad 项目起始于 20 ...

随机推荐

  1. 【Oracle】win7安装报错

    在WIN7上安装oracle 10g时,提示如下信息: 正在检查操作系统要求... 要求的结果: 5.0,5.1,5.2,6.0 之一 实际结果: 6.1 检查完成.此次检查的总体结果为: 失败 &l ...

  2. Arduino—学习笔记—基础语法

    图解 函数具体讲解 pinMode(工作接脚,模式) 工作接脚 工作接脚编号(0--13与A0--A5) 模式 工作模式:INPUT或OUTPUT 例子 将8接口设置为输出模式 pinMode(8,O ...

  3. Ice系列--强大如我IceGrid

    前言 IceGrid是一个提供服务定位和服务激活的组件,但它的功能远不止于此.从它的命名可以看出它的设计理念-网格计算(grid computing).网格计算被定义为由一系列关联的廉价计算机组成的计 ...

  4. 大数据系列4:Yarn以及MapReduce 2

    系列文章: 大数据系列:一文初识Hdfs 大数据系列2:Hdfs的读写操作 大数据谢列3:Hdfs的HA实现 通过前文,我们对Hdfs的已经有了一定的了解,本文将继续之前的内容,介绍Yarn与Yarn ...

  5. 将连续增长 N 次字符串所需的内存重分配次数从必定 N 次降低为最多 N 次 二进制安全

    SDS 与 C 字符串的区别 - Redis 设计与实现 http://redisbook.com/preview/sds/different_between_sds_and_c_string.htm ...

  6. wireshark使用手册

    Wireshark的过滤器 使用Wireshark时最常见的问题,是当您使用默认设置时,会得到大量冗余信息,以至于很难找到自己需要的部分. 过犹不及. 这就是为什么过滤器会如此重要.它们可以帮助我们在 ...

  7. 分别简述computed和watch的使用场景

    computed: 当一个属性受多个属性影响的时候就需要用到computed 最典型的栗子: 购物车商品结算的时候watch: 当一条数据影响多条数据的时候就需要用watch 栗子:搜索数据

  8. mysql本地中127.0.0.1连接不上数据库怎么办

    首先在本地使用Navicat for MySQL建立一个bai数据库.在dreamweaver中建立一个PHP格式的网页,方便链接测试.测试发du现,如果zhi无法使用localhost链接mysql ...

  9. LOJ10078

    CQOI 2005 重庆城里有 n 个车站,m 条双向公路连接其中的某些车站.每两个车站最多用一条公路连接,从任何一个车站出发都可以经过一条或者多条公路到达其他车站,但不同的路径需要花费的时间可能不同 ...

  10. LOJ1036

    AHOI 2008 聚会 Y 岛风景美丽宜人,气候温和,物产丰富.Y 岛上有 N 个城市,有 N-1 条城市间的道路连接着它们.每一条道路都连接某两个城市.幸运的是,小可可通过这些道路可以走遍 Y 岛 ...