pytorch 入门
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
import torch
from torch import nn # 包含构建神经网络的所有模块
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 训练数据集
training_data = datasets.FashionMNIST(
root="./data", # 存储测试集和训练集的路径
train=True, # 训练集
download=True, # 如果本机没有数据集,就会下载到 root 目录下。
transform=ToTensor() # 对样本数据进行处理,转换为张量数据
)
# 测试数据集
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=ToTensor()
)
# 标签字典,一个key键对应一个label
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
# 设置画布大小
# figure = plt.figure(figsize=(8, 8))
# cols, rows = 3, 3
# for i in range(1, cols * rows + 1):
# # 随机生成一个索引
# sample_idx = torch.randint(len(training_data), size=(1,)).item()
# # 获取样本及其对应的标签
# img, label = training_data[sample_idx]
# figure.add_subplot(rows, cols, i)
# # 设置标题
# plt.title(labels_map[label])
# # 不显示坐标轴
# plt.axis("off")
# # 显示灰度图
# plt.imshow(img.squeeze(), cmap="gray")
# plt.show()
# 训练数据加载器; 根据数据集生成一个迭代对象,用于模型的训练
train_dataloader = DataLoader(
# 定义好的数据集
dataset=training_data,
# 设置批量大小
batch_size=128,
# 线程数,默认为0。在Windows下设置大于0的数可能会报错。
num_workers=0,
# 打乱样本的顺序
shuffle=True)
# 测试数据加载器
test_dataloader = DataLoader(
dataset=test_data,
batch_size=128,
shuffle=True)
# 展示图片和标签
# train_features, train_labels = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")
# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.imshow(img, cmap="gray")
# plt.show()
# print(f"Label: {label}")
# 模型定义
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__() # 执行父类中的 init 函数
self.flatten = nn.Flatten() # 将每个大小为28x28的图像转换为784个像素值的连续数组
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=28 * 28, out_features=512), # 线性层
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# 优化模型参数
def train_loop(dataloader, model, loss_func, optimizer, device):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
loss = loss_func(pred, y)
# 反向传播,优化参数
optimizer.zero_grad() # 将模型的梯度归 0
loss.backward() # 用来存储每个参数的损失梯度
optimizer.step() # 梯度调整完以后,调整参数
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# 测试模型性能
def test_loop(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
test_loss += loss_fn(pred, y).item()
# 计算准确率
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# input
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 定义模型,并将模型移动到设备上
model = Network().to(device)
# 设置超参数
learning_rate = 1e-3
epochs = 20
# 定义损失函数和优化器
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
# 训练模型
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, loss_func, optimizer, device)
test_loop(test_dataloader, model, loss_func, device)
print("Done!")
# 保存模型
torch.save(model.state_dict(), 'model_weights.pth')
pytorch 入门的更多相关文章
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...
随机推荐
- Github认证
1.前言 Github关闭了密码认证,现在还有两种认证方式 token ssh 本人一直都在使用idea的可视化界面,进行git的操作,第一次使用bash进行初始化时遇到了身份验证的问题.现在简单总结 ...
- 一段简单的对TXT文件的操作代码
1 string txt = @"C:\DetectFolder\IPV4地址.txt"; 2 string path = ""; 3 4 if (File.E ...
- 2021级《JAVA语言程序设计》上机考试试题5
这是系统员功能实现,因为使用到了教师,所以教师的Bean与Dao,以及更新的Servlet与service Teacher package Bean; public class Teacher {pr ...
- Zstack迁移实战记录1
https://blog.csdn.net/weixin_43767046/article/details/113748775 这段时间除了那个重度烤机测试(上面链接),还在做另一件事,想再做一个服务 ...
- react 高效高质量搭建后台系统 系列 —— 结尾
其他章节请看: react 高效高质量搭建后台系统 系列 尾篇 本篇主要介绍表单查询.表单验证.通知(WebSocket).自动构建.最后附上 myspug 项目源码. 项目最终效果: 表单查询 需求 ...
- RocketMQ - 消费者概述
消费流程 消费者组: 一个逻辑概念,在使用消费者时需要指定一个组名.一个消费者组可以订阅多个Topic. 消费者实例: 一个消费者组程序部署了多个进程,每个进程都可以称为一个消费者实例. 订阅关系: ...
- JZOJ 1073. 【GDOI2005】山海经
\(\text{Solution}\) 非常经典的求区间最大字段和 不难想到线段树,考虑处理区间答案的合并 维护前缀后缀最大和与区间答案,合并考虑跨中点贡献即可 代码打得非常恶心... \(\text ...
- NOIP2021游记总结
\(\text{Day-1}\) 惨遭遣返······ 这真是伟大的啊!! \(\text{Day1}\) \(day\) 几好像没有意义,反正只有一天 \(\text{T1}\) 极致 \(H_2O ...
- 代码随想录算法训练营day10 | leetcode 232.用栈实现队列 225. 用队列实现栈
基础知识 使用ArrayDeque 实现栈和队列 stack push pop peek isEmpty() size() queue offer poll peek isEmpty() size() ...
- 【7】java之正则表达式
一.正则标记 所有的正则可以使用的标记都在 java.util.regex.Pattern 类里定义. 1.1 单个字符 字符:表示由一位字符所组成: \\\\:表示转义字符"\\&qu ...