TensorFlow上实践基于自编码的One Class Learning
--One Class Learning的自白
一、单分类简介
如果将分类算法进行划分,根据类别个数的不同可以分为单分类、二分类、多分类,常见的分类算法主要解决二分类和多分类问题,预测一封邮件是否是垃圾邮件是一个典型的二分类问题,手写体识别是一个典型的多分类问题,这些算法并不能很好的应用在单分类上,但单分类问题在工业界广泛存在,由于每个企业刻画用户的数据都是有限的,很多二分类问题很难找到负样本,比如通过用户的搜索记录预测一个用户是否有小孩,可以通过规则筛选出正样本,如经常搜索“宝宝、早教”之类的词,但很难筛选出合适的负样本,一个用户没有搜索并不代表没有小孩,即便用一些排除法筛选出负样本,负样本也会不纯,不能保证负样本中没有正样本,这样用分类模型跑出的结果存在解释性的困难,预测概率高可以认为是有很有可能有小孩,但概率低是有小孩的的可能小还是无法判断?在只能定义正样本不能定义负样本的场景中,使用单分类算法(One-Class Learning)更适合,单分类算法只关注与样本的相似或匹配程度,对于未知的部分不妄下结论。很多人会说“我不知道什么是爱,但我知道什么是不爱”,他们相亲时会立马排查掉不爱的对象,他们就是一个聪明的One-Class学习器。
当然分类算法根据记录所属类别个数的不同还可以分为单标签分类和多标签分类,预测一个用户喜欢的球类就是一个多标签多分类问题,这里不多加讨论,本文重点讨论单分类问题。
二、单分类算法介绍
One Class Learning比较经典的算法是One-Class-SVM[参考文献1],这个算法的思路非常简单,就是寻求一个超平面将样本中的正例圈起来,预测是就用这个超平面做决策,在圈内的样本就认为是正样本。由于核函数计算比较耗时,在海量数据的场景用得并不多;
one class svm示例图(图片引用自[参考文献2])
另外一个算法是基于神经网络的算法,在深度学习中广泛使用的自编码算法可以应用在单分类的问题上[参考文献3],自编码器是一个BP神经网络,网络输入层和输出层是一样,中间层数可以有多层,中间层的节点个数比输出入层少,最简单的情况就是中间只有一个隐藏层,如下图所示,由于中间层的节点数较少,这样中间层相当于是对数据进行了压缩和抽象,实现无监督的方式学习数据的抽象特征。
如果我们只有正样本数据,没有负样本数据,或者说只关注学习正样本的规律,那么利用正样本训练一个自编码器,编码器就相当于单分类的模型,对全量数据进行预测时,通过比较输入层和输出层的相似度就可以判断记录是否属于正样本。由于自编器采用神经网络实现,可以用GPU来进行加速计算,因此比较适合海量数据的场景。
三、单分类算法实践
这里使用TensorFlow实践了一下基于自编码的One-Class-Learning,代码引用自[参考文献4],使用MNIST数据进行自编码,首先用0-9这10个类别的数据进行训练,输入层和输出层节点数为784,中间层有两层,节点数分别为256、128,下图中第一排是原始图片,第二排是自编码后解码的图片,可以认为第二排图片是128个单元节点的压缩表示。
假设现在是一个单分类问题,我们只有部分数字7的标注数据,现在需要预测一张图片是不是数据7,那么可以用这部分标注数据来训练一个编码器,让编码器来发现数字7的内在规律,预测时通过编码器进行计算,比较输入层与输出层的相似度,如果相似度较高就可以认为是数字7,编码和解码后的图片示例如下:
实现代码如下:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt # Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data # Parameters
learning_rate = 0.008
training_epochs = 130
batch_size = 2560 # Building the encoder
def encoder(x,weights,biases):
# Encoder Hidden layer with sigmoid activation #1
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
biases['encoder_b1']))
# Decoder Hidden layer with sigmoid activation #2
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
biases['encoder_b2']))
return layer_2 # Building the decoder
def decoder(x,weights,biases):
# Encoder Hidden layer with sigmoid activation #1
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
biases['decoder_b1']))
# Decoder Hidden layer with sigmoid activation #2
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
biases['decoder_b2']))
return layer_2 def one_class_learning(dataset,testset,n_input,one_class_label,filename):
# Network Parameters
n_hidden_1 = int(n_input/2)
n_hidden_2 = int(n_input/4)
# tf Graph input (only pictures)
X = tf.placeholder("float", [None, n_input]) weights = {
'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])),
'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),
}
biases = {
'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'decoder_b2': tf.Variable(tf.random_normal([n_input])),
}
# Construct model
encoder_op = encoder(X,weights,biases)
decoder_op = decoder(encoder_op,weights,biases)
# Prediction
y_pred = decoder_op
# Targets (Labels) are the input data.
y_true = X
# Define loss and optimizer, minimize the squared error
cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost) # Initializing the variables
init = tf.global_variables_initializer() # Launch the graph
with tf.Session() as sess:
sess.run(init)
total_batch = int(len(dataset['data'])/batch_size)
# Training cycle
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs = dataset['data'][i*batch_size:(i+1)*batch_size]
batch_ys = dataset['label'][i*batch_size:(i+1)*batch_size]
batch_xs = [batch_xs[j] for j in range(len(batch_xs)) if batch_ys[j] == one_class_label] # Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
print("Epoch:", '%04d' % (epoch+1),"cost=", "{:.9f}".format(c)) encode_decode = sess.run(y_pred, feed_dict={X: testset['data']}) #examples_to_show = 14
#f, a = plt.subplots(2, examples_to_show, figsize=(examples_to_show, 2))
#for i in range(examples_to_show):
# print(testset['label'][i],sess.run(tf.reduce_mean(tf.pow(testset['data'][i] - encode_decode[i], 2))))
# a[0][i].imshow(np.reshape(testset['data'][i], (28, 28)))
# a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
#f.show()
#plt.draw()
#plt.waitforbuttonpress()
wf = open(filename,'a+')
for i in range(len(encode_decode)):
wf.write(str(one_class_label)+','+str(testset['label'][i])+','+str(sess.run(tf.reduce_mean(tf.pow(testset['data'][i] - encode_decode[i], 2))))+'\n')
if i % 500 == 0:
print(i)
wf.close() def decode_one_hot(label):
return max([i for i in range(len(label)) if label[i] == 1]) def mnist_test():
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
trainset = {'data':mnist.train.images,'label':[decode_one_hot(label) for label in mnist.train.labels]}
testset = {'data':mnist.test.images,'label':[decode_one_hot(label) for label in mnist.test.labels]}
one_class_learning(trainset,testset,784,7,'label_7.csv') mnist_test()
根据相似度进行排序,还可以计算准确率、召回率、ROC等,这里统计了一下准确率和召回率。
四、参考文献
[1].Manevitz L M, Yousef M. One-class svms for document classification[J]. Journal of Machine Learning Research, 2002, 2(1):139-154.
[2]: http://www.louisdorard.com/blog/when-machine-learning-fails
[3].Manevitz L, Yousef M. One-class document classification via Neural Networks[M]. Elsevier Science Publishers B. V. 2007.
[4].https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py
全文完,转载请注明出处:http://www.cnblogs.com/fengfenggirl/p/One-Class-Learning.html
TensorFlow上实践基于自编码的One Class Learning的更多相关文章
- 基于Python玩转人工智能最火框架 TensorFlow应用实践✍✍✍
基于Python玩转人工智能最火框架 TensorFlow应用实践 随着 TensorFlow 在研究及产品中的应用日益广泛,很多开发者及研究者都希望能深入学习这一深度学习框架.而在昨天机器之心发起 ...
- 基于Python玩转人工智能最火框架 TensorFlow应用实践
慕K网-299元-基于Python玩转人工智能最火框架 TensorFlow应用实践 需要联系我,QQ:1844912514
- Paragraph Vector在Gensim和Tensorflow上的编写以及应用
上一期讨论了Tensorflow以及Gensim的Word2Vec模型的建设以及对比.这一期,我们来看一看Mikolov的另一个模型,即Paragraph Vector模型.目前,Mikolov以及B ...
- Jcompress: 一款基于huffman编码和最小堆的压缩、解压缩小程序
前言 最近基于huffman编码和最小堆排序算法实现了一个压缩.解压缩的小程序.其源代码已经上传到github上面: Jcompress下载地址 .在本人的github上面有一个叫Utility的re ...
- 智能合约最佳实践 之 Solidity 编码规范
每一门语言都有其相应的编码规范, Solidity 也一样, 下面官方推荐的规范及我的总结,供大家参考,希望可以帮助大家写出更好规范的智能合约. 命名规范 避免使用 小写的l,大写的I,大写的O 应该 ...
- Linux及安全实践五——字符集编码
Linux及安全实践五——字符集编码 一.ASCII码 在表中查找出英文字母LXQ相对应的十六进制数值为: 4c 58 51 在终端中输入命令:vim test1.txt 在vim页面输入命令:%!x ...
- 在Hadoop上运行基于RMM中文分词算法的MapReduce程序
原文:http://xiaoxia.org/2011/12/18/map-reduce-program-of-rmm-word-count-on-hadoop/ 在Hadoop上运行基于RMM中文分词 ...
- Python玩转人工智能最火框架 TensorFlow应用实践 ☝☝☝
Python玩转人工智能最火框架 TensorFlow应用实践 (一个人学习或许会很枯燥,但是寻找更多志同道合的朋友一起,学习将会变得更加有意义✌✌) 全民人工智能时代,不甘心只做一个旁观者,那就现在 ...
- ZhuSuan 是建立在Tensorflow上的贝叶斯深层学习的 python 库
ZhuSuan 是建立在Tensorflow上的贝叶斯深层学习的 python 库. 与现有的主要针对监督任务设计的深度学习库不同,ZhuSuan 的特点是深入到贝叶斯推理中,从而支持各种生成模式:传 ...
随机推荐
- linux报错 find: missing argument to `-exec'
在linux下使用find命令时,报错:find: missing argument to `-exec' 具体执行命令为: find /u03 -name server.xml -exec grep ...
- java-小技巧-001-Long序列化到前端js不支持
1.引入:jackson-mapper-asl-1.9.2.jar 2.导入: import org.codehaus.jackson.map.annotate.JsonSerialize;impor ...
- Inception系列
从GoogLeNet的Inceptionv1开始,发展了众多inception,如inception v2.v3.v4与Inception-ResNet-V2. 故事还是要从inception v1开 ...
- [转]Tesseract-OCR (Tesseract的OCR引擎最先由HP实验室于1985年开始研发)
光学字符识别(OCR,Optical Character Recognition)是指对文本资料进行扫描,然后对图像文件进行分析处理,获取文字及版面信息的过程.OCR技术非常专业,一般多是印刷.打印行 ...
- glob.glob()、os.path.split()函数、global和nonlocal关键字
1. glob.glob() glob模块是Python最简单的模块之一, 内容非常少, 用它可以查找符合特定规则的文件路径名, 查找文件时只会用到三个匹配符: * :匹配0个或多个字符 ? : 匹配 ...
- ES6(简)
一. let.const 和 var let和const只在当前块级作用域中有效const用来声明常量var是全局作用域有效的 constants.js 模块export const A = 1;ex ...
- js时钟
<!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...
- Base64编码加密
package liferay; public class Base64 { public static final char EQUAL = '='; public static final cha ...
- redis error It was not possible to connect to the redis server(s); to create a disconnected multiplexer, disable AbortOnConnectFail. SocketFailure on PING
应用redis出现如下错误 It was not possible to connect to the redis server(s); to create a disconnected multip ...
- Nuget的学习总结
Nuget的学习总结 今天研究了一下nuget,发现nuget实在是太有用了,便写下了这篇博客,希望记录一下自己的学习历程,也希望技术圈的朋友看到之后,如果里面哪里写的不够好,可以给我些宝贵的意见,以 ...