1. #!/usr/bin/env python
  2. # coding=utf-8
  3.  
  4. from keras.models import Sequential
  5. from keras.layers import Activation, TimeDistributed, Dense, RepeatVector, recurrent
  6. import numpy as np
  7. import string
  8. import random
  9.  
  10. class CharacterTable(object):
  11.  
  12. def __init__(self, maxlen):
  13. self.chars = string.digits + '+ '
  14. self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
  15. self.indice_chars = dict((i, c) for i, c in enumerate(self.chars))
  16. self.maxlen = maxlen
  17.  
  18. def encode(self, strs, maxlen=None):
  19. maxlen = maxlen if maxlen else self.maxlen
  20. vec = np.zeros((maxlen, len(self.chars)))
  21. for i, c in enumerate(strs):
  22. vec[i, self.char_indices[c]] = 1
  23. return vec
  24.  
  25. def decode(self, vec, calc_argmax=True):
  26. if calc_argmax:
  27. vec = vec.argmax(axis=-1)
  28. return ''.join(self.indice_chars[x] for x in vec)
  29.  
  30. def gen_num():
  31. nums = random.sample('', random.randint(1, 3))
  32. return int(''.join(nums))
  33.  
  34. MAXLEN = 7 # 3+3+1
  35. ctable = CharacterTable(MAXLEN)
  36.  
  37. questions, expected = [], []
  38. seen = set()
  39. i = 0
  40. while i < 50000:
  41. a, b = gen_num(), gen_num()
  42. key = tuple(sorted((a, b)))
  43. if key in seen:
  44. continue
  45. seen.add(key)
  46. q = '{}+{}'.format(a, b)
  47. query = q + ' '*(7-len(q))
  48. ans = str(a+b)
  49. ans += ' ' * (4-len(ans))
  50.  
  51. questions.append(query)
  52. expected.append(ans)
  53. i += 1
  54. print('total questions', len(questions))
  55.  
  56. X = np.zeros((len(questions), MAXLEN, len(ctable.chars)), dtype=np.bool)
  57. y = np.zeros((len(questions), 4, len(ctable.chars)), dtype=np.bool)
  58.  
  59. for i, sent in enumerate(questions):
  60. X[i] = ctable.encode(sent)
  61.  
  62. for i, sent in enumerate(expected):
  63. y[i] = ctable.encode(sent, 4)
  64.  
  65. model = Sequential()
  66. model.add(recurrent.LSTM(128, input_shape=(7, len(ctable.chars))))
  67. model.add(RepeatVector(4))
  68. model.add(recurrent.LSTM(128, return_sequences=True))
  69. model.add(recurrent.LSTM(128, return_sequences=True))
  70.  
  71. model.add(TimeDistributed(Dense(len(ctable.chars))))
  72. model.add(Activation('softmax'))
  73.  
  74. model.compile(loss='categorical_crossentropy',
  75. optimizer='adam',
  76. metrics=['accuracy'])
  77.  
  78. model.fit(X, y, batch_size=64, nb_epoch=20, validation_split=0.02, verbose=2)
  79.  
  80. # 测试看看
  81. for i in range(10):
  82. ind = np.random.randint(0, len(questions)-5)
  83. x_test, y_test = X[ind:ind+5], y[ind:ind+5]
  84. y_preds = model.predict_classes(x_test, verbose=0)
  85. print('Q', ctable.decode(x_test[0]))
  86. print('T', ctable.decode(y_test[0]))
  87. print('Pred', ctable.decode(y_preds[0], calc_argmax=False))
  88.  
  89. json_string = model.to_json()
  90. with open('rnn_add_model.json', 'wb') as fw:
  91. fw.write(json_string)
  92. model.save_weights('rnn_add_model.h5')

基本是模仿官网例子,精简了一点,训练约1h, 准确率99.6%

rnn实现三位数加法的训练的更多相关文章

  1. GDUFE-OJ 1203x的y次方的最后三位数 快速幂

    嘿嘿今天学了快速幂也~~ Problem Description: 求x的y次方的最后三位数 . Input: 一个两位数x和一个两位数y. Output: 输出x的y次方的后三位数. Sample ...

  2. 【python】题目:有1、2、3、4个数字,能组成多少个互不相同且无重复数字的三位数?都是多少?

    # encoding:utf-8 # p001_1234threeNums.py def threeNums(): '''题目:有1.2.3.4个数字,能组成多少个互不相同且无重复数字的三位数?都是多 ...

  3. 程序设计入门——C语言 第1周编程练习 1逆序的三位数(5分)

    第1周编程练习 查看帮助 返回   第1周编程练习题,直到课程结束之前随时可以来做.在自己的IDE或编辑器中完成作业后,将源代码的全部内容拷贝.粘贴到题目的代码区,就可以提交,然后可以查看在线编译和运 ...

  4. 题目:打印出所有的 "水仙花数 ",所谓 "水仙花数 "是指一个三位数,其各位数字立方和等于该数本身。例如:153是一个 "水仙花 数 ",因为153=1的三次方+5的三次方+3的三次方。

    题目:打印出所有的 "水仙花数 ",所谓 "水仙花数 "是指一个三位数,其各位数字立方和等于该数本身.例如:153是一个 "水仙花 数 ", ...

  5. C++判断对称三位数素数

    题目内容:判断一个数是否为对称三位数素数.所谓“对称”是指一个数,倒过来还是该数.例如,375不是对称数,因为倒过来变成了573. 输入描述:输入数据含有不多于50个的正整数(0<n<23 ...

  6. HDU_2035——求A^B的最后三位数

    Problem Description 求A^B的最后三位数表示的整数.说明:A^B的含义是“A的B次方”   Input 输入数据包含多个测试实例,每个实例占一行,由两个正整数A和B组成(1< ...

  7. 网易云课堂_程序设计入门-C语言_第一周:简单的计算程序_1逆序的三位数

    1 逆序的三位数(5分) 题目内容: 程序每次读入一个正三位数,然后输出逆序的数字.注意,当输入的数字含有结尾的0时,输出不应带有前导的0.比如输入700,输出应该是7. 输入格式: 每个测试是一个3 ...

  8. js求三位数的和

    例如输入508就输出5+0+8的和13: <!DOCTYPE html> <html lang="en"> <head> <meta ch ...

  9. Java求555 555的约数中最大的三位数。

    package org.llh.test; /** * 求555 555的约数中最大的三位数 * @author llh * */ public class Car { //整数j除以整数i(i≠0) ...

随机推荐

  1. 流程控制之if

    流程控制 假如把写程序比做走路,那我们到现在为止,一直走的都是直路,还没遇到过分叉口,想象现实中,你遇到了分叉口,然后你决定往哪拐必然是有所动机的.你要判断那条岔路是你真正要走的路,如果我们想让程序也 ...

  2. 有关O_APPEND标志和lseek()的使用

    编程之路刚刚开始,错误难免,希望大家能够指出. O_APPEND表示以每次写操作都写入文件的末尾.lseek()可以调整文件读写位置. <<Linux/UNIX系统编程手册>> ...

  3. 使用mongoose连接mongodb(转载文章)

    mongodb数据库 MongoDB是一个高效的基于分布式文件存储的数据库,将数据存储为一个文档,数据结构由键值(key=>value)对组成.MongoDB 文档类似于 JSON 对象.字段值 ...

  4. MySQL--NUMA与MySQL

    ============================================================= NUMA(Non-Uniform Memory Access),非一致性内存 ...

  5. MySQL Binlog信息查看

    ##=====================================## ## 在MySQL内部查看binlog文件列表 ## SHOW BINARY LOGS; ##=========== ...

  6. 枚举 Java Enumeration接口

    Enumation 定义了一些方法,通过这些方法可以枚举对象集合中的元素 如: boolean hasMoreElements() 测试此枚举是否包含更多的元素 object nextElement( ...

  7. Netty 学习资料

    Netty 学习资料 Netty 学习资料 链接网址 说明 Netty 4.x 用户指南 http://wiki.jikexueyuan.com/project/netty-4-user-guide/ ...

  8. webpack 相关插件及作用(表格)

    webpack 相关插件及作用: table th:first-of-type { width: 200px; } table th:nth-of-type(2) { width: 140px; } ...

  9. DNS压力测试

    安装 queryperf cd /usr/local/src wget http://ftp.isc.org/isc/bind9/9.12.1/bind-9.12.1.tar.gz 编译querype ...

  10. Logstash的grok以及Ruby

    logstash的grok插件的用途是提取字段,将非格式的内容进行格式化, input { file { path => "/var/log/http.log" } } fi ...