#-*-coding:utf8-*-

__author = "buyizhiyou"
__date = "2017-11-21" '''
单步调试,结合汉字的识别学习lstm,ctc loss的tf实现,tensorflow1.4
'''
import tensorflow as tf
import numpy as np
import pdb
import random def create_sparse(batch_size, dtype=np.int32):
'''
创建稀疏张量,ctc_loss中labels要求是稀疏张量,随机生成序列长度在150~180之间的labels
'''
indices = []
values = []
for i in range(batch_size):
length = random.randint(150,180)
for j in range(length):
indices.append((i,j))
value = random.randint(0,779)
values.append(value) indices = np.asarray(indices, dtype=np.int64)
values = np.asarray(values, dtype=dtype)
shape = np.asarray([batch_size, np.asarray(indices).max(0)[1] + 1], dtype=np.int64) #[64,180] return [indices, values, shape] W = tf.Variable(tf.truncated_normal([200,781],stddev=0.1), name="W")#num_hidden=200,num_classes=781(想象成780个汉字+blank),shape (200,781)
b = tf.Variable(tf.constant(0., shape=[781]), name="b")#
global_step = tf.Variable(0, trainable=False)#全局步骤计数 #构造输入
inputs = tf.random_normal(shape=[64,60,3000], dtype=tf.float32)#为了测试,随机batch_size=64张图片,h=60,w=3000,w可以看成lstm的时间步,即lstm输入的time_step=3000,h看成是每一时间步的输入tensor的size
shape = tf.shape(inputs)#array([ 64, 3000, 60], dtype=int32)
batch_s, max_timesteps = shape[0], shape[1] #64,3000
output = create_sparse(64)#创建64张图片对应的labels,稀疏张量,序列长度变长
seq_len = np.ones(64)*180 #180为变长序列的最大值
labels = tf.SparseTensor(values=output[1],indices=output[0],dense_shape=output[2]) pdb.set_trace()
cell = tf.nn.rnn_cell.LSTMCell(200, state_is_tuple=True)
inputs = tf.transpose(inputs,[0,2,1])#转置,因为默认的tf.nn.dynamic_rnn中参数time_major=false,即inputs的shape 是`[batch_size, max_time, ...]`, '''
tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, paralle
l_iterations=None, swap_memory=False, time_major=False, scope=None)
'''
outputs1, _ = tf.nn.dynamic_rnn(cell, inputs, seq_len, dtype=tf.float32)#(64, 3000, 200)动态rnn实现了输入变长问题的解决方案http://blog.csdn.net/u010223750/article/details/71079036 outputs = tf.reshape(outputs1, [-1, 200])#(64×3000,200)
logits0 = tf.matmul(outputs, W) + b
logits1 = tf.reshape(logits0, [batch_s, -1, 781])
logits = tf.transpose(logits1, (1, 0, 2))#(3000, 64, 781) '''
tf.nn.ctc_loss(labels, inputs, sequence_length, preprocess_collapse_repeated=False, ctc_merge
_repeated=True, ignore_longer_outputs_than_inputs=False, time_major=True)
'''
loss = tf.nn.ctc_loss(logits, labels, seq_len)#关于ctc loss解决rnn输出和序列不对齐问题
#http://blog.csdn.net/left_think/article/details/76370453
#https://zhuanlan.zhihu.com/p/23293860
cost = tf.reduce_mean(loss)
optimizer = tf.train.MomentumOptimizer(learning_rate=0.01,
momentum=0.9).minimize(cost, global_step=global_step)
#decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_len, merge_repeated=False)#or "tf.nn.ctc_greedy_decoder"一种解码策略
#acc = tf.reduce_mean(tf.edit_distance(tf.cast(decoded[0], tf.int32), labels))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (outputs.get_shape())
print (sess.run(loss))

以lstm+ctc对汉字识别为例对tensorflow 中的lstm,ctc loss的调试的更多相关文章

  1. 在TensorFlow中基于lstm构建分词系统笔记

    在TensorFlow中基于lstm构建分词系统笔记(一) https://www.jianshu.com/p/ccb805b9f014 前言 我打算基于lstm构建一个分词系统,通过这个例子来学习下 ...

  2. tensorflow中的lstm的state

        考虑 state_is_tuple     Output, new_state = cell(input, state)     state其实是两个 一个 c state,一个m(对应下图的 ...

  3. tensorflow源码分析——CTC

    CTC是2006年的论文Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurren ...

  4. Python中利用LSTM模型进行时间序列预测分析

    时间序列模型 时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺 ...

  5. tensorflow实现基于LSTM的文本分类方法

    tensorflow实现基于LSTM的文本分类方法 作者:u010223750 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实 ...

  6. 一文详解如何用 TensorFlow 实现基于 LSTM 的文本分类(附源码)

    雷锋网按:本文作者陆池,原文载于作者个人博客,雷锋网已获授权. 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实例,这个星期就用 ...

  7. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  8. 在Keras中可视化LSTM

    作者|Praneet Bomma 编译|VK 来源|https://towardsdatascience.com/visualising-lstm-activations-in-keras-b5020 ...

  9. LSTM(长短期记忆网络)及其tensorflow代码应用

     本文主要包括: 一.什么是LSTM 二.LSTM的曲线拟合 三.LSTM的分类问题 四.为什么LSTM有助于消除梯度消失 一.什么是LSTM Long Short Term 网络即为LSTM,是一种 ...

随机推荐

  1. Codeforces Round #387 (Div. 2) 747F(数位DP)

    题目大意 给出整数k和t,需要产生一个满足以下要求的第k个十六进制数 即十六进制数每一位上的数出现的次数不超过t 首先我们先这样考虑,如果给你了0~f每个数字可以使用的次数num[i],如何求长度为L ...

  2. 【POJ 2387 Til the Cows Come Home】

    Time Limit: 1000MSMemory Limit: 65536K Total Submissions: 59755Accepted: 20336 Description Bessie is ...

  3. [03] html 中引入与使用css

    1. 使用style属性 <a style="color: red;"> hello ,there use style attribute</a> 2. l ...

  4. a标签打电话

    <a href="tel:0147-88469258"></a> <a href="mailto:bd@pangxiekeji.com&qu ...

  5. 创建型设计模式之建造者模式(Builder)

    结构 意图 将一个复杂对象的构建与它的表示分离,使得同样的构建过程可以创建不同的表示. 适用性 当创建复杂对象的算法应该独立于该对象的组成部分以及它们的装配方式时. 当构造过程必须允许被构造的对象有不 ...

  6. 创建型设计模式之单例模式(Singleton)

     结构 意图 保证一个类仅有一个实例,并提供一个访问它的全局访问点. 适用性 当类只能有一个实例而且客户可以从一个众所周知的访问点访问它时. 当这个唯一实例应该是通过子类化可扩展的,并且客户应该无需更 ...

  7. 小Z爱图论(NOIP信(sang)心(bin)赛)From FallDream

    题目: 小Z最近喜欢上了图论,于是他研究了一下图的连通性问题.但是他遇到了一个难题. 给定一个n个点的有向图,求有多少点对(i,j)满足从i点出发能到达点j ? 小Z仅会简单的朴素算法,所以他想问问你 ...

  8. python的优化机制与垃圾回收与gc模块

    python属于动态语言,我们可以随意的创建和销毁变量,如果频繁的创建和销毁则会浪费cpu,那么python内部是如何优化的呢? python和其他很多高级语言一样,都自带垃圾回收机制,不用我们去维护 ...

  9. 关于多态的理解,有助于理解TStream抽象类的多态机制。

    有的时候 不是很明白流的机制,因为有内存流  文件流 图片流 等等 他们之间的相互转化 靠的就是流的多态性.... unit Unit11; interface uses Winapi.Windows ...

  10. react 如何处理页面加载时无法将获取缓存信息存入全局变量中

    最近在做一个权限功能时,发现在读取用户公司ID进行列表查询 时,无法钭读取到缓存中的数据存入页面全局变量中进行加载查询 将问题代码整理出来 将信息存入缓存: let menuList = Helper ...