[Tensorflow] RNN - 01. Spam Prediction with BasicRNNCell
Ref: http://blog.csdn.net/mebiuw/article/details/60780813
Ref: https://medium.com/@erikhallstrm/hello-world-rnn-83cd7105b767 [Nice]
Ref: https://medium.com/@erikhallstrm/tensorflow-rnn-api-2bb31821b185 [Nice]
Code Analysis
Download and pre-preprocess
# Implementing an RNN in Tensorflow
#----------------------------------
#
# We implement an RNN in Tensorflow to predict spam/ham from texts
#
# Jeffrey: the data process for nlp here is advanced. import os
import re
import io
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from zipfile import ZipFile
import urllib.request from tensorflow.python.framework import ops
ops.reset_default_graph() # Start a graph
sess = tf.Session() # Set RNN parameters
epochs = 30
batch_size = 250
max_sequence_length = 40
rnn_size = 10
embedding_size = 50
min_word_frequency = 10
learning_rate = 0.0005
dropout_keep_prob = tf.placeholder(tf.float32) # Download or open data
data_dir = 'temp'
data_file = 'text_data.txt'
if not os.path.exists(data_dir):
os.makedirs(data_dir) if not os.path.isfile(os.path.join(data_dir, data_file)):
zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
page = urllib.request.urlopen(zip_url)
html_content = page.read()
z = ZipFile(io.BytesIO(html_content)) file = z.read('SMSSpamCollection') # Format Data
text_data = file.decode()
text_data = text_data.encode('ascii',errors='ignore')
text_data = text_data.decode().split('\n') # Save data to text file
with open(os.path.join(data_dir, data_file), 'w') as file_conn:
for text in text_data:
file_conn.write("{}\n".format(text))
else:
# Open data from text file
text_data = []
with open(os.path.join(data_dir, data_file), 'r') as file_conn:
for row in file_conn:
text_data.append(row)
text_data = text_data[:-1] text_data = [x.split('\t') for x in text_data if len(x)>=1]
[text_data_target, text_data_train] = [list(x) for x in zip(*text_data)] # Create a text cleaning function
def clean_text(text_string):
text_string = re.sub(r'([^\s\w]|_|[0-9])+', '', text_string)
text_string = " ".join(text_string.split())
text_string = text_string.lower()
return(text_string) # Clean texts
text_data_train = [clean_text(x) for x in text_data_train] #Jeffrey
#print("[x]:", text_data_train[:10][:10])
#print("[y]:", text_data_target[:10])
Stage result:
print("[x]:", text_data_train[:10])
print("[y]:", text_data_target[:10])
[x]: ['go until jurong point crazy available only in bugis n great world la e buffet cine there got amore wat', 'ok lar joking wif u oni', 'free entry in a wkly comp to win fa cup final tkts st may text fa to to receive entry questionstd txt ratetcs apply overs', 'u dun say so early hor u c already then say', 'nah i dont think he goes to usf he lives around here though', 'freemsg hey there darling its been weeks now and no word back id like some fun you up for it still tb ok xxx std chgs to send to rcv', 'even my brother is not like to speak with me they treat me like aids patent', 'as per your request melle melle oru minnaminunginte nurungu vettam has been set as your callertune for all callers press to copy your friends callertune', 'winner as a valued network customer you have been selected to receivea prize reward to claim call claim code kl valid hours only', 'had your mobile months or more u r entitled to update to the latest colour mobiles with camera for free call the mobile update co free on']
[y]: [1 1 0 1 1 0 1 1 0 0]
Change texts into numeric vectors
# Change texts into numeric vectors
vocab_processor = tf.contrib.learn.preprocessing.VocabularyProcessor(max_sequence_length, min_frequency=min_word_frequency)
text_processed = np.array(list(vocab_processor.fit_transform(text_data_train))) # Shuffle and split data
text_processed = np.array(text_processed)
text_data_target = np.array([1 if x=='ham' else 0 for x in text_data_target])
Stage result: one-hotting encoding
#Jeffrey
#print("[text_processed]:", text_processed.shape)
#print("[text_data_target]:", text_data_target.shape)
##[text_processed]: (5574, 40)
##[text_data_target]: (5574,) #print("[text_processed]:", text_processed)
#print("[text_data_target]:", text_data_target)
[text_processed]:
[[ 44 455 0 ..., 0 0 0]
[ 47 315 0 ..., 0 0 0]
[ 46 465 9 ..., 0 0 0]
...,
[ 0 59 9 ..., 0 0 0]
[ 5 493 108 ..., 0 0 0]
[ 0 40 474 ..., 0 0 0]] [text_data_target]:
[1 1 0 ..., 1 1 1]
Term statistics
shuffled_ix = np.random.permutation(np.arange(len(text_data_target)))
x_shuffled = text_processed[shuffled_ix]
y_shuffled = text_data_target[shuffled_ix] # Split train/test set
ix_cutoff = int(len(y_shuffled)*0.80)
x_train, x_test = x_shuffled[:ix_cutoff], x_shuffled[ix_cutoff:]
y_train, y_test = y_shuffled[:ix_cutoff], y_shuffled[ix_cutoff:] print(vocab_processor.vocabulary_) vocab_size = len(vocab_processor.vocabulary_)
print("Vocabulary Size: {:d}".format(vocab_size))
print("80-20 Train Test split: {:d} -- {:d}".format(len(y_train), len(y_test)))
[text_processed]: (5574, 40)
[text_data_target]: (5574,)
Build Graph
############################################################################### # Create placeholders
x_data = tf.placeholder(tf.int32, [None, max_sequence_length])
y_output = tf.placeholder(tf.int32, [None]) # Create embedding
embedding_mat = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0))
embedding_output = tf.nn.embedding_lookup(embedding_mat, x_data) # Here, this x_data is ids! <---- [termIdx1, termIdx2, ...]
#embedding_output_expanded = tf.expand_dims(embedding_output, -1)
Create our embedding matrix and embedding lookup operation for the x-input data:embedding_mat.
# Define the RNN cell
cell = tf.nn.rnn_cell.BasicRNNCell(num_units = rnn_size)
output, state = tf.nn.dynamic_rnn(cell, embedding_output, dtype=tf.float32)
output = tf.nn.dropout(output, dropout_keep_prob) # parameters are variables, waiting for constant later. # Get output of RNN sequence
output = tf.transpose(output, [1, 0, 2])
last = tf.gather(output, int(output.get_shape()[0]) - 1)
API: rnn_cell
from Tensorflow RNN源代码解析笔记1:RNNCell的基本实现
在Tensorflow中,定义了一个RNNCell的抽象类,具体的所有不同类型的RNN Cell都是基于这个类的.
在Tensorflow中,将会基于整个RNNCell实现一系列常用的RNNCell,比如LSTM和GRU,并且将会支持包含Dropout等在内的特性,同时也支持构建多层的RNN网络。
class BasicRNNCell(RNNCell):
"""The most basic RNN cell. Args:
num_units: int, The number of units in the RNN cell.
activation: Nonlinearity to use. Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
""" def __init__(self, num_units, activation=None, reuse=None):
super(BasicRNNCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._activation = activation or math_ops.tanh @property
def state_size(self):
return self._num_units @property
def output_size(self):
return self._num_units def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
output = self._activation(_linear([inputs, state], self._num_units, True))
return output, output
学习参数:
From: YJango的循环神经网络——介绍
所有时刻的权重矩阵都是共享的。这是递归网络相对于前馈网络而言最为突出的优势。
递归神经网络是在时间结构上存在共享特性的神经网络变体。时间结构共享是递归网络的核心中的核心。
h_state:
# Variables
weight = tf.Variable(tf.truncated_normal([rnn_size, 2], stddev=0.1))
bias = tf.Variable(tf.constant(0.1, shape=[2]))
logits_out = tf.nn.softmax(tf.matmul(last, weight) + bias)
# Loss function
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits_out, labels=y_output) # logits=float32, labels=int32
loss = tf.reduce_mean(losses) accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits_out, 1), tf.cast(y_output, tf.int64)), tf.float32)) optimizer = tf.train.RMSPropOptimizer(learning_rate)
train_step = optimizer.minimize(loss) init = tf.initialize_all_variables()
sess.run(init)
###############################################################################
############################################################################### train_loss = []
test_loss = []
train_accuracy = []
test_accuracy = [] # Start training
for epoch in range(epochs): # Shuffle training data
shuffled_ix = np.random.permutation(np.arange(len(x_train)))
# Sort x_train and y_train based on shuffled_ix
x_train = x_train[shuffled_ix]
y_train = y_train[shuffled_ix] num_batches = int(len(x_train)/batch_size) + 1
# TO DO CALCULATE GENERATIONS ExACTLY
# For each batch.
for i in range(num_batches):
# Select train data
min_ix = i * batch_size
max_ix = np.min([len(x_train), ((i+1) * batch_size)])
x_train_batch = x_train[min_ix:max_ix]
y_train_batch = y_train[min_ix:max_ix] # Run train step
train_dict = {x_data: x_train_batch, y_output: y_train_batch, dropout_keep_prob:0.5}
sess.run(train_step, feed_dict=train_dict) # Run loss and accuracy for training
temp_train_loss, temp_train_acc = sess.run([loss, accuracy], feed_dict=train_dict)
train_loss.append(temp_train_loss)
train_accuracy.append(temp_train_acc) # Run Eval Step
test_dict = {x_data: x_test, y_output: y_test, dropout_keep_prob:1.0}
temp_test_loss, temp_test_acc = sess.run([loss, accuracy], feed_dict=test_dict)
test_loss.append(temp_test_loss)
test_accuracy.append(temp_test_acc)
print('Epoch: {}, Test Loss: {:.2}, Test Acc: {:.2}'.format(epoch+1, temp_test_loss, temp_test_acc)) # Plot loss over time
epoch_seq = np.arange(1, epochs+1)
plt.plot(epoch_seq, train_loss, 'k--', label='Train Set')
plt.plot(epoch_seq, test_loss, 'r-', label='Test Set')
plt.title('Softmax Loss')
plt.xlabel('Epochs')
plt.ylabel('Softmax Loss')
plt.legend(loc='upper left')
plt.show() # Plot accuracy over time
plt.plot(epoch_seq, train_accuracy, 'k--', label='Train Set')
plt.plot(epoch_seq, test_accuracy, 'r-', label='Test Set')
plt.title('Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend(loc='upper left')
plt.show()
[Tensorflow] RNN - 01. Spam Prediction with BasicRNNCell的更多相关文章
- tensorflow rnn 最简单实现代码
tensorflow rnn 最简单实现代码 #!/usr/bin/env python # -*- coding: utf-8 -*- import tensorflow as tf from te ...
- TensorFlow (RNN)深度学习 双向LSTM(BiLSTM)+CRF 实现 sequence labeling 序列标注问题 源码下载
http://blog.csdn.net/scotfield_msn/article/details/60339415 在TensorFlow (RNN)深度学习下 双向LSTM(BiLSTM)+CR ...
- TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架
TensorFlow RNN MNIST字符识别演示快速了解TF RNN核心框架 http://blog.sina.com.cn/s/blog_4b0020f30102wv4l.html
- [Tensorflow] RNN - 02. Movie Review Sentiment Prediction with LSTM
From: Predicting Movie Review Sentiment with TensorFlow and TensorBoard Ref: http://www.cnblogs.com/ ...
- [Tensorflow] RNN - 03. MultiRNNCell for Digit Prediction
Ref: http://blog.csdn.net/u014595019/article/details/52759104 Time: 2min Successfully downloaded tra ...
- AI - TensorFlow - 示例01:基本分类
基本分类 基本分类(Basic classification):https://www.tensorflow.org/tutorials/keras/basic_classification Fash ...
- tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】
之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...
- [Tensorflow] RNN - 04. Work with CNN for Text Classification
Ref: Combining CNN and RNN for spoken language identification Ref: Convolutional Methods for Text [1 ...
- TensorFlow RNN 教程和代码
分析: 看 TensorFlow 也有一段时间了,准备按照 GitHub 上的教程,敲出来,顺便整理一下思路. RNN部分 定义参数,包括数据相关,训练相关. 定义模型,损失函数,优化函数. 训练,准 ...
随机推荐
- mysql如何查看数据库的存放位置
使用如下命令: mysql> show global variables like "%datadir%";法一: 数据库文件存放在这个位置, C:\ProgramData\ ...
- NDArray自动求导
NDArray可以很方便的求解导数,比如下面的例子:(代码主要参考自https://zh.gluon.ai/chapter_crashcourse/autograd.html) 用代码实现如下: im ...
- android:碎片的使用方式
介绍了这么多抽象的东西,也是时候应该学习一下碎片的具体用法了.你已经知道,碎 片通常都是在平板开发中才会使用的,因此我们首先要做的就是新建一个平板电脑的模拟 器.由于 4.0 系统的平板模拟器好像存在 ...
- F800上的CPU有多少个core?
本来想看看F800的一个节点上,每个core的cpu的使用情况.使用下面的命令,可以进行查看: top –P 从命令输出来看,好像是有32个core嘛. 使用下面的命令可以看到具体的CPU的信息. d ...
- 子级div相对于父级div位置不变
设置父级div为相对位置 设置子级div为绝对位置 代码如下: <!DOCTYPE html> <html> <head> <meta charset=&qu ...
- Android性能优化-减小图片下载大小
原文链接 https://developer.android.com/topic/performance/network-xfer.html 内容概要 理解图片的格式 PNG JPG WebP 如何选 ...
- Android架构初探
#一 背景点评美团合并之后,业务需要整合,我们部门的几条业务需要往美团平台迁移,为了降低迁移成本,开发和维护成本,以及将来可能要做的单元测试,需要对架构进行相应的调整.之前的代码都堆在Activity ...
- windows多线程同步--事件
推荐参考博客:秒杀多线程第六篇 经典线程同步 事件Event 事件是内核对象,多用于线程间通信,可以跨进程同步 事件主要用到三个函数:CreateEvent,OpenEvent,SetEvent, ...
- 使用VirtualBox在Ubuntu下虚拟Windows XP共享文件夹设置方法
1.首先保证虚拟的Windows XP有虚拟光驱(正常安装的都是有的,因为在ubuntu下一般都是用硬盘虚拟安装的),然后在已经运行起来的Windows XP菜单栏上选择“设备-安装增强功能“,如果不 ...
- 【PMP】项目目标的SMART原则
详细解读 Specific 具体的 用具体的语言清楚的说明要达成的标准. Measureable 可测量的 目标应该是明确的,而不是模糊的.应该有一组明确的数据,作为衡量是否达成目标的依据. Achi ...