训练一个神经网络往往只需要简单的几步:

  1. 准备训练数据
  2. 初始化模型的参数
  3. 模型向往计算与向后计算
  4. 更新模型参数
  5. 设置相关的checkpoint

如果上述的每个步骤都需要我们写Python的代码去一步步实现,未免显的繁琐,好在MXNet提供了Module模块来解决这个问题,Module把训练和推理中一些常用到的步骤代码进行了封装。对于一定已经用Symbol定义好的神经网络,我们可以很容易的使用Module提供的一些高层次接口或一些中间层次的接口来让整个训练或推理容易操作起来。

下面我们将通过在UCI letter recognition数据集上训练一个多层感知机来说明Module模块的用法。

第一步 加载一个数据集

我们先下载一个数据集,然后按80:20的比例划分训练集与测试集。我们通过MXNet的IO模块提供的数据迭代器每次返回一个batch size =32的训练样本

import logging
logging.getLogger().setLevel(logging.INFO)
import mxnet as mx
import numpy as np # 数据以文本形式保存,每行一个样本,每一行数据之间用','分割,每一个字符为label
fname = mx.test_utils.download('http://archive.ics.uci.edu/ml/machine-learning-databases/letter-recognition/letter-recognition.data')
data = np.genfromtxt(fname, delimiter=',')[:,1:]
label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')]) batch_size = 32
ntrain = int(data.shape[0]*0.8)
train_iter = mx.io.NDArrayIter(data[:ntrain, :], label[:ntrain], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data[ntrain:, :], label[ntrain:], batch_size)

第二步 定义一个network

net = mx.sym.var('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type='relu')
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=26)
net = mx.sym.SoftmaxOutput(net, name='softmax')
mx.viz.plot_network(net)

第三步 创建一个Module

我们可以通过mx.mod.Module接口创建一个Module对象,它接收下面几个参数:

  • symbol:神经网络的定义
  • context:执行运算的设备
  • data_names:网络输入数据的列表
  • label_names:网络输入标签的列表

对于我们在第二步定义的net,只有一个输入数据即data,输入标签名为softmax_label,这个是我们在使用SoftmaxOutput操作时,自动命名的。

mod = mx.mod.Module(symbol=net,
context=mx.cpu(),
data_names=['data'],
label_names=['softmax_label'])

Module的中间层次的接口

中间层次的接口主要是为了给开发者足够的灵活性,也方便排查问题。我们下面会先列出来Moduel模块有哪些常见的中间层API,然后再利用这个API来训练我们刚才定义的网络。

  • bind:绑定输入数据的形状,分配内存
  • init_params:初始化网络参数
  • init_optimizer:指定优化方法,比如sgd
  • metric.create:指定评价方法
  • forward:向前计算
  • update_metric:根据上一次的forward结果,更新评价指标
  • backward:反射传播
  • update:根据优化方法和梯度更新模型的参数
# allocate memory given the input data and label shapes
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# initialize parameters by uniform random numbers
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# use SGD with learning rate 0.1 to train
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# use accuracy as the metric
metric = mx.metric.create('acc')
# train 5 epochs, i.e. going over the data iter one pass
for epoch in range(5):
train_iter.reset()
metric.reset()
for batch in train_iter:
mod.forward(batch, is_train=True) # compute predictions
mod.update_metric(metric, batch.label) # accumulate prediction accuracy
mod.backward() # compute gradients
mod.update() # update parameters
print('Epoch %d, Training %s' % (epoch, metric.get()))

Module 高层次的API

训练

Moudle模块同时提供了高层次的API来完成训练、预测和评估。不像使用中间层次API那样繁琐,我们只需要一个接口fit就可以完成上面的步骤。

# reset train_iter to the beginning
train_iter.reset() # create a module
mod = mx.mod.Module(symbol=net,
context=mx.cpu(),
data_names=['data'],
label_names=['softmax_label']) # fit the module
mod.fit(train_iter,
eval_data=val_iter,
optimizer='sgd',
optimizer_params={'learning_rate':0.1},
eval_metric='acc',
num_epoch=8)

预测和评估

使用Moudle.predict可以得到数据的predict的结果。如果我们对结果不关心,我们可以使用score接口直接计算验证数据集的准确率。

y = mod.predict(val_iter)
score = mod.score(val_iter, ['acc'])
print("Accuracy score is %f" % (score[0][1]))

上面的代码中我们使用了acc来计算准确率,我们还可以设置其他评估方法,如:top_k_acc,F1,RMSE,MSE,MAE,ce等。

训练模型的保存

我们可以通过设计一个checkpoint calback来在训练过程中每个epoch结束后保存模型的参数

# construct a callback function to save checkpoints
model_prefix = 'mx_mlp'
checkpoint = mx.callback.do_checkpoint(model_prefix) mod = mx.mod.Module(symbol=net)
mod.fit(train_iter, num_epoch=5, epoch_end_callback=checkpoint)

使用load_checkpoint来加载已经保存的模型参数,随后我们可以把这些参数加载到Moudle中

sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
# assign the loaded parameters to the module
mod.set_params(arg_params, aux_params)

我们也可以不使用set_params,而是直接在fit接口中指定已经保存的checkpoint的参数,这些保存的参数会替代fit原本的参数初始化。

mod = mx.mod.Module(symbol=sym)
mod.fit(train_iter,
num_epoch=21,
arg_params=arg_params,
aux_params=aux_params,
begin_epoch=3)

利用Module模块把构建的神经网络跑起来的更多相关文章

  1. 利用iisnode模块,让你的Node.js应用跑在Windows系统IIS中

    最近比较喜欢用Node.js做一些简单的应用,一直想要部署到生产环境中,但是手上只有一台windows server 2008服务器,并且已经开启了IIS服务,运行了很多.Net开发的网站,80端口已 ...

  2. Azure DevOps(二)利用Azure DevOps Pipeline 构建基础设施资源

    一,引言 上一篇文章记录了利用 Azure DevOps 跨云进行构建 Docker images,并且将构建好的 Docker Images 推送到 AWS 的 ECR 中.今天我们继续讲解 Azu ...

  3. 利用 Rational ClearCase ClearMake 构建高性能的企业级构建环境

    转载地址:http://www.ibm.com/developerworks/cn/rational/r-cn-clearmakebuild/ 构建管理是 IBM® Rational® ClearCa ...

  4. TensorFlow构建卷积神经网络/模型保存与加载/正则化

    TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...

  5. maven多模块项目构建

    描述 一个大的企业级项目通常跨越了数十万行代码,牵涉了数十或数百软件人员的努力.如果开发者在同一个项目下开   发,那么项目的管理.构建将会变得很难控制.因此设计人员会将项目划分为多个模块,多个模块独 ...

  6. docker微服务部署之:五、利用DockerMaven插件自动构建镜像

    docker微服务部署之:四.安装docker.docker中安装mysql和jdk1.8.手动构建镜像.部署项目 在上一篇文章中,我们是手动构建镜像,即: 4.1.2.5.1.2.6.1.2中的将d ...

  7. 利用angular4和nodejs-express构建一个简单的网站(五)—用户的注册和登录-HttpClient

    上一节简单介绍了一下利用angular构建的主路由模块,根据上一节的介绍,主页面加载时直接跳转到用户管理界面,下面就来介绍一下用户管理模块.启动应用后,初始界面应该是这样的: 用户管理模块(users ...

  8. nginx利用geo模块做限速白名单以及geo实现全局负载均衡的操作记录

    geo指令使用ngx_http_geo_module模块提供的.默认情况下,nginx有加载这个模块,除非人为的 --without-http_geo_module.ngx_http_geo_modu ...

  9. 利用Nutch和Tomcat构建搜索引擎

    利用Nutch和Tomcat构建搜索引擎 1.安装环境及软件版本介绍 本教程是在Linux Ubuntu 12.04 desktop i386操作系统上搭建,结合使用了Nutch-1.2和Apache ...

随机推荐

  1. CentOS7+CDH5.14.0安装全流程记录,图文详解全程实测-总目录

    CentOS7+CDH5.14.0安装全流程记录,图文详解全程实测-总目录: 0.Windows 10本机下载Xshell,以方便往Linux主机上上传大文件 1.CentOS7+CDH5.14.0安 ...

  2. vue全局API

    一.Vue.extend() 顾名思义  extend  继承,官方给出的解释是   (使用基础 Vue 构造器,创建一个“子类”.参数是一个包含组件选项的对象.) Vue构造器是指  vue是一个构 ...

  3. spring的事务操作(重点)

    这篇文章一起来回顾复习下spring的事务操作.事务是spring的重点, 也是面试的必问知识点之一. 说来这次面试期间,也问到了我,由于平时用到的比较少,也没有关注过这一块的东西,所以回答的不是特别 ...

  4. python词频统计及其效能分析

    1) 博客开头给出自己的基本信息,格式建议如下: 学号2017****7128 姓名:肖文秀 词频统计及其效能分析仓库:https://gitee.com/aichenxi/word_frequenc ...

  5. 设计模式之jdk动态代理模式、责任链模式-java实现

    设计模式之JDK动态代理模式.责任链模式 需求场景 当我们的代码中的类随着业务量的增大而不断增大仿佛没有尽头时,我们可以考虑使用动态代理设计模式,代理类的代码量被固定下来,不会随着业务量的增大而增大. ...

  6. vba判断文件是否存在的两种方法(转)

    方法1. 用VBA自带的dir()判断,代码如下: 在 Microsoft Windows 中, Dir 支持多字符 (*)和单字符 (?) 的通配符来指定多重文件 Function IsFileEx ...

  7. EasyPR源码剖析(6):车牌判断之LBP特征

    一.LBP特征 LBP指局部二值模式,英文全称:Local Binary Pattern,是一种用来描述图像局部特征的算子,LBP特征具有灰度不变性和旋转不变性等显著优点. 原始的LBP算子定义在像素 ...

  8. java上传文件常见几种方式

    1.ServletFileUpload 表单提交中当提交数据类型是multipare/form-data类型的时候,如果我们用servlet去做处理的话,该http请求就会被servlet容器,包装成 ...

  9. MySQL优化:使用show status查看MySQL服务器状态信息

    在网站开发过程中,有些时候我们需要了解MySQL的服务器状态信息,譬如当前MySQL启动后的运行时间,当前MySQL的客户端会话连接数,当前MySQL服务器执行的慢查询数,当前MySQL执行了多少SE ...

  10. Django之发送邮件

    Django的发送邮件是基于django的一个组件进行操作的,EmailMessage 基本使用方法: def send_html_mail(subject, html_content, from_a ...