tensorflow kmeans 聚类
iris:
# -*- coding: utf-8 -*-
# K-means with TensorFlow
#----------------------------------
#
# This script shows how to do k-means with TensorFlow import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn import datasets
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale
from tensorflow.python.framework import ops
ops.reset_default_graph() sess = tf.Session() iris = datasets.load_iris() num_pts = len(iris.data)
num_feats = len(iris.data[0]) # Set k-means parameters
# There are 3 types of iris flowers, see if we can predict them
k=3
generations = 25 data_points = tf.Variable(iris.data)
cluster_labels = tf.Variable(tf.zeros([num_pts], dtype=tf.int64)) # Randomly choose starting points
rand_starts = np.array([iris.data[np.random.choice(len(iris.data))] for _ in range(k)]) centroids = tf.Variable(rand_starts) # In order to calculate the distance between every data point and every centroid, we
# repeat the centroids into a (num_points) by k matrix.
centroid_matrix = tf.reshape(tf.tile(centroids, [num_pts, 1]), [num_pts, k, num_feats])
# Then we reshape the data points into k (3) repeats
point_matrix = tf.reshape(tf.tile(data_points, [1, k]), [num_pts, k, num_feats])
distances = tf.reduce_sum(tf.square(point_matrix - centroid_matrix), axis=2) #Find the group it belongs to with tf.argmin()
centroid_group = tf.argmin(distances, 1) # Find the group average
def data_group_avg(group_ids, data):
# Sum each group
sum_total = tf.unsorted_segment_sum(data, group_ids, 3)
# Count each group
num_total = tf.unsorted_segment_sum(tf.ones_like(data), group_ids, 3)
# Calculate average
avg_by_group = sum_total/num_total
return(avg_by_group) means = data_group_avg(centroid_group, data_points) update = tf.group(centroids.assign(means), cluster_labels.assign(centroid_group)) init = tf.global_variables_initializer() sess.run(init) for i in range(generations):
print('Calculating gen {}, out of {}.'.format(i, generations))
_, centroid_group_count = sess.run([update, centroid_group])
group_count = []
for ix in range(k):
group_count.append(np.sum(centroid_group_count==ix))
print('Group counts: {}'.format(group_count)) [centers, assignments] = sess.run([centroids, cluster_labels]) # Find which group assignments correspond to which group labels
# First, need a most common element function
def most_common(my_list):
return(max(set(my_list), key=my_list.count)) label0 = most_common(list(assignments[0:50]))
label1 = most_common(list(assignments[50:100]))
label2 = most_common(list(assignments[100:150])) group0_count = np.sum(assignments[0:50]==label0)
group1_count = np.sum(assignments[50:100]==label1)
group2_count = np.sum(assignments[100:150]==label2) accuracy = (group0_count + group1_count + group2_count)/150. print('Accuracy: {:.2}'.format(accuracy)) # Also plot the output
# First use PCA to transform the 4-dimensional data into 2-dimensions
pca_model = PCA(n_components=2)
reduced_data = pca_model.fit_transform(iris.data)
# Transform centers
reduced_centers = pca_model.transform(centers) # Step size of mesh for plotting
h = .02 # Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) # Get k-means classifications for the grid points
xx_pt = list(xx.ravel())
yy_pt = list(yy.ravel())
xy_pts = np.array([[x,y] for x,y in zip(xx_pt, yy_pt)])
mytree = cKDTree(reduced_centers)
dist, indexes = mytree.query(xy_pts) # Put the result into a color plot
indexes = indexes.reshape(xx.shape)
plt.figure(1)
plt.clf()
plt.imshow(indexes, interpolation='nearest',
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
cmap=plt.cm.Paired,
aspect='auto', origin='lower') # Plot each of the true iris data groups
symbols = ['o', '^', 'D']
label_name = ['Setosa', 'Versicolour', 'Virginica']
for i in range(3):
temp_group = reduced_data[(i*50):(50)*(i+1)]
plt.plot(temp_group[:, 0], temp_group[:, 1], symbols[i], markersize=10, label=label_name[i])
# Plot the centroids as a white X
plt.scatter(reduced_centers[:, 0], reduced_centers[:, 1],
marker='x', s=169, linewidths=3,
color='w', zorder=10)
plt.title('K-means clustering on Iris Dataset\n'
'Centroids are marked with white cross')
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='lower right')
plt.show()

tensorflow kmeans 聚类的更多相关文章
- 用 TensorFlow 实现 k-means 聚类代码解析
k-means 是聚类中比较简单的一种.用这个例子说一下感受一下 TensorFlow 的强大功能和语法. 一. TensorFlow 的安装 按照官网上的步骤一步一步来即可,我使用的是 virtua ...
- K-Means 聚类算法
K-Means 概念定义: K-Means 是一种基于距离的排他的聚类划分方法. 上面的 K-Means 描述中包含了几个概念: 聚类(Clustering):K-Means 是一种聚类分析(Clus ...
- 用scikit-learn学习K-Means聚类
在K-Means聚类算法原理中,我们对K-Means的原理做了总结,本文我们就来讨论用scikit-learn来学习K-Means聚类.重点讲述如何选择合适的k值. 1. K-Means类概述 在sc ...
- K-Means聚类算法原理
K-Means算法是无监督的聚类算法,它实现起来比较简单,聚类效果也不错,因此应用很广泛.K-Means算法有大量的变体,本文就从最传统的K-Means算法讲起,在其基础上讲述K-Means的优化变体 ...
- K-means聚类算法
聚类分析(英语:Cluster analysis,亦称为群集分析) K-means也是聚类算法中最简单的一种了,但是里面包含的思想却是不一般.最早我使用并实现这个算法是在学习韩爷爷那本数据挖掘的书中, ...
- k-means聚类算法python实现
K-means聚类算法 算法优缺点: 优点:容易实现缺点:可能收敛到局部最小值,在大规模数据集上收敛较慢使用数据类型:数值型数据 算法思想 k-means算法实际上就是通过计算不同样本间的距离来判断他 ...
- K-Means 聚类算法原理分析与代码实现
前言 在前面的文章中,涉及到的机器学习算法均为监督学习算法. 所谓监督学习,就是有训练过程的学习.再确切点,就是有 "分类标签集" 的学习. 现在开始,将进入到非监督学习领域.从经 ...
- Kmeans聚类算法原理与实现
Kmeans聚类算法 1 Kmeans聚类算法的基本原理 K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一.K-means算法的基本思想是:以空间中k个点为中心进行聚类,对 ...
- 机器学习六--K-means聚类算法
机器学习六--K-means聚类算法 想想常见的分类算法有决策树.Logistic回归.SVM.贝叶斯等.分类作为一种监督学习方法,要求必须事先明确知道各个类别的信息,并且断言所有待分类项都有一个类别 ...
随机推荐
- MFC中几个函数的使用
1.GetDlgItem() CWnd* GetDlgItem ( int nID ) const;这个就足够了(在MFC中经常这么用),如果你是在win32API下面写的话,那么一般创建一个窗口 ...
- Linux禁止Ctrl+Alt+Del重新启动
方法1:改动/etc/inittab 屏蔽 ca:12345:ctrlaltdel:/sbin/shutdown -t1 -a -r now 或者删除改行内容 保存退出 适用对象:RedHat4.8 ...
- MySql(八):MySQL性能调优——Query 的优化
一.理解MySQL的Query Optimizer MySQL Optimizer是一个专门负责优化SELECT 语句的优化器模块,它主要的功能就是通过计算分析系统中收集的各种统计信息,为客户端请求的 ...
- 从0开始学习 GitHub 系列汇总笔记
本文学习自Stromzhang, 原文地址请移步:从0开始学习 GitHub 系列汇总 我的笔记: 0x00 从0开始学习GitHub 系列之[初识GitHub] GitHub 影响力 a.全球顶级 ...
- Hibernate demo之使用注解
1.新建maven项目 testHibernate,pom.xml <?xml version="1.0" encoding="UTF-8"?> & ...
- XSS前置课程--同源策略
什么是同源策略: 在用户浏览互联网中的网页的过程中,身份和权限的思想是贯穿始终的 同源策略(Same-Origin Policy),就是为了保证互联网之中,各类资源的安全性而诞生的产物,它实际上是一个 ...
- 30天自制操作系统(三)进入32位模式并导入C语言
1 制作真正的IPL IPL(Initial Program Loader),启动程序装载器,但是之前并没有实质性的装载任何程序,这次作者要开始装载程序了. 虽然现在开发的操作系统啥功能也没有,作者说 ...
- 给定一个递增序列,a1 <a2 <...<an 。定义这个序列的最大间隔为d=max{ai+1 - ai }(1≤i<n),现在要从a2 ,a3 ..an-1 中删除一个元素。问剩余序列的最大间隔最小是多少?
// ConsoleApplication5.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" #include<vector> ...
- Spark源码分析之五:Task调度(一)
在前四篇博文中,我们分析了Job提交运行总流程的第一阶段Stage划分与提交,它又被细化为三个分阶段: 1.Job的调度模型与运行反馈: 2.Stage划分: 3.Stage提交:对应TaskSet的 ...
- Android错误之Location of the Android SDK has not been setup in the preferences
解决的方法:打开Help-Install new software,更新文件就可以,这时国内的朋友就须要FQ了,详细有代理,能够网上自行搜索.