修改pytorch官方实例适用于自己的二分类迁移学习项目
本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:
- 根据AUC来迭代最优参数;
- 五折交叉验证;
- 输出验证集错误分类图片;
- 输出分类报告并保存AUC结果图片。
import os
import numpy as np
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score, classification_report
from sklearn.model_selection import KFold
from torch.autograd import Variable
import torch.optim as optim
import time
import copy
import shutil
import sys
import scikitplot as skplt
import matplotlib.pyplot as plt
import pandas as pd plt.switch_backend('agg')
N_CLASSES = 2
BATCH_SIZE = 8
DATA_DIR = './data'
LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
"""Imshow for Tensor."""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
since = time.time()
# 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
best_model_wts = copy.deepcopy(model.state_dict())
# best_acc = 0.0
# 初始auc
best_auc = 0.0
best_desc = [0, 0, None]
best_img_name = None
plt_auc = [None, None] for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('- ' * 50) for phase in ['train', 'val']:
if phase == 'train':
# 训练的时候进行学习率规划,其定义在下面给出
scheduler.step()
model.train(True)
else:
model.train(False)
phase_pred = np.array([])
phase_label = np.array([])
img_name = np.zeros((1, 2))
prob_pred = np.zeros((1, 2))
running_loss = 0.0
running_corrects = 0
# 这样迭代方便跟踪图片路径,输出错误图片名称
for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
optimizer.zero_grad() # forward
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels) # backward + 训练阶段优化
if phase == 'train':
loss.backward()
optimizer.step() if phase == 'val':
img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
prob = outputs.data.cpu().numpy()
prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
phase_label = np.append(phase_label, labels.data.cpu().numpy())
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data).float()
print()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
epoch_auc = roc_auc_score(phase_label, phase_pred)
print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
phase, epoch_loss, epoch_acc, epoch_auc))
report = classification_report(phase_label, phase_pred, target_names=class_names)
print(report) img_name = zip(img_name[1:], phase_pred)
# 当验证时遇到了更好的模型则予以保留
if phase == 'val' and epoch_auc > best_auc:
best_auc = epoch_auc
best_desc = epoch_acc, epoch_auc, report
best_img_name = img_name
# 深拷贝模型参数
best_model_wts = copy.deepcopy(model.state_dict())
plt_auc = phase_label, prob_pred[1:] print()
print(plt_auc[0].shape, plt_auc[1].shape)
csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
csv_file['true_label'] = pd.DataFrame(plt_auc[0])
csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
best_desc[2])
report_file.write(reports)
print(reports)
print('List the wrong judgement img ...')
count = 0
for i in best_img_name:
actual_label = int(i[0][1])
pred_label = i[1]
if actual_label != pred_label:
tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
f'pred: {LABEL_DICT[pred_label]}'
print(tmp_word)
label_file.write(tmp_word + '\n')
count += 1
print(f'This fold has {count} wrong records ...') # 载入最优模型参数
model.load_state_dict(best_model_wts)
return model def plot_img():
for i, data in enumerate(dataloaders['train']):
inputs, classes = data
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
def move_file(data, file_path, dir_path, root_path):
label_0 = 'class_2'
label_1 = 'class_1'
print(f'start copy the {file_path} file ...')
os.chdir(dir_path)
if os.path.exists(file_path):
print(f'Find exist {file_path} file, the file will be dropped.')
shutil.rmtree(os.path.join(root_path, dir_path, file_path))
print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
tmp_path = os.path.join(os.getcwd(), file_path)
tmp_pre_path = os.getcwd()
for d in data:
pre_path = os.path.join(tmp_pre_path, d)
os.chdir(tmp_path)
if d[:2] == label_0:
if not os.path.exists(label_0):
os.mkdir(label_0)
cur_path = os.path.join(tmp_path, label_0, d)
shutil.copyfile(pre_path, cur_path)
if d[:2] == label_1:
if not os.path.exists(label_1):
os.mkdir(label_1)
cur_path = os.path.join(tmp_path, label_1, d)
shutil.copyfile(pre_path, cur_path)
print('finish this work ...') if __name__ == "__main__":
if not os.path.exists('roc_img'):
os.mkdir('roc_img')
if not os.path.exists('prob_result'):
os.mkdir('prob_result')
if not os.path.exists('report'):
os.mkdir('report')
if not os.path.exists('error_record'):
os.mkdir('error_record')
if not os.path.exists('model'):
os.mkdir('model')
label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
origin_path = '/home/project/'
dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
print(f'The {m} fold for copy file and training ...')
move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
os.chdir(origin_path)
move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
os.chdir(origin_path)
data_transforms = {
'train': transforms.Compose([
# 裁剪到224,224
transforms.RandomResizedCrop(224),
# 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
transforms.ToTensor(),
# 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
} image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
shuffle=True, num_workers=8, pin_memory=False)
for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
size = len(class_names)
print('label mapping: ')
print(image_datasets['train'].class_to_idx)
use_gpu = torch.cuda.is_available()
model_ft = None
if sys.argv[1] == 'resnet':
model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(
nn.Linear(num_ftrs, N_CLASSES),
nn.Sigmoid()
) # 这边可以自行把inception模型加进去
if sys.argv[1] == 'inception':
raise Exception("not provide inception model ...")
# model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
model_ft = models.densenet121(pretrained=True)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Sequential(
nn.Linear(num_ftrs, N_CLASSES),
nn.Sigmoid()
)
# use_gpu = False if use_gpu:
model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# 每7个epoch衰减0.1倍
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
print('Start save the model ...')
torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
print(f'The mission of the fold {m} finished.')
print('# '*50)
report_file.close()
label_file.close()
修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章
- Unity-2017.3官方实例教程Space-Shooter(二)
由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...
- Unity-2017.2官方实例教程Roll-a-ball(二)
声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...
- 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)
.caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...
- Unity-2017.3官方实例教程Space-Shooter(一)
由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...
- Unity-2017.2官方实例教程Roll-a-ball(一)
声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...
- NLP(二十二)利用ALBERT实现文本二分类
在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...
- 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。
这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展
教程来源于:Unity官方实例教程 Space Shooter(一)-(五) http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...
随机推荐
- Django的MTV模式详解
参考博客:https://www.cnblogs.com/yuanchenqi/articles/7629939.html 一.MVC模型 Web服务器开发领域里著名的MVC模式. 所谓MVC就是把W ...
- javascript实现数据结构与算法系列
1.线性表(Linear list) 线性表--简单示例及线性表的顺序表示和实现 线性表--线性链表(链式存储结构) 线性表的静态单链表存储结构 循环链表与双向链表 功能完整的线性链表 线性链表的例子 ...
- python数据分析工具安装集合
用python做数据分析离不开几个好的轮子(或称为科学棧/第三方包等),比如matplotlib,numpy, scipy, pandas, scikit-learn, gensim等,这些包的功能强 ...
- 关于 C# 中接口的一些小结
< 关于 C# 中“接口”的一些小结 > 对于 C# 这样的不支持多重继承的语言,很好的体现的层次性,但是有些时候多重继承的确有一些用武之地. 比如,在 Stream 类 . 图形设备 ...
- android 调试卡在:Waiting for Debugger - Application XXX is waiting for the debugger to Attach" 解决方法
解决方法:重启adb. 步骤:cmd进入命令行,进入adb所在目录先后执行adb kill-server,adb start-server.
- CSS3中REM使用详解
px 在Web页面制作中,我们一般使用“px”来设置我们的文本,因为他比较稳定和精确.但是这种方法存在一个问题,当用户在浏览器中浏览我们制作的Web页面时,他改变了浏览器的字体大小(虽然一般人不会去改 ...
- SSH 学习记录及在SSH模式下使用XShell连接服务器
传统的网络服务程序,如rsh.FTP.POP和Telnet其本质上都是不安全的:因为它们在网络上用明文传送数据.用户帐号和用户口令,很容易受到中间人(man-in-the-middle)攻击方式的攻击 ...
- Windows安装时的几个命令(摘录)
Windows无法安装到这个磁盘.选中的磁盘采用GPT分区形式. 1.在系统提示无法安装的那一步,按住“shift+f10”,呼出“cmd”命令符. 2.输入:diskpart,回车.进入diskpa ...
- scope的四种作用域的使用
如何使用spring的作用域: <bean id="role" class="spring.chapter2.maryGame.Role" scope=& ...
- jstl 中substring,length等函数用法
引入jstl库:<%@ taglib prefix="fn" uri="http://java.sun.com/jsp/jstl/functions"%& ...