1. tf.matmul(X, w) # 进行点乘操作

参数说明:X,w都表示输入的数据,

2.tf.equal(x, y) # 比较两个数据对应位置的数是否相等,返回值为True,或者False

参数说明:x,y表示需要比较的两组数

3.tf.cast(y, 'float') # 将布尔类型转换为数字类型

参数说明:y表示输入的数据,‘float’表示转换的数据类型

4.tf.argmax(y, 1) # 返回每一行的最大值的索引

参数说明:y表示输入数据,1表示每一行的最大值的索引,0表示每一列最大值得索引

5.tf.nn.softmax(y) # 对每一行数据,按照softmax方式,计算其概率值

参数说明:y表示输入的数据

6.tf.train.GradientDescentOptimizer(0.05).minimize(loss) # 对损失值采用梯度下降的方法进行降低操作

参数说明:0.05表示学习率的大小,loss表示损失值

代码说明:相较于上一个线性拟合的代码,逻辑回归的代码,使用了tf.placeholder()进行数据输入的占位,

预测结果等价于tf.matmul(X, w) + b,作为预测结果

代码:

第一步:载入mnist数据集

第二步:超参数的设置,包括输入样本的维度,分类的类别数,一个batch的大小,以及迭代的次数

第三步:使用tf.placeholder() 构造X特征和y标签

第四步:使用tf.Variable(tf.random_normal([inputSize, num_classes]))构造W和b

第五步:使用tf.matmul获得预测得分值,使用tf.nn.softmax() 将得分值转换为概率值

第六步:使用L2损失值,即均分误差作为损失值

第七步:使用tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss), 对损失值进行梯度下降

第八步:使用tf.equal(tf.argmax(y_pred), tf.argmax(y)) 获得两个标签的最大值索引是否一致,再使用tf.reduce_mean(tf.cast()) 计算准确率

第九步:循环,使用mnist.train.next_batch进行训练数据的部分数据的读取

第十步:sess.run([optm, loss], feed_dict={X:batch[0], y:batch[1]}) # 执行损失值和梯度下降降低损失值,从而更新参数

第十一步:每迭代1000次,打印部分数据的准确率

第十二步:打印测试数据部分的准确率

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data # 第一步:mnist数据集的载入
mnist = input_data.read_data_sets('/data', one_hot=True) # 第二步:初始超参数的设置
# 输入层的维度
inputSize = 784
# 分类的类别数
num_classes = 10
# 每一次参数更新的batchSize
batch_size = 64
# 循环的次数
trainIteration = 50000 # 第三步:输入数据初始化操作, 进行一个占位操作
# X的维度为[-1, inputSize]
X = tf.placeholder(tf.float32, [None, inputSize])
# y的维度为[-1, num_classes]
y = tf.placeholder(tf.float32, [None, num_classes])
# 第四步:变量的初始化操作
# W1进行正态参数初始化,维度为[inputSize, num_classes]
W1 = tf.Variable(tf.random_normal([inputSize, num_classes], stddev=0.1))
# b1进行全零初始化,维度为[num_classes]
b1 = tf.Variable(tf.zeros([num_classes]))
# 第五步:使用tf.matmul进行预测操作, 使用softmax计算概率值,扩大类别差距
y_pred = tf.nn.softmax(tf.matmul(X, W1) + b1)
# 第六步:使用L2 loss计算损失值, tf.reduce_mean求出所有的损失值的平均值
loss = tf.reduce_mean(tf.square(y-y_pred))
# 第七步:使用梯度下降法,对loss进行降低
opt = tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss)
# 第八步:判断预测结果的最大值和真实值之间是否相等,使用True和False表示
correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
# 使用tf.cast将布尔类型转换为float类型,使用reduce_mean求平均值
accr = tf.reduce_mean(tf.cast(correct_pred, 'float'))
# 变量初始化
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) for i in range(trainIteration):
# 第九步:使用mnist.train.next_batch 获得一个batch的数据
bacth = mnist.train.next_batch(batch_size)
# 第十步:使用sess.run执行梯度下降和loss操作,输入的参数为batch_X和batch_y数据
_, loss = sess.run([opt, loss], feed_dict={X:bacth[0], y:bacth[1]})
# 第十一步:如果迭代1000次,打印当前训练数据的准确率
if i % 1000 == 0:
print('train accr %g'%(sess.run(accr, feed_dict={X:bacth[0], y:bacth[1]})))
# 第十二步:打印测试数据的batch_size的准确率
batch = mnist.test.next_batch(batch_size)
print('test accr %g'%(sess.run(accr, feed_dict={X: batch[0], y: batch[1]})))

深度学习原理与框架-Tensorflow基本操作-mnist数据集的逻辑回归 1.tf.matmul(点乘操作) 2.tf.equal(对应位置是否相等) 3.tf.cast(将布尔类型转换为数值类型) 4.tf.argmax(返回最大值的索引) 5.tf.nn.softmax(计算softmax概率值) 6.tf.train.GradientDescentOptimizer(损失值梯度下降器)的更多相关文章

  1. 深度学习原理与框架-Tensorflow基本操作-变量常用操作 1.tf.random_normal(生成正态分布随机数) 2.tf.random_shuffle(进行洗牌操作) 3. tf.assign(赋值操作) 4.tf.convert_to_tensor(转换为tensor类型) 5.tf.add(相加操作) tf.divide(相乘操作) 6.tf.placeholder(输入数据占位

    1. 使用tf.random_normal([2, 3], mean=-1, stddev=4) 创建一个正态分布的随机数 参数说明:[2, 3]表示随机数的维度,mean表示平均值,stddev表示 ...

  2. 深度学习原理与框架-Tensorflow基本操作-实现线性拟合

    代码:使用tensorflow进行数据点的线性拟合操作 第一步:使用np.random.normal生成正态分布的数据 第二步:将数据分为X_data 和 y_data 第三步:对参数W和b, 使用t ...

  3. 深度学习原理与框架-Tensorflow基本操作-Tensorflow中的变量

    1.tf.Variable([[1, 2]])  # 创建一个变量 参数说明:[[1, 2]] 表示输入的数据,为一行二列的数据 2.tf.global_variables_initializer() ...

  4. 深度学习原理与框架-Tensorflow卷积神经网络-cifar10图片分类(代码) 1.tf.nn.lrn(局部响应归一化操作) 2.random.sample(在列表中随机选值) 3.tf.one_hot(对标签进行one_hot编码)

    1.tf.nn.lrn(pool_h1, 4, bias=1.0, alpha=0.001/9.0, beta=0.75) # 局部响应归一化,使用相同位置的前后的filter进行响应归一化操作 参数 ...

  5. 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)

    1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME')  # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...

  6. 深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类

    使用tensorflow构造神经网络用来进行mnist数据集的分类 相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数, 神经网络使用的损失值为 ...

  7. 深度学习原理与框架-Alexnet(迁移学习代码) 1.sys.argv[1:](控制台输入的参数获取第二个参数开始) 2.tf.split(对数据进行切分操作) 3.tf.concat(对数据进行合并操作) 4.tf.variable_scope(指定w的使用范围) 5.tf.get_variable(构造和获得参数) 6.np.load(加载.npy文件)

    1. sys.argv[1:]  # 在控制台进行参数的输入时,只使用第二个参数以后的数据 参数说明:控制台的输入:python test.py what, 使用sys.argv[1:],那么将获得w ...

  8. 深度学习原理与框架-递归神经网络-时间序列预测(代码) 1.csv.reader(进行csv文件的读取) 2.X.tolist(将数据转换为列表类型)

    1. csv.reader(csvfile) # 进行csv文件的读取操作 参数说明:csvfile表示已经有with oepn 打开的文件 2. X.tolist() 将数据转换为列表类型 参数说明 ...

  9. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

随机推荐

  1. HashSet的自定义实现

    package com.cy.collection; import java.util.HashMap; /** * HashSet自定义实现 * 是使用hashMap实现的 * 可以看一下HashS ...

  2. php给app写接口进行接口的加密

    <?php/**inc解析接口客户端接口传输规则:1.用cmd参数(base64)来动态调用不同的接口,接口地址统一为 http://a.lovexpp.com2.将要传过来的参数组成一个数组, ...

  3. host文件的工作原理及应用

    host文件的工作原理及应用 Hosts文件是一个用于存储计算机网络中节点信息的文件,它可以将主机名映射到相应的IP地址,实现DNS的功能,它可以由计算机的用户进行控制. 一.Hosts文件基本介绍 ...

  4. 静态路由、Track与NQA联动配置举例

    原文: http://www.h3c.com/cn/d_201708/1018729_30005_0.htm#_Toc488338732 1.6.4  静态路由.Track与NQA联动配置举例 1. ...

  5. Unreal Engine 4 动态切割模型实现

    转自:http://gad.qq.com/article/detail/33199 <合金装备:复仇>里面,有一个很有趣的设定,游戏里大部分的场景和物件都可以用主角的刀动态切割. UE4中 ...

  6. MySQL MHA 搭建&测试(环境:CentOS7 + MySQL5.7.23)

    MySQL MHA架构介绍: MHA(Master High Availability)目前在MySQL高可用方面是一个相对成熟的解决方案,它由日本DeNA公司youshimaton(现就职于Face ...

  7. DateTimepicker中的星期问题

    开发机:win10 64+VS2013 客户机:win7 32bit 在项目中使用DateTimepicker,需要将时间获取到,然后转换为string,然后再转换为DateTime类型.开发机器上测 ...

  8. django前篇

    http协议 HTTP简介 HTTP协议是Hyper Text Transfer Protocol(超文本传输协议)的缩写,是用于从万维网(WWW:World Wide Web )服务器传输超文本到本 ...

  9. Masonry基本语法

    添加约束的方式: 1.通过使用NSLayoutConstraints添加约束到约束数组中,之前必须设置translatesAutoresizingMaskIntoConstraints = NO,即取 ...

  10. is 和 == 区别,id() ,回顾编码,encode(),decode()

    1. is 和 == 区别 id()函数 == 判断两边的值 is 判断内存地址例 s = "alex 是 大 xx"# abc = id(s) # 得到内存地址# print(a ...