这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. vue学习路由嵌套

    1. 路由嵌套和参数传递 传参的两种形式: a.查询字符串:login?name=tom&pwd=123 {{$route.query}} ------ <li><route ...

  2. ;,&,&&,shell,区别

    command1&command2&command3     三个命令同时执行 command1;command2;command3     不管前面命令执行成功没有,后面的命令继续执 ...

  3. AMD和CMD规范

    1.名词解释AMD:Asynchronous Modules Definition异步模块定义,提供定义模块及异步加载该模块依赖的机制.CMD:Common Module Definition 通用模 ...

  4. 用statefulSet 部署持久化的OA(Tomcat)

    1.部署多个副本的OA(Tomcat)集群,其中一个Tomcat的需要加一个定时器,其他代码跟其他的Tomcat的代码一样.需要重启后也还是保持这个状态.代码如下: apiVersion: v1 ki ...

  5. IDEA 创建和使用tomcat

    一.创建一个普通web项目,步骤略,如下图. 二.配置项目相关信息. 1.通过如下方式在Artifacts下添加我们的项目. 2.选中我们的项目. 3.修改项目的默认输出位置,可根据需要修改. 4.如 ...

  6. I2S音频总线学习

    IIS音频总线学习(一)数字音频技术 一.声音的基本概念 声音是通过一定介质传播的连续的波. 图1 声波 重要指标: 振幅:音量的大小 周期:重复出现的时间间隔 频率:指信号每秒钟变化的次数 声音按频 ...

  7. Linux系列教程(五)——Linux常用命令之链接命令和权限管理命令

    前一篇博客我们讲解了Linux文件和目录处理命令,还是老生常淡,对于新手而言,我们不需要完全记住命令的详细语法,记住该命令能完成什么功能,然后需要的时候去查就好了,用的多了我们就自然记住了.这篇博客我 ...

  8. 有哪些操作会使用到TempDB;如果TempDB异常变大,可能的原因是什么,该如何处理(转载)

    有哪些操作会使用到TempDB:如果TempDB异常变大,可能的原因是什么,该如何处理:tempdb的用途: 存储专用和全局临时变量,不考虑数据库上下文: 与Order by 子句,游标,Group ...

  9. JXOI2018简要题解

    JXOI2018简要题解 T1 排序问题 题意 九条可怜是一个热爱思考的女孩子. 九条可怜最近正在研究各种排序的性质,她发现了一种很有趣的排序方法: Gobo sort ! Gobo sort 的算法 ...

  10. python中和生成器协程相关的yield之最详最强解释,一看就懂(一)

    yield是python中一个非常重要的关键词,所有迭代器都是yield实现的,学习python,如果不把这个yield的意思和用法彻底搞清楚,学习python的生成器,协程和异步io的时候,就会彻底 ...