https://github.com/chenghuige/tensorflow-exp/blob/master/examples/sparse-tensor-classification/

 
 

tensorflow-exp/example/sparse-tensor-classification/train-validate.py

当你需要train的过程中validate的时候,如果用placeholder来接收输入数据
那么一个compute graph可以完成这个任务。如果你用的是TFRecord的方式
输入嵌入到compute graph,那么对应input(for train), input_1(for validate),就会产生两个compute graph,但是要注意的是validate过程中需要share使用等同于train过程的w_h等变量,如果直接build两次graph就回阐释下面的示意图

 
 

这种并没有共享 w_h等数据,因此validate 会有问题(注意Input_1里面对应的w_h_1)

cost, accuracy = build_graph(X, label)

_, accuracy_test = build_graph((index_test, value_test), label_test)

train_op = gen_optimizer(cost, FLAGS.learning_rate)

#train_op_test = gen_optimizer(cost_test, FLAGS.learning_rate)

 
 

来自 <http://git.oschina.net/chenghuige/tensorflow-exp/blob/master/example/sparse-tensor-classification/train-validate.py?dir=0&filepath=example%2Fsparse-tensor-classification%2Ftrain-validate.py&oid=04e0aca92d157121cac257125e2c6a66f68c1e4c&sha=b5f3b6b833ddbb99cc2c9ea763a59a3ab5c564b7>

这里
tf.get_variable_scope().reuse_variables()并不起作用,因为build_graph里面并没有使用ge_variable机制

 
 

第一种解决方案
用类 self.w_h

解决此类问题的方法之一就是使用类来创建模块,在需要的地方使用类来小心地管理他们需要的变量. 一个更高明的做法,不用调用类,而是利用TensorFlow 提供了变量作用域 机制,当构建一个视图时,很容易就可以共享命名过的变量.

 
 

来自 <http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/variable_scope/index.html>

使用类的方式,共享w_h等变量

class Mlp(object):

def __init__(self):

hidden_size = 200

num_features = NUM_FEATURES

num_classes = NUM_CLASSES

with tf.device('/cpu:0'):

self.w_h = init_weights([num_features, hidden_size], name = 'w_h')

self.b_h = init_bias([hidden_size], name = 'b_h')

self.w_o = init_weights([hidden_size, num_classes], name = 'w_o')

self.b_o = init_bias([num_classes], name = 'b_o')

 
 

def model(self, X, w_h, b_h, w_o, b_o):

h = tf.nn.relu(matmul(X, w_h) + b_h)

return tf.matmul(h, w_o) + b_o

 

def forward(self, X):

py_x = self.model(X, self.w_h, self.b_h, self.w_o, self.b_o)

return py_x

 
 

X = (index, value)

algo = Mlp()

cost, accuracy = build_graph(X, label, algo)

cost_test, accuracy_test = build_graph((index_test, value_test), label_test, algo)

train_op = gen_optimizer(cost, FLAGS.learning_rate)

 
 

类似这种做法的例子tensorflow/tensorflow/models/embedding/word2vec.py

第二中
变量共享

 
 

 

变量作用域机制在TensorFlow中主要由两部分组成:

  • tf.get_variable(<name>, <shape>, <initializer>): 通过所给的名字创建或是返回一个变量.
  • tf.variable_scope(<scope_name>): 通过 tf.get_variable()为变量名指定命名空间.

方法 tf.get_variable() 用来获取或创建一个变量,而不是直接调用tf.Variable.它采用的不是像`tf.Variable这样直接获取值来初始化的方法.一个初始化就是一个方法,创建其形状并且为这个形状提供一个张量.这里有一些在TensorFlow中使用的初始化变量:

 
 

代码

https://github.com/chenghuige/tensorflow-exp/blob/master/examples/sparse-tensor-classification/train-validate-share.py

 
 

来自 <http://wiki.jikexueyuan.com/project/tensorflow-zh/how_tos/variable_scope/index.html>

 
 

 
 

https://github.com/chenghuige/tensorflow-exp/blob/master/examples/sparse-tensor-classification/的更多相关文章

  1. https://github.com/tensorflow/models/blob/master/research/slim/datasets/preprocess_imagenet_validation_data.py 改编版

    #!/usr/bin/env python # Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apach ...

  2. iOS - 解决Unable to add a source with url `https://github.com/CocoaPods/Specs.git` named

    1  本来cocopods没有问题,最近创建项目,利用cocopods导入第三方库的时候,出现如下错误: [!] Unable to add a source with url `https://gi ...

  3. Git - could not read Username for 'https://github.com',push报错解决办法

    执行git push命令异常,如下: git -c diff.mnemonicprefix=false -c core.quotepath=false -c credential.helper=sou ...

  4. fatal: unable to access 'https://github.com/Homebrew/brew/'

    最近安装 Homebrew 遇到的坑,总结一下. 我的 Mac 版本是 10.13.6. 首先安装 Homebrew /usr/bin/ruby -e "$(curl -fsSL https ...

  5. git 解决 error: failed to push some refs to 'https://github.com/xxxx.git'

    在github远程创建仓库后, 利用gitbash进行提交本地文件的时候出现如下错误 [root@foundation38 demo]# git push -u origin master Usern ...

  6. 结对项目https://github.com/bxoing1994/test/blob/master/源代码

    所选项目名称:文本替换      结对人:曲承玉 github地址 :https://github.com/bxoing1994/test/blob/master/源代码 结对人github地址:ht ...

  7. https://github.com/python/cpython/blob/master/Doc/library/contextlib.rst 被同一个线程多次获取的同步基元组件

    # -*- coding: utf-8 -*- import time from threading import Lock, RLock from datetime import datetime ...

  8. https://github.com/golang/crypto/blob/master/bcrypt/bcrypt.go

    https://github.com/golang/crypto/blob/master/bcrypt/bcrypt.go

  9. https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/connections.py

    # Python implementation of the MySQL client-server protocol # http://dev.mysql.com/doc/internals/en/ ...

随机推荐

  1. django创建新项目anministrator问题

    1.app中models的class可以通过migrations命令生成相应的数据表 2.此时并未写入数据库,migrate命令可以把相应的改动更新到数据库中 3.createsuperuser命令创 ...

  2. Android中使用ShareSDK集成分享功能

    引言      现在APP开发集成分享功能已经是非常普遍的需求了.其他集成分享技术我没有使用过,今天我就来介绍下使用ShareSDK来进行分享功能开发的一些基本步骤和注意点,帮助朋友们避免一些坑.好了 ...

  3. Bubble Cup 8 finals I. Robots protection (575I)

    题意: 有一个正方形区域, 要求支持两个操作: 1.放置三角形,给定放置方向(有4个方向,直角边与坐标轴平行),直角顶点坐标,边长 2.查询一个点被覆盖了多少次 1<=正方形区域边长n<= ...

  4. AssetBundle

    AssetBundle是Unity推荐的一种资源打包方式,与不使用AssetBundle相比,它有如下优点: 1.AssetBundle是经过LZMA压缩过的,所以体积更小. 2.可以将AssetBu ...

  5. (一)SQL Server分区详解Partition(目录)

    一.SQL Server分区介绍 在SQL Server中,数据库的所有表和索引都视为已分区表和索引,默认这些表和索引值包含一个分区:也就是说表或索引至少包含一个分区.SQL Server中数据是按水 ...

  6. SQL 从指定表筛选指定行信息 获取表行数

    1.获取指定表的行数 --获取表中数据行数 --select max([列名]) from 表名 2.筛选指定表的指定行数据(数据表分页获取) http://www.cnblogs.com/morni ...

  7. 【IOS】自定义可点击的多文本跑马灯YFRollingLabel

    需求 项目中需要用到跑马灯来仅展示一条消息,长度合适则不滚动,过长则循环滚动. 虽然不是我写的,但看了看代码,是在一个UIView里面放入两个UILabel, 在前一个快结束的时候,另一个显示.然而点 ...

  8. [cocos] ( 01 ) cocos2d-x 3.x 开发 环境配置

    有几个需要注意的问题 Windows上使用时, Unable to execute dex: Multiple dex files define 在eclipse中libcoco2dx的Java Bu ...

  9. 读取图像,LUT以及计算耗时

    使用LUT(lookup table)检索表的方法,提高color reduce时对像素读取的速度. 实现对Mat对象中数据的读取,并计算color reduce的速度. 方法一:使用Mat的ptr( ...

  10. 小白 安装和配置Tomcat 局域网内访问网页

    1.官网http://tomcat.apache.org/  ,下载tomcat,解压就好 2.官网www.oracle.com, 下载javaJDK,截图如下,点击黄色荧光笔