[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现

  • code
  1. #%%
  2. from matplotlib import pyplot as plt
  3. from mxnet import autograd, nd
  4. import random
  5.  
  6. #%%
  7. num_inputs = 1
  8. num_examples = 100
  9. true_w = 1.56
  10. true_b = 1.24
  11. features = nd.arange(0,10,0.1).reshape((-1, 1))
  12. labels = true_w * features + true_b
  13. labels += nd.random.normal(scale=0.2, shape=labels.shape)
  14.  
  15. features[0], labels[0]
  16.  
  17. #%%
  18. # 本函数已保存在d2lzh包中方便以后使用
  19. def data_iter(batch_size, features, labels):
  20. num_examples = len(features)
  21. indices = list(range(num_examples))
  22. random.shuffle(indices) # 样本的读取顺序是随机的
  23. for i in range(0, num_examples, batch_size):
  24. j = nd.array(indices[i: min(i + batch_size, num_examples)])
  25. yield features.take(j), labels.take(j) # take函数根据索引返回对应元素
  26.  
  27. #%%
  28. batch_size = 10
  29.  
  30. for X, y in data_iter(batch_size, features, labels):
  31. print(X, y)
  32. break
  33.  
  34. #%%
  35. w = nd.random.normal(scale=0.01, shape=(num_inputs, 1))
  36. b = nd.zeros(shape=(1,))
  37.  
  38. #%%
  39.  
  40. w.attach_grad()
  41. b.attach_grad()
  42.  
  43. #%%
  44. def linreg(X, w, b): # 本函数已保存在d2lzh包中方便以后使用
  45. return nd.dot(X, w) + b
  46.  
  47. #%%
  48.  
  49. def squared_loss(y_hat, y): # 本函数已保存在d2lzh包中方便以后使用
  50. return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
  51.  
  52. #%%
  53.  
  54. def sgd(params, lr, batch_size): # 本函数已保存在d2lzh包中方便以后使用
  55. for param in params:
  56. param[:] = param - lr * param.grad / batch_size
  57.  
  58. #%%
  59.  
  60. lr = 0.05
  61. num_epochs = 20
  62. net = linreg
  63. loss = squared_loss
  64.  
  65. for epoch in range(num_epochs): # 训练模型一共需要num_epochs个迭代周期
  66. # 在每一个迭代周期中,会使用训练数据集中所有样本一次(假设样本数能够被批量大小整除)。X
  67. # 和y分别是小批量样本的特征和标签
  68. for X, y in data_iter(batch_size, features, labels):
  69. with autograd.record():
  70. l = loss(net(X, w, b), y) # l是有关小批量X和y的损失
  71. l.backward() # 小批量的损失对模型参数求梯度
  72. sgd([w, b], lr, batch_size) # 使用小批量随机梯度下降迭代模型参数
  73. train_l = loss(net(features, w, b), labels)
  74. print('epoch %d, loss %f' % (epoch + 1, train_l.mean().asnumpy()))
  75.  
  76. #%%
  77. true_w, w
  78.  
  79. #%%
  80. true_b, b
  81.  
  82. #%%
  83. plt.scatter(features.asnumpy(), labels.asnumpy(), 1)
  84.  
  85. labels1 = linreg(features,w,b)
  86. plt.scatter(features.asnumpy(), labels1.asnumpy(), 1)
  87. plt.show()

  • out

黄色是原始数据

绿色为拟合数据

[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现的更多相关文章

  1. OI回忆录第一章 逐梦之始

    2013年春,初中零年级.GXZ来到吉大高中机房,参加一位老师曾在班级宣传的"计算机培训".同行的有这位老师,以及近80名同学.和同学们一样,GXZ也是为了在机房玩游戏而参加所谓的 ...

  2. 洛谷 P5640 【CSGRound2】逐梦者的初心

    洛谷 P5640 [CSGRound2]逐梦者的初心 洛谷传送门 题目背景 注意:本题时限修改至250ms,并且数据进行大幅度加强.本题强制开启O2优化,并且不再重测,请大家自己重新提交. 由于Y校的 ...

  3. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

  4. 《逐梦旅程 WINDOWS游戏编程之从零开始》笔记8——载入三维模型&Alpha混合技术&深度测试与Z缓存

    第17章 三维游戏模型的载入 主要是如何从3ds max中导出.X文件,以及如何从X文件加载三维模型到DirextX游戏程序里.因为复杂的3D物体,要用代码去实现,那太反人类了,所以我们需要一些建模软 ...

  5. 《逐梦旅程 WINDOWS游戏编程之从零开始》笔记10——三维天空的构建&三维粒子的实现&多游戏模型的载入

    第23章 三维天空的构建 目前描述三维天空的技术主要包括三种类型,直接来介绍使用最广泛的模拟技术,详细的描述可以见作者的博文. 天空盒(Sky Box),即放到场景的是一个立方体.它是目前使用最广泛的 ...

  6. 《逐梦旅程 WINDOWS游戏编程之从零开始》笔记9——游戏摄像机&三维地形的构建

    第21章 游戏摄像机的构建 之前的程序示例,都是通过封装的DirectInput类来处理键盘和鼠标的输入,对应地改变我们人物模型的世界矩阵来达到移动物体,改变观察点的效果.其实我们的观察方向乃至观察点 ...

  7. 《逐梦旅程 WINDOWS游戏编程之从零开始》笔记7——DirectInput&纹理映射

    第15章 DirectInput接口 DirectInput作为DirectX的组件之一,依然是一些COM对象的集合.DirectInput由IDirectinput8.IDirectInputDev ...

  8. 《逐梦旅程 WINDOWS游戏编程之从零开始》笔记6——四大变换&光照与材质

    第13章 四大变换 在Direct3D中,如果为进行任何空间坐标变换而直接绘图的话,图形将始终处于应用程序窗口的中心位置,默认这个位置就成为世界坐标系的原点(0,0,0).而且我们也不能改变观察图形的 ...

  9. 《逐梦旅程 WINDOWS游戏编程之从零开始》笔记5——Direct3D中的顶点缓存和索引缓存

    第12章 Direct3D绘制基础 1. 顶点缓存 计算机所描绘的3D图形是通过多边形网格来构成的,网网格勾勒出轮廓,然后在网格轮廓的表面上贴上相应的图片,这样就构成了一个3D模型.三角形网格是构建物 ...

随机推荐

  1. mysql注释方法【自用】

    原文链接:https://www.jb51.net/article/125991.htm 一.MySQL支持三种注释方式: 1.从‘#'字符从行尾. 2.从‘-- '序列到行尾.请注意‘-- '(双破 ...

  2. web开发中如何使用引用字体

    1.在style中添加代码: @font-face { font-family: mFont; src: url('../font/crapaud_petit.ttf'); } 2.使用 <h1 ...

  3. java自动化-实际使用junit的演示

    本文简单介绍一下我写的http接口后端框架 在经过之前多篇博客介绍之后,读者应掌握如下技能 1,自动运行一个或者多个junit框架编写的java代码 2,对数据驱动以及关键字驱动有一定的了解和认识,甚 ...

  4. vue定义全局组件

    <!DOCTYPE html><html> <head> <meta charset="utf-8"> <title>& ...

  5. [shell] if语句用法

    bash中如何实现条件判断?条件测试类型:    整数测试    字符测试    文件测试 一.条件测试的表达式:    [ expression ]  括号两端必须要有空格    [[ expres ...

  6. win7下配置mysql的my.ini文件

    一.环境 操作系统是win7 x64, mysql是5.6.40. 二. 怎么配置? 修改my.ini文件, 添加[client], 在下面加一行 default-character-set=utf8 ...

  7. angular学习笔记(三)

    1.安装npm install --save @angular/material@2.0.0-beta.72.安装http://chrome-extension-downloader.com安装aug ...

  8. Chrome+postman+postman interceptor调试

    本文使用chrome+postman4.8.3+postman interceptor0.2.23调试使用cookie的请求. postman4.8.3下载地址:https://pan.baidu.c ...

  9. PostgreSQL+PostGIS 的使用

    一.PostGIS中的几何类型 PostGIS支持所有OGC规范的“Simple Features”类型,同时在此基础上扩展了对3DZ.3DM.4D坐标的支持. 1. OGC的WKB和WKT格式 OG ...

  10. [Swift]LeetCode324. 摆动排序 II | Wiggle Sort II

    Given an unsorted array nums, reorder it such that nums[0] < nums[1] > nums[2] < nums[3]... ...