#!/usr/bin/python
# -*- coding: utf-8 -*-

import tensorflow as tf

class TRNNConfig(object):
"""RNN配置参数"""

# 模型参数
embedding_dim = 64 # 词向量维度
seq_length = 600 # 序列长度
num_classes = 10 # 类别数
vocab_size = 5000 # 词汇表达小

num_layers= 2 # 隐藏层层数
hidden_dim = 128 # 隐藏层神经元
rnn = 'gru' # lstm 或 gru

dropout_keep_prob = 0.8 # dropout保留比例
learning_rate = 1e-3 # 学习率

batch_size = 128 # 每批训练大小
num_epochs = 10 # 总迭代轮次

print_per_batch = 100 # 每多少轮输出一次结果
save_per_batch = 10 # 每多少轮存入tensorboard

class TextRNN(object):
"""文本分类,RNN模型"""
def __init__(self, config):
self.config = config

# 三个待输入的数据
self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

self.rnn()

def rnn(self):
"""rnn模型"""

def lstm_cell(): # lstm核
return tf.contrib.rnn.BasicLSTMCell(self.config.hidden_dim, state_is_tuple=True)

def gru_cell(): # gru核
return tf.contrib.rnn.GRUCell(self.config.hidden_dim)

def dropout(): # 为每一个rnn核后面加一个dropout层
if (self.config.rnn == 'lstm'):
cell = lstm_cell()
else:
cell = gru_cell()
return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=self.keep_prob)

# 动作映射
with tf.device('/cpu:0'):
embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)

with tf.name_scope("rnn"):
# 多层rnn网络
cells = [dropout() for _ in range(self.config.num_layers)]
rnn_cell = tf.contrib.rnn.MultiRNNCell(cells, state_is_tuple=True)

_outputs, _ = tf.nn.dynamic_rnn(cell=rnn_cell, inputs=embedding_inputs, dtype=tf.float32)
last = _outputs[:, -1, :] # 取最后一个时序输出作为结果

with tf.name_scope("score"):
# 全连接层,后面接dropout以及relu激活
fc = tf.layers.dense(last, self.config.hidden_dim, name='fc1')
fc = tf.contrib.layers.dropout(fc, self.keep_prob)
fc = tf.nn.relu(fc)

# 分类器
self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
# 预测类别
self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)

with tf.name_scope("optimize"):
# 损失函数,交叉熵
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
#求输入的所有行的预测值的均值
self.loss = tf.reduce_mean(cross_entropy)
# 优化器
self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

with tf.name_scope("accuracy"):
# 准确率 其中 self.y_pred_cls为预测的类别
correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
#
self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

LSTM_Model的更多相关文章

  1. 基于双向BiLstm神经网络的中文分词详解及源码

    基于双向BiLstm神经网络的中文分词详解及源码 基于双向BiLstm神经网络的中文分词详解及源码 1 标注序列 2 训练网络 3 Viterbi算法求解最优路径 4 keras代码讲解 最后 源代码 ...

  2. (转) Using the latest advancements in AI to predict stock market movements

    Using the latest advancements in AI to predict stock market movements 2019-01-13 21:31:18 This blog ...

  3. NLP入门(五)用深度学习实现命名实体识别(NER)

    前言   在文章:NLP入门(四)命名实体识别(NER)中,笔者介绍了两个实现命名实体识别的工具--NLTK和Stanford NLP.在本文中,我们将会学习到如何使用深度学习工具来自己一步步地实现N ...

  4. Reading | 《TensorFlow:实战Google深度学习框架》

    目录 三.TensorFlow入门 1. TensorFlow计算模型--计算图 I. 计算图的概念 II. 计算图的使用 2.TensorFlow数据类型--张量 I. 张量的概念 II. 张量的使 ...

  5. Tensorflow[LSTM]

    0.背景 通过对<tensorflow machine learning cookbook>第9章第3节"implementing_lstm"进行阅读,发现如下形式可以 ...

  6. 使用TensorFlow的递归神经网络(LSTM)进行序列预测

    本篇文章介绍使用TensorFlow的递归神经网络(LSTM)进行序列预测.作者在网上找到的使用LSTM模型的案例都是解决自然语言处理的问题,而没有一个是来预测连续值的. 所以呢,这里是基于历史观察数 ...

  7. Tensorflow LSTM实现

    Tensorflow[LSTM]   0.背景 通过对<tensorflow machine learning cookbook>第9章第3节"implementing_lstm ...

  8. 在TensorFlow中基于lstm构建分词系统笔记

    在TensorFlow中基于lstm构建分词系统笔记(一) https://www.jianshu.com/p/ccb805b9f014 前言 我打算基于lstm构建一个分词系统,通过这个例子来学习下 ...

  9. 『TensotFlow』RNN/LSTM古诗生成

    往期RNN相关工程实践文章 『TensotFlow』基础RNN网络分类问题 『TensotFlow』RNN中文文本_上 『TensotFlow』基础RNN网络回归问题 『TensotFlow』RNN中 ...

随机推荐

  1. Linux高级网络设置——给网卡绑定多个IP

    假设这样一种场景: 某运营商的Linux服务器上装配了2家互联网公司的Web服务,每个Web服务分配了一个公网IP地址.但是运营商的Linux服务器只有一块网卡.这就需要在一块网卡上绑定多个IP地址. ...

  2. idou老师教你学Istio 22 : 如何用istio实现调用链跟踪

    大家都知道istio可以帮助我们实现灰度发布.流量监控.流量治理等一些功能. 每一个功能都帮助我们在不同场景中实现不同的业务.那么其中比如流量监控这种复杂的功能Istio是如何让我们在不同的应用中实现 ...

  3. windows下内存检测工具

    1.Intel的Parallel Inspector工具,和vs集成超好, 而且还带了线程检测工具. 2.Purifyhttps://www.cnblogs.com/hehehaha/archive/ ...

  4. [Visual Studio] 问题:VS下运行项目时,检测到在集成的托管管道模式下不适用的 ASP.NET 设置。

    vs2012调试时默认会是集成模式,vs2012调试时怎么使用传统模式哪? 这个时候只要选中启动项目按F4,在托管管道模式里选传统模式即可!

  5. Nginx入门(一)——安装和配置

    1.下载地址 http://nginx.org/en/download.html 2.启动Nginx 进入window的cmd窗口,输入如下图所示的命令,进入到nginx目录(F:/nginx-1.8 ...

  6. JAVA遇见HTML——JSP篇(JSP内置对象下)

    request.getSession() 网上资料解释: request只能存在于一次访问里 session对象的作用域为一次会话 session长驻在服务器内存里,session有id标识,一个se ...

  7. JVM之Java运行时数据区(线程共享区)

    JVM运行时区域各线程共享的区域包括堆区和方法区. 堆区 堆区最最主要的功能是存储对象实例[上篇也提到过],因此Java垃圾回收的主要战场就是在堆区,因此也有称为GC堆区.如果堆区的内存不够会出现Ou ...

  8. Java集合--Stack

    转载请注明出处:http://www.cnblogs.com/skywang12345/p/3308852.html 第1部分 Stack介绍 Stack简介 Stack是栈.它的特性是:先进后出(F ...

  9. 组件化网页开发 / 步骤一 · 4-4 匹配HTML标签

    组件化网页开发 / 步骤一 · 4-4 匹配HTML标签

  10. PHP mysqli_connect_errno() 函数

    返回上一次连接错误的错误代码: <?php $con=mysqli_connect("localhost","wrong_user","my_p ...