加入自定义块对fashion_mnist数据集进行softmax分类
在之前,我们实现了使用torch自带的层对fashion_mnist数据集进行分类。这次,我们加入一个自己实现的block,实现一个四层的多层感知机进行softmax分类,作为对“自定义块”的代码实现的一个练习。
我们设计的多层感知机是这样的:输入维度为784,在展平层过后,第一层为全连接层,输入输出维度分别为784,256;第二层为全连接层,输入输出维度分别为256,128;第三层为全连接层,输入输出维度分别为128,64;第四层为全连接层(输出层),输入输出维度分别为64,10.代码如下:
import torch
from d2l import torch as d2l
from torch import nn
from torch.nn import functional as F batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10 #输入层784; 隐藏层一784,256;隐藏层二256,128; 隐藏层三128,64; 输出层64,10
#我们用自定义Module实现隐藏层二、隐藏层三。
class practice_Module(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(256,128)
self.lin2 = nn.Linear(128,64)
nn.init.normal_(self.lin1.weight,std=0.01)
nn.init.normal_(self.lin2.weight,std=0.01)
def forward(self,X):
X = self.lin1(X)
X = F.relu(X)
X = self.lin2(X)
X = F.relu(X)
return X manual_block = practice_Module()
net = nn.Sequential(nn.Flatten(),
nn.Linear(784,256),
nn.ReLU(),
nn.Dropout(0.2),
manual_block,
nn.Dropout(0.3),
nn.Linear(64,10)
) def init_weight(m):
if m == nn.Linear:
nn.init.normal_(m.weight,std=0.01)
return
net.apply(init_weight) loss = torch.nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(),lr=0.1)
num_epochs = 20
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
首先在我们自定义的模块中,初始化函数__init__中定义我们需要的两个层lin1和lin2.上面的代码在抽象的类practice_Module中的初始化函数__init__中进行了参数初始化,也就是说默认情况下用这个类创建的所有对象都会进行这样的默认初始化。
当然,也可以按我们的需要对具体的模块对象进行参数初始化。
然后在forward函数中定义这个模块进行的操作,即先让数据经过线性层lin1,激活,再经过线性层lin2,激活,然后输出。return X语句,return的X值作为输出,就会作为nn.Sequential中的下一层输入。
注意:这里面的前向传播函数名必须是forward,而不能是其他的,改成其他的就会报错:
Module [practice_Module] is missing the required "forward" function
这也是为什么practice_Module类的实例在nn.Sequential中可以自动计算的原因,是因为系统会自动找到该实例中的方法forward并执行。
下面的语句初始化了一个practice_Module类的实例。可以这样理解:practice_Module是一个抽象的网络结构,而manual_block这个实例才是一个具体的我们需要的模型(包含具体参数)。
可以用如下代码对自定义的模块实例进行初始化。这里,nn.init.normal_()可以对nn.Module的子类的某一具体的层进行参数初始化。
在nn.Sequential中加入我们自定义的模块是非常简单的:
init_weight()函数对torch中定义好了的层进行参数初始化:
加入自定义块对fashion_mnist数据集进行softmax分类的更多相关文章
- tensorflow 离线使用 fashion_mnist 数据集
在tensflow中加载 fashion_mnist 数据集时,由于网络原因.可能会长时间加载不到或报错 此时我们可以通过离线的方式加载 1.首先下载数据集:fashion_mnist (下载后解压) ...
- 学习笔记TF010:softmax分类
回答多选项问题,使用softmax函数,对数几率回归在多个可能不同值上的推广.函数返回值是C个分量的概率向量,每个分量对应一个输出类别概率.分量为概率,C个分量和始终为1.每个样本必须属于某个输出类别 ...
- 从零和使用mxnet实现softmax分类
1.softmax从零实现 from mxnet.gluon import data as gdata from sklearn import datasets from mxnet import n ...
- 器学习算法(六)基于天气数据集的XGBoost分类预测
1.机器学习算法(六)基于天气数据集的XGBoost分类预测 1.1 XGBoost的介绍与应用 XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统.严格意义上讲XGBo ...
- tensorflow 使用 5 mnist 数据集, softmax 函数
用于分类 softmax 函数 手写数据识别:
- softmax分类算法原理(用python实现)
逻辑回归神经网络实现手写数字识别 如果更习惯看Jupyter的形式,请戳Gitthub_逻辑回归softmax神经网络实现手写数字识别.ipynb 1 - 导入模块 import numpy as n ...
- 动手学深度学习7-从零开始完成softmax分类
获取和读取数据 初始化模型参数 实现softmax运算 定义模型 定义损失函数 计算分类准确率 训练模型 小结 import torch import torchvision import numpy ...
- 利用keras自带路透社数据集进行多分类训练
import numpy as np from keras.datasets import reuters from keras import layers from keras import mod ...
- 用MATLAB的Classficiation Learner工具箱对12个数据集进行各种分类与验证
准备材料 以所有的特征集作为variable进行像Bayes吖.SVM吖.决策树吖......分类.同时对数据进行预处理,选出相关度高的特征子集作为新的一组data进行分类(预处理的代码不必放出来). ...
- 机器学习-MNIST数据集使用二分类
一.二分类训练MNIST数据集练习 %matplotlib inlineimport matplotlibimport numpy as npimport matplotlib.pyplot as p ...
随机推荐
- 2023山东省“技能兴鲁”职业技能大赛-学生组初赛wp
PWN pwn1 c++ pwn,cin 直接相当于 gets 了,程序有后门,保护基本没开,在 change 的最后一个输入点改掉返回地址为后门地址即可 from pwn import * cont ...
- [原创] KCP 源码分析(上)
KCP 协议是一种可靠的传输协议,对比 TCP 取消了累计确认(延迟 ACK).减小 RTO增长速度.选择性重传而非全部重传.通过用流量换取低时延. KCP 中最重要的两个数据结构IKCPCB和IKC ...
- 【C语言复习笔记】一些要点
[C语言复习笔记]一些要点 按学校教材复习的,整理的是我不熟悉的地方 最近C用的好少,快忘完了就赶紧整理一下(Python真好玩) 第一章 初识C语言 存储器 内存容量的大小,取决于地址总线的数量 \ ...
- 重塑元宇宙体验!3DCAT元宇宙实时云渲染解决方案来了
元宇宙作为人工智能.云计算和数字孪生等前沿技术的结合体,近年来越发受到各大企业重视. 元宇宙的应用场景层出不穷,不仅包括营销推广场景,还有品牌活动和电商销售,能有效提升品宣和商业转化效果. 元宇宙也具 ...
- View事件机制分析
目录介绍 01.Android中事件分发顺序 1.1 事件分发的对象是谁 1.2 事件分发的本质 1.3 事件在哪些对象间进行传递 1.4 事件分发过程涉及方法 1.5 Android中事件分发顺序 ...
- uni-app如何实现USB插入后自动弹出对应软件
这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 最近碰到了一个奇葩需求,要用uni-app来实现usb接入设备的时候,让软件自动弹出来,这里给出我制作的过程和参考的各种思路,希望对大家有 ...
- GIT版本控制学习博客
GIT版本控制学习博客 环境部署 下载git版本控制即可. 用户配置 (1)设置用户及地址 git config --global user.name "Username" git ...
- KingbaseES数据库配置Hikari数据源
Hikari是一个高性能的数据库连接池,它是Spring Boot 2.x中的默认数据源. 一.下载驱动 打开下面网址:选择对应平台的jdbc驱动程序. 人大金仓-成为世界卓越的数据库产品与服务提供商 ...
- KingbaseES V8R3 表加密
前言 透明加密是指将数据库page加密后写入磁盘,当需要读取对应page时进行加密读取.此过程对于用户是透明, 用户无需干预. 该文档进行数据库V8R3版本测试透明加密功能,需要说明,该版本发布时间早 ...
- 实现一个简单的echarts词云图PythonFlask
cloud.html 1 <!DOCTYPE html> 2 <html lang="en"> 3 <head> 4 <meta char ...