这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. ubuntu16.04下zabbix安装和配置

    介绍 Zabbix是用于网络和应用的开源监控软件. 它提供从服务器,虚拟机和任何其他类型的网络设备收集的数千个度量的实时监控. 这些指标可以帮助您确定IT基础架构的当前运行状况,并在客户投诉之前检测硬 ...

  2. 深度:Hadoop对Spark五大维度正面比拼报告!

    每年,市场上都会出现种种不同的数据管理规模.类型与速度表现的分布式系统.在这些系统中,Spark和hadoop是获得最大关注的两个.然而该怎么判断哪一款适合你? 如果想批处理流量数据,并将其导入HDF ...

  3. img图片加载出错处理(转载)

    为了美观当网页图片不存在时不显示叉叉图片当在页面显示的时候,万一图片被移动了位置或者丢失的话,将会在页面显示一个带X的图片,很是影响用户的体验.即使使用alt属性给出了”图片XX”的提示信息,也起不了 ...

  4. AngularJs的ng-include的使用与实现

    想在angularjs动态加载一个内容,我们可以使用ng-include来实现. 今天Insus.NET就在ASP.NET MVC环境中,举个例子来演示它的功能. 你可以在一个视图动态加载任一其它视图 ...

  5. .net获取excel表的内容(OleDB方法)

    首先引用组件和命名空间 using Microsoft.Office.Interop.Excel; using System.Data.OleDb; 然后把excel上传到指定路径 上传文件方法省略 ...

  6. Codeforces 718C solution

    C. Sasha and Array   time limit per test :  5 seconds memory limit per test :  256 megabytes Descrip ...

  7. python 的zip 函数小例子

    In [57]: name = ('Tome','Rick','Stephon') In [58]: age = (45,23,55) In [59]: for a,n in zip (name,ag ...

  8. for循环两个略骚的写法

    骚写法 或许你知道,总之我觉得很酷,希望你也这么认为. 递增遍历 最常见场景,从 0 到 10 的遍历,不输出 10: for(let i = -1; ++i < 10;) { console. ...

  9. A. Make a triangle!

    题意 给你三条边a,b,c问使得构成三角形,需要增加的最少长度是多少 思路 数学了啦 代码 #include<bits/stdc++.h> using namespace std; #de ...

  10. Linux内核设计第十七章笔记

    第十七章 设备与模块 关于设备驱动和设备管理,四种内核成分 设备类型:在所有unix系统中为了统一普通设备的操作所采用的分类 模块:Linux内核中用于按需加载和卸载目标代码的机制 内核对象:内核数据 ...