本文同步自:https://zhuanlan.zhihu.com/p/30738405

本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法

有一组数据:

x=[1,2,3,4,5,6,7,8,9,10]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条过原点的直线,穿过上述所有点

这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis

由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设 =0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢

引入概念代价函数 cost function

在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中, 为2

将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点

那么问题就变成了如何最小化J(θ)

在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低

引入梯度下降法

其中α为步幅

梯度下降法可以解释为:对J(θ)求关于 (本例中只有 )的偏导数并乘以步幅,再用 减去该值,得到的结果赋值给 。此过程需要重复多次

步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛

当选取了一个较大步幅的时候,将震荡收敛

当选取了一个过大步幅的时候,将无法收敛

调整一下例子

x=[5,6,7,8,9,10,11,12,13,14]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条直线,穿过上述所有点

很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])

n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")

# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph

plt.figure()
plt.ion()

with tf.Session() as sess:
    sess.run(init)

    # Fit all training data
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X: x, Y: y})

        # Display logs per epoch step
        if (epoch + 1) % display_step == 0:
            c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
                  "W=", sess.run(W), "b=", sess.run(b))

            plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
            plt.plot(train_X, train_Y, 'ro', label='Original data')
            plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
            plt.pause(0.001)
            plt.clf()

    print("Optimization Finished!")
    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
    print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')

    plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
    plt.pause(10)

线性回归,附tensorflow实现的更多相关文章

  1. 逻辑回归,附tensorflow实现

    本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...

  2. 简单的线性回归问题-TensorFlow+MATLAB·

    首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...

  3. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  4. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  5. 经典损失函数:交叉熵(附tensorflow)

    每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...

  6. Tensorflow之单变量线性回归问题的解决方法

    跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...

  7. Tensorflow学习笔记01

    Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...

  8. TensorFlow 从零到helloWorld

    目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建   2.1 tensorflow安装   2.2 CUDA安装   2.3 ...

  9. TF linear regression

    本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...

随机推荐

  1. 1.ElasticSearch介绍及基本概念

    一.ElasticSearch介绍 一个采用RESTful API标准的高扩展性的和高可用性的实时性分析的全文搜索工具 基于Lucene[开源的搜索引擎框架]构建 ElasticSearch是一个面向 ...

  2. 4.ElasticSearch的基本api操作

    1. ElasticSearch的Index 1. 索引初始化 在创建索引之前 对索引进行初始化操作 指定shards数量和replicas数量 curl -XPUT 'http://192.168. ...

  3. hdu5696 区间的价值

    区间的价值 我们定义"区间的价值"为一段区间的最大值*最小值. 一个区间左端点在L,右端点在R,那么该区间的长度为(R-L+1). 现在聪明的杰西想要知道,对于长度为k的区间,最大 ...

  4. win10 uwp 切换主题

    本文主要说如何在UWP切换主题,并且如何制作主题. 一般我们的应用都要有多种颜色,一种是正常的白天颜色,一种是晚上的黑夜颜色,还需要一种辅助的高对比颜色.这是微软建议的,一般应用都要包含的颜色. 我们 ...

  5. Javaweb配置最全的数据源配置

    DBCP DBCP是Apache推出的数据库连接池(Database Connection Pool). 操作步骤: 添加jar包: commons-dbcp-1.4.jar commons-pool ...

  6. Table 控件各元素及属性

    功能:在Web页中创建通用表格里. 属性: 1.CellPadding属性:用于设置表中单元格的边框和内容之间的距离(以像素为单位).默认为-(未设置). 2.CellSpacing属性:用于设置表中 ...

  7. C# 使用HtmlAgilityPack抓取网页信息

    前几天看到一篇博文:C# 爬虫 抓取小说 博主使用的是正则表达式获取小说的名字.目录以及内容. 下面使用HtmlAgilityPack来改写原博主的代码 在使用HtmlAgilityPack之前,可以 ...

  8. Maven优雅的添加第三方Jar包

    在利用Maven构建项目的时候会出现某些Jar包无法下载到本地的Repository中,鉴于这种情况比较普遍存在,特归纳以下解决问题办法:以 ojdbc14-10.2.0.4.0.jar为例[其它Ja ...

  9. [hihoCoder]无间道之并查集

    题目大意: #1066 : 无间道之并查集 时间限制:20000ms 单点时限:1000ms 内存限制:256MB 描述 这天天气晴朗.阳光明媚.鸟语花香,空气中弥漫着春天的气息……额,说远了,总之, ...

  10. 为什么国外的 App 很少会有开屏广告?

    前言: 笔者在知乎看到这个问题,觉得这的确是一个值得关注和回答的现象,遂写了回答并整理成本文发布在此抛砖引玉,欢迎讨论. 正文: 古话说得好,先问是不是,再问为什么. 对于「国外的 App 很少有开屏 ...