本文同步自:https://zhuanlan.zhihu.com/p/30738405

本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法

有一组数据:

x=[1,2,3,4,5,6,7,8,9,10]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条过原点的直线,穿过上述所有点

这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis

由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设 =0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢

引入概念代价函数 cost function

在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中, 为2

将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点

那么问题就变成了如何最小化J(θ)

在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低

引入梯度下降法

其中α为步幅

梯度下降法可以解释为:对J(θ)求关于 (本例中只有 )的偏导数并乘以步幅,再用 减去该值,得到的结果赋值给 。此过程需要重复多次

步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛

当选取了一个较大步幅的时候,将震荡收敛

当选取了一个过大步幅的时候,将无法收敛

调整一下例子

x=[5,6,7,8,9,10,11,12,13,14]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条直线,穿过上述所有点

很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])

n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")

# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph

plt.figure()
plt.ion()

with tf.Session() as sess:
    sess.run(init)

    # Fit all training data
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X: x, Y: y})

        # Display logs per epoch step
        if (epoch + 1) % display_step == 0:
            c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
                  "W=", sess.run(W), "b=", sess.run(b))

            plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
            plt.plot(train_X, train_Y, 'ro', label='Original data')
            plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
            plt.pause(0.001)
            plt.clf()

    print("Optimization Finished!")
    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
    print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')

    plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
    plt.pause(10)

线性回归,附tensorflow实现的更多相关文章

  1. 逻辑回归,附tensorflow实现

    本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...

  2. 简单的线性回归问题-TensorFlow+MATLAB·

    首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...

  3. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  4. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  5. 经典损失函数:交叉熵(附tensorflow)

    每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...

  6. Tensorflow之单变量线性回归问题的解决方法

    跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...

  7. Tensorflow学习笔记01

    Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...

  8. TensorFlow 从零到helloWorld

    目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建   2.1 tensorflow安装   2.2 CUDA安装   2.3 ...

  9. TF linear regression

    本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...

随机推荐

  1. jquery基本选择器:id选择器、class选择器、标签选择器、通配符选择器

    全栈工程师开发手册 (作者:栾鹏) jquery系列教程1-选择器全解 jquery基本选择器 jquery基本选择器,包括id选择器.class选择器.标签选择器.通配符选择器,同时配合选择器的空格 ...

  2. riot.js教程【一】简介

    Riotjs简介 Riotjs是一款简单的.优雅的.组件化UI前端开发框架: 他支持自定义标签(custom tags),拥有令人愉悦的语法,优雅的API和非常小的体积: 为什么需要一个新的界面库 前 ...

  3. 前端魔法堂——异常不仅仅是try/catch

    前言  编程时我们往往拿到的是业务流程正确的业务说明文档或规范,但实际开发中却布满荆棘和例外情况,而这些例外中包含业务用例的例外,也包含技术上的例外.对于业务用例的例外我们别无它法,必须要求实施人员与 ...

  4. SpringMVC Spring MyBatis整合配置文件

    1.spring管理SqlSessionFactory.mapper 1)在classpath下创建mybatis/sqlMapConfig.xml <?xml version="1. ...

  5. LeetCode 289. Game of Life (生命游戏)

    According to the Wikipedia's article: "The Game of Life, also known simply as Life, is a cellul ...

  6. 蒙特卡罗算法(Monte Carlo method)

    蒙特卡罗方法概述 蒙特卡罗方法又称统计模拟法.随机抽样技术,是一种随机模拟方法,以概率和统计理论方法为基础的一种计算方法,是使用随机数(或更常见的伪随机数)来解决很多计算问题的方法.将所求解的问题同一 ...

  7. Sequence one

    Problem Description Search is important in the acm algorithm. When you want to solve a problem by us ...

  8. JavaScript基础一(js基础函数与运算符)

    [使用js的三种方式] 1.在HTML标签中,直接内嵌js(并不提倡使用) <button onclick=" alert('点就点')"> 点我啊</butto ...

  9. css设置黑体宋体等(转)

    代码如下: .selector{ font-family:"Microsoft YaHei",微软雅黑,"MicrosoftJhengHei",华文细黑,STH ...

  10. Centos7下部署ceph 12.2.1 (luminous)集群及RBD使用

    前言 本文搭建了一个由三节点(master.slave1.slave2)构成的ceph分布式集群,并通过示例使用ceph块存储. 本文集群三个节点基于三台虚拟机进行搭建,节点安装的操作系统为Cento ...