转自:http://www.jianshu.com/p/fab82fa53e16

1.tf中的nce_loss的API

def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
num_true=1,
sampled_values=None,
remove_accidental_hits=False,
partition_strategy="mod",
name="nce_loss")

假设nce_loss之前的输入数据是K维的,一共有N个类,那么

  • weight.shape = (N, K)
  • bias.shape = (N)
  • inputs.shape = (batch_size, K)
  • labels.shape = (batch_size, num_true)
  • num_true : 实际的正样本个数
  • num_sampled: 采样出多少个负样本
  • num_classes = N
  • sampled_values: 采样出的负样本,如果是None,就会用不同的sampler去采样。待会儿说sampler是什么。
  • remove_accidental_hits: 如果采样时不小心采样到的负样本刚好是正样本,要不要干掉
  • partition_strategy:对weights进行embedding_lookup时并行查表时的策略。TF的embeding_lookup是在CPU里实现的,这里需要考虑多线程查表时的锁的问题。

nce_loss的实现逻辑如下:

  • _compute_sampled_logits: 通过这个函数计算出正样本和采样出的负样本对应的output和label
  • sigmoid_cross_entropy_with_logits: 通过 sigmoid cross entropy来计算output和label的loss,从而进行反向传播。这个函数把最后的问题转化为了num_sampled+num_real个两类分类问题,然后每个分类问题用了交叉熵的损伤函数,也就是logistic regression常用的损失函数。TF里还提供了一个softmax_cross_entropy_with_logits的函数,和这个有所区别。

2.tf中word2vec实现

loss = tf.reduce_mean(
tf.nn.nce_loss(nce_weights, nce_biases, embed, train_labels,
num_sampled, vocabulary_size))

它这里并没有传sampled_values,那么它的负样本是怎么得到的呢?继续看nce_loss的实现,可以看到里面处理sampled_values=None的代码如下:

if sampled_values is None:
sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
true_classes=labels,
num_true=num_true,
num_sampled=num_sampled,
unique=True,
range_max=num_classes)

所以,默认情况下,他会用log_uniform_candidate_sampler去采样。那么log_uniform_candidate_sampler是怎么采样的呢?他的实现在这里

  • 他会在[0, range_max)中采样出一个整数k
  • P(k) = (log(k + 2) - log(k + 1)) / log(range_max + 1)

可以看到,k越大,被采样到的概率越小。那么在TF的word2vec里,类别的编号有什么含义吗?看下面的代码:

def build_dataset(words):
count = [['UNK', -1]]
count.extend(collections.Counter(words).most_common(vocabulary_size - 1))
dictionary = dict()
for word, _ in count:
dictionary[word] = len(dictionary)
data = list()
unk_count = 0
for word in words:
if word in dictionary:
index = dictionary[word]
else:
index = 0 # dictionary['UNK']
unk_count += 1
data.append(index)
count[0][1] = unk_count
reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
return data, count, dictionary, reverse_dictionary

可以看到,TF的word2vec实现里,词频越大,词的类别编号也就越小。因此,在TF的word2vec里,负采样的过程其实就是优先采词频高的词作为负样本。

在提出负采样的原始论文中, 包括word2vec的原始C++实现中。是按照热门度的0.75次方采样的,这个和TF的实现有所区别。但大概的意思差不多,就是越热门,越有可能成为负样本。

Tf中的NCE-loss实现学习【转载】的更多相关文章

  1. tf中的run()与eval()【转载】

    转自:https://blog.csdn.net/jiaoyangwm/article/details/79248535  1.eval() 其实就是tf.Tensor的Session.run() 的 ...

  2. tf中计算图 执行流程学习【转载】

    转自:https://blog.csdn.net/dcrmg/article/details/79028003 https://blog.csdn.net/qian99/article/details ...

  3. Java多线程学习(转载)

    Java多线程学习(转载) 时间:2015-03-14 13:53:14      阅读:137413      评论:4      收藏:3      [点我收藏+] 转载 :http://blog ...

  4. 项目中使用Quartz集群分享--转载

    项目中使用Quartz集群分享--转载 在公司分享了Quartz,发布出来,希望大家讨论补充. CRM使用Quartz集群分享  一:CRM对定时任务的依赖与问题  二:什么是quartz,如何使用, ...

  5. 浅谈Java中的深拷贝和浅拷贝(转载)

    浅谈Java中的深拷贝和浅拷贝(转载) 原文链接: http://blog.csdn.net/tounaobun/article/details/8491392 假如说你想复制一个简单变量.很简单: ...

  6. ArcGIS中的坐标系定义与转换 (转载)

    原文:ArcGIS中的坐标系定义与转换 (转载) 1.基准面概念:  GIS中的坐标系定义由基准面和地图投影两组参数确定,而基准面的定义则由特定椭球体及其对应的转换参数确定,因此欲正确定义GIS系统坐 ...

  7. 如何设置Win7系统中的上帝模式GodMode(转载)

    如何设置Win7系统中的上帝模式GodMode(转载) NT6系统中隐藏了一个秘密的“GodMode”,字面上译为“上帝模式”.God Mode其实就是一个简单的文件夹窗口,但包含了几乎所有系统的设置 ...

  8. TF中conv2d和kernel_initializer方法

    conv2d中的padding 在使用TF搭建CNN的过程中,卷积的操作如下 convolution = tf.nn.conv2d(X, filters, strides=[1,2,2,1], pad ...

  9. (原)关于MEPG-2中的TS流数据格式学习

    关于MEPG-2中的TS流数据格式学习 Author:lihaiping1603 原创:http://www.cnblogs.com/lihaiping/p/8572997.html 本文主要记录了, ...

随机推荐

  1. day_5.12 py 老王开枪demo

    ps:2018-7-24 21:00:04 其实这部分主要是面向对象的复习!而不是面向过程 #!/usr/bin/env/python #-*-coding:utf-8-*- ''' 2018-5-1 ...

  2. C#打印标签

    一个复杂的标签包括一个复杂的表格样式和二维码.条形码等内容.所以如果直接绘制的方式将会非常的麻烦,所以采用使用的方案是使用模板的方式:1.使用Excel创建出想要的模板的样式.2.对模板中的动态内容进 ...

  3. Fis3构建迁移Webpack之路

    Webpack从2015年9月第一个版本横空初始至今已逾2载.它的出现,颠覆了一大批主流构建如Ant.Grunt和Gulp等等.腾讯NOW直播IVWEB团队之前一直采用Fis构建,本篇文章主要介绍从F ...

  4. 怎样将M4A音频格式转换成MP3格式

    因为MP3音频格式应用的广泛性,所以很多时候我们都需要将不同的音频格式转换成MP3格式的,那么如果我们需要将M4A音频格式转换成MP3格式,我们应该怎样进行实现呢?下面我们就一起来看一下吧. 操作步骤 ...

  5. ES6常用对象操作

    ES6常用对象操作 一. const 简单类型数据常量 // const实际上保证的,并不是变量的值不得改动,而是变量指向的那个内存地址不得改动.对于简单类型的数据(数值.字符串.布尔值),值就保存在 ...

  6. 【每日一题】 UVA - 1599 Ideal Path 字典序最短路

    题解:给一个1e5个点2e5条边,每个边有一个值,让你输出一条从1到n边的路径使得:条数最短的前提下字典序最小. 题解:bfs一次找最短路(因为权值都是1,不用dijkstra),再bfs一次存一下路 ...

  7. MySQL命令:约束

    约束是一种限制,它通过对表的行或列的数据做出限制,来确保表的数据的完整性.唯一性 约束分类: 约束类型与关键字: 主键      PRIMARY KEY 默认值  DEFAULT 唯一      UN ...

  8. vulnerability test

    vegas ---go--https://zhuanlan.zhihu.com/p/21826478 locust---python jmeter--java---http://www.cnblogs ...

  9. 微博登录报错 sso package orsign error

    https://blog.csdn.net/gao_chun/article/details/41344725 (1)检查应用包名签名信息是否完善 如果你的应用只有一个包名.签名,请在 http:// ...

  10. 转:HashMap实现原理分析(面试问题:两个hashcode相同 的对象怎么存入hashmap的)

    原文地址:https://www.cnblogs.com/faunjoe88/p/7992319.html 主要内容: 1)put   疑问:如果两个key通过hash%Entry[].length得 ...