#!/usr/bin/env python
# coding=utf-8 from keras.models import Sequential
from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
import numpy as np
import string
import random class CharacterTable(object): def __init__(self, maxlen):
self.chars = string.digits + '+ '
self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
self.indice_chars = dict((i, c) for i, c in enumerate(self.chars))
self.maxlen = maxlen def encode(self, strs, maxlen=None):
maxlen = maxlen if maxlen else self.maxlen
vec = np.zeros((maxlen, len(self.chars)))
for i, c in enumerate(strs):
vec[i, self.char_indices[c]] = 1
return vec def decode(self, vec, calc_argmax=True):
if calc_argmax:
vec = vec.argmax(axis=-1)
return ''.join(self.indice_chars[x] for x in vec) def gen_num():
nums = random.sample('', random.randint(1, 3))
return int(''.join(nums)) MAXLEN = 7 # 3+3+1
ctable = CharacterTable(MAXLEN) questions, expected = [], []
seen = set()
i = 0
while i < 50000:
a, b = gen_num(), gen_num()
key = tuple(sorted((a, b)))
if key in seen:
continue
seen.add(key)
q = '{}+{}'.format(a, b)
query = q + ' '*(7-len(q))
ans = str(a+b)
ans += ' ' * (4-len(ans)) questions.append(query)
expected.append(ans)
i += 1
print('total questions', len(questions)) X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool)
y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool) for i, sent in enumerate(questions):
X[i] = ctable.encode(sent) for i, sent in enumerate(expected):
y[i] = ctable.encode(sent, 4) model = Sequential()
model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars))))
model.add(RepeatVector(4))
model.add(recurrent.LSTM(128, return_sequences=True))
model.add(recurrent.LSTM(128, return_sequences=True)) model.add(TimeDistributed(Dense(len(ctable.chars))))
model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy']) model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2) # 测试看看
for i in range(10):
ind = np.random.randint(0, len(questions)-5)
x_test, y_test = X[ind:ind+5], y[ind:ind+5]
y_preds = model.predict_classes(x_test, verbose=0)
print('Q', ctable.decode(x_test[0]))
print('T', ctable.decode(y_test[0]))
print('Pred', ctable.decode(y_preds[0], calc_argmax=False)) json_string = model.to_json()
with open('rnn_add_model.json', 'wb') as fw:
fw.write(json_string)
model.save_weights('rnn_add_model.h5')

基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%

rnn实现三位数加法的训练的更多相关文章

  1. GDUFE-OJ 1203x的y次方的最后三位数 快速幂

    嘿嘿今天学了快速幂也~~ Problem Description: 求x的y次方的最后三位数 . Input: 一个两位数x和一个两位数y. Output: 输出x的y次方的后三位数. Sample ...

  2. 【python】题目:有1、2、3、4个数字,能组成多少个互不相同且无重复数字的三位数?都是多少?

    # encoding:utf-8 # p001_1234threeNums.py def threeNums(): '''题目:有1.2.3.4个数字,能组成多少个互不相同且无重复数字的三位数?都是多 ...

  3. 程序设计入门——C语言 第1周编程练习 1逆序的三位数(5分)

    第1周编程练习 查看帮助 返回   第1周编程练习题,直到课程结束之前随时可以来做.在自己的IDE或编辑器中完成作业后,将源代码的全部内容拷贝.粘贴到题目的代码区,就可以提交,然后可以查看在线编译和运 ...

  4. 题目:打印出所有的 "水仙花数 ",所谓 "水仙花数 "是指一个三位数,其各位数字立方和等于该数本身。例如:153是一个 "水仙花 数 ",因为153=1的三次方+5的三次方+3的三次方。

    题目:打印出所有的 "水仙花数 ",所谓 "水仙花数 "是指一个三位数,其各位数字立方和等于该数本身.例如:153是一个 "水仙花 数 ", ...

  5. C++判断对称三位数素数

    题目内容:判断一个数是否为对称三位数素数.所谓“对称”是指一个数,倒过来还是该数.例如,375不是对称数,因为倒过来变成了573. 输入描述:输入数据含有不多于50个的正整数(0<n<23 ...

  6. HDU_2035——求A^B的最后三位数

    Problem Description 求A^B的最后三位数表示的整数.说明:A^B的含义是“A的B次方”   Input 输入数据包含多个测试实例,每个实例占一行,由两个正整数A和B组成(1< ...

  7. 网易云课堂_程序设计入门-C语言_第一周:简单的计算程序_1逆序的三位数

    1 逆序的三位数(5分) 题目内容: 程序每次读入一个正三位数,然后输出逆序的数字.注意,当输入的数字含有结尾的0时,输出不应带有前导的0.比如输入700,输出应该是7. 输入格式: 每个测试是一个3 ...

  8. js求三位数的和

    例如输入508就输出5+0+8的和13: <!DOCTYPE html> <html lang="en"> <head> <meta ch ...

  9. Java求555 555的约数中最大的三位数。

    package org.llh.test; /** * 求555 555的约数中最大的三位数 * @author llh * */ public class Car { //整数j除以整数i(i≠0) ...

随机推荐

  1. zabbix3.4web界面添加第一台被监控服务器图文教程

    zabbix工具监控服务器是以组的形式来管理,创建单个被监控服务器之前需要先创建一个主机组,然后将被监控机添加到这个组中即可 1 创建主机群组: 2 向主机群组中添加主机 3 向主机中添加模板,选择要 ...

  2. BootStrap------之模态框1

    <!DOCTYPE html> <html lang="zh-cn"> <head> <meta charset="utf-8& ...

  3. LG3809 【模板】后缀排序

    题意 题目背景 这是一道模板题. 题目描述 读入一个长度为 $ n $ 的由大小写英文字母或数字组成的字符串,请把这个字符串的所有非空后缀按字典序从小到大排序,然后按顺序输出后缀的第一个字符在原串中的 ...

  4. 鸟哥的linux私房菜第4版--自学笔记

    -----------------------------------第一章 intel芯片架构 PS:升级电脑还得看看主板是不是适合CPU,主板适合CPU的类型是有限的PS: 现在已经没有北桥了,已 ...

  5. webpack执行命令参数

    在webpack执行命令之后可以添加一些参数,这些参数都有自己的作用,下面是参数列表: $ webpack --config XXX.js //使用另一份配置文件(比如webpack.config2. ...

  6. Unity Shader 入门精要学习 (冯乐乐 著)

    第1篇 基础篇 第1章 欢迎来到Shader的世界 第2章 渲染流水线 第3章 Unity Shader 基础 第4章 学习Shader所需的数学基础 第2篇 初级篇 第5章 开始Unity Shad ...

  7. mysql中 where in 用法详解

    这里分两种情况来介绍 1.in 后面是记录集,如: select  *  from  table  where   uname  in(select  uname  from  user); 2.in ...

  8. HI3518E平台ISP调试环境搭建

    海思的SDK提供了ISP调试的相关工具,降低了IPC的ISP调试的难度.初次搭建ISP调试环境,记录一下. SDK版本:Hi3518_MPP_V1.0.A.0 硬件平台:HI3518E_OV9732 ...

  9. day 32 管道 事件 信号量 进程池

    一.管道(多个时数据不安全)   Pipe 类 (像队列一样,数据只能取走一次) conn1,conn2 = Pipe()     建立管道 .send()   发送 .recv()   接收 二.事 ...

  10. Qt开发问答

    Qt开发问答 1, Difference between Dialog and widget and QMainWindow http://www.qtcentre.org/threads/3465- ...