版权所有,转帖注明出处



Scikit-learn是一个开源Python库,它使用统一的接口实现了一系列机器学习、预处理、交叉验证和可视化算法。

一个基本例子

from sklearn import neighbors, datasets, preprocessing
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = datasets.load_iris()
X, y = iris.data[:, :2], iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=33)
scaler = preprocessing.StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
knn = neighbors.KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
y_pred = knn.predict(X_test)
accuracy_score(y_test, y_pred)

加载数据

数据类型可以是NumPy数组、SciPy稀疏矩阵,或者其他可转换为数组的类型,如panda DataFrame等。

import numpy as np
X = np.random.random((10,5))
y = np.array(['M','M','F','F','M','F','M','M','F','F','F'])
X[X < 0.7] = 0

预处理数据

标准化/Standardization

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler().fit(X_train)
standardized_X = scaler.transform(X_train)
standardized_X_test = scaler.transform(X_test)

归一化/Normalization

from sklearn.preprocessing import Normalizer
scaler = Normalizer().fit(X_train)
normalized_X = scaler.transform(X_train)
normalized_X_test = scaler.transform(X_test)

二值化/Binarization

from sklearn.preprocessing import Binarizer
binarizer = Binarizer(threshold=0.0).fit(X)
binary_X = binarizer.transform(X)

类别特征编码

from sklearn.preprocessing import LabelEncoder
enc = LabelEncoder()
y = enc.fit_transform(y)

缺失值估算

>>>from sklearn.preprocessing import Imputer
>>>imp = Imputer(missing_values=0, strategy='mean', axis=0)
>>>imp.fit_transform(X_train)

生成多项式特征

from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(5)
oly.fit_transform(X)

训练与测试数据分组

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,random_state=0)

创建模型

有监督学习模型

线性回归

from sklearn.linear_model import LinearRegression
lr = LinearRegression(normalize=True)

支持向量机(SVM)

from sklearn.svm import SVC
svc = SVC(kernel='linear')

朴素贝叶斯

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

KNN

from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()

无监督学习模型

主成分分析(PCA)

from sklearn.decomposition import PCA
pca = PCA(n_components=0.95)

k均值/K Means

from sklearn.cluster import KMeans
k_means = KMeans(n_clusters=3, random_state=0)

模型拟合

有监督学习

lr.fit(X, y)
knn.fit(X_train, y_train)
svc.fit(X_train, y_train)

无监督学习

k_means.fit(X_train)
pca_model = pca.fit_transform(X_train)

模型预测

有监督学习

y_pred = svc.predict(np.random.random((2,5)))
y_pred = lr.predict(X_test)
y_pred = knn.predict_proba(X_test))

无监督学习

y_pred = k_means.predict(X_test)

评估模型性能

分类指标

准确度

knn.score(X_test, y_test)
from sklearn.metrics import accuracy_score
accuracy_score(y_test, y_pred)

分类报告

from sklearn.metrics import classification_report
print(classification_report(y_test, y_pred)))

混淆矩阵

from sklearn.metrics import confusion_matrix
print(confusion_matrix(y_test, y_pred)))

回归指标

平均绝对误差

from sklearn.metrics import mean_absolute_error
y_true = [3, -0.5, 2])
mean_absolute_error(y_true, y_pred))

均方差

from sklearn.metrics import mean_squared_error
mean_squared_error(y_test, y_pred))

$R^2$分数

from sklearn.metrics import r2_score
r2_score(y_true, y_pred))

聚类指标

调整兰德系数

from sklearn.metrics import adjusted_rand_score
adjusted_rand_score(y_true, y_pred))

同质性/Homogeneity

from sklearn.metrics import homogeneity_score
homogeneity_score(y_true, y_pred))

调和平均指标/V-measure

from sklearn.metrics import v_measure_score
metrics.v_measure_score(y_true, y_pred))

交叉验证

print(cross_val_score(knn, X_train, y_train, cv=4))
print(cross_val_score(lr, X, y, cv=2))

模型调优

网格搜索

from sklearn.grid_search import GridSearchCV
params = {"n_neighbors": np.arange(1,3), "metric": ["euclidean", "cityblock"]}
grid = GridSearchCV(estimator=knn,param_grid=params)
grid.fit(X_train, y_train)
print(grid.best_score_)
print(grid.best_estimator_.n_neighbors)

随机参数优化

from sklearn.grid_search import RandomizedSearchCV
params = {"n_neighbors": range(1,5), "weights": ["uniform", "distance"]}
rsearch = RandomizedSearchCV(estimator=knn,
param_distributions=params,
cv=4,
n_iter=8,
random_state=5)
rsearch.fit(X_train, y_train)
print(rsearch.best_score_)

Sklearn 速查的更多相关文章

  1. 机器学习算法 Python&R 速查表

    sklearn实战-乳腺癌细胞数据挖掘( 博主亲自录制) https://study.163.com/course/introduction.htm?courseId=1005269003&u ...

  2. 常用的14种HTTP状态码速查手册

    分类 1xx \> Information(信息) // 接收的请求正在处理 2xx \> Success(成功) // 请求正常处理完毕 3xx \> Redirection(重定 ...

  3. jQuery 常用速查

    jQuery 速查 基础 $("css 选择器") 选择元素,创建jquery对象 $("html字符串") 创建jquery对象 $(callback) $( ...

  4. 简明 Git 命令速查表(中文版)

    原文引用地址:https://github.com/flyhigher139/Git-Cheat-Sheet/blob/master/Git%20Cheat%20Sheet-Zh.md在Github上 ...

  5. 《zw版·Halcon-delphi系列原创教程》 zw版-Halcon常用函数Top100中文速查手册

    <zw版·Halcon-delphi系列原创教程> zw版-Halcon常用函数Top100中文速查手册 Halcon函数库非常庞大,v11版有1900多个算子(函数). 这个Top版,对 ...

  6. .htaccess下Flags速查表

    Flags是可选参数,当有多个标志同时出现时,彼此间以逗号分隔. 速查表: RewirteRule 标记 含义 描述 R Redirect 发出一个HTTP重定向 F Forbidden 禁止对URL ...

  7. IL指令速查

    名称 说明 Add 将两个值相加并将结果推送到计算堆栈上. Add.Ovf 将两个整数相加,执行溢出检查,并且将结果推送到计算堆栈上. Add.Ovf.Un 将两个无符号整数值相加,执行溢出检查,并且 ...

  8. Linux命令速查手册,超详细Linux命令教程

    一.常用命令速查 ls cd pwd cat more less tail head cp scp mv mkdir rmdir touch rm ps kill top free clear tre ...

  9. 25个有用的和方便的 WordPress 速查手册

    如果你是 WordPress 开发人员,下载一些方便的 WordPress 备忘单可以在你需要的时候快速查找.下面这个列表,我们已经列出了25个有用的和方便的 WordPress 速查手册,赶紧收藏吧 ...

随机推荐

  1. Python 数组

    使用之前要先导入函数库  import numpy as np 数组名=np.zeros(数组大小,数据类型)    初始化为0值,这里的数据类型只能是数值类型,字符类型不能用 一.一维数组 impo ...

  2. C 语言入门第五章--循环结构和选择结构

    C语言中有三大结构,分别是顺序结构.选择结构和循环结构: 逻辑运算: 与运算: && 或运算:|| 非运算:! ==== #include<stdio.h> int mai ...

  3. Pytorch【直播】2019 年县域农业大脑AI挑战赛---初级准备(一)切图

    比赛地址:https://tianchi.aliyun.com/competition/entrance/231717/introduction 这次比赛给的图非常大5万x5万,在训练之前必须要进行数 ...

  4. 新闻网大数据实时分析可视化系统项目——3、Hadoop2.X分布式集群部署

    (一)hadoop2.x版本下载及安装 Hadoop 版本选择目前主要基于三个厂商(国外)如下所示: 1.基于Apache厂商的最原始的hadoop版本, 所有发行版均基于这个版本进行改进. 2.基于 ...

  5. storm的JavaAPI运行报错

    报错:java.lang.NoClassDefFoundError: org/apache/storm/topology/IRichSpout 原因:idea的bug:本地运行时设置scope为pro ...

  6. 使用input选择本地图片,并且实现预览功能

    1.使用input标签选择本地图片文件 用一个盒子来存放预览的图片 2.JS实现预览 首先添加一个input change事件,再用到 URL.createObjectURL() 方法 用来创建 UR ...

  7. 报错信息 Context []startup failed due to previous errors

    文章转自:http://blog.sina.com.cn/s/blog_49b4a1f10100q93e.html 框架搭建好后,启动服务器出现如下的信息: log4j:WARN No appende ...

  8. 使用Spring Cloud Gateway保护反应式微服务(二)

    抽丝剥茧,细说架构那些事——[优锐课] 接着上篇文章:使用Spring Cloud Gateway保护反应式微服务(一) 我们继续~ 将Spring Cloud Gateway与反应式微服务一起使用 ...

  9. 嵊州普及Day3T2

    题意:对于n数列的全排列,有多少种可能,是每项前缀和不能整除3.输出可能性%1000000000037. 思路:全部模三,剩余1.2.0,1.2可这样排:1.1.2.1.2.1.2.……2或2.2.1 ...

  10. QQ企业通----类库的设计----UDPSocket组件等

    知识点: IPEndPoint    将网络端点表示为 IP 地址和端口号. UdpClient   提供用户数据报 (UDP) 网络服务. UdpClient对象.Close 关闭 UDP 连接. ...