在前边几期的文章中,笔者已经用TensorFlow进行的一些基础性的探索工作,想必大家对TensorFlow框架也是非常的好奇,本着发扬雷锋精神,笔者将详细的阐述TensorFlow框架的基本用法,并尽力做到通俗易懂,对得起读者花费的时间。

行文目录

本文从以下三个方面,展开对TensorFlow的剖析:

  • TensorFlow框架概述

  • TensorFlow基本操作

  • TensorBoard使用

TensorFlow框架概述

2015年11月9日,为加速深度学习的发展,Google发布了深度学习框架TensorFlow,经过几年的发展,TensorFlow成为了最流行的深度学习框架。

TensorFlow从名字上解释就是Tensor(张量)+Flow(流)。什么是张量呢?张量是矢量概念的推广,可以表示任意维度的数据,如一维数组,二维矩阵,N维数据。TensorFlow的运行过程实质就是张量从图的一端流动到另一端的计算过程。下文简单阐述TensorFlow的一些基本概念。

深度学习框架分为“动态计算图”和“静态计算图”,支持动态计算图的被称为动态框架,支持静态计算图的被称为静态框架。

静态框架:先定义计算执行顺序和内存分配策略,然后按照规定的计算顺序和资源进行计算。打个比方,在盖大楼的时候,静态框架就好比设计师团队与施工团队分离,设计师设计好图纸之后,施工团队才开始按照图纸方案进行施工。

动态框架:声明和执行一起执行。类似于设计师和施工团队一块儿工作,设计师说先“打地基”,施工团队就开始施工打地基。

TensorFlow支持静态和动态两种方式,一般TensorFlow程序分为两个阶段,图的构建阶段和图的执行阶段。

  • 操作

图中节点就是一个操作,比如,一次加法运算就是一个操作,构建变量的初始值也是一个操作。构建图的过程中,需要把所有操作确定下来,比如指定操作在哪台设备上执行。一些基本操作如下所示:

1#定义变量a操作
2a = tf.Variable(1.0, name="a")
3#定义操作b=a+1
4b = tf.add(a, 1, name="b")
5#定义操作c=b+1
6c = tf.add(b, 1, name="c")
7#定义操作d=b+10
8d = tf.add(b, 10, name="d")

操作之间存在依赖关系,这种依赖被称为边,操作与边相连接就构成了一张图,如图1所示:

图1 数据流图

  • 会话

TensorFlow的计算需要在会话中执行,当创建一个会话时,如果没有传递参数,会启动默认的图来构造图结构,并将图中定义的操作根据定义情况分发到CPU或者GPU上执行。

 1import tensorflow as tf
2#定义常量
3v1 = tf.constant(1, name="value1")
4v2 = tf.constant(1, name="value2")
5#v1+v2
6add_op = tf.add(v1, v2, name="add_op_name")
7#创建会话
8with tf.Session() as sess:
9  result = sess.run(add_op)
10  print("1 + 1 = %.0f" % result)

TensorFlow基本操作

  • 定义常量、变量和占位符

定义常量:a = tf.constant(1, name="value1")

定义变量:v2 = tf.Variable(0.01,name=”weight2”)

注意:变量需要经过初始化之后才能使用,常量不需要

定义占位符:v3 = tf.placeholder("float")

 1import  tensorflow  as  tf
2#定义常量
3a = tf.constant(1, name="value1")
4#定义变量
5v1 = tf.Variable(0.001)
6v2 = tf.Variable(v1.initialized_value() * 2)
7#定义占位符
8v3 = tf.placeholder("float")
9v4 = tf.placeholder("float")
10y = tf.mul(v3, v4) #构造一个op节点
11
12init = tf.global_variables_initializer()
13with tf.Session() as sess:
14  #变量初始化
15sess.run(init)
16  print("v1 is:")
17  print(sess.run(v1))
18  print("v2 is:")
19  print(sess.run(v2))
20#占位符操作
21      print sess.run(y, feed_dict={v3: 3, v4: 3})
  • TensorFlow函数

笔者对TensorFlow中常用的函数进行了简单汇总,但是并不是非常全面,如果有需要,后续笔者会单独写一篇文章来详细的描述TensorFlow中函数的用法。

表1 TF常用函数汇总

   函数族 函数介绍         常用函数
Math 数学函数 add(加), sub(减), mul(乘), Div(除),mod(取模)、abs(取绝对值)、log(计算log)、sin(正弦)
Array 数组操作 concat(合并), slice(切片), Split(分割)
Matrix 矩阵操作 diag(返回一个只有对角线的矩阵)、matul(矩阵相乘), matrix_inverse(求逆矩阵), matrix_determinant(求行列式)
Activation Functions 激活函数 relu、relu6、sigmoid、tanh、dropout
Convolution 卷积函数 conv2d、conv3d
Pooling 池化函数 avg_pool(平均池化)、max_pool、max_pool_with_argmax、avg_pool3d、max_pool3d
Normalization 数据标准化 l2_normalize(2范数标准化)、normalize_moments(均值方差归一化)
Losses 损失函数 l2_loss(误差平方和)
Classification 分类函数 sigmoid_cross_entropy_with_logits(交叉熵)、softmax、log_softmax、softmax_cross_entropy_with_logits
Recurrent Neural Networks 循环神经网络 rnn、bidirectional_rnn(双向rnn)、state_saving_rnn(可存储调用状态rnn)
Checkpointing 模型保存与加载 save(存储), restore(装载)

TensorBoard使用

当训练一个复杂的神经网络时候,经常会根据情况调整网络结构。比如,在训练过程中根据训练集和验证集的准确率,来判断是否存在过拟合,或者查看训练过程中损失函数。一般可以将这些数据打印到日志文件,但是当信息比较多的时候,直接看日志文件不直观。为了更好的理解、调试和优化网络,TensorFlow提供了一套数据可视化工具TensorBoard。

下边举一个例子来演示一下如何收集训练过程中的数据,并且利用TensorBoard将数据展示出来。本例中要进行线性拟合,拟合的函数大致为,但2和10事先不知道,通过训练得到。训练的完整代码如下:

 1#coding=utf-8
2import numpy as np
3import tensorflow as tf
4
5# 构建图
6x = tf.placeholder(tf.float32)
7y = tf.placeholder(tf.float32)
8weight = tf.get_variable("weight", [], tf.float32, initializer=tf.random_normal_initializer())
9biase  = tf.get_variable("biase", [], tf.float32, initializer=tf.random_normal_initializer())
10pred = tf.add(tf.multiply(x, weight, name="mul_op"), biase, name="add_op")
11
12#损失函数
13loss = tf.square(y - pred, name="loss")
14#优化函数
15optimizer = tf.train.GradientDescentOptimizer(0.01)
16#计算梯度,应用梯度操作
17grads_and_vars = optimizer.compute_gradients(loss)
18train_op = optimizer.apply_gradients(grads_and_vars)
19
20#收集值的操作
21tf.summary.scalar("weight", weight)
22tf.summary.scalar("biase", biase)
23tf.summary.scalar("loss", loss[0])
24
25merged_summary = tf.summary.merge_all()
26
27summary_writer = tf.summary.FileWriter('./log_graph' )
28summary_writer.add_graph(tf.get_default_graph())
29init_op = tf.global_variables_initializer()
30
31with tf.Session() as sess:
32    sess.run(init_op)
33    for step in range(500):
34        train_x = np.random.randn(1)
35        train_y = 2 * train_x + np.random.randn(1) * 0.01  + 10
36        _, summary = sess.run([train_op, merged_summary], feed_dict={x:train_x, y:train_y})
37        summary_writer.add_summary(summary, step)

执行代码之后,会将日志文件写入log_graph文件夹下,执行如下命令运行TensorBoard:

1tensorboard --logdir=./log_graph

然后在浏览器中输入:http://ip:6006,打开TensorBoard的界面,从界面可以看到:

(1)图的结构

图2 数据流图

(2)损失函数变化

图3 损失函数变化图

(3)拟合系数变化:

图4 系数和截距项变化图

从上图可以直观看出,随着训练的次数增加,系数趋近2,截距项趋近10,误差越来越小。

初识TensorFlow的更多相关文章

  1. 机器学习之路: 初识tensorflow 第一个程序

    计算图 tensorflow是一个通过计算图的形式来表示计算的编程系统tensorflow中每一个计算都是计算图上的一个节点节点之间的边描述了计算之间的依赖关系 张量 tensor张量可以简单理解成多 ...

  2. 初识 ❤ TensorFlow |【一见倾心】

    说明

  3. Tensorflow 安装 和 初识

    Windows中 Anaconda,Tensorflow 和 Pycharm的安装和配置   https://blog.csdn.net/zhuiqiuzhuoyue583/article/detai ...

  4. TensorFlow学习(1)-初识

    初识TensorFlow 一.术语潜知 深度学习:深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法. 深度学 ...

  5. TensorFlow 基础概念

    初识TensorFlow,看了几天教程后有些无聊,决定写些东西,来夯实一下基础,提供些前进动力. 一.Session.run()和Tensor.eval()的区别: 最主要的区别就是可以使用sess. ...

  6. TensorFlow学习(1)

    初识TensorFlow 一.术语潜知 深度学习:深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法. 深度学 ...

  7. TensorFlow从入门到入坑(1)

    初识TensorFlow 一.术语潜知 深度学习:深度学习(deep learning)是机器学习的分支,是一种试图使用包含复杂结构或由多重非线性变换构成的多个处理层对数据进行高层抽象的算法. 深度学 ...

  8. 语音识别(LSTM+CTC)

    完整版请微信关注“大数据技术宅” 序言:语音识别作为人工智能领域重要研究方向,近几年发展迅猛,其中RNN的贡献尤为突出.RNN设计的目的就是让神经网络可以处理序列化的数据.本文笔者将陪同小伙伴们一块儿 ...

  9. 大数据利器Hive

    序言:在大数据领域存在一个现象,那就是组件繁多,粗略估计一下轻松超过20种.如果你是初学者,瞬间就会蒙圈,不知道力往哪里使.那么,为什么会出现这种现象呢?在本文的开头笔者就简单的阐述一下这种现象出现的 ...

随机推荐

  1. php 将图片转成base64

    /** * 获取图片的Base64编码(不支持url) * @date 2017-02-20 19:41:22 * * @param $img_file 传入本地图片地址 * * @return st ...

  2. TCP(控制传输协议)详解

    1.传输层概述 在OSI参考模型中,网络层是面向通信的最高层但同时也是面向用户程序的最底层. 传输层的主要作用: 复用:在发送端,多个应用程序公用一个传输层: 分用:在接收端,传输层把从网络层接收到的 ...

  3. kubernates使用kubeadm安装

    kubeadm是Kubernetes官方提供的用于快速安装Kubernetes集群的工具,伴随Kubernetes每个版本的发布都会同步更新,kubeadm会对集群配置方面的一些实践做调整,通过实验k ...

  4. sass快速入门

    sass十分钟入门 变量 sass中可以定义变量,方便统一修改和维护. //sass style //----------------------------------- $fontStack: H ...

  5. Java-IO流之File操作和Properties操作

    java的File类主要是用来操作文件的元数据,稍作演示如下: 其中方法getAllJavaFile()是使用了过滤器FileFileter,这个过滤器只需要实现accept方法,判断什么样的文件返回 ...

  6. [LeetCode] Rectangle Overlap 矩形重叠

    A rectangle is represented as a list [x1, y1, x2, y2], where (x1, y1) are the coordinates of its bot ...

  7. 二分(HDU2289 Cup)

    贴代码: 题目意思:已知r水的下半径,R水的上半径,H为圆台高度,V为水的体积,求水的高度,如图: 水的高度一定在0-100,所以在这个区间逐步二分,对每次二分出的水的高度,计算相应的体积,看看计算出 ...

  8. mysql根据查询结果批量更新多条数据(插入或更新)

    mysql根据查询结果批量更新多条数据(插入或更新) 1.1 前言 mysql根据查询结果执行批量更新或插入时经常会遇到1093的错误问题.基本上批量插入或新增都会涉及到子查询,mysql是建议不要对 ...

  9. nova file injection的原理和调试过程

    file injection代码 file injection原理来讲是比较简单的,在nova boot命令中,有参数--file,是将文件inject到image中 nova boot --flav ...

  10. CSS Media媒体查询使用大全,完整媒体查询总结

    前面的话 一说到响应式设计,肯定离不开媒体查询media.一般认为媒体查询是CSS3的新增内容,实际上CSS2已经存在了,CSS3新增了媒体属性和使用场景(IE8-浏览器不支持).本文将详细介绍媒体查 ...