内容接前文:

https://www.cnblogs.com/devilmaycry812839668/p/14988686.html

https://www.cnblogs.com/devilmaycry812839668/p/14990021.html

前面是我们自己按照个人理解实现的单步计算,随着对这个计算框架MindSpore的深入了解我们了解到其实官方是提供了单步计算函数的。

具体函数:

from mindspore.nn import TrainOneStepCell, WithLossCell

根据官方资料:

https://www.mindspore.cn/doc/programming_guide/zh-CN/master/network_component.html?highlight=%E5%8D%95%E6%AD%A5%E8%AE%AD%E7%BB%83

根据官方提供的函数,给出如下代码:

import mindspore
import numpy as np # 引入numpy科学计算库
import matplotlib.pyplot as plt # 引入绘图库 np.random.seed(123) # 随机数生成种子 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor
from mindspore.nn import TrainOneStepCell, WithLossCell class Net(nn.Cell):
def __init__(self, input_dims, output_dims):
super(Net, self).__init__()
self.matmul = ops.MatMul() self.weight_1 = Parameter(Tensor(np.random.randn(input_dims, 128), dtype=mstype.float32), name='weight_1')
self.bias_1 = Parameter(Tensor(np.zeros(128), dtype=mstype.float32), name='bias_1')
self.weight_2 = Parameter(Tensor(np.random.randn(128, 64), dtype=mstype.float32), name='weight_2')
self.bias_2 = Parameter(Tensor(np.zeros(64), dtype=mstype.float32), name='bias_2')
self.weight_3 = Parameter(Tensor(np.random.randn(64, output_dims), dtype=mstype.float32), name='weight_3')
self.bias_3 = Parameter(Tensor(np.zeros(output_dims), dtype=mstype.float32), name='bias_3') def construct(self, x):
x1 = self.matmul(x, self.weight_1) + self.bias_1
x2 = self.matmul(x1, self.weight_2) + self.bias_2
x3 = self.matmul(x2, self.weight_3) + self.bias_3
return x3 def main():
net = Net(1, 1)
# loss function
loss = nn.MSELoss()
# optimizer
optim = nn.SGD(params=net.trainable_params(), learning_rate=0.000001)
# make net model
# model = Model(net, loss, optim, metrics={'loss': nn.Loss()})
net_with_criterion = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_criterion, optim) # 数据集
x, y = np.array([[0.1]], dtype=np.float32), np.array([[0.1]], dtype=np.float32)
x = Tensor(x)
y = Tensor(y) for i in range(20000*100):
#print(i, '\t', '*' * 100)
train_network.set_train()
res = train_network(x, y) # right
# False, False
# False, True
# True, True xxx # not right
# True, False if __name__ == '__main__':
""" 设置运行的背景context """
from mindspore import context # 为mindspore设置运行背景context
#context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') import time a = time.time()
main()
b = time.time()
print(b-a)

运行时间:

1158.24s

1154.29s

1152.69s

=====================================================

前文我们给出的单步计算 model.train  的代码修改如下:

import mindspore
import numpy as np # 引入numpy科学计算库
import matplotlib.pyplot as plt # 引入绘图库 np.random.seed(123) # 随机数生成种子 import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor
from mindspore import ParameterTuple, Parameter
from mindspore import dtype as mstype
from mindspore import Model
import mindspore.dataset as ds
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import LossMonitor class Net(nn.Cell):
def __init__(self, input_dims, output_dims):
super(Net, self).__init__()
self.matmul = ops.MatMul() self.weight_1 = Parameter(Tensor(np.random.randn(input_dims, 128), dtype=mstype.float32), name='weight_1')
self.bias_1 = Parameter(Tensor(np.zeros(128), dtype=mstype.float32), name='bias_1')
self.weight_2 = Parameter(Tensor(np.random.randn(128, 64), dtype=mstype.float32), name='weight_2')
self.bias_2 = Parameter(Tensor(np.zeros(64), dtype=mstype.float32), name='bias_2')
self.weight_3 = Parameter(Tensor(np.random.randn(64, output_dims), dtype=mstype.float32), name='weight_3')
self.bias_3 = Parameter(Tensor(np.zeros(output_dims), dtype=mstype.float32), name='bias_3') def construct(self, x):
x1 = self.matmul(x, self.weight_1) + self.bias_1
x2 = self.matmul(x1, self.weight_2) + self.bias_2
x3 = self.matmul(x2, self.weight_3) + self.bias_3
return x3 def main():
net = Net(1, 1)
# loss function
loss = nn.MSELoss()
# optimizer
optim = nn.SGD(params=net.trainable_params(), learning_rate=0.000001)
# make net model
model = Model(net, loss, optim, metrics={'loss': nn.Loss()}) # 数据集
x, y = np.array([[0.1]], dtype=np.float32), np.array([[0.1]], dtype=np.float32) def generator_multidimensional():
for i in range(1):
a = x*i
b = y*i
#print(a, b)
yield (a, b) dataset = ds.GeneratorDataset(source=generator_multidimensional, column_names=["input", "output"]) for i in range(20000*100):
#print(i, '\t', '*' * 100)
model.train(1, dataset, dataset_sink_mode=False) # right
# False, False
# False, True
# True, True xxx # not right
# True, False if __name__ == '__main__':
""" 设置运行的背景context """
from mindspore import context # 为mindspore设置运行背景context
#context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') import time a = time.time()
main()
b = time.time()
print(b-a)

运行时间:

2173.19s

2181.61s

==================================================================

可以看到,在单步计算时,如果使用框架提供的单步训练函数会更好的提升算法运算效率,运算效率提升的幅度也很大,所有在进行单步训练或者非持续数据量训练时使用框架提供的单步训练函数是首选。

单步训练函数:

from mindspore.nn import TrainOneStepCell, WithLossCell

=====================================================================

本文实验环境为  MindSpore1.1  docker版本

宿主机:Ubuntu18.04系统

CPU:I7-8700

GPU:1060ti NVIDIA显卡

(续 2 )在深度计算框架MindSpore中如何对不持续的计算进行处理——对数据集进行一定epoch数量的训练后,进行其他工作处理,再返回来接着进行一定epoch数量的训练——单步计算的更多相关文章

  1. 带你学习MindSpore中算子使用方法

    摘要:本文分享下MindSpore中算子的使用和遇到问题时的解决方法. 本文分享自华为云社区<[MindSpore易点通]算子使用问题与解决方法>,作者:chengxiaoli. 简介 算 ...

  2. TensorFlow - 框架实现中的三种 Graph

    文章目录 TensorFlow - 框架实现中的三种 Graph 1. Graph 2. GraphDef 3. MetaGraph 4. Checkpoint 5. 总结 TensorFlow - ...

  3. SSH框架应用中常用Jar包用途介绍

    struts2需要的几个jar包:1)xwork-core-2.1.62)struts2-core-2.1.83)ognl-2.7.34)freemarker-2.3.155)commons-io-1 ...

  4. 如何在Crystal框架项目中内置启动MetaQ服务?

    当Crystal框架项目中需要使用消息机制,而项目规模不大.性能要求不高时,可内置启动MetaQ服务器. 分步指南 项目引入crystal-extend-metaq模块,如下: <depende ...

  5. 如何在Crystal框架项目中内置启动Zookeeper服务?

    当Crystal框架项目需要使用到Zookeeper服务时(如使用Dubbo RPC时,需要注册服务到Zookeeper),而独立部署和启动Zookeeper服务不仅繁琐,也容易出现错误. 在小型项目 ...

  6. 浅入深出之Java集合框架(中)

    Java中的集合框架(中) 由于Java中的集合框架的内容比较多,在这里分为三个部分介绍Java的集合框架,内容是从浅到深,如果已经有java基础的小伙伴可以直接跳到<浅入深出之Java集合框架 ...

  7. Javscript调用iframe框架页面中函数的方法

    Javscript调用iframe框架页面中函数的方法,可以实现iframe之间传值或修改值了, 访问iframe里面的函数: window.frames['CallCenter_iframe'].h ...

  8. 游戏框架设计中的。绑定binding。。。命令 command 和消息message 以及MVVM

    游戏框架设计中的.绑定binding...命令 command 和消息message

  9. 关于MFC框架程序中CWinApp::OnIdle

    很早之前就发现,我写的图形引擎在MFC框架程序中的刷帧率始终在60FPS左右.好在自己的程序对刷帧率的要求不是很高,所以一直没有太过纠结此事.直到今天看了别人的程序才发现应该在函数CWinApp::O ...

  10. TP框架模板中IF Else 如何使用?

    TP框架模板中IF Else 如何使用? 截个图吧 如果效果出不来,一般就是条件写错了!!!

随机推荐

  1. Promise 期约

    Promise 期约之前 回调地狱 设想这样一个经常发生的场景,我们希望处理Ajax请求的结果,所以我们将处理请求结果的方法作为回调传入,需要将请求结果继续处理,这就导致我们陷入了回调地狱 doSom ...

  2. 报错 ERR !npicode ELIFECYCLE dev: wue-cli-service serve

    在系统变量 Path 里面加上:%SystemRoot%\system32,关掉终端,重新启动项目.

  3. C# .NET 常见DeepCopy 深度拷贝的性能对比

    先上结论 Method Mean Error StdDev Gen0 Gen1 Allocated JSONConvert 2,273.02 ns 43.758 ns 52.091 ns 0.6599 ...

  4. 【vue】利用输入框搜索过滤来选择列表

    方法1 <div id="app"> <input type="text" @input="handleInput()" ...

  5. 03-vi和vim编辑器的使用

    背景 vim是一个类似于vi的著名的功能强大.高度可定制的文本编辑器. vim在vi的基础上改进和增加了很多特性. 如今vi已经是最受IT届欢迎的编辑器之一. 不止在Linux中,主流IDE都支持vi ...

  6. C#/.NET这些实用的技巧和知识点你都知道吗?

    前言 今天大姚给大家分享一些C#/.NET中的实用的技巧和知识点,它们可以帮助我们提升代码质量和编程效率,希望可以帮助到有需要的同学. .NET使用CsvHelper快速读取和写入CSV文件 本文主要 ...

  7. 使用Swig转换C++到别的编程语言

    项目github地址: aoce 设定aoce能分别与UE4/Unity3D/android demo对接,就这三来看,分别是C++/C#/java三种语言. C++导出给别的语言使用,一般来说,分为 ...

  8. ARM+DSP!全志T113-i+玄铁HiFi4开发板硬件说明书(1)

    前 言 本文档主要介绍开发板硬件接口资源以及设计注意事项等内容,测试板卡为全志T113-i+玄铁HiFi4开发板.由于篇幅问题,本篇文章共分为上下两集,点击账户可查看更多内容详情,开发问题欢迎留言,感 ...

  9. mysql语句大全-工作中常用整理(欢迎大家在评论区继续补充)

    1.NOT EXISTS 和 NOT IN SELECT COUNT(ca.aaa) FROM xx ca WHERE NOT EXISTS( SELECT label.* FROM xxx labe ...

  10. Spark3学习【基于Java】2. Spark-Sql核心概念

    SparkSession 从Spark2开始,Spark-SQL引入了SparkSession这个核心类,它是处理DataSet等结构数据的入口.在2.0之前,使用的是spark-core里的Spar ...