技术背景

在前面一篇博客中我们介绍过基于docker的mindspore编程环境配置,这里我们基于这个环境,使用mindspore来拟合一个线性的函数,演示一下mindspore的基本用法。

环境准备

在Manjaro Linux上先用如下命令启动docker容器服务,启动后可用status查看状态:

  1. [dechin-manjaro mindspore]# systemctl start docker
  2. [dechin-manjaro mindspore]# systemctl status docker
  3. docker.service - Docker Application Container Engine
  4. Loaded: loaded (/usr/lib/systemd/system/docker.service; disabled; vendor preset: disabled)
  5. Active: active (running) since Wed 2021-04-14 16:32:38 CST; 9s ago
  6. TriggeredBy: docker.socket
  7. Docs: https://docs.docker.com
  8. Main PID: 298485 (dockerd)
  9. Tasks: 99 (limit: 47875)
  10. Memory: 186.0M
  11. CGroup: /system.slice/docker.service
  12. ├─298485 /usr/bin/dockerd -H fd://
  13. └─298496 containerd --config /var/run/docker/containerd/containerd.toml --log-level info

在按照这篇博客的方法下载下来mindspore的容器镜像之后,可以在本地的镜像仓库中查询到该镜像:

  1. [dechin-root mindspore]# docker images
  2. REPOSITORY TAG IMAGE ID
  3. swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4

容器的启动方式可以参考如下指令:

  1. [dechin-root mindspore]# docker run -it 98a3
  2. root@2a6c33894e53:~# python
  3. Python 3.7.5 (default, Feb 8 2021, 02:21:05)
  4. [GCC 7.5.0] on linux
  5. Type "help", "copyright", "credits" or "license" for more information.
  6. >>>

这里可以看到在这个容器镜像中是预装了python3.7.5版本的mindspore的,可以在python的命令行中用如下的方法进行验证:

  1. >>> from mindspore import context
  2. WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
  3. [WARNING] ME(20:139876984823936,MainProcess):2021-04-14-08:37:40.331.840 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
  4. >>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')

除了mindspore自身之外,我们还经常可能用到一些第三方的库,如matplotlib等,我们可以自行安装:

  1. root@2a6c33894e53:~# python -m pip install matplotlib
  2. Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
  3. Collecting matplotlib
  4. Downloading http://mirrors.aliyun.com/pypi/packages/ce/63/74c0b6184b6b169b121bb72458818ee60a7d7c436d7b1907bd5874188c55/matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3MB)
  5. |████████████████████████████████| 10.3MB 4.4MB/s
  6. Collecting cycler>=0.10 (from matplotlib)
  7. Downloading http://mirrors.aliyun.com/pypi/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
  8. Collecting kiwisolver>=1.0.1 (from matplotlib)
  9. Downloading http://mirrors.aliyun.com/pypi/packages/d2/46/231de802ade4225b76b96cffe419cf3ce52bbe92e3b092cf12db7d11c207/kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1MB)
  10. |████████████████████████████████| 1.1MB 13.9MB/s
  11. Collecting python-dateutil>=2.7 (from matplotlib)
  12. Downloading http://mirrors.aliyun.com/pypi/packages/d4/70/d60450c3dd48ef87586924207ae8907090de0b306af2bce5d134d78615cb/python_dateutil-2.8.1-py2.py3-none-any.whl (227kB)
  13. |████████████████████████████████| 235kB 4.6MB/s
  14. Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (2.4.7)
  15. Requirement already satisfied: pillow>=6.2.0 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (8.1.0)
  16. Requirement already satisfied: numpy>=1.16 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (1.17.5)
  17. Requirement already satisfied: six in /usr/local/python-3.7.5/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)
  18. Installing collected packages: cycler, kiwisolver, python-dateutil, matplotlib
  19. Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1 python-dateutil-2.8.1
  20. WARNING: You are using pip version 19.2.3, however version 21.0.1 is available.
  21. You should consider upgrading via the 'pip install --upgrade pip' command.
  22. root@2a6c33894e53:~# python -m pip install --upgrade pip
  23. Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
  24. Collecting pip
  25. Downloading http://mirrors.aliyun.com/pypi/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
  26. |████████████████████████████████| 1.5MB 1.3MB/s
  27. Installing collected packages: pip
  28. Found existing installation: pip 19.2.3
  29. Uninstalling pip-19.2.3:
  30. Successfully uninstalled pip-19.2.3
  31. Successfully installed pip-21.0.1

同样的方法我们再安装一下ipython:

  1. root@b8955ba28950:/home# python -m pip install IPython
  2. Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
  3. Collecting IPython
  4. Downloading http://mirrors.aliyun.com/pypi/packages/c9/b1/82cbe2b856386f44f37fdae54d9b425813bd86fe33385c9d658d64826098/ipython-7.22.0-py3-none-any.whl (785 kB)
  5. |████████████████████████████████| 785 kB 1.8 MB/s
  6. Collecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0
  7. Downloading http://mirrors.aliyun.com/pypi/packages/eb/e6/4b4ca4fa94462d4560ba2f4e62e62108ab07be2e16a92e594e43b12d3300/prompt_toolkit-3.0.18-py3-none-any.whl (367 kB)
  8. |████████████████████████████████| 367 kB 818 kB/s
  9. Collecting pickleshare
  10. Downloading http://mirrors.aliyun.com/pypi/packages/9a/41/220f49aaea88bc6fa6cba8d05ecf24676326156c23b991e80b3f2fc24c77/pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)
  11. Collecting pygments
  12. Downloading http://mirrors.aliyun.com/pypi/packages/3a/80/a52c0a7c5939737c6dca75a831e89658ecb6f590fb7752ac777d221937b9/Pygments-2.8.1-py3-none-any.whl (983 kB)
  13. |████████████████████████████████| 983 kB 2.7 MB/s
  14. Requirement already satisfied: decorator in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (4.4.2)
  15. Collecting traitlets>=4.2
  16. Downloading http://mirrors.aliyun.com/pypi/packages/f6/7d/3ecb0ebd0ce8dcdfa7bd47ab85c1d4a521e6770ef283d0824f5804994dfe/traitlets-5.0.5-py3-none-any.whl (100 kB)
  17. |████████████████████████████████| 100 kB 4.0 MB/s
  18. Collecting pexpect>4.3
  19. Downloading http://mirrors.aliyun.com/pypi/packages/39/7b/88dbb785881c28a102619d46423cb853b46dbccc70d3ac362d99773a78ce/pexpect-4.8.0-py2.py3-none-any.whl (59 kB)
  20. |████████████████████████████████| 59 kB 5.9 MB/s
  21. Collecting jedi>=0.16
  22. Downloading http://mirrors.aliyun.com/pypi/packages/f9/36/7aa67ae2663025b49e8426ead0bad983fee1b73f472536e9790655da0277/jedi-0.18.0-py2.py3-none-any.whl (1.4 MB)
  23. |████████████████████████████████| 1.4 MB 3.7 MB/s
  24. Collecting backcall
  25. Downloading http://mirrors.aliyun.com/pypi/packages/4c/1c/ff6546b6c12603d8dd1070aa3c3d273ad4c07f5771689a7b69a550e8c951/backcall-0.2.0-py2.py3-none-any.whl (11 kB)
  26. Requirement already satisfied: setuptools>=18.5 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (41.2.0)
  27. Collecting parso<0.9.0,>=0.8.0
  28. Downloading http://mirrors.aliyun.com/pypi/packages/a9/c4/d5476373088c120ffed82f34c74b266ccae31a68d665b837354d4d8dc8be/parso-0.8.2-py2.py3-none-any.whl (94 kB)
  29. |████████████████████████████████| 94 kB 6.0 MB/s
  30. Collecting ptyprocess>=0.5
  31. Downloading http://mirrors.aliyun.com/pypi/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)
  32. Collecting wcwidth
  33. Downloading http://mirrors.aliyun.com/pypi/packages/59/7c/e39aca596badaf1b78e8f547c807b04dae603a433d3e7a7e04d67f2ef3e5/wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
  34. Collecting ipython-genutils
  35. Downloading http://mirrors.aliyun.com/pypi/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl (26 kB)
  36. Installing collected packages: wcwidth, ptyprocess, parso, ipython-genutils, traitlets, pygments, prompt-toolkit, pickleshare, pexpect, jedi, backcall, IPython
  37. WARNING: The script pygmentize is installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
  38. Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
  39. WARNING: The scripts iptest, iptest3, ipython and ipython3 are installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
  40. Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
  41. Successfully installed IPython-7.22.0 backcall-0.2.0 ipython-genutils-0.2.0 jedi-0.18.0 parso-0.8.2 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.18 ptyprocess-0.7.0 pygments-2.8.1 traitlets-5.0.5 wcwidth-0.2.5

安装过程中都没有出现其他的依赖问题,接下来我们可以在docker容器中保存这些已经安装的库,避免下一次使用的时候还需要再安装一次。在用exit退出当前容器镜像之后,可以用docker ps指令查看近期的操作记录:

  1. [dechin-root mindspore]# docker ps -n 3
  2. CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
  3. 2a6c33894e53 98a3 "/bin/bash" 13 minutes ago Exited (0) 7 seconds ago upbeat_tharp
  4. 625ee5f4ee95 ea1c "bash" 9 days ago Exited (0) 9 days ago zealous_mccarthy
  5. ded2cb29290a kivy/buildozer "buildozer bash -c '…" 9 days ago Exited (1) 9 days ago exciting_lumiere

这里第一个操作记录就是我们需要保存的mindspore的镜像,那么我们可以用docker commit的指令将操作保存到一个新的镜像里面:

  1. [dechin-root mindspore]# docker commit 2a6c mindspore
  2. sha256:3a6951d9b9009f93027748ecec78078efff1fb36599a5786bcbc667e72119392

上面的执行反馈表示运行成功了,再次查看本地镜像内容:

  1. [dechin-root mindspore]# docker images
  2. REPOSITORY TAG IMAGE ID CREATED SIZE
  3. mindspore latest 3a6951d9b900 31 seconds ago 1.22GB
  4. swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4 2 months ago 1.18GB

可以看到我们的基础镜像环境已经制作完成了,在原镜像的基础上多了40M左右的空间。本章节的最后我们也说明一下,mindspore提供的这个镜像的基础系统环境为Ubuntu18.04

  1. root@b8955ba28950:/home# cat /etc/issue
  2. Ubuntu 18.04.5 LTS \n \l

MindSpore线性函数拟合

假设有如下图中红点所示的一系列散点,或者可以认为是需要我们来执行训练的数据。而图中的绿线表示真实的函数,也就是说我们是基于这样一个真实的线性函数,来生成了一系列加随机噪声的散点。最终我们的目的当然是希望能够通过这些散点将线性的函数再拟合出来,这样就可以用来预测下一个位置的函数值,相关技术用在量化金融领域,就可以预测下一步股市的价格,当然那样的函数就会更加的复杂。



对应于图中的函数,我们给定的是:

\[f(x)=2x+3
\]

生成散点数据集

加噪声的方法在get_data函数中体现,其中生成数据集的方法为:先在\([-10,10]\)的范围内生成一系列的随机\(x\)自变量值,然后生成一系列的正态分布随机数作为噪声,把这些噪声加到自变量值所对应的\(f(x)\)函数值上,就得到了原始数据。当然,这里没有用return进行返回,而是用yield的形式逐一返回。

第二步我们需要将这些数据集转化为mindspore所能够识别的数据格式:mindspore.dataset.GeneratorDataset,除了可以给\(x\)和\(y\)分别配置一个变量名之外,还可以指定这些数据集的分组(batch)和重复次数,其中分组数量的配置是有可能影响到最终的训练速率的。

  1. # test_linear.py
  2. from mindspore import context
  3. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from mindspore import dataset as ds
  7. def get_data(num, w=2.0, b=3.0):
  8. for _ in range(num):
  9. x = np.random.uniform(-10.0, 10.0)
  10. noise = np.random.normal(0, 1)
  11. y = x * w + b + noise
  12. yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
  13. eval_data = list(get_data(50)) # 生成50个带噪声的随机点
  14. x_target_label = np.array([-10, 10, 0.1])
  15. y_target_label = x_target_label * 2 + 3 # 期望的函数值
  16. x_eval_label,y_eval_label = zip(*eval_data)
  17. def create_dataset(num_data, batch_size=16, repeat_size=1):
  18. input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
  19. input_data = input_data.batch(batch_size)
  20. input_data = input_data.repeat(repeat_size)
  21. return input_data
  22. data_number = 1600
  23. batch_number = 16
  24. repeat_number = 1
  25. ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
  26. print("The dataset size of ds_train:", ds_train.get_dataset_size())
  27. dict_datasets = next(ds_train.create_dict_iterator())
  28. print(dict_datasets.keys())
  29. print("The x label value shape:", dict_datasets["data"].shape)
  30. print("The y label value shape:", dict_datasets["label"].shape)

上述代码的执行效果如下:

  1. root@b8955ba28950:/home# python test_linear.py
  2. The dataset size of ds_train: 100
  3. dict_keys(['data', 'label'])
  4. The x label value shape: (16, 1)
  5. The y label value shape: (16, 1)

到这里为止,我们就已经构造了一个1600个训练的数据,并且分为了100个batch进行训练,每个batch的大小为16。

构建拟合模型与初始参数

mindspore.nn.Dense的方法我们可以构造一个线性拟合的模型:

\[f(x)=wx+b
\]

关于该激活函数的官方文档说明如下:



而这里面的weightbias的初始化参数是由一个张量形式的数据结构来定义的,我们给了一个入参nn.Dense(1, 1, Normal(0.02), Normal(0.02))表示两组参数,都是一维的张量(或称为1阶的张量),而这两个初始化张量的元素是由两个\(N(0,\sigma)\)正态分布所生成的随机化初始数据,比如在该案例中我们可以试着将这些初始化的参数打印出来:

  1. # test_linear.py
  2. from mindspore import context
  3. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from mindspore import dataset as ds
  7. from mindspore.common.initializer import Normal
  8. from mindspore import nn
  9. def get_data(num, w=2.0, b=3.0):
  10. for _ in range(num):
  11. x = np.random.uniform(-10.0, 10.0)
  12. noise = np.random.normal(0, 1)
  13. y = x * w + b + noise
  14. yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
  15. eval_data = list(get_data(50))
  16. x_target_label = np.array([-10, 10, 0.1])
  17. y_target_label = x_target_label * 2 + 3
  18. x_eval_label,y_eval_label = zip(*eval_data)
  19. def create_dataset(num_data, batch_size=16, repeat_size=1):
  20. input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
  21. input_data = input_data.batch(batch_size)
  22. input_data = input_data.repeat(repeat_size)
  23. return input_data
  24. data_number = 1600
  25. batch_number = 16
  26. repeat_number = 1
  27. ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
  28. dict_datasets = next(ds_train.create_dict_iterator())
  29. class LinearNet(nn.Cell):
  30. def __init__(self):
  31. super(LinearNet, self).__init__()
  32. self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
  33. def construct(self, x):
  34. x = self.fc(x)
  35. return x
  36. net = LinearNet()
  37. model_params = net.trainable_params()
  38. for param in model_params:
  39. print(param, param.asnumpy())

执行结果如下,是两个一维的数组:数组

  1. root@b8955ba28950:/home# python test_linear.py
  2. Parameter (name=fc.weight) [[-0.00252427]]
  3. Parameter (name=fc.bias) [0.00694926]

在上述代码中虽然打印了两个参数值,但是并不是很直观,我们可以将这组参数值所对应的函数图画在刚才的散点图中看看效果:

  1. # test_linear.py
  2. from mindspore import context
  3. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. from mindspore import dataset as ds
  7. from mindspore.common.initializer import Normal
  8. from mindspore import nn, Tensor
  9. def get_data(num, w=2.0, b=3.0):
  10. for _ in range(num):
  11. x = np.random.uniform(-10.0, 10.0)
  12. noise = np.random.normal(0, 1)
  13. y = x * w + b + noise
  14. yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
  15. eval_data = list(get_data(50))
  16. x_target_label = np.array([-10, 10, 0.1])
  17. y_target_label = x_target_label * 2 + 3
  18. x_eval_label,y_eval_label = zip(*eval_data)
  19. def create_dataset(num_data, batch_size=16, repeat_size=1):
  20. input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
  21. input_data = input_data.batch(batch_size)
  22. input_data = input_data.repeat(repeat_size)
  23. return input_data
  24. data_number = 1600
  25. batch_number = 16
  26. repeat_number = 1
  27. ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
  28. dict_datasets = next(ds_train.create_dict_iterator())
  29. class LinearNet(nn.Cell):
  30. def __init__(self):
  31. super(LinearNet, self).__init__()
  32. self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
  33. def construct(self, x):
  34. x = self.fc(x)
  35. return x
  36. net = LinearNet()
  37. model_params = net.trainable_params()
  38. x_model_label = np.array([-10, 10, 0.1])
  39. y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
  40. Tensor(model_params[1]).asnumpy()[0])
  41. plt.axis([-10, 10, -20, 25])
  42. plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
  43. plt.plot(x_model_label, y_model_label, color="blue")
  44. plt.plot(x_target_label, y_target_label, color="green")
  45. plt.savefig('initial.png')

执行后会在当前目录生成一个名为initial.png的图片:



可以看到此时的参数所对应的函数距离我们所预期的还是比较远的。

训练与可视化

在前面的技术铺垫之后,这一步终于可以开始训练了。在机器学习中,我们需要先定义好一个用于衡量结果好坏的函数,一般可以称之为损失函数(Loss Function)。损失函数值越小,代表结果就越好,在我们面对的这个函数拟合问题中所代表的就是,拟合的效果越好。这里我们采取的是均方误差函数(Mean Square Error,简称MSE):



均方误差是最常使用的损失函数,因为不管是往哪个方向的偏移,都会导致损失函数值的急剧增大。在定义好损失函数之后,我们需要定义一个前向传播网络,用于执行损失函数的计算,这里我们直接使用了mindspore定义好的接口:mindspore.nn.loss.MSELoss:



在计算好对应参数的损失函数值之后,我们需要更新迭代参数,计算下一组参数的损失函数值,以确定向哪个方向“前进”才能找到最终的最低损失函数值。这个参数迭代的功能由反向传播网络实现,常用的参数更新算法有梯度下降等,关于梯度下降算法,在前面写过的这篇博客中有比较详细的介绍。其基本计算公式如下:



在mindspore中优化函数的接口为mindspore.nn.Momentum



这些模型都定义好之后,可以用mindspore.Model进行封装和训练。

  1. # test_linear.py
  2. from mindspore import context
  3. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  4. import numpy as np
  5. import matplotlib.pyplot as plt
  6. import matplotlib.animation as animation
  7. from mindspore import dataset as ds
  8. from mindspore.common.initializer import Normal
  9. from mindspore import nn, Tensor, Model
  10. import time
  11. from IPython import display
  12. from mindspore.train.callback import Callback, LossMonitor
  13. def get_data(num, w=2.0, b=3.0):
  14. for _ in range(num):
  15. x = np.random.uniform(-10.0, 10.0)
  16. noise = np.random.normal(0, 1)
  17. y = x * w + b + noise
  18. yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
  19. eval_data = list(get_data(50))
  20. x_target_label = np.array([-10, 10, 0.1])
  21. y_target_label = x_target_label * 2 + 3
  22. x_eval_label,y_eval_label = zip(*eval_data)
  23. def create_dataset(num_data, batch_size=16, repeat_size=1):
  24. input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
  25. input_data = input_data.batch(batch_size)
  26. input_data = input_data.repeat(repeat_size)
  27. return input_data
  28. data_number = 1600
  29. batch_number = 16
  30. repeat_number = 1
  31. ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
  32. dict_datasets = next(ds_train.create_dict_iterator())
  33. class LinearNet(nn.Cell):
  34. def __init__(self):
  35. super(LinearNet, self).__init__()
  36. self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
  37. def construct(self, x):
  38. x = self.fc(x)
  39. return x
  40. net = LinearNet()
  41. model_params = net.trainable_params()
  42. x_model_label = np.array([-10, 10, 0.1])
  43. y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0] +
  44. Tensor(model_params[1]).asnumpy()[0])
  45. net = LinearNet()
  46. net_loss = nn.loss.MSELoss()
  47. opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
  48. model = Model(net, net_loss, opt)
  49. fig = plt.figure()
  50. ims = []
  51. def plot_model_and_datasets(net, eval_data):
  52. weight = net.trainable_params()[0]
  53. bias = net.trainable_params()[1]
  54. x = np.arange(-10, 10, 0.1)
  55. y = x * Tensor(weight).asnumpy()[0][0] + Tensor(bias).asnumpy()[0]
  56. x1, y1 = zip(*eval_data)
  57. x_target = x
  58. y_target = x_target * 2 + 3
  59. plt.axis([-11, 11, -20, 25])
  60. plt.scatter(x1, y1, color="red", s=5)
  61. im = plt.plot(x, y, color="blue")
  62. ims.append(im)
  63. im1 = plt.plot(x_target, y_target, color="green")
  64. ims.append(im1)
  65. time.sleep(0.2)
  66. class ImageShowCallback(Callback):
  67. def __init__(self, net, eval_data):
  68. self.net = net
  69. self.eval_data = eval_data
  70. def step_end(self, run_context):
  71. plot_model_and_datasets(self.net, self.eval_data)
  72. display.clear_output(wait=True)
  73. epoch = 1
  74. imageshow_cb = ImageShowCallback(net, eval_data)
  75. model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)
  76. plot_model_and_datasets(net, eval_data)
  77. for net_param in net.trainable_params():
  78. print(net_param, net_param.asnumpy())
  79. ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
  80. ani.save('train.gif', writer='pillow')

执行结果如下:

  1. root@b8955ba28950:/home# python test_linear.py
  2. WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
  3. [WARNING] ME(444:140374496206976,MainProcess):2021-04-14-09:28:58.738.627 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
  4. Parameter (name=fc.weight) [[1.8964282]]
  5. Parameter (name=fc.bias) [3.0266616]

执行完成后会在当前目录下生成一个名为train.gif的动态图,演示整个训练优化的过程:



其中红色散点是训练数据,绿色直线是原始函数,蓝色直线是训练后的函数,可以看到两个函数是越来越接近的。最后拟合出来的函数为:

\[y=1.8964282x+3.0266616
\]

与我们所预期的:

\[y=2x+3
\]

还是略有差距,但是这其中的可能原因有很多,有可能是生成的随机散点的问题,也有可能是在这个范围内的线段拟合就是有这么大的误差,这里我们不做展开。到这里为止,我们就成功的使用mindspore完成了一个函数拟合的任务。

python绘制动态函数图

在上一个章节中我们演示了使用mindspore完成了一个线性函数的拟合,最后的代码中其实已经使用到了动态图的绘制方法,这里单独抽取出来作为一个章节来介绍。我们所使用到的工具是matplotlib.animation,使用的第一步是在训练的外部先生成一个动态图像的对象:

  1. fig = plt.figure()
  2. ims = []

其中ims是用于存储每一帧的数据绘制内容。第二步是将训练过程中需要变化的绘图对象添加到ims中:

  1. im = plt.plot(x, y, color="blue")
  2. ims.append(im)
  3. im1 = plt.plot(x_target, y_target, color="green")
  4. ims.append(im1)

最后根据绘制的图的对象fig和变化的图像集合ims来生成一个动态图并且保存到本地文件中:

  1. ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
  2. ani.save('train.gif', writer='pillow')

关于animation.ArtistAnimation的接口参数如下所示:



这里每一帧之间的间隔时间我们定义为500ms,重复播放1000次,基本可以认为是一直在重复播放的。最终的效果图在上一个章节中已经做了展示,这里就不再重复说明。需要注意的是,生成动态图的过程会比较漫长,而且只有通过animation才能够生成和保存gif动态图,直接通过plt.savefig是无法保存为动态图的。

总结概要

很多机器学习的算法的基础就是函数的拟合,这里我们考虑的是其中一种最简单也最常见的场景:线性函数的拟合,并且我们要通过mindspore来实现这个数据的训练。通过构造均方误差函数,配合前向传播网络与反向传播网络的使用,最终大体成功的拟合了给定的一个线性函数。文末我们还顺带介绍了使用matplotlib的animation来生成动态图的功能,可视化的展现了整个训练的过程。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/linear.html

作者ID:DechinPhy

更多原著文章请参考:https://www.cnblogs.com/dechinphy/

参考链接

  1. https://www.mindspore.cn/tutorial/training/zh-CN/master/quick_start/linear_regression.html
  2. https://blog.csdn.net/clksjx/article/details/105720120
  3. https://www.cnblogs.com/dechinphy/p/gradient.html

MindSpore函数拟合的更多相关文章

  1. matlab函数拟合

    1 函数拟合 函数拟合在工程(如采样校正)和数据分析(如隶属函数确定)中都是非常有用的工具.我这里将函数拟合分为三类:分别是多项式拟合,已知函数类型的拟合和未知函数类型的拟合.matlab中关于函数的 ...

  2. MATLAB用“fitgmdist”函数拟合高斯混合模型(一维数据)

    MATLAB用“fitgmdist”函数拟合高斯混合模型(一维数据) 作者:凯鲁嘎吉 - 博客园 http://www.cnblogs.com/kailugaji/ 在MATLAB中“fitgmdis ...

  3. matlab多项式拟合以及指定函数拟合

    clc;clear all;close all;%% 多项式拟合指令:% X = [1 2 3 4 5 6 7 8 9 ];% Y = [9 7 6 3 -1 2 5 7 20]; % P= poly ...

  4. TensorFlow-正弦函数拟合

    MNIST的代码还是有点复杂,一大半内容全在搞数据,看了半天全是一滩烂泥.最关键的是最后输出就是一个accuracy,我根本就不关心你准确率是0.98还是0.99好吗?我就想看到我手写一个5,你程序给 ...

  5. MATLAB神经网络(2) BP神经网络的非线性系统建模——非线性函数拟合

    2.1 案例背景 在工程应用中经常会遇到一些复杂的非线性系统,这些系统状态方程复杂,难以用数学方法准确建模.在这种情况下,可以建立BP神经网络表达这些非线性系统.该方法把未知系统看成是一个黑箱,首先用 ...

  6. MATLAB神经网络(3) 遗传算法优化BP神经网络——非线性函数拟合

    3.1 案例背景 遗传算法(Genetic Algorithms)是一种模拟自然界遗传机制和生物进化论而形成的一种并行随机搜索最优化方法. 其基本要素包括:染色体编码方法.适应度函数.遗传操作和运行参 ...

  7. 使用MindSpore的线性神经网络拟合非线性函数

    技术背景 在前面的几篇博客中,我们分别介绍了MindSpore的CPU版本在Docker下的安装与配置方案.MindSpore的线性函数拟合以及MindSpore后来新推出的GPU版本的Docker编 ...

  8. cftool拟合&函数逼近

    cftool拟合&函数逼近 cftool 真是神奇,之前我们搞的一些线性拟合解方程,多项式拟合,函数拟合求参数啊,等等. 已经超级多了,为啥还得搞一个cftool拟合啊?而且毫无数学理论. 如 ...

  9. Python数据处理——绘制函数图形以及数据拟合

    1.多项式拟合 对散点进行多项式拟合并打印出拟合函数以及拟合后的图形import matplotlib.pyplot as pltimport numpy as npx=np.arange(1,17, ...

随机推荐

  1. 炒币亏损一万美金?不如抢SPC空投!

    币圈的市场可以用风云变幻来形容,1月9日的时候比特币震荡,其他币种争先上涨,连平时都不涨的币种都拉出了10%-20%的领先涨幅,市场惊呼牛市来了,但喜悦还没有维持一天,1月10日(昨天)市场就走向另一 ...

  2. NGK生态商城即将上线官网,推动生态落地应用

    NGK生态商城即将上线官网,以推动生态落地应用.此举意味着NGK生态将跻身区块链顶尖之列,同时,NGK代币.NGK Dapp游戏 "呼叫河马" 以及NGK DeFi项目Baccar ...

  3. Mac mini M1使用简单体验(编程、游戏、深度学习)

    好久不见了各位! 前一阵子忍不住剁手买了M1芯片的mac mini,为了弥补自己的内疚感就卖了自己的旧的mbp2017款.数据也完全迁移到了新机器上,之前的工作也就由mbp2017彻底换成mac mi ...

  4. 进阶高阶IoT架构-教你如何简单实现一个消息队列

    前言 消息队列是软件系统领域用来实现系统间通信最广泛的中间件.基于消息队列的方式是指由应用中的某个系统负责发送消息,由关心这条消息的相关系统负责接收消息,并在收到消息后进行各自系统内的业务处理.消息可 ...

  5. C语言经典88案例,我文科妹妹说她都学会了!

    案例ex01: 将字符串转换为一个整数 1 题目 函数:fun() 功能:将字符串转换为一个整数 描述: [不能使用C语言提供的字符串函数] 输入:字符串"-1234" 输出:整型 ...

  6. Linux速通08 网络原理及基础设置、软件包管理

    使用 ifconfig命令来维护网络 # ifconfig 命令:显示所有正在启动的网卡的详细信息或设定系统中网卡的 IP地址 # 应用 ifconfig命令设定网卡的 IP地址: * 例:修改 et ...

  7. 03-Spring默认标签解析

    默认标签的解析 上一篇分析了整体的 xml 文件解析,形成 BeanDefinition 并注册到 IOC 容器中,但并没有详细的说明具体的解析,这一篇主要说一下 默认标签的解析,下一篇主要说自定义标 ...

  8. 将表单数据转换成json字符串

    $("#theForm").serialize(); 可以获取表单的数据,但是是json字符串 需要转换成json才能正常使用

  9. 【odoo14】第十五章、网站客户端开发

    odoo的web客户端.后台是员工经常使用的地方.在第九章中,我们了解了如何使用后台提供的各种可能性.本章,我们将了解如何扩展这种可能性.其中web模块包含了我们在使用odoo中的各种交互行为. 本章 ...

  10. 新的颜色对比度算法-感知对比度算法APCA

    目录 对比度 在控制台查看 插件或网站 感知对比度算法(APCA) APCA Math 原理 js 实现的 SAPC 最后 灵感的源泉来源于不断的接受新鲜事物. Chrome 89 新功能一览,性能提 ...