import tensorflow as tf
# 22 scope (name_scope/variable_scope)
from __future__ import print_function class TrainConfig:
batch_size = 20
time_steps = 20
input_size = 10
output_size = 2
cell_size = 11
learning_rate = 0.01 class TestConfig(TrainConfig):
time_steps = 1 class RNN(object): def __init__(self, config):
self._batch_size = config.batch_size
self._time_steps = config.time_steps
self._input_size = config.input_size
self._output_size = config.output_size
self._cell_size = config.cell_size
self._lr = config.learning_rate
self._built_RNN() def _built_RNN(self):
with tf.variable_scope('inputs'):
self._xs = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._input_size], name='xs')
self._ys = tf.placeholder(tf.float32, [self._batch_size, self._time_steps, self._output_size], name='ys')
with tf.name_scope('RNN'):
with tf.variable_scope('input_layer'):
l_in_x = tf.reshape(self._xs, [-1, self._input_size], name='2_2D') # (batch*n_step, in_size)
# Ws (in_size, cell_size)
Wi = self._weight_variable([self._input_size, self._cell_size])
print(Wi.name)
# bs (cell_size, )
bi = self._bias_variable([self._cell_size, ])
# l_in_y = (batch * n_steps, cell_size)
with tf.name_scope('Wx_plus_b'):
l_in_y = tf.matmul(l_in_x, Wi) + bi
l_in_y = tf.reshape(l_in_y, [-1, self._time_steps, self._cell_size], name='2_3D') with tf.variable_scope('cell'):
cell = tf.contrib.rnn.BasicLSTMCell(self._cell_size)
with tf.name_scope('initial_state'):
self._cell_initial_state = cell.zero_state(self._batch_size, dtype=tf.float32) self.cell_outputs = []
cell_state = self._cell_initial_state
for t in range(self._time_steps):
if t > 0: tf.get_variable_scope().reuse_variables()
cell_output, cell_state = cell(l_in_y[:, t, :], cell_state)
self.cell_outputs.append(cell_output)
self._cell_final_state = cell_state with tf.variable_scope('output_layer'):
# cell_outputs_reshaped (BATCH*TIME_STEP, CELL_SIZE)
cell_outputs_reshaped = tf.reshape(tf.concat(self.cell_outputs, 1), [-1, self._cell_size])
Wo = self._weight_variable((self._cell_size, self._output_size))
bo = self._bias_variable((self._output_size,))
product = tf.matmul(cell_outputs_reshaped, Wo) + bo
# _pred shape (batch*time_step, output_size)
self._pred = tf.nn.relu(product) # for displacement with tf.name_scope('cost'):
_pred = tf.reshape(self._pred, [self._batch_size, self._time_steps, self._output_size])
mse = self.ms_error(_pred, self._ys)
mse_ave_across_batch = tf.reduce_mean(mse, 0)
mse_sum_across_time = tf.reduce_sum(mse_ave_across_batch, 0)
self._cost = mse_sum_across_time
self._cost_ave_time = self._cost / self._time_steps with tf.variable_scope('trian'):
self._lr = tf.convert_to_tensor(self._lr)
self.train_op = tf.train.AdamOptimizer(self._lr).minimize(self._cost) @staticmethod
def ms_error(y_target, y_pre):
return tf.square(tf.subtract(y_target, y_pre)) @staticmethod
def _weight_variable(shape, name='weights'):
initializer = tf.random_normal_initializer(mean=0., stddev=0.5, )
return tf.get_variable(shape=shape, initializer=initializer, name=name) @staticmethod
def _bias_variable(shape, name='biases'):
initializer = tf.constant_initializer(0.1)
return tf.get_variable(name=name, shape=shape, initializer=initializer) if __name__ == '__main__':
train_config = TrainConfig() #定义train_config
test_config = TestConfig() # # the wrong method to reuse parameters in train rnn
# with tf.variable_scope('train_rnn'):
# train_rnn1 = RNN(train_config)
# with tf.variable_scope('test_rnn'):
# test_rnn1 = RNN(test_config) # the right method to reuse parameters in train rnn
#目的使train的RNN调用参数,然后利用variable_scope方法共享RNN,让test的RNN再次调用一样的参数,
with tf.variable_scope('rnn') as scope:
sess = tf.Session()
train_rnn2 = RNN(train_config)
scope.reuse_variables() #告诉TF想重复利用RNN的参数
test_rnn2 = RNN(test_config)
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)

  

TF之RNN:实现利用scope.reuse_variables()告诉TF想重复利用RNN的参数的案例—Jason niu的更多相关文章

  1. TF之RNN:TF的RNN中的常用的两种定义scope的方式get_variable和Variable—Jason niu

    # tensorflow中的两种定义scope(命名变量)的方式tf.get_variable和tf.Variable.Tensorflow当中有两种途径生成变量 variable import te ...

  2. 深度学习原理与框架-递归神经网络-RNN_exmaple(代码) 1.rnn.BasicLSTMCell(构造基本网络) 2.tf.nn.dynamic_rnn(执行rnn网络) 3.tf.expand_dim(增加输入数据的维度) 4.tf.tile(在某个维度上按照倍数进行平铺迭代) 5.tf.squeeze(去除维度上为1的维度)

    1. rnn.BasicLSTMCell(num_hidden) #  构造单层的lstm网络结构 参数说明:num_hidden表示隐藏层的个数 2.tf.nn.dynamic_rnn(cell, ...

  3. TF之RNN:matplotlib动态演示之基于顺序的RNN回归案例实现高效学习逐步逼近余弦曲线—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...

  4. TF之RNN:基于顺序的RNN分类案例对手写数字图片mnist数据集实现高精度预测—Jason niu

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  5. TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架

    TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html

  6. TF之RNN:TensorBoard可视化之基于顺序的RNN回归案例实现蓝色正弦虚线预测红色余弦实线—Jason niu

    import tensorflow as tf import numpy as np import matplotlib.pyplot as plt BATCH_START = 0 TIME_STEP ...

  7. TF:利用sklearn自带数据集使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线—Jason niu

    import tensorflow as tf from sklearn.datasets import load_digits #from sklearn.cross_validation impo ...

  8. TF:Tensorflow结构简单应用,随机生成100个数,利用Tensorflow训练使其逼近已知线性直线的效率和截距—Jason niu

    import os os.environ[' import tensorflow as tf import numpy as np x_data = np.random.rand(100).astyp ...

  9. 深度学习原理与框架-图像补全(原理与代码) 1.tf.nn.moments(求平均值和标准差) 2.tf.control_dependencies(先执行内部操作) 3.tf.cond(判别执行前或后函数) 4.tf.nn.atrous_conv2d 5.tf.nn.conv2d_transpose(反卷积) 7.tf.train.get_checkpoint_state(判断sess是否存在

    1. tf.nn.moments(x, axes=[0, 1, 2])  # 对前三个维度求平均值和标准差,结果为最后一个维度,即对每个feature_map求平均值和标准差 参数说明:x为输入的fe ...

随机推荐

  1. Ajax中onreadystatechange函数不执行,是因为放在open()后

    今天动手写Ajax时遇到的问题:按照下面的顺序来写Ajax,功能是alert出txt文档的内容,在Chrome下可以执行onreadystatechange中的事件,在IE11及以下.FF就不能执行o ...

  2. SecureCRT中sqlplus,使用Backspace删除时 ^H^H

    平时习惯用Backspace删除输入错误,但是在SecureCRT中使用是,却是: SQL> sele^H^H 网上有几个方法,觉得改SecureCRT的配置最方便.

  3. Confluence 6 发送 Confluence 通知到其他 Confluence 服务器

    你可以配置 Confluence 服务器向其他的 Confluence 服务器发送消息.在这种情况下,Confluence 服务器将不会显示 workbox. 希望发送消息到其他 Confluence ...

  4. Burp Scanner Report

    1.使用application web 漏洞平台,除此之外还有一款类似的工具 叫做mulidata,其实mulidata功能更好一点. 2.配置之前的问题处理 安装之前要确认 自己之前是否安装过 Ap ...

  5. 2018.8.1 状压 CF482C 题解

    noip2016考了一道状压dp,一道期望dp 然而这题是状压期望dp... 所以难度是什么,省选noi吗... 怎么办... 题目大意: 给定n个字符串,甲从中任选出一个串(即选出每个串的概率相同为 ...

  6. spring cloud 使用ribbon简单处理客户端负载均衡

    假如我们的multiple服务的访问量剧增,用一个服务已经无法承载, 我们可以把Hello World服务做成一个集群. 很简单,我们只需要复制Hello world服务,同时将原来的端口8762修改 ...

  7. Python零基础入门之Tkinter的对话框

    这篇博客主要是总结一下Tkinter中的对话框的使用,值得一提的是自从python3.0之后关于关于对话框的模块(messagebox.filedialog.colorchooser)都被收归到了tk ...

  8. IDM的Google商店插件

    官方扩展链接:https://chrome.google.com/webstore/detail/idm-integration-module/ngpampappnmepgilojfohadhhmbh ...

  9. WCF三种通信方式

    一.概述 WCF在通信过程中有三种模式:请求与答复.单向.双工通信.以下我们一一介绍. 二.请求与答复模式 描述: 客户端发送请求,然后一直等待服务端的响应(异步调用除外),期间处于假死状态,直到服务 ...

  10. C# 属性(Property)和字段(Field)的区别

    导读: 近期学习过程中发现了一些问题,我的学习只是学习,敲代码就是敲代码,没有加入思考,也不问为什么就直接去敲人家写好的例子去敲,把知识都学死了,逐渐散失了思考能力,所以学习的兴趣大打折扣,正如那句话 ...