https://mxnet.incubator.apache.org/tutorials/basic/module.html

  1. import logging
  2. import random
  3. logging.getLogger().setLevel(logging.INFO)
  4.  
  5. import mxnet as mx
  6. import numpy as np
  7.  
  8. mx.random.seed(1234)
  9. np.random.seed(1234)
  10. random.seed(1234)
  11.  
  12. # 准备数据
  13. fname = mx.test_utils.download('https://s3.us-east-2.amazonaws.com/mxnet-public/letter_recognition/letter-recognition.data')
  14. data = np.genfromtxt(fname=fname,delimiter=',')[:,1:]
  15. label = np.array([ord(l.split(',')[0])-ord('A') for l in open(fname, 'r')])
  16.  
  17. batch_size = 32
  18. ntrain = int(data.shape[0]*0.8)
  19.  
  20. train_iter = mx.io.NDArrayIter(data[:ntrain,:],label[:ntrain],batch_size,shuffle=True)
  21. val_iter = mx.io.NDArrayIter(data[ntrain:,:],label[ntrain:],batch_size)
  22.  
  23. # 定义网络
  24. net = mx.sym.Variable('data')
  25. net = mx.sym.FullyConnected(net, name='fc1', num_hidden=64)
  26. net = mx.sym.Activation(net, name='relu1', act_type="relu")
  27. net = mx.sym.FullyConnected(net, name='fc2', num_hidden=26)
  28. net = mx.sym.SoftmaxOutput(net, name='softmax')
  29. mx.viz.plot_network(net, node_attrs={"shape":"oval","fixedsize":"false"})
  30.  
  31. # # 创建模块
  32. mod = mx.mod.Module(symbol=net,
  33. context=mx.cpu(),
  34. data_names=['data'],
  35. label_names=['softmax_label'])
  36.  
  37. # # 中层接口
  38. # # 训练模型
  39. # mod.bind(data_shapes=train_iter.provide_data,label_shapes=train_iter.provide_label)
  40. # mod.init_params(initializer=mx.init.Uniform(scale=.1))
  41. # mod.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',0.1),))
  42. # metric = mx.metric.create('acc')
  43. #
  44. # for epoch in range(100):
  45. # train_iter.reset()
  46. # metric.reset()
  47. # for batch in train_iter:
  48. # mod.forward(batch,is_train=True)
  49. # mod.update_metric(metric,batch.label)
  50. # mod.backward()
  51. # mod.update()
  52. # print('Epoch %d,Training %s' % (epoch,metric.get()))
  53.  
  54. # fit 高层接口
  55. train_iter.reset()
  56. mod = mx.mod.Module(symbol=net,
  57. context=mx.cpu(),
  58. data_names=['data'],
  59. label_names=['softmax_label'])
  60.  
  61. mod.fit(train_iter,
  62. eval_data=val_iter,
  63. optimizer='sgd',
  64. optimizer_params={'learning_rate':0.1},
  65. eval_metric='acc',
  66. num_epoch=10)
  67.  
  68. # 预测和评估
  69. y = mod.predict(val_iter)
  70. assert y.shape == (4000,26)
  71.  
  72. # 评分
  73. score = mod.score(val_iter,['acc'])
  74. print("Accuracy score is %f"%(score[0][1]))
  75. assert score[0][1] > 0.76, "Achieved accuracy (%f) is less than expected (0.76)" % score[0][1]
  76.  
  77. # 保存和加载
  78. # 构造一个回调函数保存检查点
  79. model_prefix = 'mx_mlp'
  80. checkpoint = mx.callback.do_checkpoint(model_prefix)
  81.  
  82. mod = mx.mod.Module(symbol=net)
  83. mod.fit(train_iter,num_epoch=5,epoch_end_callback=checkpoint)
  84.  
  85. sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
  86. assert sym.tojson() == net.tojson()
  87.  
  88. # assign the loaded parameters to the module
  89. mod.set_params(arg_params, aux_params)
  90.  
  91. mod = mx.mod.Module(symbol=sym)
  92. mod.fit(train_iter,
  93. num_epoch=21,
  94. arg_params=arg_params,
  95. aux_params=aux_params,
  96. begin_epoch=3)
  97. assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]

mxnet 神经网络训练和预测的更多相关文章

  1. 吴裕雄 python 神经网络——TensorFlow 使用卷积神经网络训练和预测MNIST手写数据集

    import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_dat ...

  2. 利用Matlab神经网络计算包预测近四天除湖北外新增确诊人数:拐点已现

    数据来源: 国家卫健委 已经7连降咯! 1.20-2.10图示(更新中): 神经网络训练并预测数据: clear %除湖北以外全国新增确诊病例数 2020.1.20-2.9 num=[5,44,62, ...

  3. ResNet网络的训练和预测

    ResNet网络的训练和预测 简介 Introduction 图像分类与CNN 图像分类 是指将图像信息中所反映的不同特征,把不同类别的目标区分开来的图像处理方法,是计算机视觉中其他任务,比如目标检测 ...

  4. 神经网络训练中的Tricks之高效BP(反向传播算法)

    神经网络训练中的Tricks之高效BP(反向传播算法) 神经网络训练中的Tricks之高效BP(反向传播算法) zouxy09@qq.com http://blog.csdn.net/zouxy09 ...

  5. mxnet的训练过程——从python到C++

    mxnet的训练过程--从python到C++ mxnet(github-mxnet)的python接口相当完善,我们可以完全不看C++的代码就能直接训练模型,如果我们要学习它的C++的代码,从pyt ...

  6. 神经网络训练tricks

    神经网络构建好,训练不出好的效果怎么办?明明说好的拟合任意函数(一般连续)(为什么?可以参考http://neuralnetworksanddeeplearning.com/),说好的足够多的数据(h ...

  7. tesorflow - create neural network+结果可视化+加速神经网络训练+Optimizer+TensorFlow

    以下仅为了自己方便查看,绝大部分参考来源:莫烦Python,建议去看原博客 一.添加层 def add_layer() 定义 add_layer()函数 在 Tensorflow 里定义一个添加层的函 ...

  8. TensorFlow实战第三课(可视化、加速神经网络训练)

    matplotlib可视化 构件图形 用散点图描述真实数据之间的关系(plt.ion()用于连续显示) # plot the real data fig = plt.figure() ax = fig ...

  9. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

随机推荐

  1. webpack+react

    一直提醒我这个.闹心最后发现是在不同的js 里引入组件的时候 import React from 'react'; 和 import React from 'React'; 就是大小写的问题. 解决办 ...

  2. [转]Add Bootstrap Glyphicon to Input Box

    本文转自:http://stackoverflow.com/questions/18838964/add-bootstrap-glyphicon-to-input-box How can I add ...

  3. Mysql插入Emoji表情出错

    Caused by: java.sql.SQLException: Incorrect at com.mysql.jdbc.SQLError.createSQLException(SQLError.j ...

  4. MongoDb 学习笔记(一) --- MongoDb 数据库介绍、安装、使用

    1.数据库和文件的主要区别 . 数据库有数据库表.行和列的概念,让我们存储操作数据更方便 . 数据库提供了非常方便的接口,可以让 nodejs.php java .net 很方便的实现增加修改删除功能 ...

  5. linux创建日期文件名

    linux创建文件名添加当前系统日期时间的方法 使用`date +%y%m%d` Example: mkdir `date +%y%m%d` tar cfvz /tmp/bak.`date +%y%m ...

  6. spring AOP Capability and Goals(面向方面编程功能和目标归纳)

    原官方文档链接: https://docs.spring.io/spring/docs/5.1.6.RELEASE/spring-framework-reference/core.html#aop-i ...

  7. CSS3自定义loading效果

    效果: 使用CSS3完成loading的制作 css样式: <style type="text/css"> .mask { position: fixed; left: ...

  8. FLASK日志记录

    from flask import Flask from flask_restful import Resource, Api import logging app = Flask(__name__) ...

  9. DOM Tree

    DOM Tree   什么是DOM树:网页的所有内容在内存当中,其实是以树形结构存储的 何时创建:当浏览器,读取html中内容的时候,会马上开始创建DOM树. 如何创建: 1.读到HTML的时候还没有 ...

  10. cf1043C. Smallest Word(贪心)

    题意 题目链接 Sol 这题打cf的时候真的是脑残,自己造了个abcdad的数据开心的玩了半天一脸懵逼...最后还好ycr大佬给了个思路不然就凉透了... 首先不难看出我们最后一定可以把字符串弄成\( ...