目录结构

dogsData.py

import json

import torch
import os, glob
import random, csv from PIL import Image
from torch.utils.data import Dataset, DataLoader from torchvision import transforms
from torchvision.transforms import InterpolationMode class Dogs(Dataset): def __init__(self, root, resize, mode):
super().__init__()
self.root = root
self.resize = resize
self.nameLable = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root, name)):
continue
self.nameLable[name] = len(self.nameLable.keys()) if not os.path.exists(os.path.join(self.root, 'label.txt')):
with open(os.path.join(self.root, 'label.txt'), 'w', encoding='utf-8') as f:
f.write(json.dumps(self.nameLable, ensure_ascii=False)) # print(self.nameLable)
self.images, self.labels = self.load_csv('images.csv')
# print(self.labels) if mode == 'train':
self.images = self.images[:int(0.8*len(self.images))]
self.labels = self.labels[:int(0.8*len(self.labels))]
elif mode == 'val':
self.images = self.images[int(0.8*len(self.images)):int(0.9*len(self.images))]
self.labels = self.labels[int(0.8*len(self.labels)):int(0.9*len(self.labels))]
else:
self.images = self.images[int(0.9*len(self.images)):]
self.labels = self.labels[int(0.9*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.nameLable.keys():
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# print(len(images)) random.shuffle(images)
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
name = img.split(os.sep)[-2]
label = self.nameLable[name]
writer.writerow([img, label])
print('csv write succesful') images, labels = [], []
with open(os.path.join(self.root, filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label) assert len(images) == len(labels) return images, labels def denormalize(self, x_hat):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# x_hot = (x-mean)/std
# x = x_hat * std = mean
# x : [c, w, h]
# mean [3] => [3, 1, 1]
mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1) x = x_hat * std + mean
return x def __len__(self):
return len(self.images) def __getitem__(self, idx):
# print(idx, len(self.images), len(self.labels))
img, label = self.images[idx], self.labels[idx] # 将字符串路径转换为tensor数据
# print(self.resize, type(self.resize))
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = tf(img) label = torch.tensor(label) return img, label def main(): import visdom
import time viz = visdom.Visdom() # func1 通用
db = Dogs('Images_Data_Dog', 224, 'train')
# 取一张
# x,y = next(iter(db))
# print(x.shape, y)
# viz.image(db.denormalize(x), win='sample_x', opts=dict(title='sample_x')) # 取一个batch
loader = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)
print(len(loader))
print(db.nameLable)
# for x, y in loader:
# # print(x.shape, y)
# viz.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) # # fun2
# import torchvision
# tf = transforms.Compose([
# transforms.Resize((64, 64)),
# transforms.RandomRotation(15),
# transforms.ToTensor(),
# ])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# loader = DataLoader(db, batch_size=32, shuffle=True)
# print(len(loader))
# for x, y in loader:
# # print(x.shape, y)
# viz.images(x, nrow=8, win='batch', opts=dict(title='batch'))
# viz.text(str(y.numpy()), win='label', opts=dict(title='batch_y'))
# time.sleep(10) if __name__ == '__main__':
main()

utils.py

import torch
from torch import nn class Flatten(nn.Module):
def __init__(self):
super().__init__() def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)

train.py

import os
import sys
base_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(base_path)
base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_path)
import torch
import visdom
from torch import optim, nn
import torchvision from torch.utils.data import DataLoader from dogs_train.utils import Flatten
from dogsData import Dogs from torchvision.models import resnet18 viz = visdom.Visdom() batchsz = 32
lr = 1e-3
epochs = 20 device = torch.device('cuda')
torch.manual_seed(1234) train_db = Dogs('Images_Data_Dog', 224, mode='train')
val_db = Dogs('Images_Data_Dog', 224, mode='val')
test_db = Dogs('Images_Data_Dog', 224, mode='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
logist = model(x)
pred = logist.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct/total def main(): # model = ResNet18(5).to(device)
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],
Flatten(), # [b, 512, 1, 1] => [b, 512]
nn.Linear(512, 27)
).to(device) x = torch.randn(2, 3, 224, 224).to(device)
print(model(x).shape) optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs): for step, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device) logits = model(x)
loss = criteon(logits, y) optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 2 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl') viz.line([val_acc], [global_step], win='val_acc', update='append') print('best acc', best_acc, 'best epoch', best_epoch) model.load_state_dict(torch.load('best.mdl'))
print('loader from ckpt') test_acc = evalute(model, test_loader)
print(test_acc) if __name__ == '__main__':
main()

resnet18训练自定义数据集的更多相关文章

  1. MMDetection 快速开始,训练自定义数据集

    本文将快速引导使用 MMDetection ,记录了实践中需注意的一些问题. 环境准备 基础环境 Nvidia 显卡的主机 Ubuntu 18.04 系统安装,可见 制作 USB 启动盘,及系统安装 ...

  2. Scaled-YOLOv4 快速开始,训练自定义数据集

    代码: https://github.com/ikuokuo/start-scaled-yolov4 Scaled-YOLOv4 代码: https://github.com/WongKinYiu/S ...

  3. [炼丹术]YOLOv5训练自定义数据集

    YOLOv5训练自定义数据 一.开始之前的准备工作 克隆 repo 并在Python>=3.6.0环境中安装requirements.txt,包括PyTorch>=1.7.模型和数据集会从 ...

  4. yolov5训练自定义数据集

    yolov5训练自定义数据 step1:参考文献及代码 博客 https://blog.csdn.net/weixin_41868104/article/details/107339535 githu ...

  5. Tensorflow2 自定义数据集图片完成图片分类任务

    对于自定义数据集的图片任务,通用流程一般分为以下几个步骤: Load data Train-Val-Test Build model Transfer Learning 其中大部分精力会花在数据的准备 ...

  6. torch_13_自定义数据集实战

    1.将图片的路径和标签写入csv文件并实现读取 # 创建一个文件,包含image,存放方式:label pokemeon\\mew\\0001.jpg,0 def load_csv(self,file ...

  7. tensorflow从训练自定义CNN网络模型到Android端部署tflite

    网上有很多关于tensorflow lite在安卓端部署的教程,但是大多只讲如何把训练好的模型部署到安卓端,不讲如何训练,而实际上在部署的时候,需要知道训练模型时预处理的细节,这就导致了自己训练的模型 ...

  8. Yolo训练自定义目标检测

    Yolo训练自定义目标检测 参考darknet:https://pjreddie.com/darknet/yolo/ 1. 下载darknet 在 https://github.com/pjreddi ...

  9. pytorch加载语音类自定义数据集

    pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合 torch.u ...

  10. PyTorch 自定义数据集

    准备数据 准备 COCO128 数据集,其是 COCO train2017 前 128 个数据.按 YOLOv5 组织的目录: $ tree ~/datasets/coco128 -L 2 /home ...

随机推荐

  1. 社团管理系统(AMS)个人总结

    一.展示 所在小组:13组 源代码链接:https://github.com/xupppp/ams 博文链接: https://www.cnblogs.com/xupppp/p/11795218.ht ...

  2. 字节过滤流 缓冲流-->BufferedInputStream用法

    1创建字节输入节点流FileInputStream fis = new FileInputStream("文件读取的路径");2创建字节输入过滤流,包装一个字节输入节点流Buffe ...

  3. js字符串常用的方法

    1.  charAt( ) 获取指定下标处的字符 let str = 'hello' console.log(str.charAt(0));//h 2.  charCodeAt 获取下标出的字符的Un ...

  4. 源代码管理工具-Github

    一.实验目的 个人编程:每个开发人员电脑上有自己的代码.硬盘坏了,所有的数据和资料不能找回或是很难复原.安全意识强一些的公司会要求开发人员将代码隔一段时间放到一个集中的计算机上,以日期为文件夹进行备份 ...

  5. git push 报错error: remote unpack failed: error Short read of block

    1.解决办法:找管理代码的人给你开权限. 2如果你的push的命令写错的话,也是会出现远端拒绝的提示,所以记得检查自己的push 命令是否正确 另一种明显的权限拒绝的例子: 英语学习:to push ...

  6. 多线程JUC练习

    package com.aliyun.test.learn; import java.util.concurrent.*; import java.util.concurrent.locks.Reen ...

  7. element-ui el-tree 内容过多出现横向滚动条

    /deep/ .el-tree>.el-tree-node { display: inline-block; min-width: 100%;}

  8. WSL终端无法启动

    WSL终端无法启动 1.状态 (1)宿主系统重新启动后,通过菜单项无法启动终端. (2)在powershell中,通过命令:wsl -l -v,查看子系统均为为stopped状态. (3)Powers ...

  9. 狂神说SpringBoot笔记之编写一个http接口

    编写一个http接口 1.1.在主程序的同级目录下,新建一个controller包,一定要在同级目录下,否则识别不到 2.代码 1 package com.example.app01.demo.api ...

  10. [HCTF 2018]WarmUp 1

    主页面是一个滑稽 得到source.php 观看源码,提示source.php 访问看到源码 <?php highlight_file(__FILE__); class emmm { publi ...