这篇文章解释了底部链接的代码。

问题描述

如上图所示,有一些点位于单位正方形内,并做好了标记。要求找到一条线,作为分类的标准。这些点的数据在 inearly_separable_data.csv 文件内。

思路

最初的 SVM 可以形式化为如下:
\[\begin{equation}\min_{\boldsymbol{\omega,b}}\frac{1}{2}\|\boldsymbol{\omega}\|^2\\s.t.\ y_i(\boldsymbol{\omega}^T\boldsymbol{x}_i+b)\geqslant 1,\ i = 1,2,\cdots ,m.\end{equation} \]

引入软间隔,可以在一定情况下避免过拟合的问题。
引入软间隔之后,问题转化为
\[\begin{equation}
\min_{\boldsymbol{\omega,b}}\frac{1}{2}\|\boldsymbol{\omega}\|^2 + C \sum_{i=1}^{N}max(0,1-y_i(\boldsymbol{\omega}^T\boldsymbol{x}_i+b))
\end{equation}\]

代码

主要代码在 linear_svm.py 内,plot_boundary_on_data.py 负责画图。

一、引入库和声明

import tensorflow as tf
import numpy as np
import scipy.io as io
from matplotlib import pyplot as plt
import plot_boundary_on_data

二、 定义一些变量

# Global variables.
BATCH_SIZE = 100 # The number of training examples to use per training step. # Define the flags useable from the command line.
tf.app.flags.DEFINE_string('train', None,
'File containing the training data (labels & features).')
tf.app.flags.DEFINE_integer('num_epochs', 1,
'Number of training epochs.')
tf.app.flags.DEFINE_float('svmC', 1,
'The C parameter of the SVM cost function.')
tf.app.flags.DEFINE_boolean('verbose', False, 'Produce verbose output.')
tf.app.flags.DEFINE_boolean('plot', True, 'Plot the final decision boundary on the data.')
FLAGS = tf.app.flags.FLAGS

包括每次训练使用的数据,称为一个 batch,大小定义为 BATCH_SIZE
train 是训练集文件的位置,这里是 inearly_separable_data.csv
num_epochs 是把所有训练集的数据使用几遍。把训练集的数据使用一遍称为一个 epoch。
svmC 即\((2)\)式中 \(C\)的大小。

三、读取训练数据

# Extract it into numpy matrices.
train_data,train_labels = extract_data(train_data_filename) # Convert labels to +1,-1
train_labels[train_labels==0] = -1 # Get the shape of the training data.
train_size,num_features = train_data.shape

读出来的 train_data 是一个 [1000, 2] 的张量,样本的有两个属性,train_labels 是一个 [1000, 1] 的张量。
在读取过程中用到了 numpy 的接口。
标准的 SVM 的标记为 \(\{-1, 1\}\),而文件中标记为 \(\{0, 1\}\)。因此需要做一次转换。

四、构造网络结构

x = tf.placeholder("float", shape=[None, num_features])
y = tf.placeholder("float", shape=[None,1]) W = tf.Variable(tf.zeros([num_features,1]))
b = tf.Variable(tf.zeros([1]))
y_raw = tf.matmul(x,W) + b

线性方程的最终表现形式是 \(\boldsymbol{\omega}^t\boldsymbol{x}+b=0\)。
给定一个样本数据 \(\boldsymbol{x}\),若 \(\boldsymbol{\omega}^t\boldsymbol{x}+b \geqslant 1\),则认为对应的分类为 1,然后和样本的标记对比,若标记为1,则分类正确;否则,分类错误。
若 \(\boldsymbol{\omega}^t\boldsymbol{x}+b \leqslant 1\),则认为对应的分类为 -1,然后和样本的标记对比,若标记为-1,则分类正确;否则,分类错误。
最终要求解的值是一个 shape 为 [2, 1] 的张量 \(W\) 和一个标量 \(b\)。
y_raw 是向量机判定的输出。

五、构造优化目标

regularization_loss = 0.5*tf.reduce_sum(tf.square(W))
hinge_loss = tf.reduce_sum(tf.maximum(tf.zeros([BATCH_SIZE,1]),
1 - y*y_raw));
svm_loss = regularization_loss + svmC*hinge_loss;
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(svm_loss)

即 \( \min_{\boldsymbol{\omega,b}}\frac{1}{2}\|\boldsymbol{\omega}\|^2 + C \sum_{i=1}^{N}max(0,1-y_i(\boldsymbol{\omega}^T\boldsymbol{x}_i+b))\) 的代码表示。
指定用梯度下降法最小化 svm_loss

六、用精度来评价模型的好坏

predicted_class = tf.sign(y_raw);
correct_prediction = tf.equal(y,predicted_class)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

如果 y_raw和样本的标记 y 同符号,即认为预测正确。用预测正确的比例来评价模型的好坏。

七、用数据训练模型

with tf.Session() as s:
# Run all the initializers to prepare the trainable parameters.
tf.initialize_all_variables().run() # Iterate and train.
for step in xrange(num_epochs * train_size // BATCH_SIZE):
offset = (step * BATCH_SIZE) % train_size
batch_data = train_data[offset:(offset + BATCH_SIZE), :]
batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
train_step.run(feed_dict={x: batch_data, y: batch_labels})
print 'loss: ', svm_loss.eval(feed_dict={x: batch_data, y: batch_labels})

首先启动一个 session,每次取 BATCH_SIZE 个数据来训练模型。即用batch_databatch_lables来训练一次,每次得到一个 svm_loss 的值。

运行结果

python linear_svm.py --train linearly_separable_data.csv --svmC 1 --verbose True --num_epochs 10

运行以上命令,指定把数据使用10轮,一次使用100个数据,因此可以得到100次迭代的结果。最后得到的结果及精度如下:

思考

  1. 指定 BATCH_SIZEnum_epochs 是为了减少计算量。
    根据数学理论,应该在整个训练数据集上进行梯度下降法的迭代,每一步迭代都应该选取所有训练数据集的样本。但是这样子做计算量太大,于是在每一次迭代时选用训练数据集的一部分作为输入。
    这么做要求每一步迭代选取的数据子集的分布和总体分布一致,否则得不到正确的结果。

参考

用 TensorFlow 实现 SVM 分类问题的更多相关文章

  1. SVM原理以及Tensorflow 实现SVM分类(附代码)

    1.1. SVM介绍 1.2. 工作原理 1.2.1. 几何间隔和函数间隔 1.2.2. 最大化间隔 - 1.2.2.0.0.1. \(L( {x}^*)\)对$ {x}^*$求导为0 - 1.2.2 ...

  2. Relation Extraction中SVM分类样例unbalance data问题解决 -松弛变量与惩罚因子

    转载自:http://blog.csdn.net/yangliuy/article/details/8152390 1.问题描述 做关系抽取就是要从产品评论中抽取出描述产品特征项的target短语以及 ...

  3. SVM-支持向量机(二)非线性SVM分类

    非线性SVM分类 尽管SVM分类器非常高效,并且在很多场景下都非常实用.但是很多数据集并不是可以线性可分的.一个处理非线性数据集的方法是增加更多的特征,例如多项式特征.在某些情况下,这样可以让数据集变 ...

  4. SVM-支持向量机(一)线性SVM分类

    SVM-支持向量机 SVM(Support Vector Machine)-支持向量机,是一个功能非常强大的机器学习模型,可以处理线性与非线性的分类.回归,甚至是异常检测.它也是机器学习中非常热门的算 ...

  5. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  6. tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    # Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...

  7. 用tensorflow实现SVM

    环境配置 win10 Python 3.6 tensorflow1.15 scipy matplotlib (运行时可能会遇到module tkinter的问题) sklearn 一个基于Python ...

  8. SVM分类与回归

    SVM(支撑向量机模型)是二(多)分类问题中经常使用的方法,思想比较简单,但是具体实现与求解细节对工程人员来说比较复杂,如需了解SVM的入门知识和中级进阶可点此下载.本文从应用的角度出发,使用Libs ...

  9. VQ结合SVM分类方法

    今天整理资料时,发现了在学校时做的这个实验,当时整个过程过重偏向依赖分类器方面,而又很难对分类器性能进行一定程度的改良,所以最后没有选用这个方案,估计以后也不会接触这类机器学习的东西了,希望它对刚入门 ...

随机推荐

  1. Java中的内存划分

    Java程序在运行时,需要在内存中分配空间.为了提高运行效率,就对数据进行了不同的空间划分.因为每一片区域都有特定的数据处理方式和内存管理方式. 具体分为5种内存空间: 程序计数器:保证线程切换后能恢 ...

  2. C语言程序设计50例(一)(经典收藏)

    [程序1]题目:有1.2.3.4个数字,能组成多少个互不相同且无重复数字的三位数?都是多少?1.程序分析:可填在百位.十位.个位的数字都是1.2.3.4.组成所有的排列后再去 掉不满足条件的排列. # ...

  3. py-函数基础

    定义: 函数是指将一组语句的集合通过一个名字(函数名)封装起来,要想执行这个函数,只需调用其函数名即可 特性: 1.减少重复代码2.使程序变的可扩展3.使程序变得易维护 函数参数 形参变量 只有在被调 ...

  4. UVa 1596 Bug Hunt (STL栈)

    题意:给定两种操作,一种是定义一个数组,另一种是赋值,让你找出哪一步时出错了,出错只有两种,一种是数组越界,另一种是访问未定义变量. 析:当初看到这个题时,感觉好麻烦啊,然后就放过去了,而现在要重新回 ...

  5. spring boot打包后windows启动乱码

    事情的起因什么的就不多表了,直接进入主题... 项目都要上线了,结果发现使用 idea mvn install之后的 jar在windows下启动乱码,而使用idea启动却没有问题!!! 这是神马情况 ...

  6. hdu 1799 循环多少次?

    题目 题意:给出n,m,其中m表示有几层循环,求循环的次数 ①如果代码中出现 for(i=1;i<=n;i++) OP ; 那么做了n次OP运算: ②如果代码中出现 fori=1;i<=n ...

  7. KNN和K-Means的区别

    KNN和K-Means的区别 KNN K-Means 1.KNN是分类算法 2.监督学习 3.喂给它的数据集是带label的数据,已经是完全正确的数据 1.K-Means是聚类算法 2.非监督学习 3 ...

  8. ABP 基础设施层——集成 Entity Framework

    本文翻译自ABP的官方教程<EntityFramework Integration>,地址为:http://aspnetboilerplate.com/Pages/Documents/En ...

  9. ASP.NET MVC 4 中Razor 视图中JS无法调试

    解决方法 1.首先检查IE中这2个属性是否勾选了. 2.选择IE浏览器进行调试,调试方法有2种     A:采用debugger;的方法,如下图所示: 这时不用调试断点就会在debugger位置中命中 ...

  10. 【加密算法】3DES

    一.简介 3DES(或称为Triple DES)是三重数据加密算法(TDEA,Triple Data Encryption Algorithm)块密码的通称.它相当于是对每个数据块应用三次DES加密算 ...