import numpy as np
import matplotlib.pyplot as plt from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
from sklearn.model_selection import train_test_split
from sklearn import datasets, linear_model,discriminant_analysis def load_data():
# 使用 scikit-learn 自带的 iris 数据集
iris=datasets.load_iris()
X_train=iris.data
y_train=iris.target
return train_test_split(X_train, y_train,test_size=0.25,random_state=0,stratify=y_train) #线性判断分析LinearDiscriminantAnalysis
def test_LinearDiscriminantAnalysis(*data):
X_train,X_test,y_train,y_test=data
lda = discriminant_analysis.LinearDiscriminantAnalysis()
lda.fit(X_train, y_train)
print('Coefficients:%s, intercept %s'%(lda.coef_,lda.intercept_))
print('Score: %.2f' % lda.score(X_test, y_test)) # 产生用于分类的数据集
X_train,X_test,y_train,y_test=load_data()
# 调用 test_LinearDiscriminantAnalysis
test_LinearDiscriminantAnalysis(X_train,X_test,y_train,y_test)

def plot_LDA(converted_X,y):
'''
绘制经过 LDA 转换后的数据
:param converted_X: 经过 LDA转换后的样本集
:param y: 样本集的标记
'''
fig=plt.figure()
ax=Axes3D(fig)
colors='rgb'
markers='o*s'
for target,color,marker in zip([0,1,2],colors,markers):
pos=(y==target).ravel()
X=converted_X[pos,:]
ax.scatter(X[:,0], X[:,1], X[:,2],color=color,marker=marker,label="Label %d"%target)
ax.legend(loc="best")
fig.suptitle("Iris After LDA")
plt.show() def run_plot_LDA():
'''
执行 plot_LDA 。其中数据集来自于 load_data() 函数
'''
X_train,X_test,y_train,y_test=load_data()
X=np.vstack((X_train,X_test))
Y=np.vstack((y_train.reshape(y_train.size,1),y_test.reshape(y_test.size,1)))
lda = discriminant_analysis.LinearDiscriminantAnalysis()
lda.fit(X, Y)
converted_X=np.dot(X,np.transpose(lda.coef_))+lda.intercept_
plot_LDA(converted_X,Y) # 调用 run_plot_LDA
run_plot_LDA()

def test_LinearDiscriminantAnalysis_solver(*data):
'''
测试 LinearDiscriminantAnalysis 的预测性能随 solver 参数的影响
'''
X_train,X_test,y_train,y_test=data
solvers=['svd','lsqr','eigen']
for solver in solvers:
if(solver=='svd'):
lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver)
else:
lda = discriminant_analysis.LinearDiscriminantAnalysis(solver=solver,shrinkage=None)
lda.fit(X_train, y_train)
print('Score at solver=%s: %.2f' %(solver, lda.score(X_test, y_test))) # 调用 test_LinearDiscriminantAnalysis_solver
test_LinearDiscriminantAnalysis_solver(X_train,X_test,y_train,y_test)

def test_LinearDiscriminantAnalysis_shrinkage(*data):
'''
测试 LinearDiscriminantAnalysis 的预测性能随 shrinkage 参数的影响
'''
X_train,X_test,y_train,y_test=data
shrinkages=np.linspace(0.0,1.0,num=20)
scores=[]
for shrinkage in shrinkages:
lda = discriminant_analysis.LinearDiscriminantAnalysis(solver='lsqr',shrinkage=shrinkage)
lda.fit(X_train, y_train)
scores.append(lda.score(X_test, y_test))
## 绘图
fig=plt.figure()
ax=fig.add_subplot(1,1,1)
ax.plot(shrinkages,scores)
ax.set_xlabel(r"shrinkage")
ax.set_ylabel(r"score")
ax.set_ylim(0,1.05)
ax.set_title("LinearDiscriminantAnalysis")
plt.show()
# 调用 test_LinearDiscr
test_LinearDiscriminantAnalysis_shrinkage(X_train,X_test,y_train,y_test)

吴裕雄 python 机器学习——线性判断分析LinearDiscriminantAnalysis的更多相关文章

  1. 吴裕雄 python 机器学习——主成份分析PCA降维

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from sklearn import datas ...

  2. 吴裕雄--天生自然 人工智能机器学习实战代码:线性判断分析LINEARDISCRIMINANTANALYSIS

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

  3. 吴裕雄 python 机器学习——支持向量机线性分类LinearSVC模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets, linear_model,svm fr ...

  4. 吴裕雄 python 机器学习——局部线性嵌入LLE降维模型

    # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from sklearn import datas ...

  5. 吴裕雄 python 机器学习——人工神经网络与原始感知机模型

    import numpy as np from matplotlib import pyplot as plt from mpl_toolkits.mplot3d import Axes3D from ...

  6. 吴裕雄 python 机器学习——分类决策树模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.model_s ...

  7. 吴裕雄 python 机器学习——回归决策树模型

    import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.model_s ...

  8. 吴裕雄 python 机器学习——逻辑回归

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

  9. 吴裕雄 python 机器学习——ElasticNet回归

    import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot ...

随机推荐

  1. 亲测实验,stm32待机模式和停机模式唤醒程序的区别,以及唤醒后程序入口

    这两天研究了STM32的低功耗知识,低功耗里主要研究的是STM32的待机模式和停机模式.让单片机进入的待机模式和停机模式比较容易,实验中通过设置中断口PA1来响应待机和停机模式. void EXTI1 ...

  2. XML的介绍使用

    一.什么是XML? XML,Extensible Markup Language,扩展性标识语言.文件的后缀名为:.xml.就像HTML的作用是显示数据,XML的作用是传输和存储数据. 据说,java ...

  3. 联想扬天3900c电脑BIOS设置U盘启动图文教程

    有联想扬天3900c的用户反映说,制作好U大侠U盘后,按快捷键却识别不到U盘,不能进行U盘启动,这是怎么回事呢?其实这是BIOS设置的问题,下面U大侠教大家如何对联想扬天3900c电脑进行BIOS设置 ...

  4. Unity 中实现粒子系统的 LOD

    模型的 LOD 比较简单,直接使用 Unity 提供的组件 LODGroup 挂到模型物体上,然后分别指定不同 LOD 级别的 Renderer 即可. LODGroup 并不是用距离来控制 LOD, ...

  5. Java Web开发Session超时设置

    在Java Web开发中,Session为我们提供了很多方便,Session是由浏览器和服务器之间维护的.Session超时理解为:浏览器和服务器之间创建了一个Session,由于客户端长时间(休眠时 ...

  6. mysql循环插入千万级数据

    mysql使用存储过程循环插入大量数据,简单的一条条循环插入,效率会很低,需要考虑批量插入. 测试准备: 1.建表: CREATE TABLE `mysql_genarate` ( `id` ) NO ...

  7. Node文件模块

    在上一篇文章中有提到,Node模块分为核心模块和文件模块,接下来就简单总结一下文件模块. 文件模块则是在运行时动态加载,需要完整的路径分析.文件定位.编译执行过程.速度相比核心模块稍微慢一些,但是用的 ...

  8. 学习笔记《Java多线程编程实战指南》二

    2.1线程属性 属性 属性类型及用途  只读属性  注意事项 编号(id) long型,标识不同线程  是  不适合用作唯一标识 名称(name) String型,区分不同线程  否  设置名称有助于 ...

  9. 阿里云ssl负载均衡证书配置

    https://www.chinassl.net/ssl_install/n683.html

  10. C/C++中指针和java的引用区别

    C++指针  要区分指针变量和指针变量所指对象. 指针变量先是一个变量,它有自己的地址和存储的内容,所以要想清楚是改变指针变量的值(即地址),还是改变指针变量所指对象的值. #include < ...