迁移学习教程

来自这里

在本教程中,你将学习如何使用迁移学习来训练你的网络。在cs231n notes你可以了解更多关于迁移学习的知识。

  1. 在实践中,很少有人从头开始训练整个卷积网络(使用随机初始化),因为拥有足够大小的数据集相对较少。相反,通常在非常大的数据集(例如ImageNet,它包含120万幅、1000个类别的图像)上对ConvNet进行预训练,然后使用ConvNet作为初始化或固定的特征提取器来执行感兴趣的任务。

两个主要的迁移学习的场景如下:

  • Finetuning the convert:与随机初始化不同,我们使用一个预训练的网络初始化网络,就像在imagenet 1000 dataset上训练的网络一样。其余的训练看起来和往常一样。
  • ConvNet as fixed feature extractor:在这里,我们将冻结所有网络的权重,除了最后的全连接层。最后一个全连接层被替换为一个具有随机权重的新层,并且只训练这一层。
  1. #!/usr/bin/env python3
  2. # License: BSD
  3. # Author: Sasank Chilamkurthy
  4. from __future__ import print_function,division
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. import numpy as np
  9. import torchvision
  10. from torchvision import datasets,models,transforms
  11. import matplotlib.pyplot as plt
  12. import time
  13. import os
  14. import copy
  15. plt.ion() # 交互模式

导入数据

我们使用torchvisiontorch.utils.data包来导入数据。

我们今天要解决的问题是训练一个模型来区分蚂蚁蜜蜂。我们有蚂蚁和蜜蜂的训练图像各120张。每一类有75张验证图片。通常,如果是从零开始训练,这是一个非常小的数据集。因为我们要使用迁移学习,所以我们的例子应该具有很好地代表性。

这个数据集是一个非常小的图像子集。

你可以从这里下载数据并解压到当前目录。

  1. # 训练数据的扩充及标准化
  2. # 只进行标准化验证
  3. data_transforms = {
  4. 'train': transforms.Compose([
  5. transforms.RandomResizedCrop(224),
  6. transforms.RandomHorizontalFlip(),
  7. transforms.ToTensor(),
  8. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  9. ]),
  10. 'val': transforms.Compose([
  11. transforms.Resize(256),
  12. transforms.CenterCrop(224),
  13. transforms.ToTensor(),
  14. transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  15. ])
  16. }
  17. data_dir = 'data/hymenoptera_data'
  18. image_datasets = {x: datasets.ImageFolder(os.path.join(
  19. data_dir, x), data_transforms[x]) for x in ['train', 'val']}
  20. dataloaders = {x: torch.utils.data.DataLoader(
  21. image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
  22. dataset_size = {x:len(image_datasets[x]) for x in ['train','val']}
  23. class_name = image_datasets['train'].classes
  24. device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

可视化一些图像

为了理解数据扩充,我们可视化一些训练图像。

  1. def imshow(inp, title=None):
  2. inp = inp.numpy().transpose((1, 2, 0))
  3. mean = np.array([0.485, 0.456, 0.406])
  4. std = np.array([0.229, 0.224, 0.225])
  5. inp = std * inp + mean
  6. inp = np.clip(inp, 0, 1)
  7. plt.imshow(inp)
  8. if title is not None:
  9. plt.title(title)
  10. plt.pause(10) # 暂停一会,以便更新绘图
  11. # 获取一批训练数据
  12. inputs, classes = next(iter(dataloaders['train']))
  13. # 从批处理中生成网格
  14. out = torchvision.utils.make_grid(inputs)
  15. imshow(out, title=[class_name[x] for x in classes])

训练模型

现在我们来实现一个通用函数来训练一个模型。在这个函数中,我们将:

  • 调整学习率
  • 保存最优模型

下面例子中,参数schedule是来自torch.optim.lr_scheduler的LR调度对象。

  1. def train_model(model, criterion, optimizer, schduler, num_epochs=25):
  2. since = time.time()
  3. best_model_wts = copy.deepcopy(model.state_dict())
  4. best_acc = 0.0
  5. for epoch in range(num_epochs):
  6. print('Epoch {}/{}'.format(epoch, num_epochs-1))
  7. print('-'*10)
  8. for phase in ['train', 'val']:
  9. if phase == 'train':
  10. schduler.step()
  11. model.train() # 训练模型
  12. else:
  13. model.eval() # 评估模型
  14. running_loss = 0.0
  15. running_corrects = 0
  16. for inputs, labels in dataloaders[phase]:
  17. inputs = inputs.to(device)
  18. labels = labels.to(device)
  19. # 零化参数梯度
  20. optimizer.zero_grad()
  21. # 前向传递
  22. # 如果只是训练的话,追踪历史
  23. with torch.set_grad_enabled(phase == 'train'):
  24. outputs = model(inputs)
  25. _, preds = torch.max(outputs, 1)
  26. loss = criterion(outputs, labels)
  27. # 训练时,反向传播 + 优化
  28. if phase == 'train':
  29. loss.backward()
  30. optimizer.step()
  31. # 统计
  32. running_loss += loss.item() * inputs.size(0)
  33. running_corrects += torch.sum(preds == labels.data)
  34. epoch_loss = running_loss / dataset_size[phase]
  35. epoch_acc = running_corrects.double() / dataset_size[phase]
  36. print('{} Loss: {:.4f} Acc: {:.4f}'.format(
  37. phase, epoch_loss, epoch_acc))
  38. # 很拷贝模型
  39. if phase == 'val' and epoch_acc > best_acc:
  40. best_acc = epoch_acc
  41. best_model_wts = copy.deepcopy(model.state_dict())
  42. print()
  43. time_elapsed = time.time() - since
  44. print('Training complete in {:.0f}m {:.0f}s'.format(
  45. time_elapsed // 60, time_elapsed % 60))
  46. print('Best val Acc: {:4f}'.format(best_acc))
  47. # 导入最优模型权重
  48. model.load_state_dict(best_model_wts)
  49. return model

可视化模型预测

展示部分预测图像的通用函数:

  1. def visualize_model(model, num_images=6):
  2. was_training = model.training
  3. model.eval()
  4. images_so_far = 0
  5. fig = plt.figure()
  6. with torch.no_grad():
  7. for i, (inputs, labels) in enumerate(dataloaders['val']):
  8. inputs = inputs.to(device)
  9. labels = labels.to(device)
  10. outputs = model(inputs)
  11. _, preds = torch.max(outputs, 1)
  12. for j in range(inputs.size()[0]):
  13. images_so_far += 1
  14. ax = plt.subplot(num_images//2, 2, images_so_far)
  15. ax.axis('off')
  16. ax.set_title('predicted: {}'.format(class_name[preds[j]]))
  17. imshow(inputs.cpu().data[j])
  18. if images_so_far == num_images:
  19. model.train(mode=was_training)
  20. return
  21. model.train(mode=was_training)

Finetuning the convnet

加载预处理的模型和重置最后的全连接层:

  1. model_ft = models.resnet18(pretrained=True)
  2. num_ftrs = model_ft.fc.in_features
  3. model_ft.fc = nn.Linear(num_ftrs, 2)
  4. model_ft = model_ft.to(device)
  5. criterion = nn.CrossEntropyLoss()
  6. # 优化所有参数
  7. optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
  8. # 没7次,学习率衰减0.1
  9. exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
  10. optimizer_ft, step_size=7, gamma=0.1)

训练和评估

在CPU上可能会花费15-25分钟,但是在GPU上,少于1分钟。

  1. model_ft = train_model(model_ft, criterion, optimizer_ft,
  2. exp_lr_scheduler, num_epochs=25)
  1. visualize_model(model_ft)

ConvNet作为固定特征提取器

现在,我们冻结除最后一层外的所有网络。我们需要设置requires_grad=False来冻结参数,这样调用backward()时不计算梯度。

你可以从这篇文档中了解更多。

  1. model_conv = models.resnet18(pretrained=True)
  2. for param in model_conv.parameters():
  3. param.requires_grad = False
  4. # 新构造模块的参数默认requires_grad=True
  5. num_ftrs = model_conv.fc.in_features
  6. model_conv.fc = nn.Linear(num_ftrs, 2)
  7. model_conv = model_conv.to(device)
  8. criterion = nn.CrossEntropyLoss()
  9. # 优化所有参数
  10. optimizer_ft = optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9)
  11. # 没7次,学习率衰减0.1
  12. exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
  13. optimizer_ft, step_size=7, gamma=0.1)
  14. model_conv = train_model(model_conv, criterion, optimizer_ft,
  15. exp_lr_scheduler, num_epochs=25)
  1. visualize_model(model_conv)
  2. plt.ioff()
  3. plt.show()

[PyTorch入门]之迁移学习的更多相关文章

  1. PyTorch 计算机视觉的迁移学习教程代码详解 (TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL )

    PyTorch 原文: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html 参考文章: https://www ...

  2. pytorch入门--土堆深度学习快速入门教程

    工具函数 dir函数,让我们直到工具箱,以及工具箱中的分隔区有什么东西 help函数,让我们直到每个工具是如何使用的,工具的使用方法 示例:在pycharm的console环境,输入 import t ...

  3. 修改pytorch官方实例适用于自己的二分类迁移学习项目

    本demo从pytorch官方的迁移学习示例修改而来,增加了以下功能: 根据AUC来迭代最优参数: 五折交叉验证: 输出验证集错误分类图片: 输出分类报告并保存AUC结果图片. import os i ...

  4. Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader

    本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...

  5. PyTorch专栏(五):迁移学习

    专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60分钟入门 PyTorch入门 PyTorch自动微分 PyTorch神经网络 P ...

  6. PyTorch基础——迁移学习

    一.介绍 内容 使机器能够"举一反三"的能力 知识点 使用 PyTorch 的数据集套件从本地加载数据的方法 迁移训练好的大型神经网络模型到自己模型中的方法 迁移学习与普通深度学习 ...

  7. 使用PyTorch进行迁移学习

    概述 迁移学习可以改变你建立机器学习和深度学习模型的方式 了解如何使用PyTorch进行迁移学习,以及如何将其与使用预训练的模型联系起来 我们将使用真实世界的数据集,并比较使用卷积神经网络(CNNs) ...

  8. Pytorch迁移学习实现驾驶场景分类

    Pytorch迁移学习实现驾驶场景分类 源代码:https://github.com/Dalaska/scene_clf 1.安装 pytorch 直接用官网上的方法能装上但下载很慢.通过换源安装发现 ...

  9. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下面修改 ...

随机推荐

  1. swift bannerview 广告轮播图

    class BannerView: UIView,UIScrollViewDelegate{ //图⽚⽔平放置到scrollView上 private var scrollView:UIScrollV ...

  2. 网页滚动条CSS样式

    滚动条样式主要涉及到如下CSS属性: overflow属性: 检索或设置当对象的内容超过其指定高度及宽度时如何显示内容 overflow: auto; 在需要时内容会自动添加滚动条overflow: ...

  3. 7.windows-oracle实战第七课 --约束、索引

    数据的完整性 数据的完整性用于确保数据库数据遵从一定的商业和逻辑规则.数据的完整性使用约束.触发器.函数的方法来实现.在这三个方法中,约束易于维护,具备最好的性能,所以作为首选.  约束:not nu ...

  4. 放贷额度相关的ROI计算

    违约模型得到概率估计, 将概率值划分5档, 每一档确定一个授信系数 新的授信 = 每月收入* 授信系数 - 老的授信 计算新增授信额度 计算余额损失

  5. 关于tomcat启动错误:At least one JAR was scanned for TLDs yet contained no TLDs

    一.问题原因: 1.出现这个问题的原因就是Tomcat启动时会扫描大量jar包,如果含有不符合TLD规范的就会出现这个问题 2.以后基本上不会使用JSP作为视图层,所以我们可能根本不需要TLD这个东西 ...

  6. 18)PHP,可变函数,匿名函数 变量的作用域

    (1)可变函数: 可变函数,就是函数名“可变”——其实跟可变变量一样的道理. $str1 = “f1”;   //只是一个字符串,内容为”f1” $v1 = $str1(3, 4);   //形式上看 ...

  7. 图遍历算法的应用(包括输出长度为l的路径最短最长路径)

    判断从顶点u到v是否有路径 void ExistPath(AdjGraph* G, int u, int v, bool& has) { int w; ArcNode* p; visit[u] ...

  8. Flask pythn Web 框架总结

    Flask pythn Web 框架总结 一, Flask 介绍 Flask 是一个基于Python 实现的web 开发的'小型轻框架' 1. flask介绍 Flask是一个基于Python实现的w ...

  9. AOP实现防止接口重复提交

    项目中对于状态变更接口存在重复提交的问题. package com.yxx.survey.foundation.aop; import com.alibaba.fastjson.JSON; impor ...

  10. macbook安装LightGBM

    一开始直接用pip install lightgbm 报错: OSError: dlopen(/opt/anaconda3/lib/python3.7/site-packages/lightgbm/l ...