TensorFlow高层次机器学习API (tf.contrib.learn)
TensorFlow高层次机器学习API (tf.contrib.learn)
1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据
2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)
3.classifer.fit 训练模型
4.classifier.evaluate 评价模型
5.classifier.predict 预测新样本
完整代码:

1 from __future__ import absolute_import
2 from __future__ import division
3 from __future__ import print_function
4
5 import tensorflow as tf
6 import numpy as np
7
8 # Data sets
9 IRIS_TRAINING = "iris_training.csv"
10 IRIS_TEST = "iris_test.csv"
11
12 # Load datasets.
13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
14 filename=IRIS_TRAINING,
15 target_dtype=np.int,
16 features_dtype=np.float32)
17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
18 filename=IRIS_TEST,
19 target_dtype=np.int,
20 features_dtype=np.float32)
21
22 # Specify that all features have real-value data
23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
24
25 # Build 3 layer DNN with 10, 20, 10 units respectively.
26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
27 hidden_units=[10, 20, 10],
28 n_classes=3,
29 model_dir="/tmp/iris_model")
30
31 # Fit model.
32 classifier.fit(x=training_set.data,
33 y=training_set.target,
34 steps=2000)
35
36 # Evaluate accuracy.
37 accuracy_score = classifier.evaluate(x=test_set.data,
38 y=test_set.target)["accuracy"]
39 print('Accuracy: {0:f}'.format(accuracy_score))
40
41 # Classify two new flower samples.
42 new_samples = np.array(
43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
44 y = list(classifier.predict(new_samples, as_iterable=True))
45 print('Predictions: {}'.format(str(y)))

结果:
Accuracy:0.966667
TensorFlow高层次机器学习API (tf.contrib.learn)的更多相关文章
- TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用
一.TensorFlow高层次机器学习API (tf.contrib.learn) 1.tf.contrib.learn.datasets.base.load_csv_with_header 加载cs ...
- tf.contrib.learn.preprocessing.VocabularyProcessor()
tf.contrib.learn.preprocessing.VocabularyProcessor (max_document_length, min_frequency=0, vocabulary ...
- tensorflow中slim模块api介绍
tensorflow中slim模块api介绍 翻译 2017年08月29日 20:13:35 http://blog.csdn.net/guvcolie/article/details/77686 ...
- tf.contrib.layers.fully_connected参数笔记
tf.contrib.layers.fully_connected 添加完全连接的图层. tf.contrib.layers.fully_connected( inputs, num_ou ...
- TensorFlow——tf.contrib.layers库中的相关API
在TensorFlow中封装好了一个高级库,tf.contrib.layers库封装了很多的函数,使用这个高级库来开发将会提高效率,卷积函数使用tf.contrib.layers.conv2d,池化函 ...
- tensorflow笔记3:CRF函数:tf.contrib.crf.crf_log_likelihood()
在分析训练代码的时候,遇到了,tf.contrib.crf.crf_log_likelihood,这个函数,于是想简单理解下: 函数的目的:使用crf 来计算损失,里面用到的优化方法是:最大似然估计 ...
- tensorflow教程:tf.contrib.rnn.DropoutWrapper
tf.contrib.rnn.DropoutWrapper Defined in tensorflow/python/ops/rnn_cell_impl.py. def __init__(self, ...
- TensorFlow中的L2正则化函数:tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()的用法与异同
tf.nn.l2_loss()与tf.contrib.layers.l2_regularizerd()都是TensorFlow中的L2正则化函数,tf.contrib.layers.l2_regula ...
- 关于tensorflow里面的tf.contrib.rnn.BasicLSTMCell 中num_units参数问题
这里的num_units参数并不是指这一层油多少个相互独立的时序lstm,而是lstm单元内部的几个门的参数,这几个门其实内部是一个神经网络,答案来自知乎: class TRNNConfig(obje ...
随机推荐
- linux下的静态库创建与查看,及如何查看某个可执行依赖于哪些动态库
linux下的静态库创建与查看,及如何查看某个可执行依赖于哪些动态库 创建静态库:ar -rcs test.a *.o查看静态库:ar -tv test.a解压静态库:ar -x test.a 查 ...
- hdoj--1087--Super Jumping! Jumping! Jumping!(贪心)
Super Jumping! Jumping! Jumping! Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/32768 ...
- 2016 提高组c++ 错题
需重做 树的重心 链表 计算机基础知识 无线通讯技术: 蓝牙,wifi,GPRS 现在常用的无线通信技术:FM调频广播(用于收音机): 2G.3G移动通信技术(中国移动.中国联通.中国电信正在运营的网 ...
- 没有被广泛采用的box-sizing属性
在标准盒模型下设置的width和height只是内容的宽和高,但在设置了宽和高的情况下若还要设置border.margin.padding等时,会发生溢出的现象,因此需要将盒模型更改. box-siz ...
- Web api 测试 工具WebApiTestClient
1.打开Nuget 安装 WebApiTestClient 2.在HelpPageConfig.cs 里面添加这段文字 config.SetDocumentationProvider(new Xml ...
- 跟着8张思维导图学习javascript (转)
学习的道路就是要不断的总结归纳,好记性不如烂笔头,so,下面将po出8张javascript相关的思维导图. 思维导图小tips:思维导图又叫心智图,是表达发射性思维的有效的图形思维工具 ,它简单却又 ...
- Python更换pip源,更换conda源
更换pip源: 1.在windows文件管理器中,输入 %APPDATA% 2.在该目录下新建pip文件夹,然后到pip文件夹里面去新建个pip.ini文件 3.在新建的pip.ini文件中输入以下内 ...
- 使用node+mysql进行后端开发
使用koa: koa2是一个类,所以引入koa后,要创建实例化“对象”,才能使用koa内部封装的方法. 设置监听端口: 处理http请求: 1.http请求处理链 A.通过app.use()注册asy ...
- HDU 2303 The Embarrassed Cryptographer
The Embarrassed Cryptographer 题意 给一个两个素数乘积(1e100)K, 给以个数L(1e6), 判断K的两个素数是不是都大于L 题解 对于这么大的范围,素数肯定是要打表 ...
- MySQL 面试题(一)
原文地址:http://www.2cto.com/database/201311/254385.html 作者:黄杉(红黑联盟) 公司招聘MySQL DBA面试心得 1 2年MySQL DBA经 ...