tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
params=np.random.normal(loc=0.0,scale=1.0,size=[10,10])
encoder_inputs=tf.placeholder(dtype=tf.int32,shape=[10,10])
decoder_inputs=tf.placeholder(dtype=tf.int32,shape=[10,10])
logits=tf.placeholder(dtype=tf.float32,shape=[10,10,10])
targets=tf.placeholder(dtype=tf.int32,shape=[10,10])
weights=tf.placeholder(dtype=tf.float32,shape=[10,10])
train_encoder_inputs=np.ones(shape=[10,10],dtype=np.int32)
train_decoder_inputs=np.ones(shape=[10,10],dtype=np.int32)
train_weights=np.ones(shape=[10,10],dtype=np.float32)
num_encoder_symbols=10
num_decoder_symbols=10
embedding_size=10
cell=tf.nn.rnn_cell.BasicLSTMCell(10)
def seq2seq(encoder_inputs,decoder_inputs,cell,num_encoder_symbols,num_decoder_symbols,embedding_size):
encoder_inputs = tf.unstack(encoder_inputs, axis=0)
decoder_inputs = tf.unstack(decoder_inputs, axis=0)
results,states=tf.contrib.legacy_seq2seq.embedding_rnn_seq2seq(
encoder_inputs,
decoder_inputs,
cell,
num_encoder_symbols,
num_decoder_symbols,
embedding_size,
output_projection=None,
feed_previous=False,
dtype=None,
scope=None
)
return results
def get_loss(logits,targets,weights):
loss=tf.contrib.seq2seq.sequence_loss(
logits,
targets=targets,
weights=weights
)
return loss
results=seq2seq(encoder_inputs,decoder_inputs,cell,num_encoder_symbols,num_decoder_symbols,embedding_size)
logits=tf.stack(results,axis=0)
print(logits)
loss=get_loss(logits,targets,weights)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
results_value=sess.run(results,feed_dict={encoder_inputs:train_encoder_inputs,decoder_inputs:train_decoder_inputs})
print(type(results_value[0]))
print(len(results_value))
cost = sess.run(loss, feed_dict={encoder_inputs: train_encoder_inputs, targets: train_decoder_inputs,
weights:train_weights,decoder_inputs:train_decoder_inputs})
print(cost)
更多教程:http://www.tensorflownews.com/
tf.contrib.seq2seq.sequence_loss example:seqence loss 实例代码的更多相关文章
- 学习笔记TF044:TF.Contrib组件、统计分布、Layer、性能分析器tfprof
TF.Contrib,开源社区贡献,新功能,内外部测试,根据反馈意见改进性能,改善API友好度,API稳定后,移到TensorFlow核心模块.生产代码,以最新官方教程和API指南参考. 统计分布.T ...
- 深度学习原理与框架-递归神经网络-RNN网络基本框架(代码?) 1.rnn.LSTMCell(生成单层LSTM) 2.rnn.DropoutWrapper(对rnn进行dropout操作) 3.tf.contrib.rnn.MultiRNNCell(堆叠多层LSTM) 4.mlstm_cell.zero_state(state初始化) 5.mlstm_cell(进行LSTM求解)
问题:LSTM的输出值output和state是否是一样的 1. rnn.LSTMCell(num_hidden, reuse=tf.get_variable_scope().reuse) # 构建 ...
- tensorflow笔记3:CRF函数:tf.contrib.crf.crf_log_likelihood()
在分析训练代码的时候,遇到了,tf.contrib.crf.crf_log_likelihood,这个函数,于是想简单理解下: 函数的目的:使用crf 来计算损失,里面用到的优化方法是:最大似然估计 ...
- TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用
一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...
- 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题
这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...
- 第十六节,使用函数封装库tf.contrib.layers
这一节,介绍TensorFlow中的一个封装好的高级库,里面有前面讲过的很多函数的高级封装,使用这个高级库来开发程序将会提高效率. 我们改写第十三节的程序,卷积函数我们使用tf.contrib.lay ...
- tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be replaced by tf.contrib.rnn.BasicLSTMCell.
For Tensorflow 1.2 and Keras 2.0, the line tf.contrib.rnn.core_rnn_cell.BasicLSTMCell should be repl ...
- tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别
tf.contrib.rnn.static_rnn与tf.nn.dynamic_rnn区别 https://blog.csdn.net/u014365862/article/details/78238 ...
- TensorFlow高层次机器学习API (tf.contrib.learn)
TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格 ...
随机推荐
- Oracle RAC环境下定位并杀掉最终阻塞的会话
实验环境:Oracle RAC 11.2.0.4 (2节点) 1.模拟故障:会话被级联阻塞 2.常规方法:梳理找出最终阻塞会话 3.改进方法:立即找出最终阻塞会话 之前其实也写过一篇相关文章: 如何定 ...
- java语法基础(总结)
1,关键字:其实就是某种语言赋予了特殊含义的单词. 保留字:其实就是还没有赋予特殊含义,但是准备日后要使用过的单词. 2,标示符:其实就是在程序中自定义的名词.比如类名,变量名,函数名.包含 0-9. ...
- GIT入门笔记(18)- 标签创建和管理
git tag <name>用于新建一个标签,默认为HEAD,也可以指定一个commit id: git tag -a <tagname> -m "blablabla ...
- 新概念英语(1-27)Mrs. Smtih's living room
Where are the books? Mrs. Smtih's living room is large. There is a television in the room. The telev ...
- 清除session信息
session.removeAttribute("sessionname")是清除SESSION里的某个属性. session.invalidate()是让SESSION失 ...
- [洛谷P1197/BZOJ1015][JSOI2008]星球大战Starwar - 并查集,离线,联通块
Description 很久以前,在一个遥远的星系,一个黑暗的帝国靠着它的超级武器统治者整个星系.某一天,凭着一个偶然的机遇,一支反抗军摧毁了帝国的超级武器,并攻下了星系中几乎所有的星球.这些星球通过 ...
- Linux:nohub启动后台永久进程
nohup 命令运行由 Command参数和任何相关的 Arg参数指定的命令,忽略所有挂断(SIGHUP)信号.在注销后使用 nohup 命令运行后台中的程序.要运行后台中的 nohup 命令,添加 ...
- Hadoop API:遍历文件分区目录,并根据目录下的数据进行并行提交spark任务
hadoop api提供了一些遍历文件的api,通过该api可以实现遍历文件目录: import java.io.FileNotFoundException; import java.io.IOExc ...
- Struts(十四):通用标签-form表单
form标签是struts2标签中一个重要标签: 可以生成html标签,使用起来和html的form标签差不多: Strut2的form标签会生成一个table,进行自动布局: 可以对表单提交的值进行 ...
- Extensions in UWP Community Toolkit - Visual Extensions
概述 UWP Community Toolkit Extensions 中有一个为可视元素提供的扩展 - VisualExtensions,本篇我们结合代码详细讲解 VisualExtensions ...