机器学习tensorflow框架初试
本文来自网易云社区
作者:汪洋
前言
新手学习可以点击参考Google的教程。开始前,我们先在本地安装好 TensorFlow机器学习框架。
首先我们在本地window下安装好python环境,约定安装3.6版本;
安装Anaconda工具集后,创建名为 tensorflow 的conda 环境:conda create -n tensorflow pip python=3.6;
conda切换环境:activate tensorflow;
我们安装支持CPU的TensorFlow版本(快速):pip install --ignore-installed --upgrade tensorflow;
最后验证安装是否成功,进入 python dos命名,输入以下代码校验:
import tensorflow as tf
hello = tf.constant('Hello, TensorFlow')
sess = tf.Session()
print(sess.run(hello))输出Hello, TensorFlow,表示成功了。如果失败的话,就选择低版本重新安装如:pip install --ignore-installed --upgrade tensorflow==1.5.0。
其它安装方式点击参考教程。
监督学习实践
官方针对新手演示了一个入门示例,点击教程可查看,本文就围绕这个教程分享。
1.分类
官方示例里讲解了分类鸢尾花问题的解决,我们想到的就是用监督学习训练机器模型。采用这种学习方式后,我们需要确定用鸢尾花的哪些特征来分类,鸢尾花的特征还是蛮多的,官方示例里用的是花萼和花瓣的长度和宽度。
鸢尾花种类非常多,官方也仅是针对三种进行分类:
expected = ['Setosa', 'Versicolor', 'Virginica']
接下来就是获取大量数据,进行预处理,官方示例里直接引用了他人整理的数据源,省略了前期数据处理步骤,前5条数据结构如下:
SepalLength | SepalWidth | PetalLength | PetalWidth | Species | |
---|---|---|---|---|---|
0 | 6.4 | 2.8 | 5.6 | 2.2 | 2 |
1 | 5.0 | 2.3 | 3.3 | 1.0 | 1 |
2 | 4.9 | 2.5 | 4.5 | 1.7 | 2 |
3 | 4.9 | 3.1 | 1.5 | 0.1 | 0 |
4 | 5.7 | 3.8 | 1.7 | 0.1 | 0 |
说明:
最后一列代表着鸢尾花的品种,也就是说它是监督学习中的标签;
中间四列从左到右表示花萼的长度和宽度、花瓣的长度和宽度;
表格数据代表了从120个样本的数据集中抽集的5个样本;
机器学习一般依赖数值,因此当前数据集中标签值都为数字,对应关系:
0 | 1 | 2 |
---|---|---|
Setosa | Versicolor | Virginica |
接下来将编写代码,先复习下概念,模型指特征和标签之间的关系;训练指机器学习阶段,这个阶段模型不断优化。示例里选择的监督试学习方式,模型通过包含标签的样本进行训练。
2. 导入和解析数据集
首先我们要获取训练集和测试集,其中训练集是训练模型的样本,测试集是评估训练后模型效果的样本。
首先设置我们选择的数据集地址
"""训练集"""TRAN_URL = "http://download.tensorflow.org/data/iris_training.csv""""测试集"""TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"
使用tensorflow.keras.utils.get_file函数下载数据集,该方法第一个参数为文件名称,第二个参数为下载地址,点击查看详细)。
import tensorflow as tfdef download():
train_path = tf.keras.utils.get_file('iris_training.csv', TRAN_URL)
test_path = tf.keras.utils.get_file('iris_test.csv', TEST_URL) return train_path, test_path
然后用pandas.read_csv函数解析下载的数据,解析后生成的格式是一个表格,然后再分成特征列表和标签列表,返回训练集和测试集
import pandas as pd
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']def load_data(y_species='Species'):
train_path, test_path = download()
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
train_x, train_y = train, train.pop(y_species) test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
test_x, test_y = test, test.pop(y_species) return (train_x, train_y), (test_x, test_y)
3. 特征列-数值列
我们已经获取到数据集,在tensorflow中需要将数据转换为模型(Estimator)可以使用的数据结构,这时候调用tf.feature_column模块中的函数来转换。鸢尾花例子中,需将特征数据转换为浮点数,调用tf.feature_column.numeric_column方法。
import iris_data (train_x, train_y), (test_x, test_y) = iris_data.load_data()
my_feature_columns = []for key in train_x.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
其中key是 ['SepalLength' , 'SepalWidth' , 'PetalLength' , 'PetalWidth'] 其中之一。
4. 模型选择
官方例子中选择全连接神经网络解决鸢尾花问题,用神经网络来发现特征与标签之间的复杂关系。tensorflow中,通过实例化一个Estimator类指定模型类型,这里我们使用官方提供的预创建的Estimator类,tf.estimator.DNNClassifier,此Estimator会构建一个对样本进行分类的神经网络。
classifier = tf.estimator.DNNClassifier(
feature_columns = my_feature_columns,
hidden_units = [10,10],
n_classes = 3)
feature_columns 参数指训练的特征列(这里是数值列);
hidden_units 参数定义神经网络内每个隐藏层中的神经元数量,这里设置了2个隐藏层,每个隐藏层中神经元数量都是10个;
n_classes 参数表示要预测的标签数量,这里我们需要预测3个品种;
其它参数点击查看
5. 训练模型
上一步我们已经创建了一个学习模型,接下来将数据导入到模型中进行训练。tensorflow中,调用Estimator对象的train方法训练。
classifier.train(
input_fn = lambda:iris_data.train_input_fn(train_x, train_y, 100)
steps = 1000)
input_fn 参数表示提供训练数据的函数; steps 参数表示训练迭代次数;
在train_input_fn函数里,我们将数据转换为 train方法所需的格式。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
为了保证训练效果,训练样本需随机排序。buffer_size 设置大于样本数(120),可确保数据得到充分的随机化处理。
dataset = dataset.shuffle(1000)
为了保证训练期间,有无限量的训练样本,需调用 tf.data.Dataset.repeat。
dataset = dataset.repeat()
train方法一次处理一批样本, tf.data.Dataset.batch 方法通过组合多个样本创建一个批次,这里组合多个包含100个样本的批次。
dataset = dataset.batch(100)
6. 模型评估
接下来我们将训练好的模型预测效果。tensorflow中,每个Estimator对象提供了evaluate方法。
eval_result = classifier.evaluate(
input_fn = lambda:iris_data.eval_input_fn(test_x, test_y, 100)
)
在eval_input_fn函数里,我们将数据转换为 evaluate方法所需的格式。实现跟训练一样,只是无需随机化处理和无限量重复使用测试集。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
dataset.batch(100);return dataset
7. 预测
接下来将该模型对无标签样本进行预测。官方手动提供了三个无标签样本。
predict_x = { 'SepalLength': [5.1, 5.9, 6.9], 'SepalWidth': [3.3, 3.0, 3.1], 'PetalLength': [1.7, 4.2, 5.4], 'PetalWidth': [0.5, 1.5, 2.1],
}
tensorflow中,每个Estimator对象提供了predict方法。
predictions = classifier.predict(
input_fn = lambda:iris_data.eval_input_fn(predict_x, labels=None, 100)
)
改造下eval_input_fn方法,使其能够接受 labels = none 情况
features=dict(features)if labels is None:
inputs = featureselse:
inputs = (features, labels)
dataset = tf.data.Dataset.from_tensor_slices(inputs)
接下来打印下预测结果, predictions 中 class_ids表示可能性最大的品种,probabilities 表示每个品种的概率
for pred_dict in predictions:
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print(class_id, probability)
结果如下:
0 | 0.99706334 |
1 | 0.997407 |
2 | 0.97377485 |
结尾
通过官方例子,新手可初步了解其使用,当然更深入的使用还得学习理论和多使用API。本文是根据官方例子,作为新手重新梳理了一遍。
网易云免费体验馆,0成本体验20+款云产品
更多网易研发、产品、运营经验分享请访问网易云社区。
相关文章:
【推荐】 云计算交互设计师的正确出装姿势
机器学习tensorflow框架初试的更多相关文章
- python机器学习TensorFlow框架
TensorFlow框架 关注公众号"轻松学编程"了解更多. 一.简介 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运 ...
- TensorFlow框架(3)之MNIST机器学习入门
1. MNIST数据集 1.1 概述 Tensorflow框架载tensorflow.contrib.learn.python.learn.datasets包中提供多个机器学习的数据集.本节介绍的是M ...
- TensorFlow框架(5)之机器学习实践
1. Iris data set Iris数据集是常用的分类实验数据集,由Fisher, 1936收集整理.Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集.数据集包含150个数据集,分为3类, ...
- 【TensorFlow篇】--Tensorflow框架初始,实现机器学习中多元线性回归
一.前述 TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着基于数据流图的计算,T ...
- 人工智能 tensorflow框架-->简介及安装01
简介:Tensorflow是google于2015年11月开源的第二代机器学习框架. Tensorflow名字理解:图形边中流动的数据叫张量(Tensor),因此叫Tensorflow 既 张量流动 ...
- (第二章第二部分)TensorFlow框架之读取图片数据
系列博客链接: (第二章第一部分)TensorFlow框架之文件读取流程:https://www.cnblogs.com/kongweisi/p/11050302.html 本文概述: 目标 说明图片 ...
- 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集
一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...
- .NET数据挖掘与机器学习开源框架
1. 数据挖掘与机器学习开源框架 1.1 框架概述 1.1.1 AForge.NET AForge.NET是一个专门为开发者和研究者基于C#框架设计的,他包括计算机视觉与人工智能,图像处理,神经 ...
- 跟我学算法-吴恩达老师(超参数调试, batch归一化, softmax使用,tensorflow框架举例)
1. 在我们学习中,调试超参数是非常重要的. 超参数的调试可以是a学习率,(β1和β2,ε)在Adam梯度下降中使用, layers层数, hidden units 隐藏层的数目, learning_ ...
随机推荐
- 计算结构体、数组、指针的sizeof
1. 结构体的sizeof 题目: sturct aa{ in num; char name[10];}; struct bb{ int a; float b; struct aa c;}; stru ...
- centos7.3上用源代码安装zabbix3.2.7
安装zabbix之前请自行先搭建好LAMP环境! 1.下载源码安装包并解压 1.1 下载 [root@nmserver- ~]# mkdir zabbix [root@nmserver- ~]# cd ...
- golang实现文件上传权限验证(超简单)
Go语言创建web server非常简单,部署也很容易,不像IIS.Apache等那么重量级,需要各种依赖.配置.一些功能单一的web 服务,用Go语言开发特别适合.http文件上传下载服务,在很多地 ...
- framework7中一行的字如果过多就省略号显示的CSS写法
.order-info-title { text-overflow: ellipsis !important; white-space: nowrap !important; overflow: hi ...
- 剑指offer28 字符串的排列
1.全局变量可以在最后去定义并初始化,不一定非要在开头 2.此题有一种特殊情况需要考虑,比如字符串是“aa”,那输出应该是“aa”,而不是“aa,aa”,即相同的不输出.实现这个处理用了c++中的容器 ...
- Mybatis查询报错:There is no getter for property named '*' in 'class java.lang.String
问题: 执行查询时报错:There is no getter for property named '*' in 'class java.lang.String 原因: 传过去的参数为识别.本例为 p ...
- Webpack4 学习笔记四 暴露全局变量、externals
前言 此内容是个人学习笔记,以便日后翻阅.非教程,如有错误还请指出 webpack 暴露全局变量 通过 expose-loader 内联配置 在 webpack中配置 每个模块通过注入的方式 通过CD ...
- 洛谷题解:P1209 【[USACO1.3]修理牛棚 Barn Repair】
原题传送门:https://www.luogu.org/problemnew/show/P1209 首先,这是一道贪心题. 我们先来分析它的贪心策略. 例如,样例: 4 50 18 3 4 6 ...
- Cannot resolve reference to bean 'sessionFactory' while setting bean property 'sessionFactory'; 没有sessionFactory
maven子项目spring配置文件创建bean 没有找到另一个子项目中的bean. 需要引入另一个子项目的配置文件,仅提供测试用 如下: <!-- 仅供测试用 --> <impor ...
- PXE+DHCP+TFTP+Cobbler 无人值守安装centos 7
Cobbler(补鞋匠)是通过将DHCP.TFTP.DNS.HTTP等服务进行集成,创建一个中央管理节点,其可以实现的功能有配置服务,创建存储库,解压缩操作系统媒介,代理或集成一个配置管理系统,控制电 ...