这一节使用TF搭建一个简单的神经网络用于分类任务,首先把需要的包引入,另外为了防止在多次运行中一些图中的tensor在内存中影响实验,采取重置操作:

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. def reset_graph(seed=42):
  5. tf.reset_default_graph()
  6. tf.set_random_seed(seed)
  7. np.random.seed(seed)
  8. reset_graph()
  9. plt.figure(1,figsize=(8,6))

为了方便观察随机生成一组两维数据

  1. x0 = np.random.normal(1,1,size=(100,2)) #[(x1,x2),()]
  2. y0 = np.zeros(100)
  3. x1 = np.random.normal(-1,1,size=(100,2))
  4. y1 = np.ones(100)
  5. x = np.concatenate((x0,x1),axis = 0)
  6. y = np.concatenate((y0,y1),axis = 0)
  7. plt.scatter(x[:,0],x[:,1],c=y,cmap='RdYlGn')
  8. plt.show()

上面生成的两个类别的数据,均值分别为1-1方差都为1

接下来就是训练模型

  1. #模型
  2. tf_x = tf.placeholder(tf.float32,x.shape)
  3. tf_y = tf.placeholder(tf.int32,y.shape)
  4. output = tf.layers.dense(tf_x,10,tf.nn.relu,name="hidden")
  5. output = tf.layers.dense(output,2,name="output")
  6. with tf.name_scope("loss"):
  7. xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf_y,logits=output)
  8. loss = tf.reduce_mean(xentropy,name="loss")
  9. with tf.name_scope("train"):
  10. optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
  11. training_op = optimizer.minimize(loss)
  12. #evaluate
  13. with tf.name_scope("eval"):
  14. correct = tf.nn.in_top_k(output,y,1)
  15. accuracy = tf.reduce_mean(tf.cast(correct,tf.float32))
  16. init = tf.global_variables_initializer()
  17. plt.ion()
  18. plt.figure(figsize=(8,6))
  19. saver = tf.train.Saver()
  20. with tf.Session() as sess:
  21. sess.run(init)
  22. for step in range(100):
  23. _,acc,pred = sess.run([training_op,accuracy,output],feed_dict={tf_x:x,tf_y:y})
  24. plt.cla()
  25. plt.scatter(x[:,0],x[:,1],c=pred.argmax(1),cmap='RdYlGn')
  26. plt.text(1.5, -2, 'Accuracy=%.2f' % acc, fontdict={'size': 20, 'color': 'red'})
  27. saver.save(sess, './model', write_meta_graph=False) #保存模型
  28. plt.ioff()
  29. plt.show()

上面创建了一个隐含层的网络,使用的是elu,也可以尝试使用其他的激活函数。需要注意的是tf.layers.dense的作用是outputs = activation(inputs.kernel + bias),可以看出在输出层是没有使用激活函数的,如果activation=None就表示使用的是线性映射。模型训练完毕后,我们将其持久化,方便以后的使用。我们来看下最终的结果:

使用TensorFlow实现分类的更多相关文章

  1. Tensorflow二分类处理dense或者sparse(文本分类)的输入数据

    这里做了一些小的修改,感谢谷歌rd的帮助,使得能够统一处理dense的数据,或者类似文本分类这样sparse的输入数据.后续会做进一步学习优化,比如如何多线程处理. 具体如何处理sparse 主要是使 ...

  2. 『TensorFlow』分类问题与两种交叉熵

    关于categorical cross entropy 和 binary cross entropy的比较,差异一般体现在不同的分类(二分类.多分类等)任务目标,可以参考文章keras中两种交叉熵损失 ...

  3. tensorflow之分类学习

    写在前面的话 MNIST教程是tensorflow中文社区的第一课,例程即训练一个 手写数字识别 模型:http://www.tensorfly.cn/tfdoc/tutorials/mnist_be ...

  4. 机器学习框架ML.NET学习笔记【6】TensorFlow图片分类

    一.概述 通过之前两篇文章的学习,我们应该已经了解了多元分类的工作原理,图片的分类其流程和之前完全一致,其中最核心的问题就是特征的提取,只要完成特征提取,分类算法就很好处理了,具体流程如下: 之前介绍 ...

  5. tensorflow文本分类实战——卷积神经网络CNN

    首先说明使用的工具和环境:python3.6.8   tensorflow1.14.0   centos7.0(最好用Ubuntu) 关于环境的搭建只做简单说明,我这边是使用pip搭建了python的 ...

  6. TensorFlow 实现分类操作的函数学习

    函数:tf.nn.sigmoid_cross_entropy_with_logits(logits, targets, name=None) 说明:此函数是计算logits经过sigmod函数后的交叉 ...

  7. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(4)

    # -*- coding: utf-8 -*- import glob import os.path import numpy as np import tensorflow as tf from t ...

  8. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(3)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

  9. 吴裕雄 python 神经网络——TensorFlow 花瓣分类与迁移学习(2)

    import glob import os.path import numpy as np import tensorflow as tf from tensorflow.python.platfor ...

随机推荐

  1. go标准库的学习-hash

    参考:https://studygolang.com/pkgdoc 导入方式: import "hash" hash包提供hash函数的接口. type Hash type Has ...

  2. rac添加新节点的步骤与方法2

    上一篇文章,把节点删除了.这次新增加一个节点 .新增加的节点是host03.如下: #Public IP192.168.16.45 racdb1192.168.16.46 racdb2192.168. ...

  3. 图、dfs、bfs

    graphdfsbfs 1.clone graph2.copy list with random pointer3.topological sorting4.permutations5.subsets ...

  4. Android学习之基础知识五—ListView控件(最常用和最难用的控件)

    ListView控件允许用户通过上下滑动来将屏幕外的数据拉到屏幕内,把屏幕内的数据拉到屏幕外. 一.ListView的简单用法第一步:先创建一个ListViewTest项目,在activity_mia ...

  5. PAM unable to dlopen(/lib/security/pam_limits.so): /lib/security/pam_limits.so: wrong ELF class: ELFCLASS32

    systemctl status sshd● sshd.service - OpenSSH server daemon Loaded: loaded (/usr/lib/systemd/system/ ...

  6. CF809E Surprise me! 莫比乌斯反演、虚树

    传送门 简化题意:给出一棵\(n\)个点的树,编号为\(1\)到\(n\),第\(i\)个点的点权为\(a_i\),保证序列\(a_i\)是一个\(1\)到\(n\)的排列,求 \[ \frac{1} ...

  7. EZ 2018 05 20 NOIP2018 模拟赛(十五)

    这次的比赛充满着玄学的气息,玄学链接 首先讲一下为什么没有第十四场 其实今天早上9点时看到题目就叫了:原题! 没错,整套试卷都做过,我还写了题解 然后老叶就说换一套,但如果仅仅是这样就没什么 但等13 ...

  8. Linux性能评测工具之一:gprof篇

    这些天自己试着对项目作一些压力测试和性能优化,也对用过的测试工具作一些总结,并把相关的资料作一个汇总,以便以后信手拈来! 1 简介 改进应用程序的性能是一项非常耗时耗力的工作,但是究竟程序中是哪些函数 ...

  9. 【php增删改查实例】第十七节 - 用户登录(1)

    新建一个login文件,里面存放的就是用户登录的模块. <html> <head> <meta charset="utf-8"> <sty ...

  10. JavaScript 利用 async await 实现 sleep 效果

    const sleep = (timeountMS) => new Promise((resolve) => { setTimeout(resolve, timeountMS); }); ...