MindSpore函数拟合
技术背景
在前面一篇博客中我们介绍过基于docker的mindspore编程环境配置,这里我们基于这个环境,使用mindspore来拟合一个线性的函数,演示一下mindspore的基本用法。
环境准备
在Manjaro Linux上先用如下命令启动docker容器服务,启动后可用status查看状态:
[dechin-manjaro mindspore]# systemctl start docker
[dechin-manjaro mindspore]# systemctl status docker
● docker.service - Docker Application Container Engine
Loaded: loaded (/usr/lib/systemd/system/docker.service; disabled; vendor preset: disabled)
Active: active (running) since Wed 2021-04-14 16:32:38 CST; 9s ago
TriggeredBy: ● docker.socket
Docs: https://docs.docker.com
Main PID: 298485 (dockerd)
Tasks: 99 (limit: 47875)
Memory: 186.0M
CGroup: /system.slice/docker.service
├─298485 /usr/bin/dockerd -H fd://
└─298496 containerd --config /var/run/docker/containerd/containerd.toml --log-level info
在按照这篇博客的方法下载下来mindspore的容器镜像之后,可以在本地的镜像仓库中查询到该镜像:
[dechin-root mindspore]# docker images
REPOSITORY TAG IMAGE ID
swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4
容器的启动方式可以参考如下指令:
[dechin-root mindspore]# docker run -it 98a3
root@2a6c33894e53:~# python
Python 3.7.5 (default, Feb 8 2021, 02:21:05)
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>
这里可以看到在这个容器镜像中是预装了python3.7.5版本的mindspore的,可以在python的命令行中用如下的方法进行验证:
>>> from mindspore import context
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(20:139876984823936,MainProcess):2021-04-14-08:37:40.331.840 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
除了mindspore自身之外,我们还经常可能用到一些第三方的库,如matplotlib等,我们可以自行安装:
root@2a6c33894e53:~# python -m pip install matplotlib
Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting matplotlib
Downloading http://mirrors.aliyun.com/pypi/packages/ce/63/74c0b6184b6b169b121bb72458818ee60a7d7c436d7b1907bd5874188c55/matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3MB)
|████████████████████████████████| 10.3MB 4.4MB/s
Collecting cycler>=0.10 (from matplotlib)
Downloading http://mirrors.aliyun.com/pypi/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
Collecting kiwisolver>=1.0.1 (from matplotlib)
Downloading http://mirrors.aliyun.com/pypi/packages/d2/46/231de802ade4225b76b96cffe419cf3ce52bbe92e3b092cf12db7d11c207/kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1MB)
|████████████████████████████████| 1.1MB 13.9MB/s
Collecting python-dateutil>=2.7 (from matplotlib)
Downloading http://mirrors.aliyun.com/pypi/packages/d4/70/d60450c3dd48ef87586924207ae8907090de0b306af2bce5d134d78615cb/python_dateutil-2.8.1-py2.py3-none-any.whl (227kB)
|████████████████████████████████| 235kB 4.6MB/s
Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (2.4.7)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (8.1.0)
Requirement already satisfied: numpy>=1.16 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (1.17.5)
Requirement already satisfied: six in /usr/local/python-3.7.5/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)
Installing collected packages: cycler, kiwisolver, python-dateutil, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1 python-dateutil-2.8.1
WARNING: You are using pip version 19.2.3, however version 21.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
root@2a6c33894e53:~# python -m pip install --upgrade pip
Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting pip
Downloading http://mirrors.aliyun.com/pypi/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
|████████████████████████████████| 1.5MB 1.3MB/s
Installing collected packages: pip
Found existing installation: pip 19.2.3
Uninstalling pip-19.2.3:
Successfully uninstalled pip-19.2.3
Successfully installed pip-21.0.1
同样的方法我们再安装一下ipython:
root@b8955ba28950:/home# python -m pip install IPython
Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting IPython
Downloading http://mirrors.aliyun.com/pypi/packages/c9/b1/82cbe2b856386f44f37fdae54d9b425813bd86fe33385c9d658d64826098/ipython-7.22.0-py3-none-any.whl (785 kB)
|████████████████████████████████| 785 kB 1.8 MB/s
Collecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0
Downloading http://mirrors.aliyun.com/pypi/packages/eb/e6/4b4ca4fa94462d4560ba2f4e62e62108ab07be2e16a92e594e43b12d3300/prompt_toolkit-3.0.18-py3-none-any.whl (367 kB)
|████████████████████████████████| 367 kB 818 kB/s
Collecting pickleshare
Downloading http://mirrors.aliyun.com/pypi/packages/9a/41/220f49aaea88bc6fa6cba8d05ecf24676326156c23b991e80b3f2fc24c77/pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)
Collecting pygments
Downloading http://mirrors.aliyun.com/pypi/packages/3a/80/a52c0a7c5939737c6dca75a831e89658ecb6f590fb7752ac777d221937b9/Pygments-2.8.1-py3-none-any.whl (983 kB)
|████████████████████████████████| 983 kB 2.7 MB/s
Requirement already satisfied: decorator in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (4.4.2)
Collecting traitlets>=4.2
Downloading http://mirrors.aliyun.com/pypi/packages/f6/7d/3ecb0ebd0ce8dcdfa7bd47ab85c1d4a521e6770ef283d0824f5804994dfe/traitlets-5.0.5-py3-none-any.whl (100 kB)
|████████████████████████████████| 100 kB 4.0 MB/s
Collecting pexpect>4.3
Downloading http://mirrors.aliyun.com/pypi/packages/39/7b/88dbb785881c28a102619d46423cb853b46dbccc70d3ac362d99773a78ce/pexpect-4.8.0-py2.py3-none-any.whl (59 kB)
|████████████████████████████████| 59 kB 5.9 MB/s
Collecting jedi>=0.16
Downloading http://mirrors.aliyun.com/pypi/packages/f9/36/7aa67ae2663025b49e8426ead0bad983fee1b73f472536e9790655da0277/jedi-0.18.0-py2.py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 3.7 MB/s
Collecting backcall
Downloading http://mirrors.aliyun.com/pypi/packages/4c/1c/ff6546b6c12603d8dd1070aa3c3d273ad4c07f5771689a7b69a550e8c951/backcall-0.2.0-py2.py3-none-any.whl (11 kB)
Requirement already satisfied: setuptools>=18.5 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (41.2.0)
Collecting parso<0.9.0,>=0.8.0
Downloading http://mirrors.aliyun.com/pypi/packages/a9/c4/d5476373088c120ffed82f34c74b266ccae31a68d665b837354d4d8dc8be/parso-0.8.2-py2.py3-none-any.whl (94 kB)
|████████████████████████████████| 94 kB 6.0 MB/s
Collecting ptyprocess>=0.5
Downloading http://mirrors.aliyun.com/pypi/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)
Collecting wcwidth
Downloading http://mirrors.aliyun.com/pypi/packages/59/7c/e39aca596badaf1b78e8f547c807b04dae603a433d3e7a7e04d67f2ef3e5/wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
Collecting ipython-genutils
Downloading http://mirrors.aliyun.com/pypi/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl (26 kB)
Installing collected packages: wcwidth, ptyprocess, parso, ipython-genutils, traitlets, pygments, prompt-toolkit, pickleshare, pexpect, jedi, backcall, IPython
WARNING: The script pygmentize is installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
WARNING: The scripts iptest, iptest3, ipython and ipython3 are installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
Successfully installed IPython-7.22.0 backcall-0.2.0 ipython-genutils-0.2.0 jedi-0.18.0 parso-0.8.2 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.18 ptyprocess-0.7.0 pygments-2.8.1 traitlets-5.0.5 wcwidth-0.2.5
安装过程中都没有出现其他的依赖问题,接下来我们可以在docker容器中保存这些已经安装的库,避免下一次使用的时候还需要再安装一次。在用exit
退出当前容器镜像之后,可以用docker ps
指令查看近期的操作记录:
[dechin-root mindspore]# docker ps -n 3
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
2a6c33894e53 98a3 "/bin/bash" 13 minutes ago Exited (0) 7 seconds ago upbeat_tharp
625ee5f4ee95 ea1c "bash" 9 days ago Exited (0) 9 days ago zealous_mccarthy
ded2cb29290a kivy/buildozer "buildozer bash -c '…" 9 days ago Exited (1) 9 days ago exciting_lumiere
这里第一个操作记录就是我们需要保存的mindspore的镜像,那么我们可以用docker commit
的指令将操作保存到一个新的镜像里面:
[dechin-root mindspore]# docker commit 2a6c mindspore
sha256:3a6951d9b9009f93027748ecec78078efff1fb36599a5786bcbc667e72119392
上面的执行反馈表示运行成功了,再次查看本地镜像内容:
[dechin-root mindspore]# docker images
REPOSITORY TAG IMAGE ID CREATED SIZE
mindspore latest 3a6951d9b900 31 seconds ago 1.22GB
swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4 2 months ago 1.18GB
可以看到我们的基础镜像环境已经制作完成了,在原镜像的基础上多了40M左右的空间。本章节的最后我们也说明一下,mindspore提供的这个镜像的基础系统环境为Ubuntu18.04
:
root@b8955ba28950:/home# cat /etc/issue
Ubuntu 18.04.5 LTS \n \l
MindSpore线性函数拟合
假设有如下图中红点所示的一系列散点,或者可以认为是需要我们来执行训练的数据。而图中的绿线表示真实的函数,也就是说我们是基于这样一个真实的线性函数,来生成了一系列加随机噪声的散点。最终我们的目的当然是希望能够通过这些散点将线性的函数再拟合出来,这样就可以用来预测下一个位置的函数值,相关技术用在量化金融领域,就可以预测下一步股市的价格,当然那样的函数就会更加的复杂。
对应于图中的函数,我们给定的是:
\]
生成散点数据集
加噪声的方法在get_data
函数中体现,其中生成数据集的方法为:先在\([-10,10]\)的范围内生成一系列的随机\(x\)自变量值,然后生成一系列的正态分布随机数作为噪声,把这些噪声加到自变量值所对应的\(f(x)\)函数值上,就得到了原始数据。当然,这里没有用return
进行返回,而是用yield
的形式逐一返回。
第二步我们需要将这些数据集转化为mindspore所能够识别的数据格式:mindspore.dataset.GeneratorDataset
,除了可以给\(x\)和\(y\)分别配置一个变量名之外,还可以指定这些数据集的分组(batch)和重复次数,其中分组数量的配置是有可能影响到最终的训练速率的。
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50)) # 生成50个带噪声的随机点
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3 # 期望的函数值
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
print("The dataset size of ds_train:", ds_train.get_dataset_size())
dict_datasets = next(ds_train.create_dict_iterator())
print(dict_datasets.keys())
print("The x label value shape:", dict_datasets["data"].shape)
print("The y label value shape:", dict_datasets["label"].shape)
上述代码的执行效果如下:
root@b8955ba28950:/home# python test_linear.py
The dataset size of ds_train: 100
dict_keys(['data', 'label'])
The x label value shape: (16, 1)
The y label value shape: (16, 1)
到这里为止,我们就已经构造了一个1600个训练的数据,并且分为了100个batch进行训练,每个batch的大小为16。
构建拟合模型与初始参数
用mindspore.nn.Dense
的方法我们可以构造一个线性拟合的模型:
\]
关于该激活函数的官方文档说明如下:
而这里面的weight
和bias
的初始化参数是由一个张量形式的数据结构来定义的,我们给了一个入参nn.Dense(1, 1, Normal(0.02), Normal(0.02))
表示两组参数,都是一维的张量(或称为1阶的张量),而这两个初始化张量的元素是由两个\(N(0,\sigma)\)正态分布所生成的随机化初始数据,比如在该案例中我们可以试着将这些初始化的参数打印出来:
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
for param in model_params:
print(param, param.asnumpy())
执行结果如下,是两个一维的数组:数组
root@b8955ba28950:/home# python test_linear.py
Parameter (name=fc.weight) [[-0.00252427]]
Parameter (name=fc.bias) [0.00694926]
在上述代码中虽然打印了两个参数值,但是并不是很直观,我们可以将这组参数值所对应的函数图画在刚才的散点图中看看效果:
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Tensor
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
Tensor(model_params[1]).asnumpy()[0])
plt.axis([-10, 10, -20, 25])
plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_model_label, y_model_label, color="blue")
plt.plot(x_target_label, y_target_label, color="green")
plt.savefig('initial.png')
执行后会在当前目录生成一个名为initial.png
的图片:
可以看到此时的参数所对应的函数距离我们所预期的还是比较远的。
训练与可视化
在前面的技术铺垫之后,这一步终于可以开始训练了。在机器学习中,我们需要先定义好一个用于衡量结果好坏的函数,一般可以称之为损失函数(Loss Function)。损失函数值越小,代表结果就越好,在我们面对的这个函数拟合问题中所代表的就是,拟合的效果越好。这里我们采取的是均方误差函数(Mean Square Error,简称MSE):
均方误差是最常使用的损失函数,因为不管是往哪个方向的偏移,都会导致损失函数值的急剧增大。在定义好损失函数之后,我们需要定义一个前向传播网络,用于执行损失函数的计算,这里我们直接使用了mindspore定义好的接口:mindspore.nn.loss.MSELoss
:
在计算好对应参数的损失函数值之后,我们需要更新迭代参数,计算下一组参数的损失函数值,以确定向哪个方向“前进”才能找到最终的最低损失函数值。这个参数迭代的功能由反向传播网络实现,常用的参数更新算法有梯度下降等,关于梯度下降算法,在前面写过的这篇博客中有比较详细的介绍。其基本计算公式如下:
在mindspore中优化函数的接口为mindspore.nn.Momentum
:
这些模型都定义好之后,可以用mindspore.Model
进行封装和训练。
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Tensor, Model
import time
from IPython import display
from mindspore.train.callback import Callback, LossMonitor
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w + b + noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 + 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
Tensor(model_params[1]).asnumpy()[0])
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, opt)
fig = plt.figure()
ims = []
def plot_model_and_datasets(net, eval_data):
weight = net.trainable_params()[0]
bias = net.trainable_params()[1]
x = np.arange(-10, 10, 0.1)
y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
x1, y1 = zip(*eval_data)
x_target = x
y_target = x_target * 2 + 3
plt.axis([-11, 11, -20, 25])
plt.scatter(x1, y1, color="red", s=5)
im = plt.plot(x, y, color="blue")
ims.append(im)
im1 = plt.plot(x_target, y_target, color="green")
ims.append(im1)
time.sleep(0.2)
class ImageShowCallback(Callback):
def __init__(self, net, eval_data):
self.net = net
self.eval_data = eval_data
def step_end(self, run_context):
plot_model_and_datasets(self.net, self.eval_data)
display.clear_output(wait=True)
epoch = 1
imageshow_cb = ImageShowCallback(net, eval_data)
model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)
plot_model_and_datasets(net, eval_data)
for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
ani.save('train.gif', writer='pillow')
执行结果如下:
root@b8955ba28950:/home# python test_linear.py
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(444:140374496206976,MainProcess):2021-04-14-09:28:58.738.627 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
Parameter (name=fc.weight) [[1.8964282]]
Parameter (name=fc.bias) [3.0266616]
执行完成后会在当前目录下生成一个名为train.gif
的动态图,演示整个训练优化的过程:
其中红色散点是训练数据,绿色直线是原始函数,蓝色直线是训练后的函数,可以看到两个函数是越来越接近的。最后拟合出来的函数为:
\]
与我们所预期的:
\]
还是略有差距,但是这其中的可能原因有很多,有可能是生成的随机散点的问题,也有可能是在这个范围内的线段拟合就是有这么大的误差,这里我们不做展开。到这里为止,我们就成功的使用mindspore完成了一个函数拟合的任务。
python绘制动态函数图
在上一个章节中我们演示了使用mindspore完成了一个线性函数的拟合,最后的代码中其实已经使用到了动态图的绘制方法,这里单独抽取出来作为一个章节来介绍。我们所使用到的工具是matplotlib.animation
,使用的第一步是在训练的外部先生成一个动态图像的对象:
fig = plt.figure()
ims = []
其中ims
是用于存储每一帧的数据绘制内容。第二步是将训练过程中需要变化的绘图对象添加到ims
中:
im = plt.plot(x, y, color="blue")
ims.append(im)
im1 = plt.plot(x_target, y_target, color="green")
ims.append(im1)
最后根据绘制的图的对象fig
和变化的图像集合ims
来生成一个动态图并且保存到本地文件中:
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
ani.save('train.gif', writer='pillow')
关于animation.ArtistAnimation
的接口参数如下所示:
这里每一帧之间的间隔时间我们定义为500ms
,重复播放1000次,基本可以认为是一直在重复播放的。最终的效果图在上一个章节中已经做了展示,这里就不再重复说明。需要注意的是,生成动态图的过程会比较漫长,而且只有通过animation才能够生成和保存gif
动态图,直接通过plt.savefig
是无法保存为动态图的。
总结概要
很多机器学习的算法的基础就是函数的拟合,这里我们考虑的是其中一种最简单也最常见的场景:线性函数的拟合,并且我们要通过mindspore来实现这个数据的训练。通过构造均方误差函数,配合前向传播网络与反向传播网络的使用,最终大体成功的拟合了给定的一个线性函数。文末我们还顺带介绍了使用matplotlib的animation来生成动态图的功能,可视化的展现了整个训练的过程。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/linear.html
作者ID:DechinPhy
更多原著文章请参考:https://www.cnblogs.com/dechinphy/
参考链接
- https://www.mindspore.cn/tutorial/training/zh-CN/master/quick_start/linear_regression.html
- https://blog.csdn.net/clksjx/article/details/105720120
- https://www.cnblogs.com/dechinphy/p/gradient.html
MindSpore函数拟合的更多相关文章
- matlab函数拟合
1 函数拟合 函数拟合在工程(如采样校正)和数据分析(如隶属函数确定)中都是非常有用的工具.我这里将函数拟合分为三类:分别是多项式拟合,已知函数类型的拟合和未知函数类型的拟合.matlab中关于函数的 ...
- MATLAB用“fitgmdist”函数拟合高斯混合模型(一维数据)
MATLAB用“fitgmdist”函数拟合高斯混合模型(一维数据) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 在MATLAB中“fitgmdis ...
- matlab多项式拟合以及指定函数拟合
clc;clear all;close all;%% 多项式拟合指令:% X = [1 2 3 4 5 6 7 8 9 ];% Y = [9 7 6 3 -1 2 5 7 20]; % P= poly ...
- TensorFlow-正弦函数拟合
MNIST的代码还是有点复杂,一大半内容全在搞数据,看了半天全是一滩烂泥.最关键的是最后输出就是一个accuracy,我根本就不关心你准确率是0.98还是0.99好吗?我就想看到我手写一个5,你程序给 ...
- MATLAB神经网络(2) BP神经网络的非线性系统建模——非线性函数拟合
2.1 案例背景 在工程应用中经常会遇到一些复杂的非线性系统,这些系统状态方程复杂,难以用数学方法准确建模.在这种情况下,可以建立BP神经网络表达这些非线性系统.该方法把未知系统看成是一个黑箱,首先用 ...
- MATLAB神经网络(3) 遗传算法优化BP神经网络——非线性函数拟合
3.1 案例背景 遗传算法(Genetic Algorithms)是一种模拟自然界遗传机制和生物进化论而形成的一种并行随机搜索最优化方法. 其基本要素包括:染色体编码方法.适应度函数.遗传操作和运行参 ...
- 使用MindSpore的线性神经网络拟合非线性函数
技术背景 在前面的几篇博客中,我们分别介绍了MindSpore的CPU版本在Docker下的安装与配置方案.MindSpore的线性函数拟合以及MindSpore后来新推出的GPU版本的Docker编 ...
- cftool拟合&函数逼近
cftool拟合&函数逼近 cftool 真是神奇,之前我们搞的一些线性拟合解方程,多项式拟合,函数拟合求参数啊,等等. 已经超级多了,为啥还得搞一个cftool拟合啊?而且毫无数学理论. 如 ...
- Python数据处理——绘制函数图形以及数据拟合
1.多项式拟合 对散点进行多项式拟合并打印出拟合函数以及拟合后的图形import matplotlib.pyplot as pltimport numpy as npx=np.arange(1,17, ...
随机推荐
- BGV再掀DeFi投资热潮,NGK全球启动大会圆满落幕
此次全球启动大会的主题为"BGV再掀DeFi投资热潮,后市发展如何". 首先发言的是NGK灵石团队首席技术官STEPHEN先生,他先是对出席此次大会的嘉宾.到场的媒体记者以及NGK ...
- Captain Technology INC浅谈新能源汽车的未来
近日全世界上最大的资管公司贝莱德向位于的英国电动汽车初创公司Arrival投资1.18亿美元,且该公司已有投资者亚马逊和美国第二大汽车制造商福特汽车参投.中国最知名的电动车公司蔚来股价单日大涨22%, ...
- 离场定高转弯DF与CF的对比
也许是刚学会CAD的缘故,配合风螺旋插件,画图的感觉真是蛮爽的,忍不住画了一张又一张. 接着昨天的离场保护区,我们来聊一下PBN指定高度转弯保护区的画法.指定高度转弯的计算本身没有太多复杂的地方,真正 ...
- Spring 中的 MetaData 接口
什么是元数据(MetaData) 先直接贴一个英文解释: Metadata is simply data about data. It means it is a description and co ...
- css故障文字动画
免费分享95套java实战项目,不仅有源码还有对应的开发视频,关注公众号『勾玉技术』回复"95"即可获取 首先给内容上hover和before, .glitch:hover:bef ...
- 下载com.springsource.org.aspectj.weaver-1.6.8.RELEASE.jar
看别人都说在repo.maven.com下载,没想到竟然要登录 索性我直接在国内阿里云的镜像仓库下载好了,速度又快又方便 搜索aspectj 下载地址:https://maven.aliyun.com ...
- Ajax的基本用法
1.介绍 2.基本用法 2.1原生写法 $.ajax({ url: url, //是否是异步请求,默认是 // async: false, //请求方式,默认是get //type:'get', // ...
- (数据科学学习手札109)Python+Dash快速web应用开发——静态部件篇(中)
本文示例代码已上传至我的Github仓库https://github.com/CNFeffery/DataScienceStudyNotes 1 简介 这是我的系列教程Python+Dash快速web ...
- MongoDB 在评论中台的实践
本文主要讲述 vivo 评论中台在数据库设计上的技术探索和实践. 一.业务背景 随着公司业务发展和用户规模的增多,很多项目都在打造自己的评论功能,而评论的业务形态基本类似.当时各项目都是各自设计实现, ...
- go 报 need type assertion
responese_total := m["responses"].([]interface{})[0].(map[string]interface{})["hits&q ...