kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。

一、数据点的实现

package com.meachine.learning.kmeans;

import java.util.ArrayList;

/**
* 数据点,有n维数据
*
*/
public class Point {
private static int num;
private int id;
private int dimensioNum; // 维度
private ArrayList<Double> values;
private int clusterId = -1;
private double minDist = Integer.MAX_VALUE; public Point() {
id = ++num;
values = new ArrayList<>();
} public void add(double e) {
values.add(e);
dimensioNum++;
}
//------set与get省略----------
}

二、数据簇的实现

package com.meachine.learning.kmeans;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.Setter;
import lombok.ToString; /**
* 簇<br>
* 数据集合的基本信息
*
*/
public class Cluster {
// 簇id
private int clusterId;
// 属于该簇的点的个数
private int numOfPoints;
// 簇中心点的信息
private Point center; public Cluster(int id) {
this.clusterId = id;
numOfPoints = 0;
} public Cluster(int id, Point center) {
this.clusterId = id;
this.center = center;
}
//----------set与get省略----------------
}

三、计算数据点距离

package com.meachine.learning.kmeans;

import java.util.List;

/**
* 计算距离接口
*
*/
public interface IDistance<T> {
public double getDis(List<T> p1, List<T> p2);
}

  

package com.meachine.learning.kmeans;

import java.util.List;

/**
* 欧式距离
*
*/
public class OujilidDistance<T extends Number> implements IDistance<T> { public double getDis(List<T> a, List<T> b) {
if (a.size() != b.size()) {
throw new IllegalArgumentException("Size not compatible!");
}
double result = 0;
for (int i = 0; i < a.size(); i++) {
result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);
}
return Math.sqrt(result);
} }

四、K-Means算法

  

package com.meachine.learning.kmeans;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random; /**
* K-Means算法
*
* @author Cang
*
*/
public class KMeans {
// 簇的个数
private int k;
// 维度,即多少个变量
private int dimensioNum;
// 最大迭代次数
private int maxItrNum = 100;
private IDistance<Double> distance;
private List<Point> points;
private List<Cluster> clusters = new ArrayList<Cluster>();
private String dataFileName = "D:/testSet.txt"; public KMeans(int k) {
this.k = k;
} /**
* 初始化数据
*/
public void init() {
points = loadDataSet(dataFileName);
distance = new OujilidDistance<Double>();
initCluster();
} /**
* 加载数据集
*
* @param fileName
* @return
*/
private List<Point> loadDataSet(String fileName) {
List<Point> points = new ArrayList<>();
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
int i = 0;
while ((tempString = reader.readLine()) != null) {
Point point = new Point();
dimensioNum = tempString.split("\t").length;
for (String data : tempString.split("\t")) {
point.add(Double.parseDouble(data));
}
points.add(point);
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
}
return points;
} /**
* 初始化簇中心
*
* @return
*/
private void initCluster() {
Random ran = new Random();
int id = 0;
while (id < k) {
Cluster c = new Cluster(++id);
int temp = ran.nextInt(points.size());
c.setCenter(points.get(temp));
clusters.add(c);
}
} /**
* kMeans 具体算法
*/
public void clustering() {
boolean finished = false;
int count = 0;
while (!finished) {
// 寻找最近的中心
finished = true;
for (Point point : points) {
for (Cluster cluster : clusters) { double minLen = distance.getDis(cluster.getCenter().getValues(),
point.getValues());
// 更新最小距离
if (minLen < point.getMinDist()) {
if (cluster.getClusterId() != point.getClusterId()) {
finished = false;
point.setClusterId(cluster.getClusterId());
}
point.setMinDist(minLen);
}
}
}
System.out.println("Cluster center info:");
for (Cluster string : clusters) {
System.out.println(string.getCenter().getValues());
}
// 更改中心的位置
changeCentroids();
// 超过循环次数,则跳出循环
if (++count > maxItrNum) {
finished = true;
}
}
} /**
* 改变簇中心
*/
private void changeCentroids() {
for (Cluster cluster : clusters) {
ArrayList<Double> newCenterValue = new ArrayList<Double>();
Point newCenterPoint = new Point();
double result = 0;
for (int i = 0; i < dimensioNum; i++) {
for (Point point : points) {
if (point.getClusterId() == cluster.getClusterId()) {
result += point.getValues().get(i);
}
}
newCenterValue.add(result / points.size());
}
newCenterPoint.setClusterId(cluster.getClusterId());
newCenterPoint.setValues(newCenterValue);
cluster.setCenter(newCenterPoint);
}
} public static void main(String[] args) {
KMeans kmeans = new KMeans(4);
kmeans.init();
kmeans.clustering();
}
}

  

K-Means 算法(Java)的更多相关文章

  1. k近邻算法-java实现

    最近在看<机器学习实战>这本书,因为自己本身很想深入的了解机器学习算法,加之想学python,就在朋友的推荐之下选择了这本书进行学习. 一 . K-近邻算法(KNN)概述 最简单最初级的分 ...

  2. KNN 与 K - Means 算法比较

    KNN K-Means 1.分类算法 聚类算法 2.监督学习 非监督学习 3.数据类型:喂给它的数据集是带label的数据,已经是完全正确的数据 喂给它的数据集是无label的数据,是杂乱无章的,经过 ...

  3. K-means算法

    K-means算法很简单,它属于无监督学习算法中的聚类算法中的一种方法吧,利用欧式距离进行聚合啦. 解决的问题如图所示哈:有一堆没有标签的训练样本,并且它们可以潜在地分为K类,我们怎么把它们划分呢?  ...

  4. k近邻算法的Java实现

    k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系.输入没有标签的新数据之后, ...

  5. KNN算法java实现代码注释

    K近邻算法思想非常简单,总结起来就是根据某种距离度量检测未知数据与已知数据的距离,统计其中距离最近的k个已知数据的类别,以多数投票的形式确定未知数据的类别. 一直想自己实现knn的java实现,但限于 ...

  6. Floyd算法java实现demo

    Floyd算法java实现,如下: https://www.cnblogs.com/Halburt/p/10756572.html package a; /** * ┏┓ ┏┓+ + * ┏┛┻━━━ ...

  7. k-means算法Java一维实现

    这里的程序稍微有点变形.k_means方法返回K-means聚类的若干中心点.代码: import java.util.ArrayList; import java.util.Collections; ...

  8. 感知机学习算法Java实现

    感知机学习算法Java实现. Perceptron类用于实现感知机, 其中的perceptronOriginal()方法用于实现感知机学习算法的原始形式: perceptronAnother()方法用 ...

  9. 一致哈希算法Java实现

    一致哈希算法(Consistent Hashing Algorithms)是一个分布式系统中经常使用的算法. 传统的Hash算法当槽位(Slot)增减时,面临全部数据又一次部署的问题.而一致哈希算法确 ...

  10. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

随机推荐

  1. Android 菜单动态变化【添加或去除】

    <menu xmlns:android="http://schemas.android.com/apk/res/android"> <group android: ...

  2. redis数据持久化(快照/日志):

    1.RDB快照的配置选项: save // 900内,有1条写入,则产生快照 save // 如果300秒内有1000次写入,则产生快照 save // 如果60秒内有10000次写入,则产生快照 ( ...

  3. 微信小程序 --- app.js文件

    app.js文件是项目的入口文件: //app.js App({ onLaunch: function () { // 展示本地存储能力 var logs = wx.getStorageSync('l ...

  4. scikit_learn 中文说明入门

    原文:http://www.cnblogs.com/taceywong/p/4568806.html 原文地址:http://scikit-learn.org/stable/tutorial/basi ...

  5. 【MySQL】为什么不要问我DB极限QPS/TPS

    为什么不要问我DB极限QPS/TPS 背景 相信很多开发都会有这个疑问,DB到底可以支撑多大的业务量,如何去评估?对于这个很专业的问题,DBA也没有办法直接告诉你,更多的都是靠经验提供一个看似靠谱的结 ...

  6. 磁盘 I/O 性能监控的指标

    指标 1:每秒 I/O 数(IOPS 或 tps) 对于磁盘来说,一次磁盘的连续读或者连续写称为一次磁盘 I/O, 磁盘的 IOPS 就是每秒磁盘连续读次数和连续写次数之和.当传输小块不连续数据时,该 ...

  7. Django - Cookie、Session、自定义分页和Django分页器

    2. 今日内容 https://www.cnblogs.com/liwenzhou/p/8343243.html 1. Cookie和Session 1. Cookie 服务端: 1. 生成字符串 2 ...

  8. Girls and Boys---hdu1068(最大独立集=顶点数-最大匹配)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=1068 题意:有n个人,他们之间存在着恋爱关系,现在告诉你每个人和其他人的关系,然后要从这n个人间选出尽 ...

  9. Key Set---hud5363(快速幂)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=5363 #include <iostream> #include <cstdlib&g ...

  10. 【Maven学习】Nexus OSS私服仓库的安装和配置

    背景 公司的代码依赖是通过Maven进行管理的,而Maven的私库我们使用的是Nexus,目前使用的版本是Nexus Repository Manager OSS 2.12.1. 但是由于之前我们搭建 ...