原文地址:https://www.jianshu.com/p/1db700f866ee

问题描述





程序实现

# kNN_RBFN.py
# coding:utf-8 import numpy as np
import matplotlib.pyplot as plt def ReadData(dataFile): with open(dataFile, 'r') as f:
lines = f.readlines()
data_list = []
for line in lines:
line = line.strip().split()
data_list.append([float(l) for l in line])
dataArray = np.array(data_list)
return dataArray def sign(n): if(n>=0):
return 1
else:
return -1 def kNN(k,trainArray,dataX):
num_data=dataX.shape[0]
predY=np.zeros((num_data,))
for n in range(num_data):
distArray=np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1)
id_list=np.argsort(distArray,axis=0).tolist()[:k]
for i in id_list:
predY[n]+=trainArray[i,-1]
predY[n]=sign(predY[n])
return predY def GetZeroOneError(predY,dataY):
return (predY!=dataY).sum()/dataY.shape[0] def plot_bar_chart(X,Y,nameX,nameY,saveName):
plt.figure(figsize=(10,6))
plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
for (c,w) in zip(X,Y):
plt.text(c,w*1.03,str(round(w,4)))
plt.xlabel(nameX)
plt.ylabel(nameY)
plt.xlim(X[0]-1,X[-1]+1)
plt.xticks(X)
plt.ylim(0,1)
plt.title(nameY+" versus "+nameX)
plt.savefig(saveName)
return def RBFNetwork(k,gamma,trainArray,dataX):
num_data=dataX.shape[0]
predY=np.zeros((num_data,))
for n in range(num_data):
gaussianDistArray=np.exp(-gamma*np.sum((trainArray[:,:-1]-dataX[n,:])**2,axis=1))
id_list=np.argsort(gaussianDistArray,axis=0).tolist()[:k]
for i in id_list:
predY[n]+=trainArray[i,-1]
predY[n]=sign(predY[n])
return predY if __name__=="__main__": dataArray=ReadData("hw8_train.dat")
testArray=ReadData("hw8_test.dat")
k_list=[1,3,5,7,9]
ein_list=[]
eout_list=[]
for k in k_list:
predY=kNN(k,dataArray,dataArray[:,:-1])
ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
predY=kNN(k,dataArray,testArray[:,:-1])
eout_list.append(GetZeroOneError(predY,testArray[:,-1])) # 12
plot_bar_chart(k_list,ein_list,nameX="k",nameY="Ein(gk-nbor)",saveName="12.png") # 14
plot_bar_chart(k_list,eout_list,nameX='k',nameY="Eout(gk-bor)",saveName="14.png") gamma_list=[-3,-1,0,1,2]
ein_list=[]
eout_list=[]
for gamma in gamma_list:
predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,dataArray[:,:-1])
ein_list.append(GetZeroOneError(predY,dataArray[:,-1]))
predY=RBFNetwork(dataArray.shape[0],10**gamma,dataArray,testArray[:,:-1])
eout_list.append(GetZeroOneError(predY,testArray[:,-1])) # 16
plot_bar_chart(X=gamma_list,Y=ein_list,nameX="log10(gamma)",nameY="Ein(guniform)",saveName="16.png") # 18
plot_bar_chart(X=gamma_list,Y=eout_list,nameX="log10(gamma)",nameY="Eout(guniform)",saveName="18.png")
# kMeans.py
# coding:utf-8 from numpy import random
from kNN_RBFN import * def kMeans(t,k,dataArray):
num_data=dataArray.shape[0]
random.seed(t)
centreIDList=random.randint(0,num_data,k).tolist()
nowCentreArray=dataArray[centreIDList,:]
tmpCentreArray=np.array(nowCentreArray)
ein=1000000
nowEin=ein-1
dict={}
while(nowEin<ein):
ein=nowEin
dict = {}
for n in range(num_data):
distArray=np.sum((nowCentreArray-dataArray[n,:])**2,axis=1)
minID=np.argmin(distArray)
tmpCentreArray[minID]=(tmpCentreArray[minID]+dataArray[n,:])/2
try:
dict[minID].append(dataArray[n,:])
except:
dict[minID]=[]
dict[minID].append(dataArray[n,:])
nowCentreArray=np.array(tmpCentreArray)
nowEin=GetEin(nowCentreArray,dict)
return nowCentreArray,dict def GetEin(nowCentreArray,dict):
k=nowCentreArray.shape[0]
ein=0
for i in range(k):
if i not in dict.keys():
continue
data=np.array(dict[i])
ein+=np.average(np.sum((data-nowCentreArray[i])**2,axis=1))
return ein def plot_bar_chart(X,Y,nameX,nameY,saveName):
plt.figure(figsize=(10,6))
plt.bar(left=X,height=Y,width=0.8,align="center",yerr=0.000001)
for (c,w) in zip(X,Y):
plt.text(c,w*1.03,str(round(w,4)))
plt.xlabel(nameX)
plt.ylabel(nameY)
plt.xlim(X[0]-1,X[-1]+1)
plt.xticks(X)
plt.title(nameY+" versus "+nameX)
plt.savefig(saveName)
return if __name__=="__main__": dataArray=ReadData("hw8_nolabel_train.dat")
k_list=[2,4,6,8,10]
ein_list=[]
for k in k_list:
ein=0
for t in range(500):
nowCentreArray,dict=kMeans(t,k,dataArray)
ein+=GetEin(nowCentreArray,dict)
ein_list.append(ein/500) plot_bar_chart(k_list,ein_list,nameX="k",nameY="the average Ein over 500 experiments",saveName="20.png")

运行结果









机器学习技法笔记:Homework #8 kNN&RBF&k-Means相关习题的更多相关文章

  1. 机器学习技法笔记(2)-Linear SVM

    从这一节开始学习机器学习技法课程中的SVM, 这一节主要介绍标准形式的SVM: Linear SVM 引入SVM 首先回顾Percentron Learning Algrithm(感知器算法PLA)是 ...

  2. 机器学习十大算法之KNN(K最近邻,k-NearestNeighbor)算法

    机器学习十大算法之KNN算法 前段时间一直在搞tkinter,机器学习荒废了一阵子.如今想重新写一个,发现遇到不少问题,不过最终还是解决了.希望与大家共同进步. 闲话少说,进入正题. KNN算法也称最 ...

  3. 机器学习技法笔记:Homework #6 AdaBoost&Kernel Ridge Regression相关习题

    原文地址:http://www.jianshu.com/p/9bf9e2add795 AdaBoost 问题描述 程序实现 # coding:utf-8 import math import nump ...

  4. 机器学习技法笔记:Homework #5 特征变换&Soft-Margin SVM相关习题

    原文地址:https://www.jianshu.com/p/6bf801bdc644 特征变换 问题描述 程序实现 # coding: utf-8 import numpy as np from c ...

  5. 机器学习技法笔记:Homework #7 Decision Tree&Random Forest相关习题

    原文地址:https://www.jianshu.com/p/7ff6fd6fc99f 问题描述 程序实现 13-15 # coding:utf-8 # decision_tree.py import ...

  6. 机器学习技法笔记:14 Radial Basis Function Network

    Roadmap RBF Network Hypothesis RBF Network Learning k-Means Algorithm k-Means and RBF Network in Act ...

  7. 机器学习技法笔记:08 Adaptive Boosting

    Roadmap Motivation of Boosting Diversity by Re-weighting Adaptive Boosting Algorithm Adaptive Boosti ...

  8. 机器学习技法笔记:15 Matrix Factorization

    Roadmap Linear Network Hypothesis Basic Matrix Factorization Stochastic Gradient Descent Summary of ...

  9. 机器学习技法笔记:16 Finale

    Roadmap Feature Exploitation Techniques Error Optimization Techniques Overfitting Elimination Techni ...

随机推荐

  1. ()C#打印机

    System.Drawing.Printing下得用来完成打印功能 1.打印设置 2.页面设置 3.打印预览 4.打印

  2. USB3.0 对 2.4G WiFi 影响

    Intel的一篇白皮书<USB 3.0 Radio Frequency Interference Impact on 2.4 GHz Wireless Devices>中即清楚地指出,US ...

  3. 2019 ACM-ICPC 南京 现场赛 H. Prince and Princess

    题意 王子想要娶公主,但是需要完成一个挑战:在一些房间中找出公主在哪. 每个房间有一个人,他们彼此知道谁在哪个房间.可以问他们三种问题: 你是谁? 在某个房间是谁? 公主在哪个房间? 有三类人,一类一 ...

  4. 爬虫(五)—— selenium模块启动浏览器自动化测试

    目录 selenium模块 一.selenium介绍 二.环境搭建 三.使用selenium模块 1.使用chrome并设置为无GUI模式 2.使用chrome有GUI模式 3.查找元素 4.获取标签 ...

  5. Java不可变对象

    在创建状态后无法更改其状态的对象称为不可变对象.一个对象不可变的类称为不可变类.不变的对象可以由程序的不同区域共享而不用担心其状态改变. 不可变对象本质上是线程安全的. 示例 以下代码创建了不可变类的 ...

  6. Eureka 系列(04)客户端源码分析

    Eureka 系列(04)客户端源码分析 [TOC] 0. Spring Cloud 系列目录 - Eureka 篇 在上一篇 Eureka 系列(01)最简使用姿态 中对 Eureka 的简单用法做 ...

  7. axios interceptors 拦截 , 页面跳转, token 验证 Vue+axios实现登陆拦截,axios封装(报错,鉴权,跳转,拦截,提示)

    Vue+axios实现登陆拦截,axios封装(报错,鉴权,跳转,拦截,提示) :https://blog.csdn.net/H1069495874/article/details/80057107 ...

  8. 马士兵对话京东T6阿里P7(薪水):月薪5万,他为何要离职?

    马士兵大佬你知道吗? 你竟然不知道?你怎么可能不知道!你不知道是不可能的! 记得自己的第一行Java代码,你的Hello World是跟着谁学的吗?我的就是马士兵老师! 马士兵是唯一一个在当时讲课是让 ...

  9. 博客中引入了gitment评论系统

    官方github地址:https://github.com/imsun/gitment 官方中文说明地址:https://imsun.net/posts/gitment-introduction/ 官 ...

  10. Git中crlf自动转换的坑

    新上手一个项目,克隆了代码下来搭环境,一路坑.其中一个sh脚本执行不了,报IOException,java日志除了"找不到文件或文件夹"之外看不出任何信息,手动运行脚本才发现是脚本 ...