使用TensorFlow实现分类
这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def reset_graph(seed=42):
tf.reset_default_graph()
tf.set_random_seed(seed)
np.random.seed(seed)
reset_graph()
plt.figure(1,figsize=(8,6))
为了方便观察随机生成一组两维数据
x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
y0 = np.zeros(100)
x1 = np.random.normal(-1,1,size=(100,2))
y1 = np.ones(100)
x = np.concatenate((x0,x1),axis = 0)
y = np.concatenate((y0,y1),axis = 0)
plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
plt.show()
上面生成的两个类别的数据,均值分别为1和-1方差都为1

接下来就是训练模型
#模型
tf_x = tf.placeholder(tf.float32,x.shape)
tf_y = tf.placeholder(tf.int32,y.shape)
output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
output = tf.layers.dense(output,2,name="output")
with tf.name_scope("loss"):
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
loss = tf.reduce_mean(xentropy,name="loss")
with tf.name_scope("train"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
training_op = optimizer.minimize(loss)
#evaluate
with tf.name_scope("eval"):
correct = tf.nn.in_top_k(output,y,1)
accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
init = tf.global_variables_initializer()
plt.ion()
plt.figure(figsize=(8,6))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
for step in range(100):
_,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
plt.cla()
plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
saver.save(sess, './model', write_meta_graph=False) #保存模型
plt.ioff()
plt.show()
上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章
- Tensorflow二分类处理dense或者sparse(文本分类)的输入数据
这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...
- 『TensorFlow』分类问题与两种交叉熵
关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...
- tensorflow之分类学习
写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...
- 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类
一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...
- tensorflow文本分类实战——卷积神经网络CNN
首先说明使用的工具和环境:python3.6.8 tensorflow1.14.0 centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...
- TensorFlow 实现分类操作的函数学习
函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...
- 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)
# -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...
- 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)
import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...
- 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)
import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...
随机推荐
- vue学习路由嵌套
1. 路由嵌套和参数传递 传参的两种形式: a.查询字符串:login?name=tom&pwd=123 {{$route.query}} ------ <li><route ...
- ;,&,&&,shell,区别
command1&command2&command3 三个命令同时执行 command1;command2;command3 不管前面命令执行成功没有,后面的命令继续执 ...
- AMD和CMD规范
1.名词解释AMD:Asynchronous Modules Definition异步模块定义,提供定义模块及异步加载该模块依赖的机制.CMD:Common Module Definition 通用模 ...
- 用statefulSet 部署持久化的OA(Tomcat)
1.部署多个副本的OA(Tomcat)集群,其中一个Tomcat的需要加一个定时器,其他代码跟其他的Tomcat的代码一样.需要重启后也还是保持这个状态.代码如下: apiVersion: v1 ki ...
- IDEA 创建和使用tomcat
一.创建一个普通web项目,步骤略,如下图. 二.配置项目相关信息. 1.通过如下方式在Artifacts下添加我们的项目. 2.选中我们的项目. 3.修改项目的默认输出位置,可根据需要修改. 4.如 ...
- I2S音频总线学习
IIS音频总线学习(一)数字音频技术 一.声音的基本概念 声音是通过一定介质传播的连续的波. 图1 声波 重要指标: 振幅:音量的大小 周期:重复出现的时间间隔 频率:指信号每秒钟变化的次数 声音按频 ...
- Linux系列教程(五)——Linux常用命令之链接命令和权限管理命令
前一篇博客我们讲解了Linux文件和目录处理命令,还是老生常淡,对于新手而言,我们不需要完全记住命令的详细语法,记住该命令能完成什么功能,然后需要的时候去查就好了,用的多了我们就自然记住了.这篇博客我 ...
- 有哪些操作会使用到TempDB;如果TempDB异常变大,可能的原因是什么,该如何处理(转载)
有哪些操作会使用到TempDB:如果TempDB异常变大,可能的原因是什么,该如何处理:tempdb的用途: 存储专用和全局临时变量,不考虑数据库上下文: 与Order by 子句,游标,Group ...
- JXOI2018简要题解
JXOI2018简要题解 T1 排序问题 题意 九条可怜是一个热爱思考的女孩子. 九条可怜最近正在研究各种排序的性质,她发现了一种很有趣的排序方法: Gobo sort ! Gobo sort 的算法 ...
- python中和生成器协程相关的yield之最详最强解释,一看就懂(一)
yield是python中一个非常重要的关键词,所有迭代器都是yield实现的,学习python,如果不把这个yield的意思和用法彻底搞清楚,学习python的生成器,协程和异步io的时候,就会彻底 ...