SMOTE(Synthetic Minority Oversampling Technique),合成少数类过采样技术.它是基于随机过采样算法的一种改进方案,由于随机过采样采取简单复制样本的策略来增加少数类样本,这样容易产生模型过拟合的问题,即使得模型学习到的信息过于特别(Specific)而不够泛化(General),SMOTE算法的基本思想是对少数类样本进行分析并根据少数类样本人工合成新样本添加到数据集中,具体如下图所示,算法流程如下。

  • (1)对于少数类中每一个样本x,以欧氏距离为标准计算它到少数类样本集中所有样本的距离,得到其k近邻。
  • (2)根据样本不平衡比例设置一个采样比例以确定采样倍率N,对于每一个少数类样本x,从其k近邻中随机选择若干个样本,假设选择的近邻为o。
  • (3)对于每一个随机选出的近邻o,分别与原样本按照公式o(new)=o+rand(0,1)*(x-o)构建新的样本。
或者:

Smote算法的思想其实很简单,先随机选定n个少类的样本,如下图

 

找出初始扩展的少类样本

再找出最靠近它的m个少类样本,如下图

 

再任选最临近的m个少类样本中的任意一点,

 

在这两点上任选一点,这点就是新增的数据样本

  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pandas as pd
  4. from sklearn.preprocessing import StandardScaler
  5. from numpy import *
  6. import matplotlib.pyplot as plt
  7.  
  8. #读数据
  9. data = pd.read_table('supermarket_second_man_clothes_train.txt', low_memory=False)
  10.  
  11. #简单的预处理
  12. test_date = pd.concat([data['label'], data.iloc[:, 7:10]], axis=1)
  13. test_date = test_date.dropna(how='any')

结果:

  1. test_date.head()
  2. Out[1]:
  3. label max_date_diff max_pay cnt_time
  4. 0 0 23.0 43068.0 15
  5. 1 0 10.0 1899.0 2
  6. 2 0 146.0 3299.0 21
  7. 3 0 30.0 31959.0 35
  8. 4 0 3.0 24165.0 98
  9. test_date['label'][test_date['label']==0].count()/test_date['label'][test_date['label']==1].count()
  10. Out[2]: 67

label是样本类别判别标签,0:1=67:1,需要对label=1的数据进行扩充

  1. # 筛选目标变量
  2. aimed_date = test_date[test_date['label'] == 1]
  3. # 随机筛选少类扩充中心
  4. index = pd.DataFrame(aimed_date.index).sample(frac=0.1, random_state=1)
  5. index.columns = ['id']
  6. number = len(index)
  7. # 生成array格式
  8. aimed_date_new = aimed_date.ix[index.values.ravel(), :]

随机选取了全量少数样本的10%作为数据扩充的中心点

  1. # 自变量标准化
  2. sc = StandardScaler().fit(aimed_date_new)
  3. aimed_date_new = pd.DataFrame(sc.transform(aimed_date_new))
  4. sc1 = StandardScaler().fit(aimed_date)
  5. aimed_date = pd.DataFrame(sc1.transform(aimed_date))
  6.  
  7. # 定义欧式距离计算
  8. def dist(a, b):
  9. a = array(a)
  10. b = array(b)
  11. d = ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2 + (a[2] - b[2]) ** 2 + (a[3] - b[3]) ** 2) ** 0.5
  12. return d

下面定义距离计算的方式,所有算法中,涉及到距离的地方都需要标准化去除冈量,也同时加快了计算的速度
这边采取了欧式距离的方式

  1. # 统计所有检验距离样本个数
  2. row_l1 = aimed_date_new.iloc[:, 0].count()
  3. row_l2 = aimed_date.iloc[:, 0].count()
  4. a = zeros((row_l1, row_l2))
  5. a = pd.DataFrame(a)
  6. # 计算距离矩阵
  7. for i in range(row_l1):
  8. for j in range(row_l2):
  9. d = dist(aimed_date_new.iloc[i, :], aimed_date.iloc[j, :])
  10. a.ix[i, j] = d
  11. b = a.T.apply(lambda x: x.min())

调用上面的计算距离的函数,形成一个距离矩阵

  1. # 找到同类点位置
  2. h = []
  3. z = []
  4. for i in range(number):
  5. for j in range(len(a.iloc[i, :])):
  6. ai = a.iloc[i, j]
  7. bi = b[i]
  8. if ai == bi:
  9. h.append(i)
  10. z.append(j)
  11. else:
  12. continue
  13. new_point = [0, 0, 0, 0]
  14. new_point = pd.DataFrame(new_point)
  15. for i in range(len(h)):
  16. index_a = z[i]
  17. new = aimed_date.iloc[index_a, :]
  18. new_point = pd.concat([new, new_point], axis=1)
  19.  
  20. new_point = new_point.iloc[:, range(len(new_point.columns) - 1)]

再找到位置的情况下,再去原始的数据集中根据位置查找具体的数据

  1. import random
  2. r1 = []
  3. for i in range(len(new_point.columns)):
  4. r1.append(random.uniform(0, 1))
  5. new_point_last = []
  6. new_point_last = pd.DataFrame(new_point_last)
  7. # 求新点 new_x=old_x+rand()*(append_x-old_x)
  8. for i in range(len(new_point.columns)):
  9. new_x = (new_point.iloc[1:4, i] - aimed_date_new.iloc[number - 1 - i, 1:4]) * r1[i] + aimed_date_new.iloc[number - 1 - i, 1:4]
  10. new_point_last = pd.concat([new_point_last, new_x], axis=1)
  11. print new_point_last

最后,再根据smote的计算公式new_x=old_x+rand()*(append_x-old_x),计算出新的点即可。

smote算法的伪代码如下:

  1. import random
  2. from sklearn.neighbors import NearestNeighbors
  3. import numpy as np
  4.  
  5. class Smote:
  6.     def __init__(self,samples,N=1,k=5):
  7.         self.n_samples,self.n_attrs=samples.shape
  8.         self.N=N
  9.         self.k=k
  10.         self.samples=samples
  11.         self.newindex=0
  12.        # self.synthetic=np.zeros((self.n_samples*N,self.n_attrs))
  13.  
  14.     def over_sampling(self):
  15.         N=int(self.N)
  16.         self.synthetic = np.zeros((self.n_samples * N, self.n_attrs))
  17.         neighbors=NearestNeighbors(n_neighbors=self.k).fit(self.samples)
  18.         print('neighbors',neighbors)
  19.         for i in range(len(self.samples)):
  20.             nnarray=neighbors.kneighbors(self.samples[i].reshape(1,-1),return_distance=False)[0]
  21.             #print nnarray
  22.             self._populate(N,i,nnarray)
  23.         return self.synthetic
  24.     
  25.     # for each minority class samples,choose N of the k nearest neighbors and generate N synthetic samples.
  26.     def _populate(self,N,i,nnarray):
  27.         for j in range(N):
  28.             nn=random.randint(0,self.k-1)
  29.             dif=self.samples[nnarray[nn]]-self.samples[i]
  30.             gap=random.random()
  31.             self.synthetic[self.newindex]=self.samples[i]+gap*dif
  32.             self.newindex+=1
  33. a=np.array([[1,2,3],[4,5,6],[2,3,1],[2,1,2],[2,3,4],[2,3,4]])
  34. s=Smote(a,N=2)              #a为少数数据集,N为倍率,即从k-邻居中取出几个样本点
  35. print(s.over_sampling())

SMOTE算法的缺陷

该算法主要存在两方面的问题:一是在近邻选择时,存在一定的盲目性。从上面的算法流程可以看出,在算法执行过程中,需要确定K值,即选择多少个近邻样本,这需要用户自行解决。从K值的定义可以看出,K值的下限是M值(M值为从K个近邻中随机挑选出的近邻样本的个数,且有M< K),M的大小可以根据负类样本数量、正类样本数量和数据集最后需要达到的平衡率决定。但K值的上限没有办法确定,只能根据具体的数据集去反复测试。因此如何确定K值,才能使算法达到最优这是未知的。

另外,该算法无法克服非平衡数据集的数据分布问题,容易产生分布边缘化问题。由于负类样本的分布决定了其可选择的近邻,如果一个负类样本处在负类样本集的分布边缘,则由此负类样本和相邻样本产生的“人造”样本也会处在这个边缘,且会越来越边缘化,从而模糊了正类样本和负类样本的边界,而且使边界变得越来越模糊。这种边界模糊性,虽然使数据集的平衡性得到了改善,但加大了分类算法进行分类的难度.

针对SMOTE算法的进一步改进

针对SMOTE算法存在的边缘化和盲目性等问题,很多人纷纷提出了新的改进办法,在一定程度上改进了算法的性能,但还存在许多需要解决的问题。

Han等人Borderline-SMOTE: A New Over-Sampling Method in Imbalanced Data Sets Learning在SMOTE算法基础上进行了改进,提出了Borderhne.SMOTE算法,解决了生成样本重叠(Overlapping)的问题该算法在运行的过程中,查找一个适当的区域,该区域可以较好地反应数据集的性质,然后在该区域内进行插值,以使新增加的“人造”样本更有效。这个适当的区域一般由经验给定,因此算法在执行的过程中有一定的局限性。

过采样算法之SMOTE的更多相关文章

  1. 机器学习 —— 类不平衡问题与SMOTE过采样算法

    在前段时间做本科毕业设计的时候,遇到了各个类别的样本量分布不均的问题——某些类别的样本数量极多,而有些类别的样本数量极少,也就是所谓的类不平衡(class-imbalance)问题. 本篇简述了以下内 ...

  2. [转]类不平衡问题与SMOTE过采样算法

    在前段时间做本科毕业设计的时候,遇到了各个类别的样本量分布不均的问题——某些类别的样本数量极多,而有些类别的样本数量极少,也就是所谓的类不平衡(class-imbalance)问题. 本篇简述了以下内 ...

  3. 蓄水池采样算法(Reservoir Sampling)

    蓄水池采样算法 问题描述分析 采样问题经常会被遇到,比如: 从 100000 份调查报告中抽取 1000 份进行统计. 从一本很厚的电话簿中抽取 1000 人进行姓氏统计. 从 Google 搜索 & ...

  4. 文本主题模型之LDA(二) LDA求解之Gibbs采样算法

    文本主题模型之LDA(一) LDA基础 文本主题模型之LDA(二) LDA求解之Gibbs采样算法 文本主题模型之LDA(三) LDA求解之变分推断EM算法(TODO) 本文是LDA主题模型的第二篇, ...

  5. WebRTC 音频采样算法 附完整C++示例代码

    之前有大概介绍了音频采样相关的思路,详情见<简洁明了的插值音频重采样算法例子 (附完整C代码)>. 音频方面的开源项目很多很多. 最知名的莫过于谷歌开源的WebRTC, 其中的音频模块就包 ...

  6. MCMC等采样算法

    一.直接采样 直接采样的思想是,通过对均匀分布采样,实现对任意分布的采样.因为均匀分布采样好猜,我们想要的分布采样不好采,那就采取一定的策略通过简单采取求复杂采样. 假设y服从某项分布p(y),其累积 ...

  7. 从信用卡欺诈模型看不平衡数据分类(1)数据层面:使用过采样是主流,过采样通常使用smote,或者少数使用数据复制。过采样后模型选择RF、xgboost、神经网络能够取得非常不错的效果。(2)模型层面:使用模型集成,样本不做处理,将各个模型进行特征选择、参数调优后进行集成,通常也能够取得不错的结果。(3)其他方法:偶尔可以使用异常检测技术,IF为主

    总结:不平衡数据的分类,(1)数据层面:使用过采样是主流,过采样通常使用smote,或者少数使用数据复制.过采样后模型选择RF.xgboost.神经网络能够取得非常不错的效果.(2)模型层面:使用模型 ...

  8. RANSAC随机一致性采样算法学习体会

    The RANSAC algorithm is a learning technique to estimate parameters of a model by random sampling of ...

  9. 机器学习入门-数据过采样(上采样)1. SMOTE

    from imblearn.over_sampling import SMOTE  # 导入 overstamp = SMOTE(random_state=0) # 对训练集的数据进行上采样,测试集的 ...

随机推荐

  1. Django学习之Form表单

    一.Form介绍 普通方式手写注册功能 使用form组件实现注册功能 二.Form那些事儿 1.常用字段与插件 initial error_messages password radioSelect ...

  2. 【工具安装】MAC 安装 netdiscover 使用教程

    日期:2019-06-27 15:54:19 作者:Bay0net 介绍:在 mac os 下,如何安装 netdiscover 及基本使用方法 0x01.当前环境 MAC os 10.14.4 已安 ...

  3. oracle data guard --理论知识回顾01

    之前搭建了rac到单实例的dg环境,最近又在windows下搭建了dg,这一篇关于dg的一些理论知识回顾 官方文档 https://docs.oracle.com/cd/E11882_01/nav/p ...

  4. VS2017使用dotnet命令

    添加引用Microsoft.EntityFrameworkCore.Tools 添加引用后提示未找到命令“dotnet ef”向csprog文件添加如下节点 <ItemGroup> < ...

  5. 解决172.17 或者172.18 机房环境下harbor服务器不通的问题

    直接改docker-compose.yml文件: 把原来的network选项注释掉,自定义 #networks: # harbor: # external: false networks: harbo ...

  6. django 的 MTV 流程图

  7. Nginx跨域问题

    Nginx跨域无法访问,通常报错: Failed to load http://172.18.6.30:8086/CityServlet: No 'Access-Control-Allow-Origi ...

  8. 【扩展事件】跟踪超过3秒的SQL

    msdn 扩展事件:点击打开链接 转自:https://blog.csdn.net/yenange/article/details/52592814 -- 删除事件会话 IF EXISTS(SELEC ...

  9. C++基础-类和对象

    本文为 C++ 学习笔记,参考<Sams Teach Yourself C++ in One Hour a Day>第 8 版.<C++ Primer>第 5 版.<代码 ...

  10. 深入理解 JavaScript中的变量、值、传参

    1. demo 如果你对下面的代码没有任何疑问就能自信的回答出输出的内容,那么本篇文章就不值得你浪费时间了. var var1 = 1 var var2 = true var var3 = [1,2, ...