我不知道什么是爱,但我知道什么是不爱” 

--One Class Learning的自白

 一、单分类简介

  如果将分类算法进行划分,根据类别个数的不同可以分为单分类、二分类、多分类,常见的分类算法主要解决二分类和多分类问题,预测一封邮件是否是垃圾邮件是一个典型的二分类问题,手写体识别是一个典型的多分类问题,这些算法并不能很好的应用在单分类上,但单分类问题在工业界广泛存在,由于每个企业刻画用户的数据都是有限的,很多二分类问题很难找到负样本,比如通过用户的搜索记录预测一个用户是否有小孩,可以通过规则筛选出正样本,如经常搜索“宝宝、早教”之类的词,但很难筛选出合适的负样本,一个用户没有搜索并不代表没有小孩,即便用一些排除法筛选出负样本,负样本也会不纯,不能保证负样本中没有正样本,这样用分类模型跑出的结果存在解释性的困难,预测概率高可以认为是有很有可能有小孩,但概率低是有小孩的的可能小还是无法判断?在只能定义正样本不能定义负样本的场景中,使用单分类算法(One-Class Learning)更适合,单分类算法只关注与样本的相似或匹配程度,对于未知的部分不妄下结论。很多人会说“我不知道什么是爱,但我知道什么是不爱”,他们相亲时会立马排查掉不爱的对象,他们就是一个聪明的One-Class学习器。

  当然分类算法根据记录所属类别个数的不同还可以分为单标签分类和多标签分类,预测一个用户喜欢的球类就是一个多标签多分类问题,这里不多加讨论,本文重点讨论单分类问题。

二、单分类算法介绍

  One Class Learning比较经典的算法是One-Class-SVM[参考文献1],这个算法的思路非常简单,就是寻求一个超平面将样本中的正例圈起来,预测是就用这个超平面做决策,在圈内的样本就认为是正样本。由于核函数计算比较耗时,在海量数据的场景用得并不多;

one class svm示例图(图片引用自[参考文献2])

  

  另外一个算法是基于神经网络的算法,在深度学习中广泛使用的自编码算法可以应用在单分类的问题上[参考文献3],自编码器是一个BP神经网络,网络输入层和输出层是一样,中间层数可以有多层,中间层的节点个数比输出入层少,最简单的情况就是中间只有一个隐藏层,如下图所示,由于中间层的节点数较少,这样中间层相当于是对数据进行了压缩和抽象,实现无监督的方式学习数据的抽象特征。

  如果我们只有正样本数据,没有负样本数据,或者说只关注学习正样本的规律,那么利用正样本训练一个自编码器,编码器就相当于单分类的模型,对全量数据进行预测时,通过比较输入层和输出层的相似度就可以判断记录是否属于正样本。由于自编器采用神经网络实现,可以用GPU来进行加速计算,因此比较适合海量数据的场景。

三、单分类算法实践

  这里使用TensorFlow实践了一下基于自编码的One-Class-Learning,代码引用自[参考文献4],使用MNIST数据进行自编码,首先用0-9这10个类别的数据进行训练,输入层和输出层节点数为784,中间层有两层,节点数分别为256、128,下图中第一排是原始图片,第二排是自编码后解码的图片,可以认为第二排图片是128个单元节点的压缩表示。

  假设现在是一个单分类问题,我们只有部分数字7的标注数据,现在需要预测一张图片是不是数据7,那么可以用这部分标注数据来训练一个编码器,让编码器来发现数字7的内在规律,预测时通过编码器进行计算,比较输入层与输出层的相似度,如果相似度较高就可以认为是数字7,编码和解码后的图片示例如下:

实现代码如下:

 import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt # Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data # Parameters
learning_rate = 0.008
training_epochs = 130
batch_size = 2560 # Building the encoder
def encoder(x,weights,biases):
# Encoder Hidden layer with sigmoid activation #1
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
biases['encoder_b1']))
# Decoder Hidden layer with sigmoid activation #2
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
biases['encoder_b2']))
return layer_2 # Building the decoder
def decoder(x,weights,biases):
# Encoder Hidden layer with sigmoid activation #1
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
biases['decoder_b1']))
# Decoder Hidden layer with sigmoid activation #2
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
biases['decoder_b2']))
return layer_2 def one_class_learning(dataset,testset,n_input,one_class_label,filename):
# Network Parameters
n_hidden_1 = int(n_input/2)
n_hidden_2 = int(n_input/4)
# tf Graph input (only pictures)
X = tf.placeholder("float", [None, n_input]) weights = {
'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])),
'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),
}
biases = {
'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'decoder_b2': tf.Variable(tf.random_normal([n_input])),
}
# Construct model
encoder_op = encoder(X,weights,biases)
decoder_op = decoder(encoder_op,weights,biases)
# Prediction
y_pred = decoder_op
# Targets (Labels) are the input data.
y_true = X
# Define loss and optimizer, minimize the squared error
cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost) # Initializing the variables
init = tf.global_variables_initializer() # Launch the graph
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(dataset['data'])/batch_size)
# Training cycle
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs = dataset['data'][i*batch_size:(i+1)*batch_size]
batch_ys = dataset['label'][i*batch_size:(i+1)*batch_size]
batch_xs = [batch_xs[j] for j in range(len(batch_xs)) if batch_ys[j] == one_class_label] # Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
print("Epoch:", '%04d' % (epoch+1),"cost=", "{:.9f}".format(c)) encode_decode = sess.run(y_pred, feed_dict={X: testset['data']}) #examples_to_show = 14
#f, a = plt.subplots(2, examples_to_show, figsize=(examples_to_show, 2))
#for i in range(examples_to_show):
# print(testset['label'][i],sess.run(tf.reduce_mean(tf.pow(testset['data'][i] - encode_decode[i], 2))))
# a[0][i].imshow(np.reshape(testset['data'][i], (28, 28)))
# a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
#f.show()
#plt.draw()
#plt.waitforbuttonpress()
wf = open(filename,'a+')
for i in range(len(encode_decode)):
wf.write(str(one_class_label)+','+str(testset['label'][i])+','+str(sess.run(tf.reduce_mean(tf.pow(testset['data'][i] - encode_decode[i], 2))))+'\n')
if i % 500 == 0:
print(i)
wf.close() def decode_one_hot(label):
return max([i for i in range(len(label)) if label[i] == 1]) def mnist_test():
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
trainset = {'data':mnist.train.images,'label':[decode_one_hot(label) for label in mnist.train.labels]}
testset = {'data':mnist.test.images,'label':[decode_one_hot(label) for label in mnist.test.labels]}
one_class_learning(trainset,testset,784,7,'label_7.csv') mnist_test()

  根据相似度进行排序,还可以计算准确率、召回率、ROC等,这里统计了一下准确率和召回率。

四、参考文献

[1].Manevitz L M, Yousef M. One-class svms for document classification[J]. Journal of Machine Learning Research, 2002, 2(1):139-154.

[2]: http://www.louisdorard.com/blog/when-machine-learning-fails

[3].Manevitz L, Yousef M. One-class document classification via Neural Networks[M]. Elsevier Science Publishers B. V. 2007.

[4].https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py

全文完,转载请注明出处:http://www.cnblogs.com/fengfenggirl/p/One-Class-Learning.html

TensorFlow上实践基于自编码的One Class Learning的更多相关文章

  1. 基于Python玩转人工智能最火框架 TensorFlow应用实践✍✍✍

    基于Python玩转人工智能最火框架  TensorFlow应用实践 随着 TensorFlow 在研究及产品中的应用日益广泛,很多开发者及研究者都希望能深入学习这一深度学习框架.而在昨天机器之心发起 ...

  2. 基于Python玩转人工智能最火框架 TensorFlow应用实践

    慕K网-299元-基于Python玩转人工智能最火框架 TensorFlow应用实践 需要联系我,QQ:1844912514

  3. Paragraph Vector在Gensim和Tensorflow上的编写以及应用

    上一期讨论了Tensorflow以及Gensim的Word2Vec模型的建设以及对比.这一期,我们来看一看Mikolov的另一个模型,即Paragraph Vector模型.目前,Mikolov以及B ...

  4. Jcompress: 一款基于huffman编码和最小堆的压缩、解压缩小程序

    前言 最近基于huffman编码和最小堆排序算法实现了一个压缩.解压缩的小程序.其源代码已经上传到github上面: Jcompress下载地址 .在本人的github上面有一个叫Utility的re ...

  5. 智能合约最佳实践 之 Solidity 编码规范

    每一门语言都有其相应的编码规范, Solidity 也一样, 下面官方推荐的规范及我的总结,供大家参考,希望可以帮助大家写出更好规范的智能合约. 命名规范 避免使用 小写的l,大写的I,大写的O 应该 ...

  6. Linux及安全实践五——字符集编码

    Linux及安全实践五——字符集编码 一.ASCII码 在表中查找出英文字母LXQ相对应的十六进制数值为: 4c 58 51 在终端中输入命令:vim test1.txt 在vim页面输入命令:%!x ...

  7. 在Hadoop上运行基于RMM中文分词算法的MapReduce程序

    原文:http://xiaoxia.org/2011/12/18/map-reduce-program-of-rmm-word-count-on-hadoop/ 在Hadoop上运行基于RMM中文分词 ...

  8. Python玩转人工智能最火框架 TensorFlow应用实践 ☝☝☝

    Python玩转人工智能最火框架 TensorFlow应用实践 (一个人学习或许会很枯燥,但是寻找更多志同道合的朋友一起,学习将会变得更加有意义✌✌) 全民人工智能时代,不甘心只做一个旁观者,那就现在 ...

  9. ZhuSuan 是建立在Tensorflow上的贝叶斯深层学习的 python 库

    ZhuSuan 是建立在Tensorflow上的贝叶斯深层学习的 python 库. 与现有的主要针对监督任务设计的深度学习库不同,ZhuSuan 的特点是深入到贝叶斯推理中,从而支持各种生成模式:传 ...

随机推荐

  1. SpringBoot打成的jar包发布,shell关闭之后一直在服务器运行

    1:可以编写shell脚本, 切换到执行的jar包目录,然后使用nohup  让改命令在服务器一直运行 #!/bin/bash cd /srv/ftp/public nohup java -jar l ...

  2. col-md-1

    .col-md-12 {    width: 100%;  }  .col-md-11 {    width: 91.66666666666666%;  }  .col-md-10 {    widt ...

  3. file_get_post实现post请求

    function Post($url, $post = null){     $context = array();     if (is_array($post)) {       ksort($p ...

  4. DevExpress使用技巧总结

    DevExpress是非常主流的.NET控件,目前全世界和中国都用很多用户使用,不过由于是英文版,初次接触的同学可能会觉得困难,这里就总结DevExpress常见的10个使用技巧. 1.TextEdi ...

  5. linux系统安装 dig和nslookup命令

    Fedora / Centos:1.yum install bind-utils Ubuntu: 1.sudo apt-get install dnsutils Debian: 1.2 apt-get ...

  6. Jackson基础

    一.所需jar包: jackson-core-x.x.x-rc4.jar.jackson-databind-x.x.x-rc4.jar.jackson-annotations-x.x.x-rc4.ja ...

  7. MongoDB ----基于分布式文件存储的数据库

    参考: http://www.cnblogs.com/huangxincheng/category/355399.html http://www.cnblogs.com/daizhj/category ...

  8. js将时间戳转化为日期格式

    function getLocalTime(nS) {        var date = new Date(nS);        var Y = date.getFullYear() + '-'; ...

  9. 手撕vue-cli配置——webpack.dev.conf.js篇

    const utils = require('./utils') const webpack = require('webpack') const config = require('../confi ...

  10. ELK学习笔记之CentOS 7下ELK(6.2.4)++LogStash+Filebeat+Log4j日志集成环境搭建

    0x00 简介 现在的公司由于绝大部分项目都采用分布式架构,很早就采用ELK了,只不过最近因为额外的工作需要,仔细的研究了分布式系统中,怎么样的日志规范和架构才是合理和能够有效提高问题排查效率的. 经 ...