引子

在上次的 《word2vector论文笔记》中大致介绍了两种词向量训练方法的原理及优劣,这篇咱们以skip-gram算法为例来代码实践一把。

当前教程参考:A Word2Vec Keras tutorial

导库

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Dot, Embedding
from tensorflow.keras.losses import cosine_similarity
from tensorflow.keras.preprocessing.sequence import skipgrams
from tensorflow.keras.preprocessing import sequence import urllib
import collections
import os
import zipfile import numpy as np
import tensorflow as tf
tf.__version__
'2.0.0'

数据下载与预处理

将数据下载到本地,若本地已有数据则根据文件大小判断文件是否正确

def maybe_download(filename, url, expected_bytes):
"""Download a file if not present, and make sure it's the right size."""
if not os.path.exists(filename):
filename, _ = urllib.urlretrieve(url + filename, filename)
statinfo = os.stat(filename)
if statinfo.st_size == expected_bytes:
print('Found and verified', filename)
else:
print(statinfo.st_size)
raise Exception(
'Failed to verify ' + filename + '. Can you get to it with a browser?')
return filename

读取本地数据,输出为一个单词列表

def read_data(filename):
"""Extract the first file enclosed in a zip file as a list of words."""
with zipfile.ZipFile(filename) as f:
data = tf.compat.as_str(f.read(f.namelist()[0])).split()
return data

构造数据集

输入:单词列表、词典大小

输出:转为int后的单词列表、词频统计表、word2index字典、index2word字典

def build_dataset(words, n_words):
"""Process raw inputs into a dataset."""
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(n_words - 1))
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
data = list()
unk_count = 0
for word in words:
if word in dictionary:
index = dictionary[word]
else:
index = 0 # dictionary['UNK']
unk_count += 1
data.append(index)
count[0][1] = unk_count
reversed_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data, count, dictionary, reversed_dictionary

借助以上几个功能函数,构造我们需要的数据集。并打印查看案例

def collect_data(vocabulary_size=10000):
url = 'http://mattmahoney.net/dc/'
filename = maybe_download('text8.zip', url, 31344016)
vocabulary = read_data(filename)
print(vocabulary[:7])
data, count, dictionary, reverse_dictionary = build_dataset(vocabulary,
vocabulary_size)
del vocabulary # Hint to reduce memory.
return data, count, dictionary, reverse_dictionary vocab_size = 10000
data, count, dictionary, reverse_dictionary = collect_data(vocabulary_size=vocab_size)
print(data[:7])
Found and verified text8.zip
['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse']
[5234, 3081, 12, 6, 195, 2, 3134]

构建模型

参数设定

window_size: 即每次sample样本的时候以当前词为中心向左右取词时的窗口大小。举例当前词(target)为Wi,则上下文词(context)为Wi-3,Wi-2,Wi-1,Wi+1,Wi+2,Wi+3

vector_dim:词向量的大小

epochs:模型要训练的轮数,这个有些夸张。论文中通常训练一轮或两轮下来模型收敛效果就已经很不错。跟语料规模也有很大关系。

window_size = 3
vector_dim = 300
epochs = 100000 valid_size = 16 # Random set of words to evaluate similarity on.
valid_window = 100 # Only pick dev samples in the head of the distribution.
valid_examples = np.random.choice(valid_window, valid_size, replace=False)

sequence.make_sampling_table函数用于构造一个根据词频进行抽样的抽样分布,不同词的抽样概率按照以下公式进行计算。此分布用于下一步生成skip-gram特定的样本。

P_(word) = min(1, \frac {sqrt{\frac {wordfrequency} {samplingfactor} }} {\frac {wordfrequency} {samplingfactor}})
sampling_table = sequence.make_sampling_table(vocab_size)

skipgrams 是tf提供的用于专门构造skip-gram训练样本的工具函数,只需要提供单词序列、词典大小、窗口大小、抽样表即可。如以下示例,couples对象第一个元素为target词,第二个为context词,labels表示当前词对是否处于同一条语料的指定窗口大小范围内。

couples, labels = skipgrams(data, vocab_size, window_size=window_size, sampling_table=sampling_table)
word_target, word_context = zip(*couples)
word_target = np.array(word_target, dtype="int32")
word_context = np.array(word_context, dtype="int32") print(couples[:10], labels[:10])
[[5964, 1], [99, 379], [1385, 700], [4770, 1991], [3118, 9262], [4488, 6982], [4708, 6615], [9269, 7965], [4374, 5294], [4236, 6354]] [1, 1, 0, 1, 0, 0, 0, 0, 0, 0]

到此,我们用于模型训练的数据已经准备好啦!

关于softmax和负采样

我们看softmax的公式,为了做概率归一化,每次都需要针对词典中每个词计算其e为底的指数计算,这样当词典中的次数较大时计算成本是比较高昂的。

也有其他工作采用hierarchical softmax做为输出层。

skip-gram 则为此提出了下采样的二分类训练方法。

简言之就是根据窗口大小构造许多单词对,有些单词对确实是在目标词窗口范围内有共现的,称为正样本;反之,也根据不同词出现的频率对其他

非上下文的词进行采样,与目标词构成负样本单词对。我们的模型仅仅需要二分类就好。

创建模型

模型构建

模型本身很简单

把两个词分别输入,经过embedding层lookup到各自的词向量,通过向量点乘来衡量相似度,最后接一个单神经元的全连接层,通过sigmoid对输出进行激活,输出就是当前词对是否共现的概率判断,处于0-1之间。

损失函数为二分类损失binary_crossentropy

优化器选择rmsprop

# create some input variables
input_target = Input((1,))
input_context = Input((1,)) embedding = Embedding(vocab_size, vector_dim, input_length=1, name='embedding')
target = embedding(input_target)
target = Reshape((vector_dim, 1))(target)
context = embedding(input_context)
context = Reshape((vector_dim, 1))(context) # now perform the dot product operation to get a similarity measure
dot_product = Dot(axes=1)([target, context])
dot_product = Reshape((1,))(dot_product)
# add the sigmoid output layer
output = Dense(1, activation='sigmoid')(dot_product)
# create the primary training model
model = Model(inputs=[input_target, input_context], outputs=output)
optimizer = tf.keras.optimizers.Adam(0.001)
model.compile(loss='binary_crossentropy', optimizer=optimizer)

模型检查

打印模型结构并检查,与我们预期一致。

共300W可训练参数

model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
input_2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
embedding (Embedding) (None, 1, 300) 3000000 input_1[0][0]
input_2[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 300, 1) 0 embedding[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 300, 1) 0 embedding[1][0]
__________________________________________________________________________________________________
dot (Dot) (None, 1, 1) 0 reshape[0][0]
reshape_1[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 1) 0 dot[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 1) 2 reshape_2[0][0]
==================================================================================================
Total params: 3,000,002
Trainable params: 3,000,002
Non-trainable params: 0
__________________________________________________________________________________________________

查找相似词,用于验证词向量训练效果

这里构造一个相似词查找回调类,它会针对最常见的几个词,寻找当前词向量训练结果下与其相似度最高的top_k个词。

class SimilarityCallback:
def run_sim(self):
for i in range(valid_size):
valid_word = reverse_dictionary[valid_examples[i]]
top_k = 8 # number of nearest neighbors
sim = self._get_sim(valid_examples[i])
nearest = (-sim).argsort()[1:top_k + 1]
log_str = 'Nearest to %s:' % valid_word
for k in range(top_k):
close_word = reverse_dictionary[nearest[k]]
log_str = '%s %s,' % (log_str, close_word)
print(log_str) @staticmethod
def _get_sim(valid_word_idx):
sim = np.zeros((vocab_size,))
in_arr1 = np.zeros((1,))
in_arr2 = np.zeros((1,))
in_arr1[0,] = valid_word_idx
for i in range(vocab_size):
in_arr2[0,] = i
out = cosine_similarity(in_arr1, in_arr2)
sim[i] = out
return sim
sim_cb = SimilarityCallback()

训练并验证模型效果

每过1000个batch打印一次loss

每过10000个batch打印一次相似词查找结果

可以看到刚开始时由于模型权重随机初始化,针对高频词找出的相似词基本都没什么意义。

经过几万个batch的训练,找出的相似词已经与相应的高频词具有较强的相关性了。

arr_1 = np.zeros((1,))
arr_2 = np.zeros((1,))
arr_3 = np.zeros((1,))
for cnt in range(epochs):
idx = np.random.randint(0, len(labels)-1)
arr_1[0,] = word_target[idx]
arr_2[0,] = word_context[idx]
arr_3[0,] = labels[idx]
loss = model.train_on_batch([arr_1, arr_2], arr_3)
if cnt % 1000 == 0:
print("Iteration {}, loss={}".format(cnt, loss))
if cnt % 10000 == 0:
sim_cb.run_sim()

word2vector代码实践的更多相关文章

  1. ReactiveCocoa代码实践之-更多思考

    三.ReactiveCocoa代码实践之-更多思考 1. RACObserve()宏形参写法的区别 之前写代码考虑过 RACObserve(self.timeLabel , text) 和 RACOb ...

  2. ReactiveCocoa代码实践之-RAC网络请求重构

    前言 RAC相比以往的开发模式主要有以下优点:提供了统一的消息传递机制:提供了多种奇妙且高效的信号操作方法:配合MVVM设计模式和RAC宏绑定减少多端依赖. RAC的理论知识非常深厚,包含有FRP,高 ...

  3. 深刻理解Python中的元类(metaclass)--代码实践

    根据http://blog.jobbole.com/21351/所作的代码实践. 这篇讲得不错,但以我现在的水平,用到的机会是很少的啦... #coding=utf-8 class ObjectCre ...

  4. Java的BIO和NIO很难懂?用代码实践给你看,再不懂我转行!

    本文原题“从实践角度重新理解BIO和NIO”,原文由Object分享,为了更好的内容表现力,收录时有改动. 1.引言 这段时间自己在看一些Java中BIO和NIO之类的东西,也看了很多博客,发现各种关 ...

  5. TextCNN代码实践

    在上文<TextCNN论文解读>中已经介绍了TextCNN的原理,本文通过tf2.0来做代码实践. 数据集:来自中文任务基准测评的数据集IFLYTEK 导库 import os impor ...

  6. 机器学习(四):通俗理解支持向量机SVM及代码实践

    上一篇文章我们介绍了使用逻辑回归来处理分类问题,本文我们讲一个更强大的分类模型.本文依旧侧重代码实践,你会发现我们解决问题的手段越来越丰富,问题处理起来越来越简单. 支持向量机(Support Vec ...

  7. ReactiveCocoa代码实践之-UI组件的RAC信号操作

    上一节是自己对网络层的一些重构,本节是自己一些代码小实践做出的一些demo程序,基本涵盖大多数UI控件操作. 一.用UISlider实现调色板 假设我们现在做一个demo,上面有一个View用来展示颜 ...

  8. iOS代码实践总结

    转载地址:http://mobile.51cto.com/hot-492236.htm 最近一个月除了专门抽时间和精力重构之外,还有就是遇到需要添加功能的模块的时候,由于项目中的代码历史因素比较多,第 ...

  9. 使用 DartPad 制作代码实践教程

    DartPad 是一个开源的.在浏览器中体验和运行 Dart 编程语言的线上编辑器,目标是为了帮助开发者更好地了解 Dart 编程语言以及 Flutter 应用开发. DartPad 项目起始于 20 ...

随机推荐

  1. 【Oracle】substr()函数详解

    Oracle的substr函数简单用法 substr(字符串,截取开始位置,截取长度) //返回截取的字 substr('Hello World',0,1) //返回结果为 'H'  *从字符串第一个 ...

  2. mysql 1449 : The user specified as a definer ('usertest'@'%') does not exist 解决方法 (grant 授予权限)

    从服务器上迁移数据库到本地localhost 执行  函数  时报错, mysql 1449 : The user specified as a definer ('usertest'@'%') do ...

  3. USB限流芯片,4.8A最大,过压关闭6V

    PW1503,PW1502是超低RDS(ON)开关,具有可编程的电流限制,以保护电源源于过电流和短路保护.它具有超温保护以及反向闭锁功能. PW1503,PW1502采用薄型(1毫米)5针薄型SOT2 ...

  4. PAT Advanced 1004 Counting Leaves

    题目与翻译 1004 Counting Leaves 数树叶 (30分) A family hierarchy is usually presented by a pedigree tree. You ...

  5. ElasticSearch-命令行客户端操作

    1.引言 实际开发中,主要有三种方式可以作为elasticsearch服务的客户端: 第一种,elasticsearch-head插件(可视化工具) 第二种,使用elasticsearch提供的Res ...

  6. MariaDB(selec的使用)

      --查询基本使用 -- 查询所有列 --select * from 表名 select * from students;   --一定条件查询 select * from students whe ...

  7. QUIC协议分析-基于quic-go

    quic协议分析 QUIC是由谷歌设计的一种基于UDP的传输层网络协议,并且已经成为IETF草案.HTTP/3就是基于QUIC协议的.QUIC只是一个协议,可以通过多种方法来实现,目前常见的实现有Go ...

  8. IE双击打不开解决办法

    方法1 [百度电脑专家]一键修复 建议下载并安装[百度电脑专家],官网:http://zhuanjia.baidu.com .打开[百度电脑专家],在搜索框内输入"IE修复",在搜 ...

  9. 关于BI测试

    BI测试: BI是从数据接入.数据准备.数据分析.数据可视化到数bai据分发应用的一系列过程,目的是为了辅助企业高效决策.而报表虽然最终也实现了数据可视化,但是对于数据分析的维度.深度.颗粒度.实时性 ...

  10. (Oracle)取当前日期的最近工作日

      描述:现有一需求,日期表中存放了日期和是否节假日(0-工作日,1-节假日),现在需要取日期表中的最近的工作日.如2017/07/23(周日)最近的工作日应该是2017/07/21(周五).     ...