# coding:utf8
import numpy as np
import cPickle
import os
import tensorflow as tf class SoftMax:
def __init__(self,MAXT=30,step=0.0025):
self.MAXT = MAXT
self.step = step def load_theta(self,datapath="data/softmax.pkl"):
self.theta = cPickle.load(open(datapath,'rb')) def process_train(self,data,label,typenum=10,batch_size=500):
batches = data.shape[0] / batch_size
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = tf.Variable(tf.zeros([valuenum,typenum]))
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
cross_entropy = -tf.reduce_sum(y_*tf.log(y)) #交叉熵
train_step = tf.train.GradientDescentOptimizer(self.step).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for epoch in range(self.MAXT):
cost_=[]
for index in xrange(batches):
c_,_=sess.run([cross_entropy,train_step], feed_dict={ x: data[index * batch_size: (index + 1) * batch_size],
y_: label[index * batch_size: (index + 1) * batch_size]})
cost_.append(c_)
if epoch % 5 == 0:
print(( 'epoch %i, minibatch %i/%i,averange cost is %f') %
(epoch,index + 1,batches,np.mean(cost_)))
self.theta=sess.run(theta)
if not os.path.exists('data/softmax.pkl'):
f= open("data/softmax.pkl",'wb')
cPickle.dump(self.theta,f)
f.close()
return self.theta def process_test(self,data,label,typenum=10):
valuenum=data.shape[1]
if len(label.shape)==1:
label=self.reshape_data(label,typenum)
x = tf.placeholder("float", [None,valuenum])
theta = self.theta
y = tf.nn.softmax(tf.matmul(x,theta))
y_ = tf.placeholder("float", [None, typenum])
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print "Accuracy: ",sess.run(accuracy, feed_dict={x: data,y_: label}) def h(self,x):
m = np.exp(np.dot(x,self.theta))
sump = np.sum(m,axis=1)
return m/sump def predict(self,x):
return np.argmax(self.h(x),axis=1) def reshape_data(self,label,typenum):
label_=[]
for yl_ in label:
tl_=np.zeros(typenum)
tl_[yl_]=1.0
label_.append(tl_)
return np.mat(label_) if __name__ == '__main__':
f = open('mnist.pkl', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
training_inputs = [np.reshape(x, 784) for x in training_data[0]]
data = np.array(training_inputs)
training_inputs = [np.reshape(x, 784) for x in validation_data[0]]
vdata = np.array(training_inputs)
f.close() softmax = SoftMax()
softmax.process_train(data,training_data[1])
softmax.process_test(vdata,validation_data[1]) #Accuracy: 0.9269
softmax.process_test(data,training_data[1]) #Accuracy: 0.92718

Softmax回归(使用tensorflow)的更多相关文章

  1. 手写数字识别 ----Softmax回归模型官方案例注释(基于Tensorflow,Python)

    # 手写数字识别 ----Softmax回归模型 # regression import os import tensorflow as tf from tensorflow.examples.tut ...

  2. TensorFlow实现Softmax回归(模型存储与加载)

    # -*- coding: utf-8 -*- """ Created on Thu Oct 18 18:02:26 2018 @author: zhen "& ...

  3. 利用TensorFlow识别手写的数字---基于Softmax回归

    1 MNIST数据集 MNIST数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10类,分别对应从0-9,共10个阿拉伯数字.原始的MNIST数据库一共包含下面4个文件,见下表. 训练图像一 ...

  4. 统计学习方法:罗杰斯特回归及Tensorflow入门

    作者:桂. 时间:2017-04-21  21:11:23 链接:http://www.cnblogs.com/xingshansi/p/6743780.html 前言 看到最近大家都在用Tensor ...

  5. 使用Softmax回归将神经网络输出转成概率分布

    神经网络解决多分类问题最常用的方法是设置n个输出节点,其中n为类别的个数.对于每一个样例,神经网络可以得到一个n维数组作为输出结果.数组中的每一个维度(也就是每一个输出节点)对应一个类别,通过前向传播 ...

  6. Haskell手撸Softmax回归实现MNIST手写识别

    Haskell手撸Softmax回归实现MNIST手写识别 前言 初学Haskell,看的书是Learn You a Haskell for Great Good, 才刚看到Making Our Ow ...

  7. Softmax回归

    Reference: http://ufldl.stanford.edu/wiki/index.php/Softmax_regression http://deeplearning.net/tutor ...

  8. Softmax回归(Softmax Regression)

    转载请注明出处:http://www.cnblogs.com/BYRans/ 多分类问题 在一个多分类问题中,因变量y有k个取值,即.例如在邮件分类问题中,我们要把邮件分为垃圾邮件.个人邮件.工作邮件 ...

  9. DeepLearning之路(二)SoftMax回归

    Softmax回归   1. softmax回归模型 softmax回归模型是logistic回归模型在多分类问题上的扩展(logistic回归解决的是二分类问题). 对于训练集,有. 对于给定的测试 ...

  10. Machine Learning 学习笔记 (3) —— 泊松回归与Softmax回归

    本系列文章允许转载,转载请保留全文! [请先阅读][说明&总目录]http://www.cnblogs.com/tbcaaa8/p/4415055.html 1. 泊松回归 (Poisson ...

随机推荐

  1. hdu 2097

    ps:WA了两次好像....Sky数是三个进制下的各位数之和相等...而不是都等于22...我傻逼了... 代码: #include "stdio.h" int inp(int a ...

  2. tornado初步 ppt分享

    组内的tornado分享,初步: http://files.cnblogs.com/files/yuhan-TB/tornado.pptx

  3. ios 从网络上获取图片并在UIImageView中显示

    ios 从网络上获取图片   -(UIImage *) getImageFromURL:(NSString *)fileURL { NSLog(@"执行图片下载函数"); UIIm ...

  4. php大力力 [001节]2015-08-21.php在百度文库的几个基础教程新手上路日记 大力力php 大力同学 2015-08-21 15:28

    php大力力 [001节]2015-08-21.php在百度文库的几个基础教程新手上路日记 大力力php 大力同学 2015-08-21 15:28 话说,嗯嗯,就是我自己说,做事认真要用表格,学习技 ...

  5. Interview----最长连续乘积字串

    题目描述: 给一个浮点数序列,取最大乘积连续子串的值,例如 -2.5,4,0,3,0.5,8,-1,则取出的最大乘积连续子串为3,0.5,8. 也就是说,上述数组中,3 0.5 8这3个数的乘积3*0 ...

  6. linux命令:rm

    1.介绍: rm用来删除文件或者目录,对于链接文件,只删除了链接,不删除源文件.rm是一个非常危险的命令,像rm -rf /这个命令运行后,后果不堪设想. 2.命令格式: rm [选项] 文件/目录 ...

  7. C#里partial关键字的作用(转摘)

    C#里partial关键字的作用(转摘) 1. 什么是局部类型? C# 2.0 引入了局部类型的概念.局部类型允许我们将一个类.结构或接口分成几个部分,分别实现在几个不同的.cs文件中. 局部类型适用 ...

  8. Java多线程的实现

    记得面试的时候,面试官问了Java多线程实现的方式有几种,它们之间的区别是什么?作为一个Java新手,将最近的学习总结如下: 1.Java多线程实现方式 Java多线程实现方式主要有三种:继承Thre ...

  9. Winform 关于委托与Invoke和Begin Invoke的使用

    这方面的文章已经写得很详细了,特地摘引两篇文章 http://www.cnblogs.com/c2303191/articles/826571.html http://www.cnblogs.com/ ...

  10. 开机流程与主引导分区(MBR)——鸟哥私房菜

    在前篇随笔中,已经谈到了CMOS与BIOS,CMOS是记录各项硬件参数(包括系统时间.设备的I/O地址.CPU的电压和频率等)且嵌入到主板上面的存储器,BIOS是一个写入到主板上的韧体(韧体是写入到硬 ...