tensorflow学习笔记五----------逻辑回归
在逻辑回归中使用mnist数据集。导入相应的包以及数据集。
- import numpy as np
- import tensorflow as tf
- import matplotlib.pyplot as plt
- from tensorflow.examples.tutorials.mnist import input_data
- mnist = input_data.read_data_sets('data/', one_hot=True)
- trainimg = mnist.train.images
- trainlabel = mnist.train.labels
- testimg = mnist.test.images
- testlabel = mnist.test.labels
- print ("MNIST loaded")
使用tensorflow中函数进行逻辑回归的构建。调用softmax函数进行逻辑回归模型的构建;构造损失函数【y*tf.log(actv)】;构造梯度下降训练器;
- x = tf.placeholder("float", [None, 784]) #784代表照片像素为28*28 10代表共有十个数字
- y = tf.placeholder("float", [None, 10]) # None is for infinite
- W = tf.Variable(tf.zeros([784, 10]))
- b = tf.Variable(tf.zeros([10]))
- # LOGISTIC REGRESSION MODEL
#get the predict number- actv = tf.nn.softmax(tf.matmul(x, W) + b)
- # COST FUNCTION
- cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv), reduction_indices=1))
- # OPTIMIZER
- learning_rate = 0.01
- optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
- # PREDICTION
- pred = tf.equal(tf.argmax(actv, 1), tf.argmax(y, 1))
- # ACCURACY
- accr = tf.reduce_mean(tf.cast(pred, "float"))
- # INITIALIZER
- init = tf.global_variables_initializer()
循环五十次,每五次打印一次结果,每次训练取100个样本
- training_epochs = 50
- batch_size = 100
- display_step = 5
- # SESSION
- sess = tf.Session()
- sess.run(init)
- # MINI-BATCH LEARNING
- for epoch in range(training_epochs):
- avg_cost = 0.
- num_batch = int(mnist.train.num_examples/batch_size)
- for i in range(num_batch):
- batch_xs, batch_ys = mnist.train.next_batch(batch_size)
- sess.run(optm, feed_dict={x: batch_xs, y: batch_ys})
- feeds = {x: batch_xs, y: batch_ys}
- avg_cost += sess.run(cost, feed_dict=feeds)/num_batch
- # DISPLAY
- if epoch % display_step == 0:
- feeds_train = {x: batch_xs, y: batch_ys}
- feeds_test = {x: mnist.test.images, y: mnist.test.labels}
- train_acc = sess.run(accr, feed_dict=feeds_train)
- test_acc = sess.run(accr, feed_dict=feeds_test)
- print ("Epoch: %03d/%03d cost: %.9f train_acc: %.3f test_acc: %.3f"
- % (epoch, training_epochs, avg_cost, train_acc, test_acc))
- print ("DONE")
tensorflow学习笔记五----------逻辑回归的更多相关文章
- tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)
mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...
- Python学习笔记之逻辑回归
# -*- coding: utf-8 -*- """ Created on Wed Apr 22 17:39:19 2015 @author: 90Zeng " ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- tensorflow学习笔记(4)-学习率
tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...
- tensorflow学习笔记(2)-反向传播
tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)
续集请点击我:tensorflow学习笔记——使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...
- TensorFlow学习笔记10-卷积网络
卷积网络 卷积神经网络(Convolutional Neural Network,CNN)专门处理具有类似网格结构的数据的神经网络.如: 时间序列数据(在时间轴上有规律地采样形成的一维网格): 图像数 ...
随机推荐
- 使用 CSS 显示 XML
通过使用 CSS,可为 XML 文档添加显示信息. 使用 CSS 显示您的 XML? 使用 CSS 来格式化 XML 文档是有可能的. 下面的例子就是关于如何使用 CSS 样式表来格式化 XML 文档 ...
- Spring Boot教程(十一) springboot程序构建一个docker镜像
准备工作 环境: linux环境或mac,不要用windows jdk 8 maven 3.0 docker 对docker一无所知的看docker教程. 创建一个springboot工程 引入web ...
- [BZOJ1902]:[NOIP2004]虫食算(搜索)
题目传送门 题目描述 所谓虫食算,就是原先的算式中有一部分被虫子啃掉了,需要我们根据剩下的数字来判定被啃掉的字母. 来看一个简单的例子: 43#98650#45+8468#6633=444455069 ...
- smartload跨浏览器极速缓存插件用法
smartload由39smart团队原创,主要实现前端css/js的一次加载请求,永久缓存的加速效果,在移动端效果非常明显. 插件特点: 支持平台: PC和移动端所有版本浏览器,ie6+,firef ...
- vue根据参数不同的路由跳转以及name的作用
最近在做VUE路由跳转根据参数的值不同但是跳转的是同一个路由的功能.点击左边的目录,根据目录ID跳转不同的列表.如下图. 路由跳转的代码: this.$router.push({path: '/RFI ...
- java生成二维码的几种方式
1: 使用SwetakeQRCode在Java项目中生成二维码 http://swetake.com/qr/ 下载地址 或着http://sourceforge.jp/projects/qrcode/ ...
- spring-servlet.xml
<?xml version="1.0" encoding="UTF-8"?> <beans xmlns="http://www.sp ...
- IE浏览器中图片路径正确< img ... />标签不显示图片
如下图所示,下面的html要去加载上面的jpg图片: 代码如下: <img src="luzhanshi1.jpg" alt="图片加载失败"> 使 ...
- 3、maven导入外部自定义jar包
有些时候我们自己有一些jar包需要导入到我们的仓库中,然后在maven项目里的pom.xml文件加入这些jar包的依赖即可使用这些jar包了 1.确保行执行mvn -v没有问题 2.把需要引入的jar ...
- 系统编码 python编码
编码一直都是一个很让人头疼的问题,尤其是在python里面.花了几天时间,终于把这个问题给弄明白了. 一,什么是编码,编码过程是怎样的?常见的编码方式有哪些? 编码是从一个字符,比如'哈',到一段二进 ...