本文参考内容:

https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.2/advanced_usage_of_checkpoint.html?highlight=save_checkpoint

有官方文档内容可知,我们对网络参数的保存不仅可以使用model来自动保存,也可以使用

from mindspore.train.serialization import save_checkpoint

来进行手动保存。

===========================================================

自动保存参数:

给出在模型训练过程中自动保存参数的代码demo:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import save_checkpoint """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
#test_net(net, model, mnist_path)

其中,保存参数代码主要为:

# 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode)

在模型训练过程中,自动保存参数,不仅会保存网络参数,同时也会保存优化器参数

========================================================================

手动保存参数:

demo:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import save_checkpoint """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) #train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
#test_net(net, model, mnist_path) save_checkpoint(net, './net_parameters.ckpt')
save_checkpoint(net_opt, './net_opt_parameters.ckpt')

主要代码:

from mindspore.train.serialization import save_checkpoint

save_checkpoint(net, './net_parameters.ckpt')
save_checkpoint(net_opt, './net_opt_parameters.ckpt')


其中,
save_checkpoint(net, './net_parameters.ckpt')            # 是网络模型net中的参数保存为 net_parameters.ckpt 文件

save_checkpoint(net_opt, './net_opt_parameters.ckpt')    # 是将优化器 net_opt  中的参数保存为 net_opt_parameters.ckpt 文件




可以看到,上面的操作是把网络参数和优化器参数分别保存为了两个文件。

==============================================================================

当然,我们也可以把网络参数和优化器参数保存到一个文件里面, 如下:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import save_checkpoint """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) #train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
#test_net(net, model, mnist_path) a = net.trainable_params()
b = net_opt.trainable_params()
x = []
for param in a:
c = dict()
c['name'] = param.name
c['data'] = param
x.append(c)
for param in b:
c = dict()
c['name'] = param.name
c['data'] = param
x.append(c)
save_checkpoint(x, './parameters.ckpt')

其主要思想就是传给 save_checkpoint 中的不一定是 nn.cell,   也可以是一个 list  。

list 里面存的是每一个参数的参数字典, 参数字典的key为参数的namevalue则为参数

按照这个思想,手动保存参数还可以写成下面形式:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import save_checkpoint """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) #train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
#test_net(net, model, mnist_path) a = net.parameters_and_names()
b = net_opt.parameters_and_names()
x = []
for name, param in a:
c = dict()
c['name'] = name
c['data'] = param
x.append(c)
for name, param in b:
c = dict()
c['name'] = name
c['data'] = param
x.append(c)
save_checkpoint(x, './parameters.ckpt')


===============================================================================================

纠正一个问题:

前面我们讨论的时候都是认为优化器中参数是不包含网络参数的,但是实际中优化器的参数是包括网络参数的,给出代码:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context, Tensor """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) for name, para in net.parameters_and_names():
print(name)
print(para.requires_grad)
if name == 'conv1.weight':
print(Tensor(para)) print("*"*40) for name, para in net_opt.parameters_and_names():
print(name)
print(para.requires_grad)
if name == 'conv1.weight':
print(Tensor(para))

运行结果:

conv1.weight
True
[[[[-0.00123636 -0.00018152  0.00089368 -0.00056078 -0.02275192]
   [ 0.00409119  0.0010135  -0.00466038 -0.00555031 -0.00918423]
   [-0.00442939  0.00621323  0.00214683 -0.00155054 -0.00546987]
   [ 0.01697116 -0.00559946  0.00357803  0.03168541  0.00407573]
   [-0.00518463 -0.01200203  0.01070325 -0.02007808  0.00738484]]]

 [[[ 0.00341292  0.0079666  -0.00499046 -0.00656943 -0.00331597]
   [-0.01387733  0.00665028 -0.01610895 -0.00282408  0.0092861 ]
   [-0.01939811 -0.01994145  0.01014557  0.00459681  0.00120816]
   [-0.00354739  0.00169169  0.00359304  0.00019773  0.00124371]
   [-0.0075929  -0.02099637  0.01632461 -0.02093766  0.00231244]]]

 [[[ 0.0088163  -0.01221289 -0.01604474 -0.00574877 -0.00278494]
   [ 0.0068464  -0.01448571 -0.00408135 -0.00037711  0.01360335]
   [ 0.00826573  0.0063943  -0.00635501 -0.01091845 -0.01706182]
   [-0.01376995  0.00267098 -0.01873252 -0.00560728 -0.0133691 ]
   [ 0.00562847  0.0048407   0.01391821  0.00568764  0.01011486]]]

 [[[ 0.00413718  0.00476703  0.00920789 -0.01249459  0.01619304]
   [ 0.01443657 -0.02348764  0.0085768   0.00959142 -0.00631981]
   [-0.00826734  0.00130019  0.00431718 -0.01096678  0.00586409]
   [-0.01054094  0.01216885  0.00910433 -0.00326026  0.00994863]
   [ 0.00993542  0.00768977 -0.00420083  0.00905468 -0.0049615 ]]]

 [[[-0.00463446 -0.00677943 -0.00506198 -0.00308914  0.01606419]
   [ 0.00844193 -0.00854285  0.0003332  -0.01010361  0.01140079]
   [ 0.00595709  0.00572435 -0.00393711  0.00326021 -0.00986465]
   [-0.0090545  -0.00300089  0.00010969 -0.03852516  0.00215564]
   [-0.01172458 -0.01011858  0.01508922  0.00723284  0.00269153]]]

 [[[-0.00602197 -0.00078419 -0.00048669  0.00453082  0.00515535]
   [ 0.00237266 -0.00097092 -0.00680392 -0.00715334  0.01152472]
   [-0.00824045 -0.0188182   0.00147573 -0.00263265 -0.00235698]
   [ 0.00553491 -0.01060611  0.01170796  0.00063573  0.00259822]
   [-0.00482674 -0.01767036  0.01275289  0.00904524  0.00328132]]]]
conv2.weight
True
fc1.weight
True
fc1.bias
True
fc2.weight
True
fc2.bias
True
fc3.weight
True
fc3.bias
True
****************************************
learning_rate
True
conv1.weight
True
[[[[-0.00123636 -0.00018152  0.00089368 -0.00056078 -0.02275192]
   [ 0.00409119  0.0010135  -0.00466038 -0.00555031 -0.00918423]
   [-0.00442939  0.00621323  0.00214683 -0.00155054 -0.00546987]
   [ 0.01697116 -0.00559946  0.00357803  0.03168541  0.00407573]
   [-0.00518463 -0.01200203  0.01070325 -0.02007808  0.00738484]]]

 [[[ 0.00341292  0.0079666  -0.00499046 -0.00656943 -0.00331597]
   [-0.01387733  0.00665028 -0.01610895 -0.00282408  0.0092861 ]
   [-0.01939811 -0.01994145  0.01014557  0.00459681  0.00120816]
   [-0.00354739  0.00169169  0.00359304  0.00019773  0.00124371]
   [-0.0075929  -0.02099637  0.01632461 -0.02093766  0.00231244]]]

 [[[ 0.0088163  -0.01221289 -0.01604474 -0.00574877 -0.00278494]
   [ 0.0068464  -0.01448571 -0.00408135 -0.00037711  0.01360335]
   [ 0.00826573  0.0063943  -0.00635501 -0.01091845 -0.01706182]
   [-0.01376995  0.00267098 -0.01873252 -0.00560728 -0.0133691 ]
   [ 0.00562847  0.0048407   0.01391821  0.00568764  0.01011486]]]

 [[[ 0.00413718  0.00476703  0.00920789 -0.01249459  0.01619304]
   [ 0.01443657 -0.02348764  0.0085768   0.00959142 -0.00631981]
   [-0.00826734  0.00130019  0.00431718 -0.01096678  0.00586409]
   [-0.01054094  0.01216885  0.00910433 -0.00326026  0.00994863]
   [ 0.00993542  0.00768977 -0.00420083  0.00905468 -0.0049615 ]]]

 [[[-0.00463446 -0.00677943 -0.00506198 -0.00308914  0.01606419]
   [ 0.00844193 -0.00854285  0.0003332  -0.01010361  0.01140079]
   [ 0.00595709  0.00572435 -0.00393711  0.00326021 -0.00986465]
   [-0.0090545  -0.00300089  0.00010969 -0.03852516  0.00215564]
   [-0.01172458 -0.01011858  0.01508922  0.00723284  0.00269153]]]

 [[[-0.00602197 -0.00078419 -0.00048669  0.00453082  0.00515535]
   [ 0.00237266 -0.00097092 -0.00680392 -0.00715334  0.01152472]
   [-0.00824045 -0.0188182   0.00147573 -0.00263265 -0.00235698]
   [ 0.00553491 -0.01060611  0.01170796  0.00063573  0.00259822]
   [-0.00482674 -0.01767036  0.01275289  0.00904524  0.00328132]]]]
conv2.weight
True
fc1.weight
True
fc1.bias
True
fc2.weight
True
fc2.bias
True
fc3.weight
True
fc3.bias
True
momentum
True
moments.conv1.weight
True
moments.conv2.weight
True
moments.fc1.weight
True
moments.fc1.bias
True
moments.fc2.weight
True
moments.fc2.bias
True
moments.fc3.weight
True
moments.fc3.bias
True



从运行结果中可以看到, 优化器的参数中本身就包含了网络的训练参数,那么上面的 手动保存参数的文件就只写优化器那部分就可以了,具体如下:

#!/usr/bin python
# encoding:UTF-8 """" 对输入的超参数进行处理 """
import os
import argparse """ 设置运行的背景context """
from mindspore import context """ 对数据集进行预处理 """
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype """ 构建神经网络 """
import mindspore.nn as nn
from mindspore.common.initializer import Normal """ 训练时对模型参数的保存 """
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.serialization import save_checkpoint """ 导入模型训练需要的库 """
from mindspore.nn import Accuracy
from mindspore.train.callback import LossMonitor
from mindspore import Model import os
os.system('rm -f *.ckpt *.meta') parser = argparse.ArgumentParser(description='MindSpore LeNet Example')
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU']) args = parser.parse_known_args()[0] # 为mindspore设置运行背景context
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
# 定义数据集
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081 # 定义所需要操作的map映射
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32) # 使用map映射函数,将数据操作应用到数据集
mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 进行shuffle、batch、repeat操作
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds class LeNet5(nn.Cell):
"""
Lenet网络结构
""" def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
# 定义所需要的运算
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() def construct(self, x):
# 使用定义好的运算构建前向网络
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x # 实例化网络
net = LeNet5() # 定义损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 定义优化器
net_opt = nn.Momentum(net.trainable_params(), learning_rate=0.01, momentum=0.9) # 设置模型保存参数
# 每125steps保存一次模型参数,最多保留15个文件
config_ck = CheckpointConfig(save_checkpoint_steps=125, keep_checkpoint_max=15)
# 应用模型保存参数
ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) def train_net(args, model, epoch_size, data_path, repeat_size, ckpoint_cb, sink_mode):
"""定义训练的方法"""
# 加载训练数据集
ds_train = create_dataset(os.path.join(data_path, "train"), 32, repeat_size)
model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(125)], dataset_sink_mode=sink_mode) def test_net(network, model, data_path):
"""定义验证的方法"""
ds_eval = create_dataset(os.path.join(data_path, "test"))
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("{}".format(acc)) mnist_path = "./datasets/MNIST_Data"
train_epoch = 1
dataset_size = 1
model = Model(net, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) #train_net(args, model, train_epoch, mnist_path, dataset_size, ckpoint, False)
#test_net(net, model, mnist_path) save_checkpoint(net_opt, './net_opt_parameters.ckpt')

而,./net_opt_parameters.ckpt   文件中就已经包含了优化器的参数以及网络中的可训练参数。

所以,手动保存参数的最终代码形式就是:

save_checkpoint(net_opt, './net_opt_parameters.ckpt')


在计算框架MindSpore中手动保存参数变量(Parameter 变量)—— from mindspore.train.serialization import save_checkpoint的更多相关文章

  1. struts2:JSP页面及Action中获取HTTP参数(parameter)的几种方式

    本文演示了JSP中获取HTTP参数的几种方式,还有action中获取HTTP参数的几种方式. 1. 创建JSP页面(testParam.jsp) <%@ page language=" ...

  2. c++中的const参数,const变量,const指针,const对象,以及const成员函数

    const 是constant 的缩写,“恒定不变”的意思.被const 修饰的东西都受到强制保护,可以预防意外的变动,能提高程序的健壮性.所以很多C++程序设计书籍建议:“Use const whe ...

  3. 【Mysql】了解Mysql中的启动参数和系统变量

    一.启动参数 在程序启动时指定的设置项也称之为启动选项(startup options),这些选项控制着程序启动后的行为. 1)在命令行上使用选项 启动服务器程序的命令行后边指定启动选项的通用格式就是 ...

  4. 解决 Flask 项目无法用 .env 文件中解析的参数设置环境变量的错误

    在 Windows 上启动 Flask 项目时,工作目录有 UTF-8 编码的 .env 文件,里面配置的环境变量在 Python2 中识别为 Unicode 类型,导致下述错误: * Serving ...

  5. BS中保存参数

    开发中经常需要将值存起来,当点击某一项时以便知道点击了哪一项. 一:应用JS页面跳转(牛腩中讲到) HTML: <td class="txt c"><a href ...

  6. C语言:根据以下公式计算s,s=1+1/(1+2)+1/(1+2+3)+...+1/(1+2+3+...+n) -在形参s所指字符串中寻找与参数c相同的字符,并在其后插入一个与之相同的字符,

    //根据一下公式计算s,并将计算结果作为函数返回值,n通过形参传入.s=1+1/(1+2)+1/(1+2+3)+...+1/(1+2+3+...+n) #include <stdio.h> ...

  7. spark2.4.5计算框架中各模块的常用实例

    本项目是使用scala语言给出了spark2.4.5计算框架中各模块的常用实例. 温馨提醒:spark的版本与scala的版本号有严格的对应关系,安装请注意. Spark Core RDD以及Pair ...

  8. .NET框架- in ,out, ref , paras使用的代码总结 C#中in,out,ref的作用 C#需知--长度可变参数--Params C#中的 具名参数 和 可选参数 DEMO

    C#.net 提供的4个关键字,in,out,ref,paras开发中会经常用到,那么它们如何使用呢? 又有什么区别? 1 in in只用在委托和接口中: 例子: 1 2 3 4 5 6 7 8 9 ...

  9. 带你学习MindSpore中算子使用方法

    摘要:本文分享下MindSpore中算子的使用和遇到问题时的解决方法. 本文分享自华为云社区<[MindSpore易点通]算子使用问题与解决方法>,作者:chengxiaoli. 简介 算 ...

  10. 如何选取一个神经网络中的超参数hyper-parameters

    1.什么是超参数 所谓超参数,就是机器学习模型里面的框架参数.比如聚类方法里面类的个数,或者话题模型里面话题的个数等等,都称为超参数.它们跟训练过程中学习的参数(权重)是不一样的,通常是手工设定的,经 ...

随机推荐

  1. email邮件(带附件,模拟文件上传,跨服务器)发送核心代码 Couldn't connect to host, port: smtp.163.com, 25; timeout -1;

    邮件(带附件,模拟文件上传,跨服务器)发送核心代码1.测试邮件发送附件接口 /** * 测试邮件发送附件 * @param multipartFile * @return */ @RequestMap ...

  2. Linux命令行配置RIAD5

    环境准备: 系统: redhat6.9 硬盘:300G*3 SAS MegaCli是一款管理维护硬件RAID软件,可以用来查看raid信息等 1. 安装MegaCli rpm -ivh Lib_Uti ...

  3. ARM GIC 系列文章学习(转)

    原文来自:骏的世界 ARM GIC(一) cortex-A 处理器中断简介 对于ARM的处理器,中断给处理器提供了触觉,使处理器能够感知到外界的变化,从而实时的处理.本系列博文,是以ARM corte ...

  4. 79元国产ARM+DSP平台FFT实测分享

    T113-i国产ARM+DSP架构介绍 创龙科技SOM-TLT113是一款基于国产全志T113-i双核ARM Cortex-A7 +  HiFi4 DSP + 玄铁C906 RISC-V异构多核处理器 ...

  5. [golang]在Gin框架中使用JWT鉴权

    什么是JWT JWT,全称 JSON Web Token,是一种开放标准(RFC 7519),用于安全地在双方之间传递信息.尤其适用于身份验证和授权场景.JWT 的设计允许信息在各方之间安全地. co ...

  6. ubuntu 使用natapp配置内网穿透

    前言 在自己的服务器上起了服务,但由于域名还没申请下来,无法使用域名测试微信公众号接口,辛亏看到了这个博客:Natapp内网穿透服务工具.跟随这篇博客,我搭建了自己的内网穿透服务,现在记录如下. 过程 ...

  7. MyBatis学习篇

    什么是MyBatis (1)Mybatis是一个半ORM(对象关系映射)框架,它内部封装了JDBC,开发时只需要关注SQL语句本身,不需要花费精力去处理加载驱动.创建连接.创建statement等繁杂 ...

  8. 全新发布!桌面端效率工具RunFlow

    RunFlow是一款跨平台的生产力工具,可以启动应用程序和搜索文件等,类似于Windows平台的Wox和PowerToys,同样也类似于Mac平台的Alfred和Raycast.但我们并不与这些工具相 ...

  9. [oeasy]python0051_ 转义_escape_字符_character_单引号_双引号_反引号_ 退格键

    转义字符 回忆上次内容 上次研究的是进制转化 10进制可以转化为其他形式 bin oct hex 其他进制也可以转化为10进制 int 可以设置base来决定转为多少进制 回忆一下 我们为什么会有八进 ...

  10. AT_arc149_a 题解

    洛谷链接&Atcoder 链接 本篇题解为此题较简单做法及较少码量,并且码风优良,请放心阅读. 题目简述 求满足以下条件的小于 \(10 ^ n\) 数最大是多少? 每一位数字均相同: 是 \ ...