TensorFlow技术解析与实战学习笔记(13)------Mnist识别和卷积神经网络AlexNet
一、AlexNet:共8层:5个卷积层(卷积+池化)、3个全连接层,输出到softmax层,产生分类。
论文中lrn层推荐的参数:depth_radius = 4,bias = 1.0 , alpha = 0.001 / 9.0 , beta = 0.75
lrn现在仅在AlexNet中使用,主要是别的卷积神经网络模型效果不明显。而LRN在AlexNet中会让前向和后向速度下降,(下降1/3)。
【训练时耗时是预测的3倍】
代码:
- #加载数据
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets("MNIST_data/",one_hot = True)
- #定义卷积操作
- def conv2d(name , input_x , w , b , stride = 1,padding = 'SAME'):
- conv = tf.nn.conv2d(input_x,w,strides = [1,stride,stride,1],padding = padding , name = name)
- return tf.nn.relu(tf.nn.bias_add(conv,b))
- def max_pool(name , input_x , k=2):
- return tf.nn.max_pool(input_x,ksize = [1,k,k,1],strides = [1,k,k,1],padding = 'SAME' , name = name)
- def norm(name , input_x , lsize = 4):
- return tf.nn.lrn(input_x , lsize , bias = 1.0 , alpha = 0.001 / 9.0 , beta = 0.75 , name = name)
- def buildGraph(x,learning_rate,weight,bias,dropout):
- #############前向传播##################
- #定义网络
- x = tf.reshape(x , [-1,28,28,1])
- #第一层卷积
- with tf.variable_scope('layer1'):
- conv1 = conv2d('conv1',x,weight['wc1'],bias['bc1'])
- pool1 = max_pool('pool1',conv1)
- norm1 = norm('norm1',pool1)
- with tf.variable_scope('layer2'):
- conv2 = conv2d('conv2',norm1,weight['wc2'],bias['bc2'])
- pool2 = max_pool('pool2',conv2)
- norm2 = norm('norm2',pool2)
- with tf.variable_scope('layer3'):
- conv3 = conv2d('conv3',norm2,weight['wc3'],bias['bc3'])
- pool3 = max_pool('pool3',conv3)
- norm3 = norm('norm3',pool3)
- with tf.variable_scope('layer4'):
- conv4 = conv2d('conv4',norm3,weight['wc4'],bias['bc4'])
- with tf.variable_scope('layer5'):
- conv5 = conv2d('conv5',conv4,weight['wc5'],bias['bc5'])
- pool5 = max_pool('pool5',conv5)
- norm5 = norm('norm5',pool5)
- with tf.variable_scope('func1'):
- norm5 = tf.reshape(norm5,[-1,4*4*256])
- fc1 = tf.add(tf.matmul(norm5,weight['wf1']) , bias['bf1'])
- fc1 = tf.nn.relu(fc1)
- #dropout
- fc1 = tf.nn.dropout(fc1,dropout)
- with tf.variable_scope('func2'):
- fc2 = tf.reshape(fc1,[-1,weight['wf1'].get_shape().as_list()[0]])
- fc2 = tf.add(tf.matmul(fc1,weight['wf2']),bias['bf2'])
- fc2 = tf.nn.relu(fc2)
- #dropout
- fc2 = tf.nn.dropout(fc2,dropout)
- with tf.variable_scope('outlayer'):
- out = tf.add(tf.matmul(fc2,weight['w_out']),bias['b_out'])
- return out
- def train(mnist):
- #定义网络的超参数
- learning_rate = 0.001
- training_step = 20000
- batch_size = 128
- #定义网络的参数
- n_input = 784
- n_output = 10
- dropout = 0.75
- #x、y的占位
- x = tf.placeholder(tf.float32,[None,784])
- y = tf.placeholder(tf.float32,[None,10])
- keep_prob = tf.placeholder(tf.float32)
- #权重和偏置的设置
- weight = {
- 'wc1':tf.Variable(tf.truncated_normal([11,11,1,96],stddev = 0.1)),
- 'wc2':tf.Variable(tf.truncated_normal([5,5,96,256],stddev = 0.1)),
- 'wc3':tf.Variable(tf.truncated_normal([3,3,256,384],stddev = 0.1)),
- 'wc4':tf.Variable(tf.truncated_normal([3,3,384,384],stddev = 0.1)),
- 'wc5':tf.Variable(tf.truncated_normal([3,3,384,256],stddev = 0.1)),
- 'wf1':tf.Variable(tf.truncated_normal([4*4*256,4096])),
- 'wf2':tf.Variable(tf.truncated_normal([4096,4096])),
- 'w_out':tf.Variable(tf.truncated_normal([4096,10]))
- }
- bias = {
- 'bc1':tf.Variable(tf.constant(0.1,shape = [96])),
- 'bc2':tf.Variable(tf.constant(0.1,shape =[256])),
- 'bc3':tf.Variable(tf.constant(0.1,shape =[384])),
- 'bc4':tf.Variable(tf.constant(0.1,shape =[384])),
- 'bc5':tf.Variable(tf.constant(0.1,shape =[256])),
- 'bf1':tf.Variable(tf.constant(0.1,shape =[4096])),
- 'bf2':tf.Variable(tf.constant(0.1,shape =[4096])),
- 'b_out':tf.Variable(tf.constant(0.1,shape =[10]))
- }
- out = buildGraph(x,learning_rate,weight,bias,keep_prob)
- ####################后向传播####################
- #定义损失函数
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=out))
- optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss)
- #评估函数
- correction = tf.equal(tf.argmax(out,1),tf.argmax(y,1))
- acc = tf.reduce_mean(tf.cast(correction,tf.float32))
- #####################################开始训练##############################
- init = tf.global_variables_initializer()
- with tf.Session() as sess:
- sess.run(init)
- step = 1
- while step <= training_step:
- batch_x , batch_y = mnist.train.next_batch(batch_size)
- sess.run(out,feed_dict = {x:batch_x,y:batch_y,keep_prob:dropout})
- print(out.shape)
- sess.run(optimizer,feed_dict = {x:batch_x,y:batch_y,keep_prob:dropout})
- if step % 500 == 0:
- loss , acc = sess.run([loss,acc],feed_dict = {x:batch_x,y:batch_y,keep_prob:1})
- print(step,loss,acc)
- step += 1
- print(sess.run(acc,feed_dict = {x:mnist.test.images[:256],y:mnist.test.images[:256],keep_prob:1}))
- if __name__=='__main__':
- train(mnist)
TensorFlow技术解析与实战学习笔记(13)------Mnist识别和卷积神经网络AlexNet的更多相关文章
- TensorFlow技术解析与实战学习笔记(15)-----MNIST识别(LSTM)
一.任务:采用基本的LSTM识别MNIST图片,将其分类成10个数字. 为了使用RNN来分类图片,将每张图片的行看成一个像素序列,因为MNIST图片的大小是28*28像素,所以我们把每一个图像样本看成 ...
- 学习TF:《TensorFlow技术解析与实战》PDF+代码
TensorFlow 是谷歌公司开发的深度学习框架,也是目前深度学习的主流框架之一.<TensorFlow技术解析与实战>从深度学习的基础讲起,深入TensorFlow框架原理.模型构建. ...
- TensorFlow+实战Google深度学习框架学习笔记(12)------Mnist识别和卷积神经网络LeNet
一.卷积神经网络的简述 卷积神经网络将一个图像变窄变长.原本[长和宽较大,高较小]变成[长和宽较小,高增加] 卷积过程需要用到卷积核[二维的滑动窗口][过滤器],每个卷积核由n*m(长*宽)个小格组成 ...
- 《Tensorflow技术解析与实战》第四章
Tensorflow基础知识 Tensorflow设计理念 (1)将图的定义和图的运行完全分开,因此Tensorflow被认为是一个"符合主义"的库 (2)Tensorflow中涉 ...
- 学习笔记TF058:人脸识别
人脸识别,基于人脸部特征信息识别身份的生物识别技术.摄像机.摄像头采集人脸图像或视频流,自动检测.跟踪图像中人脸,做脸部相关技术处理,人脸检测.人脸关键点检测.人脸验证等.<麻省理工科技评论&g ...
- 机器学习实战 - 读书笔记(13) - 利用PCA来简化数据
前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习心得,这次是第13章 - 利用PCA来简化数据. 这里介绍,机器学习中的降维技术,可简化样品数据. ...
- SQL反模式学习笔记13 使用索引
目标:优化性能 改善性能最好的技术就是在数据库中合理地使用索引. 索引也是数据结构,它能使数据库将指定列中的某个值快速定位在相应的行. 反模式:无规划的使用索引 1.不使用索引或索引不足 2.使用了 ...
- Elasticsearch技术解析与实战 PDF (内含目录)
Elasticsearch技术解析与实战 介绍: Elasticsearch是一个强[0大0]的搜索引擎,提供了近实时的索引.搜索.分 ...
- elasticsearch技术解析与实战ES
elasticsearch技术解析与实战ES 下载地址: https://pan.baidu.com/s/1NpPX05C0xKx_w9gBYaMJ5w 扫码下面二维码关注公众号回复100008 获取 ...
随机推荐
- 【例题4-2 uva489】Hangman Judge
[链接] 我是链接,点我呀:) [题意] 在这里输入题意 [题解] 水题. 中间就赢了算赢.(重复说,算错 [代码] #include <bits/stdc++.h> using name ...
- 0818基于360开源数据库流量审计MySQL Sniffer
开源数据库流量审计MySQL Sniffer 我最推崇的数据库安全产品就是基于流量的数据库审计,因为它不需要更改网络结构,并且也是最关键的是,不影响数据库服务器性能,不用苦口婆心的劝数据库管理员安装监 ...
- asp.net--webconfg指南
原文链接 花了点时间整理了一下ASP.NET Web.config配置文件的基本使用方法.很适合新手参看,由于Web.config在使用很灵活,可以自定义一些节点.所以这里只介绍一些比较常用的节点. ...
- POJ 1084
WA了好久,第一次用重覆盖的模型做题.感觉这题有个陷阱,那就是当去掉某些边后,若因为这个边去掉而被破环的正方形还存在,那么就会造成覆盖不完全,WA. 所以,在去掉边后,必定有些正方形是不存在的,须重新 ...
- [Javascript Crocks] Safely Access Object Properties with `prop`
In this lesson, we’ll use a Maybe to safely operate on properties of an object that could be undefin ...
- 怎样手动的干净的删除linux上的ORACLE数据库
近期在用VMWARE虚拟机做ORACLE的数据库实验.我们都知道在WINDOWS上,我能够到加入删除程序里去自己主动删除已经安装的全部的应用程序.可是在LINUX上没有这个服务能够进行自己主动的删除. ...
- Notepad++ 设置执行 lua 和 python
Notepad++ 设置执行 lua 和 python 一.设置 run -> 设置 cmd /k lua "$(FULL_CURRENT_PATH)" & PAUS ...
- Codeforces 115A- Party(DFS)
A. Party time limit per test 3 seconds memory limit per test 256 megabytes input standard input outp ...
- WndProc函数参数列表
protected override void WndProc(ref Message m) 实现了这一点. 重写WndProc函数,可以捕捉所有窗口发生的消息.这样,我们就可以"篡改&qu ...
- luogu1120 小木棍【数据加强版】 暴力剪枝
题目大意 乔治有一些同样长的小木棍,他把这些木棍随意砍成几段,直到每段的长都不超过50.现在,他想把小木棍拼接成原来的样子,但是却忘记了自己开始时有多少根木棍和它们的长度.给出每段小木棍的长度,编程帮 ...