本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能:

  1. 根据AUC来迭代最优参数;
  2. 五折交叉验证;
  3. 输出验证集错误分类图片;
  4. 输出分类报告并保存AUC结果图片。
     import os
    import numpy as np
    import torch
    import torch.nn as nn
    from torch.optim import lr_scheduler
    import torchvision
    from torchvision import datasets, models, transforms
    from torch.utils.data import DataLoader
    from sklearn.metrics import roc_auc_score, classification_report
    from sklearn.model_selection import KFold
    from torch.autograd import Variable
    import torch.optim as optim
    import time
    import copy
    import shutil
    import sys
    import scikitplot as skplt
    import matplotlib.pyplot as plt
    import pandas as pd plt.switch_backend('agg')
    N_CLASSES = 2
    BATCH_SIZE = 8
    DATA_DIR = './data'
    LABEL_DICT = {0: 'class_1', 1: 'class_2'} def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
    plt.title(title)
    plt.pause(100) def train_model(model, criterion, optimizer, scheduler, fold, name, num_epochs=25):
    since = time.time()
    # 先深拷贝一份当前模型的参数,后面迭代过程中若遇到更优模型则替换
    best_model_wts = copy.deepcopy(model.state_dict())
    # best_acc = 0.0
    # 初始auc
    best_auc = 0.0
    best_desc = [0, 0, None]
    best_img_name = None
    plt_auc = [None, None] for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('- ' * 50) for phase in ['train', 'val']:
    if phase == 'train':
    # 训练的时候进行学习率规划,其定义在下面给出
    scheduler.step()
    model.train(True)
    else:
    model.train(False)
    phase_pred = np.array([])
    phase_label = np.array([])
    img_name = np.zeros((1, 2))
    prob_pred = np.zeros((1, 2))
    running_loss = 0.0
    running_corrects = 0
    # 这样迭代方便跟踪图片路径,输出错误图片名称
    for data, index in zip(dataloaders[phase], dataloaders[phase].batch_sampler):
    inputs, labels = data
    if use_gpu:
    inputs = Variable(inputs.cuda())
    labels = Variable(labels.cuda())
    else:
    inputs, labels = Variable(inputs), Variable(labels) # 梯度参数设为0
    optimizer.zero_grad() # forward
    outputs = model(inputs)
    _, preds = torch.max(outputs.data, 1)
    loss = criterion(outputs, labels) # backward + 训练阶段优化
    if phase == 'train':
    loss.backward()
    optimizer.step() if phase == 'val':
    img_name = np.append(img_name, np.array(dataloaders[phase].dataset.imgs)[index], axis=0)
    prob = outputs.data.cpu().numpy()
    prob_pred = np.append(prob_pred, prob, axis=0) phase_pred = np.append(phase_pred, preds.cpu().numpy())
    phase_label = np.append(phase_label, labels.data.cpu().numpy())
    running_loss += loss.item() * inputs.size(0)
    running_corrects += torch.sum(preds == labels.data).float()
    print()
    epoch_loss = running_loss / dataset_sizes[phase]
    epoch_acc = running_corrects / dataset_sizes[phase]
    epoch_auc = roc_auc_score(phase_label, phase_pred)
    print('{} Loss: {:.4f} Acc: {:.4f} Auc: {:.4f}'.format(
    phase, epoch_loss, epoch_acc, epoch_auc))
    report = classification_report(phase_label, phase_pred, target_names=class_names)
    print(report) img_name = zip(img_name[1:], phase_pred)
    # 当验证时遇到了更好的模型则予以保留
    if phase == 'val' and epoch_auc > best_auc:
    best_auc = epoch_auc
    best_desc = epoch_acc, epoch_auc, report
    best_img_name = img_name
    # 深拷贝模型参数
    best_model_wts = copy.deepcopy(model.state_dict())
    plt_auc = phase_label, prob_pred[1:] print()
    print(plt_auc[0].shape, plt_auc[1].shape)
    csv_file = pd.DataFrame(plt_auc[1], columns=['class_1', 'class_2'])
    csv_file['true_label'] = pd.DataFrame(plt_auc[0])
    csv_file['true_label'] = csv_file['true_label'].apply(lambda x: LABEL_DICT[x])
    csv_file.to_csv(f'./prob_result/{name}_fold_{fold}_porb.csv', index=False)
    skplt.metrics.plot_roc_curve(plt_auc[0], plt_auc[1], curves=['each_class'])
    plt.savefig(f'./roc_img/{name}_fold_{fold}_roc.png', dpi=600)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    reports = 'The Desc according to the Best val Auc: \nACC -> {:4f}\nAclass_2 -> {:4f}\n\n{}'.format(best_desc[0], best_desc[1],
    best_desc[2])
    report_file.write(reports)
    print(reports)
    print('List the wrong judgement img ...')
    count = 0
    for i in best_img_name:
    actual_label = int(i[0][1])
    pred_label = i[1]
    if actual_label != pred_label:
    tmp_word = f'{i[0][0].split("/")[-1]}, actual: {LABEL_DICT[actual_label]}, ' \
    f'pred: {LABEL_DICT[pred_label]}'
    print(tmp_word)
    label_file.write(tmp_word + '\n')
    count += 1
    print(f'This fold has {count} wrong records ...') # 载入最优模型参数
    model.load_state_dict(best_model_wts)
    return model def plot_img():
    for i, data in enumerate(dataloaders['train']):
    inputs, classes = data
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes]) # 此函数可以修改适用于自己项目的图片文件名
    def move_file(data, file_path, dir_path, root_path):
    label_0 = 'class_2'
    label_1 = 'class_1'
    print(f'start copy the {file_path} file ...')
    os.chdir(dir_path)
    if os.path.exists(file_path):
    print(f'Find exist {file_path} file, the file will be dropped.')
    shutil.rmtree(os.path.join(root_path, dir_path, file_path))
    print(f'Finish drop the {file_path} file.') os.mkdir(file_path)
    tmp_path = os.path.join(os.getcwd(), file_path)
    tmp_pre_path = os.getcwd()
    for d in data:
    pre_path = os.path.join(tmp_pre_path, d)
    os.chdir(tmp_path)
    if d[:2] == label_0:
    if not os.path.exists(label_0):
    os.mkdir(label_0)
    cur_path = os.path.join(tmp_path, label_0, d)
    shutil.copyfile(pre_path, cur_path)
    if d[:2] == label_1:
    if not os.path.exists(label_1):
    os.mkdir(label_1)
    cur_path = os.path.join(tmp_path, label_1, d)
    shutil.copyfile(pre_path, cur_path)
    print('finish this work ...') if __name__ == "__main__":
    if not os.path.exists('roc_img'):
    os.mkdir('roc_img')
    if not os.path.exists('prob_result'):
    os.mkdir('prob_result')
    if not os.path.exists('report'):
    os.mkdir('report')
    if not os.path.exists('error_record'):
    os.mkdir('error_record')
    if not os.path.exists('model'):
    os.mkdir('model')
    label_file = open(f'./error_record/{sys.argv[1]}_img_name_actual_pred.txt', 'w') kf = KFold(n_splits=5, shuffle=True, random_state=1)
    origin_path = '/home/project/'
    dd_list = np.array([o for o in os.listdir(DATA_DIR) if os.path.isfile(os.path.join(DATA_DIR, o))]) for m, n in enumerate(kf.split(dd_list), start=1):
    report_file = open(f'./report/{sys.argv[1]}_fold_{m}_report.txt', 'w')
    print(f'The {m} fold for copy file and training ...')
    move_file(dd_list[n[0]], 'train', DATA_DIR, origin_path)
    os.chdir(origin_path)
    move_file(dd_list[n[1]], 'val', DATA_DIR, origin_path)
    os.chdir(origin_path)
    data_transforms = {
    'train': transforms.Compose([
    # 裁剪到224,224
    transforms.RandomResizedCrop(224),
    # 随机水平翻转给定的PIL.Image,概率为0.5。即:一半的概率翻转,一半的概率不翻转。
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), # HSV以及对比度变化
    transforms.ToTensor(),
    # 把一个取值范围是[0,255]的PIL.Image或者shape为(H,W,C)的numpy.ndarray,转换成形状为[C,H,W],取值范围是[0,1.0]的FloadTensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    } image_datasets = {x: datasets.ImageFolder(os.path.join(DATA_DIR, x),
    data_transforms[x])
    for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=BATCH_SIZE,
    shuffle=True, num_workers=8, pin_memory=False)
    for x in ['train', 'val']} dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} class_names = image_datasets['train'].classes
    size = len(class_names)
    print('label mapping: ')
    print(image_datasets['train'].class_to_idx)
    use_gpu = torch.cuda.is_available()
    model_ft = None
    if sys.argv[1] == 'resnet':
    model_ft = models.resnet50(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    ) # 这边可以自行把inception模型加进去
    if sys.argv[1] == 'inception':
    raise Exception("not provide inception model ...")
    # model_ft = models.inception_v3(pretrained=True) if sys.argv[1] == 'desnet':
    model_ft = models.densenet121(pretrained=True)
    num_ftrs = model_ft.classifier.in_features
    model_ft.classifier = nn.Sequential(
    nn.Linear(num_ftrs, N_CLASSES),
    nn.Sigmoid()
    )
    # use_gpu = False if use_gpu:
    model_ft = model_ft.cuda() criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    # 每7个epoch衰减0.1倍
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
    model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, m, sys.argv[1], num_epochs=25)
    print('Start save the model ...')
    torch.save(model_ft.state_dict(), f'./model/fold_{m}_{sys.argv[1]}.pkl')
    print(f'The mission of the fold {m} finished.')
    print('# '*50)
    report_file.close()
    label_file.close()

修改pytorch官方实例适用于自己的二分类迁移学习项目的更多相关文章

  1. Unity-2017.3官方实例教程Space-Shooter(二)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(一) 章节列表: 一.创建小行星Prefab 二.创建敌机和 ...

  2. Unity-2017.2官方实例教程Roll-a-ball(二)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/97b630a23234 上一节Unity-2017.2官方实例教程Roll ...

  3. 利用sklearn对MNIST手写数据集开始一个简单的二分类判别器项目(在这个过程中学习关于模型性能的评价指标,如accuracy,precision,recall,混淆矩阵)

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  4. Unity-2017.3官方实例教程Space-Shooter(一)

    由于初学Unity,写下此文作为笔记,文中难免会有疏漏,不当之处还望指正. Unity-2017.3官方实例教程Space-Shooter(二) 章节列表: 一.从Asset Store中下载资源并导 ...

  5. Unity-2017.2官方实例教程Roll-a-ball(一)

    声明: 本文系转载,由于Unity版本不同,文中有一些小的改动,原文地址:http://www.jianshu.com/p/6e4b0435e30e Unity-2017.2官方实例教程Roll-a- ...

  6. NLP(二十二)利用ALBERT实现文本二分类

      在文章NLP(二十)利用BERT实现文本二分类中,笔者介绍了如何使用BERT来实现文本二分类功能,以判别是否属于出访类事件为例子.但是呢,利用BERT在做模型预测的时候存在预测时间较长的问题.因此 ...

  7. 对《[Unity官方实例教程 秘密行动] Unity官方教程《秘密行动》(十二) 角色移动》的一些笔记和个人补充,解决角色在地形上移动时穿透问题。

    这里素材全是网上找的. 教程看这里: [Unity官方实例教程 秘密行动] Unity官方教程<秘密行动>(九) 角色初始设定 一.模型设置: 1.首先设置模型的动作无限循环. 不设置的话 ...

  8. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  9. 源于《Unity官方实例教程 “Space Shooter”》思路分析及相应扩展

    教程来源于:Unity官方实例教程 Space Shooter(一)-(五)       http://www.jianshu.com/p/8cc3a2109d3b 一.经验总结 教程中步骤清晰,并且 ...

随机推荐

  1. cf888G. Xor-MST(Boruvka最小生成树 Trie树)

    题意 题目链接 给出\(n\)点,每个点有一个点权\(a[i]\),相邻两点之间的边权为\(a[i] \oplus a[j]\),求最小生成树的值 Sol 非常interesting的一道题,我做过两 ...

  2. JavaScript中8个容易犯的错误

    这里dbestech针对JavaScript初学者给出一些技巧和列出一些陷阱. 1. 你是否尝试过对数组元素进行排序? JavaScript默认使用字典序(alphanumeric)来排序.因此,[1 ...

  3. iPython与notebook的基本用法

    1 Ipython 安装 pip install ipython 2 Notebooke 基本用法 启动ipython使用ipython 启动notebook 使用 ipython notebook ...

  4. HTML 5入门知识(四)

    表单的作用 表单不是表格,既不用来显示数据,也不用来布局网页.表单提供一个界面,一个入口,便于用户把数据提交给后台程序进行处理. 表单的数据传递方式method属性 表单的method属性用于指定在数 ...

  5. 深度搜索C语言伪代码

    bool DFS(Node n, int d){ if (d == 4){//路径长度为返回true,表示此次搜索有解 return true; } for (Node nextNode in n){ ...

  6. Zookeeper的集群配置和Java测试程序

    Zookeeper是Apache下的项目之一,倾向于对大型应用的协同维护管理工作.IBM则给出了IBM对ZooKeeper的认知: Zookeeper 分布式服务框架是 Apache Hadoop 的 ...

  7. 小故事学设计模式之Decorate: (二)老婆的新衣服

    老婆有一件蓝色的裙子和一件粉色的裙子, 不管怎么穿,她还是原来的老婆. 但是在软件里就不一定了, 如果把老婆比作一个class的话, 有一种做法是会因为增加了两个新的Property而继承出两个子类: ...

  8. 高效实时的网络会议数据传输库—UDT

    在视频会议系统的研发当中,我们的音.视频数据必须要有相应的可靠性作为保障,因为视频会议系统是一个实时性非常强的系统,如果其数据在网络不太好的情况下,有可能会出现丢包.数据延迟.数据堵塞等现象,出现这些 ...

  9. 转 tcp协议里rst字段讲解

    TCP协议的原理来谈谈rst复位攻击 http://russelltao.iteye.com/blog/1405349 几种TCP连接中出现RST的情况 https://blog.csdn.net/c ...

  10. GreenPlum 与hadoop什么关系?(转)

    没关系. gp 可以处理大量数据, hadoop 可以处理海量. gp 只能处理湖量,或者河量. 无法处理海量. 作者:SallyLeo链接:https://www.zhihu.com/questio ...