这篇文章中,我们将使用TensorFlow.js来根据数据拟合曲线。即使用多项式产生数据然后再改变其中某些数据(点),然后我们会训练模型来找到用于产生这些数据的多项式的系数。简单的说,就是给一些在二维坐标中的散点图,然后我们建立一个系数未知的多项式,通过TensorFlow.js来训练模型,最终找到这些未知的系数,让这个多项式和散点图拟合。

  

一、运行代码

  这篇文章关注的是创建模型以及学习模型的系数,完整的代码在这里可以找到。为了在本地运行,如下所示:

$ git clone https://github.com/tensorflow/tfjs-examples.git
$ cd tfjs-examples/polynomial-regression-core
$ yarn
$ yarn watch

  即首先将核心代码下载到本地,然后进入polynomial-regression-core(即多项式回归核心)部分,最后进行yarn安装并运行。

二、输入数据

  我们的数据在x坐标轴和y坐标轴内,看上去就是将之放在了笛卡尔坐标系,如下所示:

  4

  

  即这个图是 y = ax3 + bx2 + cx + d得到的,而在上图中,我们也看到了其真实系数为a=-0.800,b=-0.200,c=0.900,d=0.500,然后这些点是根据真实的点做了一定的偏移。  

  我们的任务就是通过机器学习得到这个函数的系数a、b、c以及d来最好的匹配这些数据。接下来,我们就看看如何通过TensorFlow.js来学习得到这些数据。

三、学习步骤

第一步 :设置变量

  首先,我们需要创建一些变量。即开始我们是不知道a、b、c、d的值的,所以先给他们一个随机数,入戏所示:

const a = tf.variable(tf.scalar(Math.random()));
const b = tf.variable(tf.scalar(Math.random()));
const c = tf.variable(tf.scalar(Math.random()));
const d = tf.variable(tf.scalar(Math.random()));

第二步:创建模型

  我们可以通过TensorFlow.js中的链式调用操作来实现这个多项式方程  y = ax3 + bx2 + cx + d,下面的代码就创建了一个 predict 函数,这个函数将x作为输入,y作为输出:

function predict(x) {
// y = a * x ^ 3 + b * x ^ 2 + c * x + d
return tf.tidy(() => {
return a.mul(x.pow(tf.scalar())) // a * x^3
.add(b.mul(x.square())) // + b * x ^ 2
.add(c.mul(x)) // + c * x
.add(d); // + d
});
}

  其中,在上一篇文章中,我们讲到tf.tify函数用来清除中间张量,其他的都很好理解。

  接着,让我们把这个多项式函数的系数使用之前得到的随机数,可以看到,得到的图应该是这样:

  因为开始时,我们使用的系数是随机数,所以这个函数和给定的数据匹配的非常差,而我们写的模型就是为了通过学习得到更精确的系数值。

第三步:训练模型

  最后一步就是要训练这个模型使得系数和这些散点更加匹配,而为了训练模型,我们需要定义下面的三样东西:

  • 损失函数(loss function):这个损失函数代表了给定多项式和数据的匹配程度。 损失函数值越小,那么这个多项式和数据就跟匹配。
  • 优化器(optimizer):这个优化器实现了一个算法,它会基于损失函数的输出来修正系数值。所以优化器的目的就是尽可能的减小损失函数的值。
  • 训练迭代器(traing loop):即它会不断地运行这个优化器来减少损失函数。

  所以,上面这三样东西的 关系就非常清楚了: 训练迭代器使得优化器不断运行,使得损失函数的值不断减小,以达到多项式和数据尽可能匹配的目的。这样,最终我们就可以得到a、b、c、d较为精确的值了。

  

四、定义损失函数

  这篇文章中,我们使用MSE(均方误差,mean squared error)作为我们的损失函数。MSE的计算非常简单,就是先根据给定的x得到实际的y值与预测得到的y值之差 的平方,然后在对这些差的平方求平均数即可

  

  于是,我们可以这样定义MSE损失函数:

function loss(predictions, labels) {
// 将labels(实际的值)进行抽象
// 然后获取平均数.
const meanSquareError = predictions.sub(labels).square().mean();
return meanSquareError;
}

   即这个损失函数返回的就是一个均方差,如果这个损失函数的值越小,显然数据和系数就拟合的越好。

  

五、定义优化器

  对于我们的优化器而言,我们选用 SGD (Stochastic Gradient Descent)优化器,即随机梯度下降SGD的工作原理就是利用数据中任意的点的梯度以及使用它们的值来决定增加或者减少我们模型中系数的值

  TensorFlow.js提供了一个很方便的函数用来实现SGD,所以你不需要担心自己不会这些特别复杂的数学运算。 即 tf.train.sdg 将一个学习率(learning rate)作为输入,然后返回一个SGDOptimizer对象,它与优化损失函数的值是有关的。

  在提高它的预测能力时,学习率(learning rate)会控制模型调整幅度将会有多大。低的学习率会使得学习过程运行的更慢一些(更多的训练迭代获得更符合数据的系数),而高的学习率将会加速学习过程但是将会导致最终的模型可能在正确值周围摇摆。简单的说,你既想要学的快,又想要学的好,这是不可能的。

  下面的代码就创建了一个学习率为0.5的SGD优化器。

const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

  

六、定义训练迭代器

  既然我们已经定义了损失函数和优化器,那么现在我们就可以创建一个训练迭代器了,它会不断地运行SGD优化器来使不断修正、完善模型的系数来减小损失(MSE)。下面就是我们创建的训练迭代器:

function train(xs, ys, numIterations = ) {

  const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate); for (let iter = ; iter < numIterations; iter++) {
optimizer.minimize(() => {
const predsYs = predict(xs);
return loss(predsYs, ys);
});
}
}

  现在,让我们一步一步地仔细看看上面的代码。首先,我们定义了训练函数,并且以数据中x和y的值以及制定的迭代次数作为输入:

function train(xs, ys, numIterations) {
...
}

  接下来,我们定义了之前讨论过的学习率(learning rate)以及SGD优化器:

const learningRate = 0.5;
const optimizer = tf.train.sgd(learningRate);

  

  最后,我们定义了一个for循环,这个循环会运行numIterations次训练。在每一次迭代中,我们都调用了optimizer优化器的minimize函数,这就是见证奇迹的地方:

for (let iter = ; iter < numIterations; iter++) {
optimizer.minimize(() => {
const predsYs = predict(xs);
return loss(predsYs, ys);
});
}

  minimize 接受了一个函数作为参数,这个函数做了下面的两件事情:

  1. 首先它对所有的x值通过我们在之前定义的pridict函数预测了y值。
  2. 然后它通过我们之前定义的损失函数返回了这些预测的均方误差。

    

  minimize函数之后会自动调整这些变量(即系数a、b、c、d)来使得损失函数更小。

  在运行训练迭代器之后,a、b、c以及d就会是通过模型75次SGD迭代之后学习到的结果了。

  

七、观察结果吧!

  一旦程序运行结束,我们就可以得到最终的a、b、c和d的结果了,然后使用它们来绘制曲线,如下所示:

  这个结果已经比开始随机分配系数的结果拟合的好得多了!

TensorFlow.js之根据数据拟合曲线的更多相关文章

  1. TensorFlow.js入门(一)一维向量的学习

    TensorFlow的介绍   TensorFlow是谷歌基于DistBelief进行研发的第二代人工智能学习系统,其命名来源于本身的运行原理.Tensor(张量)意味着N维数组,Flow(流)意味着 ...

  2. 转《在浏览器中使用tensorflow.js进行人脸识别的JavaScript API》

    作者 | Vincent Mühle 编译 | 姗姗 出品 | 人工智能头条(公众号ID:AI_Thinker) [导读]随着深度学习方法的应用,浏览器调用人脸识别技术已经得到了更广泛的应用与提升.在 ...

  3. TensorFlow.js之安装与核心概念

    TensorFlow.js是通过WebGL加速.基于浏览器的机器学习js框架.通过tensorflow.js,我们可以在浏览器中开发机器学习.运行现有的模型或者重新训练现有的模型. 一.安装     ...

  4. 大前端技术系列:TWA技术+TensorFlow.js => 集成原生和AI功能的app

    大前端技术系列:TWA技术+TensorFlow.js => 集成原生和AI功能的app ( 本文内容为melodyWxy原作,git地址:https://github.com/melodyWx ...

  5. TensorFlow.js入门:一维向量的学习

    转载自:https://blog.csdn.net/weixin_34061042/article/details/89700664 一维向量及其运算 tensor 是 TensorFlow.js 的 ...

  6. 【一统江湖的大前端(9)】TensorFlow.js 开箱即用的深度学习工具

    示例代码托管在:http://www.github.com/dashnowords/blogs 博客园地址:<大史住在大前端>原创博文目录 目录 一. 上手TensorFlow.js 二. ...

  7. js声明json数据,打印json数据,遍历json数据

    1.js声明json数据: 2.打印json数据: 3.遍历json数据 //声明JSON var json = {}; json.a = 1; //第一种赋值方式(仿对象型) json['b'] = ...

  8. 通过js获取前台数据向一般处理程序传递Json数据,并解析Json数据,将前台传来的Json数据写入数据库表中

    摘自:http://blog.csdn.net/mazhaojuan/article/details/8592015 通过js获取前台数据向一般处理程序传递Json数据,并解析Json数据,将前台传来 ...

  9. 抓取Js动态生成数据且以滚动页面方式分页的网页

    代码也可以从我的开源项目HtmlExtractor中获取. 当我们在进行数据抓取的时候,如果目标网站是以Js的方式动态生成数据且以滚动页面的方式进行分页,那么我们该如何抓取呢? 如类似今日头条这样的网 ...

随机推荐

  1. 第12章:MongoDB-CRUD操作--文档--查询--游标详解

    ①是什么游标 游标不是查询结果,可以理解为数据在遍历过程中的内部指针,其返回的是一个资源,或者说数据读取接口. 客户端通过对游标进行一些设置就能对查询结果进行有效地控制,如可以限制查询得到的结果数量. ...

  2. (转)JDK安装配置教程

    转自:http://jingyan.baidu.com/article/bea41d435bc695b4c41be648.html JDK作为JAVA开发的环境,不管是做JAVA开发的学生,还是做安卓 ...

  3. Android 百度云推送

    http://developer.baidu.com/wiki/index.php?title=docs/cplat/push/sdk/clientsdk. public class DemoAppl ...

  4. [zjoi2010]cheese

    题目: 贪吃的老鼠(cheese.c/cpp/pas/in/out) 时限:每个测试点10秒 [问题描述] 奶酪店里最近出现了m只老鼠!它们的目标就是把生产出来的所有奶酪都吃掉.奶酪店中一天会生产n块 ...

  5. AngularJS 杂项知识点

    1.要用ngChange要同时使用ngModel,下拉选择获取当前选中值. 2.打包代替动态加载(js文件) requirejs真正的价值在于模块化,不是动态加载,angularjs本身有模块化机制, ...

  6. ASP.NET Web API 框架研究 Controller创建 HttpController 类型解析 选择 创建

    上一篇介绍了HttpController的一些细节,接下来说下HttpController 类型解析.选择和创建.生产HttpController实例的生产线如下图: 一.涉及的类及源码分析 涉及的类 ...

  7. [转载]持续交付和DevOps的前世今生

    作者/分享人:乔梁,20年IT老兵,腾讯公司高级管理顾问,敏捷和精益开发专家,持续交付领域先行者.曾就职于百度,国内多个知名互联网公司的企业教练. 历年QCon技术大会的讲师和专题出品人. 这是一个新 ...

  8. MySQL 5.7并发复制和mysqldump相互阻塞引起的复制延迟

    本来MySQL BINLOG和mysqldump命令属于八竿子打不着的两个事物,但在最近故障排查中,发现主库和从库已经存在很严重的复制延迟,但从库上显示slave_behind_master值为0,复 ...

  9. .NET Core 微服务之grpc 初体验(干货)

    Grpc介绍 GitHub: https://github.com/grpc/grpc gRPC是一个高性能.通用的开源RPC框架,其由Google主要面向移动应用开发并基于HTTP/2协议标准而设计 ...

  10. 在Asp.Net MVC中利用快递100接口实现订阅物流轨迹功能

    前言 分享一篇关于在电商系统中同步物流轨迹到本地服务器的文章,当前方案使用了快递100做为数据来源接口,这个接口是收费的,不过提供的功能还是非常强大的,有专门的售后维护团队.也有免费的方案,类似于快递 ...