caffe用起来太笨重了,最近转到pytorch,用起来实在不要太方便,上手也非常快,这里贴一下pytorch官网上的两个小例程,掌握一下它的用法:

例程一:利用nn  这个module构建网络,实现一个图像分类的小功能;

链接:http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

  1. # -*- coding:utf-8 -*-
  2. import torch
  3. from torch.autograd import Variable
  4. import torchvision
  5. import torchvision.transforms as transforms
  6. #数据预处理:转换为Tensor,归一化,设置训练集和验证集以及加载子进程数目
  7. transform = transforms.Compose([transforms.ToTensor() , transforms.Normalize((0.5 , 0.5 , 0.5) , (0.5 , 0.5 , 0.5))]) #前面参数是均值,后面是标准差
  8. trainset = torchvision.datasets.CIFAR10(root = './data' , train = True , download = True , transform = transform)
  9. trainloader = torch.utils.data.DataLoader(trainset , batch_size = 4 , shuffle = True , num_workers =2) #num_works = 2表示使用两个子进程加载数据
  10. testset = torchvision.datasets.CIFAR10(root = './data' , train = False , download = True , transform = transform)
  11. testloader = torch.utils.data.DataLoader(testset , batch_size = 4 , shuffle = True , num_workers = 2)
  12. classes = ('plane' , 'car' , 'bird' , 'cat' , 'deer' , 'dog' , 'frog' , 'horse' , 'ship' , 'truck')
  13.  
  14. import matplotlib.pyplot as plt
  15. import numpy as np
  16. import pylab
  17.  
  18. def imshow(img):
  19. img = img / 2 + 0.5
  20. npimg = img.numpy()
  21. plt.imshow(np.transpose(npimg , (1 , 2 , 0)))
  22. pylab.show()
  23.  
  24. dataiter = iter(trainloader)
  25. images , labels = dataiter.next()
  26. for i in range(4):
  27. p = plt.subplot()
  28. p.set_title("label: %5s" % classes[labels[i]])
  29. imshow(images[i])
  30. #构建网络
  31. from torch.autograd import Variable
  32. import torch.nn as nn
  33. import torch.nn.functional as F
  34. import torch.optim as optim
  35.  
  36. class Net(nn.Module):
  37. def __init__(self):
  38. super(Net , self).__init__()
  39. self.conv1 = nn.Conv2d(3 , 6 , 5)
  40. self.pool = nn.MaxPool2d(2 , 2)
  41. self.conv2 = nn.Conv2d(6 , 16 , 5)
  42. self.fc1 = nn.Linear(16 * 5 * 5 , 120)
  43. self.fc2 = nn.Linear(120 , 84)
  44. self.fc3 = nn.Linear(84 , 10)
  45.  
  46. def forward(self , x):
  47. x = self.pool(F.relu(self.conv1(x)))
  48. x = self.pool(F.relu(self.conv2(x)))
  49. x = x.view(-1 , 16 * 5 * 5) #利用view函数使得conv2层输出的16*5*5维的特征图尺寸变为400大小从而方便后面的全连接层的连接
  50. x = F.relu(self.fc1(x))
  51. x = F.relu(self.fc2(x))
  52. x = self.fc3(x)
  53. return x
  54.  
  55. net = Net()
  56. net.cuda()
  57.  
  58. #define loss function
  59. criterion = nn.CrossEntropyLoss()
  60. optimizer = optim.SGD(net.parameters() , lr = 0.001 , momentum = 0.9)
  61.  
  62. #train the Network
  63. for epoch in range(2):
  64. running_loss = 0.0
  65. for i , data in enumerate(trainloader , 0):
  66. inputs , labels = data
  67. inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
  68. optimizer.zero_grad()
  69. #forward + backward + optimizer
  70. outputs = net(inputs)
  71. loss = criterion(outputs , labels)
  72. loss.backward()
  73. optimizer.step()
  74.  
  75. running_loss += loss.data[0]
  76. if i % 2000 == 1999:
  77. print('[%d , %5d] loss: %.3f' % (epoch + 1 , i + 1 , running_loss / 2000))
  78. running_loss = 0.0
  79. print('Finished Training')
  80.  
  81. dataiter = iter(testloader)
  82. images , labels = dataiter.next()
  83. imshow(torchvision.utils.make_grid(images))
  84. print('GroundTruth:' , ' '.join(classes[labels[j]] for j in range(4)))
  85.  
  86. outputs = net(Variable(images.cuda()))
  87.  
  88. _ , predicted = torch.max(outputs.data , 1)
  89. print('Predicted: ' , ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
  90.  
  91. correct = 0
  92. total = 0
  93. for data in testloader:
  94. images , labels = data
  95. outputs = net(Variable(images.cuda()))
  96. _ , predicted = torch.max(outputs.data , 1)
  97. correct += (predicted == labels.cuda()).sum()
  98. total += labels.size(0)
  99. print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
  100.  
  101. class_correct = torch.ones(10).cuda()
  102. class_total = torch.ones(10).cuda()
  103. for data in testloader:
  104. images , labels = data
  105. outputs = net(Variable(images.cuda()))
  106. _ , predicted = torch.max(outputs.data , 1)
  107. c = (predicted == labels.cuda()).squeeze()
  108. #print(predicted.data[0])
  109. for i in range(4):
  110. label = labels[i]
  111. class_correct[label] += c[i]
  112. class_total[label] += 1
  113.  
  114. for i in range(10):
  115. print('Accuracy of %5s : %2d %%' % (classes[i] , 100 * class_correct[i] / class_total[i]))

例程二:在resnet18的预训练模型上进行finetune,然后实现一个ants和bees的二分类功能:

链接:http://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html

  1. # -*- coding:utf-8 -*-
  2. from __future__ import print_function , division
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torch.optim import lr_scheduler
  7. from torch.autograd import Variable
  8. import numpy as np
  9. import torchvision
  10. from torchvision import datasets , models , transforms
  11. import matplotlib.pyplot as plt
  12. import time
  13. import os
  14. import pylab
  15.  
  16. #data process
  17. data_transforms = {
  18. 'train' : transforms.Compose([
  19. transforms.RandomSizedCrop(224) ,
  20. transforms.RandomHorizontalFlip() ,
  21. transforms.ToTensor() ,
  22. transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
  23. ]) ,
  24. 'val' : transforms.Compose([
  25. transforms.Scale(256) ,
  26. transforms.CenterCrop(224) ,
  27. transforms.ToTensor() ,
  28. transforms.Normalize([0.485 , 0.456 , 0.406] , [0.229 , 0.224 , 0.225])
  29. ]) ,
  30. }
  31.  
  32. data_dir = 'hymenoptera_data'
  33. image_datasets = {x : datasets.ImageFolder(os.path.join(data_dir , x) , data_transforms[x]) for x in ['train' , 'val']}
  34. dataloders = {x : torch.utils.data.DataLoader(image_datasets[x] , batch_size = 4 , shuffle = True , num_workers = 4) for x in ['train' , 'val']}
  35. dataset_sizes = {x : len(image_datasets[x]) for x in ['train' , 'val']}
  36. class_names = image_datasets['train'].classes
  37. print(class_names)
  38. use_gpu = torch.cuda.is_available()
  39. #show several images
  40. def imshow(inp , title = None):
  41. inp = inp.numpy().transpose((1 , 2 , 0))
  42. mean = np.array([0.485 , 0.456 , 0.406])
  43. std = np.array([0.229 , 0.224 , 0.225])
  44. inp = std * inp + mean
  45. inp = np.clip(inp , 0 , 1)
  46. plt.imshow(inp)
  47. if title is not None:
  48. plt.title(title)
  49. pylab.show()
  50. plt.pause(0.001)
  51.  
  52. inputs , classes = next(iter(dataloders['train']))
  53. out = torchvision.utils.make_grid(inputs)
  54. imshow(out , title = [class_names[x] for x in classes])
  55. #train the model
  56. def train_model(model , criterion , optimizer , scheduler , num_epochs = 25):
  57.  
  58. since = time.time()
  59. best_model_wts = model.state_dict() #Returns a dictionary containing a whole state of the module.
  60. best_acc = 0.0
  61.  
  62. for epoch in range(num_epochs):
  63. print('Epoch {}/{}'.format(epoch , num_epochs - 1))
  64. print('-' * 10)
  65. #set the mode of model
  66. for phase in ['train' , 'val']:
  67. if phase == 'train':
  68. scheduler.step() #about lr and gamma
  69. model.train(True) #set model to training mode
  70. else:
  71. model.train(False) #set model to evaluate mode
  72.  
  73. running_loss = 0.0
  74. running_corrects = 0
  75.  
  76. #Iterate over data
  77. for data in dataloders[phase]:
  78. inputs , labels = data
  79. if use_gpu:
  80. inputs = Variable(inputs.cuda())
  81. labels = Variable(labels.cuda())
  82. else:
  83. inputs = Variable(inputs)
  84. lables = Variable(labels)
  85. optimizer.zero_grad()
  86. #forward
  87. outputs = model(inputs)
  88. _ , preds = torch.max(outputs , 1)
  89. loss = criterion(outputs , labels)
  90. #backward
  91. if phase == 'train':
  92. loss.backward() #backward of gradient
  93. optimizer.step() #strategy to drop
  94. running_loss += loss.data[0]
  95. running_corrects += torch.sum(preds.data == labels.data)
  96.  
  97. epoch_loss = running_loss / dataset_sizes[phase]
  98. epoch_acc = running_corrects / dataset_sizes[phase]
  99. print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase , epoch_loss , epoch_acc))
  100.  
  101. if phase == 'val' and epoch_acc > best_acc:
  102. best_acc = epoch_acc
  103. best_model_wts = model.state_dict()
  104. print()
  105.  
  106. time_elapsed = time.time() - since
  107. print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60 , time_elapsed % 60))
  108. print('Best val Acc: {:4f}'.format(best_acc))
  109. model.load_state_dict(best_model_wts)
  110. return model
  111.  
  112. #visualizing the model predictions
  113. def visualize_model(model , num_images = 6):
  114. images_so_far = 0
  115. fig = plt.figure()
  116.  
  117. for i , data in enumerate(dataloders['val']):
  118. inputs , labels = data
  119. if use_gpu:
  120. inputs , labels = Variable(inputs.cuda()) , Variable(labels.cuda())
  121. else:
  122. inputs , labels = Variable(inputs) , Variable(labels)
  123.  
  124. outputs = model(inputs)
  125. _ , preds = torch.max(outputs.data , 1)
  126. for j in range(inputs.size()[0]):
  127. images_so_far += 1
  128. ax = plt.subplot(num_images // 2 , 2 , images_so_far)
  129. ax.axis('off')
  130. ax.set_title('predicted: {}'.format(class_names[preds[j]]))
  131. imshow(inputs.cpu().data[j])
  132.  
  133. if images_so_far == num_images:
  134. return
  135.  
  136. #Finetuning the convnet
  137. from torchvision.models.resnet import model_urls
  138. model_urls['resnet18'] = model_urls['resnet18'].replace('https://' , 'http://')
  139. model_ft = models.resnet18(pretrained = True)
  140. num_ftrs = model_ft.fc.in_features
  141. model_ft.fc = nn.Linear(num_ftrs , 2)
  142. if use_gpu:
  143. model_ft = model_ft.cuda()
  144. criterion = nn.CrossEntropyLoss()
  145. optimizer_ft = optim.SGD(model_ft.parameters() , lr = 0.001 , momentum = 0.9)
  146. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft , step_size = 7 , gamma = 0.1)
  147. #start finetuning
  148. model_ft = train_model(model_ft , criterion , optimizer_ft , exp_lr_scheduler , num_epochs = 25)
  149. torch.save(model_ft.state_dict() , '/home/zf/resnet18.pth')
  150. visualize_model(model_ft)

当然finetune的话有两种方式:在这个例子里

(1)只修改最后一层全连接层,输出类数改为2,然后在预训练模型上进行finetune;

(2)固定全连接层前面的卷积层参数,也就是它们不反向传播,只对最后一层进行反向传播;实现的时候前面这些层的requires_grad就设为False就OK了;

代码见下:

  1. model_conv = torchvision.models.resnet18(pretrained=True)
  2. for param in model_conv.parameters():
  3. param.requires_grad = False
  4.  
  5. # Parameters of newly constructed modules have requires_grad=True by default
  6. num_ftrs = model_conv.fc.in_features
  7. model_conv.fc = nn.Linear(num_ftrs, 2)
  8.  
  9. if use_gpu:
  10. model_conv = model_conv.cuda()
  11.  
  12. criterion = nn.CrossEntropyLoss()
  13.  
  14. # Observe that only parameters of final layer are being optimized as
  15. # opoosed to before.
  16. optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
  17.  
  18. # Decay LR by a factor of 0.1 every 7 epochs
  19. exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
  20. model_conv = train_model(model_conv, criterion, optimizer_conv,
  21. exp_lr_scheduler, num_epochs=25)

可以说,从构建网络,到训练网络,再到测试,由于完全是python风格,实在是太方便了~

pytorch官网上两个例程的更多相关文章

  1. spring官网上下载历史版本的spring插件,springsource-tool-suite

    spring官网下载地址(https://spring.io/tools/sts/all),历史版本地址(https://spring.io/tools/sts/legacy). 注:历史版本下载的都 ...

  2. jquery ui中的dialog,官网上经典的例子

    jquery ui中的dialog,官网上经典的例子   jquery ui中dialog和easy ui中的dialog很像,但是最近用到的时候全然没有印象,一段时间不用就忘记了,这篇随笔介绍一下这 ...

  3. [pytorch] 官网教程+注释

    pytorch官网教程+注释 Classifier import torch import torchvision import torchvision.transforms as transform ...

  4. iOS开发:创建推送开发证书和生产证书,以及往极光推送官网上传证书的步骤方法

    在极光官网上面上传应用的极光推送证书的实质其实就是上传导出的p12文件,在极光推送应用管理里面,需要上传两个p12文件,一个是生产证书,一个是开发证书 ,缺一不可,具体如下所示: 在开发者账号里面创建 ...

  5. 自己封装的Windows7 64位旗舰版,微软官网上下载的Windows7原版镜像制作,绝对纯净版

    MSDN官网上下载的Windows7 64位 旗舰版原版镜像制作,绝对纯净版,无任何精简,不捆绑任何第三方软件.浏览器插件,不含任何木马.病毒等. 集成: 1.Office2010 2.DirectX ...

  6. 关于在官网上查看和下载特定版本的webrtc代码

    注:这个方法已经不适用了,帖子没删只是留个纪念而已 gclient:如果不知道gclient是什么东西 ... 就别再往下看了. 下载特定版本的代码: #gclient sync --revision ...

  7. echarts官网上的动态加载数据bug被我解决。咳咳/。

    又是昨天,为什么昨天发生了这么多事.没办法,谁让我今天没事可做呢. 昨天需求是动态加载数据,画一个实时监控的折线图.大概长这样. 我屁颠屁颠的把代码copy过来,一运行,caocaocao~bug出现 ...

  8. 训练DCGAN(pytorch官网版本)

    将pytorch官网的python代码当下来,然后下载好celeba数据集(百度网盘),在代码旁新建celeba文件夹,将解压后的img_align_celeba文件夹放进去,就可以运行代码了. 输出 ...

  9. Jenkins利用官网上的rpm源安装

    官网网址:https://pkg.jenkins.io/redhat/                (官网上有安装的命令,参考网址) 安装jdk yum install -y java-1.8.0- ...

随机推荐

  1. Android Service服务的生命周期

    与activity类似,服务也存在生命周期回调方法,你可以实现这些方法来监控服务的状态变化,并在适当的时机执行一些操作. 以下代码提纲展示了服务的每个生命周期回调方法: public class Ex ...

  2. 设置outlook 2013 默认的ost路径

    How To Change Default Data File (.OST) Location in Office 2013 To set the default location of an out ...

  3. 平衡树Splay

    维护区间添加,删除,前驱,后继,排名,逆排名 普通平衡树 #include <cstdio> #define ls t[now].ch[0] #define rs t[now].ch[1] ...

  4. 【洛谷P4054】计数问题

    题目大意:维护 N*M 个点,每个点有三个权值,支持单点修改,查询矩形区间内权值等于某个值的点的个数. 题解:矩阵可以看成两个维度,权值为第三个维度,为一个三维偏序维护问题.发现第三维仅仅为单点修改和 ...

  5. javascript面向对象精要第五章继承整理精要

    javascript中使用原型链支持继承,当一个对象的[prototype]设置为另一个对象时, 就在这两个对象之间创建了一条原型对象链.如果要创建一个继承自其它对象的对象, 使用Object.cre ...

  6. c/c++ 整形转字符串

    int findex;char instr[10]; sprintf(instr,"%d",findex); 好像ltoa用不了...

  7. 第二节,TensorFlow 使用前馈神经网络实现手写数字识别

    一 感知器 感知器学习笔记:https://blog.csdn.net/liyuanbhu/article/details/51622695 感知器(Perceptron)是二分类的线性分类模型,其输 ...

  8. poj 1330(RMQ&LCA入门题)

    传送门:Problem 1330 https://www.cnblogs.com/violet-acmer/p/9686774.html 参考资料: http://dongxicheng.org/st ...

  9. vs widows服务的发布

    1.在service1.cs里空白处点击右键,弹出菜单选择 添加安装程序 2.自动生成ProjectInstaller.cs文件后 可在InitializeComponent()方法里自定义服务名称 ...

  10. Yosimite10.10(Mac os)安装c/c++内存检测工具valgrind

    1.下载支持包m4-1.4.13.tar.gz $ curl -O http://mirrors.kernel.org/gnu/m4/m4-1.4.13.tar.gz 2. 解压m4-1.4.13.t ...