技术背景

近几年在机器学习和传统搜索算法的结合中,逐渐发展出了一种Search To Optimization的思维,旨在通过构造一个特定的机器学习模型,来替代传统算法中的搜索过程,进而加速经典图论等问题的求解。那么这里面就涉及到一个非常关键的工程步骤:把机器学习中训练出来的模型保存成一个文件或者数据库,使得其他人可以重复的使用这个已经训练出来的模型。甚至是可以发布在云端,通过API接口进行调用。那么本文的内容就是介绍给予MindSpore的模型保存与加载,官方文档可以参考这个链接

保存模型

这里我们使用的模型来自于这篇博客,是一个非常基础的线性神经网络模型,用于拟合一个给定的函数。因为这里我们是基于线性的模型,因此当我们需要去拟合一个更加高阶的函数的话,需要手动的处理,比如这里我们使用的平方函数。总结起来该模型可以抽象为:

\[f(x)=ax^2+b
\]

而这里产生训练集时,我们是加了一个随机的噪声,也就是:

\[D(x)=ax^2+b+noise
\]

完整的代码如下所示:

# save_model.py

from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
import numpy as np
from mindspore import dataset as ds
from mindspore import nn, Tensor, Model
import time
from mindspore.train.callback import Callback, LossMonitor, ModelCheckpoint def get_data(num, a=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-1.0, 1.0)
noise = np.random.normal(0, 0.03)
z = a * x ** 2 + b + noise
# 返回数据的时候就返回数据的平方
yield np.array([x**2]).astype(np.float32), np.array([z]).astype(np.float32) def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['x','z'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data data_number = 1600 # 一共产生1600组数据
batch_number = 16 # 分为16组分别进行优化
repeat_number = 2 # 重复2次,可以取到更低的损失函数值 ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number) class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, 0.02, 0.02) def construct(self, x):
x = self.fc(x)
return x net = LinearNet()
model_params = net.trainable_params()
print ('Param Shape is: {}'.format(len(model_params)))
for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())
net_loss = nn.loss.MSELoss() # 设定优化算法,常用的是Momentum和ADAM
optim = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, optim)
ckpt_cb = ModelCheckpoint() epoch = 1
# 设定每8个batch训练完成后就播报一次,这里一共播报25次
model.train(epoch, ds_train, callbacks=[LossMonitor(16), ckpt_cb], dataset_sink_mode=False) for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())

最后是通过ModelCheckpoint这一方法将训练出来的模型保存成.ckpt的checkpoint文件格式。上述代码执行的结果如下:

dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ sudo docker run --rm -v /dev/shm:/dev/shm -v /home/dechin/projects/gitlab/dechin/src/mindspore/:/home/ --runtime=nvidia --privileged=true swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-gpu:1.2.0 /bin/bash -c "cd /home && python save_model.py"
Param Shape is: 2
Parameter (name=fc.weight, shape=(1, 1), dtype=Float32, requires_grad=True) [[0.02]]
Parameter (name=fc.bias, shape=(1,), dtype=Float32, requires_grad=True) [0.02]
epoch: 1 step: 16, loss is 0.9980968
epoch: 1 step: 32, loss is 0.3894167
epoch: 1 step: 48, loss is 0.090543285
epoch: 1 step: 64, loss is 0.060300454
epoch: 1 step: 80, loss is 0.014248277
epoch: 1 step: 96, loss is 0.015697923
epoch: 1 step: 112, loss is 0.014582128
epoch: 1 step: 128, loss is 0.008066677
epoch: 1 step: 144, loss is 0.007225203
epoch: 1 step: 160, loss is 0.0046849623
epoch: 1 step: 176, loss is 0.006007362
epoch: 1 step: 192, loss is 0.004276552
Parameter (name=fc.weight, shape=(1, 1), dtype=Float32, requires_grad=True) [[1.8259585]]
Parameter (name=fc.bias, shape=(1,), dtype=Float32, requires_grad=True) [3.0577476]

可能是由于数据集的问题,可能是由于噪声的问题,使得我们训练出来的参数跟我们最开始所定义的参数有些许的区别,但是具体的误差范围还得通过测试集来验证。运行结束后会在当前目录下生成一系列的.ckpt文件和一个.meta的计算图文件:

dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ ll
total 148
drwxr-xr-x 2 1000 1000 4096 May 16 13:25 ./
drwxr-xr-x 1 root root 4096 May 15 01:55 ../
-r-------- 1 root root 211 May 16 13:25 CKP-1_196.ckpt
-r-------- 1 root root 211 May 16 13:25 CKP-1_197.ckpt
-r-------- 1 root root 211 May 16 13:25 CKP-1_198.ckpt
-r-------- 1 root root 211 May 16 13:25 CKP-1_199.ckpt
-r-------- 1 root root 211 May 16 13:25 CKP-1_200.ckpt
-r-------- 1 root root 3705 May 16 13:25 CKP-graph.meta
-rw-r--r-- 1 1000 1000 2087 May 16 13:25 save_model.py

接下来就可以开始加载这些文件中所给出的模型,用于测试集的验证。

加载模型

在模型的加载中,我们依然还是需要原始的神经网络对象LinearNet

# load_model.py

from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
import mindspore.dataset as ds
from mindspore import load_checkpoint, load_param_into_net
from mindspore import nn, Tensor, Model
import numpy as np class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1,1,0.02,0.02) def construct(self, x):
print ('x:', x)
x = self.fc(x)
print ('z:', x)
return x net = LinearNet()
param_dict = load_checkpoint("CKP-1_200.ckpt")
load_param_into_net(net, param_dict) net_loss = nn.loss.MSELoss()
model = Model(net, net_loss, metrics={"accuracy"}) data = {'x':np.array([[0.01],[0.25],[1],[4],[9]]).astype(np.float32),
'z':np.array([3.02,3.5,5,11,21]).astype(np.float32)}
dataset = ds.NumpySlicesDataset(data=data)
dataset = dataset.batch(1) acc = model.eval(dataset, dataset_sink_mode=False)

这里我们只是定义了一个非常简单的测试集,在实际场景中往往测试集也是非常大量的数据。这里为了简化本地环境,使用的是Docker的编程环境,并且在Docker容器的拉起中使用了--rm选项,在运行结束后会删除这个容器,确保环境的整洁。关于在Docker容器环境下安装和部署MindSpore的方案,可以参考这篇CPU部署博客,以及这篇GPU部署博客。上述代码的执行结果如下:

dechin@ubuntu2004:~/projects/gitlab/dechin/src/mindspore$ sudo docker run --rm -v /dev/shm:/dev/shm -v /home/dechin/projects/gitlab/dechin/src/mindspore/:/home/ --runtime=nvidia --privileged=true swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-gpu:1.2.0 /bin/bash -c "cd /home && python load_model.py"
x:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 9.00000000e+00]])
z:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 1.94898682e+01]])
x:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 2.50000000e-01]])
z:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 3.52294850e+00]])
x:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 9.99999978e-03]])
z:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 3.08499861e+00]])
x:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 1.00000000e+00]])
z:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 4.89154148e+00]])
x:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 4.00000000e+00]])
z:
Tensor(shape=[1, 1], dtype=Float32, value=
[[ 1.03659143e+01]])

从这个结果中我们可以看到,由于在训练集中我们使用的数据集中在比较小的范围,因此在测试这个范围外的数据时,误差相较于这个区域内的数据点会更大一些,这也是属于正常的现象。

总结概要

本文主要从工程实现的角度测试了一下MindSpore的机器学习模型保存与加载的功能,通过这个功能,我们可以将自己训练好的机器学习模型发布出去供更多的人使用,我们也可以直接使用别人在更好的硬件体系上训练好的模型,或者应用于迁移学习。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/sl.html

作者ID:DechinPhy

更多原著文章请参考:https://www.cnblogs.com/dechinphy/

MindSpore保存与加载模型的更多相关文章

  1. TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结

    写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...

  2. [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题

    [深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...

  3. tensorflow学习笔记(三十四):Saver(保存与加载模型)

    Savertensorflow 中的 Saver 对象是用于 参数保存和恢复的.如何使用呢? 这里介绍了一些基本的用法. 官网中给出了这么一个例子: v1 = tf.Variable(..., nam ...

  4. sklearn训练模型的保存与加载

    使用joblib模块保存于加载模型 在机器学习的过程中,我们会进行模型的训练,最常用的就是sklearn中的库,而对于训练好的模型,我们当然是要进行保存的,不然下次需要进行预测的时候就需要重新再进行训 ...

  5. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  6. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  7. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  8. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  9. tensorflow 之模型的保存与加载(三)

    前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...

随机推荐

  1. 微信小程序 | flex布局属性

    flex-direction 主轴方向 row: 横向 row-reverse: 横向倒序 column: 纵向 column-reverse: 纵向倒序; flex-wrap 元素排列换行 nowr ...

  2. Springboot进行Http接口交互实现邮件告警

    本项目采用idea编辑器,依赖maven环境,相关搭建请自行百度一.引入相关依赖    本文Http接口交互使用hutool工具类与阿里FastJson解析报文. <dependencies&g ...

  3. WordPress的SEO优化技巧

    世界上大约有30%的网站都是由Wordpress搭建的,因为Wordpress自身构架清晰,代码规范,且网页评论直接书写在整个页面里,能够被搜索引擎检索到,因此对搜索引擎很友好.但有时候还是会出现只被 ...

  4. 201871030106-陈鑫莲 实验三 结对项目—《D{0-1}KP 实例数据集算法实验平台》项目报告

    项目 内容 课程班级博客链接 班级博客 这个作业要求链接 作业要求 我的课程学习目标 1.学会结对学习,体会结对学习的快乐2.了解并实践结对编程 3.加深对D{0-1}问题的解法的理解4.复习并熟悉P ...

  5. 北航OO第四单元作业总结(4.1~4.3)及课程总结

    前言 在学习过JML规格描述语言之后,本单元进行了UML(Unified Modeling Language)的学习.和JML单纯用语言描述的形式不同,UML通过可视化的图形形式,对一系列有关类的元素 ...

  6. JavaWeb 补充(Servlet)

    Servlet: server applet 概念: 运行在服务器端的小程序     * Servlet就是一个接口,定义了Java类被浏览器访问到(tomcat识别)的规则.     * 将来我们自 ...

  7. 【spring源码系列】之【xml解析】

    1. 读源码的方法 java程序员都知道读源码的重要性,尤其是spring的源码,代码设计不仅优雅,而且功能越来越强大,几乎可以与很多开源框架整合,让应用更易于专注业务领域开发.但是能把spring的 ...

  8. 自动化kolla-ansible部署ubuntu20.04+openstack-victoria之镜像制作win2008r2-19

    自动化kolla-ansible部署ubuntu20.04+openstack-victoria之镜像制作win2008r2-19 欢迎加QQ群:1026880196 进行交流学习 制作OpenSta ...

  9. JavaCV 视频滤镜(LOGO、滚动字幕、画中画、NxN宫格)

    其实,在JavaCV中除了FFmpegFrameGrabber和FFmpegFrameRecorder之外,还有一个重要的类,那就是FFmpegFrameFilter. FFmpegFrameFilt ...

  10. 使用var和不使用var的区别(全局变量/局部变量)

    https://blog.csdn.net/czh500/article/details/80429133