学习了tensorflow的线性回归。

首先是一个sklearn中makeregression数据集,对其进行线性回归训练的例子。来自腾讯云实验室

import tensorflow as tf
import numpy as np
class linearRegressionModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self._index_in_epoch=0
self.constructModel()
self.sess=tf.Session()
self.sess.run(tf.global_variables_initializer())
#权重初始化
def weight_variable(self,shape):
initial=tf.truncated_normal(shape,stddev=0.1)
return tf.Variable(initial)
#偏置项初始化
def bais_variable(self,shape):
initial=tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#获取数据块,每次选100个样本,如果选完,则重新打乱
def next_batch(self,batch_size):
start=self._index_in_epoch
self._index_in_epoch+=batch_size
if self._index_in_epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self._datas=self._datas[perm]
self._labels=self._labels[perm]
start=0
self._index_in_epoch=batch_size
assert batch_size<=self._num_datas
end=self._index_in_epoch
return self._datas[start:end],self._labels[start:end]
def constructModel(self):
self.x=tf.placeholder(tf.float32,[None,self.x_dimen])
self.y=tf.placeholder(tf.float32,[None,1])
self.w=self.weight_variable([self.x_dimen,1])
self.b=self.bais_variable([1])
self.y_prec=tf.nn.bias_add(tf.matmul(self.x,self.w),self.b)
mse=tf.reduce_mean(tf.squared_difference(self.y_prec,self.y))
l2=tf.reduce_mean(tf.square(self.w))
#self.loss=mse+0.15*l2
self.loss=mse
self.train_step=tf.train.AdamOptimizer(0.1).minimize(self.loss)
def train(self,x_train,y_train,x_test,y_test):
self._datas=x_train
self._labels=y_train
self._num_datas=x_train.shape[0]
for i in range(5000):
batch=self.next_batch(100)
self.sess.run(self.train_step,
feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if i%10==0:
train_loss=self.sess.run(self.loss,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
print("setp %d,test_loss %f"%(i,train_loss))
def predict_batch(self,arr,batchsize):
for i in range(0,len(arr),batchsize):
yield arr[i:i+batchsize]
def predict(self,x_predict):
pred_list=[]
for x_test_batch in self.predict_batch(x_predict,100):
pred =self.sess.run(self.y_prec,{self.x:x_test_batch})
pred_list.append(pred)
return np.vstack(pred_list)

仿照这个代码,联系使用线性回归的方法对mnist进行训练。开始选择学习率为0.1,结果训练失败,调节学习率为0.01.正确率在0.91左右

给出训练类:

import tensorflow as tf
import numpy as np
class myLinearModel:
def __init__(self,x_dimen):
self.x_dimen=x_dimen
self.epoch=0
self._num_datas=0
self.datas=None
self.lables=None
self.constructModel()
def get_weiInit(self,shape):
weiInit=tf.truncated_normal(shape)
return tf.Variable(weiInit)
def get_biasInit(self,shape):
biasInit=tf.constant(0.1,shape=shape)
return tf.Variable(biasInit)
def constructModel(self):
self.x = tf.placeholder(dtype=tf.float32,shape=[None,self.x_dimen])
self.y=tf.placeholder(dtype=tf.float32,shape=[None,10])
self.weight=self.get_weiInit([self.x_dimen,10])
self.bias=self.get_biasInit([10])
self.y_pre=tf.nn.softmax(tf.matmul(self.x,self.weight)+self.bias)
self.correct_mat=tf.equal(tf.argmax(self.y_pre,1),tf.argmax(self.y,1))
#self.loss=tf.reduce_mean(tf.squared_difference(self.y_pre,self.y))
self.loss=-tf.reduce_sum(self.y*tf.log(self.y_pre))
self.train_step = tf.train.GradientDescentOptimizer(0.01).minimize(self.loss)
self.accuracy=tf.reduce_mean(tf.cast(self.correct_mat,"float"))
def next_batch(self,batchsize):
start=self.epoch
self.epoch+=batchsize
if self.epoch>self._num_datas:
perm=np.arange(self._num_datas)
np.random.shuffle(perm)
self.datas=self.datas[perm,:]
self.lables=self.lables[perm,:]
start=0
self.epoch=batchsize
end=self.epoch
return self.datas[start:end,:],self.lables[start:end,:]
def train(self,x_train,y_train,x_test,y_test):
self.datas=x_train
self.lables=y_train
self._num_datas=(self.lables.shape[0])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for i in range(5000):
batch=self.next_batch(100)
sess.run(self.train_step,feed_dict={
self.x:batch[0],
self.y:batch[1]
})
if 1:
train_loss = sess.run(self.loss, feed_dict={
self.x: batch[0],
self.y: batch[1]
})
print("setp %d,test_loss %f" % (i, train_loss))
#print("y_pre",sess.run(self.y_pre,feed_dict={ self.x: batch[0],
# self.y: batch[1]}))
#print("*****************weight********************",sess.run(self.weight))
print(sess.run(self.accuracy,feed_dict={self.x:x_test,self.y:y_test}))

然后是调用方法,包括了对这个mnist数据集的下载

from myTensorflowLinearModle import myLinearModel as mlm
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) if __name__=='__main__': x_train,x_test,y_train,y_test=mnist.train.images,mnist.test.images,mnist.train.labels,mnist.test.labels
linear = mlm(len(x_train[1]))
linear.train(x_train,y_train,x_test,y_test)

下载方法来自tensorflow的官方文档中文版

使用线性回归识别手写阿拉伯数字mnist数据集的更多相关文章

  1. stanford coursera 机器学习编程作业 exercise 3(使用神经网络 识别手写的阿拉伯数字(0-9))

    本作业使用神经网络(neural networks)识别手写的阿拉伯数字(0-9) 关于使用逻辑回归实现多分类问题:识别手写的阿拉伯数字(0-9),请参考:http://www.cnblogs.com ...

  2. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  3. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  4. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  5. Tensorflow搭建卷积神经网络识别手写英语字母

    更新记录: 2018年2月5日 初始文章版本 近几天需要进行英语手写体识别,查阅了很多资料,但是大多数资料都是针对MNIST数据集的,并且主要识别手写数字.为了满足实际的英文手写识别需求,需要从训练集 ...

  6. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  7. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  8. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

  9. 如何用卷积神经网络CNN识别手写数字集?

    前几天用CNN识别手写数字集,后来看到kaggle上有一个比赛是识别手写数字集的,已经进行了一年多了,目前有1179个有效提交,最高的是100%,我做了一下,用keras做的,一开始用最简单的MLP, ...

随机推荐

  1. Linux主要shell命令详解(中)

    shell中的特殊字符 shell中除使用普通字符外,还可以使用一些具有特殊含义和功能的特殊字符.在使用它们时应注意其特殊的含义和作用范围.下面分别对这些特殊字符加以介绍. 1. 通配符 通配符用于模 ...

  2. js跨域问题解释 使用jsonp或jQuery的解决方案

    js跨域及解决方案 1.什么是跨域 我们经常会在页面上使用ajax请求访问其他服务器的数据,此时,客户端会出现跨域问题. 跨域问题是由于javascript语言安全限制中的同源策略造成的. 简单来说, ...

  3. 基于24位AD转换模块HX711的重量称量实验(已补充皮重存储,线性温度漂移修正)

    转载:http://www.geek-workshop.com/thread-2315-1-1.html 以前在X宝上买过一个称重放大器,180+大洋.原理基本上就是把桥式拉力传感器输出的mV级信号放 ...

  4. Mac Apache ZooKeeper 配置

    1.配置准备工作 1)配置 ZooKeeper 准备工作 下载相关软件 apache-zookeeper-v3.4.10.zip ZooKeeper 官网 ZooKeeper 配置软件下载地址,密码: ...

  5. 【C语言】练习3-5

     题目来源:<The C programming language>中的习题P51  练习2-1:  编写函数itob(n, s, b),将整数n转换为以b为底的数,并将转换结果以字符的形 ...

  6. [企业化NET]Window Server 2008 R2[3]-SVN 服务端 和 客户端 基本使用

    1.  服务器基本安装即问题解决记录      √ 2.  SVN环境搭建和客户端使用 2.1  服务端 和 客户端 安装    √ 2.2  项目建立与基本使用     √ 2.3  基本冲突解决, ...

  7. SharePoint利用HttpModule的Init方法实现全局初始化

    接上篇 我们知道,HttpRuntime中会对每一个Request创建一个HttpApplication对象(HttpApplicationFactory从一个HttpApplication池来拿). ...

  8. 数据库查询实例(包含所有where条件例子)

    查询指定列 [例1] 查询全体学生的学号与姓名. SELECT Sno,Sname FROM Student: [例2] 查询全体学生的姓名.学号.所在系. SELECT Sname,Sno,Sdep ...

  9. proguard的简单配置说明

    #需要转换的jar文件路径-injars 'D:\fs-np.jar'#转换后的jar文件名称-outjars 'D:\fs-np-sec.jar' #关联的第三方jar-libraryjars 'C ...

  10. Maven pom.xml中的元素modules、parent、properties以及import(转)

    前言 项目中用到了maven,而且用到的内容不像利用maven/eclipse搭建ssm(spring+spring mvc+mybatis)用的那么简单:maven的核心是pom.xml,那么我就它 ...