PyTorch深度学习实践——多分类问题
多分类问题
课程来源:PyTorch深度学习实践——河北工业大学
《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili
Softmax
这一讲介绍使用softmax分类器实现多分类问题。
上一节课计算的是二分类问题,也就是输出的label可以分类为0,1两类。只要计算出\(P(y=1)\)的概率,那么\(P(y=0)=1-P(y=1)\);所以只需要计算一种类型的概率即可,也就是只要一个参数。
而在使用MINIST对手写数字进行分类的时候一共是有10个分类的(数字0-9)。
处理方式:视为10个二分类问题(一个label和其他9个label),计算每一个label的概率。如下图所示,但是问题在于
- 每一个二分类问题的结果是独立的,不能保证10个结果加起来等于1,且无法解决互相抑制的问题。
- 每一个结果不能保证大于0。

我们希望输出是有竞争关系的,也就是如果有一项很大那么其他项要相对比较小。为了解决上述问题,提出softmax函数,使用结构如下:

Softmax计算公式如下:
\]
Softmax函数计算简单示例如下:

接下来考虑多分类问题中的损失函数如何定义:和上述BCE基本一致,同样使用交叉熵作为损失函数,定义式如下:
\]
上面这种计算方式也就是NLLLoss,这种损失函数的结构如下:

而NLLLoss损失函数加上Softmax就是交叉熵损失,对应PyTorch中的nn.CrossEntropyLoss(),也就是最后一层的非线性变化不需要进行,直接交给上述损失函数即可。如下图所示:

在Minist数据集上实现多分类问题
Minsit 数据介绍:每一个手写图片都可以看做是一个28x28的矩阵,如下图所示:

总体构建模型并训练还是如上四步,在最后一步加上测试过程。
注:1.在视觉处理中,灰度图可以看做单通道图像,而彩色图像事实上就是RGB三通道的矩阵,在PyTorch中要把构造成通道数量的C放在第一维的三维向量。
2.神经网络训练中尽量将图像矩阵转换为0-1分布的数据
模型结构简图:

代码如下:
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
# prepare dataset
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化
train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
# 定义模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784, 512)
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) # 最后一层不做激活,不进行非线性变换
model = Net()
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
# training cycle forward, backward, update
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# 预测结果
outputs = model(inputs)
# 交叉熵
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度
total += labels.size(0)
correct += (predicted == labels).sum().item() # 张量之间的比较运算
print('accuracy on test set: %d %% ' % (100*correct/total))
if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()
作业
Pytorch详解NLLLoss和CrossEntropyLoss 详见如下网址:https://blog.csdn.net/weixin_43593330/article/details/108622747
Otto Group Product Classification Challenge
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim
##str转数值类型
def label2id(labels):
id=[]
target_labels=['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
for label in labels:
id.append(target_labels.index(label))
return id
class MyDataset(Dataset):
def __init__(self,filepath):
data=pd.read_csv(filepath)
labels=data['target']
self.x_data=torch.from_numpy(np.array(data)[:,1:-1].astype(np.float32))
self.y_data=label2id(labels)
self.len=data.shape[0]
def __getitem__(self,index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len
class Module(nn.Module):
def __init__(self):
super(Module,self).__init__()
self.linear1=nn.Linear(93,64)
self.linear2 = nn.Linear(64, 32)
self.linear3 = nn.Linear(32, 16)
self.linear4 = nn.Linear(16, 9)
self.activate=nn.ReLU()
def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.activate(self.linear3(x))
x=self.linear4(x)
return x
def train(epoch):
running_loss=0.0
for batch_idx,data in enumerate(train_loader,1):
x,y=data
y_pred=model(x)
loss=critetion(y_pred,y)
running_loss+=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
if batch_idx %100 ==0:
print('[%d, %5d] loss = %.3f' % (epoch + 1, batch_idx, running_loss / 100))
running_loss=0.0
if __name__=="__main__":
train_data = MyDataset('train.csv')
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=0)
model=Module()
critetion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
loss_list=[]
for epoch in range(30):
train(epoch)
plt.plot(range(len(loss_list)), loss_list)
plt.xlabel('step')
plt.ylabel('loss')
plt.show()
def test():
test_data=pd.read_csv('test.csv')
x_test=torch.from_numpy(np.array(test_data)[:,1:].astype(np.float32))
y_pred=model(x_test)
_,pred=torch.max(y_pred,dim=1)
out=pd.get_dummies(pred)#获取one-hot,其实就是0-8
labels=['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
out.columns=labels
out.insert(0,'id',test_data['id'])
result=pd.DataFrame(out)
result.to_csv('otto-group-product_predictions.csv', index=False)
test()
loss可视化:

结果:

PyTorch深度学习实践——多分类问题的更多相关文章
- PyTorch深度学习实践——处理多维特征的输入
处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...
- PyTorch深度学习实践——反向传播
反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...
- PyTorch深度学习实践-Overview
Overview 1.PyTorch简介 PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...
- 万字总结Keras深度学习中文文本分类
摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...
- 深度学习实践系列(2)- 搭建notMNIST的深度神经网络
如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...
- 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络
前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...
- 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码
PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...
- 医学图像 | 使用深度学习实现乳腺癌分类(附python演练)
乳腺癌是全球第二常见的女性癌症.2012年,它占所有新癌症病例的12%,占所有女性癌症病例的25%. 当乳腺细胞生长失控时,乳腺癌就开始了.这些细胞通常形成一个肿瘤,通常可以在x光片上直接看到或感觉到 ...
- 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型
MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集.这个网站 ...
随机推荐
- 用 CSS 让你的文字更有文艺范
透明文字,模糊文字,镂空文字,渐变文字,图片背景文字,用 CSS 让你的文字也有 freestyle- 前言 我们做页面涉及字体的时候,最多就是换个 color 换个 font-family,总是觉得 ...
- 安装Windows11操作系统(不需要绕过TPM检测脚本等) - 初学者系列 - 学习者系列文章
Windows11操作系统是去年微软公司的最新力作.对于该操作系统的安装,网上有很多的教程了.这次主要写的是不需要绕过TPM检测操作安装Windows11操作系统. 1. 制作启动U盘: ...
- Java多线程专题6: Queue和List
合集目录 Java多线程专题6: Queue和List CopyOnWriteArrayList 如何通过写时拷贝实现并发安全的 List? CopyOnWrite(COW), 是计算机程序设计领域中 ...
- 右键没有word?excel?ppt?注解表该改改啦
✿[office 2019]office2010版本以上的都可以(例如:office 2010.office 2016.office 2019) 一.快速方法解决右键没有word: 在电脑桌面右键一个 ...
- 常见线程池之 newCacheThreadPool 缓存线程池 简单使用
package com.aaa.threaddemo; import java.util.concurrent.BlockingQueue; import java.util.concurrent.E ...
- Promise、Generator、Async有什么区别?
前言 我们知道Promise与Async/await函数都是用来解决JavaScript中的异步问题的,从最开始的回调函数处理异步,到Promise处理异步,到Generator处理异步,再到Asyn ...
- 使用Docker快速搭建Halo个人博客到阿里云服务器上[附加主题和使用域名访问]
一.前言 小编买了一个服务器也是一直想整个网站,一直在摸索,看了能够快速搭建博客系统的教程.总结了有以下几种方式,大家按照自己喜欢的去搭建: halo wordpress hexo vuepress ...
- JAVA多线程学习八-多个线程之间共享数据的方式
多个线程访问共享对象和数据的方式 如果每个线程执行的代码相同,可以使用同一个Runnable对象,这个Runnable对象中有那个共享数据,例如,买票系统就可以这么做. 如果每个线程执行的代码不同,这 ...
- HTML-iframe标签
碎碎:这两天在实践中,用到了 iframe,之前对其不甚了解,了解之中遇到好多奇葩问题,今天记录下这两天遇到的相关的内容. 嵌入的 iframe 页面的边框 嵌入的 iframe 页面的背景 嵌入的 ...
- NSMutableString常用方法
1.NSMutableString常用方法 - (void)appendString:(NSString *)aString; 拼接aString到最后面 NSMutableString *strM ...