resnet18训练自定义数据集
目录结构
dogsData.py
- import json
- import torch
- import os, glob
- import random, csv
- from PIL import Image
- from torch.utils.data import Dataset, DataLoader
- from torchvision import transforms
- from torchvision.transforms import InterpolationMode
- class Dogs(Dataset):
- def __init__(self, root, resize, mode):
- super().__init__()
- self.root = root
- self.resize = resize
- self.nameLable = {}
- for name in sorted(os.listdir(os.path.join(root))):
- if not os.path.isdir(os.path.join(root, name)):
- continue
- self.nameLable[name] = len(self.nameLable.keys())
- if not os.path.exists(os.path.join(self.root, 'label.txt')):
- with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
- f.write(json.dumps(self.nameLable, ensure_ascii=False))
- # print(self.nameLable)
- self.images, self.labels = self.load_csv('images.csv')
- # print(self.labels)
- if mode == 'train':
- self.images = self.images[:int(0.8*len(self.images))]
- self.labels = self.labels[:int(0.8*len(self.labels))]
- elif mode == 'val':
- self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
- self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
- else:
- self.images = self.images[int(0.9*len(self.images)):]
- self.labels = self.labels[int(0.9*len(self.labels)):]
- def load_csv(self, filename):
- if not os.path.exists(os.path.join(self.root, filename)):
- images = []
- for name in self.nameLable.keys():
- images += glob.glob(os.path.join(self.root, name, '*.png'))
- images += glob.glob(os.path.join(self.root, name, '*.jpg'))
- images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
- # print(len(images))
- random.shuffle(images)
- with open(os.path.join(self.root, filename), mode='w', newline='') as f:
- writer = csv.writer(f)
- for img in images:
- name = img.split(os.sep)[-2]
- label = self.nameLable[name]
- writer.writerow([img, label])
- print('csv write succesful')
- images, labels = [], []
- with open(os.path.join(self.root, filename)) as f:
- reader = csv.reader(f)
- for row in reader:
- img, label = row
- label = int(label)
- images.append(img)
- labels.append(label)
- assert len(images) == len(labels)
- return images, labels
- def denormalize(self, x_hat):
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
- # x_hot = (x-mean)/std
- # x = x_hat * std = mean
- # x : [c, w, h]
- # mean [3] => [3, 1, 1]
- mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
- std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
- x = x_hat * std + mean
- return x
- def __len__(self):
- return len(self.images)
- def __getitem__(self, idx):
- # print(idx, len(self.images), len(self.labels))
- img, label = self.images[idx], self.labels[idx]
- # 将字符串路径转换为tensor数据
- # print(self.resize, type(self.resize))
- tf = transforms.Compose([
- lambda x: Image.open(x).convert('RGB'),
- transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
- transforms.RandomRotation(15),
- transforms.CenterCrop(self.resize),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- img = tf(img)
- label = torch.tensor(label)
- return img, label
- def main():
- import visdom
- import time
- viz = visdom.Visdom()
- # func1 通用
- db = Dogs('Images_Data_Dog', 224, 'train')
- # 取一张
- # x,y = next(iter(db))
- # print(x.shape, y)
- # viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x'))
- # 取一个batch
- loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
- print(len(loader))
- print(db.nameLable)
- # for x, y in loader:
- # # print(x.shape, y)
- # viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
- # viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
- # time.sleep(10)
- # # fun2
- # import torchvision
- # tf = transforms.Compose([
- # transforms.Resize((64, 64)),
- # transforms.RandomRotation(15),
- # transforms.ToTensor(),
- # ])
- # db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
- # loader = DataLoader(db, batch_size=32, shuffle=True)
- # print(len(loader))
- # for x, y in loader:
- # # print(x.shape, y)
- # viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
- # viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
- # time.sleep(10)
- if __name__ == '__main__':
- main()
utils.py
- import torch
- from torch import nn
- class Flatten(nn.Module):
- def __init__(self):
- super().__init__()
- def forward(self, x):
- shape = torch.prod(torch.tensor(x.shape[1:])).item()
- return x.view(-1, shape)
train.py
- import os
- import sys
- base_path = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(base_path)
- base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
- sys.path.append(base_path)
- import torch
- import visdom
- from torch import optim, nn
- import torchvision
- from torch.utils.data import DataLoader
- from dogs_train.utils import Flatten
- from dogsData import Dogs
- from torchvision.models import resnet18
- viz = visdom.Visdom()
- batchsz = 32
- lr = 1e-3
- epochs = 20
- device = torch.device('cuda')
- torch.manual_seed(1234)
- train_db = Dogs('Images_Data_Dog', 224, mode='train')
- val_db = Dogs('Images_Data_Dog', 224, mode='val')
- test_db = Dogs('Images_Data_Dog', 224, mode='test')
- train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
- val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
- test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
- def evalute(model, loader):
- correct = 0
- total = len(loader.dataset)
- for x, y in loader:
- x = x.to(device)
- y = y.to(device)
- with torch.no_grad():
- logist = model(x)
- pred = logist.argmax(dim=1)
- correct += torch.eq(pred, y).sum().float().item()
- return correct/total
- def main():
- # model = ResNet18(5).to(device)
- trained_model = resnet18(pretrained=True)
- model = nn.Sequential(*list(trained_model.children())[:-1],
- Flatten(), # [b, 512, 1, 1] => [b, 512]
- nn.Linear(512, 27)
- ).to(device)
- x = torch.randn(2, 3, 224, 224).to(device)
- print(model(x).shape)
- optimizer = optim.Adam(model.parameters(), lr=lr)
- criteon = nn.CrossEntropyLoss()
- best_acc, best_epoch = 0, 0
- global_step = 0
- viz.line([0], [-1], win='loss', opts=dict(title='loss'))
- viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
- for epoch in range(epochs):
- for step, (x, y) in enumerate(train_loader):
- x = x.to(device)
- y = y.to(device)
- logits = model(x)
- loss = criteon(logits, y)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- viz.line([loss.item()], [global_step], win='loss', update='append')
- global_step += 1
- if epoch % 2 == 0:
- val_acc = evalute(model, val_loader)
- if val_acc > best_acc:
- best_acc = val_acc
- best_epoch = epoch
- torch.save(model.state_dict(), 'best.mdl')
- viz.line([val_acc], [global_step], win='val_acc', update='append')
- print('best acc', best_acc, 'best epoch', best_epoch)
- model.load_state_dict(torch.load('best.mdl'))
- print('loader from ckpt')
- test_acc = evalute(model, test_loader)
- print(test_acc)
- if __name__ == '__main__':
- main()
resnet18训练自定义数据集的更多相关文章
- MMDetection 快速开始,训练自定义数据集
本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...
- Scaled-YOLOv4 快速开始,训练自定义数据集
代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...
- [炼丹术]YOLOv5训练自定义数据集
YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...
- yolov5训练自定义数据集
yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...
- Tensorflow2 自定义数据集图片完成图片分类任务
对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...
- torch_13_自定义数据集实战
1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...
- tensorflow从训练自定义CNN网络模型到Android端部署tflite
网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...
- Yolo训练自定义目标检测
Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...
- pytorch加载语音类自定义数据集
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...
- PyTorch 自定义数据集
准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...
随机推荐
- 数据库ip被锁了怎么办
由于多次访问失败,导致ip被限制,登录时会报错 Internal error/check (Not system error) 如何解决: 找一台同事的机子,(或者修改自己的ip),然后打开sql 的 ...
- MongoDB:内嵌文档查询匹配 查询集合中的文档
1.db.getCollection('Notification').find({ Title:{$regex:/班/}, "Message.TargetUrl":{$regex: ...
- linux学习之grep
grep 可进行查找内容 如 cat logs/anyproxy.log | grep "2020080321000049" 还可以通过-v 反向过滤 如 tail -f log ...
- 调用mglearn时的报错 TypeError: __init__() got an unexpected keyword argument 'cachedir'
import mglearn的时候发生的报错 原因是调用了joblib包中的memory类,但是cachedir这个参数已经弃用了 查到下面帖子之后改掉cachedir解决问题 https://blo ...
- turtle绘制风轮
题目要求: 使用turtle库,绘制一个风轮效果,其中,每个风轮内角为45度,风轮边长150像素. 我的代码: import turtle turtle.setup(500,500,100,200) ...
- VUE项目中检测网页滑动注意事项
一.this.$nextTick(function () { window.addEventListener('scroll', this.onScroll, true) ...
- 树莓派 wiringPi的BCM与BOARD编码
一.基础命令使用wiringPi库 1.1.获取管教信息 gpio readall ---获取管脚信息 1.2.BOARD编码和BCM一般都在python库中使用 import RPi.GPIO ...
- 在.NET中使用JWT
1.配置文件添加 //jwt配置文件 "JWT": { "SigningKey": "14fa5f2rrwsg627fs256fdgff2r5rf52 ...
- fetch请求方式
Fetch请求的方式 1:GET 请求 // 未传参数 const getData = async () => { const res = await fetch('http://www.xxx ...
- Pytorch基础复习
项目推进中期,重新到头来学Pytorch.five落泪了.(╬▔皿▔)凸 笑死,憋不住了,边更边学. 整篇博客整体采用总分总形式.首先将介绍内容(加黑部分)之间关系进行概括,后拆解,最后以图总结. 全 ...