版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。

神经网络可以用来模拟回归问题 (regression),实质上是单输入单输出神经网络模型,例如给下面一组数据,用一条线来对数据进行拟合,并可以预测新输入 x 的输出值。

一、详细解读

我们通过这个简单的例子来熟悉Keras构建神经网络的步骤:

1.导入模块并生成数据

首先导入本例子需要的模块,numpy、Matplotlib、和keras.models、keras.layers模块。Sequential是多个网络层的线性堆叠,可以通过向Sequential模型传递一个layer的list来构造该模型,也可以通过.add()方法一个个的将layer加入模型中。layers.Dense 意思是这个神经层是全连接层。

2.建立模型

然后用 Sequential 建立 model,再用 model.add 添加神经层,添加的是 Dense 全连接神经层。参数有两个,(注意此处Keras 2.0.2版本中有变更)一个是输入数据的维度,另一个units代表神经元数,即输出单元数。如果需要添加下一个神经层的时候,不用再定义输入的纬度,因为它默认就把前一层的输出作为当前层的输入。在这个简单的例子里,只需要一层就够了。

3.激活模型

model.compile来激活模型,参数中,误差函数用的是 mse均方误差;优化器用的是 sgd 随机梯度下降法。

4.训练模型

训练的时候用 model.train_on_batch 一批一批的训练 X_train, Y_train。默认的返回值是 cost,每100步输出一下结果。

5.验证模型

用到的函数是 model.evaluate,输入测试集的x和y,输出 cost,weights 和 biases。其中 weights 和 biases 是取在模型的第一层 model.layers[0] 学习到的参数。从学习到的结果你可以看到, weights 比较接近0.5,bias 接近 2。

Weights= [[ 0.49136472]]

biases= [ 2.00405312]

6.可视化学习结果

最后可以画出预测结果,与测试集的值进行对比。

二、完整代码

  

  1. import numpy as np
  2. np.random.seed(1337)
  3. from keras.models import Sequential
  4. from keras.layers import Dense
  5. import matplotlib.pyplot as plt
  6.  
  7. # 生成数据
  8. X = np.linspace(-1, 1, 200) #在返回(-1, 1)范围内的等差序列
  9. np.random.shuffle(X) # 打乱顺序
  10. Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200, )) #生成Y并添加噪声
  11. # plot
  12. plt.scatter(X, Y)
  13. plt.show()
  14.  
  15. X_train, Y_train = X[:160], Y[:160] # 前160组数据为训练数据集
  16. X_test, Y_test = X[160:], Y[160:] #后40组数据为测试数据集
  17.  
  18. # 构建神经网络模型
  19. model = Sequential()
  20. model.add(Dense(input_dim=1, units=1))
  21.  
  22. # 选定loss函数和优化器
  23. model.compile(loss='mse', optimizer='sgd')
  24.  
  25. # 训练过程
  26. print('Training -----------')
  27. for step in range(501):
  28. cost = model.train_on_batch(X_train, Y_train)
  29. if step % 50 == 0:
  30. print("After %d trainings, the cost: %f" % (step, cost))
  31.  
  32. # 测试过程
  33. print('\nTesting ------------')
  34. cost = model.evaluate(X_test, Y_test, batch_size=40)
  35. print('test cost:', cost)
  36. W, b = model.layers[0].get_weights()
  37. print('Weights=', W, '\nbiases=', b)
  38.  
  39. # 将训练结果绘出
  40. Y_pred = model.predict(X_test)
  41. plt.scatter(X_test, Y_test)
  42. plt.plot(X_test, Y_pred)
  43. plt.show()

  

三、其他补充

1. numpy.linspace

numpy.linspace(start, stop, num=50, endpoint=True,retstep=False,dtype=None)

返回等差序列,序列范围在(start,end),生成num个元素的np数组,如果endpoint为False,则生成num+1个但是返回num个,retstep=True则在其后返回步长.

>>> np.linspace(2.0, 3.0, num=5)
array([ 2. , 2.25, 2.5 , 2.75, 3. ])
>>> np.linspace(2.0, 3.0, num=5, endpoint=False)
array([ 2. , 2.2, 2.4, 2.6, 2.8])
>>> np.linspace(2.0, 3.0, num=5, retstep=True)
(array([ 2. , 2.25, 2.5 , 2.75, 3. ]), 0.25)

Keras上实现简单线性回归模型的更多相关文章

  1. 基于tensorflow的简单线性回归模型

    #!/usr/local/bin/python3 ##ljj [1] ##linear regression model import tensorflow as tf import matplotl ...

  2. 机器学习(2):简单线性回归 | 一元回归 | 损失计算 | MSE

    前文再续书接上一回,机器学习的主要目的,是根据特征进行预测.预测到的信息,叫标签. 从特征映射出标签的诸多算法中,有一个简单的算法,叫简单线性回归.本文介绍简单线性回归的概念. (1)什么是简单线性回 ...

  3. 机器学习——Day 2 简单线性回归

    写在开头 由于某些原因开始了机器学习,为了更好的理解和深入的思考(记录)所以开始写博客. 学习教程来源于github的Avik-Jain的100-Days-Of-MLCode 英文版:https:// ...

  4. Python回归分析五部曲(一)—简单线性回归

    回归最初是遗传学中的一个名词,是由英国生物学家兼统计学家高尔顿首先提出来的,他在研究人类身高的时候发现:高个子回归人类的平均身高,而矮个子则从另一方向回归人类的平均身高: 回归分析整体逻辑 回归分析( ...

  5. day-12 python实现简单线性回归和多元线性回归算法

    1.问题引入  在统计学中,线性回归是利用称为线性回归方程的最小二乘函数对一个或多个自变量和因变量之间关系进行建模的一种回归分析.这种函数是一个或多个称为回归系数的模型参数的线性组合.一个带有一个自变 ...

  6. PRML读书笔记——线性回归模型(上)

    本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...

  7. 用Tensorflow完成简单的线性回归模型

    思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在 ...

  8. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  9. R语言解读一元线性回归模型

    转载自:http://blog.fens.me/r-linear-regression/ 前言 在我们的日常生活中,存在大量的具有相关性的事件,比如大气压和海拔高度,海拔越高大气压强越小:人的身高和体 ...

随机推荐

  1. git crate&query&delete tag(九)

    root@vmuer-VirtualBox:/opt/myProject# git log --pretty=oneline0169b7a1c4bccb47e76711f353fd8d3864bde9 ...

  2. python3中“->”的含义

    ->:标记返回函数注释,信息作为.__annotations__属性提供 __annotations__属性是字典.键return是用于在箭头后检索值的键.但是在Python中3.5,PEP 4 ...

  3. [SDOI2018]物理实验 set,扫描线,旋转坐标系

    [SDOI2018]物理实验 set,扫描线,旋转坐标系 链接 loj 思路 先将导轨移到原点,然后旋转坐标系,参考博客. 然后分线段,每段的贡献(三角函数值)求出来,用自己喜欢的平衡树,我选set. ...

  4. pcm音频的格式类型

    [文章内容属于多方转载内容] PCM Parameters PCM audio is coded using a combination of various parameters. Resoluti ...

  5. [C#]AdvPropertyGrid的使用示例(第三方控件:DevComponents.DotNetBar2.dll)

    开发环境:Visual Studio 2019 .NET版本:4.5.2 效果如下: 1.初始化界面: 2.属性“人物”-自定义控件显示: 3.属性“地址”-自定义窗体显示: 4.属性“性别”-枚举显 ...

  6. 调用 Dll 中的函数时,出现栈(STACK)的清除问题 -> 故障模块名称: StackHash_0a9e

    在一个名为 test.dll 文件中,有一个 Max() 函数的定义是: #ifdef BUILD_DLL #define DLL_EXPORT __declspec(dllexport) __std ...

  7. 全国自考C++程序设计

    一.单项选择题(本大题共20小题,每小题1分,共20分)在每小题列出的四个备选项中 只有一个是符合题目要求的,请将其代码填写在题后的括号内.错选.多选或未选均无 分. 1. 编写C++程序一般需经过的 ...

  8. 031 Spring Data Elasticsearch学习笔记---重点掌握第5节高级查询和第6节聚合部分

    Elasticsearch提供的Java客户端有一些不太方便的地方: 很多地方需要拼接Json字符串,在java中拼接字符串有多恐怖你应该懂的 需要自己把对象序列化为json存储 查询到结果也需要自己 ...

  9. Nginx Tutorial #1: Basic Concepts(转)

    add by zhj: 文章写的很好,适合初学者 原文:https://www.netguru.com/codestories/nginx-tutorial-basics-concepts Intro ...

  10. 『正睿OI 2019SC Day3』

    容斥原理 容斥原理指的是一种排重,补漏的计算思想,形式化的来说,我们有如下公式: \[\left | \bigcup_{i=1}^nS_i \right |=\sum_{i}|S_i|-\sum_{i ...