PyTorch 实战(模型训练、模型加载、模型测试)
本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型
自定义数据集
参考我的上一篇博客:自定义数据集处理数据加载
默认小伙伴有对深度学习框架有一定的了解,这里就不做过多的说明了。
好吧,还是简单的说一下吧:
我们在做好了自定义数据集之后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是相似的,过程如下所示:- 导入必要的包
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
- 加载数据
可以发现和MNIST 、CIFAR的加载基本上是一样的
train_db = Pokemon('pokeman', 224, mode='train')
val_db = Pokemon('pokeman', 224, mode='val')
test_db = Pokemon('pokeman', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
- 搭建神经网络
ResNet-18网络结构:
ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:
- 训练模型
def evalute(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def main():
model = ResNet18(5).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evalute(model, test_loader)
- 迁移学习
提升模型的准确率:
# model = ResNet18(5).to(device)
trained_model=resnet18(pretrained=True) # 此时是一个非常好的model
model = nn.Sequential(*list(trained_model.children())[:-1], # 此时使用的是前17层的网络 0-17 *:随机打散
Flatten(),
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
- 保存、加载模型
pytorch保存模型的方式有两种:
第一种:将整个网络都都保存下来
第二种:仅保存和加载模型参数(推荐使用这样的方法)
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
可以看到这是我保存的模型:
其中best.mdl是第二中方法保存的
model.pkl则是第一种方法保存的
- 测试模型
这里是训练时的情况
看这个数据准确率还是不错的,但是还是需要实际的测试这个模型,看它到底学到东西了没有,接下来简单的测试一下:
import torch
from PIL import Image
from torchvision import transforms
device = torch.device('cuda')
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
def prediect(img_path):
net=torch.load('model.pkl')
net=net.to(device)
torch.no_grad()
img=Image.open(img_path)
img=transform(img).unsqueeze(0)
img_ = img.to(device)
outputs = net(img_)
_, predicted = torch.max(outputs, 1)
# print(predicted)
print('this picture maybe :',classes[predicted[0]])
if __name__ == '__main__':
prediect('./test/name.jpg')
实际的测试结果:
效果还是可以的,完整的代码:
https://github.com/huzixuan1/Loader_DateSet
数据集下载链接:
https://pan.baidu.com/s/12-NQiF4fXEOKrXXdbdDPCg
由于笔者能力水平有限,在表述上可能有些不准确;有问题可以联系QQ:1017190168
PyTorch 实战(模型训练、模型加载、模型测试)的更多相关文章
- Pytorch实战学习(四):加载数据集
<PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili Dataset & Dataloader 1.Dataset & Dataloader作用 ※Datas ...
- 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)
文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...
- [译]Vulkan教程(31)加载模型
[译]Vulkan教程(31)加载模型 Loading models 加载模型 Introduction 入门 Your program is now ready to render textured ...
- PyTorch保存模型与加载模型+Finetune预训练模型使用
Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- 全面解析Pytorch框架下模型存储,加载以及冻结
最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...
- 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)
1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...
- [Pytorch]Pytorch 保存模型与加载模型(转)
转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...
随机推荐
- PanGu-Coder2:从排序中学习,激发大模型潜力
本文分享自华为云社区<PanGu-Coder2:从排序中学习,激发大模型潜力>,作者: 华为云软件分析Lab . 2022年7月,华为云PaaS技术创新Lab联合华为诺亚方舟语音语义实验室 ...
- Vika and Her Friends
Smiling & Weeping ----早知道思念那么浓烈,不分手就好了 题目链接:Problem - A - Codeforces 题目大意:有n个Vika的朋友在一个n*m的方格中去捉 ...
- .NET使用quartz+topshelf实现定时执行任务调度服务
一.项目开发 1.新建控制台应用(.NET Framework) 2.配置新项目,自行修改项目名称.位置和框架(建议使用.NET Framework4.5以上版本) 创建好的项目目录如下: 3.右键引 ...
- 我的 Windows 文件管理哲学
前言 作为一个不合格的 Geek,我经常面临把 Windows 弄崩溃的尴尬处境,我的系统因此重装了一遍又一遍--不过在一次次的重装中,我逐渐总结出了于我个人而言行之有效的文件管理哲学,在此略做总 ...
- Solution -「洛谷 P2000」拯救世界
Description Link. 概括什么好麻烦哦 w. Solution 生成函数裸题. 把所有情况罗列出来: kkk: 金: \(1+x^6+x^{12}+\dots=\frac{1}{1-x^ ...
- Record - Dec. 1st, 2020 - Exam. REC
Prob. 1 Desc. & Link. 行走的形式是比较自由的,因为只要走到了最优答案处就可以不管了,所以不需要考虑游戏的结束. 考虑二分答案. 然后预处理出每个节点到 \(s\)(另一棵 ...
- Python爬虫-IP隐藏技术与代理爬取
在进行爬虫程序开发和运行时,常常会遇到目标网站的反爬虫机制,最常见的就是IP封禁,这时需要使用IP隐藏技术和代理爬取. 一.IP隐藏技术 IP隐藏技术,即伪装IP地址,使得爬虫请求的IP地址不被目标网 ...
- Teamcenter RAC开发 GoToHelper
RAC开发,有时候会用到发送到我的Teamcenter 可以参考 com.teamcenter.rac.tcapps 包下 package com.teamcenter.rac.tracelinks; ...
- c++ 常用的 STL
c++ 中常用的 STL vector //vector 变长数组 倍增的思想(倍增:系统为每一个程序分配空间的时候,所需要的时间和空间大小无关,与请求次数相关)尽量减少请求的次数 /* 返回元素的个 ...
- 2D物理引擎 Box2D for javascript Games 第四章 将力作用到刚体上
2D物理引擎 Box2D for javascript Games 第四章 将力作用到刚体上 将力作用到刚体上 Box2D 是一个在力作用下的世界,它可以将力作用于刚体上,从而给我们一个更加真实的模拟 ...