这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. 马哥Python视频

    链接:https://pan.baidu.com/s/1KMXqdXlaIjZ3OaZ-PUwE9A 密码私聊我

  2. Linux 下的 Redis 安装 && 启动 && 关闭 && 卸载

    转自https://blog.csdn.net/zgf19930504/article/details/51850594 Redis 在Linux 和 在Windows 下的安装是有很大的不同的,和通 ...

  3. 前台获取json未定义问题之两种常用解决办法

    来自博客园的一位朋友解答: 为什么要 eval这里要添加 “("("+data+")");//”呢? 原因在于:eval本身的问题. 由于json是以”{}”的 ...

  4. 1115 洛谷luogu最大子段和

    题目描述 给出一段序列,选出其中连续且非空的一段使得这段和最大. 输入输出格式 输入格式: 第一行是一个正整数NNN,表示了序列的长度. 第二行包含NNN个绝对值不大于100001000010000的 ...

  5. iOS更新惹怒高通:苹果太可耻!

    之前高通同时在德国.中国发起对苹果的专利诉讼,而他们都赢得了最终的胜利,其中包含iPhone 7.8以及X系列机型,统统在禁售机型当中. 从法院公布的细节看,高通对iPhone禁售的理由是,iOS系统 ...

  6. 新版本grafana添加数据源报错!

    前提: grafana配置的数据源url没有错误. 现象: 1)升级完grafana之后发现原来配置的open-facon数据源无效了,一直提示HTTP ERROR NOT FOUND. 2)安装新版 ...

  7. 4.5《想成为黑客,不知道这些命令行可不行》(Learn Enough Command Line to Be Dangerous)—第四章小结

    本章相关重要命令总结在Table 6. 命令 描述 示例 mkdir <name> 创建某目录 $ mkdir foo pwd 显示当前所在目录 $ pwd cd <dir> ...

  8. WPF C#仿ios 安卓 红点消息提示

    原文:WPF C#仿ios 安卓 红点消息提示 先把效果贴出来,大家看看. 代码下载地址: http://download.csdn.net/detail/candyvoice/9730751 点击+ ...

  9. 汇编 OD 标志位 置位相关指令

    知识点: l 标志位 置位相关指令   l 标志寄存器PSW 标志寄存器PSW(程序状态字寄存器PSW)    标志寄存器PSW是一个16为的寄存器.它反映了CPU运算的状态特征并且存放某些控制标志. ...

  10. DevOps知识地图实践指南

    DevOps知识地图   DevOps方法论的主要来源是Agile, Lean 和TOC, 独创的方法论是持续交付. DevOps经典图书: * <DevOps实践指南> * <持续 ...