详细解读简单的lstm的实例
http://blog.csdn.net/zjm750617105/article/details/51321889
本文是初学keras这两天来,自己仿照addition_rnn.py,写的一个实例,数据处理稍微有些不同,但是准确性相比addition_rnn.py 差一点,下面直接贴代码,
解释和注释都在代码里边。
- <span style="font-family: Arial, Helvetica, sans-serif;">#coding:utf-8</span>
- from keras.models import Sequential
- from keras.layers.recurrent import LSTM
- from utils import log
- from numpy import random
- import numpy as np
- from keras.layers.core import RepeatVector, TimeDistributedDense, Activation
- '''''
- 先用lstm实现一个计算加法的keras版本, 根据addition_rnn.py改写
- size: 500
- 10次: test_acu = 0.3050 base_acu= 0.3600
- 30次: rest_acu = 0.3300 base_acu= 0.4250
- size: 50000
- 10次: test_acu: loss: 0.4749 - acc: 0.8502 - val_loss: 0.4601 - val_acc: 0.8539
- base_acu: loss: 0.3707 - acc: 0.9008 - val_loss: 0.3327 - val_acc: 0.9135
- 20次: test_acu: loss: 0.1536 - acc: 0.9505 - val_loss: 0.1314 - val_acc: 0.9584
- base_acu: loss: 0.0538 - acc: 0.9891 - val_loss: 0.0454 - val_acc: 0.9919
- 30次: test_acu: loss: 0.0671 - acc: 0.9809 - val_loss: 0.0728 - val_acc: 0.9766
- base_acu: loss: 0.0139 - acc: 0.9980 - val_loss: 0.0502 - val_acc: 0.9839
- '''
- log = log()
- #defination the global variable
- training_size = 50000
- hidden_size = 128
- batch_size = 128
- layers = 1
- maxlen = 7
- single_digit = 3
- def generate_data():
- log.info("generate the questions and answers")
- questions = []
- expected = []
- seen = set()
- while len(seen) < training_size:
- num1 = random.randint(1, 999) #generate a num [1,999]
- num2 = random.randint(1, 999)
- #用set来存储又有排序,来保证只有不同数据和结果
- key = tuple(sorted((num1,num2)))
- if key in seen:
- continue
- seen.add(key)
- q = '{}+{}'.format(num1,num2)
- query = q + ' ' * (maxlen - len(q))
- ans = str(num1 + num2)
- ans = ans + ' ' * (single_digit + 1 - len(ans))
- questions.append(query)
- expected.append(ans)
- return questions, expected
- class CharacterTable():
- '''''
- encode: 将一个str转化为一个n维数组
- decode: 将一个n为数组转化为一个str
- 输入输出分别为
- character_table = [' ', '+', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
- 如果一个question = [' 123+23']
- 那个改question对应的数组就是(7,12):
- 同样expected最大是一个四位数[' 146']:
- 那么ans对应的数组就是[4,12]
- '''
- def __init__(self, chars, maxlen):
- self.chars = sorted(set(chars))
- '''''
- >>> b = [(c, i) for i, c in enumerate(a)]
- >>> dict(b)
- {' ': 0, '+': 1, '1': 3, '0': 2, '3': 5, '2': 4, '5': 7, '4': 6, '7': 9, '6': 8, '9': 11, '8': 10}
- 得出的结果是无序的,但是下面这种方式得出的结果是有序的
- '''
- self.char_index = dict((c, i) for i, c in enumerate(self.chars))
- self.index_char = dict((i, c) for i, c in enumerate(self.chars))
- self.maxlen = maxlen
- def encode(self, C, maxlen):
- X = np.zeros((maxlen, len(self.chars)))
- for i, c in enumerate(C):
- X[i, self.char_index[c]] = 1
- return X
- def decode(self, X, calc_argmax=True):
- if calc_argmax:
- X = X.argmax(axis=-1)
- return ''.join(self.index_char[x] for x in X)
- chars = '0123456789 +'
- character_table = CharacterTable(chars,len(chars))
- questions , expected = generate_data()
- log.info('Vectorization...') #失量化
- inputs = np.zeros((len(questions), maxlen, len(chars))) #(5000, 7, 12)
- labels = np.zeros((len(expected), single_digit+1, len(chars))) #(5000, 4, 12)
- log.info("encoding the questions and get inputs")
- for i, sentence in enumerate(questions):
- inputs[i] = character_table.encode(sentence, maxlen=len(sentence))
- #print("questions is ", questions[0])
- #print("X is ", inputs[0])
- log.info("encoding the expected and get labels")
- for i, sentence in enumerate(expected):
- labels[i] = character_table.encode(sentence, maxlen=len(sentence))
- #print("expected is ", expected[0])
- #print("y is ", labels[0])
- log.info("total inputs is %s"%str(inputs.shape))
- log.info("total labels is %s"%str(labels.shape))
- log.info("build model")
- model = Sequential()
- '''''
- LSTM(output_dim, init='glorot_uniform', inner_init='orthogonal',
- forget_bias_init='one', activation='tanh',
- inner_activation='hard_sigmoid',
- W_regularizer=None, U_regularizer=None, b_regularizer=None,
- dropout_W=0., dropout_U=0., **kwargs)
- output_dim: 输出层的维数,或者可以用output_shape
- init:
- uniform(scale=0.05) :均匀分布,最常用的。Scale就是均匀分布的每个数据在-scale~scale之间。此处就是-0.05~0.05。scale默认值是0.05;
- lecun_uniform:是在LeCun在98年发表的论文中基于uniform的一种方法。区别就是lecun_uniform的scale=sqrt(3/f_in)。f_in就是待初始化权值矩阵的行。
- normal:正态分布(高斯分布)。
- Identity :用于2维方阵,返回一个单位阵.
- Orthogonal:用于2维方阵,返回一个正交矩阵. lstm默认
- Zero:产生一个全0矩阵。
- glorot_normal:基于normal分布,normal的默认 sigma^2=scale=0.05,而此处sigma^2=scale=sqrt(2 / (f_in+ f_out)),其中,f_in和f_out是待初始化矩阵的行和列。
- glorot_uniform:基于uniform分布,uniform的默认scale=0.05,而此处scale=sqrt( 6 / (f_in +f_out)) ,其中,f_in和f_out是待初始化矩阵的行和列。
- W_regularizer , b_regularizer and activity_regularizer:
- 官方文档: http://keras.io/regularizers/
- from keras.regularizers import l2, activity_l2
- model.add(Dense(64, input_dim=64, W_regularizer=l2(0.01), activity_regularizer=activity_l2(0.01)))
- 加入规则项主要是为了在小样本数据下过拟合现象的发生,我们都知道,一半在训练过程中解决过拟合现象的方法主要中两种,一种是加入规则项(权值衰减), 第二种是加大数据量
- 很显然,加大数据量一般是不容易的,而加入规则项则比较容易,所以在发生过拟合的情况下,我们一般都采用加入规则项来解决这个问题.
- '''
- model.add(LSTM(hidden_size, input_shape=(maxlen, len(chars)))) #(7,12) 输入层
- '''''
- keras.layers.core.RepeatVector(n)
- 把1维的输入重复n次。假设输入维度为(nb_samples, dim),那么输出shape就是(nb_samples, n, dim)
- inputshape: 任意。当把这层作为某个模型的第一层时,需要用到该参数(元组,不包含样本轴)。
- outputshape:(nb_samples,nb_input_units)
- '''
- model.add(RepeatVector(single_digit + 1))
- #表示有多少个隐含层
- for _ in range(layers):
- model.add(LSTM(hidden_size, return_sequences=True))
- '''''
- TimeDistributedDense:
- 官方文档:http://keras.io/layers/core/#timedistributeddense
- keras.layers.core.TimeDistributedDense(output_dim,init='glorot_uniform', activation='linear', weights=None
- W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None,
- input_dim=None, input_length=None)
- 这是一个基于时间维度的全连接层。主要就是用来构建RNN(递归神经网络)的,但是在构建RNN时需要设置return_sequences=True。
- for example:
- # input shape: (nb_samples, timesteps,10)
- model.add(LSTM(5, return_sequences=True, input_dim=10)) # output shape: (nb_samples, timesteps, 5)
- model.add(TimeDistributedDense(15)) # output shape:(nb_samples, timesteps, 15)
- W_constraint:
- from keras.constraints import maxnorm
- model.add(Dense(64, W_constraint =maxnorm(2))) #限制权值的各个参数不能大于2
- '''
- model.add(TimeDistributedDense(len(chars)))
- model.add(Activation('softmax'))
- '''''
- 关于目标函数和优化函数,参考另外一片博文: http://blog.csdn.net/zjm750617105/article/details/51321915
- '''
- model.compile(loss='categorical_crossentropy',
- optimizer='adam',
- metrics=['accuracy'])
- # Train the model each generation and show predictions against the validation dataset
- for iteration in range(1, 3):
- print()
- print('-' * 50)
- print('Iteration', iteration)
- model.fit(inputs, labels, batch_size=batch_size, nb_epoch=2,
- validation_split = 0.1)
- # Select 10 samples from the validation set at random so we can visualize errors
- model.get_config()
详细解读简单的lstm的实例的更多相关文章
- Paxos协议超级详细解释+简单实例
转载自: https://blog.csdn.net/cnh294141800/article/details/53768464 Paxos协议超级详细解释+简单实例 Basic-Paxos算法 ...
- MemCache超详细解读
MemCache是什么 MemCache是一个自由.源码开放.高性能.分布式的分布式内存对象缓存系统,用于动态Web应用以减轻数据库的负载.它通过在内存中缓存数据和对象来减少读取数据库的次数,从而提高 ...
- MemCache超详细解读 图
http://www.cnblogs.com/xrq730/p/4948707.html MemCache是什么 MemCache是一个自由.源码开放.高性能.分布式的分布式内存对象缓存系统,用于 ...
- MemCache详细解读
MemCache是什么 MemCache是一个自由.源码开放.高性能.分布式的分布式内存对象缓存系统,用于动态Web应用以减轻数据库的负载.它通过在内存中缓存数据和对象来减少读取数据库的次数,从而提高 ...
- 【Python】【Web.py】详细解读Python的web.py框架下的application.py模块
详细解读Python的web.py框架下的application.py模块 这篇文章主要介绍了Python的web.py框架下的application.py模块,作者深入分析了web.py的源码, ...
- SpringMVC 原理 - 设计原理、启动过程、请求处理详细解读
SpringMVC 原理 - 设计原理.启动过程.请求处理详细解读 目录 一. 设计原理 二. 启动过程 三. 请求处理 一. 设计原理 Servlet 规范 SpringMVC 是基于 Servle ...
- NLP突破性成果 BERT 模型详细解读 bert参数微调
https://zhuanlan.zhihu.com/p/46997268 NLP突破性成果 BERT 模型详细解读 章鱼小丸子 不懂算法的产品经理不是好的程序员 关注她 82 人赞了该文章 Goo ...
- springmvc 项目完整示例01 需求与数据库表设计 简单的springmvc应用实例 web项目
一个简单的用户登录系统 用户有账号密码,登录ip,登录时间 打开登录页面,输入用户名密码 登录日志,可以记录登陆的时间,登陆的ip 成功登陆了的话,就更新用户的最后登入时间和ip,同时记录一条登录记录 ...
- Android BLE蓝牙详细解读
代码地址如下:http://www.demodashi.com/demo/15062.html 随着物联网时代的到来,越来越多的智能硬件设备开始流行起来,比如智能手环.心率检测仪.以及各式各样的智能家 ...
随机推荐
- 马士兵hadoop第二课:hdfs集群集中管理和hadoop文件操作
马士兵hadoop第一课:虚拟机搭建和安装hadoop及启动 马士兵hadoop第二课:hdfs集群集中管理和hadoop文件操作 马士兵hadoop第三课:java开发hdfs 马士兵hadoop第 ...
- 使用 IntraWeb (15) - 基本控件之 TIWEdit、TIWMemo、TIWText
TIWEdit //单行文本框, 通过 PasswordPrompt 属性可以作为密码框 TIWMemo //多行文本框 TIWText //相当于多行的 TIWLabel 或不能编辑的 TIWMem ...
- net自定义安装程序快捷方式
创建快捷方式对于绝大多数 Windows 用户来说都是小菜一碟了,然而,这项工作却为程序员带来不少麻烦..NET 没有提供简便直接的创建快捷方式的方法,那么在 .NET 中我们如何为应用程序创建快捷方 ...
- J1850 Implement
http://avrobdii.googlecode.com/svn/trunk/code/J1850.c /* Copyright (C) Trampas Stern name of author ...
- The .NET weak event pattern in C#
Introduction As you may know event handlers are a common source of memory leaks caused by the persis ...
- ASP.NET MVC与Sql Server交互,把字典数据插入数据库
在"ASP.NET MVC与Sql Server交互, 插入数据"中,在Controller中拼接sql语句.比如: _db.InsertData("insert int ...
- Unity3D实践系列08, MonoBehaviour类的各种触发事件
在脚本的生命周期中,有Awake, Start, FixedUpdate, Update, LateUpdate等方法,其实这些属于MonoBehaviour类的事件响应方法,是MonoBehavio ...
- 对ORM的支持 之 8.4 集成JPA ——跟我学spring3
8.4 集成JPA JPA全称为Java持久性API(Java Persistence API),JPA是Java EE 5标准之一,是一个ORM规范,由厂商来实现该规范,目前有Hibernate. ...
- 技术人生:special considerations that are very important
For the most part, a lot of what we know about software development can be applied to different envi ...
- 用自定义的RoundImageView来实现圆形图片(可加边框)
本文的控件来自:http://blog.csdn.net/alan_biao/article/details/17379925 这个控件不同于之前介绍过的那个框架,这个仅仅能过将图片裁剪为圆形,没能弄 ...