1、基于MXNET框架的线性回归从零实现例子 

  下面博客是基于MXNET框架下的线性回归从零实现,以一个简单的房屋价格预测作为例子来解释线性回归的基本要素。这个应用的目标是预测一栋房子的售出价格(元)。

  为了简单起见,这里我们假设价格只取决于房屋状况的两个因素,即面积(平方米)和房龄(年)。接下来我们希望探索价格与这两个因素的具体关系:

  设房屋的面积为x1,房龄为x2,售出价格为y。我们需要建立基于输入x1和x2来计算输出yy的表达式,也就是模型(model)。顾名思义,线性回归假设输出与各个输入之间是线性关系:y'=x1w1+x2w2+b

  其中w1和w2是权重(weight),b是偏差(bias),且均为标量。它们是线性回归模型的参数(parameter)。模型输出y'是线性回归对真实价格y的预测或估计。我们通常允许它们之间有一定误差。

2、实现部分(各个部分见代码注释)

2.1、生成数据集(随机生成批量样本数据与高斯噪声)

2.2、读取数据集(遍历数据集并不断读取小批量数据样本)

2.3、初始化模型参数(均值为0、标准差为0.01的正态随机数,偏差则初始化成0)

2.4、定义模型

2.5、定义损失函数(平方损失函数)

2.6、定义优化算法(sgd小批量随机梯度下降算法)

2.7、训练模型(过调用反向函数backward计算小批量随机梯度,并调用优化算法sgd迭代模型参数)

3、代码实现

 1 from IPython import display
2 from matplotlib import pyplot as plt
3 from mxnet import autograd, nd
4 import random
5
6
7 # 生成数据集
8 num_inputs = 2
9 num_examples = 1000
10
11 true_w = [2, -3.4]
12 true_b = 4.2
13 features = nd.random.normal(scale=1, shape=(num_examples, num_inputs))
14
15 labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
16 labels += nd.random.normal(scale=0.01, shape=labels.shape)
17
18 print(features[0], labels[0])
19
20
21 def use_svg_display():
22 # 用矢量图显示
23 display.set_matplotlib_formats('svg')
24
25
26 def set_figsize(figsize=(3.5, 2.5)):
27 use_svg_display()
28 # 设置图的尺寸
29 plt.rcParams['figure.figsize'] = figsize
30
31
32 set_figsize()
33 plt.scatter(features[:, 1].asnumpy(), labels.asnumpy(), 1)
34
35 # plt.scatter(features[:, 0].asnumpy(), labels.asnumpy(), 1)
36 # help(plt.scatter)
37
38
39 # 读取数据集
40 def data_iter(batch_size, features, labels):
41 num_examples = len(features)
42 indices = list(range(num_examples))
43 random.shuffle(indices)
44 for i in range(0, num_examples, batch_size):
45 j = nd.array(indices[i: min(i + batch_size, num_examples)])
46 yield features.take(j), labels.take(j)
47
48
49 batch_size = 10
50
51 for X, y in data_iter(batch_size, features, labels):
52 print(X, y)
53 break
54
55
56 # 初始化模型参数
57 w = nd.random.normal(scale=0.01, shape=(num_inputs, 1))
58 b = nd.zeros(shape=(1,))
59 # 之后的模型训练中,需要对这些参数求梯度来迭代参数的值,因此我们需要创建它们的梯度
60 w.attach_grad()
61 b.attach_grad()
62
63
64 # 定义模型
65 def linreg(X, w, b):
66 return nd.dot(X, w) + b
67
68
69 # 定义平方损失函数
70 def squared_loss(y_hat, y):
71 return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
72
73
74 #定义优化算法
75 def sgd(params, lr , batch_size):
76 for param in params:
77 param[:] = param - lr * param.grad / batch_size
78
79
80 #训练模型
81 lr = 0.03
82 num_epochs = 3
83 net = linreg
84 loss = squared_loss
85
86 for epoch in range(num_epochs):
87 for X, y in data_iter(batch_size, features, labels):
88 with autograd.record():
89 l = loss(net(X, w, b), y)
90 l.backward()
91 sgd([w, b], lr, batch_size)
92 train_l = loss(net(features, w, b), labels)
93 print('epoch %d, loss %f' % (epoch + 1, train_l.mean().asnumpy()))
94
95
96
97 plt.show()

4、结果

4.1、特征features[1, :]和[:, 1]与labels之间的散点图

4.2、迭代结果

4.3、线性回归模型真实权重参数与训练得到的参数比较:print(true_w, w)   print(true_b, b)

基于MXNET框架的线性回归从零实现(房价预测为例)的更多相关文章

  1. Tensorflow之多元线性回归问题(以波士顿房价预测为例)

    一.根据波士顿房价信息进行预测,多元线性回归+特征数据归一化 #读取数据 %matplotlib notebook import tensorflow as tf import matplotlib. ...

  2. 基于 Keras 的 LSTM 时间序列分析——以苹果股价预测为例

    简介 时间序列简单的说就是各时间点上形成的数值序列,时间序列分析就是通过观察历史数据预测未来的值.预测未来股价走势是一个再好不过的例子了.在本文中,我们将看到如何在递归神经网络的帮助下执行时间序列分析 ...

  3. 基于netty框架的Socket传输

    一.Netty框架介绍 什么是netty?先看下百度百科的解释:         Netty是由JBOSS提供的一个java开源框架.Netty提供异步的.事件驱动的网络应用程序框架和工具,用以快速开 ...

  4. AWS研究热点:BMXNet – 基于MXNet的开源二进神经网络实现

    http://www.atyun.com/9625.html 最近提出的二进神经网络(BNN)可以通过应用逐位运算替代标准算术运算来大大减少存储器大小和存取率.通过显着提高运行时的效率并降低能耗,让最 ...

  5. 基于laravel框架构建最小内容管理系统

    校园失物招领平台开发 --基于laravel框架构建最小内容管理系统 摘要 ​ 针对目前大学校园人口密度大.人群活动频繁.师生学习生活等物品容易遗失的基本现状,在分析传统失物招领过程中的工作效率低下. ...

  6. 基于Dubbo框架构建分布式服务(一)

    Dubbo是Alibaba开源的分布式服务框架,我们可以非常容易地通过Dubbo来构建分布式服务,并根据自己实际业务应用场景来选择合适的集群容错模式,这个对于很多应用都是迫切希望的,只需要通过简单的配 ...

  7. 基于SSH框架的学生公寓管理系统的质量属性

    系统名称:学生公寓管理系统 首先介绍一下学生公寓管理系统,在学生公寓管理方面,针对学生有关住宿信息问题进行管理,学生公寓管理系统主要包含了1)学生信息记录:包括学号.姓名.性别.院系.班级:2)住宿信 ...

  8. 基于BootStrap框架构建快速响应的GPS部标监控平台

    最近一个客户要求将gps部标平台移植到bootStrap框架作为前端框架,符合交通部796部标只是他们的一个基本要求,重点是要和他们的冷链云物流平台进行适配.我自己先浏览了客户的云物流平台的界面,采用 ...

  9. 基于ssh框架的在线考试系统开发的质量属性

    我做的系统是基于ssh框架的在线考试系统.在线考试系统有以下几点特性:(1)系统响应时间需要非常快,可以迅速的出题,答题.(2)系统的负载量也需要非常大,可以支持多人在线考试(3)还有系统的安全性也需 ...

随机推荐

  1. kubernetes生产实践之mysql

    简介 kubedb mysql 生命周期及特性 Supported MySQL Features Features Availability Clustering ✓ Persistent Volum ...

  2. BeetleX使用bootstrap5开发SPA应用

        在早期版本BeetleX.WebFamily只提供了vuejs+element的集成,由于element只适合PC管理应用开发相对于移动应用适配则没这么方便.在新版本组件集成了bootstra ...

  3. Java 树结构实际应用 一(堆排序2秒排完800w数据)

    堆排序 1 堆排序基本介绍 1) 堆排序是利用堆这种数据结构而设计的一种排序算法,堆排序是一种选择排序,它的最坏,最好,平均时间复 杂度均为 O(nlogn),它也是不稳定排序. 2) 堆是具有以下性 ...

  4. python基础学习之列表的功能方法

    列表:list 格式 li = [1,2,3,4,5,6] 列表内部随意嵌套其他格式:字符串.列表.数字.元组.字典. 列表内部有序,且内容可更改 a = [1,2,3,4]    a[0] = 5  ...

  5. 前端学习 node 快速入门 系列 —— 服务端渲染

    其他章节请看: 前端学习 node 快速入门 系列 服务端渲染 在简易版 Apache一文中,我们用 node 做了一个简单的服务器,能提供静态资源访问的能力. 对于真正的网站,页面中的数据应该来自服 ...

  6. effective解读-第八条 避免使用finalizer和Cleaner

    java9之前finalizer,java9使用cleaner代替了finalizer.相比finalizer,cleaner(它存在于一个独立类Cleaner中,需要时候注入到对应类中即可)不会污染 ...

  7. Java中的集合List - 入门篇

    前言 大家好啊,我是汤圆,今天给大家带来的是<Java中的集合List - 入门篇>,希望对大家有帮助,谢谢 简介 说实话,Java中的集合有很多种,但是这里作为入门级别,先简单介绍第一种 ...

  8. ECMAScript 2018(ES9)新特性简介

    目录 简介 异步遍历 Rest/Spread操作符和对象构建 Rest Spread 创建和拷贝对象 Spread和bject.assign() 的区别 正则表达式 promise.finally 模 ...

  9. 18. vue-router案例-tabBar导航

    目标: 做一个导航tabbar 一. 分析 我们的目标是做一个导航tabbar, 要求 这个导航不仅可以在一个页面使用, 可以在多个页面通用 每个页面的样式可能不一样 每个页面的图标, 文字可能不一样 ...

  10. MySQL实战45讲,丁奇带你搞懂

    之前,你大概都是通过搜索别人的经验来解决问题.如果能够理解MySQL的工作原理,那么在遇到问题的时候,是不是就能更快地直戳问题的本质? 以实战中的常见问题为切入点,带你剖析现象背后的本质原因.为你串起 ...