import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #this is data
mnist = input_data.read_data_sets("MNIST_data",one_hot=True) lr = 0.001
train_iters = 10000
batch_size = 128
display_step = 10 n_inputs = 28
n_steps = 28
n_hidden_unis = 128
n_classes = 10 x = tf.placeholder(tf.float32,[None,n_steps,n_inputs])
y = tf.placeholder(tf.float32,[None,n_classes]) #define weight
weights = {
#(28,128)
"in":tf.Variable(tf.random_normal([n_inputs,n_hidden_unis])),
#(128,10)
"out":tf.Variable(tf.random_normal([n_hidden_unis,n_classes]))
}
biases = {
#(128,)
"in":tf.Variable(tf.constant(0.1,shape=[n_hidden_unis,])),
#(10,)
"out":tf.Variable(tf.constant(0.1,shape=[n_classes,]))
} def RNN(X,weights,biases):
#形状变换成lstm可以训练的维度
X = tf.reshape(X,[-1,n_inputs]) #(128*28,28)
X_in = tf.matmul(X,weights["in"])+biases["in"] #(128*28,128)
X_in = tf.reshape(X_in,[-1,n_steps,n_hidden_unis]) #(128,28,128) #cell
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden_unis,forget_bias=1.0,state_is_tuple=True)
#lstm cell is divided into two parts(c_state,m_state)
_init_state = lstm_cell.zero_state(batch_size,dtype=tf.float32) outputs,states = tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=_init_state,time_major = False) #outputs
# results = tf.matmul(states[1],weights["out"])+biases["out"]
#or
outputs = tf.transpose(outputs,[1,0,2])
results = tf.matmul(outputs[-1],weights["out"])+biases["out"] return results pred = RNN(x,weights,biases)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred,labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(loss) correct_pred = tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32)) init = tf.initialize_all_variables() with tf.Session() as sess:
sess.run(init)
step = 0
while step*batch_size < train_iters:
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape([batch_size,n_steps,n_inputs])
sess.run(train_op,feed_dict={x:batch_xs,y:batch_ys})
if step%20 ==0:
print(sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys}))

  

tensorflow1.0 构建lstm做图片分类的更多相关文章

  1. tensorflow1.0 构建神经网络做图片分类

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_dat ...

  2. tensorflow1.0 构建神经网络做非线性归回

    """ Please note, this code is only for python 3+. If you are using python 2+, please ...

  3. tensorflow1.0 构建卷积神经网络

    import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os os.envi ...

  4. 深度学习之神经网络核心原理与算法-caffe&keras框架图片分类

    之前我们在使用cnn做图片分类的时候使用了CIFAR-10数据集 其他框架对于CIFAR-10的图片分类是怎么做的 来与TensorFlow做对比. Caffe Keras 安装 官方安装文档: ht ...

  5. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  6. 5分钟Serverless实践:构建无服务器的图片分类系统

    前言 在过去“5分钟Serverless实践”系列文章中,我们介绍了如何构建无服务器API和Web应用,从本质上来说,它们都属于基于APIG触发器对外提供一个无服务器API的场景.现在本文将介绍一种新 ...

  7. 第二十二节,TensorFlow中的图片分类模型库slim的使用、数据集处理

    Google在TensorFlow1.0,之后推出了一个叫slim的库,TF-slim是TensorFlow的一个新的轻量级的高级API接口.这个模块是在16年新推出的,其主要目的是来做所谓的“代码瘦 ...

  8. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...

  9. 源码分析——迁移学习Inception V3网络重训练实现图片分类

    1. 前言 近些年来,随着以卷积神经网络(CNN)为代表的深度学习在图像识别领域的突破,越来越多的图像识别算法不断涌现.在去年,我们初步成功尝试了图像识别在测试领域的应用:将网站样式错乱问题.无线领域 ...

随机推荐

  1. 仅用200个样本就能得到当前最佳结果:手写字符识别新模型TextCaps

    由于深度学习近期取得的进展,手写字符识别任务对一些主流语言来说已然不是什么难题了.但是对于一些训练样本较少的非主流语言来说,这仍是一个挑战性问题.为此,本文提出新模型TextCaps,它每类仅用200 ...

  2. html-css:浮动_清除浮动

    1.浮动 清除浮动之前我们首先需要了解为什么要清除浮动 1. 假设我们有一个父盒子,不设置高度,其高度有内部子盒子的大小自动撑开,这样是完全可行的,因为有时候我们并不想直接固定死父盒子的大小,而是根据 ...

  3. 已知IP地址和子网掩码求出网络地址、广播地址、地址范围和主机数(转载https://blog.csdn.net/qq_39026548/article/details/78959089)

    假设IP地址为128.11.67.31,子网掩码是255.255.240.0.请算出网络地址.广播地址.地址范围.主机数.方法:将IP地址和子网掩码转化成二进制形式,然后进行后续操作. IP地址和子网 ...

  4. JVM中内存分配策略及堆和栈的比较

    最近愈发对JVM底层的运行 原理产生了兴趣,遂查阅相关资料以备忘. 内存分配策略 根据编译原理的观点,程序运行时的内存分配,有三种策略,分别为静态的.堆式的.栈式的. 静态存储分配指的是在编译时就能确 ...

  5. SQL Server Profiler常用功能

    最近因调研Linq to object 和Linq to Entity的数据组合查询问题,需要用到Sql Server Profiler检测在数据上执行的语句,在调试sql语句时,给了很大的帮助. 这 ...

  6. [转载]Docker容器无法被stop or kill问题

    来源: 问题过程 某环境一个mysql容器无法被stop or kill or rm sudo docker ps | grep mysql 查看该容器 7844250860f8 mysql:5.7. ...

  7. 类实例调用静态方法(Java)

    前言 第一次看到在Java中是可以通过类实例调用静态方法,当然不推荐这么做,接下来会讲到,但是在C#中通过类实例调用静态方法在编译时就不会通过,这里做下记录. 类实例调用静态方法 首先我们来看一个简单 ...

  8. mysql正则匹配中文时存在的问题

    可以看到,目前正则匹配字母没问题,c出现1次,2次,3次匹配的结果都是正常的 接下来我们看看匹配中文的效果 可以看到,当匹配连续出现歪时,结果就开始不正常了 然后我去看了下mysql的中文文档中关于正 ...

  9. redis持久化(RDB、AOF、混合持久化)

    redis持久化(RDB.AOF.混合持久化) 1. RDB快照(snapshot) 在默认情况下, Redis 将内存数据库快照保存在名字为 dump.rdb 的二进制文件中. 你可以对 Redis ...

  10. Centos 7 系统定时重启

    crontab -e         //系统命令 00 08 * * * root systemctl restart docker00 08 * * * root reboot //写入需要重启的 ...