# encoding: utf-8
import numpy as np
import matplotlib.pyplot as plt
import cPickle
import gzip class SVC(object):
def __init__(self, c=1.0, delta=0.001): # 初始化
self.N = 0
self.delta = delta
self.X = None
self.y = None
self.w = None
self.wn = 0
self.K = np.zeros((self.N, self.N))
self.a = np.zeros((self.N, 1))
self.b = 0
self.C = c
self.stop=1
self.k=0
self.cls=0
self.train_result=[] def kernel_function(self,x1, x2): # 核函数
return np.dot(x1, x2) def kernel_matrix(self, x): # 核矩阵
for i in range(0, len(x)):
for j in range(i, len(x)):
self.K[j][i] = self.K[i][j] = self.kernel_function(self.X[i], self.X[j]) def get_w(self): # 计算更新w
ay = self.a * self.y
w = np.zeros((1, self.wn))
for i in range(0, self.N):
w += self.X[i] * ay[i]
return w def get_b(self, a1, a2, a1_old, a2_old): # 计算更新B
y1 = self.y[a1]
y2 = self.y[a2]
a1_new = self.a[a1]
a2_new = self.a[a2]
b1_new = -self.E[a1] - y1 * self.K[a1][a1] * (a1_new - a1_old) - y2 * self.K[a2][a1] * (
a2_new - a2_old) + self.b
b2_new = -self.E[a2] - y1 * self.K[a1][a2] * (a1_new - a1_old) - y2 * self.K[a2][a2] * (
a2_new - a2_old) + self.b
if (0 < a1_new) and (a1_new < self.C) and (0 < a2_new) and (a2_new < self.C):
return b1_new[0]
else:
return (b1_new[0] + b2_new[0]) / 2.0 def gx(self, x): # 判别函数g(x)
return np.dot(self.w, x) + self.b def satisfy_kkt(self, a): # 判断样本点是否满足kkt条件
index = a[1]
if a[0] == 0 and self.y[index] * self.gx(self.X[index]) > 1:
return 1
elif a[0] < self.C and self.y[index] * self.gx(self.X[index]) == 1:
return 1
elif a[0] == self.C and self.y[index] * self.gx(self.X[index]) < 1:
return 1
return 0 def clip_func(self, a_new, a1_old, a2_old, y1, y2): # 拉格朗日乘子的裁剪函数
if (y1 == y2):
L = max(0, a1_old + a2_old - self.C)
H = min(self.C, a1_old + a2_old)
else:
L = max(0, a2_old - a1_old)
H = min(self.C, self.C + a2_old - a1_old)
if a_new < L:
a_new = L
if a_new > H:
a_new = H
return a_new def update_a(self, a1, a2): # 更新a1,a2
partial_a2 = self.K[a1][a1] + self.K[a2][a2] - 2 * self.K[a1][a2]
if partial_a2 <= 1e-9:
print "error:", partial_a2
a2_new_unc = self.a[a2] + (self.y[a2] * ((self.E[a1] - self.E[a2]) / partial_a2))
a2_new = self.clip_func(a2_new_unc, self.a[a1], self.a[a2], self.y[a1], self.y[a2])
a1_new = self.a[a1] + self.y[a1] * self.y[a2] * (self.a[a2] - a2_new)
if abs(a1_new - self.a[a1]) < self.delta:
return 0
self.a[a1] = a1_new
self.a[a2] = a2_new
self.is_update = 1
return 1 def update(self, first_a): # 更新拉格朗日乘子
for second_a in range(0, self.N):
if second_a == first_a:
continue
a1_old = self.a[first_a]
a2_old = self.a[second_a]
if self.update_a(first_a, second_a) == 0:
return
self.b= self.get_b(first_a, second_a, a1_old, a2_old)
self.w = self.get_w()
self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
self.stop=0 def train(self, x, y, max_iternum=100): # SMO算法
x_len = len(x)
self.X = x
self.N = x_len
self.wn = len(x[0])
self.y = np.array(y).reshape((self.N, 1))
self.K = np.zeros((self.N, self.N))
self.kernel_matrix(self.X)
self.b = 0
self.a = np.zeros((self.N, 1))
self.w = self.get_w()
self.E = [self.gx(self.X[i]) - self.y[i] for i in range(0, self.N)]
self.is_update = 0
for i in range(0, max_iternum):
self.stop=1
data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a))) if x > 0 and x< self.C]
if len(data_on_bound) == 0:
data_on_bound = [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]
for data in data_on_bound:
if self.satisfy_kkt(data) != 1:
self.update(data[1])
if self.is_update == 0:
for data in [[x,y] for x,y in zip(self.a, range(0, len(self.a)))]:
if self.satisfy_kkt(data) != 1:
self.update(data[1])
if self.stop:
break
return self.w, self.b def fit(self,x, y): # 训练模型, 一对一法k(k-1)/2个SVM进行多类分类
self.cls, y = np.unique(y, return_inverse=True)
self.k=len(self.cls)
for i in range(self.k):
for j in range(i):
a,b=self.sub_data(x,y,i,j)
self.train_result.append([i,j,self.train(a,b)]) def predict(self,x_new): # 预测
p=np.zeros(self.k)
for i,j,w in self.train_result:
self.w=w[0]
self.b=w[1]
if self.classfy(x_new)==1:
p[j]+=1
else:
p[i]+=1
return self.cls[np.argmax(p)] def sub_data(self,x,y,i,j): # 数据分类
subx=[]
suby=[]
for a,b in zip(x,y):
if b==i:
subx.append(a)
suby.append(-1)
elif b==j:
subx.append(a)
suby.append(1)
return subx,suby def classfy(self,x_new): # 预测
y_new=self.gx(x_new)
cl = int(np.sign(y_new))
if cl == 0:
cl = 1
return cl def load_data():
f = gzip.open('../data/mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data) if __name__ == "__main__":
svc = SVC()
np.random.seed(0)
l=1000
training_data, validation_data, test_data = load_data()
svc.fit(training_data[0][:l],training_data[1][:l])
predictions = [svc.predict(a) for a in test_data[0][:l]]
num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1][:l]))
print "%s of %s values correct." % (num_correct, len(test_data[1][:l])) #72/100 #808/1000 #8194/10000(较慢)

使用支持向量机训练mnist数据的更多相关文章

  1. tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)

    tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...

  2. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  3. 【OpenCV】opencv3.0中的SVM训练 mnist 手写字体识别

    前言: SVM(支持向量机)一种训练分类器的学习方法 mnist 是一个手写字体图像数据库,训练样本有60000个,测试样本有10000个 LibSVM 一个常用的SVM框架 OpenCV3.0 中的 ...

  4. 使用caffe训练mnist数据集 - caffe教程实战(一)

    个人认为学习一个陌生的框架,最好从例子开始,所以我们也从一个例子开始. 学习本教程之前,你需要首先对卷积神经网络算法原理有些了解,而且安装好了caffe 卷积神经网络原理参考:http://cs231 ...

  5. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  6. LeNet训练MNIST

    jupyter notebook: https://github.com/Penn000/NN/blob/master/notebook/LeNet/LeNet.ipynb LeNet训练MNIST ...

  7. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

  8. TensorFlow训练MNIST数据集(3) —— 卷积神经网络

    前面两篇随笔实现的单层神经网络 和多层神经网络, 在MNIST测试集上的正确率分别约为90%和96%.在换用多层神经网络后,正确率已有很大的提升.这次将采用卷积神经网络继续进行测试. 1.模型基本结构 ...

  9. TensorFlow 训练MNIST数据集(2)—— 多层神经网络

    在我的上一篇随笔中,采用了单层神经网络来对MNIST进行训练,在测试集中只有约90%的正确率.这次换一种神经网络(多层神经网络)来进行训练和测试. 1.获取MNIST数据 MNIST数据集只要一行代码 ...

随机推荐

  1. 2013年9月份第1周51Aspx源码发布详情

    大型B2B家具门户网源码  2013-9-6 [VS2008]功能描述: 1.门户信息管理 安全取数据即使数据库连接中断不会报错 2.稳定性 每句代码经过3次以上检查.此网站还在运营3年了,没有出过问 ...

  2. 获取一个 app 的 URL Scheme 的方法:

    获取一个 app 的 URL Scheme 的方法: 上这个网站 URL Schemes 查一下相应的 app 的 URL Scheme 是否有被收录 第一种方法没找到的话,把相应的 app 的 ip ...

  3. 2016 - 1 - 20 runloop学习

    一:Runloop基本知识 1.本质就是运行循环 2.基本作用: 2.1保证程序持续运行 2.2处理APP中的各种事件:触摸,定时器,selector... 2.3节省CPU资源,系统程序性能:它会让 ...

  4. linux上安装mysql

    linux下mysql 最新版安装图解教程 1.查看当前安装的linux版本 命令:lsb_release -a 如下图所示 通过上图中的数据可以看出安装的版本为RedHat5.4,所以我们需要下载R ...

  5. UIWebView的缓存策略,清除cookie

    缓存策略 NSURLRequestCachePolicy NSURLRequestUseProtocolCachePolicy缓存策略定义在 web 协议实现中,用于请求特定的URL.是默认的URL缓 ...

  6. Linux中的汇编简介

    GNU as汇编语法 GNU汇编语法使用的是AT&T汇编它和Intel汇编的语法主要有以下一些不同: AT&T汇编中的立即操作数前面要加上'$',寄存器操作数名前要加上百分号'%',绝 ...

  7. Android.mk 常用宏和变量

    android ndk开发有一个重要的文件 Android.mk,他虽然重要,但是对它进行深入介绍的文档却比较的少,这里将对Android.mk中常用的宏和变量进行说明: 由于这一部分的内容多,资料零 ...

  8. ubuntu下修改mysql默认字符编码出现的Job failed to start解决办法

    ubuntu下修改mysql默认字符编码出现的Job failed to start解决办法 前几天卸掉了用了好多年的Windows,安装了Ubuntu12.04,就开始各种搭环境.今天装好了MySQ ...

  9. Magento:Paypal付款不成功返回后不要清空购物车产品的解决方案

    经常遇到这个问题,当我们使用第三方支付工具Gateway如paypal支付的时候,如果用户付款不成功或者取消了订单再返回网站时,发现购物车里面的产品已经被清空了,如果是客户主动cancel的还好,但是 ...

  10. 简单的IOS6和IOS7通过图片名适配

    在美工提供图片图片的前提下,只需要下面给UIImage做一个分类,就可以简单的实现在6和7上的图片名字适配. 比如美工在6上面提供的图片叫common_button_big_red_highlight ...