使用线性回归识别手写阿拉伯数字mnist数据集
学习了tensorflow的线性回归。
首先是一个sklearn中makeregression数据集,对其进行线性回归训练的例子。来自腾讯云实验室
import tensorflow as tf
import numpy as np
class linearRegressionModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self._index_in_epoch=0
self.constructModel()
self.sess=tf.Session()
self.sess.run(tf.global_variables_initializer())
#权重初始化
def weight_variable(self,shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
#偏置项初始化
def bais_variable(self,shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#获取数据块,每次选100个样本,如果选完,则重新打乱
def next_batch(self,batch_size):
start=self._index_in_epoch
self._index_in_epoch+=batch_size
if self._index_in_epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self._datas=self._datas[perm]
self._labels=self._labels[perm]
start=0
self._index_in_epoch=batch_size
assert batch_size<=self._num_datas
end=self._index_in_epoch
return self._datas[start:end],self._labels[start:end]
def constructModel(self):
self.x=tf.placeholder(tf.float32,[None,self.x_dimen])
self.y=tf.placeholder(tf.float32,[None,1])
self.w=self.weight_variable([self.x_dimen,1])
self.b=self.bais_variable([1])
self.y_prec=tf.nn.bias_add(tf.matmul(self.x,self.w),self.b)
mse=tf.reduce_mean(tf.squared_difference(self.y_prec,self.y))
l2=tf.reduce_mean(tf.square(self.w))
#self.loss=mse+0.15*l2
self.loss=mse
self.train_step=tf.train.AdamOptimizer(0.1).minimize(self.loss)
def train(self,x_train,y_train,x_test,y_test):
self._datas=x_train
self._labels=y_train
self._num_datas=x_train.shape[0]
for i in range(5000):
batch=self.next_batch(100)
self.sess.run(self.train_step,
feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if i%10==0:
train_loss=self.sess.run(self.loss,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
print("setp %d,test_loss %f"%(i,train_loss))
def predict_batch(self,arr,batchsize):
for i in range(0,len(arr),batchsize):
yield arr[i:i+batchsize]
def predict(self,x_predict):
pred_list=[]
for x_test_batch in self.predict_batch(x_predict,100):
pred =self.sess.run(self.y_prec,{self.x:x_test_batch})
pred_list.append(pred)
return np.vstack(pred_list)
仿照这个代码,联系使用线性回归的方法对mnist进行训练。开始选择学习率为0.1,结果训练失败,调节学习率为0.01.正确率在0.91左右
给出训练类:
import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm,:]
self.lables=self.lables[perm,:]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end,:],self.lables[start:end,:]
def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
self.lables=y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={ self.x: batch[0],
# self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))
然后是调用方法,包括了对这个mnist数据集的下载
from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) if __name__=='__main__': x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
linear = mlm(len(x_train[1]))
linear.train(x_train,y_train,x_test,y_test)
下载方法来自tensorflow的官方文档中文版
使用线性回归识别手写阿拉伯数字mnist数据集的更多相关文章
- stanford coursera 机器学习编程作业 exercise 3(使用神经网络 识别手写的阿拉伯数字(0-9))
本作业使用神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于使用逻辑回归实现多分类问题:识别手写的阿拉伯数字(0-9),请参考:http://www.cnblogs.com ...
- 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...
- TensorFlow实战之Softmax Regression识别手写数字
关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...
- Tensorflow搭建卷积神经网络识别手写英语字母
更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...
- 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...
- 3 TensorFlow入门之识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- KNN 算法-实战篇-如何识别手写数字
公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...
- 如何用卷积神经网络CNN识别手写数字集?
前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...
随机推荐
- Linux主要shell命令详解(中)
shell中的特殊字符 shell中除使用普通字符外,还可以使用一些具有特殊含义和功能的特殊字符.在使用它们时应注意其特殊的含义和作用范围.下面分别对这些特殊字符加以介绍. 1. 通配符 通配符用于模 ...
- js跨域问题解释 使用jsonp或jQuery的解决方案
js跨域及解决方案 1.什么是跨域 我们经常会在页面上使用ajax请求访问其他服务器的数据,此时,客户端会出现跨域问题. 跨域问题是由于javascript语言安全限制中的同源策略造成的. 简单来说, ...
- 基于24位AD转换模块HX711的重量称量实验(已补充皮重存储,线性温度漂移修正)
转载:http://www.geek-workshop.com/thread-2315-1-1.html 以前在X宝上买过一个称重放大器,180+大洋.原理基本上就是把桥式拉力传感器输出的mV级信号放 ...
- Mac Apache ZooKeeper 配置
1.配置准备工作 1)配置 ZooKeeper 准备工作 下载相关软件 apache-zookeeper-v3.4.10.zip ZooKeeper 官网 ZooKeeper 配置软件下载地址,密码: ...
- 【C语言】练习3-5
题目来源:<The C programming language>中的习题P51 练习2-1: 编写函数itob(n, s, b),将整数n转换为以b为底的数,并将转换结果以字符的形 ...
- [企业化NET]Window Server 2008 R2[3]-SVN 服务端 和 客户端 基本使用
1. 服务器基本安装即问题解决记录 √ 2. SVN环境搭建和客户端使用 2.1 服务端 和 客户端 安装 √ 2.2 项目建立与基本使用 √ 2.3 基本冲突解决, ...
- SharePoint利用HttpModule的Init方法实现全局初始化
接上篇 我们知道,HttpRuntime中会对每一个Request创建一个HttpApplication对象(HttpApplicationFactory从一个HttpApplication池来拿). ...
- 数据库查询实例(包含所有where条件例子)
查询指定列 [例1] 查询全体学生的学号与姓名. SELECT Sno,Sname FROM Student: [例2] 查询全体学生的姓名.学号.所在系. SELECT Sname,Sno,Sdep ...
- proguard的简单配置说明
#需要转换的jar文件路径-injars 'D:\fs-np.jar'#转换后的jar文件名称-outjars 'D:\fs-np-sec.jar' #关联的第三方jar-libraryjars 'C ...
- Maven pom.xml中的元素modules、parent、properties以及import(转)
前言 项目中用到了maven,而且用到的内容不像利用maven/eclipse搭建ssm(spring+spring mvc+mybatis)用的那么简单:maven的核心是pom.xml,那么我就它 ...