1、准备环境,探索数据

  1. import numpy as np
  2. from keras.models import Sequential
  3. from keras.layers import Dense
  4. import matplotlib.pyplot as plt
  5.  
  6. # 创建数据集
  7. rng = np.random.RandomState(27)
  8. X = np.linspace(-3, 5, 300)
  9. rng.shuffle(X) # 将数据集随机化
  10. y = 0.5 * X + 1 + np.random.normal(0, 0.05, 300) # 假设真实模型为:y = 0.5X + 1
  11.  
  12. # 绘制数据集
  13. plt.scatter(X, y, s=0.5)
  14. plt.show()

2、准备数据训练模型

  1. # 划分训练集和测试集
  2. X_train, y_train = X[:400], y[:400]
  3. X_test, y_test = X[-100:], y[-100:]
  4.  
  5. # 定义模型
  6. model = Sequential () # 用 Keras 序贯模型(Sequential)定义一个单输入单输出的模型 model
  7. model.add(Dense(output_dim=1, input_dim=1)) # 通过 add()方法一层, Dense 是全连接层,第一层需要定义输入
  8.  
  9. # 设置模型参数
  10. model.compile(loss='mse', optimizer='sgd') # 通过compile()方法选择损失函数(均方误差)和 优化器(随机梯度下降)
  11.  
  12. # 开始训练
  13. print('Training ==========')
  14. for step in range(301):
  15. cost = model.train_on_batch(X_train, y_train) # Keras 的 train_on_batch() 函数训练模型
  16. if step % 100 == 0:
  17. print('train cost: ', cost)

3、测试训练好的模型

  1. print('\nTesting ==========')
  2. cost = model.evaluate(X_test, y_test, batch_size=40)
  3. print('test cost:', cost)
  4. W, b = model.layers[0].get_weights() # 查看训练出的网络参数
  5.  
  6. print('Weights=', W, '\nbiases=', b) # 由于网络只有一层,且每次训练的输入和输出只有一个节点,因此第一层训练出 y=WX+b 的模型,其中 W,b 为训练出的参数

最终的测试 cost 为: 0.0026768923737108706

4、可视化测试结果

  1. y_pred = model.predict(X_test) # 用测试集进行预测
  2. plt.scatter(X_test, y_test, s=4) # 绘制测试点图
  3. plt.plot(X_test, y_pred, lw=0.7) # 绘制回归直线
  4. plt.show()

。。。

Keras 训练一个单层全连接网络的线性回归模型的更多相关文章

  1. PRML读书笔记——线性回归模型(上)

    本章开始学习第一个有监督学习模型--线性回归模型."线性"在这里的含义仅限定了模型必须是参数的线性函数.而正如我们接下来要看到的,线性回归模型可以是输入变量\(x\)的非线性函数. ...

  2. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  3. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  4. Keras(一)Sequential与Model模型、Keras基本结构功能

    keras介绍与基本的模型保存 思维导图 1.keras网络结构 2.keras网络配置 3.keras预处理功能 模型的节点信息提取 config = model.get_config() 把mod ...

  5. 线性回归模型的 MXNet 与 TensorFlow 实现

    本文主要探索如何使用深度学习框架 MXNet 或 TensorFlow 实现线性回归模型?并且以 Kaggle 上数据集 USA_Housing 做线性回归任务来预测房价. 回归任务,scikit-l ...

  6. 【scikit-learn】scikit-learn的线性回归模型

     内容概要 怎样使用pandas读入数据 怎样使用seaborn进行数据的可视化 scikit-learn的线性回归模型和用法 线性回归模型的评估測度 特征选择的方法 作为有监督学习,分类问题是预 ...

  7. R语言解读多元线性回归模型

    转载:http://blog.fens.me/r-multi-linear-regression/ 前言 本文接上一篇R语言解读一元线性回归模型.在许多生活和工作的实际问题中,影响因变量的因素可能不止 ...

  8. 机器学习(一) 从一个R语言案例学线性回归

    写在前面的话 按照正常的顺序,本文应该先讲一些线性回归的基本概念,比如什么叫线性回归,线性回规的常用解法等.但既然本文名为<从一个R语言案例学会线性回归>,那就更重视如何使用R语言去解决线 ...

  9. 多元线性回归模型的特征压缩:岭回归和Lasso回归

    多元线性回归模型中,如果所有特征一起上,容易造成过拟合使测试数据误差方差过大:因此减少不必要的特征,简化模型是减小方差的一个重要步骤.除了直接对特征筛选,来也可以进行特征压缩,减少某些不重要的特征系数 ...

随机推荐

  1. 跨交换机VLAN之间的通信(基于Cisco模拟器)

    实验要求: 拓扑结构如下 1.交换机2台:主机4台:网线若干. 2.把主机.交换机进行互联. 3.给2台交换机重命名为A.B. 4.设置2台交换机及主机的ip.注意IP要不冲突 5.在2台交换机上分别 ...

  2. control+shift + o热键冲突?????

    不知道有没有宝贝跟我遇到一样的问题 就是    control +shift+o    热键冲突了 进过我的严密调查. 这是因为你用的是A卡. 只要你把A卡换成N卡就可以了, 但是因为我太贫穷了,只能 ...

  3. Linux--部署Django项目

    简单部署 1.安装虚拟环境virtualenvwrapper,创建虚拟环境目录,进入虚拟环境,我的虚拟环境目录叫venv2 [root@HH ~]# workon venv2 (venv2) [roo ...

  4. raid,磁盘配额,DNS综合测试题

    DNS解析综合学习案例1.用户需把/dev/myvg/mylv逻辑卷以支持磁盘配额的方式挂载到网页目录下2.在网页目录下创建测试文件index.html,内容为用户名称,通过浏览器访问测试3.创建用户 ...

  5. 解决PEnetwork启动的时候提示"An error occured while starting the "TCP/IP Registry Compatibility" Service (2)!"程序将立即退出的问题

    解决PEnetwork启动的时候提示"An error occured while starting the "TCP/IP Registry Compatibility" ...

  6. BBS_02day

    目录 BBS_02day: 展示个人所有文章: 点赞,点彩功能: 评论功能: BBS_02day: 展示个人所有文章: def article_detail(request,username,arti ...

  7. 性感VSCODE在线刷LeetCode的题

    安装Nodejs并勾选添加到PATH VSCODE安装插件LeetCode 注册LeetCode账号(注意CN国区和国际区账号不通用),重启VSCODE并点左边栏那个LeetCode图标sign in ...

  8. intellij idea 解决2019年4月到期延期问题

    56ZS5PQ1RF-eyJsaWNlbnNlSWQiOiI1NlpTNVBRMVJGIiwibGljZW5zZWVOYW1lIjoi5q2j54mI5o6I5p2DIC4iLCJhc3NpZ25lZ ...

  9. Unity C# File类 本地数据保存和游戏存档

    进行本地数据存档和载入在游戏开发中非常常见,几乎任何一款游戏都需要这样的功能. 命名空间: using System.IO; 主要用于引入File类以处理各类文件操作. using System.Ru ...

  10. 阿里云ECS服务器CentOS7.2安装Python2.7.13

    阿里云ECS服务器CentOS7.2安装Python2.7.13 yum中最新的也是Python 2.6.6,只能下载Python 2.7.9的源代码自己编译安装. 操作步骤如下: 检查CentOS7 ...