TensorFlow笔记-08-过拟合,正则化,matplotlib 区分红蓝点

首先提醒一下,第7讲的最后滑动平均的代码已经更新了,代码要比理论重要

今天是过拟合,和正则化,本篇后面可能或更有兴趣,因为涉及到可视化图形了,而不是纯数据

  • 过拟合:神经网络模型在训练集上的准确率比较高在新的数据进行预测或分类时准确率较低,说明模型泛华能力差
  • 正则化:在损失函数中给每个参数w加上权重,引入模型辅助度指标,从而抑制模型噪声,减小过拟合

使用正则化后,损失函数 loss 变为两项之和:

loss = loss(y与y_) + REGULARIZER*loss(w)

其中,第一项是预测结果与标准答案之间的差距,如之前讲过的交叉熵,均方误差等;第二项是正则化计算结果

看过我爬虫教程的可能了解这个正则 re,re就是这个 regularize

  • 正则化计算方法:

    (1)L1正则化:lossL1 = Σi |wi|

    用 Tensorflow 函数表示:loss(w) = tf.contrib.layers.11_regularizer(REGULARIZER)(w)

    (2)L2正则化:lossL2 = Σi |wi|2

    用 Tensorflow 函数表示:loss(w) = tf.contrib.layers.12_regularizer(REGULARIZER)(w)

  • 用 Tensorflow 函数实现正则化:

    tf.add_to_collection('losses', tf.contrib.layers.12_regularizer(regularizer)(w)

    losss = cem + tf.add_n(tf.get_collection('losses'))

例如:

    用 300 个符合正态分布的点 X[x0, x1]作为数据集,根据点 X[x0, x1] 计算生成标注 Y_,将数据集标注为红色点和蓝色点。

    标注规则:当 x02 + x12 >= 2 时,y_=0,标注为蓝色

    我们分别用无正则化和正则化两种方法,拟合曲线,把红色点和蓝色点。在实际分类时,如果前向传播输出的预测值y接近1则为红色点概率越大,接近0则为蓝色点概率越大,输出的预测值y为0.5是红蓝点概率分界线

    在本例子中,我们使用了之前未用过的模块与函数

  • matplotlib 模块:Python 中可视化工具模块,实现函数可视化

  • matplotlib 的安装

    • 1.在 PyCharm 设置中添加就可以:

    • 2.终端安装指令:

      pip install matplotlib

  • 函数 plt.scatter ():利用指定颜色实现点 (x,y) 的可视化

    plt.scatter (x 坐标,y 坐标,c="颜色")

    plt.show()

  • 收集规定区域内所有的网格坐标点:

    # 找到规定区域以步长为分辨率的行列网格坐标点

    xx,yy = np.mgrid[起:止:步长,起:止:步长]

    # 收集规定区域内所有的网格坐标点

    grid = np.c_[xx.ravel(), yy.ravel()]

  • plt.contour() 函数:告知 x,y 坐标和各点高度,用 levels 指定高度的点瞄上颜色

    plt.contour (x 轴坐标值,y 轴坐标值,该点的高度,levels=[等高线高度])

    plt.show()

    本例代码如下:

  1. #coding:utf-8
  2. #导入模块,生成模拟数据集
  3. import tensorflow as tf
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. BATCH_SIZE = 30
  7. seed = 2
  8. # 基于 seed 产生随机数
  9. rdm = np.random.RandomState(seed)
  10. #随机数返回行列的矩阵,表示组坐标点(x0, x1)作为输入数据集
  11. X = rdm.randn(300,2)
  12. #从X这个300行2列的矩阵中取出一行,判断如果两个坐标的平方和小于2,给Y赋值1,其余值0
  13. #作为输入数据集的标签(正确答案)
  14. Y_ = [int(x0*x0 + x1*x1<2) for (x0,x1) in X]
  15. #遍历Y中的每个元素,1赋值 'red' 其余赋值为 'blue',这样可视化显示时人可以直观区分
  16. Y_c = [['red' if y else 'blue'] for y in Y_]
  17. #对数据集X和标签Y进行shap整理,第一个元素为-1表示,随第二个参数计算得到
  18. #第二个元素表示多少列,把X郑磊为n行2列,把Y整理为n行1列
  19. X = np.vstack(X).reshape(-1,2)
  20. Y_ = np.vstack(Y_).reshape(-1,1)
  21. print(X)
  22. print(Y_)
  23. print(Y_c)
  24. # 用plt.scatter画出数据集X各行中第0列元素和第1列元素的点即各行的(x0,x1),
  25. # 用各行Y_c对应的值表示颜色(c是color的缩写)
  26. plt.scatter(X[:,0], X[:,1],c=np.squeeze(Y_c))
  27. plt.show()
  28. # 定义神经网络的输入,参数和输出,定义前向传播过程
  29. def get_weight(shape, regularizer):
  30. w = tf.Variable(tf.random_normal(shape), dtype=tf.float32)
  31. tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
  32. return w
  33. def get_bias(shape):
  34. b = tf.Variable(tf.constant(0.01, shape=shape))
  35. return b
  36. x = tf.placeholder(tf.float32, shape=(None, 2))
  37. y_ = tf.placeholder(tf.float32, shape=(None, 1))
  38. w1 = get_weight([2,11], 0.01)
  39. b1 = get_bias([11])
  40. y1 = tf.nn.relu(tf.matmul(x, w1)+b1)
  41. w2 = get_weight([11,1], 0.01)
  42. b2 = get_bias([1])
  43. y = tf.matmul(y1, w2)+b2 #输出层不过激活
  44. # 定义损失函数
  45. loss_mse = tf.reduce_mean(tf.square(y-y_))
  46. loss_total = loss_mse + tf.add_n(tf.get_collection('losses'))
  47. # 定义反向传播方法:不含正则化
  48. train_step = tf.train.AdadeltaOptimizer(0.0001).minimize(loss_mse)
  49. with tf.Session() as sess:
  50. init_op = tf.global_variables_initializer()
  51. sess.run(init_op)
  52. STEPS = 40000
  53. for i in range(STEPS):
  54. start = (i*BATCH_SIZE)%300
  55. end = start + BATCH_SIZE
  56. sess.run(train_step,feed_dict={x:X[start:end],y_:Y_[start:end]})
  57. if i % 2000 == 0:
  58. loss_mse_v = sess.run(loss_mse,feed_dict={x:X,y_:Y_})
  59. print("Atfer %d steps, loss is:%f" %(i, loss_mse_v))
  60. #xx在-3到3之间以步长为0.01,yy在-3到3之间以步长0.01,生成二维码网格坐标点
  61. xx,yy = np.mgrid[-3:3:.01, -3:3:.01]
  62. #将xx,yy拉直,并合成一个2列的矩阵,得到一个网格的集合
  63. grid = np.c_[xx.ravel(),yy.ravel()]
  64. #将网格坐标点喂入神经网络,probs为输出
  65. probs = sess.run(y, feed_dict={x:grid})
  66. # probs 的shape调整成xx的样子
  67. probs = probs.reshape(xx.shape)
  68. print("w1:\n",sess.run(w1))
  69. print("b1:\n", sess.run(b1))
  70. print("w2:\n", sess.run(w2))
  71. print("b2:\n", sess.run(b2))
  72. plt.scatter(X[:,0],X[:,1], c=np.squeeze(Y_c))
  73. plt.contour(xx,yy,probs,levels=[.5])
  74. plt.show()
  75. #定义反向传播方法:包含正则化
  76. train_step = tf.train.AdamOptimizer(0.0001).minimize(loss_total)
  77. with tf.Session() as sess:
  78. init_op = tf.global_variables_initializer()
  79. sess.run(init_op)
  80. STEPS = 40000
  81. for i in range(STEPS):
  82. start = (i*BATCH_SIZE)%300
  83. end = start + BATCH_SIZE
  84. sess.run(train_step, feed_dict={x:X[start:end],y_:Y_[start:end]})
  85. if i %2000 ==0:
  86. loss_v = sess.run(loss_total, feed_dict={x:X,y_:Y_})
  87. print("Atfer %d steps, loss is:%f" % (i, loss_v))
  88. # xx在-3到3之间以步长为0.01,yy在-3到3之间以步长0.01,生成二维码网格坐标点
  89. xx, yy = np.mgrid[-3:3:.01, -3:3:.01]
  90. # 将xx,yy拉直,并合成一个2列的矩阵,得到一个网格的集合
  91. grid = np.c_[xx.ravel(), yy.ravel()]
  92. # 将网格坐标点喂入神经网络,probs为输出
  93. probs = sess.run(y, feed_dict={x: grid})
  94. # probs 的shape调整成xx的样子
  95. probs = probs.reshape(xx.shape)
  96. print("w1:\n", sess.run(w1))
  97. print("b1:\n", sess.run(b1))
  98. print("w2:\n", sess.run(w2))
  99. print("b2:\n", sess.run(b2))
  100. plt.scatter(X[:, 0], X[:, 1], c=np.squeeze(Y_c))
  101. plt.contour(xx, yy, probs, levels=[.5])
  102. plt.show()

运行过程可能较慢,请耐心等待

注意:红字提示不是报错,只是提示,现在可以不管

运行结果

主要看输出的三张图:

下面再说一篇每张图的意思:

第一张图:

    只有红蓝点,对随机的点进行数据集可视化,标注规则:当 x02 + x12 >= 2 时,y_=0,标注为蓝色

第二张图:

    代码的注释中说明已经很详细了,就是执行没有正则化训练过程,将红蓝点分开的效果

显然我这个有点失败,但主要就是为了突出第三张图片的效果

第三张图:

执行包含正则化训练过程,将红蓝点分开的效果

更多文章链接:Tensorflow 笔记


- 本笔记不允许任何个人和组织转载

TensorFlow笔记-08-过拟合,正则化,matplotlib 区分红蓝点的更多相关文章

  1. 20180929 北京大学 人工智能实践:Tensorflow笔记08

    https://www.bilibili.com/video/av22530538/?p=28 ———————————————————————————————————————————————————— ...

  2. Tensorflow 笔记

    TensorFlow笔记-08-过拟合,正则化,matplotlib 区分红蓝点 TensorFlow笔记-07-神经网络优化-学习率,滑动平均 TensorFlow笔记-06-神经网络优化-损失函数 ...

  3. tensorflow笔记(二)之构造一个简单的神经网络

    tensorflow笔记(二)之构造一个简单的神经网络 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7425200.html ...

  4. tensorflow笔记:多层LSTM代码分析

    tensorflow笔记:多层LSTM代码分析 标签(空格分隔): tensorflow笔记 tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) ten ...

  5. tensorflow笔记(一)之基础知识

    tensorflow笔记(一)之基础知识 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7399701.html 前言 这篇no ...

  6. tensorflow笔记(三)之 tensorboard的使用

    tensorflow笔记(三)之 tensorboard的使用 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7429344.h ...

  7. tensorflow笔记(五)之MNIST手写识别系列二

    tensorflow笔记(五)之MNIST手写识别系列二 版权声明:本文为博主原创文章,转载请指明转载地址 http://www.cnblogs.com/fydeblog/p/7455233.html ...

  8. TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵

    TensorFlow笔记-06-神经网络优化-损失函数,自定义损失函数,交叉熵 神经元模型:用数学公式比表示为:f(Σi xi*wi + b), f为激活函数 神经网络 是以神经元为基本单位构成的 激 ...

  9. TensorFlow笔记-01-开篇概述

    人工智能实践:TensorFlow笔记-01-开篇概述 从今天开始,从零开始学习TensorFlow,有相同兴趣的同志,可以互相学习笔记,本篇是开篇介绍 Tensorflow,已经人工智能领域的一些名 ...

随机推荐

  1. A Creative Cutout CodeForces - 933D (计数)

    大意:给定$n$个圆, 圆心均在原点, 第$k$个圆半径为$\sqrt{k}$ 定义一个点的美丽值为所有包含这个点的圆的编号和 定义函数$f(n)$为只有$n$个圆时所有点的贡献,求$\sum_{k= ...

  2. 从mysql数据库删除重复记录只保留其中一条

    这两天做了一个调用第三方接口的小程序,因为是实时更新数据,所以请求接口的频率就很高,这样有时会出现往数据库插入重复的数据,对数据库造成压力也不方便管理,因为要通过原生sql语句,解决数据库的去重问题. ...

  3. nothing added to commit but untracked files present.

    当我们使用git的时候 如果我们在工作区修改了某些文件而没有新增文件,可以直接用: $ git commit --all -m "备注信息"                  -- ...

  4. C#加密方法汇总(SHA1加密字符串,MD5加密字符串,可逆加密等)

    using System;using System.Collections.Generic;using System.Text; namespace StringEncry{ class Encode ...

  5. Leetcode 74

    class Solution { public: bool searchMatrix(vector<vector<int>>& matrix, int target) ...

  6. ElasticSearch-hadoop saveToEs源码分析

    ElasticSearch-hadoop saveToEs源码分析: 类的调用路径关系为: EsSpark -> EsRDDWriter -> RestService -> Rest ...

  7. 【webpack系列】1 What is webpack?

    什么是webpack? 现今的网页可以看做是功能丰富的应用,拥有着复杂的js代码和一大堆依赖包.为了简化开发的复杂程度,有了很多好用的实践方法 模块化 让我们可以把复杂的程序细化为小的文件 类似于Ty ...

  8. idea破解更新

    idea破解教程: https://www.cnblogs.com/jpfss/p/8872358.html JetbrainsCrack-3.1-release-enc.jar下载:http://i ...

  9. ShiroFilterFactoryBean 处理拦截资源文件问题(Shiro权限管理)

    一.需要定义ShiroFilterFactoryBean()方法,而ShiroFilterFactoryBean.class是实现了FactoryBean和BeanPostProcessor接口: 1 ...

  10. duilib CEditUI 禁止输入中文字符,禁止复制粘贴

    1.CEditUI 禁止使用中文输入法 在 CEditUI::DoEvent 函数中,添加代码: if(m_bOnlyEnglishChar && m_pWindow &&am ...