pytorch 入门
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
import torch
from torch import nn # 包含构建神经网络的所有模块
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 训练数据集
training_data = datasets.FashionMNIST(
root="./data", # 存储测试集和训练集的路径
train=True, # 训练集
download=True, # 如果本机没有数据集,就会下载到 root 目录下。
transform=ToTensor() # 对样本数据进行处理,转换为张量数据
)
# 测试数据集
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=ToTensor()
)
# 标签字典,一个key键对应一个label
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
# 设置画布大小
# figure = plt.figure(figsize=(8, 8))
# cols, rows = 3, 3
# for i in range(1, cols * rows + 1):
# # 随机生成一个索引
# sample_idx = torch.randint(len(training_data), size=(1,)).item()
# # 获取样本及其对应的标签
# img, label = training_data[sample_idx]
# figure.add_subplot(rows, cols, i)
# # 设置标题
# plt.title(labels_map[label])
# # 不显示坐标轴
# plt.axis("off")
# # 显示灰度图
# plt.imshow(img.squeeze(), cmap="gray")
# plt.show()
# 训练数据加载器; 根据数据集生成一个迭代对象,用于模型的训练
train_dataloader = DataLoader(
# 定义好的数据集
dataset=training_data,
# 设置批量大小
batch_size=128,
# 线程数,默认为0。在Windows下设置大于0的数可能会报错。
num_workers=0,
# 打乱样本的顺序
shuffle=True)
# 测试数据加载器
test_dataloader = DataLoader(
dataset=test_data,
batch_size=128,
shuffle=True)
# 展示图片和标签
# train_features, train_labels = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")
# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.imshow(img, cmap="gray")
# plt.show()
# print(f"Label: {label}")
# 模型定义
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__() # 执行父类中的 init 函数
self.flatten = nn.Flatten() # 将每个大小为28x28的图像转换为784个像素值的连续数组
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=28 * 28, out_features=512), # 线性层
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# 优化模型参数
def train_loop(dataloader, model, loss_func, optimizer, device):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
loss = loss_func(pred, y)
# 反向传播,优化参数
optimizer.zero_grad() # 将模型的梯度归 0
loss.backward() # 用来存储每个参数的损失梯度
optimizer.step() # 梯度调整完以后,调整参数
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# 测试模型性能
def test_loop(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
test_loss += loss_fn(pred, y).item()
# 计算准确率
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# input
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 定义模型,并将模型移动到设备上
model = Network().to(device)
# 设置超参数
learning_rate = 1e-3
epochs = 20
# 定义损失函数和优化器
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
# 训练模型
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, loss_func, optimizer, device)
test_loop(test_dataloader, model, loss_func, device)
print("Done!")
# 保存模型
torch.save(model.state_dict(), 'model_weights.pth')
pytorch 入门的更多相关文章
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...
随机推荐
- 安装CUDA
https://developer.nvidia.com/cuda-toolkit-archive 使用deb安装的话,有时会报错: dpkg: 处理软件包 nvidia-driver-450 (-- ...
- JavaScript 日期和时间的格式化
一.日期和时间的格式化 1.原生方法 1.1.使用 toLocaleString 方法 Date 对象有一个 toLocaleString 方法,该方法可以根据本地时间和地区设置格式化日期时间.例如: ...
- [学习笔记]SQL server完全备份指南
方式一,使用SQL Server Management Studio 准备工作 连接目标数据库服务器 在目标数据库上右键->属性,将数据库的恢复模式设置为"简单",兼容级别设 ...
- Autoit 制作上传工具完美版
一. 制作上传器 在ui自动化过程中经常遇到需要上传的动作,我们可以使用input标签来送值,但这样不太稳定,所以建议使用autoit制作出来的exe工具. 下面就教大家如何制作上传器,如何使用吧! ...
- node.js 历史版本下载
https://nodejs.org/zh-cn/download/releases/
- LeetCode-2024 考试的最大困扰度
来源:力扣(LeetCode)链接:https://leetcode-cn.com/problems/maximize-the-confusion-of-an-exam 题目描述 一位老师正在出一场由 ...
- winfrom快捷键
1.当活动窗体的快捷键 1 protected override bool ProcessCmdKey(ref Message msg, Keys keyData) 2 { 3 KeyEventArg ...
- 《用Python写网络爬虫》pdf高清版免费下载
<用Python写网络爬虫>pdf高清版免费下载地址: 提取码:clba 内容简介 · · · · · · 作为一种便捷地收集网上信息并从中抽取出可用信息的方式,网络爬虫技术变得越来越有 ...
- Visual Studio 2017(vs2017)绿色便携版-北桃特供
原版的VisualStudio2017有几十G,安装起来特别慢,不少用户叫苦连天.该版本是精简过的vs2017,且简化了原来的安装程序,特别适用于教学.个人开发者.某些要求不高的企业. 该绿色便携版是 ...
- pgbouncer相关概念和使用
pgbouncer相关概念和使用 1.pgbouncer介绍 PG 是多进程结构,每新增一个会话就会新增一个进程,相对而言对数据库的开销就会比较巨大. 因为在正常业务会话中,有不少sessio ...