在之前,我们实现了使用torch自带的层对fashion_mnist数据集进行分类。这次,我们加入一个自己实现的block,实现一个四层的多层感知机进行softmax分类,作为对“自定义块”的代码实现的一个练习。

我们设计的多层感知机是这样的:输入维度为784,在展平层过后,第一层为全连接层,输入输出维度分别为784,256;第二层为全连接层,输入输出维度分别为256,128;第三层为全连接层,输入输出维度分别为128,64;第四层为全连接层(输出层),输入输出维度分别为64,10.代码如下:

  1. import torch
  2. from d2l import torch as d2l
  3. from torch import nn
  4. from torch.nn import functional as F
  5.  
  6. batch_size = 256
  7. train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
  8. num_inputs = 784
  9. num_outputs = 10
  10.  
  11. #输入层784; 隐藏层一784,256;隐藏层二256,128; 隐藏层三128,64; 输出层64,10
  12. #我们用自定义Module实现隐藏层二、隐藏层三。
  13. class practice_Module(nn.Module):
  14. def __init__(self):
  15. super().__init__()
  16. self.lin1 = nn.Linear(256,128)
  17. self.lin2 = nn.Linear(128,64)
  18. nn.init.normal_(self.lin1.weight,std=0.01)
  19. nn.init.normal_(self.lin2.weight,std=0.01)
  20. def forward(self,X):
  21. X = self.lin1(X)
  22. X = F.relu(X)
  23. X = self.lin2(X)
  24. X = F.relu(X)
  25. return X
  26.  
  27. manual_block = practice_Module()
  28. net = nn.Sequential(nn.Flatten(),
  29. nn.Linear(784,256),
  30. nn.ReLU(),
  31. nn.Dropout(0.2),
  32. manual_block,
  33. nn.Dropout(0.3),
  34. nn.Linear(64,10)
  35. )
  36.  
  37. def init_weight(m):
  38. if m == nn.Linear:
  39. nn.init.normal_(m.weight,std=0.01)
  40. return
  41. net.apply(init_weight)
  42.  
  43. loss = torch.nn.CrossEntropyLoss(reduction='none')
  44. trainer = torch.optim.SGD(net.parameters(),lr=0.1)
  45. num_epochs = 20
  46. d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)

首先在我们自定义的模块中,初始化函数__init__中定义我们需要的两个层lin1和lin2.上面的代码在抽象的类practice_Module中的初始化函数__init__中进行了参数初始化,也就是说默认情况下用这个类创建的所有对象都会进行这样的默认初始化。

当然,也可以按我们的需要对具体的模块对象进行参数初始化。

然后在forward函数中定义这个模块进行的操作,即先让数据经过线性层lin1,激活,再经过线性层lin2,激活,然后输出。return X语句,return的X值作为输出,就会作为nn.Sequential中的下一层输入。

注意:这里面的前向传播函数名必须是forward,而不能是其他的,改成其他的就会报错:

  1. Module [practice_Module] is missing the required "forward" function

这也是为什么practice_Module类的实例在nn.Sequential中可以自动计算的原因,是因为系统会自动找到该实例中的方法forward并执行。

下面的语句初始化了一个practice_Module类的实例。可以这样理解:practice_Module是一个抽象的网络结构,而manual_block这个实例才是一个具体的我们需要的模型(包含具体参数)。

可以用如下代码对自定义的模块实例进行初始化。这里,nn.init.normal_()可以对nn.Module的子类的某一具体的层进行参数初始化。

在nn.Sequential中加入我们自定义的模块是非常简单的:

init_weight()函数对torch中定义好了的层进行参数初始化:

加入自定义块对fashion_mnist数据集进行softmax分类的更多相关文章

  1. tensorflow 离线使用 fashion_mnist 数据集

    在tensflow中加载 fashion_mnist 数据集时,由于网络原因.可能会长时间加载不到或报错 此时我们可以通过离线的方式加载 1.首先下载数据集:fashion_mnist (下载后解压) ...

  2. 学习笔记TF010:softmax分类

    回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...

  3. 从零和使用mxnet实现softmax分类

    1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...

  4. 器学习算法(六)基于天气数据集的XGBoost分类预测

    1.机器学习算法(六)基于天气数据集的XGBoost分类预测 1.1 XGBoost的介绍与应用 XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统.严格意义上讲XGBo ...

  5. tensorflow 使用 5 mnist 数据集, softmax 函数

    用于分类  softmax 函数 手写数据识别:

  6. softmax分类算法原理(用python实现)

    逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...

  7. 动手学深度学习7-从零开始完成softmax分类

    获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...

  8. 利用keras自带路透社数据集进行多分类训练

    import numpy as np from keras.datasets import reuters from keras import layers from keras import mod ...

  9. 用MATLAB的Classficiation Learner工具箱对12个数据集进行各种分类与验证

    准备材料 以所有的特征集作为variable进行像Bayes吖.SVM吖.决策树吖......分类.同时对数据进行预处理,选出相关度高的特征子集作为新的一组data进行分类(预处理的代码不必放出来). ...

  10. 机器学习-MNIST数据集使用二分类

    一.二分类训练MNIST数据集练习 %matplotlib inlineimport matplotlibimport numpy as npimport matplotlib.pyplot as p ...

随机推荐

  1. Springboot K8s Job 一次性任务 如何禁用端口监听

    问题:SpringBoot一次性任务执行时,也会默认监听服务端口,当使用k8s job运行时,可能多个pod执行存在端口冲突 解决办法:命令行禁用SpringBoot一次性任务启动时端口占用 java ...

  2. 数据处理——IF函数求同时满足多个条件 多个条件满足一个以上

    以满足两个条件为例,满足多个条件类似 以如下案例为例进行说明: 一.IF公式同时满足多个条件 此例也可使用函数的嵌套,对于函数使用掌握不牢的新手,嵌套使用会有些困难,以下方法针对刚入门学习参考 1.利 ...

  3. Git进阶命令-reset

    之前有关Git,写过一片文章: Git五个常见问题及解决方法 一.reset命令使用场景 有时候我们提交了一些错误的或者不完善的代码,需要回退到之前的某个稳定的版本,面对这种情况有两种解决方法: 解决 ...

  4. C++ 字面值的前缀和后缀

    一般字符字面值用前缀,数字字面值用后缀: --C++ Primer第五版2.1.3

  5. 记一次由虚假唤醒产生的bug

    记一次由虚假唤醒产生的bug 用int a代表产品数量最少0最多10,有两个生产者,三个消费者,用多线程和条件变量模拟生产消费过程: #include <sys/types.h> #inc ...

  6. 使用Servlet实现文件下载

    一位朋友最近在学习JavaWeb开发,开始学习文件下载操作,他自己尝试着去网上看一些教程,总的来说也不是太了解,就让我和他说说,如何实现文件下载功能.我和他说了一下大致的思路,主要分为前端和后端两部分 ...

  7. Salesforce LWC学习(四十三) lwc 零基础学习路径的视频已上传B站

    本篇参考:https://www.bilibili.com/video/BV1QM411G7pN/ 还记得salesforce零基础学习(一百二十五)零基础学习SF路径 中描述的那样,预计今年年底以前 ...

  8. Redis无法向磁盘写入RBD数据

    2020-12-09 11:52:25|21965|ERROR|storage/DRedisAsyncCallback.cpp:394[cbIncrby]Cmd 'INCRBY' failed, ke ...

  9. KingbaseESV8R6识别IO使用率过高

    前言 数据库正常运行离不开I/O的使用,在操作系统上,I/O又离不开存储的性能及使用方式,我们可以在存储层利用raid条带化技术使IOPS达到最佳性能. 本篇文章有助于确认数据库I/O使用率过高的原因 ...

  10. KingbaseES V8R6备份恢复系列之 -- system-Id不匹配备份故障

    ​ KingbaseES V8R6备份恢复案例之---system-Id不匹配备份故障 案例说明: 在KingbaseES V8R6执行备份时,在sys_log日志中出现system-id不一致的故障 ...