MindSpore网络模型类
MindSpore网络模型类
Q:使用MindSpore进行模型训练时,CTCLoss的输入参数有四个:inputs, labels_indices, labels_values, sequence_length,如何使用CTCLoss进行训练?
A:定义的model.train接口里接收的dataset可以是多个数据组成,形如(data1, data2, data3, …),所以dataset是可以包含inputs,labels_indices,labels_values,sequence_length的信息的。只需要定义好相应形式的dataset,传入model.train里就可以。具体的可以了解下相应的数据处理接口。
Q:模型转移时如何把PyTorch的权重加载到MindSpore中?
A:首先输入PyTorch的pth文件,以ResNet-18为例,MindSpore的网络结构和PyTorch保持一致,转完之后可直接加载进网络,这边参数只用到BN和Conv2D,若有其他层ms和PyTorch名称不一致,需要同样的修改名称。
Q:模型已经训练好,如何将模型的输出结果保存为文本或者npy的格式?
A:网络的输出为Tensor,需要使用asnumpy()方法将Tensor转换为numpy,再进行下一步保存。具体可参考:
out = net(x)
np.save("output.npy", out.asnumpy())
Q:使用MindSpore做分割训练,必须将数据转为MindRecords吗?
A:build_seg_data.py是将数据集生成MindRecord的脚本,可以直接使用/适配数据集。或者如果想尝试实现数据集的读取,可以使用GeneratorDataset自定义数据集加载。
Q:MindSpore可以读取TensorFlow的ckpt文件吗?
A:MindSpore的ckpt和TensorFlow的ckpt格式是不通用的,虽然都是使用protobuf协议,但是proto的定义是不同的。当前MindSpore不支持读取TensorFlow或PyTorch的ckpt文件。
Q:如何不将数据处理为MindRecord格式,直接进行训练呢?
A:可以使用自定义的数据加载方式 GeneratorDataset,具体可以参考数据集加载文档中的自定义数据集加载。
Q:MindSpore现支持直接读取哪些其他框架的模型和哪些格式呢?比如PyTorch下训练得到的pth模型可以加载到MindSpore框架下使用吗?
A: MindSpore采用protbuf存储训练参数,无法直接读取其他框架的模型。对于模型文件本质保存的就是参数和对应的值,可以用其他框架的API将参数读取出来之后,拿到参数的键值对,然后再加载到MindSpore中使用。比如想用其他框架训练好的ckpt文件,可以先把参数读取出来,再调用MindSpore的save_checkpoint接口,就可以保存成MindSpore可以读取的ckpt文件格式了。
Q:用MindSpore训练出的模型如何在Ascend 310上使用?可以转换成适用于HiLens Kit用的吗?
A:Ascend 310需要运行专用的OM模型,先使用MindSpore导出ONNX或AIR模型,再转化为Ascend 310支持的OM模型。具体可参考多平台推理。可以,HiLens Kit是以Ascend 310为推理核心,所以前后两个问题本质上是一样的,需要转换为OM模型.
Q:MindSpore如何进行参数(如dropout值)修改?
A:在构造网络的时候可以通过 if self.training: x = dropput(x),验证的时候,执行前设置network.set_train(mode_false),就可以不适用dropout,训练时设置为True就可以使用dropout。
Q:从哪里可以查看MindSpore训练及推理的样例代码或者教程?
A:可以访问MindSpore官网教程训练和MindSpore官网教程推理。
Q:MindSpore支持哪些模型的训练?
A:MindSpore针对典型场景均有模型训练支持,支持情况详见Release note。
Q:MindSpore有哪些现成的推荐类或生成类网络或模型可用?
A:目前正在开发Wide & Deep、DeepFM、NCF等推荐类模型,NLP领域已经支持Bert_NEZHA,正在开发MASS等模型,用户可根据场景需要改造为生成类网络,可以关注MindSpore Model Zoo。
Q:MindSpore模型训练代码能有多简单?
A:除去网络定义,MindSpore提供了Model类的接口,大多数场景只需几行代码就可完成模型训练。
Q:如何使用MindSpore拟合f(x)=a×sin(x)+bf(x)=a×sin(x)+b这类函数?
A:以下拟合案例是基于MindSpore线性拟合官方案例改编而成。
# The fitting function is:f(x)=2*sin(x)+3.
import numpy as np
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Model, context
from mindspore.train.callback import LossMonitor
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
def get_data(num, w=2.0, b=3.0):
# f(x)=w * sin(x) + b
# f(x)=2 * sin(x) +3
for i in range(num):
x = np.random.uniform(-np.pi, np.pi)
noise = np.random.normal(0, 1)
y = w * np.sin(x) + b + noise
yield np.array([np.sin(x)]).astype(np.float32), np.array([y]).astype(np.float32)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
if __name__ == "__main__":
num_data = 1600
batch_size = 16
repeat_size = 1
lr = 0.005
momentum = 0.9
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), lr, momentum)
model = Model(net, net_loss, opt)
ds_train = create_dataset(num_data, batch_size=batch_size, repeat_size=repeat_size)
model.train(1, ds_train, callbacks=LossMonitor(), dataset_sink_mode=False)
print(net.trainable_params()[0], "\n%s" % net.trainable_params()[1])
Q:如何使用MindSpore拟合f(x)=ax2+bx+cf(x)=ax2+bx+c这类的二次函数?
A:以下代码引用自MindSpore的官方教程的代码仓
在以下几处修改即可很好的拟合f(x)=ax2+bx+cf(x)=ax2+bx+c:
- 数据集生成。
- 拟合网络。
- 优化器。
修改的详细信息如下,附带解释。
# Since the selected optimizer does not support CPU, so the training computing platform is changed to GPU, which requires readers to install the corresponding GPU version of MindSpore.
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
# Assuming that the function to be fitted this time is f(x)=2x^2+3x+4, the data generation function is modified as follows:
def get_data(num, a=2.0, b=3.0 ,c = 4):
for i in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
# The y value is generated by the fitting target function ax^2+bx+c.
y = x * x * a + x * b + c + noise
# When a*x^2+b*x+c is fitted, a and b are weight parameters and c is offset parameter bias. The training data corresponding to the two weights are x^2 and x respectively, so the data set generation mode is changed as follows:
yield np.array([x*x, x]).astype(np.float32), np.array([y]).astype(np.float32)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
# Because the full join function inputs two training parameters, the input value is changed to 2, the first Nomral(0.02) will automatically assign random weights to the input two parameters, and the second Normal is the random bias.
self.fc = nn.Dense(2, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
if __name__ == "__main__":
num_data = 1600
batch_size = 16
repeat_size = 1
lr = 0.005
momentum = 0.9
net = LinearNet()
net_loss = nn.loss.MSELoss()
# RMSProp optimalizer with better effect is selected for quadratic function fitting, Currently, Ascend and GPU computing platforms are supported.
opt = nn.RMSProp(net.trainable_params(), learning_rate=0.1)
model = Model(net, net_loss, opt)
ds_train = create_dataset(num_data, batch_size=batch_size, repeat_size=repeat_size)
model.train(1, ds_train, callbacks=LossMonitor(), dataset_sink_mode=False)
print(net.trainable_params()[0], "\n%s" % net.trainable_params()[1])
MindSpore网络模型类的更多相关文章
- MindSpore特性支持类
MindSpore特性支持类 Q:请问MindSpore支持梯度截断吗? A:支持,可以参考梯度截断的定义和使用. Q:如何在训练神经网络过程中对计算损失的超参数进行改变? A:暂时还未有这样的功能. ...
- Java类的继承与多态特性-入门笔记
相信对于继承和多态的概念性我就不在怎么解释啦!不管你是.Net还是Java面向对象编程都是比不缺少一堂课~~Net如此Java亦也有同样的思想成分包含其中. 继承,多态,封装是Java面向对象的3大特 ...
- 人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型
人脸检测及识别python实现系列(5)——利用keras库训练人脸识别模型 经过前面稍显罗嗦的准备工作,现在,我们终于可以尝试训练我们自己的卷积神经网络模型了.CNN擅长图像处理,keras库的te ...
- MindSpore后端运行类
MindSpore后端运行类 Q:如何在训练过程中监控loss在最低的时候并保存训练参数? A:可以自定义一个Callback.参考ModelCheckpoint的写法,此外再增加判断loss的逻辑: ...
- MindSpore平台系统类
MindSpore平台系统类 Q:MindSpore只能在华为自己的NPU上跑么? A: MindSpore同时支持华为自己的Ascend NPU.GPU与CPU,是支持异构算力的. Q:MindSp ...
- MindSpore算子支持类
MindSpore算子支持类 Q:在使用Conv2D进行卷积定义的时候使用到了group的参数,group的值不是只需要保证可以被输入输出的维度整除即可了吗?group参数的传递方式是怎样的呢? A: ...
- [技术干货-算子使用] Mindspore 控制流中存在原地更新操作类副作用算子时循环值不更新问题记录
关于mindspore 原地更新类算子的一点思考记录如下: 现象记录: 原始测试代码 错误结果复现: 分析: 如果在场景中加入42行的copy()操作此时cpu的结果就会正确,但是gpu的结果则不受c ...
- MindSpore技术理解(上)
MindSpore技术理解(上) 引言 深度学习研究和应用在近几十年得到了爆炸式的发展,掀起了人工智能的第三次浪潮,并且在图像识别.语音识别与合成.无人驾驶.机器视觉等方面取得了巨大的成功.这也对算法 ...
- 如何基于MindSpore实现万亿级参数模型算法?
摘要:近来,增大模型规模成为了提升模型性能的主要手段.特别是NLP领域的自监督预训练语言模型,规模越来越大,从GPT3的1750亿参数,到Switch Transformer的16000亿参数,又是一 ...
随机推荐
- vue.js中使用set方法 this.$set
vue教程中有这样一个注意事项: 第一种具体情况如下: 运行结果: 当利用索引改变数组某一项时,页面不会刷新.解决方法如下: 运行结果: 三种方式都可以解决,使用Vue.set.vm.$set()或者 ...
- Python中Scapy网络嗅探模块的使用
目录 Scapy scapy的安装和使用 发包 发包和收包 抓包 将抓取到的数据包保存 查看抓取到的数据包 格式化输出 过滤抓包 Scapy scapy是python中一个可用于网络嗅探的非常强大的第 ...
- Win64 驱动内核编程-27.强制读写受保护的内存
强制读写受保护的内存 某些时候我们需要读写别的进程的内存,某些时候别的进程已经对自己的内存读写做了保护,这里说四个思路(两个R3的,两个R0的). 方案1(R3):直接修改别人内存 最基本的也最简单的 ...
- Caddy-基于go的微型serve用来做反向代理和Gateway
1.简单配置 2.go实现,直接一个二进制包,没依赖. 3.默认全站https 常用 反向代理,封装多端口gateway 使用:启动直接执行二进制文件 caddy 就行 根据输出信息 直接https: ...
- 截取字符串长度,超出部分用省略号代替 PHP
function subText($text, $length){ if (mb_strlen($text, 'utf8') > $length) { return mb_substr($tex ...
- 老Python带你从浅入深探究Tuple
元组 Python中的元组容器序列(tuple)与列表容器序列(list)具有极大的相似之处,因此也常被称为不可变的列表. 但是两者之间也有很多的差距,元组侧重于数据的展示,而列表侧重于数据的存储与操 ...
- API网关才是大势所趋?SpringCloud Gateway保姆级入门教程
什么是微服务网关 SpringCloud Gateway是Spring全家桶中一个比较新的项目,Spring社区是这么介绍它的: 该项目借助Spring WebFlux的能力,打造了一个API网关.旨 ...
- Nifi:初识nifi
写在前面: 第一次接触这一系统的时候,只有github上的一坨源码和官方的英文文档,用起来只能说是一步一个坑,一踩一个脚印,现在回想那段血泪史,只想 ***,现在用起来算是有了一些经验和总结,这里就做 ...
- 墙裂推荐一波mysql学习资源
在日常工作与学习中,无论是开发.运维.测试,还是架构师,数据库是一门必不可少的"必修课", 也是必备的涨薪神器.在互联网公司中,开源数据库用得比较多的当属 MySQL 了. 但my ...
- 如何更好理解Peterson算法?
如何更好理解Peterson算法? 1 Peterson算法提出的背景 在我们讲述Peterson算法之间,我们先了解一下Peterson算法提出前的背景(即:在这个算法提出之前,前人们都做了哪些工作 ...