使用线性回归识别手写阿拉伯数字mnist数据集
学习了tensorflow的线性回归。
首先是一个sklearn中makeregression数据集,对其进行线性回归训练的例子。来自腾讯云实验室
import tensorflow as tf
import numpy as np
class linearRegressionModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self._index_in_epoch=0
self.constructModel()
self.sess=tf.Session()
self.sess.run(tf.global_variables_initializer())
#权重初始化
def weight_variable(self,shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
#偏置项初始化
def bais_variable(self,shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#获取数据块,每次选100个样本,如果选完,则重新打乱
def next_batch(self,batch_size):
start=self._index_in_epoch
self._index_in_epoch+=batch_size
if self._index_in_epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self._datas=self._datas[perm]
self._labels=self._labels[perm]
start=0
self._index_in_epoch=batch_size
assert batch_size<=self._num_datas
end=self._index_in_epoch
return self._datas[start:end],self._labels[start:end]
def constructModel(self):
self.x=tf.placeholder(tf.float32,[None,self.x_dimen])
self.y=tf.placeholder(tf.float32,[None,1])
self.w=self.weight_variable([self.x_dimen,1])
self.b=self.bais_variable([1])
self.y_prec=tf.nn.bias_add(tf.matmul(self.x,self.w),self.b)
mse=tf.reduce_mean(tf.squared_difference(self.y_prec,self.y))
l2=tf.reduce_mean(tf.square(self.w))
#self.loss=mse+0.15*l2
self.loss=mse
self.train_step=tf.train.AdamOptimizer(0.1).minimize(self.loss)
def train(self,x_train,y_train,x_test,y_test):
self._datas=x_train
self._labels=y_train
self._num_datas=x_train.shape[0]
for i in range(5000):
batch=self.next_batch(100)
self.sess.run(self.train_step,
feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if i%10==0:
train_loss=self.sess.run(self.loss,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
print("setp %d,test_loss %f"%(i,train_loss))
def predict_batch(self,arr,batchsize):
for i in range(0,len(arr),batchsize):
yield arr[i:i+batchsize]
def predict(self,x_predict):
pred_list=[]
for x_test_batch in self.predict_batch(x_predict,100):
pred =self.sess.run(self.y_prec,{self.x:x_test_batch})
pred_list.append(pred)
return np.vstack(pred_list)
仿照这个代码,联系使用线性回归的方法对mnist进行训练。开始选择学习率为0.1,结果训练失败,调节学习率为0.01.正确率在0.91左右
给出训练类:
import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm,:]
self.lables=self.lables[perm,:]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end,:],self.lables[start:end,:]
def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
self.lables=y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={ self.x: batch[0],
# self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))
然后是调用方法,包括了对这个mnist数据集的下载
from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) if __name__=='__main__': x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
linear = mlm(len(x_train[1]))
linear.train(x_train,y_train,x_test,y_test)
下载方法来自tensorflow的官方文档中文版
使用线性回归识别手写阿拉伯数字mnist数据集的更多相关文章
- stanford coursera 机器学习编程作业 exercise 3(使用神经网络 识别手写的阿拉伯数字(0-9))
本作业使用神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于使用逻辑回归实现多分类问题:识别手写的阿拉伯数字(0-9),请参考:http://www.cnblogs.com ...
- 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...
- TensorFlow实战之Softmax Regression识别手写数字
关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...
- Tensorflow搭建卷积神经网络识别手写英语字母
更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...
- 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...
- 3 TensorFlow入门之识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- KNN 算法-实战篇-如何识别手写数字
公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...
- 如何用卷积神经网络CNN识别手写数字集?
前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...
随机推荐
- 对IT技术开发职业生涯的思考
对职业生涯的思考 从刚毕业到目前所在公司,差不多6年了,想想这六年里面,自己的能力和刚毕业比有了很大的提升,但是现在在什么能力上,我不知道,毕竟没有去过别的公司.最近也在思考自己未来,算是比较迷茫阶段 ...
- Sentinel 简介与API订阅发布
Sentinel 简介 Redis 的 Sentinel 系统用于管理多个 redis 服务器(instance), 该系统执行以下三个任务: 监控(Monitoring): Sentinel 会不断 ...
- 实体格式化转xml
In the past, I've done the following to control datetime serialization: Ignore the DateTime property ...
- 解决Winform程序在不同分辨率系统下界面混乱
问题分析: 产生界面混乱的主要原因是,winform程序的坐标是基于点(Point)的,而Point又与DPI相关,具体就是 一英寸 =72Points 一英寸 = 96pixel ...
- QQMacMgr for Mac(腾讯电脑管家)安装
1.软件简介 腾讯电脑管家是 macOS 系统上一款由腾讯公司带来到的安全管理软件.功能有垃圾清理.软件仓库.小火箭加速和防钓鱼等.而在视觉 UI 上,导入星空概念,操作过场动画全部以星空为题材 ...
- [aaronyang原创] Mssql 一张表3列的sql面试题,看你sql学的怎么样
文章已经迁移到:http://www.ayjs.net/post/99.html 文章已经迁移到:http://www.ayjs.net/post/99.html 文章已经迁移到:http://www ...
- 在CentOS上编译安装MySQL 5.7.13步骤详解
MySQL 5.7主要特性 更好的性能 对于多核CPU.固态硬盘.锁有着更好的优化,每秒100W QPS已不再是MySQL的追求,下个版本能否上200W QPS才是用户更关心的. 更好的InnoDB存 ...
- 项目记录25--unity-tolua框架 View02---BasePanel.lua
还在,还在. ... . 每天晚上找点时间写点点,多了也不想学到底是什么心理啊. 写完看电影去. 今天写两个算超完毕了BaseUI.lua,UIManager.lua(完好中这个) local Bas ...
- jlink下载不进去程序
- 三角函数 与 JavaScript
三角函数 canvas 和 JavaScript 中所有与角相关的API如Math.sin().Math.cos().Math.tan(),都需要以弧度为单位值.但大部分人还是习惯以角度单位.所以 ...